diff --git a/.devcontainer/README.md b/.devcontainer/README.md new file mode 100644 index 000000000000..242cad8a839b --- /dev/null +++ b/.devcontainer/README.md @@ -0,0 +1,26 @@ +# Dev container configurations + +This directory contains the configuration for dev containers, which is used to +initialize the development environment in **Codespaces**, **Visual Studio +Code**, and **JetBrains IDEs**. The environment is installed with all the +necessary dependencies for development and is ready for linting, formatting, and +running tests. + +* **GitHub Codespaces**. Create a codespace for the repo by clicking + the "Code" button on the main page of the repo, selecting the "Codespaces" + tab, and clicking the "+". The configurations will automatically be used. + Follow + [this guide](https://docs.github.com/en/codespaces/developing-in-a-codespace/creating-a-codespace-for-a-repository) + for more details. + +* **Visual Studio Code**. Open the root folder of the repo in VS Code. A + notification will pop up to open it in a dev container with the + configuration. Follow + [this guide](https://code.visualstudio.com/docs/devcontainers/tutorial) + for more details. + +* **JetBrains IDEs**. Open the `.devcontainer/devcontainer.json` in your + JetBrains IDE. Click the docker icon to create a dev container. + Follow + [this guide](https://www.jetbrains.com/help/idea/connect-to-devcontainer.html) + for more details. \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000000..2cfd12938f1f --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,28 @@ +{ + "image": "mcr.microsoft.com/vscode/devcontainers/python:3.10", + "postCreateCommand": "sh ./.devcontainer/setup.sh && pip install -r requirements.txt", + "customizations": { + "vscode": { + "settings": { + "python.testing.pytestEnabled": true, + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": true + }, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "editor.rulers": [ + 80 + ] + }, + "extensions": [ + "charliermarsh.ruff", + "ms-python.python" + ] + } + }, + "features": { + "ghcr.io/devcontainers/features/github-cli:1": {} + } +} diff --git a/.devcontainer/setup.sh b/.devcontainer/setup.sh new file mode 100644 index 000000000000..dce0e8132710 --- /dev/null +++ b/.devcontainer/setup.sh @@ -0,0 +1,4 @@ +sudo pip install --upgrade pip +sudo pip install -r requirements.txt +echo "bash shell/lint.sh" > .git/hooks/pre-commit +chmod a+x .git/hooks/pre-commit diff --git a/.gemini/styleguide.md b/.gemini/styleguide.md new file mode 100644 index 000000000000..806c60d7948c --- /dev/null +++ b/.gemini/styleguide.md @@ -0,0 +1,205 @@ +# Keras API design guidelines + +These guidelines are meant to help focus design discussions and help us create delightful developer experiences. + +These are meant as guidelines, not rules: each decision should be debated in its own unique context. + +Some text remixed from external references: + +- [User experience design for APIs](https://blog.keras.io/user-experience-design-for-apis.html) +- [Notes to Myself on Software Engineering](https://medium.com/s/story/notes-to-myself-on-software-engineering-c890f16f4e4d) + +--- + +## Design end-to-end workflows, not individual functions and classes. + +When developing APIs, start by designing end-to-end workflows, and only sketch out specific function/class signatures at the end. + +- The goal is to arrive at workflows that feel like they are purposefully designed and well-optimized, rather than cobbled together to route around the features provided by the API. The workflows should come first, before atomic features. **Features only exist to support a workflow.** No feature should exist to provide a capability "just in case", "because we can". +- **Every design review document should prominently feature a code example of one or two end-to-end workflows showing the canonical use-case for the new API.** +- Every time we discuss choices surrounding a specific API feature, we should start by asking: **in what workflows will this be used?** Then we should make the choice that makes the most sense with respect to these workflows. We should not make API design decisions about features in isolation. +- This implies that we will often ask the question: **do users really need to configure this parameter?**, and in many cases, the answer will be "no", rather than being "yes" by default. + +--- + +## Carefully weigh whether a new feature should be included. + +It's okay to say no: just because someone asks for a feature doesn't mean we should do it. Every feature has a cost that goes beyond the initial CL: maintenance cost, documentation cost, and cognitive cost for our users (a sprawling API surface is a major usability issue). + +In particular, in the Keras API, every new feature has to be maintained in perpetuity. + +As such, our criteria for adding a new feature in the API is the following: + +- **It should be broadly useful to our users**, rather than a niche feature that is only relevant to a specific vertical of researchers. Niche features should be maintained independently by those who need them (e.g. by extending the API via subclassing), as third-party add-on packages. +- **It should be widely recognized as a machine learning best practice.** We will not add new layers/etc that were recently published to ArXiv.org, even in case of claims of increased accuracy/etc. We only add new objects that are already commonly used in the machine learning community. Presumably, a new technique that does result in meaningful gains would be broadly adopted after a few months anyway (like ResNet), and that's when we would be adding it to the core API. SIG-addons maintains a repository of significantly more volatile and independently maintained code to which the barriers to entry are lower. +- **It should have an owner committed to maintaining it in the long term.** In particular, the code should be maintainable by multiple people on the team, not just by one technical guru. + +In addition, when saying yes to a request for supporting a new use case, remember that **literally adding what the user/team requested is often not the optimal choice**. Users are focused on their own specific use case, and we must counter this with a holistic and principled vision of the whole project (see: designing end-to-end workflows, not atomic functions/classes). Often, the right answer is to extend an existing feature. **Find the natural place to integrate the new feature in existing APIs.** + +### Examples: + +- We should not have added the self-normalizing activation function to the API. It was added before passing the test of time, and that technique has shown later not to reach broad adoption. **Note that citation count is not a good metric of adoption**; that paper has a high citation count. +- We should not move to core an API that has debuted somewhere on GitHub or TF-Addons but has failed to gain more than a few users after a few months. + +--- + +## Seek to minimize cognitive load for our users. + +Always seek to minimize the cognitive load imposed on our users in the course of using our APIs. + +At a high level: + +- **Automate everything that can be automated.** +- **Minimize the actions & choices required from the user.** Make sure default values for arguments are sensible and reflect best practices (so that users usually wouldn't have to manually configure these). Don't expose options that are not important or do not match real use cases, "just in case". +- **Design simple and consistent workflows that reflect simple and consistent mental models.** + +Here are a few practical rules: + +- **No API should deal with internal implementation details.** An API is a language for our users to talk about the problem they care about -- and they don't care about our internal hacks. For instance, an option like `use_locking` in an optimizer should be avoided. If an argument requires users to understand the implementation (not just what the code is supposed to implement, like SGD in this case), then the argument should not be included in the public API. **An API is all about the problem it solves, not about how the code works in the background.** +- **Introduce as few new concepts as possible.** It's not just that additional data structures require more effort in order to learn about their methods and properties, it's that they multiply the number of **mental models** that are necessary to grok your API. Ideally, you should only need **a single universal mental model around which everything is organized** (in Keras, that's the `Layer`). Definitely avoid having more than 2 or 3 mental models underlying the workflows you design. Likewise, avoid having concepts that are mostly overlapping but subtly different, since the difference will be difficult to convey clearly and will confuse our users (like, say, `Network` and `Model` -- this is why we don't export `Network` as a public API). +- **Objects that do interchangeable things should have identical or very close APIs.** In particular they should have the same positional arguments. For example, it should be possible to swap one optimizer for another in user code (when leaving all arguments to their default value) without editing the arguments. +- **If you find yourself proposing a signature with more than 6-7 arguments, consider whether all of these arguments are useful.** How many people and use cases would be affected if you removed one argument? How much would they be affected -- would they be able to easily extend the API (e.g. via subclassing) to support their use case without that built-in argument? Could this API be broken up into smaller, modular objects? +- **Best-practices should come baked into your API.** The simplest way to use your API (leaving all arguments to their default value, using the most obvious tool for the task, etc) should be as close as possible to the best way of solving the problem. In particular, all arguments that can be given a default value should be given a default value, and that default should match the most common use case. +- **Plain Python types are preferable to custom types.** Use tuples, strings, ints... A custom type requires more knowledge and effort on the part of the user (e.g. `TensorShape`, which is also breaking established conventions of scientific Python). **When using enums, make sure that their values are strings**, so as to make it possible for users to pass plain strings (example: `data_format="channels_last"`, `padding="valid"`). +- **Explicit, single-level configuration arguments are preferable to nested, hidden configuration arguments.** Avoid something like: `MyLayer(hyperparameter_dict)`, instead use `MyLayer(units, activation=None, ...)`. + +In particular, naming is important and difficult: + +- **The meaning of an argument should be clear from its name and should not require knowledge that only the implementers have.** In particular, argument names should only involve recognized terms of art ("L1 norm" is a term of art), and should not involve implementation-related vocabulary (e.g. "fused batchnorm"). +- **Avoid `OverlyLongAndSpecificNamingPatterns`.** If you find yourself with argument names with involve more than 3 subparts (e.g. "squared_operator_norm"), reconsider. Argument names should be intuitive and easy to remember. +- Avoid overly generic names (`x`, `variable`, `parameter`). +- **Make sure you are consistent in your naming choices.** Naming consistency means both **internal naming consistency** (don't call `dim` what is called `axis` in other places, don't call `ndims` what is called `ndim` elsewhere) and **consistency with established conventions for the problem domain (terms of art)**. Before settling on a name, make sure to look up existing names used by domain experts (or other APIs). In our case, argument names should be consistent with the broader scientific Python conventions, in particular NumPy. + +Note that Keras uses the following naming rules: + +- We use the convention `num_*` for counters, though omitting an explicit counter is nicer when there is no ambiguity (e.g. `units`, `epochs`, `filters`). +- The rank of a tensor is its `ndim`. A specific dimension index is an `axis`. The number of dimensions in a linear projection (or similar) is `units`. +- By convention Keras layers are named with nouns rather than verbs (e.g. `Normalization` and not `Normalize`, `Convolution` and not `Convolve`). +- Following Python conventions, classes use capitalized parts (e.g. `ClassName`) and functions and methods use snake case (e.g. `function_name`). +- If an argument name has a numerical suffix (e.g. `alpha_1`), we put an underscore before the suffix in snake case. The capitalized equivalent would be e.g. `Alpha1`. +- We used fully spelled-out names, e.g. `attention_scores` and not `attn_scores`. There are a couple standardized exceptions to this rule, in particular `dim` for "dimension" and `num` for "number". These are sufficiently common that they are not ambiguous to a first-time reader. + +### Example: + +```python +MyConstructor( + per_variable_sparsity_config=[ + 'layer_1/kernel:0.8', 'layer_2/kernel:1.5']) +``` + +What's wrong with this? + +- Overly long argument name +- Too much cognitive load involved in preparing an appropriate argument value +- Preparing an argument value requires internal implementation knowledge +- Reliance on TF variable names (subject to changes at any time, thus breaking this code) +- Nested config adding indirection +- Incorrect typing (float values being passing as strings) + +Possible alternative: + +``` +obj = MyConstructor() +obj.configure_sparsity(some_layer.kernel, value=0.8) +obj.configure_sparsity(some_other_layer.kernel, value=1.5) +``` + +What's nice about this? + +- Object-based variable references. +- Modular, simple action, with a clear name. +- Plain Python types. + +--- + +## Balance expressivity vs. user-friendliness. + +### Simple use cases should be simple, advanced use cases should be possible: + +**Don't increase the cognitive load of common use cases for the sake of niche use cases**, even minimally. +**Make sure that advanced users have a path to support their use case**, even if this path requires the users to roll out plugins or other API extensions (in particular via subclassing). **It is ok for advanced use cases not to be directly supported in the built-in API options.** + +### Keep our APIs modular. + +**Complex objects should be achievable by composing simple objects with few arguments, that do one thing reliably.** There is a balance to strike between having complex signatures on fewer objects, and having more objects with simpler signatures. A good API has a reasonable number of objects, with reasonably simple signatures (see also: avoiding signatures with more than 6-7 arguments). + +**Things that create state or side-effects should be classes. Functions should be stateless.** +For instance, layers that create weights should not be cast as functions, since it makes the weights (and other elements of state) hard to access, impossible to update, and forces reliance on a global state capturing the side effects of layer-functions. + +### APIs should be strictly compartmentalized. + +For instance, the optimizer API or the layers API should not contain arguments for configuring distributed training. That should go into the distribution API. + +--- + +## Don't neglect error messages, docstrings, and documentation. + +Documentation and error messages are an integral part of the API. Good docs and helpful error messages are key to a delightful user experience. + +- **Catch user errors early and anticipate common mistakes.** Do user input validation as soon as possible. Actively keep track of common mistakes that people make (by screening GitHub and StackOverflow), and either solve them by simplifying our API, adding targeted error messages for these mistakes, or having a "solutions to common issues" page in our docs. Consider adding automated fallback behaviors (e.g. casting a wrongly-typed input) instead of raising errors, when applicable. Be nice to our users. +- **Provide detailed feedback messages upon user error.** Error messages should be contextual, informative, and actionable. Every error message that transparently provides the user with the solution to their problem means one less support ticket, multiplied by how many times users run into the same issue. A good error message should answer: + - What happened, in what context? + - What did the software expect? + - How can the user fix it? +- **A docstring should answer the question: what is this about, and why & how should I use it?** It should assume as little context as possible, and it shouldn't mention specialized terms without first introducing them (for example, "num_blocks: Number of blocks in the kernel" is not a good argument description if this is the first time you mention "blocks" in your docstring). +- **Show, don't tell: your documentation should not talk about how the software works, it should show how to use it.** Show code examples for end-to-end workflows; show code examples for each and every common use case and key feature of your API. **All docstrings should include code examples.** +- **Deliberately design the user onboarding process for your feature.** How are complete newcomers going to find out the best way to solve their use case with your tool? Have an answer ready. Make sure your onboarding material closely maps to what your users care about: don't teach newcomers how your framework is implemented, teach them how they can use it to solve their own problems. After shipping a CL and writing good docstrings, make sure to create a Colab guide / tutorial showcasing the target workflow, and post it on the docs website. +- The feature is not ready until: + - 1) Users know about it + - 2) They know how to use it + - 3) They're actually using it to solve the corresponding problem. + +Note that Keras uses the following rules for writing docstrings: + +- For class docstrings, document arguments in a `Arguments:` section in the class docstring, not in `__init__`. + - When a user creates a class, they are not calling the `MyLayer.__init__()` method as if it were a regular method, they are calling `MyLayer`. We don't want to generate documentation for the `__init__()` method as a standalone method that needs to be called directly, that would be confusing. We also don't need `__init__()` docstrings that always start with "Initializes a MyLayer class.", which is useless information. Leaving `__init__()` without a docstring is the best practice. + - If constructor arguments are documented in `__init__`, it forces us to programmatically copy the `__init__` docstring when generating docs and concatenate it to the class docstring. This means that the Arguments section becomes the last thing in the docstring, which is bad. +- The order of information in a class docstring should be: + - One-line description of the class, that gives initial context to the user. e.g. `Applies Dropout to the input.` Make sure the one-line description is useful. No `Intantiates an ObscureName class instance.` + - Paragraph(s) of more detailed information that tells the user what the object is for and when they need to use it. e.g. `The Dropout layer randomly sets input units to 0 with a frequency of "rate" at each step during training time, which helps prevent overfitting. Inputs not set to 0 are scaled up by "1/(1 - rate)" such that the sum over all inputs is unchanged. [...]` + - If there is a reference paper, cite it here. + - `Arguments` section. + - If it's a layer that has arguments in `call`, the `Call arguments` section. + - If it's a `Layer`, `Input shape` and `Output shape` sections. + - Example(s). + - Lastly, addendum. Information that isn't very important and that most users don't need, but that should be documented somewhere. + - e.g. the section "About the layer's `dtype` attribute" in the base Layer class. + - e.g. warnings about edge cases or compatibility issues. + - e.g. pointers to further guides and tutorials. + +### Error messages: a case study + +The following would be a very poor error message: + +``` +AssertionError: '1 != 3' +``` + +In general, to validate user input, always use `ValueError` and avoid `assert`. + +Also bad: + +``` +ValueError: 'Invalid target shape (600, 1).' +``` + +The following is better, but still not sufficient, because it does not tell the user what they passed, and does not quite say how to fix it: + +``` +ValueError: 'categorical_crossentropy requires target.shape[1] == classes' +``` + +Now, here's a good example, that says **what was passed**, **what was expected**, and **how to fix the issue**: + +``` +ValueError: '''You are passing a target array of shape (600, 1) while using as loss `categorical_crossentropy`. +`categorical_crossentropy` expects targets to be binary matrices (1s and 0s) of shape (samples, classes). +If your targets are integer classes, you can convert them to the expected format via: + +--- +from keras.utils import to_categorical +y_binary = to_categorical(y_int) +--- + +Alternatively, you can use the loss function `sparse_categorical_crossentropy` instead, which does expect integer targets. +``` diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..9ca60e46ff1e --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,26 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + github-actions: + patterns: + - "*" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + groups: + python: + patterns: + - "*" + ignore: + # TODO: ignore all updates for JAX GPU due to cuda version issue + - dependency-name: "jax[cuda12_pip]" \ No newline at end of file diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml new file mode 100644 index 000000000000..c07371e57e89 --- /dev/null +++ b/.github/workflows/actions.yml @@ -0,0 +1,150 @@ +name: Tests + +# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future +# Currently only basic flow tests run with NNX enabled + +on: + push: + branches: [ master ] + pull_request: + release: + types: [created] + +permissions: + contents: read + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10'] + backend: [tensorflow, jax, torch, numpy, openvino] + nnx_enabled: [false] + include: + - python-version: '3.11' + backend: jax + nnx_enabled: true + name: ${{ matrix.backend == 'jax' && format('Run tests ({0}, {1}, nnx_enabled = {2})', matrix.python-version, matrix.backend, matrix.nnx_enabled) || format('Run tests ({0}, {1})', matrix.python-version, matrix.backend) }} + runs-on: ubuntu-latest + env: + PYTHON: ${{ matrix.python-version }} + KERAS_HOME: .github/workflows/config/${{ matrix.backend }} + steps: + - uses: actions/checkout@v5 + - name: Check for changes in keras/src/applications + uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + applications: + - 'keras/src/applications/**' + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off --upgrade + if [ "${{ matrix.nnx_enabled }}" == "true" ]; then + pip install --upgrade flax>=0.11.1 + fi + pip uninstall -y keras keras-nightly + pip install -e "." --progress-bar off --upgrade + - name: Test applications with pytest + if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }} + run: | + pytest keras/src/applications --cov=keras/src/applications --cov-config=pyproject.toml + coverage xml --include='keras/src/applications/*' -o apps-coverage.xml + - name: Codecov keras.applications + if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }} + uses: codecov/codecov-action@v5 + with: + env_vars: PYTHON,KERAS_HOME + flags: keras.applications,keras.applications-${{ matrix.backend }} + files: apps-coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false + - name: Test integrations + if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }} + run: | + python integration_tests/import_test.py + python integration_tests/numerical_test.py + - name: Test JAX-specific integrations + if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }} + run: | + python integration_tests/jax_custom_fit_test.py + - name: Test basic flow with NNX + if: ${{ matrix.nnx_enabled == true }} + env: + KERAS_NNX_ENABLED: true + run: | + python integration_tests/import_test.py + python integration_tests/basic_full_flow.py + - name: Test TF-specific integrations + if: ${{ matrix.backend == 'tensorflow'}} + run: | + python integration_tests/tf_distribute_training_test.py + python integration_tests/tf_custom_fit_test.py + - name: Test Torch-specific integrations + if: ${{ matrix.backend == 'torch'}} + run: | + pytest integration_tests/torch_workflow_test.py + python integration_tests/torch_custom_fit_test.py + - name: Test with pytest + if: ${{ matrix.nnx_enabled == false }} + run: | + if [ "${{ matrix.backend }}" == "openvino" ]; then + IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt" + IGNORE_ARGS=$(awk '{print "--ignore=" $0}' "$IGNORE_FILE") + else + IGNORE_ARGS="" + fi + pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS + coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml + - name: Codecov keras + if: ${{ matrix.nnx_enabled == false }} + uses: codecov/codecov-action@v5 + with: + env_vars: PYTHON,KERAS_HOME,KERAS_NNX_ENABLED + flags: keras,keras-${{ matrix.backend }}${{ matrix.nnx_enabled == 'true' && '-nnx' || '' }} + files: core-coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false + + format: + name: Check the code format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up Python 3.10 + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off --upgrade + pip uninstall -y keras keras-nightly + pip install -e "." --progress-bar off --upgrade + - name: Run pre-commit + run: pre-commit run --all-files --hook-stage manual diff --git a/.github/workflows/auto-assignment.yaml b/.github/workflows/auto-assignment.yaml new file mode 100644 index 000000000000..32bfd7f564a7 --- /dev/null +++ b/.github/workflows/auto-assignment.yaml @@ -0,0 +1,21 @@ +name: auto-assignment +on: + issues: + types: + - opened + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/github-script@v8 + with: + script: | + const script = require('./\.github/workflows/scripts/auto-assignment.js') + script({github, context}) \ No newline at end of file diff --git a/.github/workflows/config/jax/keras.json b/.github/workflows/config/jax/keras.json new file mode 100644 index 000000000000..e20cd4ea7bfe --- /dev/null +++ b/.github/workflows/config/jax/keras.json @@ -0,0 +1,7 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "jax", + "image_data_format": "channels_last", + "nnx_enabled": false +} diff --git a/.github/workflows/config/numpy/keras.json b/.github/workflows/config/numpy/keras.json new file mode 100644 index 000000000000..bc20704a8320 --- /dev/null +++ b/.github/workflows/config/numpy/keras.json @@ -0,0 +1,6 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "numpy", + "image_data_format": "channels_last" +} diff --git a/.github/workflows/config/openvino/keras.json b/.github/workflows/config/openvino/keras.json new file mode 100644 index 000000000000..bc2ac8f1e344 --- /dev/null +++ b/.github/workflows/config/openvino/keras.json @@ -0,0 +1,6 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "openvino", + "image_data_format": "channels_last" +} diff --git a/.github/workflows/config/tensorflow/keras.json b/.github/workflows/config/tensorflow/keras.json new file mode 100644 index 000000000000..dd7fb5b2d368 --- /dev/null +++ b/.github/workflows/config/tensorflow/keras.json @@ -0,0 +1,6 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "tensorflow", + "image_data_format": "channels_last" +} diff --git a/.github/workflows/config/torch/keras.json b/.github/workflows/config/torch/keras.json new file mode 100644 index 000000000000..4d73d5171f0f --- /dev/null +++ b/.github/workflows/config/torch/keras.json @@ -0,0 +1,6 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "torch", + "image_data_format": "channels_first" +} diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml new file mode 100644 index 000000000000..350fd262c163 --- /dev/null +++ b/.github/workflows/labeler.yaml @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This workflow automatically identifies issues and pull requests (PRs) and add the +# appropriate label as per defined rules. +# First Labeler workflow: It searches for the keyword "Gemma" (case-insensitive) in both the title +# and description of the issue/PR. If a match is found, the workflow adds the label 'Gemma' to the issue/PR. + +name: 'Labeler' +on: + issues: + types: [edited,opened] + pull_request_target: + types: [opened, edited] + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/github-script@v8 + with: + script: | + const script = require('./\.github/workflows/scripts/labeler.js') + script({github, context}) \ No newline at end of file diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml new file mode 100644 index 000000000000..8a0a714d428b --- /dev/null +++ b/.github/workflows/nightly.yml @@ -0,0 +1,159 @@ +name: Nightly + +on: + workflow_dispatch: # To Generate wheels on demand outside of schedule. + schedule: + - cron: "0 3 * * *" # run at 3 AM UTC / 8 PM PDT + +permissions: + contents: read + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + backend: [tensorflow, jax, torch, numpy] + name: Run tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + env: + PYTHON: ${{ matrix.python-version }} + KERAS_BACKEND: ${{ matrix.backend }} + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off --upgrade + pip uninstall -y keras keras-nightly + pip install -e "." --progress-bar off --upgrade + - name: Test integrations + if: ${{ matrix.backend != 'numpy'}} + run: | + python integration_tests/import_test.py + - name: Test TF-specific integrations + if: ${{ matrix.backend == 'tensorflow'}} + run: | + python integration_tests/tf_distribute_training_test.py + - name: Test Torch-specific integrations + if: ${{ matrix.backend == 'torch'}} + run: | + pytest integration_tests/torch_workflow_test.py + - name: Test with pytest + run: | + pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml + + build-python-latest: + strategy: + fail-fast: false + matrix: + python-version: ["3.13"] + backend: [tensorflow, jax, torch, numpy] + name: Run tests (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + env: + PYTHON: ${{ matrix.python-version }} + KERAS_BACKEND: ${{ matrix.backend }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-latest-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off --upgrade + pip uninstall -y keras keras-nightly + pip install -e "." --progress-bar off --upgrade + - name: Test integrations + if: ${{ matrix.backend != 'numpy'}} + run: | + python integration_tests/import_test.py + - name: Test TF-specific integrations + if: ${{ matrix.backend == 'tensorflow'}} + run: | + python integration_tests/tf_distribute_training_test.py + - name: Test Torch-specific integrations + if: ${{ matrix.backend == 'torch'}} + run: | + pytest integration_tests/torch_workflow_test.py + - name: Test with pytest + run: | + pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml + + format: + name: Check the code format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up Python 3.10 + uses: actions/setup-python@v6 + with: + python-version: "3.10" + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off --upgrade + pip uninstall -y keras keras-nightly + pip install -e "." --progress-bar off --upgrade + - name: Run pre-commit + run: pre-commit run --all-files --hook-stage manual + + nightly: + name: Build Wheel file and upload + needs: [build, build-python-latest, format] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools + pip install twine + pip install -r requirements.txt --progress-bar off --upgrade + pip uninstall -y keras keras-nightly + - name: Build wheel file + run: | + python pip_build.py --nightly + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_NIGHTLY_API_TOKEN }} + packages-dir: dist/ + verbose: true diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml new file mode 100644 index 000000000000..ad04566a7b27 --- /dev/null +++ b/.github/workflows/scorecard.yml @@ -0,0 +1,61 @@ +name: Scorecard supply-chain security +on: + # For Branch-Protection check. Only the default branch is supported. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection + branch_protection_rule: + # To guarantee Maintained check is occasionally updated. See + # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained + schedule: + - cron: '42 8 * * 2' + push: + branches: [ "master" ] + +# Declare default permissions as read only. +permissions: read-all + +jobs: + analysis: + name: Scorecard analysis + runs-on: ubuntu-latest + permissions: + # Needed to upload the results to code-scanning dashboard. + security-events: write + # Needed to publish results and get a badge (see publish_results below). + id-token: write + + steps: + - name: "Checkout code" + uses: actions/checkout@ff7abcd0c3c05ccf6adc123a8cd1fd4fb30fb493 # v4.1.1 + with: + persist-credentials: false + + - name: "Run analysis" + uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3 + with: + results_file: results.sarif + results_format: sarif + # (Optional) "write" PAT token. Uncomment the `repo_token` line below if: + # - you want to enable the Branch-Protection check on a *public* repository, or + # - you are installing Scorecard on a *private* repository + # To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat. + # repo_token: ${{ secrets.SCORECARD_TOKEN }} + + # Publish results to OpenSSF REST API for easy access by consumers + # Allows the repository to include the Scorecard badge. + # See https://github.com/ossf/scorecard-action#publishing-results. + publish_results: true + + # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF + # format to the repository Actions tab. + - name: "Upload artifact" + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: SARIF file + path: results.sarif + retention-days: 5 + + # Upload the results to GitHub's code scanning dashboard. + - name: "Upload to code-scanning" + uses: github/codeql-action/upload-sarif@3599b3baa15b485a2e49ef411a7a4bb2452e7f93 # v3.29.5 + with: + sarif_file: results.sarif diff --git a/.github/workflows/scripts/auto-assignment.js b/.github/workflows/scripts/auto-assignment.js new file mode 100644 index 000000000000..89398a373041 --- /dev/null +++ b/.github/workflows/scripts/auto-assignment.js @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2023 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** Automatically assign issues and PRs to users in the `assigneesList` + * on a rotating basis. + + @param {!object} + GitHub objects can call GitHub APIs using their built-in library functions. + The context object contains issue and PR details. +*/ + +module.exports = async ({ github, context }) => { + let issueNumber; + let assigneesList; + // Is this an issue? If so, assign the issue number. Otherwise, assign the PR number. + if (context.payload.issue) { + //assignee List for issues. + assigneesList = ["mehtamansi29", "sachinprasadhs"]; + issueNumber = context.payload.issue.number; + } else { + //assignee List for PRs. + assigneesList = []; + issueNumber = context.payload.number; + } + console.log("assignee list", assigneesList); + console.log("entered auto assignment for this issue: ", issueNumber); + if (!assigneesList.length) { + console.log("No assignees found for this repo."); + return; + } + let noOfAssignees = assigneesList.length; + let selection = issueNumber % noOfAssignees; + let assigneeForIssue = assigneesList[selection]; + + console.log( + "issue Number = ", + issueNumber + " , assigning to: ", + assigneeForIssue + ); + return github.rest.issues.addAssignees({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + assignees: [assigneeForIssue], + }); +}; diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js new file mode 100644 index 000000000000..769683174688 --- /dev/null +++ b/.github/workflows/scripts/labeler.js @@ -0,0 +1,49 @@ +/* +Copyright 2024 Google LLC. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + + +/** + * Invoked from labeler.yaml file to add + * label 'Gemma' to the issue and PR for which have gemma keyword present. + * @param {!Object.} github contains pre defined functions. + * context Information about the workflow run. + */ + +module.exports = async ({ github, context }) => { + const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title + const issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body + const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number + const keyword_label = { + gemma:'Gemma' + } + const labelsToAdd = [] + console.log(issue_title,issue_description,issue_number) + + for(const [keyword, label] of Object.entries(keyword_label)){ + if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){ + console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`) + labelsToAdd.push(label) + } + } + if(labelsToAdd.length > 0){ + console.log(`Adding labels ${labelsToAdd} to the issue '#${issue_number}'.`) + github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: labelsToAdd + }) + } +}; \ No newline at end of file diff --git a/.github/workflows/stale-issue-pr.yaml b/.github/workflows/stale-issue-pr.yaml new file mode 100644 index 000000000000..72c25057ed3f --- /dev/null +++ b/.github/workflows/stale-issue-pr.yaml @@ -0,0 +1,55 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" +jobs: + close-issues: + # Don't do this in forks + if: github.repository == 'keras-team/keras' + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + actions: write + steps: + - name: Awaiting response issues + uses: actions/stale@v10 + with: + operations-per-run: 500 + days-before-issue-stale: 14 + days-before-issue-close: 14 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: completed + only-labels: "stat:awaiting response from contributor" + stale-issue-message: > + This issue is stale because it has been open for 14 days with no activity. + It will be closed if no further activity occurs. Thank you. + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:awaiting response from contributor" + close-issue-message: > + This issue was closed because it has been inactive for 28 days. + Please reopen if you'd like to work on this further. + days-before-pr-stale: 14 + days-before-pr-close: 14 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you." + close-pr-message: "This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further." + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Contribution issues + uses: actions/stale@v10 + with: + operations-per-run: 500 + days-before-issue-stale: 180 + days-before-issue-close: 365 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: not_planned + any-of-labels: "stat:contributions welcome,good first issue" + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:contributions welcome,good first issue" + stale-issue-message: > + This issue is stale because it has been open for 180 days with no activity. + It will be closed if no further activity occurs. Thank you. + close-issue-message: > + This issue was closed because it has been inactive for more than 1 year. + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 1c6e8077b138..afd700b49952 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,22 @@ -*.DS_Store +.DS_Store *.pyc -temp/* -build/* -keras/datasets/data/* -keras/datasets/temp/* \ No newline at end of file +.vscode-test +__pycache__ +**/.vscode-test/** +**/.vscode test/** +**/.vscode-smoke/** +**/.venv*/ +bin/** +build/** +obj/** +.pytest_cache +tmp/** +.vs/ +dist/** +**/*.egg-info/* +.vscode +examples/**/*.jpg +.python-version +.coverage +*coverage.xml +.ruff_cache \ No newline at end of file diff --git a/.kokoro/README.md b/.kokoro/README.md new file mode 100644 index 000000000000..2c7724d98822 --- /dev/null +++ b/.kokoro/README.md @@ -0,0 +1 @@ +CI to run on PR and merge to Master. \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh new file mode 100644 index 000000000000..d4118f977eea --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -0,0 +1,80 @@ +set -e +set -x + +cd "${KOKORO_ROOT}/" + +sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 + +PYTHON_BINARY="/usr/bin/python3.10" + +"${PYTHON_BINARY}" -m venv venv +source venv/bin/activate +# Check the python version +python --version +python3 --version + +# setting the LD_LIBRARY_PATH manually is causing segmentation fault +#export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:" +# Check cuda +nvidia-smi +nvcc --version + +cd "src/github/keras" +pip install -U pip setuptools +# psutil is used by background log reader +pip install -U psutil + +if [ "$KERAS_BACKEND" == "tensorflow" ] +then + echo "TensorFlow backend detected." + pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000 + pip uninstall -y keras keras-nightly + echo "Check that TensorFlow uses GPU" + python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("GPU"))' + # Raise error if GPU is not detected. + python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("GPU")) > 0' + + # TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted + pytest keras --ignore keras/src/applications \ + --ignore keras/src/layers/merging/merging_test.py \ + --cov=keras \ + --cov-config=pyproject.toml +fi + +if [ "$KERAS_BACKEND" == "jax" ] +then + echo "JAX backend detected." + pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000 + pip uninstall -y keras keras-nightly + python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())' + # Raise error if GPU is not detected. + python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"' + + # TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted + # TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted + # keras/backend/jax/distribution_lib_test.py is configured for CPU test for now. + pytest keras --ignore keras/src/applications \ + --ignore keras/src/layers/merging/merging_test.py \ + --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \ + --ignore keras/src/backend/jax/distribution_lib_test.py \ + --ignore keras/src/distribution/distribution_lib_test.py \ + --cov=keras \ + --cov-config=pyproject.toml + + pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml +fi + +if [ "$KERAS_BACKEND" == "torch" ] +then + echo "PyTorch backend detected." + pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000 + pip uninstall -y keras keras-nightly + python3 -c 'import torch;print(torch.__version__);print(torch.cuda.is_available())' + # Raise error if GPU is not detected. + python3 -c 'import torch;assert torch.cuda.is_available()' + + pytest keras --ignore keras/src/applications \ + --cov=keras \ + --cov-config=pyproject.toml + +fi diff --git a/.kokoro/github/ubuntu/gpu/jax/continuous.cfg b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg new file mode 100644 index 000000000000..0447221645c6 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "jax" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg new file mode 100644 index 000000000000..0447221645c6 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "jax" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg new file mode 100644 index 000000000000..ab1accd6f920 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "tensorflow" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg new file mode 100644 index 000000000000..ab1accd6f920 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "tensorflow" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/continuous.cfg b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg new file mode 100644 index 000000000000..6742451f36d5 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "torch" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg new file mode 100644 index 000000000000..6742451f36d5 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "torch" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000000..6003a890ce0c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + - repo: local + hooks: + - id: api-gen + name: api_gen + entry: | + bash shell/api_gen.sh + git status + clean=$(git status | grep "nothing to commit") + if [ -z "$clean" ]; then + echo "Please run shell/api_gen.sh to generate API." + exit 1 + fi + language: system + stages: [pre-commit, manual] + require_serial: true + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.2 + hooks: + - id: ruff + args: [--config, pyproject.toml, --fix, .] + stages: [pre-commit] + - id: ruff-format + args: [--config, pyproject.toml, .] + stages: [pre-commit] + - id: ruff + args: [--config, pyproject.toml, .] + stages: [manual] + - id: ruff-format + args: ["--check", --config, pyproject.toml, .] + stages: [manual] \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000000..61b18ac7ed3d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,221 @@ +Keras 3 is a high-velocity open-source project. We welcome contributions! + +Contributions can be made in a variety of ways, including coding, enriching documentation, refining docstrings, and providing code examples. + + +## Current items open for contributions +At [this link](https://github.com/keras-team/keras/issues/18442), you'll find a list of items where you help is needed! + + +## How to contribute code + +Follow these steps to submit your code contribution. + +### Step 1. Open an issue + +Before making any changes, we recommend opening an issue (if one doesn't already +exist) and discussing your proposed changes. This way, we can give you feedback +and validate the proposed changes. + +If the changes are minor (simple bug fix or documentation fix), then feel free +to open a Pull Request (PR) without discussion. + +### Step 2. Make code changes + +To make code changes, you need to fork the repository. You will need to setup a +development environment and run the unit tests. This is covered in the section +"Setup environment". + +### Step 3. Create a pull request + +Once the change is ready, open a pull request from your branch in your fork to +the master branch in [keras-team/keras](https://github.com/keras-team/keras). + +### Step 4. Sign the Contributor License Agreement + +After creating the pull request, the `cla/google` check will be performed and, +if you haven't signed the Contributor License Agreement (CLA), it will fail with +instructions on how to do so. Please follow the instructions to sign the CLA and +the check will pass. + +![CLA signed](https://github.com/keras-team/keras/assets/1091026/71c26353-e3b5-4135-8bae-64693c717775) + + +### Step 5. Code review + +If the tests fail, look into the error messages and try to fix them. + +![CI tests](https://github.com/keras-team/keras/assets/1091026/6f6c17ef-6bd7-4e95-9fbc-1906cde37380) + +A reviewer will review the pull request and provide comments. There may be +several rounds of comments and code changes before the pull request gets +approved by the reviewer. + +![Approval from reviewer](https://github.com/keras-team/keras/assets/1091026/8d28f74c-21e9-4146-b0ff-62d649a552a8) + +### Step 6. Merging + +Once the pull request is approved, a `ready to pull` tag will be added to the +pull request. A team member will take care of the merging. + +![Ready to pull and merged](https://github.com/keras-team/keras/assets/1091026/c3908345-d7ae-44ee-a428-01f3b448b46b) + +Here is an [example pull request](https://github.com/keras-team/keras/pull/18848) +for your reference. + +## Setup environment + +We provide two ways of setting up a development environment. One is to use a +dev container, and the other one is to set up a local environment by installing +the dev tools needed. + +### Option 1: GitHub Codespace or dev container + +We support GitHub Codespaces, Visual Studio Code dev containers and JetBrain dev +containers. Please see the +[Dev container documentation](https://github.com/keras-team/keras/tree/master/.devcontainer). + +### Option 2: Set up a local environment + +To set up your local dev environment, you will need the following tools. + +1. [git](https://github.com/) for code repository management. +2. [python](https://www.python.org/) to build and code in Keras. + +The following commands check the tools above are successfully installed. Note +that Keras requires at least Python 3.10 to run. + +```shell +git --version +python --version +``` + +Clone your forked repo to your local machine. Go to the cloned directory to +install the dependencies. + +```shell +git clone https://github.com/YOUR_GITHUB_USERNAME/keras.git +cd keras +pip install -r requirements.txt +``` + +You then need to configure the backend to use, see the +[Configuring your backend](https://github.com/keras-team/keras/blob/master/README.md#configuring-your-backend) +section of the README. + +You can also add GPU support to your environment, see the +[Adding GPU support](https://github.com/keras-team/keras/blob/master/README.md#adding-gpu-support) +section of the README. + +## Generating public API and formatting the code + +For the first time you are setting up the repo, please run `pre-commit install`. +Note that this needs to be done only once at the beginning. + +Now, whenever you run `git commit -m ""`, three things are +automatically done: + +- Public API generation +- Code formatting +- Code linting + +If there's any error, the commit will not go through. Please fix the error ( +most of the times, the error is fixed automatically by the formatter/linter) and +re-run the following: + +``` +git add . +git commit -m "" # This will not get logged as a duplicate commit. +``` + +In case you want to run the above manually on all files, you can do the +following: + +``` +pre-commit run --all-files +``` + +KerasHub uses [Ruff](https://docs.astral.sh/ruff/) to format the code. + +### Docstrings + +We do not have an automated way to check docstring style, so if you write +or edit any docstring, please make sure to check them manually. +Keras docstrings follow the conventions below: + +A **class docstring** may contain the following items: + +* A one-line description of the class. +* Paragraph(s) of more detailed information. +* Optional `Examples` section. +* `Args` section for arguments in `__init__()`. +* If it's a layer: + * `Call arguments` section for arguments in `Layer.call()`. + * `Returns` section for the return values of `Layer.call()`. + * Optional `Raises` section for possible errors. + +You can check out `MultiHeadAttention` as an example +[(link)](https://github.com/keras-team/keras/blob/v3.0.0/keras/layers/attention/multi_head_attention.py#L20). + +A **function docstring** may contain the following items: + +* One-line description of the function. +* Paragraph(s) of more detailed information. +* Optional `Examples` section. +* `Args` section for the function arguments. +* `Returns` section for the return values. +* Optional `Raises` section for possible errors. + +You can check out `text_dataset_from_directory` as an example +[(link)](https://github.com/keras-team/keras/blob/v3.0.0/keras/utils/text_dataset_utils.py#L27). + +## Run tests + +We use [pytest](https://pytest.org/) to run the tests. + +### Run a test file + +To run the tests in `keras/src/losses/losses_test.py`, use the following command +at the root directory of the repo. + +```shell +pytest keras/src/losses/losses_test.py +``` + +### Run a single test case + +You can specify a single test class to run within a file. + +```shell +pytest keras/src/losses/losses_test.py::MeanSquaredErrorTest +``` + +You can also specify a single test method to run within a class. + +```shell +pytest keras/src/losses/losses_test.py::MeanSquaredErrorTest::test_sample_weighted +``` + +### Run all tests + +You can run all the tests locally by running the following command in the repo +root directory. + +```shell +pytest keras +``` + +Note that you can skip the Keras applications tests using the +`SKIP_APPLICATIONS_TESTS` environment variable. This will cut down the testing +time significantly. + +```shell +SKIP_APPLICATIONS_TESTS=True pytest keras +``` + +To run all tests using a different backend, you can simply specify it on the +command line. + +```shell +KERAS_BACKEND=jax SKIP_APPLICATIONS_TESTS=True pytest keras +``` diff --git a/LICENSE b/LICENSE index 20efd1b3e976..f49a4e16e68b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,22 +1,201 @@ -The MIT License (MIT) + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -Copyright (c) 2015 + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: + 1. Definitions. -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 21f8c071ea4d..09eefc83741d 100644 --- a/README.md +++ b/README.md @@ -1,267 +1,123 @@ -# Keras: Theano-based Deep Learning library +# Keras 3: Deep Learning for Humans -## You have just found Keras. +Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only). +Effortlessly build and train models for computer vision, natural language processing, audio processing, +timeseries forecasting, recommender systems, etc. -Keras is a minimalist, highly modular neural network library in the spirit of Torch, written in Python / Theano so as not to have to deal with the dearth of ecosystem in Lua. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research. +- **Accelerated model development**: Ship deep learning solutions faster thanks to the high-level UX of Keras +and the availability of easy-to-debug runtimes like PyTorch or JAX eager execution. +- **State-of-the-art performance**: By picking the backend that is the fastest for your model architecture (often JAX!), +leverage speedups ranging from 20% to 350% compared to other frameworks. [Benchmark here](https://keras.io/getting_started/benchmarks/). +- **Datacenter-scale training**: Scale confidently from your laptop to large clusters of GPUs or TPUs. -Use Keras if you need a deep learning library that: -- allows for easy and fast prototyping (through total modularity, minimalism, and extensibility). -- supports both convolutional networks (for vision) and recurrent networks (for sequence data). As well as combinations of the two. -- runs seamlessly on the CPU and the GPU. +Join nearly three million developers, from burgeoning startups to global enterprises, in harnessing the power of Keras 3. -## Guiding principles -- __Modularity.__ A model is understood as a sequence of standalone, fully-configurable modules that can be plugged together with as little restrictions as possible. In particular, neural layers, cost functions, optimizers, initialization schemes, activation functions and dropout are all standalone modules that you can combine to create new models. +## Installation -- __Minimalism.__ Each module should be kept short and simple (<100 lines of code). Every piece of code should be transparent upon first reading. No black magic: it hurts iteration speed and ability to innovate. +### Install with pip -- __Easy extensibility.__ A new feature (a new module, per the above definition, or a new way to combine modules together) are dead simple to add (as new classes/functions), and existing modules provide ample examples. +Keras 3 is available on PyPI as `keras`. Note that Keras 2 remains available as the `tf-keras` package. -- __Work with Python__. No separate models configuration files in a declarative format (like in Caffe or PyLearn2). Models are described in Python code, which is compact, easier to debug, benefits from syntax highlighting, and most of all, allows for ease of extensibility. See for yourself with the examples below. +1. Install `keras`: -## Examples +``` +pip install keras --upgrade +``` -### Multilayer Perceptron (MLP): +2. Install backend package(s). -```python -from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation -from keras.optimizers import SGD - -model = Sequential() -model.add(Dense(20, 64, init='uniform')) -model.add(Activation('tanh')) -model.add(Dropout(0.5)) -model.add(Dense(64, 64, init='uniform')) -model.add(Activation('tanh')) -model.add(Dropout(0.5)) -model.add(Dense(64, 1, init='uniform')) -model.add(Activation('softmax')) - -sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='mean_squared_error', optimizer=sgd) - -model.fit(X_train, y_train, nb_epoch=20, batch_size=16) -score = model.evaluate(X_test, y_test, batch_size=16) -``` +To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. +Note that `tensorflow` is required for using certain Keras 3 features: certain preprocessing layers +as well as `tf.data` pipelines. -### Alternative implementation of MLP: +### Local installation -```python -model = Sequential() -model.add(Dense(20, 64, init='uniform', activation='tanh')) -model.add(Dropout(0.5)) -model.add(Dense(64, 64, init='uniform', activation='tanh')) -model.add(Dropout(0.5)) -model.add(Dense(64, 1, init='uniform', activation='softmax') - -sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='mean_squared_error', optimizer=sgd) -``` +#### Minimal installation -### VGG-like convnet: +Keras 3 is compatible with Linux and macOS systems. For Windows users, we recommend using WSL2 to run Keras. +To install a local development version: -```python -from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation, Flatten -from keras.layers.convolutional import Convolution2D, MaxPooling2D -from keras.optimizers import SGD - -model = Sequential() -model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(32, 32, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) -model.add(Dropout(0.25)) - -model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(64, 64, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) -model.add(Dropout(0.25)) - -model.add(Flatten(64*8*8)) -model.add(Dense(64*8*8, 256)) -model.add(Activation('relu')) -model.add(Dropout(0.5)) - -model.add(Dense(256, 10)) -model.add(Activation('softmax')) - -sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='categorical_crossentropy', optimizer=sgd) - -model.fit(X_train, Y_train, batch_size=32, nb_epoch=1) +1. Install dependencies: +``` +pip install -r requirements.txt ``` -### Sequence classification with LSTM: +2. Run installation command from the root directory. -```python -from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation, Embedding -from keras.layers.recurrent import LSTM +``` +python pip_build.py --install +``` -model = Sequential() -model.add(Embedding(max_features, 256)) -model.add(LSTM(256, 128, activation='sigmoid', inner_activation='hard_sigmoid')) -model.add(Dropout(0.5)) -model.add(Dense(128, 1)) -model.add(Activation('sigmoid')) +3. Run API generation script when creating PRs that update `keras_export` public APIs: -model.compile(loss='binary_crossentropy', optimizer='rmsprop') +``` +./shell/api_gen.sh +``` + +#### Adding GPU support -model.fit(X_train, Y_train, batch_size=16, nb_epoch=10) -score = model.evaluate(X_test, Y_test, batch_size=16) +The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also +provide a separate `requirements-{backend}-cuda.txt` for TensorFlow, JAX, and PyTorch. These install all CUDA +dependencies via `pip` and expect a NVIDIA driver to be pre-installed. We recommend a clean Python environment for each +backend to avoid CUDA version mismatches. As an example, here is how to create a JAX GPU environment with `conda`: + +```shell +conda create -y -n keras-jax python=3.10 +conda activate keras-jax +pip install -r requirements-jax-cuda.txt +python pip_build.py --install ``` -### Architecture for learning image captions with a convnet and a Gated Recurrent Unit: -(word-level embedding, caption of maximum length 16 words). +## Configuring your backend -Note that getting this to actually "work" will require using a bigger convnet, initialized with pre-trained weights. -Displaying readable results will also require an embedding decoder. +You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json` +to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example: -```python -max_caption_len = 16 - -model = Sequential() -model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(32, 32, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) - -model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(64, 64, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) - -model.add(Convolution2D(128, 64, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(128, 128, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) - -model.add(Flatten(128*4*4)) -model.add(Dense(128*4*4, 256)) -model.add(Activation('relu')) -model.add(Dropout(0.5)) - -model.add(Repeat(max_caption_len)) -# the GRU below returns sequences of max_caption_len vectors of size 256 (our word embedding size) -model.add(GRU(256, 256, return_sequences=True)) - -model.compile(loss='mean_squared_error', optimizer='rmsprop') - -# "images" is a numpy array of shape (nb_samples, nb_channels=3, width, height) -# "captions" is a numpy array of shape (nb_samples, max_caption_len=16, embedding_dim=256) -# captions are supposed already embedded (dense vectors). -model.fit(images, captions, batch_size=16, nb_epoch=100) - +``` +export KERAS_BACKEND="jax" ``` -In the examples folder, you will find example models for real datasets: -- CIFAR10 small images classification: Convnet with realtime data augmentation -- IMDB movie review sentiment classification: LSTM over sequences of words -- Reuters newswires topic classification: Multilayer Perceptron - -## Warning - -This is a 0.0.1 alpha release. Feature scope is limited, and wild bugs may appear. - -## Current capabilities - -- model architectures: - - Sequential (pipeline of layers) - -- layers: - - layers.core: - - Dense - - Dropout - - Activation - - Embedding - - Reshape - - Flatten - - RepeatVector - - layers.convolutional: - - Convolution2D - - MaxPooling2D - - layers.recurrent: - - SimpleRNN - - SimpleDeepRNN - - GRU - - LSTM - - layers.advanced_activations: - - LeakyReLU - - PReLU - - layers.normalization: - - BatchNormalization - -- optimizers: - - SGD (supports decay, momentum, Nesterov momentum) - - RMSprop - - Adagrad - - Adadelta - -- datasets: - - CIFAR10: thumbnail image classification - - Reuters: newswire topic classification - - IMDB: sentiment classification - -- preprocessing: - - image: - - ImageDataGenerator: realtime image data augmentation and preprocessing (normalization, ZCA whitening) - - random_rotation - - random_shift - - horizontal_flip - - vertical_flip - - text: - - Tokenizer - - one_hot - - sequence: - - pad_sequences - -- objectives: - - mean_squared_error - - mean_absolute_error - - hinge - - squared_hinge - - binary_crossentropy - - categorical_crossentropy - -- activation functions: - softmax, softplus, relu, sigmoid, hard_sigmoid, linear - -- initialization functions: - normal, uniform, lecun_uniform, orthogonal +In Colab, you can do: +```python +import os +os.environ["KERAS_BACKEND"] = "jax" -## Installation +import keras +``` -Keras uses the following dependencies: +**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after +the package has been imported. -- numpy, scipy +**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model +predictions using `model.predict()` method. -- Theano - - See installation instructions: http://deeplearning.net/software/theano/install.html#install +## Backwards compatibility -- PIL (optional, required if you use preprocessing.image) +Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your +existing `tf.keras` code, make sure that your calls to `model.save()` are using the up-to-date `.keras` format, and you're +done. -- Optional but recommended if you use CNNs: cuDNN. +If your `tf.keras` model does not include custom components, you can start running it on top of JAX or PyTorch immediately. -Once you have the dependencies installed, cd to the Keras folder and run the install command: -``` -sudo python setup.py install -``` +If it does include custom components (e.g. custom layers or a custom `train_step()`), it is usually possible to convert it +to a backend-agnostic implementation in just a few minutes. -## Why this name, Keras? +In addition, Keras models can consume datasets in any format, regardless of the backend you're using: +you can train your models with your existing `tf.data.Dataset` pipelines or PyTorch `DataLoaders`. -Keras (κέρας) means _horn_ in Greek. It is a reference to a literary image from ancient Greek and Latin literature, first found in the _Odyssey_, where dream spirits (_Oneiroi_, singular _Oneiros_) are divided between those who deceive men with false visions, who arrive to Earth through a gate of ivory, and those who announce a future that will come to pass, who arrive through a gate of horn. It's a play on the words κέρας (horn) / κραίνω (fulfill), and ἐλέφας (ivory) / ἐλεφαίρομαι (deceive). +## Why use Keras 3? -Keras was developed as part of the research effort of project ONEIROS (Open-ended Neuro-Electronic Intelligent Robot Operating System). +- Run your high-level Keras workflows on top of any framework -- benefiting at will from the advantages of each framework, +e.g. the scalability and performance of JAX or the production ecosystem options of TensorFlow. +- Write custom components (e.g. layers, models, metrics) that you can use in low-level workflows in any framework. + - You can take a Keras model and train it in a training loop written from scratch in native TF, JAX, or PyTorch. + - You can take a Keras model and use it as part of a PyTorch-native `Module` or as part of a JAX-native model function. +- Make your ML code future-proof by avoiding framework lock-in. +- As a PyTorch user: get access to power and usability of Keras, at last! +- As a JAX user: get access to a fully-featured, battle-tested, well-documented modeling and training library. -_"Oneiroi are beyond our unravelling --who can be sure what tale they tell? Not all that men look for comes to pass. Two gates there are that give passage to fleeting Oneiroi; one is made of horn, one of ivory. The Oneiroi that pass through sawn ivory are deceitful, bearing a message that will not be fulfilled; those that come out through polished horn have truth behind them, to be accomplished for men who see them."_ Homer, Odyssey 19. 562 ff (Shewring translation). +Read more in the [Keras 3 release announcement](https://keras.io/keras_3/). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000000..6850a69606a3 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,80 @@ +# Security Policy + + - [**Using Keras Securely**](#using-keras-securely) + - [Untrusted inputs](#untrusted-inputs) + - [Data privacy](#data-privacy) + - [Untrusted environments or networks](#untrusted-environments-or-networks) + - [Multi-Tenant environments](#multi-tenant-environments) + - [**Reporting a Vulnerability**](#reporting-a-vulnerability) + +## Using Keras Securely + +### Untrusted inputs + +Some models accept various input formats (text, images, audio, etc.). The libraries converting these inputs have varying security levels, so it's crucial to isolate the model and carefully pre-process inputs to mitigate script injection risks. + +For maximum security when handling untrusted inputs, you may need to employ the following: + +* Sandboxing: Isolate the model process. +* Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using [fuzzing for prompt injection](https://github.com/FonduAI/awesome-prompt-injection?tab=readme-ov-file#tools)). This will give you leads on how hard you will have to work on the next topics. +* Updates: Keep your model and libraries updated with the latest security patches. +* Input Sanitation: Before feeding data to the model, sanitize inputs rigorously. This involves techniques such as: + * Validation: Enforce strict rules on allowed characters and data types. + * Filtering: Remove potentially malicious scripts or code fragments. + * Encoding: Convert special characters into safe representations. + * Verification: Run tooling that identifies potential script injections (e.g. [models that detect prompt injection attempts](https://python.langchain.com/docs/guides/safety/hugging_face_prompt_injection)). + +### Data privacy +To protect sensitive data from potential leaks or unauthorized access, it is essential to sandbox the model execution. This means running the model in a secure, isolated environment, which helps mitigate many attack vectors. + +When training the model with sensitive data, expose your newly-trained model to tests to identify potential sensitive data leaks. + +### Untrusted environments or networks + +If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions: +* Confirm the hash of any downloaded artifact (i.e. pre-trained model weights) matches a known-good value +* Encrypt your data while sending it over the network. + +### Multi-Tenant environments + +If you intend to run multiple models in parallel with shared memory, it is your responsibility to ensure the models do not interact or access each other's data. The primary areas of concern are tenant isolation, resource allocation, model sharing and hardware attacks. + +#### Tenant Isolation + +You must make sure that models run separately. Since models can run code, it's important to use strong isolation methods to prevent unwanted access to the data from other tenants. + +Separating networks is also a big part of isolation. If you keep model network traffic separate, you not only prevent unauthorized access to data or models, but also prevent malicious users or tenants sending graphs to execute under another tenant’s identity. + +#### Resource Allocation + +A denial of service caused by one model can impact the overall system health. Implement safeguards like rate limits, access controls, and health monitoring. + +#### Model Sharing + +In a multitenant design that allows sharing models, make sure that tenants and users fully understand the potential security risks involved. They must be aware that they will essentially be running code provided by other users. Unfortunately, there are no reliable methods available to detect malicious models, graphs, or checkpoints. To mitigate this risk, the recommended approach is to sandbox the model execution, effectively isolating it from the rest of the system. + +#### Hardware Attacks + +Besides the virtual environment, the hardware (GPUs or TPUs) can also be attacked. [Research](https://scholar.google.com/scholar?q=gpu+side+channel) has shown that side channel attacks on GPUs are possible, which can make data leak from other models or processes running on the same system at the same time. + +## Reporting a Vulnerability + +Beware that none of the topics under [Using Keras Securely](#using-keras-securely) are considered vulnerabilities of Keras. + +If you have discovered a security vulnerability in this project, please report it +privately. **Do not disclose it as a public issue.** This gives us time to work with you +to fix the issue before public exposure, reducing the chance that the exploit will be +used before a patch is released. + +You may submit the report in the following ways: + +- send an email to francois.chollet@gmail.com; and/or +- send a [private vulnerability report](https://github.com/keras-team/keras/security/advisories/new) + +Please provide the following information in your report: + +- A description of the vulnerability and its impact +- How to reproduce the issue + +This project is maintained by volunteers on a reasonable-effort basis. As such, +please give us 90 days to work on a fix before public exposure. diff --git a/api_gen.py b/api_gen.py new file mode 100644 index 000000000000..daa4e9f2d579 --- /dev/null +++ b/api_gen.py @@ -0,0 +1,187 @@ +"""Script to generate keras public API in `keras/api` directory. + +Usage: + +Run via `./shell/api_gen.sh`. +It generates API and formats user and generated APIs. +""" + +import os +import re +import shutil + +import namex + +PACKAGE = "keras" +BUILD_DIR_NAME = "tmp_build_dir" + + +def ignore_files(_, filenames): + return [f for f in filenames if f.endswith("_test.py")] + + +def copy_source_to_build_directory(root_path): + # Copy sources (`keras/` directory and setup files) to build dir + build_dir = os.path.join(root_path, BUILD_DIR_NAME) + build_package_dir = os.path.join(build_dir, PACKAGE) + build_src_dir = os.path.join(build_package_dir, "src") + root_src_dir = os.path.join(root_path, PACKAGE, "src") + if os.path.exists(build_dir): + shutil.rmtree(build_dir) + os.makedirs(build_package_dir) + shutil.copytree(root_src_dir, build_src_dir) + return build_dir + + +def create_legacy_directory(package_dir): + src_dir = os.path.join(package_dir, "src") + # Make keras/_tf_keras/ by copying keras/ + tf_keras_dirpath_parent = os.path.join(package_dir, "_tf_keras") + tf_keras_dirpath = os.path.join(tf_keras_dirpath_parent, "keras") + os.makedirs(tf_keras_dirpath, exist_ok=True) + with open(os.path.join(tf_keras_dirpath_parent, "__init__.py"), "w") as f: + f.write("from keras._tf_keras import keras\n") + with open(os.path.join(package_dir, "__init__.py")) as f: + init_file = f.read() + init_file = init_file.replace( + "from keras import _legacy as _legacy", + "from keras import _tf_keras as _tf_keras", + ) + with open(os.path.join(package_dir, "__init__.py"), "w") as f: + f.write(init_file) + # Remove the import of `_tf_keras` in `keras/_tf_keras/keras/__init__.py` + init_file = init_file.replace("from keras import _tf_keras\n", "\n") + with open(os.path.join(tf_keras_dirpath, "__init__.py"), "w") as f: + f.write(init_file) + for dirname in os.listdir(package_dir): + dirpath = os.path.join(package_dir, dirname) + if os.path.isdir(dirpath) and dirname not in ( + "_legacy", + "_tf_keras", + "src", + ): + destpath = os.path.join(tf_keras_dirpath, dirname) + if os.path.exists(destpath): + shutil.rmtree(destpath) + shutil.copytree( + dirpath, + destpath, + ignore=ignore_files, + ) + + # Copy keras/_legacy/ file contents to keras/_tf_keras/keras + legacy_submodules = [ + path[:-3] + for path in os.listdir(os.path.join(src_dir, "legacy")) + if path.endswith(".py") + ] + legacy_submodules += [ + path + for path in os.listdir(os.path.join(src_dir, "legacy")) + if os.path.isdir(os.path.join(src_dir, "legacy", path)) + ] + for root, _, fnames in os.walk(os.path.join(package_dir, "_legacy")): + for fname in fnames: + if fname.endswith(".py"): + legacy_fpath = os.path.join(root, fname) + tf_keras_root = root.replace( + os.path.join(os.path.sep, "_legacy"), + os.path.join(os.path.sep, "_tf_keras", "keras"), + ) + core_api_fpath = os.path.join( + root.replace(os.path.join(os.path.sep, "_legacy"), ""), + fname, + ) + if not os.path.exists(tf_keras_root): + os.makedirs(tf_keras_root) + tf_keras_fpath = os.path.join(tf_keras_root, fname) + with open(legacy_fpath) as f: + legacy_contents = f.read() + legacy_contents = legacy_contents.replace( + "keras._legacy", "keras._tf_keras.keras" + ) + if os.path.exists(core_api_fpath): + with open(core_api_fpath) as f: + core_api_contents = f.read() + core_api_contents = core_api_contents.replace( + "from keras import _tf_keras as _tf_keras\n", "" + ) + for legacy_submodule in legacy_submodules: + core_api_contents = core_api_contents.replace( + f"from keras import {legacy_submodule} as {legacy_submodule}\n", # noqa: E501 + "", + ) + core_api_contents = core_api_contents.replace( + f"keras.{legacy_submodule}", + f"keras._tf_keras.keras.{legacy_submodule}", + ) + # Remove duplicate generated comments string. + legacy_contents = re.sub(r"\n", r"\\n", legacy_contents) + legacy_contents = re.sub('""".*"""', "", legacy_contents) + legacy_contents = re.sub(r"\\n", r"\n", legacy_contents) + # If the same module is in legacy and core_api, use legacy + legacy_imports = re.findall( + r"import (\w+)", legacy_contents + ) + for import_name in legacy_imports: + core_api_contents = re.sub( + f"\n.* import {import_name} as {import_name}\n", + r"\n", + core_api_contents, + ) + legacy_contents = f"{core_api_contents}\n{legacy_contents}" + with open(tf_keras_fpath, "w") as f: + f.write(legacy_contents) + + # Delete keras/api/_legacy/ + shutil.rmtree(os.path.join(package_dir, "_legacy")) + + +def export_version_string(api_init_fname): + with open(api_init_fname) as f: + contents = f.read() + with open(api_init_fname, "w") as f: + contents += "from keras.src.version import __version__ as __version__\n" + f.write(contents) + + +def build(): + root_path = os.path.dirname(os.path.abspath(__file__)) + code_api_dir = os.path.join(root_path, PACKAGE, "api") + # Create temp build dir + build_dir = copy_source_to_build_directory(root_path) + build_api_dir = os.path.join(build_dir, PACKAGE) + build_src_dir = os.path.join(build_api_dir, "src") + build_api_init_fname = os.path.join(build_api_dir, "__init__.py") + try: + os.chdir(build_dir) + open(build_api_init_fname, "w").close() + namex.generate_api_files( + "keras", + code_directory="src", + exclude_directories=[ + os.path.join("src", "backend", "jax"), + os.path.join("src", "backend", "openvino"), + os.path.join("src", "backend", "tensorflow"), + os.path.join("src", "backend", "torch"), + ], + ) + # Add __version__ to `api/`. + export_version_string(build_api_init_fname) + # Creates `_tf_keras` with full keras API + create_legacy_directory(package_dir=os.path.join(build_dir, PACKAGE)) + # Copy back the keras/api and keras/__init__.py from build directory + if os.path.exists(build_src_dir): + shutil.rmtree(build_src_dir) + if os.path.exists(code_api_dir): + shutil.rmtree(code_api_dir) + shutil.copytree( + build_api_dir, code_api_dir, ignore=shutil.ignore_patterns("src/") + ) + finally: + # Clean up: remove the build directory (no longer needed) + shutil.rmtree(build_dir) + + +if __name__ == "__main__": + build() diff --git a/keras/datasets/__init__.py b/benchmarks/__init__.py similarity index 100% rename from keras/datasets/__init__.py rename to benchmarks/__init__.py diff --git a/benchmarks/layer_benchmark/README.md b/benchmarks/layer_benchmark/README.md new file mode 100644 index 000000000000..6ca51d1fd23f --- /dev/null +++ b/benchmarks/layer_benchmark/README.md @@ -0,0 +1,16 @@ +# Benchmark the layer performance + +This directory contains benchmarks to compare the performance of +`keras.layers.XXX` and `tf.keras.layers.XXX`. We compare the performance of +both the forward pass and train step (forward & backward pass). + +To run the benchmark, use the command below and change the flags according to +your target: + +```shell +python3 -m benchmarks.layer_benchmark.conv_benchmark \ + --benchmark_name=benchmark_conv2D \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` \ No newline at end of file diff --git a/keras/layers/__init__.py b/benchmarks/layer_benchmark/__init__.py similarity index 100% rename from keras/layers/__init__.py rename to benchmarks/layer_benchmark/__init__.py diff --git a/benchmarks/layer_benchmark/activation_benchmark.py b/benchmarks/layer_benchmark/activation_benchmark.py new file mode 100644 index 000000000000..52d443f0c42e --- /dev/null +++ b/benchmarks/layer_benchmark/activation_benchmark.py @@ -0,0 +1,173 @@ +"""Benchmark activation layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.activation_benchmark \ + --benchmark_name=benchmark_elu \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_elu( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ELU" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_prelu( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "PReLU" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_relu( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ReLU" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_leaky_relu( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "LeakyReLU" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_softmax( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Softmax" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_elu": benchmark_elu, + "benchmark_relu": benchmark_relu, + "benchmark_leaky_relu": benchmark_leaky_relu, + "benchmark_prelu": benchmark_prelu, + "benchmark_softmax": benchmark_softmax, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/attention_benchmark.py b/benchmarks/layer_benchmark/attention_benchmark.py new file mode 100644 index 000000000000..ab18c443ca06 --- /dev/null +++ b/benchmarks/layer_benchmark/attention_benchmark.py @@ -0,0 +1,132 @@ +"""Benchmark attention layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.attention_benchmark \ + --benchmark_name=benchmark_attention \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_attention( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Attention" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 64], [256, 64]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_multi_head_attention( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "MultiHeadAttention" + init_args = { + "num_heads": 4, + "key_dim": 16, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 64], [256, 64], [256, 64]], + flat_call_inputs=True, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_additive_attention( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "AdditiveAttention" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 64], [256, 64], [256, 64]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_attention": benchmark_attention, + "benchmark_multi_head_attention": benchmark_multi_head_attention, + "benchmark_additive_attention": benchmark_additive_attention, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/base_benchmark.py b/benchmarks/layer_benchmark/base_benchmark.py new file mode 100644 index 000000000000..e7e04d5cfd5a --- /dev/null +++ b/benchmarks/layer_benchmark/base_benchmark.py @@ -0,0 +1,280 @@ +import time + +import numpy as np +import tensorflow as tf +from absl import flags + +import keras + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "benchmark_name", + None, + "The name of benchmark to run. If None, all benchmarks in the file will be " + "run.", +) + +flags.DEFINE_integer( + "num_samples", + 1000, + "Number of input data samples.", +) + +flags.DEFINE_integer( + "batch_size", + 20, + "Batch size of data.", +) + +flags.DEFINE_bool( + "jit_compile", + True, + "If True, the benchmark will run with XLA compilation.", +) + + +class BenchmarkMetricsCallback: + def __init__(self, start_batch=1, stop_batch=None): + self.start_batch = start_batch + self.stop_batch = stop_batch + + self.state = {} + + def on_train_batch_begin(self, batch, logs=None): + if batch == self.start_batch: + self.state["benchmark_begin"] = time.time() + + def on_train_batch_end(self, batch, logs=None): + if batch == self.stop_batch: + self.state["benchmark_end"] = time.time() + throughput = (self.stop_batch - self.start_batch + 1) / ( + self.state["benchmark_end"] - self.state["benchmark_begin"] + ) + self.state["throughput"] = throughput + + def on_predict_batch_begin(self, batch, logs=None): + if batch == self.start_batch: + self.state["benchmark_begin"] = time.time() + + def on_predict_batch_end(self, batch, logs=None): + if batch == self.stop_batch: + self.state["benchmark_end"] = time.time() + throughput = (self.stop_batch - self.start_batch + 1) / ( + self.state["benchmark_end"] - self.state["benchmark_begin"] + ) + self.state["throughput"] = throughput + + +class KerasCoreBenchmarkMetricsCallback(keras.callbacks.Callback): + def __init__(self, start_batch=1, stop_batch=None): + self._callback = BenchmarkMetricsCallback(start_batch, stop_batch) + + def on_train_batch_begin(self, batch, logs=None): + self._callback.on_train_batch_begin(batch, logs) + + def on_train_batch_end(self, batch, logs=None): + self._callback.on_train_batch_end(batch, logs) + + def on_predict_batch_begin(self, batch, logs=None): + self._callback.on_predict_batch_begin(batch, logs) + + def on_predict_batch_end(self, batch, logs=None): + self._callback.on_predict_batch_end(batch, logs) + + +class TFKerasBenchmarkMetricsCallback(tf.keras.callbacks.Callback): + def __init__(self, start_batch=1, stop_batch=None): + self._callback = BenchmarkMetricsCallback(start_batch, stop_batch) + + def on_train_batch_begin(self, batch, logs=None): + self._callback.on_train_batch_begin(batch, logs) + + def on_train_batch_end(self, batch, logs=None): + self._callback.on_train_batch_end(batch, logs) + + def on_predict_batch_begin(self, batch, logs=None): + self._callback.on_predict_batch_begin(batch, logs) + + def on_predict_batch_end(self, batch, logs=None): + self._callback.on_predict_batch_end(batch, logs) + + +class LayerBenchmark: + def __init__( + self, + layer_name, + init_args, + input_shape, + flat_call_inputs=True, + jit_compile=True, + keras_layer=None, + tf_keras_layer=None, + ): + self.layer_name = layer_name + _keras_layer_class = getattr(keras.layers, layer_name) + _tf_keras_layer_class = getattr(tf.keras.layers, layer_name) + + if keras_layer is None: + # Sometimes you want to initialize the keras layer and tf_keras + # layer in a different way. For example, `Bidirectional` layer, + # which takes in `keras.layers.Layer` and + # `tf.keras.layer.Layer` separately. + self._keras_layer = _keras_layer_class(**init_args) + else: + self._keras_layer = keras_layer + + if tf_keras_layer is None: + self._tf_keras_layer = _tf_keras_layer_class(**init_args) + else: + self._tf_keras_layer = tf_keras_layer + + self.input_shape = input_shape + self._keras_model = self._build_keras_model( + input_shape, flat_call_inputs + ) + self._tf_keras_model = self._build_tf_keras_model( + input_shape, flat_call_inputs + ) + + self._keras_model.compile( + loss="mse", optimizer="sgd", jit_compile=jit_compile + ) + self._tf_keras_model.compile( + loss="mse", optimizer="sgd", jit_compile=jit_compile + ) + + self.flat_call_inputs = flat_call_inputs + self.jit_compile = jit_compile + self.input_shape = input_shape + + def _build_keras_model(self, input_shape, flat_call_inputs=True): + inputs = [] + if not isinstance(input_shape[0], (tuple, list)): + input_shape = [input_shape] + + for shape in input_shape: + inputs.append(keras.Input(shape=shape)) + + if flat_call_inputs: + outputs = self._keras_layer(*inputs) + else: + outputs = self._keras_layer(inputs) + return keras.Model(inputs=inputs, outputs=outputs) + + def _build_tf_keras_model(self, input_shape, flat_call_inputs=True): + inputs = [] + if not isinstance(input_shape[0], (tuple, list)): + input_shape = [input_shape] + + for shape in input_shape: + inputs.append(tf.keras.Input(shape=shape)) + + if flat_call_inputs: + outputs = self._tf_keras_layer(*inputs) + else: + outputs = self._tf_keras_layer(inputs) + return tf.keras.Model(inputs=inputs, outputs=outputs) + + def benchmark_predict(self, num_samples, batch_size, data=None): + if data is None: + # Generate default data if not provided. + if isinstance(self.input_shape[0], (tuple, list)): + # The layer has multiple inputs. + data = [] + for data_shape in self.input_shape: + data_shape = [num_samples] + list(data_shape) + data.append(np.random.normal(size=data_shape)) + else: + data_shape = [num_samples] + list(self.input_shape) + data = np.random.normal(size=data_shape) + + num_iterations = num_samples // batch_size - 1 + callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations) + tf_keras_callback = TFKerasBenchmarkMetricsCallback( + stop_batch=num_iterations + ) + + self._keras_model.predict( + data, + batch_size=batch_size, + callbacks=[callback], + ) + + self._tf_keras_model.predict( + data, + batch_size=batch_size, + callbacks=[tf_keras_callback], + ) + + keras_throughput = callback._callback.state["throughput"] * batch_size + tf_keras_throughput = ( + tf_keras_callback._callback.state["throughput"] * batch_size + ) + print( + f"Keras 3 throughput of forward pass of {self.layer_name}: " + f"{keras_throughput:.2f} samples/sec." + ) + print( + f"TF Keras throughput of forward pass of {self.layer_name}: " + f"{tf_keras_throughput:.2f} samples/sec." + ) + + def benchmark_train(self, num_samples, batch_size, data=None, label=None): + if data is None: + # Generate default data if not provided. + if isinstance(self.input_shape[0], (tuple, list)): + # The layer has multiple inputs. + data = [] + for data_shape in self.input_shape: + data_shape = [num_samples] + list(data_shape) + data.append(np.random.normal(size=data_shape)) + else: + data_shape = [num_samples] + list(self.input_shape) + data = [np.random.normal(size=data_shape)] + + if label is None: + # Generate default label if not provided. + if self.flat_call_inputs: + # Scale by a small factor to avoid zero gradients. + label = ( + keras.backend.convert_to_numpy(self._keras_layer(*data)) + * 1.001 + ) + else: + label = ( + keras.backend.convert_to_numpy(self._keras_layer(data)) + * 1.001 + ) + + num_iterations = num_samples // batch_size - 1 + callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations) + tf_keras_callback = TFKerasBenchmarkMetricsCallback( + stop_batch=num_iterations + ) + + self._keras_model.fit( + data, + label, + batch_size=batch_size, + callbacks=[callback], + ) + self._tf_keras_model.fit( + data, + label, + batch_size=batch_size, + callbacks=[tf_keras_callback], + ) + + keras_throughput = callback._callback.state["throughput"] * batch_size + tf_keras_throughput = ( + tf_keras_callback._callback.state["throughput"] * batch_size + ) + print( + f"Keras 3 throughput of forward & backward pass of " + f"{self.layer_name}: {keras_throughput:.2f} samples/sec." + ) + print( + f"TF Keras throughput of forward & backward pass of " + f"{self.layer_name}: {tf_keras_throughput:.2f} samples/sec." + ) diff --git a/benchmarks/layer_benchmark/conv_benchmark.py b/benchmarks/layer_benchmark/conv_benchmark.py new file mode 100644 index 000000000000..32d19f282fb3 --- /dev/null +++ b/benchmarks/layer_benchmark/conv_benchmark.py @@ -0,0 +1,340 @@ +"""Benchmark conv layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.conv_benchmark \ + --benchmark_name=benchmark_conv2D \ + --num_samples=2046 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_conv1D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Conv1D" + init_args = { + "filters": 64, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[1024, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv2D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Conv2D" + init_args = { + "filters": 16, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[128, 128, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv3D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Conv3D" + init_args = { + "filters": 16, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 32, 32, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_depthwise_conv1D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "DepthwiseConv1D" + init_args = { + "kernel_size": 16, + "depth_multiplier": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 64], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_depthwise_conv2D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "DepthwiseConv2D" + init_args = { + "kernel_size": 16, + "depth_multiplier": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[128, 128, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_separable_conv1D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "SeparableConv1D" + init_args = { + "kernel_size": 16, + "depth_multiplier": 2, + "filters": 3, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 64], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_separable_conv2D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "SeparableConv2D" + init_args = { + "kernel_size": 16, + "depth_multiplier": 2, + "filters": 3, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[128, 128, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv1D_transpose( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Conv1DTranspose" + init_args = { + "filters": 32, + "kernel_size": 4, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv2D_transpose( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Conv2DTranspose" + init_args = { + "filters": 16, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[128, 128, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv3D_transpose( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Conv3DTranspose" + init_args = { + "filters": 16, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 32, 32, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_conv1D": benchmark_conv1D, + "benchmark_conv2D": benchmark_conv2D, + "benchmark_conv3D": benchmark_conv3D, + "benchmark_depthwise_conv1D": benchmark_depthwise_conv1D, + "benchmark_depthwise_conv2D": benchmark_depthwise_conv2D, + "benchmark_separable_conv1D": benchmark_separable_conv1D, + "benchmark_separable_conv2D": benchmark_separable_conv2D, + "benchmark_conv1D_transpose": benchmark_conv1D_transpose, + "benchmark_conv2D_transpose": benchmark_conv2D_transpose, + "benchmark_conv3D_transpose": benchmark_conv3D_transpose, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES: + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/core_benchmark.py b/benchmarks/layer_benchmark/core_benchmark.py new file mode 100644 index 000000000000..1291cbd082d7 --- /dev/null +++ b/benchmarks/layer_benchmark/core_benchmark.py @@ -0,0 +1,138 @@ +"""Benchmark core layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.core_benchmark \ + --benchmark_name=benchmark_dense \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +import numpy as np +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_dense( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Dense" + init_args = {"units": 256} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_einsum_dense( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "EinsumDense" + init_args = { + "equation": "abc,cd->abd", + "output_shape": (None, 256), + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_embedding( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Embedding" + init_args = { + "input_dim": 128, + "output_dim": 256, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[ + 256, + ], + jit_compile=jit_compile, + ) + + data = [np.random.randint(30, size=(num_samples, 256))] + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + data=data, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + data=data, + ) + + +BENCHMARK_NAMES = { + "benchmark_dense": benchmark_dense, + "benchmark_einsum_dense": benchmark_einsum_dense, + "benchmark_embedding": benchmark_embedding, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/merge_benchmark.py b/benchmarks/layer_benchmark/merge_benchmark.py new file mode 100644 index 000000000000..81795fb02e14 --- /dev/null +++ b/benchmarks/layer_benchmark/merge_benchmark.py @@ -0,0 +1,264 @@ +"""Benchmark merge layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.merge_benchmark \ + --benchmark_name=benchmark_add \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_add( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Add" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 256], [256, 256]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_average( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Average" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 256], [256, 256]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_concatenate( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Concatenate" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 256], [256, 256]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_dot( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Dot" + init_args = {"axes": [2, 1]} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 32], [32, 64]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_maximum( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Maximum" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 256], [256, 256]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_minimum( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Minimum" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 256], [256, 256]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_multiply( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Multiply" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 64], [256, 64]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_subtract( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Subtract" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[[256, 256], [256, 256]], + flat_call_inputs=False, + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_add": benchmark_add, + "benchmark_average": benchmark_average, + "benchmark_concatenate": benchmark_concatenate, + "benchmark_dot": benchmark_dot, + "benchmark_maximum": benchmark_maximum, + "benchmark_minimum": benchmark_minimum, + "benchmark_multiply": benchmark_multiply, + "benchmark_subtract": benchmark_subtract, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/normalization_benchmark.py b/benchmarks/layer_benchmark/normalization_benchmark.py new file mode 100644 index 000000000000..82e83a88c4d4 --- /dev/null +++ b/benchmarks/layer_benchmark/normalization_benchmark.py @@ -0,0 +1,154 @@ +"""Benchmark normalization layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.normalization_benchmark \ + --benchmark_name=benchmark_batch_normalization \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_batch_normalization( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "BatchNormalization" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_group_normalization( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GroupNormalization" + init_args = { + "groups": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_layer_normalization( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "LayerNormalization" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 128, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_unit_normalization( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "UnitNormalization" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 128, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_batch_normalization": benchmark_batch_normalization, + "benchmark_group_normalization": benchmark_group_normalization, + "benchmark_layer_normalization": benchmark_layer_normalization, + "benchmark_unit_normalization": benchmark_unit_normalization, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/pooling_benchmark.py b/benchmarks/layer_benchmark/pooling_benchmark.py new file mode 100644 index 000000000000..c64c986927d1 --- /dev/null +++ b/benchmarks/layer_benchmark/pooling_benchmark.py @@ -0,0 +1,372 @@ +"""Benchmark pooling layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.pooling_benchmark \ + --benchmark_name=benchmark_max_pooling1d \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_average_pooling1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "AveragePooling1D" + init_args = { + "pool_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[1024, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_average_pooling2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "AveragePooling2D" + init_args = { + "pool_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_average_pooling3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "AveragePooling3D" + init_args = { + "pool_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[64, 64, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_max_pooling1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "MaxPooling1D" + init_args = { + "pool_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[1024, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_max_pooling2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "MaxPooling2D" + init_args = { + "pool_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_max_pooling3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "MaxPooling3D" + init_args = { + "pool_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[64, 64, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_global_average_pooling1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GlobalAveragePooling1D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[1024, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_global_average_pooling2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GlobalAveragePooling2D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_global_average_pooling3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GlobalAveragePooling3D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[64, 64, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_global_max_pooling1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GlobalMaxPooling1D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[1024, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_global_max_pooling2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GlobalMaxPooling2D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_global_max_pooling3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GlobalMaxPooling3D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[64, 64, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_average_pooling1d": benchmark_average_pooling1d, + "benchmark_average_pooling2d": benchmark_average_pooling2d, + "benchmark_average_pooling3d": benchmark_average_pooling3d, + "benchmark_max_pooling1d": benchmark_max_pooling1d, + "benchmark_max_pooling2d": benchmark_max_pooling2d, + "benchmark_max_pooling3d": benchmark_max_pooling3d, + "benchmark_global_average_pooling1d": benchmark_global_average_pooling1d, + "benchmark_global_average_pooling2d": benchmark_global_average_pooling2d, + "benchmark_global_average_pooling3d": benchmark_global_average_pooling3d, + "benchmark_global_max_pooling1d": benchmark_global_max_pooling1d, + "benchmark_global_max_pooling2d": benchmark_global_max_pooling2d, + "benchmark_global_max_pooling3d": benchmark_global_max_pooling3d, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/regularization_benchmark.py b/benchmarks/layer_benchmark/regularization_benchmark.py new file mode 100644 index 000000000000..9e15e92752a0 --- /dev/null +++ b/benchmarks/layer_benchmark/regularization_benchmark.py @@ -0,0 +1,216 @@ +"""Benchmark regularization layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.regularization_benchmark \ + --benchmark_name=benchmark_dropout\ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_dropout( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Dropout" + init_args = { + "rate": 0.5, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_gaussian_dropout( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GaussianDropout" + init_args = { + "rate": 0.5, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_gaussian_noise( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GaussianNoise" + init_args = { + "stddev": 0.5, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 4], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_spatial_dropout1D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "SpatialDropout1D" + init_args = { + "rate": 0.5, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_spatial_dropout2D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "SpatialDropout2D" + init_args = { + "rate": 0.5, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_spatial_dropout3D( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "SpatialDropout3D" + init_args = { + "rate": 0.5, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 32, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_dropout": benchmark_dropout, + "benchmark_gaussian_dropout": benchmark_gaussian_dropout, + "benchmark_gaussian_noise": benchmark_gaussian_noise, + "benchmark_spatial_dropout1D": benchmark_spatial_dropout1D, + "benchmark_spatial_dropout2D": benchmark_spatial_dropout2D, + "benchmark_spatial_dropout3D": benchmark_spatial_dropout3D, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/reshaping_benchmark.py b/benchmarks/layer_benchmark/reshaping_benchmark.py new file mode 100644 index 000000000000..e0336d5ff72d --- /dev/null +++ b/benchmarks/layer_benchmark/reshaping_benchmark.py @@ -0,0 +1,336 @@ +"""Benchmark reshaping layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.reshaping_benchmark \ + --benchmark_name=benchmark_cropping2d \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +from absl import app +from absl import flags + +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_cropping1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Cropping1D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[1024, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_cropping2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Cropping2D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_cropping3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Cropping3D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 32, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_flatten( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Flatten" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_permute( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Permute" + init_args = { + "dims": (2, 1), + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_up_sampling1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "UpSampling1D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_up_sampling2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "UpSampling2D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[128, 128, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_up_sampling3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "UpSampling3D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 16, 16, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_zero_padding1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ZeroPadding1D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_zero_padding2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ZeroPadding2D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_zero_padding3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ZeroPadding3D" + init_args = {} + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 32, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_cropping1d": benchmark_cropping1d, + "benchmark_cropping2d": benchmark_cropping2d, + "benchmark_cropping3d": benchmark_cropping3d, + "benchmark_flatten": benchmark_flatten, + "benchmark_permute": benchmark_permute, + "benchmark_up_sampling1d": benchmark_up_sampling1d, + "benchmark_up_sampling2d": benchmark_up_sampling2d, + "benchmark_up_sampling3d": benchmark_up_sampling3d, + "benchmark_zero_padding1d": benchmark_zero_padding1d, + "benchmark_zero_padding2d": benchmark_zero_padding2d, + "benchmark_zero_padding3d": benchmark_zero_padding3d, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/layer_benchmark/rnn_benchmark.py b/benchmarks/layer_benchmark/rnn_benchmark.py new file mode 100644 index 000000000000..321399c2dc2f --- /dev/null +++ b/benchmarks/layer_benchmark/rnn_benchmark.py @@ -0,0 +1,283 @@ +"""Benchmark rnn layers. + +To run benchmarks, see the following command for an example, please change the +flag to your custom value: + +``` +python3 -m benchmarks.layer_benchmark.rnn_benchmark \ + --benchmark_name=benchmark_lstm \ + --num_samples=2048 \ + --batch_size=256 \ + --jit_compile=True +``` +""" + +import tensorflow as tf +from absl import app +from absl import flags + +import keras +from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark + +FLAGS = flags.FLAGS + + +def benchmark_conv_lstm1d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ConvLSTM1D" + init_args = { + "filters": 16, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 256, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv_lstm2d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ConvLSTM2D" + init_args = { + "filters": 16, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[32, 32, 32, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_conv_lstm3d( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "ConvLSTM3D" + init_args = { + "filters": 8, + "kernel_size": 2, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[8, 16, 16, 16, 3], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_gru( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "GRU" + init_args = { + "units": 32, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_lstm( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "LSTM" + init_args = { + "units": 32, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_simple_rnn( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "SimpleRNN" + init_args = { + "units": 32, + } + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_bidirectional( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "Bidirectional" + init_args = {} + keras_layer = keras.layers.Bidirectional(keras.layers.LSTM(32)) + tf_keras_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)) + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[256, 256], + jit_compile=jit_compile, + keras_layer=keras_layer, + tf_keras_layer=tf_keras_layer, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +def benchmark_time_distributed( + num_samples, + batch_size, + jit_compile=True, +): + layer_name = "TimeDistributed" + init_args = {} + keras_layer = keras.layers.TimeDistributed(keras.layers.Conv2D(16, (3, 3))) + tf_keras_layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Conv2D(16, (3, 3)) + ) + benchmark = LayerBenchmark( + layer_name, + init_args, + input_shape=[10, 32, 32, 3], + jit_compile=jit_compile, + keras_layer=keras_layer, + tf_keras_layer=tf_keras_layer, + ) + + benchmark.benchmark_predict( + num_samples=num_samples, + batch_size=batch_size, + ) + + benchmark.benchmark_train( + num_samples=num_samples, + batch_size=batch_size, + ) + + +BENCHMARK_NAMES = { + "benchmark_conv_lstm1d": benchmark_conv_lstm1d, + "benchmark_conv_lstm2d": benchmark_conv_lstm2d, + "benchmark_conv_lstm3d": benchmark_conv_lstm3d, + "benchmark_gru": benchmark_gru, + "benchmark_lstm": benchmark_lstm, + "benchmark_simple_rnn": benchmark_simple_rnn, + "benchmark_bidirectional": benchmark_bidirectional, + "benchmark_time_distributed": benchmark_time_distributed, +} + + +def main(_): + benchmark_name = FLAGS.benchmark_name + num_samples = FLAGS.num_samples + batch_size = FLAGS.batch_size + jit_compile = FLAGS.jit_compile + + if benchmark_name is None: + for name, benchmark_fn in BENCHMARK_NAMES.items(): + benchmark_fn(num_samples, batch_size, jit_compile) + return + + if benchmark_name not in BENCHMARK_NAMES: + raise ValueError( + f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must " + f"be one of {BENCHMARK_NAMES.keys()}" + ) + benchmark_fn = BENCHMARK_NAMES[benchmark_name] + benchmark_fn(num_samples, batch_size, jit_compile) + + +if __name__ == "__main__": + app.run(main) diff --git a/keras/preprocessing/__init__.py b/benchmarks/model_benchmark/__init__.py similarity index 100% rename from keras/preprocessing/__init__.py rename to benchmarks/model_benchmark/__init__.py diff --git a/benchmarks/model_benchmark/benchmark_utils.py b/benchmarks/model_benchmark/benchmark_utils.py new file mode 100644 index 000000000000..dafba9205669 --- /dev/null +++ b/benchmarks/model_benchmark/benchmark_utils.py @@ -0,0 +1,24 @@ +import time + +import keras + + +class BenchmarkMetricsCallback(keras.callbacks.Callback): + def __init__(self, start_batch=1, stop_batch=None): + self.start_batch = start_batch + self.stop_batch = stop_batch + + # Store the throughput of each epoch. + self.state = {"throughput": []} + + def on_train_batch_begin(self, batch, logs=None): + if batch == self.start_batch: + self.state["epoch_begin_time"] = time.time() + + def on_train_batch_end(self, batch, logs=None): + if batch == self.stop_batch: + epoch_end_time = time.time() + throughput = (self.stop_batch - self.start_batch + 1) / ( + epoch_end_time - self.state["epoch_begin_time"] + ) + self.state["throughput"].append(throughput) diff --git a/benchmarks/model_benchmark/bert_benchmark.py b/benchmarks/model_benchmark/bert_benchmark.py new file mode 100644 index 000000000000..d589baa52d98 --- /dev/null +++ b/benchmarks/model_benchmark/bert_benchmark.py @@ -0,0 +1,161 @@ +"""Benchmark BERT model on GLUE/MRPC task. + +To run the script, make sure you are in benchmarks/ directory, abd run the +command below: +``` +python3 -m model_benchmark.bert_benchmark \ + --epochs 2 \ + --batch_size 32 +``` + +""" + +import time + +import keras_nlp +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from absl import app +from absl import flags +from absl import logging +from model_benchmark.benchmark_utils import BenchmarkMetricsCallback + +import keras + +flags.DEFINE_string("model_size", "small", "The size of model to benchmark.") +flags.DEFINE_string( + "mixed_precision_policy", + "mixed_float16", + "The global precision policy to use, e.g., 'mixed_float16' or 'float32'.", +) +flags.DEFINE_integer("epochs", 2, "The number of epochs.") +flags.DEFINE_integer("batch_size", 8, "Batch Size.") + + +FLAGS = flags.FLAGS + + +MODEL_SIZE_MAP = { + "tiny": "bert_tiny_en_uncased", + "small": "bert_small_en_uncased", + "base": "bert_base_en_uncased", + "large": "bert_large_en_uncased", +} + + +def load_data(): + """Load data. + + Load GLUE/MRPC dataset, and convert the dictionary format to + (features, label), where `features` is a tuple of all input sentences. + """ + feature_names = ("sentence1", "sentence2") + + def split_features(x): + # GLUE comes with dictionary data, we convert it to a uniform format + # (features, label), where features is a tuple consisting of all + # features. This format is necessary for using KerasNLP preprocessors. + features = tuple([x[name] for name in feature_names]) + label = x["label"] + return (features, label) + + train_ds, test_ds, validation_ds = tfds.load( + "glue/mrpc", + split=["train", "test", "validation"], + ) + + train_ds = ( + train_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE) + .batch(FLAGS.batch_size) + .prefetch(tf.data.AUTOTUNE) + ) + test_ds = ( + test_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE) + .batch(FLAGS.batch_size) + .prefetch(tf.data.AUTOTUNE) + ) + validation_ds = ( + validation_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE) + .batch(FLAGS.batch_size) + .prefetch(tf.data.AUTOTUNE) + ) + return train_ds, test_ds, validation_ds + + +def load_model(): + if FLAGS.model_size not in MODEL_SIZE_MAP.keys(): + raise KeyError( + f"`model_size` must be one of {MODEL_SIZE_MAP.keys()}, but " + f"received {FLAGS.model_size}." + ) + return keras_nlp.models.BertClassifier.from_preset( + MODEL_SIZE_MAP[FLAGS.model_size], num_classes=2 + ) + + +def main(_): + keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy) + + logging.info( + "Benchmarking configs...\n" + "=========================\n" + f"MODEL: BERT {FLAGS.model_size}\n" + f"TASK: glue/mrpc \n" + f"BATCH_SIZE: {FLAGS.batch_size}\n" + f"EPOCHS: {FLAGS.epochs}\n" + "=========================\n" + ) + + # Load datasets. + train_ds, test_ds, validation_ds = load_data() + + # Load the model. + model = load_model() + # Set loss and metrics. + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + metrics = [keras.metrics.SparseCategoricalAccuracy()] + # Configure optimizer. + lr = keras.optimizers.schedules.PolynomialDecay( + 5e-4, + decay_steps=train_ds.cardinality() * FLAGS.epochs, + end_learning_rate=0.0, + ) + optimizer = keras.optimizers.AdamW(lr, weight_decay=0.01) + optimizer.exclude_from_weight_decay( + var_names=["LayerNorm", "layer_norm", "bias"] + ) + + model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + + benchmark_metrics_callback = BenchmarkMetricsCallback( + start_batch=1, + stop_batch=train_ds.cardinality().numpy() - 1, + ) + + # Start training. + logging.info("Starting Training...") + + st = time.time() + history = model.fit( + train_ds, + validation_data=validation_ds, + epochs=FLAGS.epochs, + callbacks=[benchmark_metrics_callback], + ) + + wall_time = time.time() - st + validation_accuracy = history.history["val_sparse_categorical_accuracy"][-1] + examples_per_second = ( + np.mean(np.array(benchmark_metrics_callback.state["throughput"])) + * FLAGS.batch_size + ) + + logging.info("Training Finished!") + logging.info(f"Wall Time: {wall_time:.4f} seconds.") + logging.info(f"Validation Accuracy: {validation_accuracy:.4f}") + logging.info(f"examples_per_second: {examples_per_second:.4f}") + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/model_benchmark/image_classification_benchmark.py b/benchmarks/model_benchmark/image_classification_benchmark.py new file mode 100644 index 000000000000..5fb6ec9fa71c --- /dev/null +++ b/benchmarks/model_benchmark/image_classification_benchmark.py @@ -0,0 +1,163 @@ +"""Image classification benchmark. + +This script runs image classification benchmark with "dogs vs cats" datasets. +It supports the following 3 models: + +- EfficientNetV2B0 +- Xception +- ResNet50V2 + +To run the benchmark, make sure you are in model_benchmark/ directory, and run +the command below: + +python3 -m model_benchmark.image_classification_benchmark \ + --model="EfficientNetV2B0" \ + --epochs=2 \ + --batch_size=32 \ + --mixed_precision_policy="mixed_float16" +""" + +import time + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from absl import app +from absl import flags +from absl import logging +from model_benchmark.benchmark_utils import BenchmarkMetricsCallback + +import keras + +flags.DEFINE_string("model", "EfficientNetV2B0", "The model to benchmark.") +flags.DEFINE_integer("epochs", 1, "The number of epochs.") +flags.DEFINE_integer("batch_size", 4, "Batch Size.") +flags.DEFINE_string( + "mixed_precision_policy", + "mixed_float16", + "The global precision policy to use, e.g., 'mixed_float16' or 'float32'.", +) + +FLAGS = flags.FLAGS + +BATCH_SIZE = 32 +IMAGE_SIZE = (224, 224) +CHANNELS = 3 + +MODEL_MAP = { + "EfficientNetV2B0": keras.applications.EfficientNetV2B0, + "Xception": keras.applications.Xception, + "ResNet50V2": keras.applications.ResNet50V2, +} + + +def load_data(): + # Load cats vs dogs dataset, and split into train and validation sets. + train_dataset, val_dataset = tfds.load( + "cats_vs_dogs", split=["train[:90%]", "train[90%:]"], as_supervised=True + ) + + resizing = keras.layers.Resizing( + IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True + ) + + def preprocess_inputs(image, label): + image = tf.cast(image, "float32") + return resizing(image), label + + train_dataset = ( + train_dataset.map( + preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE + ) + .batch(FLAGS.batch_size) + .prefetch(tf.data.AUTOTUNE) + ) + val_dataset = ( + val_dataset.map(preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE) + .batch(FLAGS.batch_size) + .cache() + .prefetch(tf.data.AUTOTUNE) + ) + return train_dataset, val_dataset + + +def load_model(): + model_class = MODEL_MAP[FLAGS.model] + # Load the EfficientNetV2B0 model and add a classification head. + model = model_class(include_top=False, weights="imagenet") + classifier = keras.models.Sequential( + [ + keras.Input([IMAGE_SIZE[0], IMAGE_SIZE[1], CHANNELS]), + model, + keras.layers.GlobalAveragePooling2D(), + keras.layers.Dense(2), + ] + ) + return classifier + + +def main(_): + keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy) + + logging.info( + "Benchmarking configs...\n" + "=========================\n" + f"MODEL: {FLAGS.model}\n" + f"TASK: image classification/dogs-vs-cats \n" + f"BATCH_SIZE: {FLAGS.batch_size}\n" + f"EPOCHS: {FLAGS.epochs}\n" + "=========================\n" + ) + + # Load datasets. + train_ds, validation_ds = load_data() + + # Load the model. + classifier = load_model() + + lr = keras.optimizers.schedules.PolynomialDecay( + 5e-4, + decay_steps=train_ds.cardinality() * FLAGS.epochs, + end_learning_rate=0.0, + ) + optimizer = keras.optimizers.Adam(lr) + loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + + benchmark_metrics_callback = BenchmarkMetricsCallback( + start_batch=1, + stop_batch=train_ds.cardinality().numpy() - 1, + ) + + classifier.compile( + optimizer=optimizer, + loss=loss, + metrics=["sparse_categorical_accuracy"], + ) + # Start training. + logging.info("Starting Training...") + + st = time.time() + + history = classifier.fit( + train_ds, + validation_data=validation_ds, + epochs=FLAGS.epochs, + callbacks=[benchmark_metrics_callback], + ) + + wall_time = time.time() - st + validation_accuracy = history.history["val_sparse_categorical_accuracy"][-1] + + examples_per_second = ( + np.mean(np.array(benchmark_metrics_callback.state["throughput"])) + * FLAGS.batch_size + ) + + logging.info("Training Finished!") + logging.info(f"Wall Time: {wall_time:.4f} seconds.") + logging.info(f"Validation Accuracy: {validation_accuracy:.4f}") + logging.info(f"examples_per_second: {examples_per_second:.4f}") + + +if __name__ == "__main__": + app.run(main) diff --git a/benchmarks/torch_ctl_benchmark/README.md b/benchmarks/torch_ctl_benchmark/README.md new file mode 100644 index 000000000000..fa7cc6566a65 --- /dev/null +++ b/benchmarks/torch_ctl_benchmark/README.md @@ -0,0 +1,13 @@ +# Benchmark the performance of torch custom training loop + +This directory contains benchmarks to compare the performance of a Keras model +and a equivalent Torch model while using the same Torch custom training loop. + +The benchmark purpose is to understand the performance diff resulting from the +modeling API choice (Keras or Torch). + +To run the benchmark, use the command below and change to your target: + +```shell +python3 -m benchmarks.torch_ctl_benchmark.conv_model_benchmark +``` \ No newline at end of file diff --git a/keras/utils/__init__.py b/benchmarks/torch_ctl_benchmark/__init__.py similarity index 100% rename from keras/utils/__init__.py rename to benchmarks/torch_ctl_benchmark/__init__.py diff --git a/benchmarks/torch_ctl_benchmark/benchmark_utils.py b/benchmarks/torch_ctl_benchmark/benchmark_utils.py new file mode 100644 index 000000000000..f60b24954474 --- /dev/null +++ b/benchmarks/torch_ctl_benchmark/benchmark_utils.py @@ -0,0 +1,36 @@ +import time + +import numpy as np +import torch + + +def train_loop(model, train_loader, num_epochs, optimizer, loss_fn, framework): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.to(device) + start = None + average_batch_time_per_epoch = [] + for _ in range(num_epochs): + running_loss = 0.0 + for batch_idx, (inputs, targets) in enumerate(train_loader): + if batch_idx == 1: + start = time.time() + inputs = inputs.to(device) + targets = targets.to(device) + # Forward pass + outputs = model(inputs) + loss = loss_fn(outputs, targets) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + end = time.time() + average_batch_time_per_epoch.append( + (end - start) / (len(train_loader) - 1) + ) + average_time = np.mean(average_batch_time_per_epoch) + + print(f"Time per batch in {framework}: {average_time:.2f}") diff --git a/benchmarks/torch_ctl_benchmark/conv_model_benchmark.py b/benchmarks/torch_ctl_benchmark/conv_model_benchmark.py new file mode 100644 index 000000000000..9ac0c2c56313 --- /dev/null +++ b/benchmarks/torch_ctl_benchmark/conv_model_benchmark.py @@ -0,0 +1,98 @@ +"""Benchmark Keras performance with torch custom training loop. + +In this file we use a convolution model. Training loop is written in the +vanilla torch way, and we compare the performance between building model with +Keras and torch. +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +import keras +from benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop +from keras import layers + +num_classes = 2 +input_shape = (3, 256, 256) +batch_size = 128 +num_batches = 20 +num_epochs = 1 + +x_train = np.random.normal( + size=(num_batches * batch_size, *input_shape) +).astype(np.float32) +y_train = np.random.randint(0, num_classes, size=(num_batches * batch_size,)) + +# Create a TensorDataset +dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) +) +# Create a DataLoader +train_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False +) + + +class TorchModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv = torch.nn.Conv2d(3, 32, kernel_size=(3, 3)) + self.activation = torch.nn.ReLU() + self.max_pool = torch.nn.MaxPool2d((2, 2)) + self.flatten = torch.nn.Flatten() + self.dense = torch.nn.LazyLinear(num_classes) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + x = self.conv(x) + x = self.activation(x) + x = self.max_pool(x) + x = self.flatten(x) + x = self.dense(x) + x = self.softmax(x) + return x + + +def run_keras_custom_training_loop(): + keras_model = keras.Sequential( + [ + layers.Input(shape=input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dense(num_classes), + layers.Softmax(), + ] + ) + optimizer = optim.Adam(keras_model.parameters(), lr=0.001) + loss_fn = nn.CrossEntropyLoss() + train_loop( + keras_model, + train_loader, + num_epochs=num_epochs, + optimizer=optimizer, + loss_fn=loss_fn, + framework="keras", + ) + + +def run_torch_custom_training_loop(): + torch_model = TorchModel() + optimizer = optim.Adam(torch_model.parameters(), lr=0.001) + loss_fn = nn.CrossEntropyLoss() + train_loop( + torch_model, + train_loader, + num_epochs=num_epochs, + optimizer=optimizer, + loss_fn=loss_fn, + framework="torch", + ) + + +if __name__ == "__main__": + run_keras_custom_training_loop() + run_torch_custom_training_loop() diff --git a/benchmarks/torch_ctl_benchmark/dense_model_benchmark.py b/benchmarks/torch_ctl_benchmark/dense_model_benchmark.py new file mode 100644 index 000000000000..4ecb382f4f34 --- /dev/null +++ b/benchmarks/torch_ctl_benchmark/dense_model_benchmark.py @@ -0,0 +1,97 @@ +"""Benchmark Keras performance with torch custom training loop. + +In this file we use a model with 3 dense layers. Training loop is written in the +vanilla torch way, and we compare the performance between building model with +Keras and torch. +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +import keras +from benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop +from keras import layers + +num_classes = 2 +input_shape = (8192,) +batch_size = 4096 +num_batches = 20 +num_epochs = 1 + +x_train = np.random.normal( + size=(num_batches * batch_size, *input_shape) +).astype(np.float32) +y_train = np.random.randint(0, num_classes, size=(num_batches * batch_size,)) + +# Create a TensorDataset +dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) +) +# Create a DataLoader +train_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False +) + + +class TorchModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.dense1 = torch.nn.Linear(8192, 64) + self.activation1 = torch.nn.ReLU() + self.dense2 = torch.nn.Linear(64, 8) + self.activation2 = torch.nn.ReLU() + self.dense3 = torch.nn.Linear(8, num_classes) + self.softmax = torch.nn.Softmax(dim=1) + + def forward(self, x): + x = self.dense1(x) + x = self.activation1(x) + x = self.dense2(x) + x = self.activation2(x) + x = self.dense3(x) + x = self.softmax(x) + return x + + +def run_keras_custom_training_loop(): + keras_model = keras.Sequential( + [ + layers.Input(shape=input_shape), + layers.Dense(64, activation="relu"), + layers.Dense(8, activation="relu"), + layers.Dense(num_classes), + layers.Softmax(), + ] + ) + optimizer = optim.Adam(keras_model.parameters(), lr=0.001) + loss_fn = nn.CrossEntropyLoss() + train_loop( + keras_model, + train_loader, + num_epochs=num_epochs, + optimizer=optimizer, + loss_fn=loss_fn, + framework="keras", + ) + + +def run_torch_custom_training_loop(): + torch_model = TorchModel() + optimizer = optim.Adam(torch_model.parameters(), lr=0.001) + loss_fn = nn.CrossEntropyLoss() + train_loop( + torch_model, + train_loader, + num_epochs=num_epochs, + optimizer=optimizer, + loss_fn=loss_fn, + framework="torch", + ) + + +if __name__ == "__main__": + run_keras_custom_training_loop() + run_torch_custom_training_loop() diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000000..d6453385fe4f --- /dev/null +++ b/codecov.yml @@ -0,0 +1,35 @@ +coverage: + status: + project: + default: + # `auto` compares coverage with the base-commit + target: auto + + patch: + default: + target:auto + +comment: + layout: "header, reach, diff, flags, files" + behavior: default + require_changes: no + require_base: no + require_head: yes + show_carryforward_flags: yes + +flag_management: + default_rules: + carryforward: false + statuses: + - type: project + target: auto + - type: patch + target: auto + individual_flags: + - name: keras + paths: + - keras + - name: keras.applications + paths: + - keras/applications + carryforward: true diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000000..9853ff86baf1 --- /dev/null +++ b/conftest.py @@ -0,0 +1,55 @@ +try: + # When using torch and tensorflow, torch needs to be imported first, + # otherwise it will segfault upon import. This should force the torch + # import to happen first for all tests. + import torch # noqa: F401 +except ImportError: + torch = None + +import pytest # noqa: E402 + +from keras.src.backend import backend # noqa: E402 + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "requires_trainable_backend: mark test for trainable backend only", + ) + + +def pytest_collection_modifyitems(config, items): + openvino_skipped_tests = [] + if backend() == "openvino": + with open( + "keras/src/backend/openvino/excluded_concrete_tests.txt", "r" + ) as file: + openvino_skipped_tests = file.readlines() + # it is necessary to check if stripped line is not empty + # and exclude such lines + openvino_skipped_tests = [ + line.strip() for line in openvino_skipped_tests if line.strip() + ] + + requires_trainable_backend = pytest.mark.skipif( + backend() in ["numpy", "openvino"], + reason="Trainer not implemented for NumPy and OpenVINO backend.", + ) + for item in items: + if "requires_trainable_backend" in item.keywords: + item.add_marker(requires_trainable_backend) + # also, skip concrete tests for openvino, listed in the special file + # this is more granular mechanism to exclude tests rather + # than using --ignore option + for skipped_test in openvino_skipped_tests: + if skipped_test in item.nodeid: + item.add_marker( + skip_if_backend( + "openvino", + "Not supported operation by openvino backend", + ) + ) + + +def skip_if_backend(given_backend, reason): + return pytest.mark.skipif(backend() == given_backend, reason=reason) diff --git a/examples/cifar10_cnn.py b/examples/cifar10_cnn.py deleted file mode 100644 index ab878d88b465..000000000000 --- a/examples/cifar10_cnn.py +++ /dev/null @@ -1,117 +0,0 @@ -from keras.datasets import cifar10 -from keras.preprocessing.image import ImageDataGenerator -from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation, Flatten -from keras.layers.convolutional import Convolution2D, MaxPooling2D -from keras.optimizers import SGD, Adadelta, Adagrad -from keras.utils import np_utils, generic_utils - -''' - Train a (fairly simple) deep CNN on the CIFAR10 small images dataset. - - GPU run command: - THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_cnn.py - - It gets down to 0.65 test logloss in 25 epochs, and down to 0.55 after 50 epochs. - (it's still underfitting at that point, though). -''' - -batch_size = 32 -nb_classes = 10 -nb_epoch = 200 -data_augmentation = True - -# the data, shuffled and split between tran and test sets -(X_train, y_train), (X_test, y_test) = cifar10.load_data(test_split=0.1) -print X_train.shape[0], 'train samples' -print X_test.shape[0], 'test samples' - -# convert class vectors to binary class matrices -Y_train = np_utils.to_categorical(y_train, nb_classes) -Y_test = np_utils.to_categorical(y_test, nb_classes) - -model = Sequential() - -model.add(Convolution2D(32, 3, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(32, 32, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) -model.add(Dropout(0.25)) - -model.add(Convolution2D(64, 32, 3, 3, border_mode='full')) -model.add(Activation('relu')) -model.add(Convolution2D(64, 64, 3, 3)) -model.add(Activation('relu')) -model.add(MaxPooling2D(poolsize=(2, 2))) -model.add(Dropout(0.25)) - -model.add(Flatten(64*8*8)) -model.add(Dense(64*8*8, 512, init='normal')) -model.add(Activation('relu')) -model.add(Dropout(0.5)) - -model.add(Dense(512, nb_classes, init='normal')) -model.add(Activation('softmax')) - -# let's train the model using SGD + momentum (how original). -sgd = SGD(lr=0.01, decay=1e-7, momentum=0.9, nesterov=True) -model.compile(loss='categorical_crossentropy', optimizer=sgd) - -if not data_augmentation: - print "Not using data augmentation or normalization" - - X_train = X_train.astype("float32") - X_test = X_train.astype("float32") - X_train /= 255 - X_test /= 255 - print X_train[0] - model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=10) - score = model.evaluate(X_test, Y_test, batch_size=batch_size) - print 'Test score:', score - -else: - print "Using real time data augmentation" - - # this will do preprocessing and realtime data augmentation - datagen = ImageDataGenerator( - featurewise_center=True, # set input mean to 0 over the dataset - samplewise_center=False, # set each sample mean to 0 - featurewise_std_normalization=True, # divide inputs by std of the dataset - samplewise_std_normalization=False, # divide each input by its std - zca_whitening=False, # apply ZCA whitening - rotation_range=20, # randomly rotate images in the range (degrees, 0 to 180) - width_shift_range=0.2, # randomly shift images horizontally (fraction of total width) - height_shift_range=0.2, # randomly shift images vertically (fraction of total height) - horizontal_flip=True, # randomly flip images - vertical_flip=False) # randomly flip images - - # compute quantities required for featurewise normalization - # (std, mean, and principal components if ZCA whitening is applied) - datagen.fit(X_train) - - for e in range(nb_epoch): - print '-'*40 - print 'Epoch', e - print '-'*40 - print "Training..." - # batch train with realtime data augmentation - progbar = generic_utils.Progbar(X_train.shape[0]) - for X_batch, Y_batch in datagen.flow(X_train, Y_train): - loss = model.train(X_batch, Y_batch) - progbar.add(X_batch.shape[0], values=[("train loss", loss)]) - - print "Testing..." - # test time! - progbar = generic_utils.Progbar(X_test.shape[0]) - for X_batch, Y_batch in datagen.flow(X_test, Y_test): - score = model.test(X_batch, Y_batch) - progbar.add(X_batch.shape[0], values=[("test loss", score)]) - - - - - - - - diff --git a/examples/demo_custom_jax_workflow.py b/examples/demo_custom_jax_workflow.py new file mode 100644 index 000000000000..196262888d32 --- /dev/null +++ b/examples/demo_custom_jax_workflow.py @@ -0,0 +1,121 @@ +# flake8: noqa +import os + +# Set backend env to JAX +os.environ["KERAS_BACKEND"] = "jax" + +import jax +import numpy as np + +from keras import Model +from keras import backend +from keras import initializers +from keras import layers +from keras import ops +from keras import optimizers + + +class MyDense(layers.Layer): + def __init__(self, units, name=None): + super().__init__(name=name) + self.units = units + + def build(self, input_shape): + input_dim = input_shape[-1] + w_shape = (input_dim, self.units) + w_value = initializers.GlorotUniform()(w_shape) + self.w = backend.Variable(w_value, name="kernel") + + b_shape = (self.units,) + b_value = initializers.Zeros()(b_shape) + self.b = backend.Variable(b_value, name="bias") + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +class MyModel(Model): + def __init__(self, hidden_dim, output_dim): + super().__init__() + self.dense1 = MyDense(hidden_dim) + self.dense2 = MyDense(hidden_dim) + self.dense3 = MyDense(output_dim) + + def call(self, x): + x = jax.nn.relu(self.dense1(x)) + x = jax.nn.relu(self.dense2(x)) + return self.dense3(x) + + +def Dataset(): + for _ in range(20): + yield (np.random.random((32, 128)), np.random.random((32, 4))) + + +def loss_fn(y_true, y_pred): + return ops.sum((y_true - y_pred) ** 2) + + +model = MyModel(hidden_dim=256, output_dim=4) + +optimizer = optimizers.SGD(learning_rate=0.001) +dataset = Dataset() + +# Build model +x = np.random.random((1, 128)) +model(x) +# Build optimizer +optimizer.build(model.trainable_variables) + + +######### Custom JAX workflow ############### + + +def compute_loss_and_updates( + trainable_variables, non_trainable_variables, x, y +): + y_pred, non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss = loss_fn(y, y_pred) + return loss, non_trainable_variables + + +grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True) + + +@jax.jit +def train_step(state, data): + trainable_variables, non_trainable_variables, optimizer_variables = state + x, y = data + (loss, non_trainable_variables), grads = grad_fn( + trainable_variables, non_trainable_variables, x, y + ) + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + # Return updated state + return loss, ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) + + +trainable_variables = model.trainable_variables +non_trainable_variables = model.non_trainable_variables +optimizer_variables = optimizer.variables +state = trainable_variables, non_trainable_variables, optimizer_variables +# Training loop +for data in dataset: + loss, state = train_step(state, data) + print("Loss:", loss) + +# Post-processing model state update +trainable_variables, non_trainable_variables, optimizer_variables = state +for variable, value in zip(model.trainable_variables, trainable_variables): + variable.assign(value) +for variable, value in zip( + model.non_trainable_variables, non_trainable_variables +): + variable.assign(value) diff --git a/examples/demo_custom_layer_backend_agnostic.py b/examples/demo_custom_layer_backend_agnostic.py new file mode 100644 index 000000000000..b3849c20cb50 --- /dev/null +++ b/examples/demo_custom_layer_backend_agnostic.py @@ -0,0 +1,87 @@ +import numpy as np + +import keras +from keras import Model +from keras import initializers +from keras import layers +from keras import losses +from keras import metrics +from keras import ops +from keras import optimizers + + +class MyDense(layers.Layer): + def __init__(self, units, name=None): + super().__init__(name=name) + self.units = units + + def build(self, input_shape): + input_dim = input_shape[-1] + self.w = self.add_weight( + shape=(input_dim, self.units), + initializer=initializers.GlorotNormal(), + name="kernel", + trainable=True, + ) + + self.b = self.add_weight( + shape=(self.units,), + initializer=initializers.Zeros(), + name="bias", + trainable=True, + ) + + def call(self, inputs): + # Use Keras ops to create backend-agnostic layers/metrics/etc. + return ops.matmul(inputs, self.w) + self.b + + +class MyDropout(layers.Layer): + def __init__(self, rate, name=None): + super().__init__(name=name) + self.rate = rate + # Use seed_generator for managing RNG state. + # It is a state element and its seed variable is + # tracked as part of `layer.variables`. + self.seed_generator = keras.random.SeedGenerator(1337) + + def call(self, inputs): + # Use `keras.random` for random ops. + return keras.random.dropout(inputs, self.rate, seed=self.seed_generator) + + +class MyModel(Model): + def __init__(self, hidden_dim, output_dim): + super().__init__() + self.dense1 = MyDense(hidden_dim) + self.dense2 = MyDense(hidden_dim) + self.dense3 = MyDense(output_dim) + self.dp = MyDropout(0.5) + + def call(self, x): + x1 = self.dense1(x) + x2 = self.dense2(x) + # Why not use some ops here as well + x = ops.concatenate([x1, x2], axis=-1) + x = self.dp(x) + return self.dense3(x) + + +model = MyModel(hidden_dim=256, output_dim=16) + +x = np.random.random((50000, 128)) +y = np.random.random((50000, 16)) +batch_size = 32 +epochs = 5 + +model.compile( + optimizer=optimizers.SGD(learning_rate=0.001), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], +) +history = model.fit(x, y, batch_size=batch_size, epochs=epochs) + +model.summary() + +print("History:") +print(history.history) diff --git a/examples/demo_custom_tf_workflow.py b/examples/demo_custom_tf_workflow.py new file mode 100644 index 000000000000..b8fc2b7b6d78 --- /dev/null +++ b/examples/demo_custom_tf_workflow.py @@ -0,0 +1,84 @@ +# flake8: noqa +import os + +# Set backend env to tensorflow +os.environ["KERAS_BACKEND"] = "tensorflow" + +import numpy as np +import tensorflow as tf + +from keras import Model +from keras import backend +from keras import initializers +from keras import layers +from keras import ops +from keras import optimizers + + +class MyDense(layers.Layer): + def __init__(self, units, name=None): + super().__init__(name=name) + self.units = units + + def build(self, input_shape): + input_dim = input_shape[-1] + w_shape = (input_dim, self.units) + w_value = initializers.GlorotUniform()(w_shape) + self.w = backend.Variable(w_value, name="kernel") + + b_shape = (self.units,) + b_value = initializers.Zeros()(b_shape) + self.b = backend.Variable(b_value, name="bias") + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +class MyModel(Model): + def __init__(self, hidden_dim, output_dim): + super().__init__() + self.dense1 = MyDense(hidden_dim) + self.dense2 = MyDense(hidden_dim) + self.dense3 = MyDense(output_dim) + + def call(self, x): + x = tf.nn.relu(self.dense1(x)) + x = tf.nn.relu(self.dense2(x)) + return self.dense3(x) + + +def Dataset(): + for _ in range(20): + yield ( + np.random.random((32, 128)).astype("float32"), + np.random.random((32, 4)).astype("float32"), + ) + + +def loss_fn(y_true, y_pred): + return ops.sum((y_true - y_pred) ** 2) + + +model = MyModel(hidden_dim=256, output_dim=4) + +optimizer = optimizers.SGD(learning_rate=0.001) +dataset = Dataset() + + +######### Custom TF workflow ############### + + +@tf.function(jit_compile=True) +def train_step(data): + x, y = data + with tf.GradientTape() as tape: + y_pred = model(x) + loss = loss_fn(y, y_pred) + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + return loss + + +for data in dataset: + loss = train_step(data) + print("Loss:", float(loss)) diff --git a/examples/demo_custom_torch_workflow.py b/examples/demo_custom_torch_workflow.py new file mode 100644 index 000000000000..ebd0b51a26c8 --- /dev/null +++ b/examples/demo_custom_torch_workflow.py @@ -0,0 +1,130 @@ +# flake8: noqa +import os + +# Set backend env to torch +os.environ["KERAS_BACKEND"] = "torch" + +import torch +import torch.nn as nn +import torch.optim as optim +from keras import layers +import keras +import numpy as np + +# Model / data parameters +num_classes = 10 +input_shape = (28, 28, 1) +learning_rate = 0.01 +batch_size = 64 +num_epochs = 1 + +# Load the data and split it between train and test sets +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +# Scale images to the [0, 1] range +x_train = x_train.astype("float32") / 255 +x_test = x_test.astype("float32") / 255 +# Make sure images have shape (28, 28, 1) +x_train = np.expand_dims(x_train, -1) +x_test = np.expand_dims(x_test, -1) +print("x_train shape:", x_train.shape) +print(x_train.shape[0], "train samples") +print(x_test.shape[0], "test samples") + +# Create the Keras model +model = keras.Sequential( + [ + layers.Input(shape=(28, 28, 1)), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes), + ] +) + +################################################################# +######## Writing a torch training loop for a Keras model ######## +################################################################# + +# Instantiate the torch optimizer +optimizer = optim.Adam(model.parameters(), lr=learning_rate) + +# Instantiate the torch loss function +loss_fn = nn.CrossEntropyLoss() + + +def train(model, train_loader, num_epochs, optimizer, loss_fn): + for epoch in range(num_epochs): + running_loss = 0.0 + for batch_idx, (inputs, targets) in enumerate(train_loader): + # Forward pass + outputs = model(inputs) + loss = loss_fn(outputs, targets) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + # Print loss statistics + if (batch_idx + 1) % 10 == 0: + print( + f"Epoch [{epoch + 1}/{num_epochs}], " + f"Batch [{batch_idx + 1}/{len(train_loader)}], " + f"Loss: {running_loss / 10}" + ) + running_loss = 0.0 + + +# Create a TensorDataset +dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) +) + +# Create a DataLoader +train_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=False +) + +train(model, train_loader, num_epochs, optimizer, loss_fn) + + +################################################################ +######## Using a Keras model or layer in a torch Module ######## +################################################################ + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.model = keras.Sequential( + [ + layers.Input(shape=(28, 28, 1)), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes), + ] + ) + + def forward(self, x): + return self.model(x) + + +torch_module = MyModel() + +# Instantiate the torch optimizer +optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate) + +# Instantiate the torch loss function +loss_fn = nn.CrossEntropyLoss() + +train(torch_module, train_loader, num_epochs, optimizer, loss_fn) diff --git a/examples/demo_functional.py b/examples/demo_functional.py new file mode 100644 index 000000000000..0c7f7ce487e6 --- /dev/null +++ b/examples/demo_functional.py @@ -0,0 +1,59 @@ +import numpy as np + +from keras import Model +from keras import layers +from keras import losses +from keras import metrics +from keras import optimizers +import keras + +keras.config.disable_traceback_filtering() + +inputs = layers.Input((100,)) +x = layers.Dense(512, activation="relu")(inputs) +residual = x +x = layers.Dense(512, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) +x += residual +x = layers.Dense(512, activation="relu")(x) +residual = x +x = layers.Dense(512, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) +x += residual +residual = x +x = layers.Dense(512, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) +x += residual +outputs = layers.Dense(16)(x) +model = Model(inputs, outputs) + +model.summary() + +x = np.random.random((50000, 100)) +y = np.random.random((50000, 16)) +batch_size = 32 +epochs = 5 + +model.compile( + optimizer=optimizers.Adam(learning_rate=0.001), + loss=losses.MeanSquaredError(), + metrics=[ + metrics.CategoricalAccuracy(name="acc"), + metrics.MeanSquaredError(name="mse"), + ], +) + +print("\nTrain model") +history = model.fit( + x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 +) +print("\nHistory:") +print(history.history) + +print("\nEvaluate model") +scores = model.evaluate(x, y, return_dict=True) +print(scores) + +print("\nRun inference") +pred = model.predict(x) +print(f"Inferred output shape {pred.shape}") diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py new file mode 100644 index 000000000000..906dc47563de --- /dev/null +++ b/examples/demo_jax_distributed.py @@ -0,0 +1,342 @@ +# To run this demo, you will need to spin up a "TPU VM" on Google Cloud. +# Please follow instructions here: https://cloud.google.com/tpu/docs/run-calculation-jax + +# Force a JAX backend +import os, pprint, collections + +os.environ["KERAS_BACKEND"] = "jax" + +pp = pprint.PrettyPrinter() + +import jax +import jax.numpy as jnp +import tensorflow as tf # just for tf.data +import keras # Keras multi-backend + +import numpy as np +from tqdm import tqdm + +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +""" Dataset +Classic MNIST, loaded using tf.data +""" + +BATCH_SIZE = 192 + +( + (x_train, train_labels), + (x_eval, eval_labels), +) = keras.datasets.mnist.load_data() +x_train = np.expand_dims(x_train, axis=-1).astype( + np.float32 +) # from 28x28 to 28x28 x 1 color channel (B&W) +x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32) + +train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels)) +train_data = train_data.shuffle(5000, reshuffle_each_iteration=True) +train_data = train_data.batch(BATCH_SIZE, drop_remainder=True) +train_data = train_data.repeat() + +eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels)) +eval_data = eval_data.batch(10000) # everything as one batch + +STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE + +""" Keras model +Simple but non-trivial model with: +* Batch Normalization (non-trainable state updated during training, different training-time and inference behavior) +* Dropout (randomness, different training time and inference behavior) +""" + + +# Keras "sequential" model building style +def make_backbone(): + return keras.Sequential( + [ + keras.layers.Rescaling( + 1.0 / 255.0 + ), # input images are in the range [0, 255] + keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=24, + kernel_size=6, + padding="same", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + use_bias=False, + strides=2, + name="large_k", + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + ], + name="backbone", + ) + + +def make_model(): + input = keras.Input(shape=[28, 28, 1]) + y = make_backbone()(input) + y = keras.layers.Flatten()(y) + y = keras.layers.Dense(200, activation="relu")(y) + y = keras.layers.Dropout(0.4)(y) + y = keras.layers.Dense(10, activation="softmax")(y) + model = keras.Model(inputs=input, outputs=y) + return model + + +""" JAX-native distribution with a Keras model +For now, you have to write a custom training loop for this +Note: The features required by jax.sharding are not supported by the Colab TPU +runtime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs. +""" + +if len(jax.local_devices()) < 8: + raise Exception("This part requires 8 devices to run") +else: + print("\nIdentified local devices:") + pp.pprint(jax.local_devices()) + +# ----------------- Keras --------------------- + +# instantiate the model +model = make_model() + +# learning rate +lr = keras.optimizers.schedules.ExponentialDecay(0.01, STEPS_PER_EPOCH, 0.6) + +# optimizer +optimizer = keras.optimizers.Adam(lr) + +# initialize all state with .build() +(one_batch, one_batch_labels) = next(iter(train_data)) +model.build(one_batch) +optimizer.build(model.trainable_variables) + +""" Distribution settings + +* Sharding the data on the batch axis +* Replicating all model variables + +Note: this implements standard "data parallel" distributed training + +* Just for show, sharding the largest convolutional kernel along the + "channels" axis 4-ways and replicating 2-ways + +Note: this does not reflect a best practice but is intended to show + that you can split a very large kernel across multiple devices + if you have to +""" + +print( + "\nMostly data-parallel distribution. " + "Data is sharded across devices while the model is replicated. " + "For demo purposes, we split the largest kernel 4-ways " + "(and replicate 2-ways since we have 8 devices)." +) + +# ------------------ Jax ---------------------- + +devices = mesh_utils.create_device_mesh((8,)) + +# data will be split along the batch axis +data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh +# naming axes of the sharded partition +data_sharding = NamedSharding( + data_mesh, + P( + "batch", + ), +) +# all variables will be replicated on all devices +var_mesh = Mesh(devices, axis_names=("_")) +# in NamedSharding, axes that are not mentioned are replicated (all axes here) +var_replication = NamedSharding(var_mesh, P()) + +# for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices) +large_kernel_mesh = Mesh( + devices.reshape((-1, 4)), axis_names=(None, "out_chan") +) # naming axes of the mesh +large_kernel_sharding = NamedSharding( + large_kernel_mesh, P(None, None, None, "out_chan") +) # naming axes of the sharded partition + +# ----------------- Keras --------------------- + +# Use Keras APIs to find the variable of a specific layer (we will be sharding this one in a special way) +# In a Conv2D or Dense layer, the variables are 'kernel' and 'bias' +special_layer_var = model.get_layer("backbone").get_layer("large_k").kernel + +# ------------------ Jax ---------------------- +# - accessing variables in Keras lists model.trainable_variables, +# - model.non_trainable_variables and optimizer.variables + +# Apply the distribution settings to the model variables +non_trainable_variables = jax.device_put( + model.non_trainable_variables, var_replication +) +optimizer_variables = jax.device_put(optimizer.variables, var_replication) +# this is what you would do replicate all trainable variables: +# trainable_variables = jax.device_put(model.trainable_variables, var_replication) + +# For the demo, we split the largest kernel 4-ways instead of replicating it. +# We still replicate all other trainable variables as in standard "data-parallel" +# distributed training. +print_once = True +trainable_variables = model.trainable_variables +for i, v in enumerate(trainable_variables): + if v is special_layer_var: + # Apply distribution settings: sharding + sharded_v = jax.device_put(v, large_kernel_sharding) + trainable_variables[i] = sharded_v + + print("Sharding of convolutional", v.name, v.shape) + jax.debug.visualize_array_sharding( + jnp.reshape(sharded_v, [-1, v.shape[-1]]) + ) + else: + # Apply distribution settings: replication + replicated_v = jax.device_put(v, var_replication) + trainable_variables[i] = replicated_v + + if print_once: + print_once = False + print( + "\nSharding of all other model variables (they are replicated)" + ) + jax.debug.visualize_array_sharding( + jnp.reshape(replicated_v, [-1, v.shape[-1]]) + ) + +# collect state in a handy named tuple +TrainingState = collections.namedtuple( + "TrainingState", + ["trainable_variables", "non_trainable_variables", "optimizer_variables"], +) +device_train_state = TrainingState( + trainable_variables=trainable_variables, + non_trainable_variables=non_trainable_variables, + optimizer_variables=optimizer_variables, +) +# display data sharding +x, y = next(iter(train_data)) +sharded_x = jax.device_put(x.numpy(), data_sharding) +print("Data sharding") +jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28 * 28])) + +# ------------------ Jax ---------------------- +# - Using Keras-provided stateless APIs +# - model.stateless_call +# - optimizer.stateless_apply +# These functions also work on other backends. + +# define loss +loss = keras.losses.SparseCategoricalCrossentropy() + + +# This is the loss function that will be differentiated. +# Keras provides a pure functional forward pass: model.stateless_call +def compute_loss(trainable_variables, non_trainable_variables, x, y): + y_pred, updated_non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss_value = loss(y, y_pred) + return loss_value, updated_non_trainable_variables + + +# function to compute gradients +compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) + + +# Training step: Keras provides a pure functional optimizer.stateless_apply +@jax.jit +def train_step(train_state, x, y): + (loss_value, non_trainable_variables), grads = compute_gradients( + train_state.trainable_variables, + train_state.non_trainable_variables, + x, + y, + ) + + trainable_variables, optimizer_variables = optimizer.stateless_apply( + train_state.optimizer_variables, grads, train_state.trainable_variables + ) + + return loss_value, TrainingState( + trainable_variables, non_trainable_variables, optimizer_variables + ) + + +# training loop +EPOCHS = 5 +print("\nTraining:") +data_iter = iter(train_data) +for epoch in range(EPOCHS): + loss_value = None # default + for i in tqdm(range(STEPS_PER_EPOCH)): + x, y = next(data_iter) + sharded_x = jax.device_put(x.numpy(), data_sharding) + loss_value, device_train_state = train_step( + device_train_state, sharded_x, y.numpy() + ) + print("Epoch", epoch, "loss:", loss_value) + +# The output of the model is still sharded. Sharding follows the data. + +data, labels = next(iter(eval_data)) +sharded_data = jax.device_put(data.numpy(), data_sharding) + + +@jax.jit +def predict(data): + predictions, updated_non_trainable_variables = model.stateless_call( + device_train_state.trainable_variables, + device_train_state.non_trainable_variables, + data, + ) + return predictions + + +predictions = predict(sharded_data) +print("\nModel output sharding follows data sharding:") +jax.debug.visualize_array_sharding(predictions) + +# Post-processing model state update to write them back into the model +update = lambda variable, value: variable.assign(value) + +jax.tree_map( + update, model.trainable_variables, device_train_state.trainable_variables +) +jax.tree_map( + update, + model.non_trainable_variables, + device_train_state.non_trainable_variables, +) +jax.tree_map( + update, optimizer.variables, device_train_state.optimizer_variables +) + +# check that the model has the new state by running an eval +# known issue: the optimizer should not be required here +model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(), + metrics=[keras.metrics.SparseCategoricalAccuracy()], +) +print("\nUpdating model and running an eval:") +loss, accuracy = model.evaluate(eval_data) +print("The model achieved an evaluation accuracy of:", accuracy) diff --git a/examples/demo_mnist_convnet.py b/examples/demo_mnist_convnet.py new file mode 100644 index 000000000000..ce08b2b92efb --- /dev/null +++ b/examples/demo_mnist_convnet.py @@ -0,0 +1,56 @@ +import numpy as np +import keras +from keras import layers +from keras.utils import to_categorical + +# Model / data parameters +num_classes = 10 +input_shape = (28, 28, 1) + +# Load the data and split it between train and test sets +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +# Scale images to the [0, 1] range +x_train = x_train.astype("float32") / 255 +x_test = x_test.astype("float32") / 255 +# Make sure images have shape (28, 28, 1) +x_train = np.expand_dims(x_train, -1) +x_test = np.expand_dims(x_test, -1) +print("x_train shape:", x_train.shape) +print(x_train.shape[0], "train samples") +print(x_test.shape[0], "test samples") + + +# convert class vectors to binary class matrices +y_train = to_categorical(y_train, num_classes) +y_test = to_categorical(y_test, num_classes) + +batch_size = 128 +epochs = 3 + +model = keras.Sequential( + [ + layers.Input(shape=input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes, activation="softmax"), + ] +) + +model.summary() + +model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] +) + +model.fit( + x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1 +) + +score = model.evaluate(x_test, y_test, verbose=0) +print("Test loss:", score[0]) +print("Test accuracy:", score[1]) diff --git a/examples/demo_subclass.py b/examples/demo_subclass.py new file mode 100644 index 000000000000..ea22f063a5d8 --- /dev/null +++ b/examples/demo_subclass.py @@ -0,0 +1,42 @@ +import numpy as np + +from keras import Model +from keras import layers +from keras import losses +from keras import metrics +from keras import optimizers + + +class MyModel(Model): + def __init__(self, hidden_dim, output_dim): + super().__init__() + self.dense1 = layers.Dense(hidden_dim, activation="relu") + self.dense2 = layers.Dense(hidden_dim, activation="relu") + self.dense3 = layers.Dense(output_dim) + + def call(self, x): + x = self.dense1(x) + x = self.dense2(x) + return self.dense3(x) + + +model = MyModel(hidden_dim=256, output_dim=16) + +x = np.random.random((50000, 128)) +y = np.random.random((50000, 16)) +batch_size = 32 +epochs = 6 + +model.compile( + optimizer=optimizers.SGD(learning_rate=0.001), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], +) +history = model.fit( + x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 +) + +print("History:") +print(history.history) + +model.summary() diff --git a/examples/demo_torch_multi_gpu.py b/examples/demo_torch_multi_gpu.py new file mode 100644 index 000000000000..8a42ab7d621e --- /dev/null +++ b/examples/demo_torch_multi_gpu.py @@ -0,0 +1,213 @@ +# flake8: noqa +import os + +# Set backend env to torch +os.environ["KERAS_BACKEND"] = "torch" + +import torch +import torch.nn as nn +import torch.optim as optim +from keras import layers +import keras +import numpy as np + +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# Model / data parameters +num_classes = 10 +input_shape = (28, 28, 1) +learning_rate = 0.01 +batch_size = 128 +num_epochs = 1 + + +def get_data(): + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + print("x_train shape:", x_train.shape) + print(x_train.shape[0], "train samples") + print(x_test.shape[0], "test samples") + + # Create a TensorDataset + dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) + ) + return dataset + + +def get_model(): + # Create the Keras model + model = keras.Sequential( + [ + layers.Input(shape=(28, 28, 1)), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes), + ] + ) + return model + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.model = keras.Sequential( + [ + layers.Input(shape=(28, 28, 1)), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes), + ] + ) + + def forward(self, x): + return self.model(x) + + +def train(model, train_loader, num_epochs, optimizer, loss_fn): + for epoch in range(num_epochs): + running_loss = 0.0 + for batch_idx, (inputs, targets) in enumerate(train_loader): + inputs = inputs.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + # Forward pass + outputs = model(inputs) + loss = loss_fn(outputs, targets) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + # Print loss statistics + if (batch_idx + 1) % 10 == 0: + print( + f"Epoch [{epoch + 1}/{num_epochs}], " + f"Batch [{batch_idx + 1}/{len(train_loader)}], " + f"Loss: {running_loss / 10}" + ) + running_loss = 0.0 + + +def setup(current_gpu_index, num_gpu): + # Device setup + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "56492" + device = torch.device("cuda:{}".format(current_gpu_index)) + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=num_gpu, + rank=current_gpu_index, + ) + torch.cuda.set_device(device) + + +def prepare(dataset, current_gpu_index, num_gpu, batch_size): + sampler = DistributedSampler( + dataset, + num_replicas=num_gpu, + rank=current_gpu_index, + shuffle=False, + ) + + # Create a DataLoader + train_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + shuffle=False, + ) + + return train_loader + + +def cleanup(): + # Cleanup + dist.destroy_process_group() + + +def main(current_gpu_index, num_gpu): + # setup the process groups + setup(current_gpu_index, num_gpu) + + ################################################################# + ######## Writing a torch training loop for a Keras model ######## + ################################################################# + + dataset = get_data() + model = get_model() + + # prepare the dataloader + dataloader = prepare(dataset, current_gpu_index, num_gpu, batch_size) + + # Instantiate the torch optimizer + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + # Instantiate the torch loss function + loss_fn = nn.CrossEntropyLoss() + + # Put model on device + model = model.to(current_gpu_index) + ddp_model = DDP( + model, device_ids=[current_gpu_index], output_device=current_gpu_index + ) + + train(ddp_model, dataloader, num_epochs, optimizer, loss_fn) + + ################################################################ + ######## Using a Keras model or layer in a torch Module ######## + ################################################################ + + torch_module = MyModel().to(current_gpu_index) + ddp_torch_module = DDP( + torch_module, + device_ids=[current_gpu_index], + output_device=current_gpu_index, + ) + + # Instantiate the torch optimizer + optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate) + + # Instantiate the torch loss function + loss_fn = nn.CrossEntropyLoss() + + train(ddp_torch_module, dataloader, num_epochs, optimizer, loss_fn) + + cleanup() + + +if __name__ == "__main__": + # GPU parameters + num_gpu = torch.cuda.device_count() + + print(f"Running on {num_gpu} GPUs") + + torch.multiprocessing.spawn( + main, + args=(num_gpu,), + nprocs=num_gpu, + join=True, + ) diff --git a/examples/imdb_lstm.py b/examples/imdb_lstm.py deleted file mode 100644 index 5d9ed0b313a0..000000000000 --- a/examples/imdb_lstm.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np - -from keras.preprocessing import sequence -from keras.optimizers import SGD, RMSprop, Adagrad -from keras.utils import np_utils -from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation, Embedding -from keras.layers.recurrent import LSTM, GRU -from keras.datasets import imdb - -''' - Train a LSTM on the IMDB sentiment classification task. - - The dataset is actually too small for LSTM to be of any advantage - compared to simpler, much faster methods such as TF-IDF+LogReg. - - Notes: - - - RNNs are tricky. Choice of batch size is important, - choice of loss and optimizer is critical, etc. - Most configurations won't converge. - - - LSTM loss decrease during training can be quite different - from what you see with CNNs/MLPs/etc. It's more or less a sigmoid - instead of an inverse exponential. -''' - -max_features=20000 -maxlen = 100 # cut texts after this number of words (among top max_features most common words) -batch_size = 16 - -print "Loading data..." -(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features, test_split=0.2) -print len(X_train), 'train sequences' -print len(X_test), 'test sequences' - -print "Pad sequences (samples x time)" -X_train = sequence.pad_sequences(X_train, maxlen=maxlen) -X_test = sequence.pad_sequences(X_test, maxlen=maxlen) -print 'X_train shape:', X_train.shape -print 'X_test shape:', X_test.shape - -print 'Build model...' -model = Sequential() -model.add(Embedding(max_features, 256)) -model.add(LSTM(256, 128)) # try using a GRU instead, for fun -model.add(Dropout(0.5)) -model.add(Dense(128, 1)) -model.add(Activation('sigmoid')) - -# try using different optimizers and different optimizer configs -model.compile(loss='binary_crossentropy', optimizer='rmsprop') - -print "Train..." -model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=10, verbose=1) -score = model.evaluate(X_test, y_test, batch_size=batch_size) -print 'Test score:', score - -classes = model.predict_classes(X_test, batch_size=batch_size) -acc = np_utils.accuracy(classes, y_test) -print 'Test accuracy:', acc - diff --git a/examples/reuters_mlp.py b/examples/reuters_mlp.py deleted file mode 100644 index d081ca06cfa1..000000000000 --- a/examples/reuters_mlp.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np - -from keras.datasets import reuters -from keras.models import Sequential -from keras.layers.core import Dense, Dropout, Activation -from keras.layers.normalization import BatchNormalization -from keras.utils import np_utils -from keras.preprocessing.text import Tokenizer - -''' - Train and evaluate a simple MLP on the Reuters newswire topic classification task. - - GPU run command: - THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python examples/reuters_mlp.py - - CPU run command: - python examples/reuters_mlp.py -''' - -max_words = 10000 -batch_size = 16 - -print "Loading data..." -(X_train, y_train), (X_test, y_test) = reuters.load_data(nb_words=max_words, test_split=0.2) -print len(X_train), 'train sequences' -print len(X_test), 'test sequences' - -nb_classes = np.max(y_train)+1 -print nb_classes, 'classes' - -print "Vectorizing sequence data..." -tokenizer = Tokenizer(nb_words=max_words) -X_train = tokenizer.sequences_to_matrix(X_train, mode="binary") -X_test = tokenizer.sequences_to_matrix(X_test, mode="binary") -print 'X_train shape:', X_train.shape -print 'X_test shape:', X_test.shape - -print "Convert class vector to binary class matrix (for use with categorical_crossentropy)" -Y_train = np_utils.to_categorical(y_train, nb_classes) -Y_test = np_utils.to_categorical(y_test, nb_classes) -print 'Y_train shape:', Y_train.shape -print 'Y_test shape:', Y_test.shape - -print "Building model..." -model = Sequential() -model.add(Dense(max_words, 256, init='normal')) -model.add(Activation('relu')) -#model.add(BatchNormalization(input_shape=(256,))) # try without batch normalization (doesn't work as well!) -model.add(Dropout(0.5)) -model.add(Dense(256, nb_classes, init='normal')) -model.add(Activation('softmax')) - -model.compile(loss='categorical_crossentropy', optimizer='adadelta') - -print "Training..." -model.fit(X_train, Y_train, nb_epoch=5, batch_size=batch_size) -score = model.evaluate(X_test, Y_test, batch_size=batch_size) -print 'Test score:', score - -classes = model.predict_classes(X_test, batch_size=batch_size) -acc = np_utils.accuracy(classes, y_test) -print 'Test accuracy:', acc - diff --git a/guides/custom_train_step_in_jax.py b/guides/custom_train_step_in_jax.py new file mode 100644 index 000000000000..2085b2028680 --- /dev/null +++ b/guides/custom_train_step_in_jax.py @@ -0,0 +1,357 @@ +""" +Title: Customizing what happens in `fit()` with JAX +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/06/27 +Last modified: 2023/06/27 +Description: Overriding the training step of the Model class with JAX. +Accelerator: GPU +""" + +""" +## Introduction + +When you're doing supervised learning, you can use `fit()` and everything works +smoothly. + +When you need to take control of every little detail, you can write your own training +loop entirely from scratch. + +But what if you need a custom training algorithm, but you still want to benefit from +the convenient features of `fit()`, such as callbacks, built-in distribution support, +or step fusing? + +A core principle of Keras is **progressive disclosure of complexity**. You should +always be able to get into lower-level workflows in a gradual way. You shouldn't fall +off a cliff if the high-level functionality doesn't exactly match your use case. You +should be able to gain more control over the small details while retaining a +commensurate amount of high-level convenience. + +When you need to customize what `fit()` does, you should **override the training step +function of the `Model` class**. This is the function that is called by `fit()` for +every batch of data. You will then be able to call `fit()` as usual -- and it will be +running your own learning algorithm. + +Note that this pattern does not prevent you from building models with the Functional +API. You can do this whether you're building `Sequential` models, Functional API +models, or subclassed models. + +Let's see how that works. +""" + +""" +## Setup +""" + +import os + +# This guide can only be run with the JAX backend. +os.environ["KERAS_BACKEND"] = "jax" + +import jax +import keras +import numpy as np + +""" +## A first simple example + +Let's start from a simple example: + +- We create a new class that subclasses `keras.Model`. +- We implement a fully-stateless `compute_loss_and_updates()` method +to compute the loss as well as the updated values for the non-trainable +variables of the model. Internally, it calls `stateless_call()` and +the built-in `compute_loss()`. +- We implement a fully-stateless `train_step()` method to compute current +metric values (including the loss) as well as updated values for the +trainable variables, the optimizer variables, and the metric variables. + +Note that you can also take into account the `sample_weight` argument by: + +- Unpacking the data as `x, y, sample_weight = data` +- Passing `sample_weight` to `compute_loss()` +- Passing `sample_weight` alongside `y` and `y_pred` +to metrics in `stateless_update_state()` +""" + + +class CustomModel(keras.Model): + def compute_loss_and_updates( + self, + trainable_variables, + non_trainable_variables, + x, + y, + training=False, + ): + y_pred, non_trainable_variables = self.stateless_call( + trainable_variables, + non_trainable_variables, + x, + training=training, + ) + loss = self.compute_loss(x, y, y_pred) + return loss, (y_pred, non_trainable_variables) + + def train_step(self, state, data): + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + x, y = data + + # Get the gradient function. + grad_fn = jax.value_and_grad( + self.compute_loss_and_updates, has_aux=True + ) + + # Compute the gradients. + (loss, (y_pred, non_trainable_variables)), grads = grad_fn( + trainable_variables, + non_trainable_variables, + x, + y, + training=True, + ) + + # Update trainable variables and optimizer variables. + ( + trainable_variables, + optimizer_variables, + ) = self.optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + + # Update metrics. + new_metrics_vars, logs = [], [] + for metric in self.metrics: + this_metric_vars = metrics_variables[ + len(new_metrics_vars) : len(new_metrics_vars) + + len(metric.variables) + ] + if metric.name == "loss": + this_metric_vars = metric.stateless_update_state( + this_metric_vars, loss + ) + else: + this_metric_vars = metric.stateless_update_state( + this_metric_vars, y, y_pred + ) + logs = metric.stateless_result(this_metric_vars) + new_metrics_vars += this_metric_vars + + # Return metric logs and updated state variables. + state = ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + new_metrics_vars, + ) + return logs, state + + +""" +Let's try this out: +""" + +# Construct and compile an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + +# Just use `fit` as usual +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.fit(x, y, epochs=3) + + +""" +## Going lower-level + +Naturally, you could just skip passing a loss function in `compile()`, and instead do +everything *manually* in `train_step`. Likewise for metrics. + +Here's a lower-level example, that only uses `compile()` to configure the optimizer: +""" + + +class CustomModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae") + self.loss_fn = keras.losses.MeanSquaredError() + + def compute_loss_and_updates( + self, + trainable_variables, + non_trainable_variables, + x, + y, + training=False, + ): + y_pred, non_trainable_variables = self.stateless_call( + trainable_variables, + non_trainable_variables, + x, + training=training, + ) + loss = self.loss_fn(y, y_pred) + return loss, (y_pred, non_trainable_variables) + + def train_step(self, state, data): + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + x, y = data + + # Get the gradient function. + grad_fn = jax.value_and_grad( + self.compute_loss_and_updates, has_aux=True + ) + + # Compute the gradients. + (loss, (y_pred, non_trainable_variables)), grads = grad_fn( + trainable_variables, + non_trainable_variables, + x, + y, + training=True, + ) + + # Update trainable variables and optimizer variables. + ( + trainable_variables, + optimizer_variables, + ) = self.optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + + # Update metrics. + loss_tracker_vars = metrics_variables[ + : len(self.loss_tracker.variables) + ] + mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :] + + loss_tracker_vars = self.loss_tracker.stateless_update_state( + loss_tracker_vars, loss + ) + mae_metric_vars = self.mae_metric.stateless_update_state( + mae_metric_vars, y, y_pred + ) + + logs = {} + logs[self.loss_tracker.name] = self.loss_tracker.stateless_result( + loss_tracker_vars + ) + logs[self.mae_metric.name] = self.mae_metric.stateless_result( + mae_metric_vars + ) + + new_metrics_vars = loss_tracker_vars + mae_metric_vars + + # Return metric logs and updated state variables. + state = ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + new_metrics_vars, + ) + return logs, state + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + return [self.loss_tracker, self.mae_metric] + + +# Construct an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) + +# We don't pass a loss or metrics here. +model.compile(optimizer="adam") + +# Just use `fit` as usual -- you can use callbacks, etc. +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.fit(x, y, epochs=5) + + +""" +## Providing your own evaluation step + +What if you want to do the same for calls to `model.evaluate()`? Then you would +override `test_step` in exactly the same way. Here's what it looks like: +""" + + +class CustomModel(keras.Model): + def test_step(self, state, data): + # Unpack the data. + x, y = data + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state + + # Compute predictions and loss. + y_pred, non_trainable_variables = self.stateless_call( + trainable_variables, + non_trainable_variables, + x, + training=False, + ) + loss = self.compute_loss(x, y, y_pred) + + # Update metrics. + new_metrics_vars, logs = [], [] + for metric in self.metrics: + this_metric_vars = metrics_variables[ + len(new_metrics_vars) : len(new_metrics_vars) + + len(metric.variables) + ] + if metric.name == "loss": + this_metric_vars = metric.stateless_update_state( + this_metric_vars, loss + ) + else: + this_metric_vars = metric.stateless_update_state( + this_metric_vars, y, y_pred + ) + logs = metric.stateless_result(this_metric_vars) + new_metrics_vars += this_metric_vars + + # Return metric logs and updated state variables. + state = ( + trainable_variables, + non_trainable_variables, + new_metrics_vars, + ) + return logs, state + + +# Construct an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(loss="mse", metrics=["mae"]) + +# Evaluate with our custom test_step +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.evaluate(x, y) + + +""" +That's it! +""" diff --git a/guides/custom_train_step_in_tensorflow.py b/guides/custom_train_step_in_tensorflow.py new file mode 100644 index 000000000000..6959da6989ab --- /dev/null +++ b/guides/custom_train_step_in_tensorflow.py @@ -0,0 +1,463 @@ +""" +Title: Customizing what happens in `fit()` with TensorFlow +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2020/04/15 +Last modified: 2023/06/27 +Description: Overriding the training step of the Model class with TensorFlow. +Accelerator: GPU +""" + +""" +## Introduction + +When you're doing supervised learning, you can use `fit()` and everything works +smoothly. + +When you need to take control of every little detail, you can write your own training +loop entirely from scratch. + +But what if you need a custom training algorithm, but you still want to benefit from +the convenient features of `fit()`, such as callbacks, built-in distribution support, +or step fusing? + +A core principle of Keras is **progressive disclosure of complexity**. You should +always be able to get into lower-level workflows in a gradual way. You shouldn't fall +off a cliff if the high-level functionality doesn't exactly match your use case. You +should be able to gain more control over the small details while retaining a +commensurate amount of high-level convenience. + +When you need to customize what `fit()` does, you should **override the training step +function of the `Model` class**. This is the function that is called by `fit()` for +every batch of data. You will then be able to call `fit()` as usual -- and it will be +running your own learning algorithm. + +Note that this pattern does not prevent you from building models with the Functional +API. You can do this whether you're building `Sequential` models, Functional API +models, or subclassed models. + +Let's see how that works. +""" + +""" +## Setup +""" + +import os + +# This guide can only be run with the TF backend. +os.environ["KERAS_BACKEND"] = "tensorflow" + +import tensorflow as tf +import keras +from keras import layers +import numpy as np + +""" +## A first simple example + +Let's start from a simple example: + +- We create a new class that subclasses `keras.Model`. +- We just override the method `train_step(self, data)`. +- We return a dictionary mapping metric names (including the loss) to their current +value. + +The input argument `data` is what gets passed to fit as training data: + +- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple +`(x, y)` +- If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, then `data` will be +what gets yielded by `dataset` at each batch. + +In the body of the `train_step()` method, we implement a regular training update, +similar to what you are already familiar with. Importantly, **we compute the loss via +`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to +`compile()`. + +Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`, +to update the state of the metrics that were passed in `compile()`, +and we query results from `self.metrics` at the end to retrieve their current value. +""" + + +class CustomModel(keras.Model): + def train_step(self, data): + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + x, y = data + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute the loss value + # (the loss function is configured in `compile()`) + loss = self.compute_loss(y=y, y_pred=y_pred) + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + # Update weights + self.optimizer.apply(gradients, trainable_vars) + + # Update metrics (includes the metric that tracks the loss) + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred) + + # Return a dict mapping metric names to current value + return {m.name: m.result() for m in self.metrics} + + +""" +Let's try this out: +""" + +# Construct and compile an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + +# Just use `fit` as usual +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.fit(x, y, epochs=3) + +""" +## Going lower-level + +Naturally, you could just skip passing a loss function in `compile()`, and instead do +everything *manually* in `train_step`. Likewise for metrics. + +Here's a lower-level example, that only uses `compile()` to configure the optimizer: + +- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`). +- We implement a custom `train_step()` that updates the state of these metrics +(by calling `update_state()` on them), then query them (via `result()`) to return their current average value, +to be displayed by the progress bar and to be pass to any callback. +- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise +calling `result()` would return an average since the start of training, whereas we usually work +with per-epoch averages. Thankfully, the framework can do that for us: just list any metric +you want to reset in the `metrics` property of the model. The model will call `reset_states()` +on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to +`evaluate()`. +""" + + +class CustomModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae") + self.loss_fn = keras.losses.MeanSquaredError() + + def train_step(self, data): + x, y = data + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute our own loss + loss = self.loss_fn(y, y_pred) + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + # Update weights + self.optimizer.apply(gradients, trainable_vars) + + # Compute our own metrics + self.loss_tracker.update_state(loss) + self.mae_metric.update_state(y, y_pred) + return { + "loss": self.loss_tracker.result(), + "mae": self.mae_metric.result(), + } + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + return [self.loss_tracker, self.mae_metric] + + +# Construct an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) + +# We don't pass a loss or metrics here. +model.compile(optimizer="adam") + +# Just use `fit` as usual -- you can use callbacks, etc. +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.fit(x, y, epochs=5) + + +""" +## Supporting `sample_weight` & `class_weight` + +You may have noticed that our first basic example didn't make any mention of sample +weighting. If you want to support the `fit()` arguments `sample_weight` and +`class_weight`, you'd simply do the following: + +- Unpack `sample_weight` from the `data` argument +- Pass it to `compute_loss` & `update_state` (of course, you could also just apply +it manually if you don't rely on `compile()` for losses & metrics) +- That's it. +""" + + +class CustomModel(keras.Model): + def train_step(self, data): + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + if len(data) == 3: + x, y, sample_weight = data + else: + sample_weight = None + x, y = data + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute the loss value. + # The loss function is configured in `compile()`. + loss = self.compute_loss( + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + ) + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + # Update weights + self.optimizer.apply(gradients, trainable_vars) + + # Update the metrics. + # Metrics are configured in `compile()`. + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + # Return a dict mapping metric names to current value. + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} + + +# Construct and compile an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + +# You can now use sample_weight argument +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +sw = np.random.random((1000, 1)) +model.fit(x, y, sample_weight=sw, epochs=3) + +""" +## Providing your own evaluation step + +What if you want to do the same for calls to `model.evaluate()`? Then you would +override `test_step` in exactly the same way. Here's what it looks like: +""" + + +class CustomModel(keras.Model): + def test_step(self, data): + # Unpack the data + x, y = data + # Compute predictions + y_pred = self(x, training=False) + # Updates the metrics tracking the loss + loss = self.compute_loss(y=y, y_pred=y_pred) + # Update the metrics. + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred) + # Return a dict mapping metric names to current value. + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} + + +# Construct an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(loss="mse", metrics=["mae"]) + +# Evaluate with our custom test_step +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.evaluate(x, y) + +""" +## Wrapping up: an end-to-end GAN example + +Let's walk through an end-to-end example that leverages everything you just learned. + +Let's consider: + +- A generator network meant to generate 28x28x1 images. +- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and +"real"). +- One optimizer for each. +- A loss function to train the discriminator. +""" + +# Create the discriminator +discriminator = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1)), + layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.GlobalMaxPooling2D(), + layers.Dense(1), + ], + name="discriminator", +) + +# Create the generator +latent_dim = 128 +generator = keras.Sequential( + [ + keras.Input(shape=(latent_dim,)), + # We want to generate 128 coefficients to reshape into a 7x7x128 map + layers.Dense(7 * 7 * 128), + layers.LeakyReLU(negative_slope=0.2), + layers.Reshape((7, 7, 128)), + layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), + ], + name="generator", +) + +""" +Here's a feature-complete GAN class, overriding `compile()` to use its own signature, +and implementing the entire GAN algorithm in 17 lines in `train_step`: +""" + + +class GAN(keras.Model): + def __init__(self, discriminator, generator, latent_dim): + super().__init__() + self.discriminator = discriminator + self.generator = generator + self.latent_dim = latent_dim + self.d_loss_tracker = keras.metrics.Mean(name="d_loss") + self.g_loss_tracker = keras.metrics.Mean(name="g_loss") + self.seed_generator = keras.random.SeedGenerator(1337) + + @property + def metrics(self): + return [self.d_loss_tracker, self.g_loss_tracker] + + def compile(self, d_optimizer, g_optimizer, loss_fn): + super().compile() + self.d_optimizer = d_optimizer + self.g_optimizer = g_optimizer + self.loss_fn = loss_fn + + def train_step(self, real_images): + if isinstance(real_images, tuple): + real_images = real_images[0] + # Sample random points in the latent space + batch_size = tf.shape(real_images)[0] + random_latent_vectors = keras.random.normal( + shape=(batch_size, self.latent_dim), seed=self.seed_generator + ) + + # Decode them to fake images + generated_images = self.generator(random_latent_vectors) + + # Combine them with real images + combined_images = tf.concat([generated_images, real_images], axis=0) + + # Assemble labels discriminating real from fake images + labels = tf.concat( + [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0 + ) + # Add random noise to the labels - important trick! + labels += 0.05 * keras.random.uniform( + tf.shape(labels), seed=self.seed_generator + ) + + # Train the discriminator + with tf.GradientTape() as tape: + predictions = self.discriminator(combined_images) + d_loss = self.loss_fn(labels, predictions) + grads = tape.gradient(d_loss, self.discriminator.trainable_weights) + self.d_optimizer.apply(grads, self.discriminator.trainable_weights) + + # Sample random points in the latent space + random_latent_vectors = keras.random.normal( + shape=(batch_size, self.latent_dim), seed=self.seed_generator + ) + + # Assemble labels that say "all real images" + misleading_labels = tf.zeros((batch_size, 1)) + + # Train the generator (note that we should *not* update the weights + # of the discriminator)! + with tf.GradientTape() as tape: + predictions = self.discriminator( + self.generator(random_latent_vectors) + ) + g_loss = self.loss_fn(misleading_labels, predictions) + grads = tape.gradient(g_loss, self.generator.trainable_weights) + self.g_optimizer.apply(grads, self.generator.trainable_weights) + + # Update metrics and return their value. + self.d_loss_tracker.update_state(d_loss) + self.g_loss_tracker.update_state(g_loss) + return { + "d_loss": self.d_loss_tracker.result(), + "g_loss": self.g_loss_tracker.result(), + } + + +""" +Let's test-drive it: +""" + +# Prepare the dataset. We use both the training & test MNIST digits. +batch_size = 64 +(x_train, _), (x_test, _) = keras.datasets.mnist.load_data() +all_digits = np.concatenate([x_train, x_test]) +all_digits = all_digits.astype("float32") / 255.0 +all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) +dataset = tf.data.Dataset.from_tensor_slices(all_digits) +dataset = dataset.shuffle(buffer_size=1024).batch(batch_size) + +gan = GAN( + discriminator=discriminator, generator=generator, latent_dim=latent_dim +) +gan.compile( + d_optimizer=keras.optimizers.Adam(learning_rate=0.0003), + g_optimizer=keras.optimizers.Adam(learning_rate=0.0003), + loss_fn=keras.losses.BinaryCrossentropy(from_logits=True), +) + +# To limit the execution time, we only train on 100 batches. You can train on +# the entire dataset. You will need about 20 epochs to get nice results. +gan.fit(dataset.take(100), epochs=1) + +""" +The ideas behind deep learning are simple, so why should their implementation be painful? +""" diff --git a/guides/custom_train_step_in_torch.py b/guides/custom_train_step_in_torch.py new file mode 100644 index 000000000000..665190fe8fe8 --- /dev/null +++ b/guides/custom_train_step_in_torch.py @@ -0,0 +1,492 @@ +""" +Title: Customizing what happens in `fit()` with PyTorch +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/06/27 +Last modified: 2023/06/27 +Description: Overriding the training step of the Model class with PyTorch. +Accelerator: GPU +""" + +""" +## Introduction + +When you're doing supervised learning, you can use `fit()` and everything works +smoothly. + +When you need to take control of every little detail, you can write your own training +loop entirely from scratch. + +But what if you need a custom training algorithm, but you still want to benefit from +the convenient features of `fit()`, such as callbacks, built-in distribution support, +or step fusing? + +A core principle of Keras is **progressive disclosure of complexity**. You should +always be able to get into lower-level workflows in a gradual way. You shouldn't fall +off a cliff if the high-level functionality doesn't exactly match your use case. You +should be able to gain more control over the small details while retaining a +commensurate amount of high-level convenience. + +When you need to customize what `fit()` does, you should **override the training step +function of the `Model` class**. This is the function that is called by `fit()` for +every batch of data. You will then be able to call `fit()` as usual -- and it will be +running your own learning algorithm. + +Note that this pattern does not prevent you from building models with the Functional +API. You can do this whether you're building `Sequential` models, Functional API +models, or subclassed models. + +Let's see how that works. +""" + +""" +## Setup +""" + +import os + +# This guide can only be run with the torch backend. +os.environ["KERAS_BACKEND"] = "torch" + +import torch +import keras +from keras import layers +import numpy as np + +""" +## A first simple example + +Let's start from a simple example: + +- We create a new class that subclasses `keras.Model`. +- We just override the method `train_step(self, data)`. +- We return a dictionary mapping metric names (including the loss) to their current +value. + +The input argument `data` is what gets passed to fit as training data: + +- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple +`(x, y)` +- If you pass a `torch.utils.data.DataLoader` or a `tf.data.Dataset`, +by calling `fit(dataset, ...)`, then `data` will be what gets yielded +by `dataset` at each batch. + +In the body of the `train_step()` method, we implement a regular training update, +similar to what you are already familiar with. Importantly, **we compute the loss via +`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to +`compile()`. + +Similarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`, +to update the state of the metrics that were passed in `compile()`, +and we query results from `self.metrics` at the end to retrieve their current value. +""" + + +class CustomModel(keras.Model): + def train_step(self, data): + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + x, y = data + + # Call torch.nn.Module.zero_grad() to clear the leftover gradients + # for the weights from the previous train step. + self.zero_grad() + + # Compute loss + y_pred = self(x, training=True) # Forward pass + loss = self.compute_loss(y=y, y_pred=y_pred) + + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + + trainable_weights = [v for v in self.trainable_weights] + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + self.optimizer.apply(gradients, trainable_weights) + + # Update metrics (includes the metric that tracks the loss) + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred) + + # Return a dict mapping metric names to current value + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} + + +""" +Let's try this out: +""" + +# Construct and compile an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + +# Just use `fit` as usual +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.fit(x, y, epochs=3) + +""" +## Going lower-level + +Naturally, you could just skip passing a loss function in `compile()`, and instead do +everything *manually* in `train_step`. Likewise for metrics. + +Here's a lower-level example, that only uses `compile()` to configure the optimizer: + +- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`). +- We implement a custom `train_step()` that updates the state of these metrics +(by calling `update_state()` on them), then query them (via `result()`) to return their current average value, +to be displayed by the progress bar and to be pass to any callback. +- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise +calling `result()` would return an average since the start of training, whereas we usually work +with per-epoch averages. Thankfully, the framework can do that for us: just list any metric +you want to reset in the `metrics` property of the model. The model will call `reset_states()` +on any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to +`evaluate()`. +""" + + +class CustomModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae") + self.loss_fn = keras.losses.MeanSquaredError() + + def train_step(self, data): + x, y = data + + # Call torch.nn.Module.zero_grad() to clear the leftover gradients + # for the weights from the previous train step. + self.zero_grad() + + # Compute loss + y_pred = self(x, training=True) # Forward pass + loss = self.loss_fn(y, y_pred) + + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + + trainable_weights = [v for v in self.trainable_weights] + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + self.optimizer.apply(gradients, trainable_weights) + + # Compute our own metrics + self.loss_tracker.update_state(loss) + self.mae_metric.update_state(y, y_pred) + return { + "loss": self.loss_tracker.result(), + "mae": self.mae_metric.result(), + } + + @property + def metrics(self): + # We list our `Metric` objects here so that `reset_states()` can be + # called automatically at the start of each epoch + # or at the start of `evaluate()`. + return [self.loss_tracker, self.mae_metric] + + +# Construct an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) + +# We don't pass a loss or metrics here. +model.compile(optimizer="adam") + +# Just use `fit` as usual -- you can use callbacks, etc. +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.fit(x, y, epochs=5) + + +""" +## Supporting `sample_weight` & `class_weight` + +You may have noticed that our first basic example didn't make any mention of sample +weighting. If you want to support the `fit()` arguments `sample_weight` and +`class_weight`, you'd simply do the following: + +- Unpack `sample_weight` from the `data` argument +- Pass it to `compute_loss` & `update_state` (of course, you could also just apply +it manually if you don't rely on `compile()` for losses & metrics) +- That's it. +""" + + +class CustomModel(keras.Model): + def train_step(self, data): + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + if len(data) == 3: + x, y, sample_weight = data + else: + sample_weight = None + x, y = data + + # Call torch.nn.Module.zero_grad() to clear the leftover gradients + # for the weights from the previous train step. + self.zero_grad() + + # Compute loss + y_pred = self(x, training=True) # Forward pass + loss = self.compute_loss( + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + ) + + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + + trainable_weights = [v for v in self.trainable_weights] + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + self.optimizer.apply(gradients, trainable_weights) + + # Update metrics (includes the metric that tracks the loss) + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + # Return a dict mapping metric names to current value + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} + + +# Construct and compile an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + +# You can now use sample_weight argument +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +sw = np.random.random((1000, 1)) +model.fit(x, y, sample_weight=sw, epochs=3) + +""" +## Providing your own evaluation step + +What if you want to do the same for calls to `model.evaluate()`? Then you would +override `test_step` in exactly the same way. Here's what it looks like: +""" + + +class CustomModel(keras.Model): + def test_step(self, data): + # Unpack the data + x, y = data + # Compute predictions + y_pred = self(x, training=False) + # Updates the metrics tracking the loss + loss = self.compute_loss(y=y, y_pred=y_pred) + # Update the metrics. + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred) + # Return a dict mapping metric names to current value. + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} + + +# Construct an instance of CustomModel +inputs = keras.Input(shape=(32,)) +outputs = keras.layers.Dense(1)(inputs) +model = CustomModel(inputs, outputs) +model.compile(loss="mse", metrics=["mae"]) + +# Evaluate with our custom test_step +x = np.random.random((1000, 32)) +y = np.random.random((1000, 1)) +model.evaluate(x, y) + +""" +## Wrapping up: an end-to-end GAN example + +Let's walk through an end-to-end example that leverages everything you just learned. + +Let's consider: + +- A generator network meant to generate 28x28x1 images. +- A discriminator network meant to classify 28x28x1 images into two classes ("fake" and +"real"). +- One optimizer for each. +- A loss function to train the discriminator. +""" + +# Create the discriminator +discriminator = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1)), + layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.GlobalMaxPooling2D(), + layers.Dense(1), + ], + name="discriminator", +) + +# Create the generator +latent_dim = 128 +generator = keras.Sequential( + [ + keras.Input(shape=(latent_dim,)), + # We want to generate 128 coefficients to reshape into a 7x7x128 map + layers.Dense(7 * 7 * 128), + layers.LeakyReLU(negative_slope=0.2), + layers.Reshape((7, 7, 128)), + layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), + layers.LeakyReLU(negative_slope=0.2), + layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), + ], + name="generator", +) + +""" +Here's a feature-complete GAN class, overriding `compile()` to use its own signature, +and implementing the entire GAN algorithm in 17 lines in `train_step`: +""" + + +class GAN(keras.Model): + def __init__(self, discriminator, generator, latent_dim): + super().__init__() + self.discriminator = discriminator + self.generator = generator + self.latent_dim = latent_dim + self.d_loss_tracker = keras.metrics.Mean(name="d_loss") + self.g_loss_tracker = keras.metrics.Mean(name="g_loss") + self.seed_generator = keras.random.SeedGenerator(1337) + self.built = True + + @property + def metrics(self): + return [self.d_loss_tracker, self.g_loss_tracker] + + def compile(self, d_optimizer, g_optimizer, loss_fn): + super().compile() + self.d_optimizer = d_optimizer + self.g_optimizer = g_optimizer + self.loss_fn = loss_fn + + def train_step(self, real_images): + if isinstance(real_images, tuple): + real_images = real_images[0] + # Sample random points in the latent space + batch_size = real_images.shape[0] + random_latent_vectors = keras.random.normal( + shape=(batch_size, self.latent_dim), seed=self.seed_generator + ) + + # Decode them to fake images + generated_images = self.generator(random_latent_vectors) + + # Combine them with real images + real_images = torch.tensor(real_images) + combined_images = torch.concat([generated_images, real_images], axis=0) + + # Assemble labels discriminating real from fake images + labels = torch.concat( + [torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))], axis=0 + ) + # Add random noise to the labels - important trick! + labels += 0.05 * keras.random.uniform( + labels.shape, seed=self.seed_generator + ) + + # Train the discriminator + self.zero_grad() + predictions = self.discriminator(combined_images) + d_loss = self.loss_fn(labels, predictions) + d_loss.backward() + grads = [v.value.grad for v in self.discriminator.trainable_weights] + with torch.no_grad(): + self.d_optimizer.apply(grads, self.discriminator.trainable_weights) + + # Sample random points in the latent space + random_latent_vectors = keras.random.normal( + shape=(batch_size, self.latent_dim), seed=self.seed_generator + ) + + # Assemble labels that say "all real images" + misleading_labels = torch.zeros((batch_size, 1)) + + # Train the generator (note that we should *not* update the weights + # of the discriminator)! + self.zero_grad() + predictions = self.discriminator(self.generator(random_latent_vectors)) + g_loss = self.loss_fn(misleading_labels, predictions) + grads = g_loss.backward() + grads = [v.value.grad for v in self.generator.trainable_weights] + with torch.no_grad(): + self.g_optimizer.apply(grads, self.generator.trainable_weights) + + # Update metrics and return their value. + self.d_loss_tracker.update_state(d_loss) + self.g_loss_tracker.update_state(g_loss) + return { + "d_loss": self.d_loss_tracker.result(), + "g_loss": self.g_loss_tracker.result(), + } + + +""" +Let's test-drive it: +""" + +# Prepare the dataset. We use both the training & test MNIST digits. +batch_size = 64 +(x_train, _), (x_test, _) = keras.datasets.mnist.load_data() +all_digits = np.concatenate([x_train, x_test]) +all_digits = all_digits.astype("float32") / 255.0 +all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) + +# Create a TensorDataset +dataset = torch.utils.data.TensorDataset( + torch.from_numpy(all_digits), torch.from_numpy(all_digits) +) +# Create a DataLoader +dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True +) + +gan = GAN( + discriminator=discriminator, generator=generator, latent_dim=latent_dim +) +gan.compile( + d_optimizer=keras.optimizers.Adam(learning_rate=0.0003), + g_optimizer=keras.optimizers.Adam(learning_rate=0.0003), + loss_fn=keras.losses.BinaryCrossentropy(from_logits=True), +) + +gan.fit(dataloader, epochs=1) + +""" +The ideas behind deep learning are simple, so why should their implementation be painful? +""" diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py new file mode 100644 index 000000000000..6f6dbbf25d78 --- /dev/null +++ b/guides/distributed_training_with_jax.py @@ -0,0 +1,273 @@ +""" +Title: Multi-GPU distributed training with JAX +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/07/11 +Last modified: 2023/07/11 +Description: Guide to multi-GPU/TPU training for Keras models with JAX. +Accelerator: GPU +""" + +""" +## Introduction + +There are generally two ways to distribute computation across multiple devices: + +**Data parallelism**, where a single model gets replicated on multiple devices or +multiple machines. Each of them processes different batches of data, then they merge +their results. There exist many variants of this setup, that differ in how the different +model replicas merge results, in whether they stay in sync at every batch or whether they +are more loosely coupled, etc. + +**Model parallelism**, where different parts of a single model run on different devices, +processing a single batch of data together. This works best with models that have a +naturally-parallel architecture, such as models that feature multiple branches. + +This guide focuses on data parallelism, in particular **synchronous data parallelism**, +where the different replicas of the model stay in sync after each batch they process. +Synchronicity keeps the model convergence behavior identical to what you would see for +single-device training. + +Specifically, this guide teaches you how to use `jax.sharding` APIs to train Keras +models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16) +installed on a single machine (single host, multi-device training). This is the +most common setup for researchers and small-scale industry workflows. +""" + +""" +## Setup + +Let's start by defining the function that creates the model that we will train, +and the function that creates the dataset we will train on (MNIST in this case). +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import jax +import numpy as np +import tensorflow as tf +import keras + +from jax.experimental import mesh_utils +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + + +def get_model(): + # Make a simple convnet with batch normalization and dropout. + inputs = keras.Input(shape=(28, 28, 1)) + x = keras.layers.Rescaling(1.0 / 255.0)(inputs) + x = keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.Conv2D( + filters=24, + kernel_size=6, + use_bias=False, + strides=2, + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + strides=2, + name="large_k", + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.GlobalAveragePooling2D()(x) + x = keras.layers.Dense(256, activation="relu")(x) + x = keras.layers.Dropout(0.5)(x) + outputs = keras.layers.Dense(10)(x) + model = keras.Model(inputs, outputs) + return model + + +def get_datasets(): + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") + x_test = x_test.astype("float32") + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + print("x_train shape:", x_train.shape) + print(x_train.shape[0], "train samples") + print(x_test.shape[0], "test samples") + + # Create TF Datasets + train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_data, eval_data + + +""" +## Single-host, multi-device synchronous training + +In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16). +Each device will run a copy of your model (called a **replica**). For simplicity, in +what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality. + +**How it works** + +At each step of training: + +- The current batch of data (called **global batch**) is split into 8 different + sub-batches (called **local batches**). For instance, if the global batch has 512 + samples, each of the 8 local batches will have 64 samples. +- Each of the 8 replicas independently processes a local batch: they run a forward pass, + then a backward pass, outputting the gradient of the weights with respect to the loss of + the model on the local batch. +- The weight updates originating from local gradients are efficiently merged across the 8 + replicas. Because this is done at the end of every step, the replicas always stay in + sync. + +In practice, the process of synchronously updating the weights of the model replicas is +handled at the level of each individual weight variable. This is done through a using +a `jax.sharding.NamedSharding` that is configured to replicate the variables. + +**How to use it** + +To do single-host, multi-device synchronous training with a Keras model, you +would use the `jax.sharding` features. Here's how it works: + +- We first create a device mesh using `mesh_utils.create_device_mesh`. +- We use `jax.sharding.Mesh`, `jax.sharding.NamedSharding` and + `jax.sharding.PartitionSpec` to define how to partition JAX arrays. + - We specify that we want to replicate the model and optimizer variables + across all devices by using a spec with no axis. + - We specify that we want to shard the data across devices by using a spec + that splits along the batch dimension. +- We use `jax.device_put` to replicate the model and optimizer variables across + devices. This happens once at the beginning. +- In the training loop, for each batch that we process, we use `jax.device_put` + to split the batch across devices before invoking the train step. + +Here's the flow, where each step is split into its own utility function: +""" + +# Config +num_epochs = 2 +batch_size = 64 + +train_data, eval_data = get_datasets() +train_data = train_data.batch(batch_size, drop_remainder=True) + +model = get_model() +optimizer = keras.optimizers.Adam(1e-3) +loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +# Initialize all state with .build() +(one_batch, one_batch_labels) = next(iter(train_data)) +model.build(one_batch) +optimizer.build(model.trainable_variables) + + +# This is the loss function that will be differentiated. +# Keras provides a pure functional forward pass: model.stateless_call +def compute_loss(trainable_variables, non_trainable_variables, x, y): + y_pred, updated_non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss_value = loss(y, y_pred) + return loss_value, updated_non_trainable_variables + + +# Function to compute gradients +compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) + + +# Training step, Keras provides a pure functional optimizer.stateless_apply +@jax.jit +def train_step(train_state, x, y): + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) = train_state + (loss_value, non_trainable_variables), grads = compute_gradients( + trainable_variables, non_trainable_variables, x, y + ) + + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + + return loss_value, ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) + + +# Replicate the model and optimizer variable on all devices +def get_replicated_train_state(devices): + # All variables will be replicated on all devices + var_mesh = Mesh(devices, axis_names=("_")) + # In NamedSharding, axes not mentioned are replicated (all axes here) + var_replication = NamedSharding(var_mesh, P()) + + # Apply the distribution settings to the model variables + trainable_variables = jax.device_put( + model.trainable_variables, var_replication + ) + non_trainable_variables = jax.device_put( + model.non_trainable_variables, var_replication + ) + optimizer_variables = jax.device_put(optimizer.variables, var_replication) + + # Combine all state in a tuple + return (trainable_variables, non_trainable_variables, optimizer_variables) + + +num_devices = len(jax.local_devices()) +print(f"Running on {num_devices} devices: {jax.local_devices()}") +devices = mesh_utils.create_device_mesh((num_devices,)) + +# Data will be split along the batch axis +data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh +data_sharding = NamedSharding( + data_mesh, + P( + "batch", + ), +) # naming axes of the sharded partition + +# Display data sharding +x, y = next(iter(train_data)) +sharded_x = jax.device_put(x.numpy(), data_sharding) +print("Data sharding") +jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28])) + +train_state = get_replicated_train_state(devices) + +# Custom training loop +for epoch in range(num_epochs): + data_iter = iter(train_data) + loss_value = None # default + for data in data_iter: + x, y = data + sharded_x = jax.device_put(x.numpy(), data_sharding) + loss_value, train_state = train_step(train_state, sharded_x, y.numpy()) + print("Epoch", epoch, "loss:", loss_value) + +# Post-processing model state update to write them back into the model +trainable_variables, non_trainable_variables, optimizer_variables = train_state +for variable, value in zip(model.trainable_variables, trainable_variables): + variable.assign(value) +for variable, value in zip( + model.non_trainable_variables, non_trainable_variables +): + variable.assign(value) + +""" +That's it! +""" diff --git a/guides/distributed_training_with_tensorflow.py b/guides/distributed_training_with_tensorflow.py new file mode 100644 index 000000000000..0207eed0f1dd --- /dev/null +++ b/guides/distributed_training_with_tensorflow.py @@ -0,0 +1,276 @@ +""" +Title: Multi-GPU distributed training with TensorFlow +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2020/04/28 +Last modified: 2023/06/29 +Description: Guide to multi-GPU training for Keras models with TensorFlow. +Accelerator: GPU +""" + +""" +## Introduction + +There are generally two ways to distribute computation across multiple devices: + +**Data parallelism**, where a single model gets replicated on multiple devices or +multiple machines. Each of them processes different batches of data, then they merge +their results. There exist many variants of this setup, that differ in how the different +model replicas merge results, in whether they stay in sync at every batch or whether they +are more loosely coupled, etc. + +**Model parallelism**, where different parts of a single model run on different devices, +processing a single batch of data together. This works best with models that have a +naturally-parallel architecture, such as models that feature multiple branches. + +This guide focuses on data parallelism, in particular **synchronous data parallelism**, +where the different replicas of the model stay in sync after each batch they process. +Synchronicity keeps the model convergence behavior identical to what you would see for +single-device training. + +Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras +models on multiple GPUs, with minimal changes to your code, +on multiple GPUs (typically 2 to 16) installed on a single machine (single host, +multi-device training). This is the most common setup for researchers and small-scale +industry workflows. +""" + +""" +## Setup +""" + +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + +import tensorflow as tf +import keras + +""" +## Single-host, multi-device synchronous training + +In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each +device will run a copy of your model (called a **replica**). For simplicity, in what +follows, we'll assume we're dealing with 8 GPUs, at no loss of generality. + +**How it works** + +At each step of training: + +- The current batch of data (called **global batch**) is split into 8 different +sub-batches (called **local batches**). For instance, if the global batch has 512 +samples, each of the 8 local batches will have 64 samples. +- Each of the 8 replicas independently processes a local batch: they run a forward pass, +then a backward pass, outputting the gradient of the weights with respect to the loss of +the model on the local batch. +- The weight updates originating from local gradients are efficiently merged across the 8 +replicas. Because this is done at the end of every step, the replicas always stay in +sync. + +In practice, the process of synchronously updating the weights of the model replicas is +handled at the level of each individual weight variable. This is done through a **mirrored +variable** object. + +**How to use it** + +To do single-host, multi-device synchronous training with a Keras model, you would use +the [`tf.distribute.MirroredStrategy` API]( + https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy). +Here's how it works: + +- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you +want to use (by default the strategy will use all GPUs available). +- Use the strategy object to open a scope, and within this scope, create all the Keras +objects you need that contain variables. Typically, that means **creating & compiling the +model** inside the distribution scope. In some cases, the first call to `fit()` may also +create variables, so it's a good idea to put your `fit()` call in the scope as well. +- Train the model via `fit()` as usual. + +Importantly, we recommend that you use `tf.data.Dataset` objects to load data +in a multi-device or distributed workflow. + +Schematically, it looks like this: + +```python +# Create a MirroredStrategy. +strategy = tf.distribute.MirroredStrategy() +print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) + +# Open a strategy scope. +with strategy.scope(): + # Everything that creates variables should be under the strategy scope. + # In general this is only model construction & `compile()`. + model = Model(...) + model.compile(...) + + # Train the model on all available devices. + model.fit(train_dataset, validation_data=val_dataset, ...) + + # Test the model on all available devices. + model.evaluate(test_dataset) +``` + +Here's a simple end-to-end runnable example: +""" + + +def get_compiled_model(): + # Make a simple 2-layer densely-connected neural network. + inputs = keras.Input(shape=(784,)) + x = keras.layers.Dense(256, activation="relu")(inputs) + x = keras.layers.Dense(256, activation="relu")(x) + outputs = keras.layers.Dense(10)(x) + model = keras.Model(inputs, outputs) + model.compile( + optimizer=keras.optimizers.Adam(), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + ) + return model + + +def get_dataset(): + batch_size = 32 + num_val_samples = 10000 + + # Return the MNIST dataset in the form of a `tf.data.Dataset`. + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Preprocess the data (these are Numpy arrays) + x_train = x_train.reshape(-1, 784).astype("float32") / 255 + x_test = x_test.reshape(-1, 784).astype("float32") / 255 + y_train = y_train.astype("float32") + y_test = y_test.astype("float32") + + # Reserve num_val_samples samples for validation + x_val = x_train[-num_val_samples:] + y_val = y_train[-num_val_samples:] + x_train = x_train[:-num_val_samples] + y_train = y_train[:-num_val_samples] + return ( + tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch( + batch_size + ), + tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size), + tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size), + ) + + +# Create a MirroredStrategy. +strategy = tf.distribute.MirroredStrategy() +print("Number of devices: {}".format(strategy.num_replicas_in_sync)) + +# Open a strategy scope. +with strategy.scope(): + # Everything that creates variables should be under the strategy scope. + # In general this is only model construction & `compile()`. + model = get_compiled_model() + + # Train the model on all available devices. + train_dataset, val_dataset, test_dataset = get_dataset() + model.fit(train_dataset, epochs=2, validation_data=val_dataset) + + # Test the model on all available devices. + model.evaluate(test_dataset) + +""" +## Using callbacks to ensure fault tolerance + +When using distributed training, you should always make sure you have a strategy to +recover from failure (fault tolerance). The simplest way to handle this is to pass +`ModelCheckpoint` callback to `fit()`, to save your model +at regular intervals (e.g. every 100 batches or every epoch). You can then restart +training from your saved model. + +Here's a simple example: +""" + +# Prepare a directory to store all the checkpoints. +checkpoint_dir = "./ckpt" +if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + +def make_or_restore_model(): + # Either restore the latest model, or create a fresh one + # if there is no checkpoint available. + checkpoints = [ + os.path.join(checkpoint_dir, name) + for name in os.listdir(checkpoint_dir) + ] + if checkpoints: + latest_checkpoint = max(checkpoints, key=os.path.getctime) + print("Restoring from", latest_checkpoint) + return keras.models.load_model(latest_checkpoint) + print("Creating a new model") + return get_compiled_model() + + +def run_training(epochs=1): + # Create a MirroredStrategy. + strategy = tf.distribute.MirroredStrategy() + + # Open a strategy scope and create/restore the model + with strategy.scope(): + model = make_or_restore_model() + + callbacks = [ + # This callback saves a SavedModel every epoch + # We include the current epoch in the folder name. + keras.callbacks.ModelCheckpoint( + filepath=os.path.join(checkpoint_dir, "ckpt-{epoch}.keras"), + save_freq="epoch", + ) + ] + model.fit( + train_dataset, + epochs=epochs, + callbacks=callbacks, + validation_data=val_dataset, + verbose=2, + ) + + +# Running the first time creates the model +run_training(epochs=1) + +# Calling the same function again will resume from where we left off +run_training(epochs=1) + +""" +## `tf.data` performance tips + +When doing distributed training, the efficiency with which you load data can often become +critical. Here are a few tips to make sure your `tf.data` pipelines +run as fast as possible. + +**Note about dataset batching** + +When creating your dataset, make sure it is batched with the global batch size. +For instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you +call use a global batch size of 512. + +**Calling `dataset.cache()`** + +If you call `.cache()` on a dataset, its data will be cached after running through the +first iteration over the data. Every subsequent iteration will use the cached data. The +cache can be in memory (default) or to a local file you specify. + +This can improve performance when: + +- Your data is not expected to change from iteration to iteration +- You are reading data from a remote distributed filesystem +- You are reading data from local disk, but your data would fit in memory and your +workflow is significantly IO-bound (e.g. reading & decoding image files). + +**Calling `dataset.prefetch(buffer_size)`** + +You should almost always call `.prefetch(buffer_size)` after creating a dataset. It means +your data pipeline will run asynchronously from your model, +with new samples being preprocessed and stored in a buffer while the current batch +samples are used to train the model. The next batch will be prefetched in GPU memory by +the time the current batch is over. +""" + +""" +That's it! +""" diff --git a/guides/distributed_training_with_torch.py b/guides/distributed_training_with_torch.py new file mode 100644 index 000000000000..cf39419b29ba --- /dev/null +++ b/guides/distributed_training_with_torch.py @@ -0,0 +1,272 @@ +""" +Title: Multi-GPU distributed training with PyTorch +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/06/29 +Last modified: 2023/06/29 +Description: Guide to multi-GPU training for Keras models with PyTorch. +Accelerator: GPU +""" + +""" +## Introduction + +There are generally two ways to distribute computation across multiple devices: + +**Data parallelism**, where a single model gets replicated on multiple devices or +multiple machines. Each of them processes different batches of data, then they merge +their results. There exist many variants of this setup, that differ in how the different +model replicas merge results, in whether they stay in sync at every batch or whether they +are more loosely coupled, etc. + +**Model parallelism**, where different parts of a single model run on different devices, +processing a single batch of data together. This works best with models that have a +naturally-parallel architecture, such as models that feature multiple branches. + +This guide focuses on data parallelism, in particular **synchronous data parallelism**, +where the different replicas of the model stay in sync after each batch they process. +Synchronicity keeps the model convergence behavior identical to what you would see for +single-device training. + +Specifically, this guide teaches you how to use PyTorch's `DistributedDataParallel` +module wrapper to train Keras, with minimal changes to your code, +on multiple GPUs (typically 2 to 16) installed on a single machine (single host, +multi-device training). This is the most common setup for researchers and small-scale +industry workflows. +""" + +""" +## Setup + +Let's start by defining the function that creates the model that we will train, +and the function that creates the dataset we will train on (MNIST in this case). +""" + +import os + +os.environ["KERAS_BACKEND"] = "torch" + +import torch +import numpy as np +import keras + + +def get_model(): + # Make a simple convnet with batch normalization and dropout. + inputs = keras.Input(shape=(28, 28, 1)) + x = keras.layers.Rescaling(1.0 / 255.0)(inputs) + x = keras.layers.Conv2D( + filters=12, kernel_size=3, padding="same", use_bias=False + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.Conv2D( + filters=24, + kernel_size=6, + use_bias=False, + strides=2, + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + strides=2, + name="large_k", + )(x) + x = keras.layers.BatchNormalization(scale=False, center=True)(x) + x = keras.layers.ReLU()(x) + x = keras.layers.GlobalAveragePooling2D()(x) + x = keras.layers.Dense(256, activation="relu")(x) + x = keras.layers.Dropout(0.5)(x) + outputs = keras.layers.Dense(10)(x) + model = keras.Model(inputs, outputs) + return model + + +def get_dataset(): + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") + x_test = x_test.astype("float32") + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + print("x_train shape:", x_train.shape) + + # Create a TensorDataset + dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) + ) + return dataset + + +""" +Next, let's define a simple PyTorch training loop that targets +a GPU (note the calls to `.cuda()`). +""" + + +def train_model(model, dataloader, num_epochs, optimizer, loss_fn): + for epoch in range(num_epochs): + running_loss = 0.0 + running_loss_count = 0 + for batch_idx, (inputs, targets) in enumerate(dataloader): + inputs = inputs.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + # Forward pass + outputs = model(inputs) + loss = loss_fn(outputs, targets) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_loss_count += 1 + + # Print loss statistics + print( + f"Epoch {epoch + 1}/{num_epochs}, " + f"Loss: {running_loss / running_loss_count}" + ) + + +""" +## Single-host, multi-device synchronous training + +In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each +device will run a copy of your model (called a **replica**). For simplicity, in what +follows, we'll assume we're dealing with 8 GPUs, at no loss of generality. + +**How it works** + +At each step of training: + +- The current batch of data (called **global batch**) is split into 8 different +sub-batches (called **local batches**). For instance, if the global batch has 512 +samples, each of the 8 local batches will have 64 samples. +- Each of the 8 replicas independently processes a local batch: they run a forward pass, +then a backward pass, outputting the gradient of the weights with respect to the loss of +the model on the local batch. +- The weight updates originating from local gradients are efficiently merged across the 8 +replicas. Because this is done at the end of every step, the replicas always stay in +sync. + +In practice, the process of synchronously updating the weights of the model replicas is +handled at the level of each individual weight variable. This is done through a **mirrored +variable** object. + +**How to use it** + +To do single-host, multi-device synchronous training with a Keras model, you would use +the `torch.nn.parallel.DistributedDataParallel` module wrapper. +Here's how it works: + +- We use `torch.multiprocessing.start_processes` to start multiple Python processes, one +per device. Each process will run the `per_device_launch_fn` function. +- The `per_device_launch_fn` function does the following: + - It uses `torch.distributed.init_process_group` and `torch.cuda.set_device` + to configure the device to be used for that process. + - It uses `torch.utils.data.distributed.DistributedSampler` + and `torch.utils.data.DataLoader` to turn our data into a distributed data loader. + - It also uses `torch.nn.parallel.DistributedDataParallel` to turn our model into + a distributed PyTorch module. + - It then calls the `train_model` function. +- The `train_model` function will then run in each process, with the model using +a separate device in each process. + +Here's the flow, where each step is split into its own utility function: +""" + +# Config +num_gpu = torch.cuda.device_count() +num_epochs = 2 +batch_size = 64 +print(f"Running on {num_gpu} GPUs") + + +def setup_device(current_gpu_index, num_gpus): + # Device setup + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "56492" + device = torch.device("cuda:{}".format(current_gpu_index)) + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=num_gpus, + rank=current_gpu_index, + ) + torch.cuda.set_device(device) + + +def cleanup(): + torch.distributed.destroy_process_group() + + +def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size): + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=num_gpus, + rank=current_gpu_index, + shuffle=False, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + shuffle=False, + ) + return dataloader + + +def per_device_launch_fn(current_gpu_index, num_gpu): + # Setup the process groups + setup_device(current_gpu_index, num_gpu) + + dataset = get_dataset() + model = get_model() + + # prepare the dataloader + dataloader = prepare_dataloader( + dataset, current_gpu_index, num_gpu, batch_size + ) + + # Instantiate the torch optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + # Instantiate the torch loss function + loss_fn = torch.nn.CrossEntropyLoss() + + # Put model on device + model = model.to(current_gpu_index) + ddp_model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[current_gpu_index], output_device=current_gpu_index + ) + + train_model(ddp_model, dataloader, num_epochs, optimizer, loss_fn) + + cleanup() + + +""" +Time to start multiple processes: +""" + +if __name__ == "__main__": + # We use the "fork" method rather than "spawn" to support notebooks + torch.multiprocessing.start_processes( + per_device_launch_fn, + args=(num_gpu,), + nprocs=num_gpu, + join=True, + start_method="fork", + ) + +""" +That's it! +""" diff --git a/guides/functional_api.py b/guides/functional_api.py new file mode 100644 index 000000000000..c174953179e0 --- /dev/null +++ b/guides/functional_api.py @@ -0,0 +1,881 @@ +""" +Title: The Functional API +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2019/03/01 +Last modified: 2020/04/12 +Description: Complete guide to the functional API. +Accelerator: GPU +""" + +""" +## Setup +""" + +import numpy as np +import keras +from keras import layers +from keras import ops + +""" +## Introduction + +The Keras *functional API* is a way to create models that are more flexible +than the `keras.Sequential` API. The functional API can handle models +with non-linear topology, shared layers, and even multiple inputs or outputs. + +The main idea is that a deep learning model is usually +a directed acyclic graph (DAG) of layers. +So the functional API is a way to build *graphs of layers*. + +Consider the following model: + +
+``` +(input: 784-dimensional vectors) + ↧ +[Dense (64 units, relu activation)] + ↧ +[Dense (64 units, relu activation)] + ↧ +[Dense (10 units, softmax activation)] + ↧ +(output: logits of a probability distribution over 10 classes) +``` +
+ +This is a basic graph with three layers. +To build this model using the functional API, start by creating an input node: +""" + +inputs = keras.Input(shape=(784,)) + +""" +The shape of the data is set as a 784-dimensional vector. +The batch size is always omitted since only the shape of each sample is specified. + +If, for example, you have an image input with a shape of `(32, 32, 3)`, +you would use: +""" + +# Just for demonstration purposes. +img_inputs = keras.Input(shape=(32, 32, 3)) + +""" +The `inputs` that is returned contains information about the shape and `dtype` +of the input data that you feed to your model. +Here's the shape: +""" + +inputs.shape + +""" +Here's the dtype: +""" + +inputs.dtype + +""" +You create a new node in the graph of layers by calling a layer on this `inputs` +object: +""" + +dense = layers.Dense(64, activation="relu") +x = dense(inputs) + +""" +The "layer call" action is like drawing an arrow from "inputs" to this layer +you created. +You're "passing" the inputs to the `dense` layer, and you get `x` as the output. + +Let's add a few more layers to the graph of layers: +""" + +x = layers.Dense(64, activation="relu")(x) +outputs = layers.Dense(10)(x) + +""" +At this point, you can create a `Model` by specifying its inputs and outputs +in the graph of layers: +""" + +model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model") + +""" +Let's check out what the model summary looks like: +""" + +model.summary() + +""" +You can also plot the model as a graph: +""" + +keras.utils.plot_model(model, "my_first_model.png") + +""" +And, optionally, display the input and output shapes of each layer +in the plotted graph: +""" + +keras.utils.plot_model( + model, "my_first_model_with_shape_info.png", show_shapes=True +) + +""" +This figure and the code are almost identical. In the code version, +the connection arrows are replaced by the call operation. + +A "graph of layers" is an intuitive mental image for a deep learning model, +and the functional API is a way to create models that closely mirrors this. +""" + +""" +## Training, evaluation, and inference + +Training, evaluation, and inference work exactly in the same way for models +built using the functional API as for `Sequential` models. + +The `Model` class offers a built-in training loop (the `fit()` method) +and a built-in evaluation loop (the `evaluate()` method). Note +that you can easily [customize these loops](/guides/customizing_what_happens_in_fit/) +to implement training routines beyond supervised learning +(e.g. [GANs](https://keras.io/examples/generative/dcgan_overriding_train_step/)). + +Here, load the MNIST image data, reshape it into vectors, +fit the model on the data (while monitoring performance on a validation split), +then evaluate the model on the test data: +""" + +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +x_train = x_train.reshape(60000, 784).astype("float32") / 255 +x_test = x_test.reshape(10000, 784).astype("float32") / 255 + +model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.RMSprop(), + metrics=["accuracy"], +) + +history = model.fit( + x_train, y_train, batch_size=64, epochs=2, validation_split=0.2 +) + +test_scores = model.evaluate(x_test, y_test, verbose=2) +print("Test loss:", test_scores[0]) +print("Test accuracy:", test_scores[1]) + +""" +For further reading, see the [training and evaluation](/guides/training_with_built_in_methods/) guide. +""" + +""" +## Save and serialize + +Saving the model and serialization work the same way for models built using +the functional API as they do for `Sequential` models. The standard way +to save a functional model is to call `model.save()` +to save the entire model as a single file. You can later recreate the same model +from this file, even if the code that built the model is no longer available. + +This saved file includes the: + +- model architecture +- model weight values (that were learned during training) +- model training config, if any (as passed to `compile()`) +- optimizer and its state, if any (to restart training where you left off) +""" + +model.save("my_model.keras") +del model +# Recreate the exact same model purely from the file: +model = keras.models.load_model("my_model.keras") + +""" +For details, read the model [serialization & saving]( + /guides/serialization_and_saving/) guide. +""" + +""" +## Use the same graph of layers to define multiple models + +In the functional API, models are created by specifying their inputs +and outputs in a graph of layers. That means that a single +graph of layers can be used to generate multiple models. + +In the example below, you use the same stack of layers to instantiate two models: +an `encoder` model that turns image inputs into 16-dimensional vectors, +and an end-to-end `autoencoder` model for training. +""" + +encoder_input = keras.Input(shape=(28, 28, 1), name="img") +x = layers.Conv2D(16, 3, activation="relu")(encoder_input) +x = layers.Conv2D(32, 3, activation="relu")(x) +x = layers.MaxPooling2D(3)(x) +x = layers.Conv2D(32, 3, activation="relu")(x) +x = layers.Conv2D(16, 3, activation="relu")(x) +encoder_output = layers.GlobalMaxPooling2D()(x) + +encoder = keras.Model(encoder_input, encoder_output, name="encoder") +encoder.summary() + +x = layers.Reshape((4, 4, 1))(encoder_output) +x = layers.Conv2DTranspose(16, 3, activation="relu")(x) +x = layers.Conv2DTranspose(32, 3, activation="relu")(x) +x = layers.UpSampling2D(3)(x) +x = layers.Conv2DTranspose(16, 3, activation="relu")(x) +decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x) + +autoencoder = keras.Model(encoder_input, decoder_output, name="autoencoder") +autoencoder.summary() + +""" +Here, the decoding architecture is strictly symmetrical +to the encoding architecture, so the output shape is the same as +the input shape `(28, 28, 1)`. + +The reverse of a `Conv2D` layer is a `Conv2DTranspose` layer, +and the reverse of a `MaxPooling2D` layer is an `UpSampling2D` layer. +""" + +""" +## All models are callable, just like layers + +You can treat any model as if it were a layer by invoking it on an `Input` or +on the output of another layer. By calling a model you aren't just reusing +the architecture of the model, you're also reusing its weights. + +To see this in action, here's a different take on the autoencoder example that +creates an encoder model, a decoder model, and chains them in two calls +to obtain the autoencoder model: +""" + +encoder_input = keras.Input(shape=(28, 28, 1), name="original_img") +x = layers.Conv2D(16, 3, activation="relu")(encoder_input) +x = layers.Conv2D(32, 3, activation="relu")(x) +x = layers.MaxPooling2D(3)(x) +x = layers.Conv2D(32, 3, activation="relu")(x) +x = layers.Conv2D(16, 3, activation="relu")(x) +encoder_output = layers.GlobalMaxPooling2D()(x) + +encoder = keras.Model(encoder_input, encoder_output, name="encoder") +encoder.summary() + +decoder_input = keras.Input(shape=(16,), name="encoded_img") +x = layers.Reshape((4, 4, 1))(decoder_input) +x = layers.Conv2DTranspose(16, 3, activation="relu")(x) +x = layers.Conv2DTranspose(32, 3, activation="relu")(x) +x = layers.UpSampling2D(3)(x) +x = layers.Conv2DTranspose(16, 3, activation="relu")(x) +decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x) + +decoder = keras.Model(decoder_input, decoder_output, name="decoder") +decoder.summary() + +autoencoder_input = keras.Input(shape=(28, 28, 1), name="img") +encoded_img = encoder(autoencoder_input) +decoded_img = decoder(encoded_img) +autoencoder = keras.Model(autoencoder_input, decoded_img, name="autoencoder") +autoencoder.summary() + +""" +As you can see, the model can be nested: a model can contain sub-models +(since a model is just like a layer). +A common use case for model nesting is *ensembling*. +For example, here's how to ensemble a set of models into a single model +that averages their predictions: +""" + + +def get_model(): + inputs = keras.Input(shape=(128,)) + outputs = layers.Dense(1)(inputs) + return keras.Model(inputs, outputs) + + +model1 = get_model() +model2 = get_model() +model3 = get_model() + +inputs = keras.Input(shape=(128,)) +y1 = model1(inputs) +y2 = model2(inputs) +y3 = model3(inputs) +outputs = layers.average([y1, y2, y3]) +ensemble_model = keras.Model(inputs=inputs, outputs=outputs) + +""" +## Manipulate complex graph topologies + +### Models with multiple inputs and outputs + +The functional API makes it easy to manipulate multiple inputs and outputs. +This cannot be handled with the `Sequential` API. + +For example, if you're building a system for ranking customer issue tickets by +priority and routing them to the correct department, +then the model will have three inputs: + +- the title of the ticket (text input), +- the text body of the ticket (text input), and +- any tags added by the user (categorical input) + +This model will have two outputs: + +- the priority score between 0 and 1 (scalar sigmoid output), and +- the department that should handle the ticket (softmax output +over the set of departments). + +You can build this model in a few lines with the functional API: +""" + +num_tags = 12 # Number of unique issue tags +num_words = 10000 # Size of vocabulary obtained when preprocessing text data +num_departments = 4 # Number of departments for predictions + +title_input = keras.Input( + shape=(None,), name="title" +) # Variable-length sequence of ints +body_input = keras.Input( + shape=(None,), name="body" +) # Variable-length sequence of ints +tags_input = keras.Input( + shape=(num_tags,), name="tags" +) # Binary vectors of size `num_tags` + +# Embed each word in the title into a 64-dimensional vector +title_features = layers.Embedding(num_words, 64)(title_input) +# Embed each word in the text into a 64-dimensional vector +body_features = layers.Embedding(num_words, 64)(body_input) + +# Reduce sequence of embedded words in the title into a single 128-dimensional vector +title_features = layers.LSTM(128)(title_features) +# Reduce sequence of embedded words in the body into a single 32-dimensional vector +body_features = layers.LSTM(32)(body_features) + +# Merge all available features into a single large vector via concatenation +x = layers.concatenate([title_features, body_features, tags_input]) + +# Stick a logistic regression for priority prediction on top of the features +priority_pred = layers.Dense(1, name="priority")(x) +# Stick a department classifier on top of the features +department_pred = layers.Dense(num_departments, name="department")(x) + +# Instantiate an end-to-end model predicting both priority and department +model = keras.Model( + inputs=[title_input, body_input, tags_input], + outputs={"priority": priority_pred, "department": department_pred}, +) + +""" +Now plot the model: +""" + +keras.utils.plot_model( + model, "multi_input_and_output_model.png", show_shapes=True +) + +""" +When compiling this model, you can assign different losses to each output. +You can even assign different weights to each loss -- to modulate +their contribution to the total training loss. +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss=[ + keras.losses.BinaryCrossentropy(from_logits=True), + keras.losses.CategoricalCrossentropy(from_logits=True), + ], + loss_weights=[1.0, 0.2], +) + +""" +Since the output layers have different names, you could also specify +the losses and loss weights with the corresponding layer names: +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss={ + "priority": keras.losses.BinaryCrossentropy(from_logits=True), + "department": keras.losses.CategoricalCrossentropy(from_logits=True), + }, + loss_weights={"priority": 1.0, "department": 0.2}, +) + +""" +Train the model by passing lists of NumPy arrays of inputs and targets: +""" + +# Dummy input data +title_data = np.random.randint(num_words, size=(1280, 10)) +body_data = np.random.randint(num_words, size=(1280, 100)) +tags_data = np.random.randint(2, size=(1280, num_tags)).astype("float32") + +# Dummy target data +priority_targets = np.random.random(size=(1280, 1)) +dept_targets = np.random.randint(2, size=(1280, num_departments)) + +model.fit( + {"title": title_data, "body": body_data, "tags": tags_data}, + {"priority": priority_targets, "department": dept_targets}, + epochs=2, + batch_size=32, +) + +""" +When calling fit with a `Dataset` object, it should yield either a +tuple of lists like `([title_data, body_data, tags_data], [priority_targets, dept_targets])` +or a tuple of dictionaries like +`({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})`. + +For more detailed explanation, refer to the [training and evaluation](/guides/training_with_built_in_methods/) guide. +""" + +""" +### A toy ResNet model + +In addition to models with multiple inputs and outputs, +the functional API makes it easy to manipulate non-linear connectivity +topologies -- these are models with layers that are not connected sequentially, +which the `Sequential` API cannot handle. + +A common use case for this is residual connections. +Let's build a toy ResNet model for CIFAR10 to demonstrate this: +""" + +inputs = keras.Input(shape=(32, 32, 3), name="img") +x = layers.Conv2D(32, 3, activation="relu")(inputs) +x = layers.Conv2D(64, 3, activation="relu")(x) +block_1_output = layers.MaxPooling2D(3)(x) + +x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output) +x = layers.Conv2D(64, 3, activation="relu", padding="same")(x) +block_2_output = layers.add([x, block_1_output]) + +x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output) +x = layers.Conv2D(64, 3, activation="relu", padding="same")(x) +block_3_output = layers.add([x, block_2_output]) + +x = layers.Conv2D(64, 3, activation="relu")(block_3_output) +x = layers.GlobalAveragePooling2D()(x) +x = layers.Dense(256, activation="relu")(x) +x = layers.Dropout(0.5)(x) +outputs = layers.Dense(10)(x) + +model = keras.Model(inputs, outputs, name="toy_resnet") +model.summary() + +""" +Plot the model: +""" + +keras.utils.plot_model(model, "mini_resnet.png", show_shapes=True) + +""" +Now train the model: +""" + +(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() + +x_train = x_train.astype("float32") / 255.0 +x_test = x_test.astype("float32") / 255.0 +y_train = keras.utils.to_categorical(y_train, 10) +y_test = keras.utils.to_categorical(y_test, 10) + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss=keras.losses.CategoricalCrossentropy(from_logits=True), + metrics=["acc"], +) +# We restrict the data to the first 1000 samples so as to limit execution time +# on Colab. Try to train on the entire dataset until convergence! +model.fit( + x_train[:1000], + y_train[:1000], + batch_size=64, + epochs=1, + validation_split=0.2, +) + +""" +## Shared layers + +Another good use for the functional API are models that use *shared layers*. +Shared layers are layer instances that are reused multiple times in the same model -- +they learn features that correspond to multiple paths in the graph-of-layers. + +Shared layers are often used to encode inputs from similar spaces +(say, two different pieces of text that feature similar vocabulary). +They enable sharing of information across these different inputs, +and they make it possible to train such a model on less data. +If a given word is seen in one of the inputs, +that will benefit the processing of all inputs that pass through the shared layer. + +To share a layer in the functional API, call the same layer instance multiple times. +For instance, here's an `Embedding` layer shared across two different text inputs: +""" + +# Embedding for 1000 unique words mapped to 128-dimensional vectors +shared_embedding = layers.Embedding(1000, 128) + +# Variable-length sequence of integers +text_input_a = keras.Input(shape=(None,), dtype="int32") + +# Variable-length sequence of integers +text_input_b = keras.Input(shape=(None,), dtype="int32") + +# Reuse the same layer to encode both inputs +encoded_input_a = shared_embedding(text_input_a) +encoded_input_b = shared_embedding(text_input_b) + +""" +## Extract and reuse nodes in the graph of layers + +Because the graph of layers you are manipulating is a static data structure, +it can be accessed and inspected. And this is how you are able to plot +functional models as images. + +This also means that you can access the activations of intermediate layers +("nodes" in the graph) and reuse them elsewhere -- +which is very useful for something like feature extraction. + +Let's look at an example. This is a VGG19 model with weights pretrained on ImageNet: +""" + +vgg19 = keras.applications.VGG19() + +""" +And these are the intermediate activations of the model, +obtained by querying the graph data structure: +""" + +features_list = [layer.output for layer in vgg19.layers] + +""" +Use these features to create a new feature-extraction model that returns +the values of the intermediate layer activations: +""" + +feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list) + +img = np.random.random((1, 224, 224, 3)).astype("float32") +extracted_features = feat_extraction_model(img) + +""" +This comes in handy for tasks like +[neural style transfer](https://keras.io/examples/generative/neural_style_transfer/), +among other things. +""" + +""" +## Extend the API using custom layers + +`keras` includes a wide range of built-in layers, for example: + +- Convolutional layers: `Conv1D`, `Conv2D`, `Conv3D`, `Conv2DTranspose` +- Pooling layers: `MaxPooling1D`, `MaxPooling2D`, `MaxPooling3D`, `AveragePooling1D` +- RNN layers: `GRU`, `LSTM`, `ConvLSTM2D` +- `BatchNormalization`, `Dropout`, `Embedding`, etc. + +But if you don't find what you need, it's easy to extend the API by creating +your own layers. All layers subclass the `Layer` class and implement: + +- `call` method, that specifies the computation done by the layer. +- `build` method, that creates the weights of the layer (this is just a style +convention since you can create weights in `__init__`, as well). + +To learn more about creating layers from scratch, read +[custom layers and models](/guides/making_new_layers_and_models_via_subclassing) guide. + +The following is a basic implementation of `keras.layers.Dense`: +""" + + +class CustomDense(layers.Layer): + def __init__(self, units=32): + super().__init__() + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), initializer="random_normal", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +inputs = keras.Input((4,)) +outputs = CustomDense(10)(inputs) + +model = keras.Model(inputs, outputs) + +""" +For serialization support in your custom layer, define a `get_config()` +method that returns the constructor arguments of the layer instance: +""" + + +class CustomDense(layers.Layer): + def __init__(self, units=32): + super().__init__() + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), initializer="random_normal", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + def get_config(self): + return {"units": self.units} + + +inputs = keras.Input((4,)) +outputs = CustomDense(10)(inputs) + +model = keras.Model(inputs, outputs) +config = model.get_config() + +new_model = keras.Model.from_config( + config, custom_objects={"CustomDense": CustomDense} +) + +""" +Optionally, implement the class method `from_config(cls, config)` which is used +when recreating a layer instance given its config dictionary. +The default implementation of `from_config` is: + +```python +def from_config(cls, config): + return cls(**config) +``` +""" + +""" +## When to use the functional API + +Should you use the Keras functional API to create a new model, +or just subclass the `Model` class directly? In general, the functional API +is higher-level, easier and safer, and has a number of +features that subclassed models do not support. + +However, model subclassing provides greater flexibility when building models +that are not easily expressible as directed acyclic graphs of layers. +For example, you could not implement a Tree-RNN with the functional API +and would have to subclass `Model` directly. + +For an in-depth look at the differences between the functional API and +model subclassing, read +[What are Symbolic and Imperative APIs in TensorFlow 2.0?](https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html). + +### Functional API strengths: + +The following properties are also true for Sequential models +(which are also data structures), but are not true for subclassed models +(which are Python bytecode, not data structures). + +#### Less verbose + +There is no `super().__init__(...)`, no `def call(self, ...):`, etc. + +Compare: + +```python +inputs = keras.Input(shape=(32,)) +x = layers.Dense(64, activation='relu')(inputs) +outputs = layers.Dense(10)(x) +mlp = keras.Model(inputs, outputs) +``` + +With the subclassed version: + +```python +class MLP(keras.Model): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense_1 = layers.Dense(64, activation='relu') + self.dense_2 = layers.Dense(10) + + def call(self, inputs): + x = self.dense_1(inputs) + return self.dense_2(x) + +# Instantiate the model. +mlp = MLP() +# Necessary to create the model's state. +# The model doesn't have a state until it's called at least once. +_ = mlp(ops.zeros((1, 32))) +``` + +#### Model validation while defining its connectivity graph + +In the functional API, the input specification (shape and dtype) is created +in advance (using `Input`). Every time you call a layer, +the layer checks that the specification passed to it matches its assumptions, +and it will raise a helpful error message if not. + +This guarantees that any model you can build with the functional API will run. +All debugging -- other than convergence-related debugging -- +happens statically during the model construction and not at execution time. +This is similar to type checking in a compiler. + +#### A functional model is plottable and inspectable + +You can plot the model as a graph, and you can easily access intermediate nodes +in this graph. For example, to extract and reuse the activations of intermediate +layers (as seen in a previous example): + +```python +features_list = [layer.output for layer in vgg19.layers] +feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list) +``` + +#### A functional model can be serialized or cloned + +Because a functional model is a data structure rather than a piece of code, +it is safely serializable and can be saved as a single file +that allows you to recreate the exact same model +without having access to any of the original code. +See the [serialization & saving guide](/guides/serialization_and_saving/). + +To serialize a subclassed model, it is necessary for the implementer +to specify a `get_config()` +and `from_config()` method at the model level. + + +### Functional API weakness: + +#### It does not support dynamic architectures + +The functional API treats models as DAGs of layers. +This is true for most deep learning architectures, but not all -- for example, +recursive networks or Tree RNNs do not follow this assumption and cannot +be implemented in the functional API. +""" + +""" +## Mix-and-match API styles + +Choosing between the functional API or Model subclassing isn't a +binary decision that restricts you into one category of models. +All models in the `keras` API can interact with each other, whether they're +`Sequential` models, functional models, or subclassed models that are written +from scratch. + +You can always use a functional model or `Sequential` model +as part of a subclassed model or layer: +""" + +units = 32 +timesteps = 10 +input_dim = 5 + +# Define a Functional model +inputs = keras.Input((None, units)) +x = layers.GlobalAveragePooling1D()(inputs) +outputs = layers.Dense(1)(x) +model = keras.Model(inputs, outputs) + + +class CustomRNN(layers.Layer): + def __init__(self): + super().__init__() + self.units = units + self.projection_1 = layers.Dense(units=units, activation="tanh") + self.projection_2 = layers.Dense(units=units, activation="tanh") + # Our previously-defined Functional model + self.classifier = model + + def call(self, inputs): + outputs = [] + state = ops.zeros(shape=(inputs.shape[0], self.units)) + for t in range(inputs.shape[1]): + x = inputs[:, t, :] + h = self.projection_1(x) + y = h + self.projection_2(state) + state = y + outputs.append(y) + features = ops.stack(outputs, axis=1) + print(features.shape) + return self.classifier(features) + + +rnn_model = CustomRNN() +_ = rnn_model(ops.zeros((1, timesteps, input_dim))) + +""" +You can use any subclassed layer or model in the functional API +as long as it implements a `call` method that follows one of the following patterns: + +- `call(self, inputs, **kwargs)` -- +Where `inputs` is a tensor or a nested structure of tensors (e.g. a list of tensors), +and where `**kwargs` are non-tensor arguments (non-inputs). +- `call(self, inputs, training=None, **kwargs)` -- +Where `training` is a boolean indicating whether the layer should behave +in training mode and inference mode. +- `call(self, inputs, mask=None, **kwargs)` -- +Where `mask` is a boolean mask tensor (useful for RNNs, for instance). +- `call(self, inputs, training=None, mask=None, **kwargs)` -- +Of course, you can have both masking and training-specific behavior at the same time. + +Additionally, if you implement the `get_config` method on your custom Layer or model, +the functional models you create will still be serializable and cloneable. + +Here's a quick example of a custom RNN, written from scratch, +being used in a functional model: +""" + +units = 32 +timesteps = 10 +input_dim = 5 +batch_size = 16 + + +class CustomRNN(layers.Layer): + def __init__(self): + super().__init__() + self.units = units + self.projection_1 = layers.Dense(units=units, activation="tanh") + self.projection_2 = layers.Dense(units=units, activation="tanh") + self.classifier = layers.Dense(1) + + def call(self, inputs): + outputs = [] + state = ops.zeros(shape=(inputs.shape[0], self.units)) + for t in range(inputs.shape[1]): + x = inputs[:, t, :] + h = self.projection_1(x) + y = h + self.projection_2(state) + state = y + outputs.append(y) + features = ops.stack(outputs, axis=1) + return self.classifier(features) + + +# Note that you specify a static batch size for the inputs with the `batch_shape` +# arg, because the inner computation of `CustomRNN` requires a static batch size +# (when you create the `state` zeros tensor). +inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim)) +x = layers.Conv1D(32, 3)(inputs) +outputs = CustomRNN()(x) + +model = keras.Model(inputs, outputs) + +rnn_model = CustomRNN() +_ = rnn_model(ops.zeros((1, 10, 5))) diff --git a/guides/making_new_layers_and_models_via_subclassing.py b/guides/making_new_layers_and_models_via_subclassing.py new file mode 100644 index 000000000000..76766763320a --- /dev/null +++ b/guides/making_new_layers_and_models_via_subclassing.py @@ -0,0 +1,679 @@ +""" +Title: Making new layers and models via subclassing +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2019/03/01 +Last modified: 2023/06/25 +Description: Complete guide to writing `Layer` and `Model` objects from scratch. +Accelerator: None +""" + +""" +## Introduction + +This guide will cover everything you need to know to build your own +subclassed layers and models. In particular, you'll learn about the following features: + +- The `Layer` class +- The `add_weight()` method +- Trainable and non-trainable weights +- The `build()` method +- Making sure your layers can be used with any backend +- The `add_loss()` method +- The `training` argument in `call()` +- The `mask` argument in `call()` +- Making sure your layers can be serialized + +Let's dive in. +""" +""" +## Setup +""" + +import numpy as np +import keras +from keras import ops +from keras import layers + +""" +## The `Layer` class: the combination of state (weights) and some computation + +One of the central abstractions in Keras is the `Layer` class. A layer +encapsulates both a state (the layer's "weights") and a transformation from +inputs to outputs (a "call", the layer's forward pass). + +Here's a densely-connected layer. It has two state variables: +the variables `w` and `b`. +""" + + +class Linear(keras.layers.Layer): + def __init__(self, units=32, input_dim=32): + super().__init__() + self.w = self.add_weight( + shape=(input_dim, units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(units,), initializer="zeros", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +""" +You would use a layer by calling it on some tensor input(s), much like a Python +function. +""" + +x = ops.ones((2, 2)) +linear_layer = Linear(4, 2) +y = linear_layer(x) +print(y) + +""" +Note that the weights `w` and `b` are automatically tracked by the layer upon +being set as layer attributes: +""" + +assert linear_layer.weights == [linear_layer.w, linear_layer.b] + +""" +## Layers can have non-trainable weights + +Besides trainable weights, you can add non-trainable weights to a layer as +well. Such weights are meant not to be taken into account during +backpropagation, when you are training the layer. + +Here's how to add and use a non-trainable weight: +""" + + +class ComputeSum(keras.layers.Layer): + def __init__(self, input_dim): + super().__init__() + self.total = self.add_weight( + initializer="zeros", shape=(input_dim,), trainable=False + ) + + def call(self, inputs): + self.total.assign_add(ops.sum(inputs, axis=0)) + return self.total + + +x = ops.ones((2, 2)) +my_sum = ComputeSum(2) +y = my_sum(x) +print(y.numpy()) +y = my_sum(x) +print(y.numpy()) + +""" +It's part of `layer.weights`, but it gets categorized as a non-trainable weight: +""" + +print("weights:", len(my_sum.weights)) +print("non-trainable weights:", len(my_sum.non_trainable_weights)) + +# It's not included in the trainable weights: +print("trainable_weights:", my_sum.trainable_weights) + +""" +## Best practice: deferring weight creation until the shape of the inputs is known + +Our `Linear` layer above took an `input_dim` argument that was used to compute +the shape of the weights `w` and `b` in `__init__()`: +""" + + +class Linear(keras.layers.Layer): + def __init__(self, units=32, input_dim=32): + super().__init__() + self.w = self.add_weight( + shape=(input_dim, units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(units,), initializer="zeros", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +""" +In many cases, you may not know in advance the size of your inputs, and you +would like to lazily create weights when that value becomes known, some time +after instantiating the layer. + +In the Keras API, we recommend creating layer weights in the +`build(self, inputs_shape)` method of your layer. Like this: +""" + + +class Linear(keras.layers.Layer): + def __init__(self, units=32): + super().__init__() + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), initializer="random_normal", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +""" +The `__call__()` method of your layer will automatically run build the first time +it is called. You now have a layer that's lazy and thus easier to use: +""" + +# At instantiation, we don't know on what inputs this is going to get called +linear_layer = Linear(32) + +# The layer's weights are created dynamically the first time the layer is called +y = linear_layer(x) + +""" +Implementing `build()` separately as shown above nicely separates creating weights +only once from using weights in every call. +""" + +""" +## Layers are recursively composable + +If you assign a Layer instance as an attribute of another Layer, the outer layer +will start tracking the weights created by the inner layer. + +We recommend creating such sublayers in the `__init__()` method and leave it to +the first `__call__()` to trigger building their weights. +""" + + +class MLPBlock(keras.layers.Layer): + def __init__(self): + super().__init__() + self.linear_1 = Linear(32) + self.linear_2 = Linear(32) + self.linear_3 = Linear(1) + + def call(self, inputs): + x = self.linear_1(inputs) + x = keras.activations.relu(x) + x = self.linear_2(x) + x = keras.activations.relu(x) + return self.linear_3(x) + + +mlp = MLPBlock() +y = mlp( + ops.ones(shape=(3, 64)) +) # The first call to the `mlp` will create the weights +print("weights:", len(mlp.weights)) +print("trainable weights:", len(mlp.trainable_weights)) + +""" +## Backend-agnostic layers and backend-specific layers + +As long as a layer only uses APIs from the `keras.ops` namespace +(or other Keras namespaces such as `keras.activations`, `keras.random`, or `keras.layers`), +then it can be used with any backend -- TensorFlow, JAX, or PyTorch. + +All layers you've seen so far in this guide work with all Keras backends. + +The `keras.ops` namespace gives you access to: + +- The NumPy API, e.g. `ops.matmul`, `ops.sum`, `ops.reshape`, `ops.stack`, etc. +- Neural networks-specific APIs such as `ops.softmax`, `ops.conv`, `ops.binary_crossentropy`, `ops.relu`, etc. + +You can also use backend-native APIs in your layers (such as `tf.nn` functions), +but if you do this, then your layer will only be usable with the backend in question. +For instance, you could write the following JAX-specific layer using `jax.numpy`: + +```python +import jax + +class Linear(keras.layers.Layer): + ... + + def call(self, inputs): + return jax.numpy.matmul(inputs, self.w) + self.b +``` + +This would be the equivalent TensorFlow-specific layer: + +```python +import tensorflow as tf + +class Linear(keras.layers.Layer): + ... + + def call(self, inputs): + return tf.matmul(inputs, self.w) + self.b +``` + +And this would be the equivalent PyTorch-specific layer: + +```python +import torch + +class Linear(keras.layers.Layer): + ... + + def call(self, inputs): + return torch.matmul(inputs, self.w) + self.b +``` + +Because cross-backend compatibility is a tremendously useful property, we strongly +recommend that you seek to always make your layers backend-agnostic by leveraging +only Keras APIs. +""" + +""" +## The `add_loss()` method + +When writing the `call()` method of a layer, you can create loss tensors that +you will want to use later, when writing your training loop. This is doable by +calling `self.add_loss(value)`: +""" + + +# A layer that creates an activity regularization loss +class ActivityRegularizationLayer(keras.layers.Layer): + def __init__(self, rate=1e-2): + super().__init__() + self.rate = rate + + def call(self, inputs): + self.add_loss(self.rate * ops.mean(inputs)) + return inputs + + +""" +These losses (including those created by any inner layer) can be retrieved via +`layer.losses`. This property is reset at the start of every `__call__()` to +the top-level layer, so that `layer.losses` always contains the loss values +created during the last forward pass. +""" + + +class OuterLayer(keras.layers.Layer): + def __init__(self): + super().__init__() + self.activity_reg = ActivityRegularizationLayer(1e-2) + + def call(self, inputs): + return self.activity_reg(inputs) + + +layer = OuterLayer() +assert ( + len(layer.losses) == 0 +) # No losses yet since the layer has never been called + +_ = layer(ops.zeros((1, 1))) +assert len(layer.losses) == 1 # We created one loss value + +# `layer.losses` gets reset at the start of each __call__ +_ = layer(ops.zeros((1, 1))) +assert len(layer.losses) == 1 # This is the loss created during the call above + +""" +In addition, the `loss` property also contains regularization losses created +for the weights of any inner layer: +""" + + +class OuterLayerWithKernelRegularizer(keras.layers.Layer): + def __init__(self): + super().__init__() + self.dense = keras.layers.Dense( + 32, kernel_regularizer=keras.regularizers.l2(1e-3) + ) + + def call(self, inputs): + return self.dense(inputs) + + +layer = OuterLayerWithKernelRegularizer() +_ = layer(ops.zeros((1, 1))) + +# This is `1e-3 * sum(layer.dense.kernel ** 2)`, +# created by the `kernel_regularizer` above. +print(layer.losses) + +""" +These losses are meant to be taken into account when writing custom training loops. + +They also work seamlessly with `fit()` (they get automatically summed and added to the main loss, if any): +""" + +inputs = keras.Input(shape=(3,)) +outputs = ActivityRegularizationLayer()(inputs) +model = keras.Model(inputs, outputs) + +# If there is a loss passed in `compile`, the regularization +# losses get added to it +model.compile(optimizer="adam", loss="mse") +model.fit(np.random.random((2, 3)), np.random.random((2, 3))) + +# It's also possible not to pass any loss in `compile`, +# since the model already has a loss to minimize, via the `add_loss` +# call during the forward pass! +model.compile(optimizer="adam") +model.fit(np.random.random((2, 3)), np.random.random((2, 3))) + +""" +## You can optionally enable serialization on your layers + +If you need your custom layers to be serializable as part of a +[Functional model](/guides/functional_api/), you can optionally implement a `get_config()` +method: +""" + + +class Linear(keras.layers.Layer): + def __init__(self, units=32): + super().__init__() + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), initializer="random_normal", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + def get_config(self): + return {"units": self.units} + + +# Now you can recreate the layer from its config: +layer = Linear(64) +config = layer.get_config() +print(config) +new_layer = Linear.from_config(config) + +""" +Note that the `__init__()` method of the base `Layer` class takes some keyword +arguments, in particular a `name` and a `dtype`. It's good practice to pass +these arguments to the parent class in `__init__()` and to include them in the +layer config: +""" + + +class Linear(keras.layers.Layer): + def __init__(self, units=32, **kwargs): + super().__init__(**kwargs) + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="random_normal", + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), initializer="random_normal", trainable=True + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + def get_config(self): + config = super().get_config() + config.update({"units": self.units}) + return config + + +layer = Linear(64) +config = layer.get_config() +print(config) +new_layer = Linear.from_config(config) + +""" +If you need more flexibility when deserializing the layer from its config, you +can also override the `from_config()` class method. This is the base +implementation of `from_config()`: + +```python +def from_config(cls, config): + return cls(**config) +``` + +To learn more about serialization and saving, see the complete +[guide to saving and serializing models](/guides/serialization_and_saving/). +""" + +""" +## Privileged `training` argument in the `call()` method + +Some layers, in particular the `BatchNormalization` layer and the `Dropout` +layer, have different behaviors during training and inference. For such +layers, it is standard practice to expose a `training` (boolean) argument in +the `call()` method. + +By exposing this argument in `call()`, you enable the built-in training and +evaluation loops (e.g. `fit()`) to correctly use the layer in training and +inference. +""" + + +class CustomDropout(keras.layers.Layer): + def __init__(self, rate, **kwargs): + super().__init__(**kwargs) + self.rate = rate + + def call(self, inputs, training=None): + if training: + return keras.random.dropout(inputs, rate=self.rate) + return inputs + + +""" +## Privileged `mask` argument in the `call()` method + +The other privileged argument supported by `call()` is the `mask` argument. + +You will find it in all Keras RNN layers. A mask is a boolean tensor (one +boolean value per timestep in the input) used to skip certain input timesteps +when processing timeseries data. + +Keras will automatically pass the correct `mask` argument to `__call__()` for +layers that support it, when a mask is generated by a prior layer. +Mask-generating layers are the `Embedding` +layer configured with `mask_zero=True`, and the `Masking` layer. +""" + +""" +## The `Model` class + +In general, you will use the `Layer` class to define inner computation blocks, +and will use the `Model` class to define the outer model -- the object you +will train. + +For instance, in a ResNet50 model, you would have several ResNet blocks +subclassing `Layer`, and a single `Model` encompassing the entire ResNet50 +network. + +The `Model` class has the same API as `Layer`, with the following differences: + +- It exposes built-in training, evaluation, and prediction loops +(`model.fit()`, `model.evaluate()`, `model.predict()`). +- It exposes the list of its inner layers, via the `model.layers` property. +- It exposes saving and serialization APIs (`save()`, `save_weights()`...) + +Effectively, the `Layer` class corresponds to what we refer to in the +literature as a "layer" (as in "convolution layer" or "recurrent layer") or as +a "block" (as in "ResNet block" or "Inception block"). + +Meanwhile, the `Model` class corresponds to what is referred to in the +literature as a "model" (as in "deep learning model") or as a "network" (as in +"deep neural network"). + +So if you're wondering, "should I use the `Layer` class or the `Model` class?", +ask yourself: will I need to call `fit()` on it? Will I need to call `save()` +on it? If so, go with `Model`. If not (either because your class is just a block +in a bigger system, or because you are writing training & saving code yourself), +use `Layer`. + +For instance, we could take our mini-resnet example above, and use it to build +a `Model` that we could train with `fit()`, and that we could save with +`save_weights()`: +""" + +""" +```python +class ResNet(keras.Model): + + def __init__(self, num_classes=1000): + super().__init__() + self.block_1 = ResNetBlock() + self.block_2 = ResNetBlock() + self.global_pool = layers.GlobalAveragePooling2D() + self.classifier = Dense(num_classes) + + def call(self, inputs): + x = self.block_1(inputs) + x = self.block_2(x) + x = self.global_pool(x) + return self.classifier(x) + + +resnet = ResNet() +dataset = ... +resnet.fit(dataset, epochs=10) +resnet.save(filepath.keras) +``` +""" + +""" +## Putting it all together: an end-to-end example + +Here's what you've learned so far: + +- A `Layer` encapsulate a state (created in `__init__()` or `build()`) and some +computation (defined in `call()`). +- Layers can be recursively nested to create new, bigger computation blocks. +- Layers are backend-agnostic as long as they only use Keras APIs. You can use +backend-native APIs (such as `jax.numpy`, `torch.nn` or `tf.nn`), but then +your layer will only be usable with that specific backend. +- Layers can create and track losses (typically regularization losses) +via `add_loss()`. +- The outer container, the thing you want to train, is a `Model`. A `Model` is +just like a `Layer`, but with added training and serialization utilities. + +Let's put all of these things together into an end-to-end example: we're going +to implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion +-- so that it runs the same with TensorFlow, JAX, and PyTorch. +We'll train it on MNIST digits. + +Our VAE will be a subclass of `Model`, built as a nested composition of layers +that subclass `Layer`. It will feature a regularization loss (KL divergence). +""" + + +class Sampling(layers.Layer): + """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" + + def call(self, inputs): + z_mean, z_log_var = inputs + batch = ops.shape(z_mean)[0] + dim = ops.shape(z_mean)[1] + epsilon = keras.random.normal(shape=(batch, dim)) + return z_mean + ops.exp(0.5 * z_log_var) * epsilon + + +class Encoder(layers.Layer): + """Maps MNIST digits to a triplet (z_mean, z_log_var, z).""" + + def __init__( + self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs + ): + super().__init__(name=name, **kwargs) + self.dense_proj = layers.Dense(intermediate_dim, activation="relu") + self.dense_mean = layers.Dense(latent_dim) + self.dense_log_var = layers.Dense(latent_dim) + self.sampling = Sampling() + + def call(self, inputs): + x = self.dense_proj(inputs) + z_mean = self.dense_mean(x) + z_log_var = self.dense_log_var(x) + z = self.sampling((z_mean, z_log_var)) + return z_mean, z_log_var, z + + +class Decoder(layers.Layer): + """Converts z, the encoded digit vector, back into a readable digit.""" + + def __init__( + self, original_dim, intermediate_dim=64, name="decoder", **kwargs + ): + super().__init__(name=name, **kwargs) + self.dense_proj = layers.Dense(intermediate_dim, activation="relu") + self.dense_output = layers.Dense(original_dim, activation="sigmoid") + + def call(self, inputs): + x = self.dense_proj(inputs) + return self.dense_output(x) + + +class VariationalAutoEncoder(keras.Model): + """Combines the encoder and decoder into an end-to-end model for training.""" + + def __init__( + self, + original_dim, + intermediate_dim=64, + latent_dim=32, + name="autoencoder", + **kwargs, + ): + super().__init__(name=name, **kwargs) + self.original_dim = original_dim + self.encoder = Encoder( + latent_dim=latent_dim, intermediate_dim=intermediate_dim + ) + self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim) + + def call(self, inputs): + z_mean, z_log_var, z = self.encoder(inputs) + reconstructed = self.decoder(z) + # Add KL divergence regularization loss. + kl_loss = -0.5 * ops.mean( + z_log_var - ops.square(z_mean) - ops.exp(z_log_var) + 1 + ) + self.add_loss(kl_loss) + return reconstructed + + +""" +Let's train it on MNIST using the `fit()` API: +""" + +(x_train, _), _ = keras.datasets.mnist.load_data() +x_train = x_train.reshape(60000, 784).astype("float32") / 255 + +original_dim = 784 +vae = VariationalAutoEncoder(784, 64, 32) + +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +vae.compile(optimizer, loss=keras.losses.MeanSquaredError()) + +vae.fit(x_train, x_train, epochs=2, batch_size=64) diff --git a/guides/sequential_model.py b/guides/sequential_model.py new file mode 100644 index 000000000000..9b481f2d72d0 --- /dev/null +++ b/guides/sequential_model.py @@ -0,0 +1,371 @@ +""" +Title: The Sequential model +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2020/04/12 +Last modified: 2023/06/25 +Description: Complete guide to the Sequential model. +Accelerator: GPU +""" + +""" +## Setup + +""" + +import keras +from keras import layers +from keras import ops + +""" +## When to use a Sequential model + +A `Sequential` model is appropriate for **a plain stack of layers** +where each layer has **exactly one input tensor and one output tensor**. + +Schematically, the following `Sequential` model: +""" + +# Define Sequential model with 3 layers +model = keras.Sequential( + [ + layers.Dense(2, activation="relu", name="layer1"), + layers.Dense(3, activation="relu", name="layer2"), + layers.Dense(4, name="layer3"), + ] +) +# Call model on a test input +x = ops.ones((3, 3)) +y = model(x) + +""" +is equivalent to this function: +""" + +# Create 3 layers +layer1 = layers.Dense(2, activation="relu", name="layer1") +layer2 = layers.Dense(3, activation="relu", name="layer2") +layer3 = layers.Dense(4, name="layer3") + +# Call layers on a test input +x = ops.ones((3, 3)) +y = layer3(layer2(layer1(x))) + +""" +A Sequential model is **not appropriate** when: + +- Your model has multiple inputs or multiple outputs +- Any of your layers has multiple inputs or multiple outputs +- You need to do layer sharing +- You want non-linear topology (e.g. a residual connection, a multi-branch +model) +""" + +""" +## Creating a Sequential model + +You can create a Sequential model by passing a list of layers to the Sequential +constructor: +""" + +model = keras.Sequential( + [ + layers.Dense(2, activation="relu"), + layers.Dense(3, activation="relu"), + layers.Dense(4), + ] +) + +""" +Its layers are accessible via the `layers` attribute: +""" + +model.layers + +""" +You can also create a Sequential model incrementally via the `add()` method: +""" + +model = keras.Sequential() +model.add(layers.Dense(2, activation="relu")) +model.add(layers.Dense(3, activation="relu")) +model.add(layers.Dense(4)) + +""" +Note that there's also a corresponding `pop()` method to remove layers: +a Sequential model behaves very much like a list of layers. +""" + +model.pop() +print(len(model.layers)) # 2 + +""" +Also note that the Sequential constructor accepts a `name` argument, just like +any layer or model in Keras. This is useful to annotate TensorBoard graphs +with semantically meaningful names. +""" + +model = keras.Sequential(name="my_sequential") +model.add(layers.Dense(2, activation="relu", name="layer1")) +model.add(layers.Dense(3, activation="relu", name="layer2")) +model.add(layers.Dense(4, name="layer3")) + +""" +## Specifying the input shape in advance + +Generally, all layers in Keras need to know the shape of their inputs +in order to be able to create their weights. So when you create a layer like +this, initially, it has no weights: +""" + +layer = layers.Dense(3) +layer.weights # Empty + +""" +It creates its weights the first time it is called on an input, since the shape +of the weights depends on the shape of the inputs: +""" + +# Call layer on a test input +x = ops.ones((1, 4)) +y = layer(x) +layer.weights # Now it has weights, of shape (4, 3) and (3,) + +""" +Naturally, this also applies to Sequential models. When you instantiate a +Sequential model without an input shape, it isn't "built": it has no weights +(and calling +`model.weights` results in an error stating just this). The weights are created +when the model first sees some input data: +""" + +model = keras.Sequential( + [ + layers.Dense(2, activation="relu"), + layers.Dense(3, activation="relu"), + layers.Dense(4), + ] +) # No weights at this stage! + +# At this point, you can't do this: +# model.weights + +# You also can't do this: +# model.summary() + +# Call the model on a test input +x = ops.ones((1, 4)) +y = model(x) +print("Number of weights after calling the model:", len(model.weights)) # 6 + +""" +Once a model is "built", you can call its `summary()` method to display its +contents: +""" + +model.summary() + +""" +However, it can be very useful when building a Sequential model incrementally +to be able to display the summary of the model so far, including the current +output shape. In this case, you should start your model by passing an `Input` +object to your model, so that it knows its input shape from the start: +""" + +model = keras.Sequential() +model.add(keras.Input(shape=(4,))) +model.add(layers.Dense(2, activation="relu")) + +model.summary() + +""" +Note that the `Input` object is not displayed as part of `model.layers`, since +it isn't a layer: +""" + +model.layers + +""" +Models built with a predefined input shape like this always have weights (even +before seeing any data) and always have a defined output shape. + +In general, it's a recommended best practice to always specify the input shape +of a Sequential model in advance if you know what it is. +""" + +""" +## A common debugging workflow: `add()` + `summary()` + +When building a new Sequential architecture, it's useful to incrementally stack +layers with `add()` and frequently print model summaries. For instance, this +enables you to monitor how a stack of `Conv2D` and `MaxPooling2D` layers is +downsampling image feature maps: +""" + +model = keras.Sequential() +model.add(keras.Input(shape=(250, 250, 3))) # 250x250 RGB images +model.add(layers.Conv2D(32, 5, strides=2, activation="relu")) +model.add(layers.Conv2D(32, 3, activation="relu")) +model.add(layers.MaxPooling2D(3)) + +# Can you guess what the current output shape is at this point? Probably not. +# Let's just print it: +model.summary() + +# The answer was: (40, 40, 32), so we can keep downsampling... + +model.add(layers.Conv2D(32, 3, activation="relu")) +model.add(layers.Conv2D(32, 3, activation="relu")) +model.add(layers.MaxPooling2D(3)) +model.add(layers.Conv2D(32, 3, activation="relu")) +model.add(layers.Conv2D(32, 3, activation="relu")) +model.add(layers.MaxPooling2D(2)) + +# And now? +model.summary() + +# Now that we have 4x4 feature maps, time to apply global max pooling. +model.add(layers.GlobalMaxPooling2D()) + +# Finally, we add a classification layer. +model.add(layers.Dense(10)) + +""" +Very practical, right? + + +""" + +""" +## What to do once you have a model + +Once your model architecture is ready, you will want to: + +- Train your model, evaluate it, and run inference. See our +[guide to training & evaluation with the built-in loops]( + /guides/training_with_built_in_methods/) +- Save your model to disk and restore it. See our +[guide to serialization & saving](/guides/serialization_and_saving/). +- Speed up model training by leveraging multiple GPUs. See our +[guide to multi-GPU and distributed training](https://keras.io/guides/distributed_training/). +""" + +""" +## Feature extraction with a Sequential model + +Once a Sequential model has been built, it behaves like a [Functional API +model](/guides/functional_api/). This means that every layer has an `input` +and `output` attribute. These attributes can be used to do neat things, like +quickly +creating a model that extracts the outputs of all intermediate layers in a +Sequential model: +""" + +initial_model = keras.Sequential( + [ + keras.Input(shape=(250, 250, 3)), + layers.Conv2D(32, 5, strides=2, activation="relu"), + layers.Conv2D(32, 3, activation="relu"), + layers.Conv2D(32, 3, activation="relu"), + ] +) +feature_extractor = keras.Model( + inputs=initial_model.inputs, + outputs=[layer.output for layer in initial_model.layers], +) + +# Call feature extractor on test input. +x = ops.ones((1, 250, 250, 3)) +features = feature_extractor(x) + +""" +Here's a similar example that only extract features from one layer: +""" + +initial_model = keras.Sequential( + [ + keras.Input(shape=(250, 250, 3)), + layers.Conv2D(32, 5, strides=2, activation="relu"), + layers.Conv2D(32, 3, activation="relu", name="my_intermediate_layer"), + layers.Conv2D(32, 3, activation="relu"), + ] +) +feature_extractor = keras.Model( + inputs=initial_model.inputs, + outputs=initial_model.get_layer(name="my_intermediate_layer").output, +) +# Call feature extractor on test input. +x = ops.ones((1, 250, 250, 3)) +features = feature_extractor(x) + +""" +## Transfer learning with a Sequential model + +Transfer learning consists of freezing the bottom layers in a model and only training +the top layers. If you aren't familiar with it, make sure to read our [guide +to transfer learning](/guides/transfer_learning/). + +Here are two common transfer learning blueprint involving Sequential models. + +First, let's say that you have a Sequential model, and you want to freeze all +layers except the last one. In this case, you would simply iterate over +`model.layers` and set `layer.trainable = False` on each layer, except the +last one. Like this: + +```python +model = keras.Sequential([ + keras.Input(shape=(784)), + layers.Dense(32, activation='relu'), + layers.Dense(32, activation='relu'), + layers.Dense(32, activation='relu'), + layers.Dense(10), +]) + +# Presumably you would want to first load pre-trained weights. +model.load_weights(...) + +# Freeze all layers except the last one. +for layer in model.layers[:-1]: + layer.trainable = False + +# Recompile and train (this will only update the weights of the last layer). +model.compile(...) +model.fit(...) +``` + +Another common blueprint is to use a Sequential model to stack a pre-trained +model and some freshly initialized classification layers. Like this: + +```python +# Load a convolutional base with pre-trained weights +base_model = keras.applications.Xception( + weights='imagenet', + include_top=False, + pooling='avg') + +# Freeze the base model +base_model.trainable = False + +# Use a Sequential model to add a trainable classifier on top +model = keras.Sequential([ + base_model, + layers.Dense(1000), +]) + +# Compile & train +model.compile(...) +model.fit(...) +``` + +If you do transfer learning, you will probably find yourself frequently using +these two patterns. +""" + +""" +That's about all you need to know about Sequential models! + +To find out more about building models in Keras, see: + +- [Guide to the Functional API](/guides/functional_api/) +- [Guide to making new Layers & Models via subclassing]( + /guides/making_new_layers_and_models_via_subclassing/) +""" diff --git a/guides/training_with_built_in_methods.py b/guides/training_with_built_in_methods.py new file mode 100644 index 000000000000..49a9dad1d8a9 --- /dev/null +++ b/guides/training_with_built_in_methods.py @@ -0,0 +1,1243 @@ +""" +Title: Training & evaluation with the built-in methods +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2019/03/01 +Last modified: 2023/03/25 +Description: Complete guide to training & evaluation with `fit()` and `evaluate()`. +Accelerator: GPU +""" + +""" +## Setup +""" + +# We import torch & TF so as to use torch Dataloaders & tf.data.Datasets. +import torch +import tensorflow as tf + +import os +import numpy as np +import keras +from keras import layers +from keras import ops + +""" +## Introduction + +This guide covers training, evaluation, and prediction (inference) models +when using built-in APIs for training & validation (such as `Model.fit()`, +`Model.evaluate()` and `Model.predict()`). + +If you are interested in leveraging `fit()` while specifying your +own training step function, see the +[Customizing what happens in `fit()` guide](/guides/customizing_what_happens_in_fit/). + +If you are interested in writing your own training & evaluation loops from +scratch, see the guide +["writing a training loop from scratch"](/guides/writing_a_training_loop_from_scratch/). + +In general, whether you are using built-in loops or writing your own, model training & +evaluation works strictly in the same way across every kind of Keras model -- +Sequential models, models built with the Functional API, and models written from +scratch via model subclassing. + +This guide doesn't cover distributed training, which is covered in our +[guide to multi-GPU & distributed training](https://keras.io/guides/distributed_training/). +""" + +""" +## API overview: a first end-to-end example + +When passing data to the built-in training loops of a model, you should either use: + +- NumPy arrays (if your data is small and fits in memory) +- Subclasses of `keras.utils.PyDataset` +- `tf.data.Dataset` objects +- PyTorch `DataLoader` instances + +In the next few paragraphs, we'll use the MNIST dataset as NumPy arrays, in +order to demonstrate how to use optimizers, losses, and metrics. Afterwards, we'll +take a close look at each of the other options. + +Let's consider the following model (here, we build in with the Functional API, but it +could be a Sequential model or a subclassed model as well): +""" + +inputs = keras.Input(shape=(784,), name="digits") +x = layers.Dense(64, activation="relu", name="dense_1")(inputs) +x = layers.Dense(64, activation="relu", name="dense_2")(x) +outputs = layers.Dense(10, activation="softmax", name="predictions")(x) + +model = keras.Model(inputs=inputs, outputs=outputs) + +""" +Here's what the typical end-to-end workflow looks like, consisting of: + +- Training +- Validation on a holdout set generated from the original training data +- Evaluation on the test data + +We'll use MNIST data for this example. +""" + +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +# Preprocess the data (these are NumPy arrays) +x_train = x_train.reshape(60000, 784).astype("float32") / 255 +x_test = x_test.reshape(10000, 784).astype("float32") / 255 + +y_train = y_train.astype("float32") +y_test = y_test.astype("float32") + +# Reserve 10,000 samples for validation +x_val = x_train[-10000:] +y_val = y_train[-10000:] +x_train = x_train[:-10000] +y_train = y_train[:-10000] + +""" +We specify the training configuration (optimizer, loss, metrics): +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(), # Optimizer + # Loss function to minimize + loss=keras.losses.SparseCategoricalCrossentropy(), + # List of metrics to monitor + metrics=[keras.metrics.SparseCategoricalAccuracy()], +) + +""" +We call `fit()`, which will train the model by slicing the data into "batches" of size +`batch_size`, and repeatedly iterating over the entire dataset for a given number of +`epochs`. +""" + +print("Fit model on training data") +history = model.fit( + x_train, + y_train, + batch_size=64, + epochs=2, + # We pass some validation for + # monitoring validation loss and metrics + # at the end of each epoch + validation_data=(x_val, y_val), +) + +""" +The returned `history` object holds a record of the loss values and metric values +during training: +""" + +history.history + +""" +We evaluate the model on the test data via `evaluate()`: +""" + +# Evaluate the model on the test data using `evaluate` +print("Evaluate on test data") +results = model.evaluate(x_test, y_test, batch_size=128) +print("test loss, test acc:", results) + +# Generate predictions (probabilities -- the output of the last layer) +# on new data using `predict` +print("Generate predictions for 3 samples") +predictions = model.predict(x_test[:3]) +print("predictions shape:", predictions.shape) + +""" +Now, let's review each piece of this workflow in detail. +""" + +""" +## The `compile()` method: specifying a loss, metrics, and an optimizer + +To train a model with `fit()`, you need to specify a loss function, an optimizer, and +optionally, some metrics to monitor. + +You pass these to the model as arguments to the `compile()` method: +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(learning_rate=1e-3), + loss=keras.losses.SparseCategoricalCrossentropy(), + metrics=[keras.metrics.SparseCategoricalAccuracy()], +) + +""" +The `metrics` argument should be a list -- your model can have any number of metrics. + +If your model has multiple outputs, you can specify different losses and metrics for +each output, and you can modulate the contribution of each output to the total loss of +the model. You will find more details about this in the **Passing data to multi-input, +multi-output models** section. + +Note that if you're satisfied with the default settings, in many cases the optimizer, +loss, and metrics can be specified via string identifiers as a shortcut: +""" + +model.compile( + optimizer="rmsprop", + loss="sparse_categorical_crossentropy", + metrics=["sparse_categorical_accuracy"], +) + +""" +For later reuse, let's put our model definition and compile step in functions; we will +call them several times across different examples in this guide. +""" + + +def get_uncompiled_model(): + inputs = keras.Input(shape=(784,), name="digits") + x = layers.Dense(64, activation="relu", name="dense_1")(inputs) + x = layers.Dense(64, activation="relu", name="dense_2")(x) + outputs = layers.Dense(10, activation="softmax", name="predictions")(x) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +def get_compiled_model(): + model = get_uncompiled_model() + model.compile( + optimizer="rmsprop", + loss="sparse_categorical_crossentropy", + metrics=["sparse_categorical_accuracy"], + ) + return model + + +""" +### Many built-in optimizers, losses, and metrics are available + +In general, you won't have to create your own losses, metrics, or optimizers +from scratch, because what you need is likely to be already part of the Keras API: + +Optimizers: + +- `SGD()` (with or without momentum) +- `RMSprop()` +- `Adam()` +- etc. + +Losses: + +- `MeanSquaredError()` +- `KLDivergence()` +- `CosineSimilarity()` +- etc. + +Metrics: + +- `AUC()` +- `Precision()` +- `Recall()` +- etc. +""" + +""" +### Custom losses + +If you need to create a custom loss, Keras provides three ways to do so. + +The first method involves creating a function that accepts inputs `y_true` and +`y_pred`. The following example shows a loss function that computes the mean squared +error between the real data and the predictions: +""" + + +def custom_mean_squared_error(y_true, y_pred): + return ops.mean(ops.square(y_true - y_pred), axis=-1) + + +model = get_uncompiled_model() +model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error) + +# We need to one-hot encode the labels to use MSE +y_train_one_hot = ops.one_hot(y_train, num_classes=10) +model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1) + +""" +If you need a loss function that takes in parameters beside `y_true` and `y_pred`, you +can subclass the `keras.losses.Loss` class and implement the following two methods: + +- `__init__(self)`: accept parameters to pass during the call of your loss function +- `call(self, y_true, y_pred)`: use the targets (y_true) and the model predictions +(y_pred) to compute the model's loss + +Let's say you want to use mean squared error, but with an added term that +will de-incentivize prediction values far from 0.5 (we assume that the categorical +targets are one-hot encoded and take values between 0 and 1). This +creates an incentive for the model not to be too confident, which may help +reduce overfitting (we won't know if it works until we try!). + +Here's how you would do it: +""" + + +class CustomMSE(keras.losses.Loss): + def __init__(self, regularization_factor=0.1, name="custom_mse"): + super().__init__(name=name) + self.regularization_factor = regularization_factor + + def call(self, y_true, y_pred): + mse = ops.mean(ops.square(y_true - y_pred), axis=-1) + reg = ops.mean(ops.square(0.5 - y_pred), axis=-1) + return mse + reg * self.regularization_factor + + +model = get_uncompiled_model() +model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE()) + +y_train_one_hot = ops.one_hot(y_train, num_classes=10) +model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1) + + +""" +### Custom metrics + +If you need a metric that isn't part of the API, you can easily create custom metrics +by subclassing the `keras.metrics.Metric` class. You will need to implement 4 +methods: + +- `__init__(self)`, in which you will create state variables for your metric. +- `update_state(self, y_true, y_pred, sample_weight=None)`, which uses the targets +y_true and the model predictions y_pred to update the state variables. +- `result(self)`, which uses the state variables to compute the final results. +- `reset_state(self)`, which reinitializes the state of the metric. + +State update and results computation are kept separate (in `update_state()` and +`result()`, respectively) because in some cases, the results computation might be very +expensive and would only be done periodically. + +Here's a simple example showing how to implement a `CategoricalTruePositives` metric +that counts how many samples were correctly classified as belonging to a given class: +""" + + +class CategoricalTruePositives(keras.metrics.Metric): + def __init__(self, name="categorical_true_positives", **kwargs): + super().__init__(name=name, **kwargs) + self.true_positives = self.add_variable( + shape=(), name="ctp", initializer="zeros" + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + y_pred = ops.reshape(ops.argmax(y_pred, axis=1), (-1, 1)) + values = ops.cast(y_true, "int32") == ops.cast(y_pred, "int32") + values = ops.cast(values, "float32") + if sample_weight is not None: + sample_weight = ops.cast(sample_weight, "float32") + values = ops.multiply(values, sample_weight) + self.true_positives.assign_add(ops.sum(values)) + + def result(self): + return self.true_positives + + def reset_state(self): + # The state of the metric will be reset at the start of each epoch. + self.true_positives.assign(0) + + +model = get_uncompiled_model() +model.compile( + optimizer=keras.optimizers.RMSprop(learning_rate=1e-3), + loss=keras.losses.SparseCategoricalCrossentropy(), + metrics=[CategoricalTruePositives()], +) +model.fit(x_train, y_train, batch_size=64, epochs=3) + +""" +### Handling losses and metrics that don't fit the standard signature + +The overwhelming majority of losses and metrics can be computed from `y_true` and +`y_pred`, where `y_pred` is an output of your model -- but not all of them. For +instance, a regularization loss may only require the activation of a layer (there are +no targets in this case), and this activation may not be a model output. + +In such cases, you can call `self.add_loss(loss_value)` from inside the call method of +a custom layer. Losses added in this way get added to the "main" loss during training +(the one passed to `compile()`). Here's a simple example that adds activity +regularization (note that activity regularization is built-in in all Keras layers -- +this layer is just for the sake of providing a concrete example): +""" + + +class ActivityRegularizationLayer(layers.Layer): + def call(self, inputs): + self.add_loss(ops.sum(inputs) * 0.1) + return inputs # Pass-through layer. + + +inputs = keras.Input(shape=(784,), name="digits") +x = layers.Dense(64, activation="relu", name="dense_1")(inputs) + +# Insert activity regularization as a layer +x = ActivityRegularizationLayer()(x) + +x = layers.Dense(64, activation="relu", name="dense_2")(x) +outputs = layers.Dense(10, name="predictions")(x) + +model = keras.Model(inputs=inputs, outputs=outputs) +model.compile( + optimizer=keras.optimizers.RMSprop(learning_rate=1e-3), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), +) + +# The displayed loss will be much higher than before +# due to the regularization component. +model.fit(x_train, y_train, batch_size=64, epochs=1) + +""" +Note that when you pass losses via `add_loss()`, it becomes possible to call +`compile()` without a loss function, since the model already has a loss to minimize. + +Consider the following `LogisticEndpoint` layer: it takes as inputs +targets & logits, and it tracks a crossentropy loss via `add_loss()`. +""" + + +class LogisticEndpoint(keras.layers.Layer): + def __init__(self, name=None): + super().__init__(name=name) + self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True) + + def call(self, targets, logits, sample_weights=None): + # Compute the training-time loss value and add it + # to the layer using `self.add_loss()`. + loss = self.loss_fn(targets, logits, sample_weights) + self.add_loss(loss) + + # Return the inference-time prediction tensor (for `.predict()`). + return ops.softmax(logits) + + +""" +You can use it in a model with two inputs (input data & targets), compiled without a +`loss` argument, like this: +""" + +inputs = keras.Input(shape=(3,), name="inputs") +targets = keras.Input(shape=(10,), name="targets") +logits = keras.layers.Dense(10)(inputs) +predictions = LogisticEndpoint(name="predictions")(targets, logits) + +model = keras.Model(inputs=[inputs, targets], outputs=predictions) +model.compile(optimizer="adam") # No loss argument! + +data = { + "inputs": np.random.random((3, 3)), + "targets": np.random.random((3, 10)), +} +model.fit(data) + +""" +For more information about training multi-input models, see the section **Passing data +to multi-input, multi-output models**. +""" + +""" +### Automatically setting apart a validation holdout set + +In the first end-to-end example you saw, we used the `validation_data` argument to pass +a tuple of NumPy arrays `(x_val, y_val)` to the model for evaluating a validation loss +and validation metrics at the end of each epoch. + +Here's another option: the argument `validation_split` allows you to automatically +reserve part of your training data for validation. The argument value represents the +fraction of the data to be reserved for validation, so it should be set to a number +higher than 0 and lower than 1. For instance, `validation_split=0.2` means "use 20% of +the data for validation", and `validation_split=0.6` means "use 60% of the data for +validation". + +The way the validation is computed is by taking the last x% samples of the arrays +received by the `fit()` call, before any shuffling. + +Note that you can only use `validation_split` when training with NumPy data. +""" + +model = get_compiled_model() +model.fit(x_train, y_train, batch_size=64, validation_split=0.2, epochs=1) + +""" +## Training & evaluation using `tf.data` Datasets + +In the past few paragraphs, you've seen how to handle losses, metrics, and optimizers, +and you've seen how to use the `validation_data` and `validation_split` arguments in +`fit()`, when your data is passed as NumPy arrays. + +Another option is to use an iterator-like, such as a `tf.data.Dataset`, a +PyTorch `DataLoader`, or a Keras `PyDataset`. Let's take look at the former. + +The `tf.data` API is a set of utilities in TensorFlow 2.0 for loading and preprocessing +data in a way that's fast and scalable. For a complete guide about creating `Datasets`, +see the [tf.data documentation](https://www.tensorflow.org/guide/data). + +**You can use `tf.data` to train your Keras +models regardless of the backend you're using -- +whether it's JAX, PyTorch, or TensorFlow.** +You can pass a `Dataset` instance directly to the methods `fit()`, `evaluate()`, and +`predict()`: +""" + +model = get_compiled_model() + +# First, let's create a training Dataset instance. +# For the sake of our example, we'll use the same MNIST data as before. +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +# Shuffle and slice the dataset. +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) + +# Now we get a test dataset. +test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) +test_dataset = test_dataset.batch(64) + +# Since the dataset already takes care of batching, +# we don't pass a `batch_size` argument. +model.fit(train_dataset, epochs=3) + +# You can also evaluate or predict on a dataset. +print("Evaluate") +result = model.evaluate(test_dataset) +dict(zip(model.metrics_names, result)) + +""" +Note that the Dataset is reset at the end of each epoch, so it can be reused of the +next epoch. + +If you want to run training only on a specific number of batches from this Dataset, you +can pass the `steps_per_epoch` argument, which specifies how many training steps the +model should run using this Dataset before moving on to the next epoch. +""" + +model = get_compiled_model() + +# Prepare the training dataset +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) + +# Only use the 100 batches per epoch (that's 64 * 100 samples) +model.fit(train_dataset, epochs=3, steps_per_epoch=100) + +""" +You can also pass a `Dataset` instance as the `validation_data` argument in `fit()`: +""" + +model = get_compiled_model() + +# Prepare the training dataset +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) + +# Prepare the validation dataset +val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)) +val_dataset = val_dataset.batch(64) + +model.fit(train_dataset, epochs=1, validation_data=val_dataset) + +""" +At the end of each epoch, the model will iterate over the validation dataset and +compute the validation loss and validation metrics. + +If you want to run validation only on a specific number of batches from this dataset, +you can pass the `validation_steps` argument, which specifies how many validation +steps the model should run with the validation dataset before interrupting validation +and moving on to the next epoch: +""" + +model = get_compiled_model() + +# Prepare the training dataset +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) + +# Prepare the validation dataset +val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)) +val_dataset = val_dataset.batch(64) + +model.fit( + train_dataset, + epochs=1, + # Only run validation using the first 10 batches of the dataset + # using the `validation_steps` argument + validation_data=val_dataset, + validation_steps=10, +) + +""" +Note that the validation dataset will be reset after each use (so that you will always +be evaluating on the same samples from epoch to epoch). + +The argument `validation_split` (generating a holdout set from the training data) is +not supported when training from `Dataset` objects, since this feature requires the +ability to index the samples of the datasets, which is not possible in general with +the `Dataset` API. +""" + +""" +## Training & evaluation using `PyDataset` instances + +`keras.utils.PyDataset` is a utility that you can subclass to obtain +a Python generator with two important properties: + +- It works well with multiprocessing. +- It can be shuffled (e.g. when passing `shuffle=True` in `fit()`). + +A `PyDataset` must implement two methods: + +- `__getitem__` +- `__len__` + +The method `__getitem__` should return a complete batch. +If you want to modify your dataset between epochs, you may implement `on_epoch_end`. +You may also implement `on_epoch_begin` to be called at the start of each epoch. + +Here's a quick example: +""" + + +class ExamplePyDataset(keras.utils.PyDataset): + def __init__(self, x, y, batch_size, **kwargs): + super().__init__(**kwargs) + self.x = x + self.y = y + self.batch_size = batch_size + + def __len__(self): + return int(np.ceil(len(self.x) / float(self.batch_size))) + + def __getitem__(self, idx): + batch_x = self.x[idx * self.batch_size : (idx + 1) * self.batch_size] + batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size] + return batch_x, batch_y + + +train_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32) +val_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32) + +""" +To fit the model, pass the dataset instead as the `x` argument (no need for a `y` +argument since the dataset includes the targets), and pass the validation dataset +as the `validation_data` argument. And no need for the `validation_batch_size` +argument, since the dataset is already batched! +""" + +model = get_compiled_model() +model.fit( + train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1 +) + +""" +Evaluating the model is just as easy: +""" + +model.evaluate(val_py_dataset) + +""" +Importantly, `PyDataset` objects support three common constructor arguments +that handle the parallel processing configuration: + +- `workers`: Number of workers to use in multithreading or + multiprocessing. Typically, you'd set it to the number of + cores on your CPU. +- `use_multiprocessing`: Whether to use Python multiprocessing for + parallelism. Setting this to `True` means that your + dataset will be replicated in multiple forked processes. + This is necessary to gain compute-level (rather than I/O level) + benefits from parallelism. However it can only be set to + `True` if your dataset can be safely pickled. +- `max_queue_size`: Maximum number of batches to keep in the queue + when iterating over the dataset in a multithreaded or + multiprocessed setting. + You can reduce this value to reduce the CPU memory consumption of + your dataset. It defaults to 10. + +By default, multiprocessing is disabled (`use_multiprocessing=False`) and only +one thread is used. You should make sure to only turn on `use_multiprocessing` if +your code is running inside a Python `if __name__ == "__main__":` block in order +to avoid issues. + +Here's a 4-thread, non-multiprocessed example: +""" + +train_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32, workers=4) +val_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32, workers=4) + +model = get_compiled_model() +model.fit( + train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1 +) + +""" +## Training & evaluation using PyTorch `DataLoader` objects + +All built-in training and evaluation APIs are also compatible with `torch.utils.data.Dataset` and +`torch.utils.data.DataLoader` objects -- regardless of whether you're using the PyTorch backend, +or the JAX or TensorFlow backends. Let's take a look at a simple example. + +Unlike `PyDataset` which are batch-centric, PyTorch `Dataset` objects are sample-centric: +the `__len__` method returns the number of samples, +and the `__getitem__` method returns a specific sample. +""" + + +class ExampleTorchDataset(torch.utils.data.Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +train_torch_dataset = ExampleTorchDataset(x_train, y_train) +val_torch_dataset = ExampleTorchDataset(x_val, y_val) + +""" +To use a PyTorch Dataset, you need to wrap it into a `Dataloader` which takes care +of batching and shuffling: +""" + +train_dataloader = torch.utils.data.DataLoader( + train_torch_dataset, batch_size=32, shuffle=True +) +val_dataloader = torch.utils.data.DataLoader( + val_torch_dataset, batch_size=32, shuffle=True +) + +""" +Now you can use them in the Keras API just like any other iterator: +""" + +model = get_compiled_model() +model.fit( + train_dataloader, batch_size=64, validation_data=val_dataloader, epochs=1 +) +model.evaluate(val_dataloader) + +""" +## Using sample weighting and class weighting + +With the default settings the weight of a sample is decided by its frequency +in the dataset. There are two methods to weight the data, independent of +sample frequency: + +* Class weights +* Sample weights +""" + +""" +### Class weights + +This is set by passing a dictionary to the `class_weight` argument to +`Model.fit()`. This dictionary maps class indices to the weight that should +be used for samples belonging to this class. + +This can be used to balance classes without resampling, or to train a +model that gives more importance to a particular class. + +For instance, if class "0" is half as represented as class "1" in your data, +you could use `Model.fit(..., class_weight={0: 1., 1: 0.5})`. +""" + +""" +Here's a NumPy example where we use class weights or sample weights to +give more importance to the correct classification of class #5 (which +is the digit "5" in the MNIST dataset). +""" + +class_weight = { + 0: 1.0, + 1: 1.0, + 2: 1.0, + 3: 1.0, + 4: 1.0, + # Set weight "2" for class "5", + # making this class 2x more important + 5: 2.0, + 6: 1.0, + 7: 1.0, + 8: 1.0, + 9: 1.0, +} + +print("Fit with class weight") +model = get_compiled_model() +model.fit(x_train, y_train, class_weight=class_weight, batch_size=64, epochs=1) + +""" +### Sample weights + +For fine grained control, or if you are not building a classifier, +you can use "sample weights". + +- When training from NumPy data: Pass the `sample_weight` + argument to `Model.fit()`. +- When training from `tf.data` or any other sort of iterator: + Yield `(input_batch, label_batch, sample_weight_batch)` tuples. + +A "sample weights" array is an array of numbers that specify how much weight +each sample in a batch should have in computing the total loss. It is commonly +used in imbalanced classification problems (the idea being to give more weight +to rarely-seen classes). + +When the weights used are ones and zeros, the array can be used as a *mask* for +the loss function (entirely discarding the contribution of certain samples to +the total loss). +""" + +sample_weight = np.ones(shape=(len(y_train),)) +sample_weight[y_train == 5] = 2.0 + +print("Fit with sample weight") +model = get_compiled_model() +model.fit( + x_train, y_train, sample_weight=sample_weight, batch_size=64, epochs=1 +) + +""" +Here's a matching `Dataset` example: +""" + +sample_weight = np.ones(shape=(len(y_train),)) +sample_weight[y_train == 5] = 2.0 + +# Create a Dataset that includes sample weights +# (3rd element in the return tuple). +train_dataset = tf.data.Dataset.from_tensor_slices( + (x_train, y_train, sample_weight) +) + +# Shuffle and slice the dataset. +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) + +model = get_compiled_model() +model.fit(train_dataset, epochs=1) + +""" +## Passing data to multi-input, multi-output models + +In the previous examples, we were considering a model with a single input (a tensor of +shape `(764,)`) and a single output (a prediction tensor of shape `(10,)`). But what +about models that have multiple inputs or outputs? + +Consider the following model, which has an image input of shape `(32, 32, 3)` (that's +`(height, width, channels)`) and a time series input of shape `(None, 10)` (that's +`(timesteps, features)`). Our model will have two outputs computed from the +combination of these inputs: a "score" (of shape `(1,)`) and a probability +distribution over five classes (of shape `(5,)`). +""" + +image_input = keras.Input(shape=(32, 32, 3), name="img_input") +timeseries_input = keras.Input(shape=(None, 10), name="ts_input") + +x1 = layers.Conv2D(3, 3)(image_input) +x1 = layers.GlobalMaxPooling2D()(x1) + +x2 = layers.Conv1D(3, 3)(timeseries_input) +x2 = layers.GlobalMaxPooling1D()(x2) + +x = layers.concatenate([x1, x2]) + +score_output = layers.Dense(1, name="score_output")(x) +class_output = layers.Dense(5, name="class_output")(x) + +model = keras.Model( + inputs=[image_input, timeseries_input], outputs=[score_output, class_output] +) + +""" +Let's plot this model, so you can clearly see what we're doing here (note that the +shapes shown in the plot are batch shapes, rather than per-sample shapes). +""" + +keras.utils.plot_model( + model, "multi_input_and_output_model.png", show_shapes=True +) + +""" +At compilation time, we can specify different losses to different outputs, by passing +the loss functions as a list: +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss=[ + keras.losses.MeanSquaredError(), + keras.losses.CategoricalCrossentropy(), + ], +) + +""" +If we only passed a single loss function to the model, the same loss function would be +applied to every output (which is not appropriate here). + +Likewise for metrics: +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss=[ + keras.losses.MeanSquaredError(), + keras.losses.CategoricalCrossentropy(), + ], + metrics=[ + [ + keras.metrics.MeanAbsolutePercentageError(), + keras.metrics.MeanAbsoluteError(), + ], + [keras.metrics.CategoricalAccuracy()], + ], +) + +""" +Since we gave names to our output layers, we could also specify per-output losses and +metrics via a dict: +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss={ + "score_output": keras.losses.MeanSquaredError(), + "class_output": keras.losses.CategoricalCrossentropy(), + }, + metrics={ + "score_output": [ + keras.metrics.MeanAbsolutePercentageError(), + keras.metrics.MeanAbsoluteError(), + ], + "class_output": [keras.metrics.CategoricalAccuracy()], + }, +) + +""" +We recommend the use of explicit names and dicts if you have more than 2 outputs. + +It's possible to give different weights to different output-specific losses (for +instance, one might wish to privilege the "score" loss in our example, by giving to 2x +the importance of the class loss), using the `loss_weights` argument: +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss={ + "score_output": keras.losses.MeanSquaredError(), + "class_output": keras.losses.CategoricalCrossentropy(), + }, + metrics={ + "score_output": [ + keras.metrics.MeanAbsolutePercentageError(), + keras.metrics.MeanAbsoluteError(), + ], + "class_output": [keras.metrics.CategoricalAccuracy()], + }, + loss_weights={"score_output": 2.0, "class_output": 1.0}, +) + +""" +You could also choose not to compute a loss for certain outputs, if these outputs are +meant for prediction but not for training: +""" + +# List loss version +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss=[None, keras.losses.CategoricalCrossentropy()], +) + +# Or dict loss version +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss={"class_output": keras.losses.CategoricalCrossentropy()}, +) + +""" +Passing data to a multi-input or multi-output model in `fit()` works in a similar way as +specifying a loss function in compile: you can pass **lists of NumPy arrays** (with +1:1 mapping to the outputs that received a loss function) or **dicts mapping output +names to NumPy arrays**. +""" + +model.compile( + optimizer=keras.optimizers.RMSprop(1e-3), + loss=[ + keras.losses.MeanSquaredError(), + keras.losses.CategoricalCrossentropy(), + ], +) + +# Generate dummy NumPy data +img_data = np.random.random_sample(size=(100, 32, 32, 3)) +ts_data = np.random.random_sample(size=(100, 20, 10)) +score_targets = np.random.random_sample(size=(100, 1)) +class_targets = np.random.random_sample(size=(100, 5)) + +# Fit on lists +model.fit( + [img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1 +) + +# Alternatively, fit on dicts +model.fit( + {"img_input": img_data, "ts_input": ts_data}, + {"score_output": score_targets, "class_output": class_targets}, + batch_size=32, + epochs=1, +) + +""" +Here's the `Dataset` use case: similarly as what we did for NumPy arrays, the `Dataset` +should return a tuple of dicts. +""" + +train_dataset = tf.data.Dataset.from_tensor_slices( + ( + {"img_input": img_data, "ts_input": ts_data}, + {"score_output": score_targets, "class_output": class_targets}, + ) +) +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64) + +model.fit(train_dataset, epochs=1) + +""" +## Using callbacks + +Callbacks in Keras are objects that are called at different points during training (at +the start of an epoch, at the end of a batch, at the end of an epoch, etc.). They +can be used to implement certain behaviors, such as: + +- Doing validation at different points during training (beyond the built-in per-epoch +validation) +- Checkpointing the model at regular intervals or when it exceeds a certain accuracy +threshold +- Changing the learning rate of the model when training seems to be plateauing +- Doing fine-tuning of the top layers when training seems to be plateauing +- Sending email or instant message notifications when training ends or where a certain +performance threshold is exceeded +- Etc. + +Callbacks can be passed as a list to your call to `fit()`: +""" + +model = get_compiled_model() + +callbacks = [ + keras.callbacks.EarlyStopping( + # Stop training when `val_loss` is no longer improving + monitor="val_loss", + # "no longer improving" being defined as "no better than 1e-2 less" + min_delta=1e-2, + # "no longer improving" being further defined as "for at least 2 epochs" + patience=2, + verbose=1, + ) +] +model.fit( + x_train, + y_train, + epochs=20, + batch_size=64, + callbacks=callbacks, + validation_split=0.2, +) + +""" +### Many built-in callbacks are available + +There are many built-in callbacks already available in Keras, such as: + +- `ModelCheckpoint`: Periodically save the model. +- `EarlyStopping`: Stop training when training is no longer improving the validation +metrics. +- `TensorBoard`: periodically write model logs that can be visualized in +[TensorBoard](https://www.tensorflow.org/tensorboard) (more details in the section +"Visualization"). +- `CSVLogger`: streams loss and metrics data to a CSV file. +- etc. + +See the [callbacks documentation](/api/callbacks/) for the complete list. + +### Writing your own callback + +You can create a custom callback by extending the base class +`keras.callbacks.Callback`. A callback has access to its associated model through the +class property `self.model`. + +Make sure to read the +[complete guide to writing custom callbacks](/guides/writing_your_own_callbacks/). + +Here's a simple example saving a list of per-batch loss values during training: +""" + + +class LossHistory(keras.callbacks.Callback): + def on_train_begin(self, logs): + self.per_batch_losses = [] + + def on_batch_end(self, batch, logs): + self.per_batch_losses.append(logs.get("loss")) + + +""" +## Checkpointing models + +When you're training model on relatively large datasets, it's crucial to save +checkpoints of your model at frequent intervals. + +The easiest way to achieve this is with the `ModelCheckpoint` callback: +""" + +model = get_compiled_model() + +callbacks = [ + keras.callbacks.ModelCheckpoint( + # Path where to save the model + # The two parameters below mean that we will overwrite + # the current checkpoint if and only if + # the `val_loss` score has improved. + # The saved model name will include the current epoch. + filepath="mymodel_{epoch}.keras", + save_best_only=True, # Only save a model if `val_loss` has improved. + monitor="val_loss", + verbose=1, + ) +] +model.fit( + x_train, + y_train, + epochs=2, + batch_size=64, + callbacks=callbacks, + validation_split=0.2, +) + +""" +The `ModelCheckpoint` callback can be used to implement fault-tolerance: +the ability to restart training from the last saved state of the model in case training +gets randomly interrupted. Here's a basic example: +""" + +# Prepare a directory to store all the checkpoints. +checkpoint_dir = "./ckpt" +if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + +def make_or_restore_model(): + # Either restore the latest model, or create a fresh one + # if there is no checkpoint available. + checkpoints = [ + os.path.join(checkpoint_dir, name) + for name in os.listdir(checkpoint_dir) + ] + if checkpoints: + latest_checkpoint = max(checkpoints, key=os.path.getctime) + print("Restoring from", latest_checkpoint) + return keras.models.load_model(latest_checkpoint) + print("Creating a new model") + return get_compiled_model() + + +model = make_or_restore_model() +callbacks = [ + # This callback saves the model every 100 batches. + # We include the training loss in the saved model name. + keras.callbacks.ModelCheckpoint( + filepath=os.path.join(checkpoint_dir, "model-loss={loss:.2f}.keras"), + save_freq=100, + ) +] +model.fit(x_train, y_train, epochs=1, callbacks=callbacks) + +""" +You call also write your own callback for saving and restoring models. + +For a complete guide on serialization and saving, see the +[guide to saving and serializing Models](/guides/serialization_and_saving/). +""" + +""" +## Using learning rate schedules + +A common pattern when training deep learning models is to gradually reduce the learning +as training progresses. This is generally known as "learning rate decay". + +The learning decay schedule could be static (fixed in advance, as a function of the +current epoch or the current batch index), or dynamic (responding to the current +behavior of the model, in particular the validation loss). + +### Passing a schedule to an optimizer + +You can easily use a static learning rate decay schedule by passing a schedule object +as the `learning_rate` argument in your optimizer: +""" + +initial_learning_rate = 0.1 +lr_schedule = keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True +) + +optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule) + +""" +Several built-in schedules are available: `ExponentialDecay`, `PiecewiseConstantDecay`, +`PolynomialDecay`, and `InverseTimeDecay`. + +### Using callbacks to implement a dynamic learning rate schedule + +A dynamic learning rate schedule (for instance, decreasing the learning rate when the +validation loss is no longer improving) cannot be achieved with these schedule objects, +since the optimizer does not have access to validation metrics. + +However, callbacks do have access to all metrics, including validation metrics! You can +thus achieve this pattern by using a callback that modifies the current learning rate +on the optimizer. In fact, this is even built-in as the `ReduceLROnPlateau` callback. +""" + +""" +## Visualizing loss and metrics during training with TensorBoard + +The best way to keep an eye on your model during training is to use +[TensorBoard](https://www.tensorflow.org/tensorboard) -- a browser-based application +that you can run locally that provides you with: + +- Live plots of the loss and metrics for training and evaluation +- (optionally) Visualizations of the histograms of your layer activations +- (optionally) 3D visualizations of the embedding spaces learned by your `Embedding` +layers + +If you have installed TensorFlow with pip, you should be able to launch TensorBoard +from the command line: + +``` +tensorboard --logdir=/full_path_to_your_logs +``` +""" + +""" +### Using the TensorBoard callback + +The easiest way to use TensorBoard with a Keras model and the `fit()` method is the +`TensorBoard` callback. + +In the simplest case, just specify where you want the callback to write logs, and +you're good to go: +""" + +keras.callbacks.TensorBoard( + log_dir="/full_path_to_your_logs", + histogram_freq=0, # How often to log histogram visualizations + embeddings_freq=0, # How often to log embedding visualizations + update_freq="epoch", +) # How often to write logs (default: once per epoch) + +""" +For more information, see the +[documentation for the `TensorBoard` callback](https://keras.io/api/callbacks/tensorboard/). +""" diff --git a/guides/transfer_learning.py b/guides/transfer_learning.py new file mode 100644 index 000000000000..94716de6eb78 --- /dev/null +++ b/guides/transfer_learning.py @@ -0,0 +1,557 @@ +""" +Title: Transfer learning & fine-tuning +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2020/04/15 +Last modified: 2023/06/25 +Description: Complete guide to transfer learning & fine-tuning in Keras. +Accelerator: GPU +""" + +""" +## Setup +""" + +import numpy as np +import keras +from keras import layers +import tensorflow_datasets as tfds +import matplotlib.pyplot as plt + +""" +## Introduction + +**Transfer learning** consists of taking features learned on one problem, and +leveraging them on a new, similar problem. For instance, features from a model that has +learned to identify raccoons may be useful to kick-start a model meant to identify + tanukis. + +Transfer learning is usually done for tasks where your dataset has too little data to + train a full-scale model from scratch. + +The most common incarnation of transfer learning in the context of deep learning is the + following workflow: + +1. Take layers from a previously trained model. +2. Freeze them, so as to avoid destroying any of the information they contain during + future training rounds. +3. Add some new, trainable layers on top of the frozen layers. They will learn to turn + the old features into predictions on a new dataset. +4. Train the new layers on your dataset. + +A last, optional step, is **fine-tuning**, which consists of unfreezing the entire +model you obtained above (or part of it), and re-training it on the new data with a +very low learning rate. This can potentially achieve meaningful improvements, by + incrementally adapting the pretrained features to the new data. + +First, we will go over the Keras `trainable` API in detail, which underlies most + transfer learning & fine-tuning workflows. + +Then, we'll demonstrate the typical workflow by taking a model pretrained on the +ImageNet dataset, and retraining it on the Kaggle "cats vs dogs" classification + dataset. + +This is adapted from +[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python) +and the 2016 blog post +["building powerful image classification models using very little data"](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html). +""" + +""" +## Freezing layers: understanding the `trainable` attribute + +Layers & models have three weight attributes: + +- `weights` is the list of all weights variables of the layer. +- `trainable_weights` is the list of those that are meant to be updated (via gradient + descent) to minimize the loss during training. +- `non_trainable_weights` is the list of those that aren't meant to be trained. + Typically they are updated by the model during the forward pass. + +**Example: the `Dense` layer has 2 trainable weights (kernel & bias)** +""" + +layer = keras.layers.Dense(3) +layer.build((None, 4)) # Create the weights + +print("weights:", len(layer.weights)) +print("trainable_weights:", len(layer.trainable_weights)) +print("non_trainable_weights:", len(layer.non_trainable_weights)) + +""" +In general, all weights are trainable weights. The only built-in layer that has +non-trainable weights is the `BatchNormalization` layer. It uses non-trainable weights + to keep track of the mean and variance of its inputs during training. +To learn how to use non-trainable weights in your own custom layers, see the +[guide to writing new layers from scratch](https://keras.io/guides/making_new_layers_and_models_via_subclassing/). + +**Example: the `BatchNormalization` layer has 2 trainable weights and 2 non-trainable + weights** +""" + +layer = keras.layers.BatchNormalization() +layer.build((None, 4)) # Create the weights + +print("weights:", len(layer.weights)) +print("trainable_weights:", len(layer.trainable_weights)) +print("non_trainable_weights:", len(layer.non_trainable_weights)) + +""" +Layers & models also feature a boolean attribute `trainable`. Its value can be changed. +Setting `layer.trainable` to `False` moves all the layer's weights from trainable to +non-trainable. This is called "freezing" the layer: the state of a frozen layer won't +be updated during training (either when training with `fit()` or when training with + any custom loop that relies on `trainable_weights` to apply gradient updates). + +**Example: setting `trainable` to `False`** +""" + +layer = keras.layers.Dense(3) +layer.build((None, 4)) # Create the weights +layer.trainable = False # Freeze the layer + +print("weights:", len(layer.weights)) +print("trainable_weights:", len(layer.trainable_weights)) +print("non_trainable_weights:", len(layer.non_trainable_weights)) + +""" +When a trainable weight becomes non-trainable, its value is no longer updated during + training. +""" + +# Make a model with 2 layers +layer1 = keras.layers.Dense(3, activation="relu") +layer2 = keras.layers.Dense(3, activation="sigmoid") +model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2]) + +# Freeze the first layer +layer1.trainable = False + +# Keep a copy of the weights of layer1 for later reference +initial_layer1_weights_values = layer1.get_weights() + +# Train the model +model.compile(optimizer="adam", loss="mse") +model.fit(np.random.random((2, 3)), np.random.random((2, 3))) + +# Check that the weights of layer1 have not changed during training +final_layer1_weights_values = layer1.get_weights() +np.testing.assert_allclose( + initial_layer1_weights_values[0], final_layer1_weights_values[0] +) +np.testing.assert_allclose( + initial_layer1_weights_values[1], final_layer1_weights_values[1] +) + +""" +Do not confuse the `layer.trainable` attribute with the argument `training` in +`layer.__call__()` (which controls whether the layer should run its forward pass in + inference mode or training mode). For more information, see the +[Keras FAQ]( + https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute). +""" + +""" +## Recursive setting of the `trainable` attribute + +If you set `trainable = False` on a model or on any layer that has sublayers, +all children layers become non-trainable as well. + +**Example:** +""" + +inner_model = keras.Sequential( + [ + keras.Input(shape=(3,)), + keras.layers.Dense(3, activation="relu"), + keras.layers.Dense(3, activation="relu"), + ] +) + +model = keras.Sequential( + [ + keras.Input(shape=(3,)), + inner_model, + keras.layers.Dense(3, activation="sigmoid"), + ] +) + +model.trainable = False # Freeze the outer model + +assert inner_model.trainable == False # All layers in `model` are now frozen +assert ( + inner_model.layers[0].trainable == False +) # `trainable` is propagated recursively + +""" +## The typical transfer-learning workflow + +This leads us to how a typical transfer learning workflow can be implemented in Keras: + +1. Instantiate a base model and load pre-trained weights into it. +2. Freeze all layers in the base model by setting `trainable = False`. +3. Create a new model on top of the output of one (or several) layers from the base + model. +4. Train your new model on your new dataset. + +Note that an alternative, more lightweight workflow could also be: + +1. Instantiate a base model and load pre-trained weights into it. +2. Run your new dataset through it and record the output of one (or several) layers + from the base model. This is called **feature extraction**. +3. Use that output as input data for a new, smaller model. + +A key advantage of that second workflow is that you only run the base model once on + your data, rather than once per epoch of training. So it's a lot faster & cheaper. + +An issue with that second workflow, though, is that it doesn't allow you to dynamically +modify the input data of your new model during training, which is required when doing +data augmentation, for instance. Transfer learning is typically used for tasks when +your new dataset has too little data to train a full-scale model from scratch, and in +such scenarios data augmentation is very important. So in what follows, we will focus + on the first workflow. + +Here's what the first workflow looks like in Keras: + +First, instantiate a base model with pre-trained weights. + +```python +base_model = keras.applications.Xception( + weights='imagenet', # Load weights pre-trained on ImageNet. + input_shape=(150, 150, 3), + include_top=False) # Do not include the ImageNet classifier at the top. +``` + +Then, freeze the base model. + +```python +base_model.trainable = False +``` + +Create a new model on top. + +```python +inputs = keras.Input(shape=(150, 150, 3)) +# We make sure that the base_model is running in inference mode here, +# by passing `training=False`. This is important for fine-tuning, as you will +# learn in a few paragraphs. +x = base_model(inputs, training=False) +# Convert features of shape `base_model.output_shape[1:]` to vectors +x = keras.layers.GlobalAveragePooling2D()(x) +# A Dense classifier with a single unit (binary classification) +outputs = keras.layers.Dense(1)(x) +model = keras.Model(inputs, outputs) +``` + +Train the model on new data. + +```python +model.compile(optimizer=keras.optimizers.Adam(), + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.BinaryAccuracy()]) +model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...) +``` + +""" + +""" +## Fine-tuning + +Once your model has converged on the new data, you can try to unfreeze all or part of + the base model and retrain the whole model end-to-end with a very low learning rate. + +This is an optional last step that can potentially give you incremental improvements. + It could also potentially lead to quick overfitting -- keep that in mind. + +It is critical to only do this step *after* the model with frozen layers has been +trained to convergence. If you mix randomly-initialized trainable layers with +trainable layers that hold pre-trained features, the randomly-initialized layers will +cause very large gradient updates during training, which will destroy your pre-trained + features. + +It's also critical to use a very low learning rate at this stage, because +you are training a much larger model than in the first round of training, on a dataset + that is typically very small. +As a result, you are at risk of overfitting very quickly if you apply large weight + updates. Here, you only want to readapt the pretrained weights in an incremental way. + +This is how to implement fine-tuning of the whole base model: + +```python +# Unfreeze the base model +base_model.trainable = True + +# It's important to recompile your model after you make any changes +# to the `trainable` attribute of any inner layer, so that your changes +# are take into account +model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.BinaryAccuracy()]) + +# Train end-to-end. Be careful to stop before you overfit! +model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...) +``` + +**Important note about `compile()` and `trainable`** + +Calling `compile()` on a model is meant to "freeze" the behavior of that model. This + implies that the `trainable` +attribute values at the time the model is compiled should be preserved throughout the + lifetime of that model, +until `compile` is called again. Hence, if you change any `trainable` value, make sure + to call `compile()` again on your +model for your changes to be taken into account. + +**Important notes about `BatchNormalization` layer** + +Many image models contain `BatchNormalization` layers. That layer is a special case on + every imaginable count. Here are a few things to keep in mind. + +- `BatchNormalization` contains 2 non-trainable weights that get updated during +training. These are the variables tracking the mean and variance of the inputs. +- When you set `bn_layer.trainable = False`, the `BatchNormalization` layer will +run in inference mode, and will not update its mean & variance statistics. This is not +the case for other layers in general, as +[weight trainability & inference/training modes are two orthogonal concepts]( + https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute). +But the two are tied in the case of the `BatchNormalization` layer. +- When you unfreeze a model for finetuning by setting `base_model.trainable=True` that +contains `BatchNormalization` layers, then all layers of the base model become +trainable along with `BatchNormalization` layers. It's a good idea to keep +`BatchNormalization` either frozen during fine-tuning, or running in inference mode, +so remember to set `layer.trainable = False` +on those layers specifically after unfreezing the outer model, or otherwise +call the model with `training=False` to keep it inference mode. + +You'll see this pattern in action in the end-to-end example at the end of this guide. +""" + +""" +## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset + +To solidify these concepts, let's walk you through a concrete end-to-end transfer +learning & fine-tuning example. We will load the Xception model, pre-trained on + ImageNet, and use it on the Kaggle "cats vs. dogs" classification dataset. +""" + +""" +### Getting the data + +First, let's fetch the cats vs. dogs dataset using TFDS. If you have your own dataset, +you'll probably want to use the utility +`keras.utils.image_dataset_from_directory` to generate similar labeled + dataset objects from a set of images on disk filed into class-specific folders. + +Transfer learning is most useful when working with very small datasets. To keep our +dataset small, we will use 40% of the original training data (25,000 images) for + training, 10% for validation, and 10% for testing. +""" + +tfds.disable_progress_bar() + +train_ds, validation_ds, test_ds = tfds.load( + "cats_vs_dogs", + # Reserve 10% for validation and 10% for test + split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"], + as_supervised=True, # Include labels +) + +print(f"Number of training samples: {train_ds.cardinality()}") +print(f"Number of validation samples: {validation_ds.cardinality()}") +print(f"Number of test samples: {test_ds.cardinality()}") + +""" +These are the first 9 images in the training dataset -- as you can see, they're all +different sizes. +""" + +plt.figure(figsize=(10, 10)) +for i, (image, label) in enumerate(train_ds.take(9)): + ax = plt.subplot(3, 3, i + 1) + plt.imshow(image) + plt.title(int(label)) + plt.axis("off") + +""" +We can also see that label 1 is "dog" and label 0 is "cat". +""" + +""" +### Standardizing the data + +Our raw images have a variety of sizes. In addition, each pixel consists of 3 integer +values between 0 and 255 (RGB level values). This isn't a great fit for feeding a +neural network. We need to do 2 things: + +- Standardize to a fixed image size. We pick 150x150. +- Normalize pixel values between -1 and 1. We'll do this using a `Normalization` layer as +part of the model itself. + +In general, it's a good practice to develop models that take raw data as input, as +opposed to models that take already-preprocessed data. The reason being that, if your +model expects preprocessed data, any time you export your model to use it elsewhere +(in a web browser, in a mobile app), you'll need to reimplement the exact same +preprocessing pipeline. This gets very tricky very quickly. So we should do the least + possible amount of preprocessing before hitting the model. + +Here, we'll do image resizing in the data pipeline (because a deep neural network can +only process contiguous batches of data), and we'll do the input value scaling as part +of the model, when we create it. + +Let's resize images to 150x150: +""" + +resize_fn = keras.layers.Resizing(150, 150) + +train_ds = train_ds.map(lambda x, y: (resize_fn(x), y)) +validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y)) +test_ds = test_ds.map(lambda x, y: (resize_fn(x), y)) + +""" +### Using random data augmentation + +When you don't have a large image dataset, it's a good practice to artificially +introduce sample diversity by applying random yet realistic transformations to +the training images, such as random horizontal flipping or small random rotations. This +helps expose the model to different aspects of the training data while slowing down +overfitting. +""" + +augmentation_layers = [ + layers.RandomFlip("horizontal"), + layers.RandomRotation(0.1), +] + + +def data_augmentation(x): + for layer in augmentation_layers: + x = layer(x) + return x + + +train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y)) + +""" +Let's batch the data and use prefetching to optimize loading speed. +""" + +from tensorflow import data as tf_data + +batch_size = 64 + +train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache() +validation_ds = ( + validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache() +) +test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache() + +""" +Let's visualize what the first image of the first batch looks like after various random + transformations: +""" + +for images, labels in train_ds.take(1): + plt.figure(figsize=(10, 10)) + first_image = images[0] + for i in range(9): + ax = plt.subplot(3, 3, i + 1) + augmented_image = data_augmentation(np.expand_dims(first_image, 0)) + plt.imshow(np.array(augmented_image[0]).astype("int32")) + plt.title(int(labels[0])) + plt.axis("off") + +""" +## Build a model + +Now let's built a model that follows the blueprint we've explained earlier. + +Note that: + +- We add a `Rescaling` layer to scale input values (initially in the `[0, 255]` + range) to the `[-1, 1]` range. +- We add a `Dropout` layer before the classification layer, for regularization. +- We make sure to pass `training=False` when calling the base model, so that +it runs in inference mode, so that batchnorm statistics don't get updated +even after we unfreeze the base model for fine-tuning. +""" + +base_model = keras.applications.Xception( + weights="imagenet", # Load weights pre-trained on ImageNet. + input_shape=(150, 150, 3), + include_top=False, +) # Do not include the ImageNet classifier at the top. + +# Freeze the base_model +base_model.trainable = False + +# Create new model on top +inputs = keras.Input(shape=(150, 150, 3)) + +# Pre-trained Xception weights requires that input be scaled +# from (0, 255) to a range of (-1., +1.), the rescaling layer +# outputs: `(inputs * scale) + offset` +scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1) +x = scale_layer(inputs) + +# The base model contains batchnorm layers. We want to keep them in inference mode +# when we unfreeze the base model for fine-tuning, so we make sure that the +# base_model is running in inference mode here. +x = base_model(x, training=False) +x = keras.layers.GlobalAveragePooling2D()(x) +x = keras.layers.Dropout(0.2)(x) # Regularize with dropout +outputs = keras.layers.Dense(1)(x) +model = keras.Model(inputs, outputs) + +model.summary(show_trainable=True) + +""" +## Train the top layer +""" + +model.compile( + optimizer=keras.optimizers.Adam(), + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.BinaryAccuracy()], +) + +epochs = 2 +print("Fitting the top layer of the model") +model.fit(train_ds, epochs=epochs, validation_data=validation_ds) + +""" +## Do a round of fine-tuning of the entire model + +Finally, let's unfreeze the base model and train the entire model end-to-end with a low + learning rate. + +Importantly, although the base model becomes trainable, it is still running in +inference mode since we passed `training=False` when calling it when we built the +model. This means that the batch normalization layers inside won't update their batch +statistics. If they did, they would wreck havoc on the representations learned by the + model so far. +""" + +# Unfreeze the base_model. Note that it keeps running in inference mode +# since we passed `training=False` when calling it. This means that +# the batchnorm layers will not update their batch statistics. +# This prevents the batchnorm layers from undoing all the training +# we've done so far. +base_model.trainable = True +model.summary(show_trainable=True) + +model.compile( + optimizer=keras.optimizers.Adam(1e-5), # Low learning rate + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.BinaryAccuracy()], +) + +epochs = 1 +print("Fitting the end-to-end model") +model.fit(train_ds, epochs=epochs, validation_data=validation_ds) + +""" +After 10 epochs, fine-tuning gains us a nice improvement here. +Let's evaluate the model on the test dataset: +""" + +print("Test dataset evaluation") +model.evaluate(test_ds) diff --git a/guides/understanding_masking_and_padding.py b/guides/understanding_masking_and_padding.py new file mode 100644 index 000000000000..5e5cad177b55 --- /dev/null +++ b/guides/understanding_masking_and_padding.py @@ -0,0 +1,380 @@ +""" +Title: Understanding masking & padding +Authors: Scott Zhu, Francois Chollet +Date created: 2019/07/16 +Last modified: 2023/06/25 +Description: Complete guide to using mask-aware sequence layers in Keras. +Accelerator: None +""" + +""" +## Setup +""" +import numpy as np +import keras +from keras import ops +from keras import layers + +""" +## Introduction + +**Masking** is a way to tell sequence-processing layers that certain timesteps +in an input are missing, and thus should be skipped when processing the data. + +**Padding** is a special form of masking where the masked steps are at the start or +the end of a sequence. Padding comes from the need to encode sequence data into +contiguous batches: in order to make all sequences in a batch fit a given standard +length, it is necessary to pad or truncate some sequences. + +Let's take a close look. +""" + +""" +## Padding sequence data + +When processing sequence data, it is very common for individual samples to have +different lengths. Consider the following example (text tokenized as words): + +``` +[ + ["Hello", "world", "!"], + ["How", "are", "you", "doing", "today"], + ["The", "weather", "will", "be", "nice", "tomorrow"], +] +``` + +After vocabulary lookup, the data might be vectorized as integers, e.g.: + +``` +[ + [71, 1331, 4231] + [73, 8, 3215, 55, 927], + [83, 91, 1, 645, 1253, 927], +] +``` + +The data is a nested list where individual samples have length 3, 5, and 6, +respectively. Since the input data for a deep learning model must be a single tensor +(of shape e.g. `(batch_size, 6, vocab_size)` in this case), samples that are shorter +than the longest item need to be padded with some placeholder value (alternatively, +one might also truncate long samples before padding short samples). + +Keras provides a utility function to truncate and pad Python lists to a common length: +`keras.utils.pad_sequences`. +""" + +raw_inputs = [ + [711, 632, 71], + [73, 8, 3215, 55, 927], + [83, 91, 1, 645, 1253, 927], +] + +# By default, this will pad using 0s; it is configurable via the +# "value" parameter. +# Note that you could use "pre" padding (at the beginning) or +# "post" padding (at the end). +# We recommend using "post" padding when working with RNN layers +# (in order to be able to use the +# CuDNN implementation of the layers). +padded_inputs = keras.utils.pad_sequences(raw_inputs, padding="post") +print(padded_inputs) + + +""" +## Masking + +Now that all samples have a uniform length, the model must be informed that some part +of the data is actually padding and should be ignored. That mechanism is **masking**. + +There are three ways to introduce input masks in Keras models: + +- Add a `keras.layers.Masking` layer. +- Configure a `keras.layers.Embedding` layer with `mask_zero=True`. +- Pass a `mask` argument manually when calling layers that support this argument (e.g. +RNN layers). +""" + +""" +## Mask-generating layers: `Embedding` and `Masking` + +Under the hood, these layers will create a mask tensor (2D tensor with shape `(batch, +sequence_length)`), and attach it to the tensor output returned by the `Masking` or +`Embedding` layer. +""" + +embedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True) +masked_output = embedding(padded_inputs) + +print(masked_output._keras_mask) + +masking_layer = layers.Masking() +# Simulate the embedding lookup by expanding the 2D input to 3D, +# with embedding dimension of 10. +unmasked_embedding = ops.cast( + ops.tile(ops.expand_dims(padded_inputs, axis=-1), [1, 1, 10]), + dtype="float32", +) + +masked_embedding = masking_layer(unmasked_embedding) +print(masked_embedding._keras_mask) + +""" +As you can see from the printed result, the mask is a 2D boolean tensor with shape +`(batch_size, sequence_length)`, where each individual `False` entry indicates that +the corresponding timestep should be ignored during processing. +""" + +""" +## Mask propagation in the Functional API and Sequential API + +When using the Functional API or the Sequential API, a mask generated by an `Embedding` +or `Masking` layer will be propagated through the network for any layer that is +capable of using them (for example, RNN layers). Keras will automatically fetch the +mask corresponding to an input and pass it to any layer that knows how to use it. + +For instance, in the following Sequential model, the `LSTM` layer will automatically +receive a mask, which means it will ignore padded values: +""" + +model = keras.Sequential( + [ + layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True), + layers.LSTM(32), + ] +) + +""" +This is also the case for the following Functional API model: +""" + +inputs = keras.Input(shape=(None,), dtype="int32") +x = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs) +outputs = layers.LSTM(32)(x) + +model = keras.Model(inputs, outputs) + +""" +## Passing mask tensors directly to layers +""" + +""" +Layers that can handle masks (such as the `LSTM` layer) have a `mask` argument in their +`__call__` method. + +Meanwhile, layers that produce a mask (e.g. `Embedding`) expose a `compute_mask(input, +previous_mask)` method which you can call. + +Thus, you can pass the output of the `compute_mask()` method of a mask-producing layer +to the `__call__` method of a mask-consuming layer, like this: + +""" + + +class MyLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.embedding = layers.Embedding( + input_dim=5000, output_dim=16, mask_zero=True + ) + self.lstm = layers.LSTM(32) + + def call(self, inputs): + x = self.embedding(inputs) + # Note that you could also prepare a `mask` tensor manually. + # It only needs to be a boolean tensor + # with the right shape, i.e. (batch_size, timesteps). + mask = self.embedding.compute_mask(inputs) + output = self.lstm( + x, mask=mask + ) # The layer will ignore the masked values + return output + + +layer = MyLayer() +x = np.random.random((32, 10)) * 100 +x = x.astype("int32") +layer(x) + +""" +## Supporting masking in your custom layers +""" + +""" +Sometimes, you may need to write layers that generate a mask (like `Embedding`), or +layers that need to modify the current mask. + +For instance, any layer that produces a tensor with a different time dimension than its +input, such as a `Concatenate` layer that concatenates on the time dimension, will +need to modify the current mask so that downstream layers will be able to properly +take masked timesteps into account. + +To do this, your layer should implement the `layer.compute_mask()` method, which +produces a new mask given the input and the current mask. + +Here is an example of a `TemporalSplit` layer that needs to modify the current mask. +""" + + +class TemporalSplit(keras.layers.Layer): + """Split the input tensor into 2 tensors along the time dimension.""" + + def call(self, inputs): + # Expect the input to be 3D and mask to be 2D, split the input tensor into 2 + # subtensors along the time axis (axis 1). + return ops.split(inputs, 2, axis=1) + + def compute_mask(self, inputs, mask=None): + # Also split the mask into 2 if it presents. + if mask is None: + return None + return ops.split(mask, 2, axis=1) + + +first_half, second_half = TemporalSplit()(masked_embedding) +print(first_half._keras_mask) +print(second_half._keras_mask) + +""" +Here is another example of a `CustomEmbedding` layer that is capable of generating a +mask from input values: +""" + + +class CustomEmbedding(keras.layers.Layer): + def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs): + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.mask_zero = mask_zero + + def build(self, input_shape): + self.embeddings = self.add_weight( + shape=(self.input_dim, self.output_dim), + initializer="random_normal", + dtype="float32", + ) + + def call(self, inputs): + inputs = ops.cast(inputs, "int32") + return ops.take(self.embeddings, inputs) + + def compute_mask(self, inputs, mask=None): + if not self.mask_zero: + return None + return ops.not_equal(inputs, 0) + + +layer = CustomEmbedding(10, 32, mask_zero=True) +x = np.random.random((3, 10)) * 9 +x = x.astype("int32") + +y = layer(x) +mask = layer.compute_mask(x) + +print(mask) + +""" +Note: For more details about format limitations related to masking, see the +[serialization guide](/guides/serialization_and_saving). +""" + +""" +## Opting-in to mask propagation on compatible layers + +Most layers don't modify the time dimension, so don't need to modify the current mask. +However, they may still want to be able to **propagate** the current mask, unchanged, +to the next layer. **This is an opt-in behavior.** By default, a custom layer will +destroy the current mask (since the framework has no way to tell whether propagating +the mask is safe to do). + +If you have a custom layer that does not modify the time dimension, and if you want it +to be able to propagate the current input mask, you should set `self.supports_masking += True` in the layer constructor. In this case, the default behavior of +`compute_mask()` is to just pass the current mask through. + +Here's an example of a layer that is whitelisted for mask propagation: + +""" + + +class MyActivation(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Signal that the layer is safe for mask propagation + self.supports_masking = True + + def call(self, inputs): + return ops.relu(inputs) + + +""" +You can now use this custom layer in-between a mask-generating layer (like `Embedding`) +and a mask-consuming layer (like `LSTM`), and it will pass the mask along so that it +reaches the mask-consuming layer. +""" + +inputs = keras.Input(shape=(None,), dtype="int32") +x = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs) +x = MyActivation()(x) # Will pass the mask along +print("Mask found:", x._keras_mask) +outputs = layers.LSTM(32)(x) # Will receive the mask + +model = keras.Model(inputs, outputs) +y = model(np.random.randint(0, 5000, size=(32, 100))) + +""" +## Writing layers that need mask information + +Some layers are mask *consumers*: they accept a `mask` argument in `call` and use it to +determine whether to skip certain time steps. + +To write such a layer, you can simply add a `mask=None` argument in your `call` +signature. The mask associated with the inputs will be passed to your layer whenever +it is available. + +Here's a simple example below: a layer that computes a softmax over the time dimension +(axis 1) of an input sequence, while discarding masked timesteps. +""" + + +class TemporalSoftmax(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def call(self, inputs, mask=None): + assert mask is not None + broadcast_float_mask = ops.expand_dims(ops.cast(mask, "float32"), -1) + inputs_exp = ops.exp(inputs) * broadcast_float_mask + inputs_sum = ops.sum( + inputs_exp * broadcast_float_mask, axis=-1, keepdims=True + ) + return inputs_exp / inputs_sum + + +inputs = keras.Input(shape=(None,), dtype="int32") +x = layers.Embedding(input_dim=10, output_dim=32, mask_zero=True)(inputs) +x = layers.Dense(1)(x) +outputs = TemporalSoftmax()(x) + +model = keras.Model(inputs, outputs) +y = model(np.random.randint(0, 10, size=(32, 100))) + +""" +## Summary + +That is all you need to know about padding & masking in Keras. To recap: + +- "Masking" is how layers are able to know when to skip / ignore certain timesteps in +sequence inputs. +- Some layers are mask-generators: `Embedding` can generate a mask from input values +(if `mask_zero=True`), and so can the `Masking` layer. +- Some layers are mask-consumers: they expose a `mask` argument in their `__call__` +method. This is the case for RNN layers. +- In the Functional API and Sequential API, mask information is propagated +automatically. +- When using layers in a standalone way, you can pass the `mask` arguments to layers +manually. +- You can easily write layers that modify the current mask, that generate a new mask, +or that consume the mask associated with the inputs. +""" diff --git a/guides/writing_a_custom_training_loop_in_jax.py b/guides/writing_a_custom_training_loop_in_jax.py new file mode 100644 index 000000000000..614376386708 --- /dev/null +++ b/guides/writing_a_custom_training_loop_in_jax.py @@ -0,0 +1,528 @@ +""" +Title: Writing a training loop from scratch in JAX +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/06/25 +Last modified: 2023/06/25 +Description: Writing low-level training & evaluation loops in JAX. +Accelerator: None +""" + +""" +## Setup +""" + +import os + +# This guide can only be run with the jax backend. +os.environ["KERAS_BACKEND"] = "jax" + +import jax + +# We import TF so we can use tf.data. +import tensorflow as tf +import keras +import numpy as np + +""" +## Introduction + +Keras provides default training and evaluation loops, `fit()` and `evaluate()`. +Their usage is covered in the guide +[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/). + +If you want to customize the learning algorithm of your model while still leveraging +the convenience of `fit()` +(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and +implement your own `train_step()` method, which +is called repeatedly during `fit()`. + +Now, if you want very low-level control over training & evaluation, you should write +your own training & evaluation loops from scratch. This is what this guide is about. +""" + +""" +## A first end-to-end example + +To write a custom training loop, we need the following ingredients: + +- A model to train, of course. +- An optimizer. You could either use an optimizer from `keras.optimizers`, or +one from the `optax` package. +- A loss function. +- A dataset. The standard in the JAX ecosystem is to load data via `tf.data`, +so that's what we'll use. + +Let's line them up. + +First, let's get the model and the MNIST dataset: +""" + + +def get_model(): + inputs = keras.Input(shape=(784,), name="digits") + x1 = keras.layers.Dense(64, activation="relu")(inputs) + x2 = keras.layers.Dense(64, activation="relu")(x1) + outputs = keras.layers.Dense(10, name="predictions")(x2) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +model = get_model() + +# Prepare the training dataset. +batch_size = 32 +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = np.reshape(x_train, (-1, 784)).astype("float32") +x_test = np.reshape(x_test, (-1, 784)).astype("float32") +y_train = keras.utils.to_categorical(y_train) +y_test = keras.utils.to_categorical(y_test) + +# Reserve 10,000 samples for validation. +x_val = x_train[-10000:] +y_val = y_train[-10000:] +x_train = x_train[:-10000] +y_train = y_train[:-10000] + +# Prepare the training dataset. +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size) + +# Prepare the validation dataset. +val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)) +val_dataset = val_dataset.batch(batch_size) + +""" +Next, here's the loss function and the optimizer. +We'll use a Keras optimizer in this case. +""" + +# Instantiate a loss function. +loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True) + +# Instantiate an optimizer. +optimizer = keras.optimizers.Adam(learning_rate=1e-3) + +""" +### Getting gradients in JAX + +Let's train our model using mini-batch gradient with a custom training loop. + +In JAX, gradients are computed via *metaprogramming*: you call the `jax.grad` (or +`jax.value_and_grad` on a function in order to create a gradient-computing function +for that first function. + +So the first thing we need is a function that returns the loss value. +That's the function we'll use to generate the gradient function. Something like this: + +```python +def compute_loss(x, y): + ... + return loss +``` + +Once you have such a function, you can compute gradients via metaprogramming as such: + +```python +grad_fn = jax.grad(compute_loss) +grads = grad_fn(x, y) +``` + +Typically, you don't just want to get the gradient values, you also want to get +the loss value. You can do this by using `jax.value_and_grad` instead of `jax.grad`: + +```python +grad_fn = jax.value_and_grad(compute_loss) +loss, grads = grad_fn(x, y) +``` + +### JAX computation is purely stateless + +In JAX, everything must be a stateless function -- so our loss computation function +must be stateless as well. That means that all Keras variables (e.g. weight tensors) +must be passed as function inputs, and any variable that has been updated during the +forward pass must be returned as function output. The function have no side effect. + +During the forward pass, the non-trainable variables of a Keras model might get +updated. These variables could be, for instance, RNG seed state variables or +BatchNormalization statistics. We're going to need to return those. So we need +something like this: + +```python +def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y): + ... + return loss, non_trainable_variables +``` + +Once you have such a function, you can get the gradient function by +specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss +computation function returns more outputs than just the loss. Note that the loss +should always be the first output. + +```python +grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True) +(loss, non_trainable_variables), grads = grad_fn( + trainable_variables, non_trainable_variables, x, y +) +``` + +Now that we have established the basics, +let's implement this `compute_loss_and_updates` function. +Keras models have a `stateless_call` method which will come in handy here. +It works just like `model.__call__`, but it requires you to explicitly +pass the value of all the variables in the model, and it returns not just +the `__call__` outputs but also the (potentially updated) non-trainable +variables. +""" + + +def compute_loss_and_updates( + trainable_variables, non_trainable_variables, x, y +): + y_pred, non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss = loss_fn(y, y_pred) + return loss, non_trainable_variables + + +""" +Let's get the gradient function: +""" + +grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True) + +""" +### The training step function + +Next, let's implement the end-to-end training step, the function +that will both run the forward pass, compute the loss, compute the gradients, +but also use the optimizer to update the trainable variables. This function +also needs to be stateless, so it will get as input a `state` tuple that +includes every state element we're going to use: + +- `trainable_variables` and `non_trainable_variables`: the model's variables. +- `optimizer_variables`: the optimizer's state variables, +such as momentum accumulators. + +To update the trainable variables, we use the optimizer's stateless method +`stateless_apply`. It's equivalent to `optimizer.apply()`, but it requires +always passing `trainable_variables` and `optimizer_variables`. It returns +both the updated trainable variables and the updated optimizer_variables. +""" + + +def train_step(state, data): + trainable_variables, non_trainable_variables, optimizer_variables = state + x, y = data + (loss, non_trainable_variables), grads = grad_fn( + trainable_variables, non_trainable_variables, x, y + ) + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + # Return updated state + return loss, ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) + + +""" +### Make it fast with `jax.jit` + +By default, JAX operations run eagerly, +just like in TensorFlow eager mode and PyTorch eager mode. +And just like TensorFlow eager mode and PyTorch eager mode, it's pretty slow +-- eager mode is better used as a debugging environment, not as a way to do +any actual work. So let's make our `train_step` fast by compiling it. + +When you have a stateless JAX function, you can compile it to XLA via the +`@jax.jit` decorator. It will get traced during its first execution, and in +subsequent executions you will be executing the traced graph (this is just +like `@tf.function(jit_compile=True)`. Let's try it: +""" + + +@jax.jit +def train_step(state, data): + trainable_variables, non_trainable_variables, optimizer_variables = state + x, y = data + (loss, non_trainable_variables), grads = grad_fn( + trainable_variables, non_trainable_variables, x, y + ) + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + # Return updated state + return loss, ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) + + +""" +We're now ready to train our model. The training loop itself +is trivial: we just repeatedly call `loss, state = train_step(state, data)`. + +Note: + +- We convert the TF tensors yielded by the `tf.data.Dataset` to NumPy +before passing them to our JAX function. +- All variables must be built beforehand: +the model must be built and the optimizer must be built. Since we're using a +Functional API model, it's already built, but if it were a subclassed model +you'd need to call it on a batch of data to build it. +""" + +# Build optimizer variables. +optimizer.build(model.trainable_variables) + +trainable_variables = model.trainable_variables +non_trainable_variables = model.non_trainable_variables +optimizer_variables = optimizer.variables +state = trainable_variables, non_trainable_variables, optimizer_variables + +# Training loop +for step, data in enumerate(train_dataset): + data = (data[0].numpy(), data[1].numpy()) + loss, state = train_step(state, data) + # Log every 100 batches. + if step % 100 == 0: + print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}") + print(f"Seen so far: {(step + 1) * batch_size} samples") + +""" +A key thing to notice here is that the loop is entirely stateless -- the variables +attached to the model (`model.weights`) are never getting updated during the loop. +Their new values are only stored in the `state` tuple. That means that at some point, +before saving the model, you should be attaching the new variable values back to the model. + +Just call `variable.assign(new_value)` on each model variable you want to update: +""" + +trainable_variables, non_trainable_variables, optimizer_variables = state +for variable, value in zip(model.trainable_variables, trainable_variables): + variable.assign(value) +for variable, value in zip( + model.non_trainable_variables, non_trainable_variables +): + variable.assign(value) + +""" +## Low-level handling of metrics + +Let's add metrics monitoring to this basic training loop. + +You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training +loops written from scratch. Here's the flow: + +- Instantiate the metric at the start of the loop +- Include `metric_variables` in the `train_step` arguments +and `compute_loss_and_updates` arguments. +- Call `metric.stateless_update_state()` in the `compute_loss_and_updates` function. +It's equivalent to `update_state()` -- only stateless. +- When you need to display the current value of the metric, outside the `train_step` +(in the eager scope), attach the new metric variable values to the metric object +and vall `metric.result()`. +- Call `metric.reset_state()` when you need to clear the state of the metric +(typically at the end of an epoch) + +Let's use this knowledge to compute `CategoricalAccuracy` on training and +validation data at the end of training: +""" + +# Get a fresh model +model = get_model() + +# Instantiate an optimizer to train the model. +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +# Instantiate a loss function. +loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True) + +# Prepare the metrics. +train_acc_metric = keras.metrics.CategoricalAccuracy() +val_acc_metric = keras.metrics.CategoricalAccuracy() + + +def compute_loss_and_updates( + trainable_variables, non_trainable_variables, metric_variables, x, y +): + y_pred, non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss = loss_fn(y, y_pred) + metric_variables = train_acc_metric.stateless_update_state( + metric_variables, y, y_pred + ) + return loss, (non_trainable_variables, metric_variables) + + +grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True) + + +@jax.jit +def train_step(state, data): + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metric_variables, + ) = state + x, y = data + (loss, (non_trainable_variables, metric_variables)), grads = grad_fn( + trainable_variables, non_trainable_variables, metric_variables, x, y + ) + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + # Return updated state + return loss, ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metric_variables, + ) + + +""" +We'll also prepare an evaluation step function: +""" + + +@jax.jit +def eval_step(state, data): + trainable_variables, non_trainable_variables, metric_variables = state + x, y = data + y_pred, non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss = loss_fn(y, y_pred) + metric_variables = val_acc_metric.stateless_update_state( + metric_variables, y, y_pred + ) + return loss, ( + trainable_variables, + non_trainable_variables, + metric_variables, + ) + + +""" +Here are our loops: +""" + +# Build optimizer variables. +optimizer.build(model.trainable_variables) + +trainable_variables = model.trainable_variables +non_trainable_variables = model.non_trainable_variables +optimizer_variables = optimizer.variables +metric_variables = train_acc_metric.variables +state = ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metric_variables, +) + +# Training loop +for step, data in enumerate(train_dataset): + data = (data[0].numpy(), data[1].numpy()) + loss, state = train_step(state, data) + # Log every 100 batches. + if step % 100 == 0: + print(f"Training loss (for 1 batch) at step {step}: {float(loss):.4f}") + _, _, _, metric_variables = state + for variable, value in zip( + train_acc_metric.variables, metric_variables + ): + variable.assign(value) + print(f"Training accuracy: {train_acc_metric.result()}") + print(f"Seen so far: {(step + 1) * batch_size} samples") + +metric_variables = val_acc_metric.variables +( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metric_variables, +) = state +state = trainable_variables, non_trainable_variables, metric_variables + +# Eval loop +for step, data in enumerate(val_dataset): + data = (data[0].numpy(), data[1].numpy()) + loss, state = eval_step(state, data) + # Log every 100 batches. + if step % 100 == 0: + print( + f"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}" + ) + _, _, metric_variables = state + for variable, value in zip(val_acc_metric.variables, metric_variables): + variable.assign(value) + print(f"Validation accuracy: {val_acc_metric.result()}") + print(f"Seen so far: {(step + 1) * batch_size} samples") + +""" +## Low-level handling of losses tracked by the model + +Layers & models recursively track any losses created during the forward pass +by layers that call `self.add_loss(value)`. The resulting list of scalar loss +values are available via the property `model.losses` +at the end of the forward pass. + +If you want to be using these loss components, you should sum them +and add them to the main loss in your training step. + +Consider this layer, that creates an activity regularization loss: +""" + + +class ActivityRegularizationLayer(keras.layers.Layer): + def call(self, inputs): + self.add_loss(1e-2 * jax.numpy.sum(inputs)) + return inputs + + +""" +Let's build a really simple model that uses it: +""" + +inputs = keras.Input(shape=(784,), name="digits") +x = keras.layers.Dense(64, activation="relu")(inputs) +# Insert activity regularization as a layer +x = ActivityRegularizationLayer()(x) +x = keras.layers.Dense(64, activation="relu")(x) +outputs = keras.layers.Dense(10, name="predictions")(x) + +model = keras.Model(inputs=inputs, outputs=outputs) + +""" +Here's what our `compute_loss_and_updates` function should look like now: + +- Pass `return_losses=True` to `model.stateless_call()`. +- Sum the resulting `losses` and add them to the main loss. +""" + + +def compute_loss_and_updates( + trainable_variables, non_trainable_variables, metric_variables, x, y +): + y_pred, non_trainable_variables, losses = model.stateless_call( + trainable_variables, non_trainable_variables, x, return_losses=True + ) + loss = loss_fn(y, y_pred) + if losses: + loss += jax.numpy.sum(losses) + metric_variables = train_acc_metric.stateless_update_state( + metric_variables, y, y_pred + ) + return loss, non_trainable_variables, metric_variables + + +""" +That's it! +""" diff --git a/guides/writing_a_custom_training_loop_in_tensorflow.py b/guides/writing_a_custom_training_loop_in_tensorflow.py new file mode 100644 index 000000000000..0e55a63e50db --- /dev/null +++ b/guides/writing_a_custom_training_loop_in_tensorflow.py @@ -0,0 +1,532 @@ +""" +Title: Writing a training loop from scratch in TensorFlow +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2019/03/01 +Last modified: 2023/06/25 +Description: Writing low-level training & evaluation loops in TensorFlow. +Accelerator: None +""" + +""" +## Setup +""" + +import time +import os + +# This guide can only be run with the TensorFlow backend. +os.environ["KERAS_BACKEND"] = "tensorflow" + +import tensorflow as tf +import keras +import numpy as np + +""" +## Introduction + +Keras provides default training and evaluation loops, `fit()` and `evaluate()`. +Their usage is covered in the guide +[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/). + +If you want to customize the learning algorithm of your model while still leveraging +the convenience of `fit()` +(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and +implement your own `train_step()` method, which +is called repeatedly during `fit()`. + +Now, if you want very low-level control over training & evaluation, you should write +your own training & evaluation loops from scratch. This is what this guide is about. +""" + +""" +## A first end-to-end example + +Let's consider a simple MNIST model: +""" + + +def get_model(): + inputs = keras.Input(shape=(784,), name="digits") + x1 = keras.layers.Dense(64, activation="relu")(inputs) + x2 = keras.layers.Dense(64, activation="relu")(x1) + outputs = keras.layers.Dense(10, name="predictions")(x2) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +model = get_model() + +""" +Let's train it using mini-batch gradient with a custom training loop. + +First, we're going to need an optimizer, a loss function, and a dataset: +""" + +# Instantiate an optimizer. +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +# Instantiate a loss function. +loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +# Prepare the training dataset. +batch_size = 32 +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = np.reshape(x_train, (-1, 784)) +x_test = np.reshape(x_test, (-1, 784)) + +# Reserve 10,000 samples for validation. +x_val = x_train[-10000:] +y_val = y_train[-10000:] +x_train = x_train[:-10000] +y_train = y_train[:-10000] + +# Prepare the training dataset. +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size) + +# Prepare the validation dataset. +val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)) +val_dataset = val_dataset.batch(batch_size) + +""" +Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of +the trainable weights of the layer with respect to a loss value. Using an optimizer +instance, you can use these gradients to update these variables (which you can +retrieve using `model.trainable_weights`). + +Here's our training loop, step by step: + +- We open a `for` loop that iterates over epochs +- For each epoch, we open a `for` loop that iterates over the dataset, in batches +- For each batch, we open a `GradientTape()` scope +- Inside this scope, we call the model (forward pass) and compute the loss +- Outside the scope, we retrieve the gradients of the weights +of the model with regard to the loss +- Finally, we use the optimizer to update the weights of the model based on the +gradients +""" + +epochs = 3 +for epoch in range(epochs): + print(f"\nStart of epoch {epoch}") + + # Iterate over the batches of the dataset. + for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): + # Open a GradientTape to record the operations run + # during the forward pass, which enables auto-differentiation. + with tf.GradientTape() as tape: + # Run the forward pass of the layer. + # The operations that the layer applies + # to its inputs are going to be recorded + # on the GradientTape. + logits = model( + x_batch_train, training=True + ) # Logits for this minibatch + + # Compute the loss value for this minibatch. + loss_value = loss_fn(y_batch_train, logits) + + # Use the gradient tape to automatically retrieve + # the gradients of the trainable variables with respect to the loss. + grads = tape.gradient(loss_value, model.trainable_weights) + + # Run one step of gradient descent by updating + # the value of the variables to minimize the loss. + optimizer.apply(grads, model.trainable_weights) + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + +""" +## Low-level handling of metrics + +Let's add metrics monitoring to this basic loop. + +You can readily reuse the built-in metrics (or custom ones you wrote) in such training +loops written from scratch. Here's the flow: + +- Instantiate the metric at the start of the loop +- Call `metric.update_state()` after each batch +- Call `metric.result()` when you need to display the current value of the metric +- Call `metric.reset_state()` when you need to clear the state of the metric +(typically at the end of an epoch) + +Let's use this knowledge to compute `SparseCategoricalAccuracy` on training and +validation data at the end of each epoch: +""" + +# Get a fresh model +model = get_model() + +# Instantiate an optimizer to train the model. +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +# Instantiate a loss function. +loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + +# Prepare the metrics. +train_acc_metric = keras.metrics.SparseCategoricalAccuracy() +val_acc_metric = keras.metrics.SparseCategoricalAccuracy() + +""" +Here's our training & evaluation loop: +""" + +epochs = 2 +for epoch in range(epochs): + print(f"\nStart of epoch {epoch}") + start_time = time.time() + + # Iterate over the batches of the dataset. + for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): + with tf.GradientTape() as tape: + logits = model(x_batch_train, training=True) + loss_value = loss_fn(y_batch_train, logits) + grads = tape.gradient(loss_value, model.trainable_weights) + optimizer.apply(grads, model.trainable_weights) + + # Update training metric. + train_acc_metric.update_state(y_batch_train, logits) + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + + # Display metrics at the end of each epoch. + train_acc = train_acc_metric.result() + print(f"Training acc over epoch: {float(train_acc):.4f}") + + # Reset training metrics at the end of each epoch + train_acc_metric.reset_state() + + # Run a validation loop at the end of each epoch. + for x_batch_val, y_batch_val in val_dataset: + val_logits = model(x_batch_val, training=False) + # Update val metrics + val_acc_metric.update_state(y_batch_val, val_logits) + val_acc = val_acc_metric.result() + val_acc_metric.reset_state() + print(f"Validation acc: {float(val_acc):.4f}") + print(f"Time taken: {time.time() - start_time:.2f}s") + +""" +## Speeding-up your training step with `tf.function` + +The default runtime in TensorFlow is eager execution. +As such, our training loop above executes eagerly. + +This is great for debugging, but graph compilation has a definite performance +advantage. Describing your computation as a static graph enables the framework +to apply global performance optimizations. This is impossible when +the framework is constrained to greedily execute one operation after another, +with no knowledge of what comes next. + +You can compile into a static graph any function that takes tensors as input. +Just add a `@tf.function` decorator on it, like this: +""" + + +@tf.function +def train_step(x, y): + with tf.GradientTape() as tape: + logits = model(x, training=True) + loss_value = loss_fn(y, logits) + grads = tape.gradient(loss_value, model.trainable_weights) + optimizer.apply(grads, model.trainable_weights) + train_acc_metric.update_state(y, logits) + return loss_value + + +""" +Let's do the same with the evaluation step: +""" + + +@tf.function +def test_step(x, y): + val_logits = model(x, training=False) + val_acc_metric.update_state(y, val_logits) + + +""" +Now, let's re-run our training loop with this compiled training step: +""" + +epochs = 2 +for epoch in range(epochs): + print(f"\nStart of epoch {epoch}") + start_time = time.time() + + # Iterate over the batches of the dataset. + for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): + loss_value = train_step(x_batch_train, y_batch_train) + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + + # Display metrics at the end of each epoch. + train_acc = train_acc_metric.result() + print(f"Training acc over epoch: {float(train_acc):.4f}") + + # Reset training metrics at the end of each epoch + train_acc_metric.reset_state() + + # Run a validation loop at the end of each epoch. + for x_batch_val, y_batch_val in val_dataset: + test_step(x_batch_val, y_batch_val) + + val_acc = val_acc_metric.result() + val_acc_metric.reset_state() + print(f"Validation acc: {float(val_acc):.4f}") + print(f"Time taken: {time.time() - start_time:.2f}s") + +""" +Much faster, isn't it? +""" + +""" +## Low-level handling of losses tracked by the model + +Layers & models recursively track any losses created during the forward pass +by layers that call `self.add_loss(value)`. The resulting list of scalar loss +values are available via the property `model.losses` +at the end of the forward pass. + +If you want to be using these loss components, you should sum them +and add them to the main loss in your training step. + +Consider this layer, that creates an activity regularization loss: + +""" + + +class ActivityRegularizationLayer(keras.layers.Layer): + def call(self, inputs): + self.add_loss(1e-2 * tf.reduce_sum(inputs)) + return inputs + + +""" +Let's build a really simple model that uses it: +""" + +inputs = keras.Input(shape=(784,), name="digits") +x = keras.layers.Dense(64, activation="relu")(inputs) +# Insert activity regularization as a layer +x = ActivityRegularizationLayer()(x) +x = keras.layers.Dense(64, activation="relu")(x) +outputs = keras.layers.Dense(10, name="predictions")(x) + +model = keras.Model(inputs=inputs, outputs=outputs) + +""" +Here's what our training step should look like now: +""" + + +@tf.function +def train_step(x, y): + with tf.GradientTape() as tape: + logits = model(x, training=True) + loss_value = loss_fn(y, logits) + # Add any extra losses created during the forward pass. + loss_value += sum(model.losses) + grads = tape.gradient(loss_value, model.trainable_weights) + optimizer.apply(grads, model.trainable_weights) + train_acc_metric.update_state(y, logits) + return loss_value + + +""" +## Summary + +Now you know everything there is to know about using built-in training loops and +writing your own from scratch. + +To conclude, here's a simple end-to-end example that ties together everything +you've learned in this guide: a DCGAN trained on MNIST digits. +""" + +""" +## End-to-end example: a GAN training loop from scratch + +You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new +images that look almost real, by learning the latent distribution of a training +dataset of images (the "latent space" of the images). + +A GAN is made of two parts: a "generator" model that maps points in the latent +space to points in image space, a "discriminator" model, a classifier +that can tell the difference between real images (from the training dataset) +and fake images (the output of the generator network). + +A GAN training loop looks like this: + +1) Train the discriminator. +- Sample a batch of random points in the latent space. +- Turn the points into fake images via the "generator" model. +- Get a batch of real images and combine them with the generated images. +- Train the "discriminator" model to classify generated vs. real images. + +2) Train the generator. +- Sample random points in the latent space. +- Turn the points into fake images via the "generator" network. +- Get a batch of real images and combine them with the generated images. +- Train the "generator" model to "fool" the discriminator and classify the fake images +as real. + +For a much more detailed overview of how GANs works, see +[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python). + +Let's implement this training loop. First, create the discriminator meant to classify +fake vs real digits: +""" + +discriminator = keras.Sequential( + [ + keras.Input(shape=(28, 28, 1)), + keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), + keras.layers.LeakyReLU(negative_slope=0.2), + keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), + keras.layers.LeakyReLU(negative_slope=0.2), + keras.layers.GlobalMaxPooling2D(), + keras.layers.Dense(1), + ], + name="discriminator", +) +discriminator.summary() + +""" +Then let's create a generator network, +that turns latent vectors into outputs of shape `(28, 28, 1)` (representing +MNIST digits): +""" + +latent_dim = 128 + +generator = keras.Sequential( + [ + keras.Input(shape=(latent_dim,)), + # We want to generate 128 coefficients to reshape into a 7x7x128 map + keras.layers.Dense(7 * 7 * 128), + keras.layers.LeakyReLU(negative_slope=0.2), + keras.layers.Reshape((7, 7, 128)), + keras.layers.Conv2DTranspose( + 128, (4, 4), strides=(2, 2), padding="same" + ), + keras.layers.LeakyReLU(negative_slope=0.2), + keras.layers.Conv2DTranspose( + 128, (4, 4), strides=(2, 2), padding="same" + ), + keras.layers.LeakyReLU(negative_slope=0.2), + keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), + ], + name="generator", +) + +""" +Here's the key bit: the training loop. As you can see it is quite straightforward. The +training step function only takes 17 lines. +""" + +# Instantiate one optimizer for the discriminator and another for the generator. +d_optimizer = keras.optimizers.Adam(learning_rate=0.0003) +g_optimizer = keras.optimizers.Adam(learning_rate=0.0004) + +# Instantiate a loss function. +loss_fn = keras.losses.BinaryCrossentropy(from_logits=True) + + +@tf.function +def train_step(real_images): + # Sample random points in the latent space + random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim)) + # Decode them to fake images + generated_images = generator(random_latent_vectors) + # Combine them with real images + combined_images = tf.concat([generated_images, real_images], axis=0) + + # Assemble labels discriminating real from fake images + labels = tf.concat( + [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0 + ) + # Add random noise to the labels - important trick! + labels += 0.05 * tf.random.uniform(labels.shape) + + # Train the discriminator + with tf.GradientTape() as tape: + predictions = discriminator(combined_images) + d_loss = loss_fn(labels, predictions) + grads = tape.gradient(d_loss, discriminator.trainable_weights) + d_optimizer.apply(grads, discriminator.trainable_weights) + + # Sample random points in the latent space + random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim)) + # Assemble labels that say "all real images" + misleading_labels = tf.zeros((batch_size, 1)) + + # Train the generator (note that we should *not* update the weights + # of the discriminator)! + with tf.GradientTape() as tape: + predictions = discriminator(generator(random_latent_vectors)) + g_loss = loss_fn(misleading_labels, predictions) + grads = tape.gradient(g_loss, generator.trainable_weights) + g_optimizer.apply(grads, generator.trainable_weights) + return d_loss, g_loss, generated_images + + +""" +Let's train our GAN, by repeatedly calling `train_step` on batches of images. + +Since our discriminator and generator are convnets, you're going to want to +run this code on a GPU. +""" + +# Prepare the dataset. We use both the training & test MNIST digits. +batch_size = 64 +(x_train, _), (x_test, _) = keras.datasets.mnist.load_data() +all_digits = np.concatenate([x_train, x_test]) +all_digits = all_digits.astype("float32") / 255.0 +all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) +dataset = tf.data.Dataset.from_tensor_slices(all_digits) +dataset = dataset.shuffle(buffer_size=1024).batch(batch_size) + +epochs = 1 # In practice you need at least 20 epochs to generate nice digits. +save_dir = "./" + +for epoch in range(epochs): + print(f"\nStart epoch {epoch}") + + for step, real_images in enumerate(dataset): + # Train the discriminator & generator on one batch of real images. + d_loss, g_loss, generated_images = train_step(real_images) + + # Logging. + if step % 100 == 0: + # Print metrics + print(f"discriminator loss at step {step}: {d_loss:.2f}") + print(f"adversarial loss at step {step}: {g_loss:.2f}") + + # Save one generated image + img = keras.utils.array_to_img( + generated_images[0] * 255.0, scale=False + ) + img.save(os.path.join(save_dir, f"generated_img_{step}.png")) + + # To limit execution time we stop after 10 steps. + # Remove the lines below to actually train the model! + if step > 10: + break + +""" +That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the +Colab GPU. +""" diff --git a/guides/writing_a_custom_training_loop_in_torch.py b/guides/writing_a_custom_training_loop_in_torch.py new file mode 100644 index 000000000000..a3641d1a88ad --- /dev/null +++ b/guides/writing_a_custom_training_loop_in_torch.py @@ -0,0 +1,386 @@ +""" +Title: Writing a training loop from scratch in PyTorch +Author: [fchollet](https://twitter.com/fchollet) +Date created: 2023/06/25 +Last modified: 2023/06/25 +Description: Writing low-level training & evaluation loops in PyTorch. +Accelerator: None +""" + +""" +## Setup +""" + +import os + +# This guide can only be run with the torch backend. +os.environ["KERAS_BACKEND"] = "torch" + +import torch +import keras +import numpy as np + +""" +## Introduction + +Keras provides default training and evaluation loops, `fit()` and `evaluate()`. +Their usage is covered in the guide +[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/). + +If you want to customize the learning algorithm of your model while still leveraging +the convenience of `fit()` +(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and +implement your own `train_step()` method, which +is called repeatedly during `fit()`. + +Now, if you want very low-level control over training & evaluation, you should write +your own training & evaluation loops from scratch. This is what this guide is about. +""" + +""" +## A first end-to-end example + +To write a custom training loop, we need the following ingredients: + +- A model to train, of course. +- An optimizer. You could either use a `keras.optimizers` optimizer, +or a native PyTorch optimizer from `torch.optim`. +- A loss function. You could either use a `keras.losses` loss, +or a native PyTorch loss from `torch.nn`. +- A dataset. You could use any format: a `tf.data.Dataset`, +a PyTorch `DataLoader`, a Python generator, etc. + +Let's line them up. We'll use torch-native objects in each case -- +except, of course, for the Keras model. + +First, let's get the model and the MNIST dataset: +""" + + +# Let's consider a simple MNIST model +def get_model(): + inputs = keras.Input(shape=(784,), name="digits") + x1 = keras.layers.Dense(64, activation="relu")(inputs) + x2 = keras.layers.Dense(64, activation="relu")(x1) + outputs = keras.layers.Dense(10, name="predictions")(x2) + model = keras.Model(inputs=inputs, outputs=outputs) + return model + + +# Create load up the MNIST dataset and put it in a torch DataLoader +# Prepare the training dataset. +batch_size = 32 +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = np.reshape(x_train, (-1, 784)).astype("float32") +x_test = np.reshape(x_test, (-1, 784)).astype("float32") +y_train = keras.utils.to_categorical(y_train) +y_test = keras.utils.to_categorical(y_test) + +# Reserve 10,000 samples for validation. +x_val = x_train[-10000:] +y_val = y_train[-10000:] +x_train = x_train[:-10000] +y_train = y_train[:-10000] + +# Create torch Datasets +train_dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_train), torch.from_numpy(y_train) +) +val_dataset = torch.utils.data.TensorDataset( + torch.from_numpy(x_val), torch.from_numpy(y_val) +) + +# Create DataLoaders for the Datasets +train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True +) +val_dataloader = torch.utils.data.DataLoader( + val_dataset, batch_size=batch_size, shuffle=False +) + +""" +Next, here's our PyTorch optimizer and our PyTorch loss function: +""" + +# Instantiate a torch optimizer +model = get_model() +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +# Instantiate a torch loss function +loss_fn = torch.nn.CrossEntropyLoss() + +""" +Let's train our model using mini-batch gradient with a custom training loop. + +Calling `loss.backward()` on a loss tensor triggers backpropagation. +Once that's done, your optimizer is magically aware of the gradients for each variable +and can update its variables, which is done via `optimizer.step()`. +Tensors, variables, optimizers are all interconnected to one another via hidden global state. +Also, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't +get the right gradients for your variables. + +Here's our training loop, step by step: + +- We open a `for` loop that iterates over epochs +- For each epoch, we open a `for` loop that iterates over the dataset, in batches +- For each batch, we call the model on the input data to retrieve the predictions, +then we use them to compute a loss value +- We call `loss.backward()` to +- Outside the scope, we retrieve the gradients of the weights +of the model with regard to the loss +- Finally, we use the optimizer to update the weights of the model based on the +gradients +""" + +epochs = 3 +for epoch in range(epochs): + for step, (inputs, targets) in enumerate(train_dataloader): + # Forward pass + logits = model(inputs) + loss = loss_fn(logits, targets) + + # Backward pass + model.zero_grad() + loss.backward() + + # Optimizer variable updates + optimizer.step() + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + +""" +As an alternative, let's look at what the loop looks like when using a Keras optimizer +and a Keras loss function. + +Important differences: + +- You retrieve the gradients for the variables via `v.value.grad`, +called on each trainable variable. +- You update your variables via `optimizer.apply()`, which must be +called in a `torch.no_grad()` scope. + +**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs +as well as Python `unittest` APIs use the argument order convention +`fn(y_true, y_pred)` (reference values first, predicted values second), +PyTorch actually uses `fn(y_pred, y_true)` for its losses. +So make sure to invert the order of `logits` and `targets`. +""" + +model = get_model() +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True) + +for epoch in range(epochs): + print(f"\nStart of epoch {epoch}") + for step, (inputs, targets) in enumerate(train_dataloader): + # Forward pass + logits = model(inputs) + loss = loss_fn(targets, logits) + + # Backward pass + model.zero_grad() + trainable_weights = [v for v in model.trainable_weights] + + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + optimizer.apply(gradients, trainable_weights) + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + +""" +## Low-level handling of metrics + +Let's add metrics monitoring to this basic training loop. + +You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training +loops written from scratch. Here's the flow: + +- Instantiate the metric at the start of the loop +- Call `metric.update_state()` after each batch +- Call `metric.result()` when you need to display the current value of the metric +- Call `metric.reset_state()` when you need to clear the state of the metric +(typically at the end of an epoch) + +Let's use this knowledge to compute `CategoricalAccuracy` on training and +validation data at the end of each epoch: +""" + +# Get a fresh model +model = get_model() + +# Instantiate an optimizer to train the model. +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +# Instantiate a loss function. +loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True) + +# Prepare the metrics. +train_acc_metric = keras.metrics.CategoricalAccuracy() +val_acc_metric = keras.metrics.CategoricalAccuracy() + +""" +Here's our training & evaluation loop: +""" + +for epoch in range(epochs): + print(f"\nStart of epoch {epoch}") + for step, (inputs, targets) in enumerate(train_dataloader): + # Forward pass + logits = model(inputs) + loss = loss_fn(targets, logits) + + # Backward pass + model.zero_grad() + trainable_weights = [v for v in model.trainable_weights] + + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + optimizer.apply(gradients, trainable_weights) + + # Update training metric. + train_acc_metric.update_state(targets, logits) + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + + # Display metrics at the end of each epoch. + train_acc = train_acc_metric.result() + print(f"Training acc over epoch: {float(train_acc):.4f}") + + # Reset training metrics at the end of each epoch + train_acc_metric.reset_state() + + # Run a validation loop at the end of each epoch. + for x_batch_val, y_batch_val in val_dataloader: + val_logits = model(x_batch_val, training=False) + # Update val metrics + val_acc_metric.update_state(y_batch_val, val_logits) + val_acc = val_acc_metric.result() + val_acc_metric.reset_state() + print(f"Validation acc: {float(val_acc):.4f}") + + +""" +## Low-level handling of losses tracked by the model + +Layers & models recursively track any losses created during the forward pass +by layers that call `self.add_loss(value)`. The resulting list of scalar loss +values are available via the property `model.losses` +at the end of the forward pass. + +If you want to be using these loss components, you should sum them +and add them to the main loss in your training step. + +Consider this layer, that creates an activity regularization loss: +""" + + +class ActivityRegularizationLayer(keras.layers.Layer): + def call(self, inputs): + self.add_loss(1e-2 * torch.sum(inputs)) + return inputs + + +""" +Let's build a really simple model that uses it: +""" + +inputs = keras.Input(shape=(784,), name="digits") +x = keras.layers.Dense(64, activation="relu")(inputs) +# Insert activity regularization as a layer +x = ActivityRegularizationLayer()(x) +x = keras.layers.Dense(64, activation="relu")(x) +outputs = keras.layers.Dense(10, name="predictions")(x) + +model = keras.Model(inputs=inputs, outputs=outputs) + +""" +Here's what our training loop should look like now: +""" + +# Get a fresh model +model = get_model() + +# Instantiate an optimizer to train the model. +optimizer = keras.optimizers.Adam(learning_rate=1e-3) +# Instantiate a loss function. +loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True) + +# Prepare the metrics. +train_acc_metric = keras.metrics.CategoricalAccuracy() +val_acc_metric = keras.metrics.CategoricalAccuracy() + +for epoch in range(epochs): + print(f"\nStart of epoch {epoch}") + for step, (inputs, targets) in enumerate(train_dataloader): + # Forward pass + logits = model(inputs) + loss = loss_fn(targets, logits) + if model.losses: + loss = loss + torch.sum(*model.losses) + + # Backward pass + model.zero_grad() + trainable_weights = [v for v in model.trainable_weights] + + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + optimizer.apply(gradients, trainable_weights) + + # Update training metric. + train_acc_metric.update_state(targets, logits) + + # Log every 100 batches. + if step % 100 == 0: + print( + f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}" + ) + print(f"Seen so far: {(step + 1) * batch_size} samples") + + # Display metrics at the end of each epoch. + train_acc = train_acc_metric.result() + print(f"Training acc over epoch: {float(train_acc):.4f}") + + # Reset training metrics at the end of each epoch + train_acc_metric.reset_state() + + # Run a validation loop at the end of each epoch. + for x_batch_val, y_batch_val in val_dataloader: + val_logits = model(x_batch_val, training=False) + # Update val metrics + val_acc_metric.update_state(y_batch_val, val_logits) + val_acc = val_acc_metric.result() + val_acc_metric.reset_state() + print(f"Validation acc: {float(val_acc):.4f}") + +""" +That's it! +""" diff --git a/guides/writing_your_own_callbacks.py b/guides/writing_your_own_callbacks.py new file mode 100644 index 000000000000..17d2da1b00db --- /dev/null +++ b/guides/writing_your_own_callbacks.py @@ -0,0 +1,444 @@ +""" +Title: Writing your own callbacks +Authors: Rick Chao, Francois Chollet +Date created: 2019/03/20 +Last modified: 2023/06/25 +Description: Complete guide to writing new Keras callbacks. +Accelerator: GPU +""" + +""" +## Introduction + +A callback is a powerful tool to customize the behavior of a Keras model during +training, evaluation, or inference. Examples include `keras.callbacks.TensorBoard` +to visualize training progress and results with TensorBoard, or +`keras.callbacks.ModelCheckpoint` to periodically save your model during training. + +In this guide, you will learn what a Keras callback is, what it can do, and how you can +build your own. We provide a few demos of simple callback applications to get you +started. +""" + +""" +## Setup +""" + +import numpy as np +import keras + +""" +## Keras callbacks overview + +All callbacks subclass the `keras.callbacks.Callback` class, and +override a set of methods called at various stages of training, testing, and +predicting. Callbacks are useful to get a view on internal states and statistics of +the model during training. + +You can pass a list of callbacks (as the keyword argument `callbacks`) to the following +model methods: + +- `keras.Model.fit()` +- `keras.Model.evaluate()` +- `keras.Model.predict()` +""" + +""" +## An overview of callback methods + +### Global methods + +#### `on_(train|test|predict)_begin(self, logs=None)` + +Called at the beginning of `fit`/`evaluate`/`predict`. + +#### `on_(train|test|predict)_end(self, logs=None)` + +Called at the end of `fit`/`evaluate`/`predict`. + +### Batch-level methods for training/testing/predicting + +#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)` + +Called right before processing a batch during training/testing/predicting. + +#### `on_(train|test|predict)_batch_end(self, batch, logs=None)` + +Called at the end of training/testing/predicting a batch. Within this method, `logs` is +a dict containing the metrics results. + +### Epoch-level methods (training only) + +#### `on_epoch_begin(self, epoch, logs=None)` + +Called at the beginning of an epoch during training. + +#### `on_epoch_end(self, epoch, logs=None)` + +Called at the end of an epoch during training. +""" + +""" +## A basic example + +Let's take a look at a concrete example. To get started, let's import tensorflow and +define a simple Sequential Keras model: +""" + + +# Define the Keras model to add callbacks to +def get_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(1)) + model.compile( + optimizer=keras.optimizers.RMSprop(learning_rate=0.1), + loss="mean_squared_error", + metrics=["mean_absolute_error"], + ) + return model + + +""" +Then, load the MNIST data for training and testing from Keras datasets API: +""" + +# Load example MNIST data and pre-process it +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = x_train.reshape(-1, 784).astype("float32") / 255.0 +x_test = x_test.reshape(-1, 784).astype("float32") / 255.0 + +# Limit the data to 1000 samples +x_train = x_train[:1000] +y_train = y_train[:1000] +x_test = x_test[:1000] +y_test = y_test[:1000] + +""" +Now, define a simple custom callback that logs: + +- When `fit`/`evaluate`/`predict` starts & ends +- When each epoch starts & ends +- When each training batch starts & ends +- When each evaluation (test) batch starts & ends +- When each inference (prediction) batch starts & ends +""" + + +class CustomCallback(keras.callbacks.Callback): + def on_train_begin(self, logs=None): + keys = list(logs.keys()) + print("Starting training; got log keys: {}".format(keys)) + + def on_train_end(self, logs=None): + keys = list(logs.keys()) + print("Stop training; got log keys: {}".format(keys)) + + def on_epoch_begin(self, epoch, logs=None): + keys = list(logs.keys()) + print( + "Start epoch {} of training; got log keys: {}".format(epoch, keys) + ) + + def on_epoch_end(self, epoch, logs=None): + keys = list(logs.keys()) + print("End epoch {} of training; got log keys: {}".format(epoch, keys)) + + def on_test_begin(self, logs=None): + keys = list(logs.keys()) + print("Start testing; got log keys: {}".format(keys)) + + def on_test_end(self, logs=None): + keys = list(logs.keys()) + print("Stop testing; got log keys: {}".format(keys)) + + def on_predict_begin(self, logs=None): + keys = list(logs.keys()) + print("Start predicting; got log keys: {}".format(keys)) + + def on_predict_end(self, logs=None): + keys = list(logs.keys()) + print("Stop predicting; got log keys: {}".format(keys)) + + def on_train_batch_begin(self, batch, logs=None): + keys = list(logs.keys()) + print( + "...Training: start of batch {}; got log keys: {}".format( + batch, keys + ) + ) + + def on_train_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + print( + "...Training: end of batch {}; got log keys: {}".format(batch, keys) + ) + + def on_test_batch_begin(self, batch, logs=None): + keys = list(logs.keys()) + print( + "...Evaluating: start of batch {}; got log keys: {}".format( + batch, keys + ) + ) + + def on_test_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + print( + "...Evaluating: end of batch {}; got log keys: {}".format( + batch, keys + ) + ) + + def on_predict_batch_begin(self, batch, logs=None): + keys = list(logs.keys()) + print( + "...Predicting: start of batch {}; got log keys: {}".format( + batch, keys + ) + ) + + def on_predict_batch_end(self, batch, logs=None): + keys = list(logs.keys()) + print( + "...Predicting: end of batch {}; got log keys: {}".format( + batch, keys + ) + ) + + +""" +Let's try it out: +""" + +model = get_model() +model.fit( + x_train, + y_train, + batch_size=128, + epochs=1, + verbose=0, + validation_split=0.5, + callbacks=[CustomCallback()], +) + +res = model.evaluate( + x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()] +) + +res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()]) + +""" +### Usage of `logs` dict + +The `logs` dict contains the loss value, and all the metrics at the end of a batch or +epoch. Example includes the loss and mean absolute error. +""" + + +class LossAndErrorPrintingCallback(keras.callbacks.Callback): + def on_train_batch_end(self, batch, logs=None): + print( + "Up to batch {}, the average loss is {:7.2f}.".format( + batch, logs["loss"] + ) + ) + + def on_test_batch_end(self, batch, logs=None): + print( + "Up to batch {}, the average loss is {:7.2f}.".format( + batch, logs["loss"] + ) + ) + + def on_epoch_end(self, epoch, logs=None): + print( + "The average loss for epoch {} is {:7.2f} " + "and mean absolute error is {:7.2f}.".format( + epoch, logs["loss"], logs["mean_absolute_error"] + ) + ) + + +model = get_model() +model.fit( + x_train, + y_train, + batch_size=128, + epochs=2, + verbose=0, + callbacks=[LossAndErrorPrintingCallback()], +) + +res = model.evaluate( + x_test, + y_test, + batch_size=128, + verbose=0, + callbacks=[LossAndErrorPrintingCallback()], +) + +""" +## Usage of `self.model` attribute + +In addition to receiving log information when one of their methods is called, +callbacks have access to the model associated with the current round of +training/evaluation/inference: `self.model`. + +Here are a few of the things you can do with `self.model` in a callback: + +- Set `self.model.stop_training = True` to immediately interrupt training. +- Mutate hyperparameters of the optimizer (available as `self.model.optimizer`), +such as `self.model.optimizer.learning_rate`. +- Save the model at period intervals. +- Record the output of `model.predict()` on a few test samples at the end of each +epoch, to use as a sanity check during training. +- Extract visualizations of intermediate features at the end of each epoch, to monitor +what the model is learning over time. +- etc. + +Let's see this in action in a couple of examples. +""" + +""" +## Examples of Keras callback applications + +### Early stopping at minimum loss + +This first example shows the creation of a `Callback` that stops training when the +minimum of loss has been reached, by setting the attribute `self.model.stop_training` +(boolean). Optionally, you can provide an argument `patience` to specify how many +epochs we should wait before stopping after having reached a local minimum. + +`keras.callbacks.EarlyStopping` provides a more complete and general implementation. +""" + + +class EarlyStoppingAtMinLoss(keras.callbacks.Callback): + """Stop training when the loss is at its min, i.e. the loss stops decreasing. + + Arguments: + patience: Number of epochs to wait after min has been hit. After this + number of no improvement, training stops. + """ + + def __init__(self, patience=0): + super().__init__() + self.patience = patience + # best_weights to store the weights at which the minimum loss occurs. + self.best_weights = None + + def on_train_begin(self, logs=None): + # The number of epoch it has waited when loss is no longer minimum. + self.wait = 0 + # The epoch the training stops at. + self.stopped_epoch = 0 + # Initialize the best as infinity. + self.best = np.inf + + def on_epoch_end(self, epoch, logs=None): + current = logs.get("loss") + if np.less(current, self.best): + self.best = current + self.wait = 0 + # Record the best weights if current results is better (less). + self.best_weights = self.model.get_weights() + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + print("Restoring model weights from the end of the best epoch.") + self.model.set_weights(self.best_weights) + + def on_train_end(self, logs=None): + if self.stopped_epoch > 0: + print(f"Epoch {self.stopped_epoch + 1}: early stopping") + + +model = get_model() +model.fit( + x_train, + y_train, + batch_size=64, + epochs=30, + verbose=0, + callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()], +) + +""" +### Learning rate scheduling + +In this example, we show how a custom Callback can be used to dynamically change the +learning rate of the optimizer during the course of training. + +See `callbacks.LearningRateScheduler` for a more general implementations. +""" + + +class CustomLearningRateScheduler(keras.callbacks.Callback): + """Learning rate scheduler which sets the learning rate according to schedule. + + Arguments: + schedule: a function that takes an epoch index + (integer, indexed from 0) and current learning rate + as inputs and returns a new learning rate as output (float). + """ + + def __init__(self, schedule): + super().__init__() + self.schedule = schedule + + def on_epoch_begin(self, epoch, logs=None): + if not hasattr(self.model.optimizer, "learning_rate"): + raise ValueError('Optimizer must have a "learning_rate" attribute.') + # Get the current learning rate from model's optimizer. + lr = self.model.optimizer.learning_rate + # Call schedule function to get the scheduled learning rate. + scheduled_lr = self.schedule(epoch, lr) + # Set the value back to the optimizer before this epoch starts + self.model.optimizer.learning_rate = scheduled_lr + print( + f"\nEpoch {epoch}: Learning rate is {float(np.array(scheduled_lr))}." + ) + + +LR_SCHEDULE = [ + # (epoch to start, learning rate) tuples + (3, 0.05), + (6, 0.01), + (9, 0.005), + (12, 0.001), +] + + +def lr_schedule(epoch, lr): + """Helper function to retrieve the scheduled learning rate based on epoch.""" + if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]: + return lr + for i in range(len(LR_SCHEDULE)): + if epoch == LR_SCHEDULE[i][0]: + return LR_SCHEDULE[i][1] + return lr + + +model = get_model() +model.fit( + x_train, + y_train, + batch_size=64, + epochs=15, + verbose=0, + callbacks=[ + LossAndErrorPrintingCallback(), + CustomLearningRateScheduler(lr_schedule), + ], +) + +""" +### Built-in Keras callbacks + +Be sure to check out the existing Keras callbacks by +reading the [API docs](https://keras.io/api/callbacks/). +Applications include logging to CSV, saving +the model, visualizing metrics in TensorBoard, and a lot more! +""" diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py new file mode 100644 index 000000000000..ae5c7a4c0449 --- /dev/null +++ b/integration_tests/basic_full_flow.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +import keras +from keras.src import layers +from keras.src import losses +from keras.src import metrics +from keras.src import optimizers +from keras.src import testing + + +class MyModel(keras.Model): + def __init__(self, hidden_dim, output_dim, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.dense1 = layers.Dense(hidden_dim, activation="relu") + self.dense2 = layers.Dense(hidden_dim, activation="relu") + self.dense3 = layers.Dense(output_dim) + + def call(self, x): + x = self.dense1(x) + x = self.dense2(x) + return self.dense3(x) + + +class BasicFlowTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basic_fit(self): + model = MyModel(hidden_dim=2, output_dim=1) + + x = np.random.random((128, 4)) + y = np.random.random((128, 4)) + batch_size = 32 + epochs = 3 + + model.compile( + optimizer=optimizers.SGD(learning_rate=0.001), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + output_before_fit = model(x) + model.fit( + x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 + ) + output_after_fit = model(x) + + self.assertNotAllClose(output_before_fit, output_after_fit) + + def test_basic_fit_no_training(self): + model = MyModel(hidden_dim=2, output_dim=1) + x = np.random.random((128, 4)) + model.predict(x) + model(x) diff --git a/integration_tests/dataset_tests/boston_housing_test.py b/integration_tests/dataset_tests/boston_housing_test.py new file mode 100644 index 000000000000..635738fe5f05 --- /dev/null +++ b/integration_tests/dataset_tests/boston_housing_test.py @@ -0,0 +1,22 @@ +from keras.src import testing +from keras.src.datasets import boston_housing + + +class BostonHousingTest(testing.TestCase): + def test_load_data(self): + (x_train, y_train), (x_test, y_test) = boston_housing.load_data() + self.assertEqual(x_train.shape[1], 13) + self.assertEqual(x_train.shape[0] + x_test.shape[0], 506) + + def test_seed_reproducibility(self): + seed = 123 + first_load = boston_housing.load_data(seed=seed) + second_load = boston_housing.load_data(seed=seed) + self.assertAllClose(first_load[0][0], second_load[0][0]) + self.assertAllClose(first_load[1][0], second_load[1][0]) + + def test_invalid_test_split(self): + with self.assertRaises(AssertionError): + boston_housing.load_data(test_split=-0.1) + with self.assertRaises(AssertionError): + boston_housing.load_data(test_split=1.0) diff --git a/integration_tests/dataset_tests/california_housing_test.py b/integration_tests/dataset_tests/california_housing_test.py new file mode 100644 index 000000000000..7f0cc4566177 --- /dev/null +++ b/integration_tests/dataset_tests/california_housing_test.py @@ -0,0 +1,32 @@ +from keras.src import testing +from keras.src.datasets import california_housing + + +class CaliforniaHousingTest(testing.TestCase): + def test_load_data_large(self): + (x_train, y_train), (x_test, y_test) = california_housing.load_data( + version="large" + ) + self.assertEqual(x_train.shape[1], 8) + # Ensure the dataset contains 20,640 samples as documented + self.assertEqual(x_train.shape[0] + x_test.shape[0], 20640) + + def test_load_data_small(self): + (x_train, y_train), (x_test, y_test) = california_housing.load_data( + version="small" + ) + self.assertEqual(x_train.shape[1], 8) + # Ensure the small dataset contains 600 samples as documented + self.assertEqual(x_train.shape[0] + x_test.shape[0], 600) + + def test_invalid_version(self): + with self.assertRaises(ValueError): + california_housing.load_data(version="invalid_version") + + def test_seed_reproducibility(self): + # Ensure the data is reproducible with the same seed + seed = 123 + first_load = california_housing.load_data(version="large", seed=seed) + second_load = california_housing.load_data(version="large", seed=seed) + self.assertAllClose(first_load[0][0], second_load[0][0]) + self.assertAllClose(first_load[1][0], second_load[1][0]) diff --git a/integration_tests/dataset_tests/cifar100_test.py b/integration_tests/dataset_tests/cifar100_test.py new file mode 100644 index 000000000000..7f2062e22403 --- /dev/null +++ b/integration_tests/dataset_tests/cifar100_test.py @@ -0,0 +1,35 @@ +import numpy as np + +from keras.src import testing +from keras.src.datasets import cifar100 + + +class Cifar100LoadDataTest(testing.TestCase): + def test_shapes_fine_label_mode(self): + (x_train, y_train), (x_test, y_test) = cifar100.load_data( + label_mode="fine" + ) + self.assertEqual(x_train.shape, (50000, 32, 32, 3)) + self.assertEqual(y_train.shape, (50000, 1)) + self.assertEqual(x_test.shape, (10000, 32, 32, 3)) + self.assertEqual(y_test.shape, (10000, 1)) + + def test_shapes_coarse_label_mode(self): + (x_train, y_train), (x_test, y_test) = cifar100.load_data( + label_mode="coarse" + ) + self.assertEqual(x_train.shape, (50000, 32, 32, 3)) + self.assertEqual(y_train.shape, (50000, 1)) + self.assertEqual(x_test.shape, (10000, 32, 32, 3)) + self.assertEqual(y_test.shape, (10000, 1)) + + def test_dtypes(self): + (x_train, y_train), (x_test, y_test) = cifar100.load_data() + self.assertEqual(x_train.dtype, np.uint8) + self.assertEqual(y_train.dtype, np.int64) + self.assertEqual(x_test.dtype, np.uint8) + self.assertEqual(y_test.dtype, np.int64) + + def test_invalid_label_mode(self): + with self.assertRaises(ValueError): + cifar100.load_data(label_mode="invalid") diff --git a/integration_tests/dataset_tests/cifar10_test.py b/integration_tests/dataset_tests/cifar10_test.py new file mode 100644 index 000000000000..fe1c20319b00 --- /dev/null +++ b/integration_tests/dataset_tests/cifar10_test.py @@ -0,0 +1,38 @@ +import numpy as np + +from keras.src import testing +from keras.src.datasets import cifar10 + + +class Cifar10LoadDataTest(testing.TestCase): + def test_x_train_shape(self): + (x_train, _), _ = cifar10.load_data() + self.assertEqual(x_train.shape, (50000, 32, 32, 3)) + + def test_y_train_shape(self): + (_, y_train), _ = cifar10.load_data() + self.assertEqual(y_train.shape, (50000, 1)) + + def test_x_test_shape(self): + _, (x_test, _) = cifar10.load_data() + self.assertEqual(x_test.shape, (10000, 32, 32, 3)) + + def test_y_test_shape(self): + _, (_, y_test) = cifar10.load_data() + self.assertEqual(y_test.shape, (10000, 1)) + + def test_x_train_dtype(self): + (x_train, _), _ = cifar10.load_data() + self.assertEqual(x_train.dtype, np.uint8) + + def test_y_train_dtype(self): + (_, y_train), _ = cifar10.load_data() + self.assertEqual(y_train.dtype, np.uint8) + + def test_x_test_dtype(self): + _, (x_test, _) = cifar10.load_data() + self.assertEqual(x_test.dtype, np.uint8) + + def test_y_test_dtype(self): + _, (_, y_test) = cifar10.load_data() + self.assertEqual(y_test.dtype, np.uint8) diff --git a/integration_tests/dataset_tests/fashion_mnist_test.py b/integration_tests/dataset_tests/fashion_mnist_test.py new file mode 100644 index 000000000000..92c43eeefe32 --- /dev/null +++ b/integration_tests/dataset_tests/fashion_mnist_test.py @@ -0,0 +1,38 @@ +import numpy as np + +from keras.src import testing +from keras.src.datasets import fashion_mnist + + +class FashionMnistLoadDataTest(testing.TestCase): + def test_x_train_shape(self): + (x_train, _), _ = fashion_mnist.load_data() + self.assertEqual(x_train.shape, (60000, 28, 28)) + + def test_y_train_shape(self): + (_, y_train), _ = fashion_mnist.load_data() + self.assertEqual(y_train.shape, (60000,)) + + def test_x_test_shape(self): + _, (x_test, _) = fashion_mnist.load_data() + self.assertEqual(x_test.shape, (10000, 28, 28)) + + def test_y_test_shape(self): + _, (_, y_test) = fashion_mnist.load_data() + self.assertEqual(y_test.shape, (10000,)) + + def test_x_train_dtype(self): + (x_train, _), _ = fashion_mnist.load_data() + self.assertEqual(x_train.dtype, np.uint8) + + def test_y_train_dtype(self): + (_, y_train), _ = fashion_mnist.load_data() + self.assertEqual(y_train.dtype, np.uint8) + + def test_x_test_dtype(self): + _, (x_test, _) = fashion_mnist.load_data() + self.assertEqual(x_test.dtype, np.uint8) + + def test_y_test_dtype(self): + _, (_, y_test) = fashion_mnist.load_data() + self.assertEqual(y_test.dtype, np.uint8) diff --git a/integration_tests/dataset_tests/imdb_test.py b/integration_tests/dataset_tests/imdb_test.py new file mode 100644 index 000000000000..a41bf6f971db --- /dev/null +++ b/integration_tests/dataset_tests/imdb_test.py @@ -0,0 +1,48 @@ +import numpy as np + +from keras.src import testing +from keras.src.datasets import imdb + + +class ImdbLoadDataTest(testing.TestCase): + def test_load_data_default(self): + (x_train, y_train), (x_test, y_test) = imdb.load_data() + self.assertIsInstance(x_train, np.ndarray) + self.assertIsInstance(y_train, np.ndarray) + self.assertIsInstance(x_test, np.ndarray) + self.assertIsInstance(y_test, np.ndarray) + + # Check lengths + self.assertEqual(len(x_train), 25000) + self.assertEqual(len(y_train), 25000) + self.assertEqual(len(x_test), 25000) + self.assertEqual(len(y_test), 25000) + + # Check types within lists for x + self.assertIsInstance(x_train[0], list) + self.assertIsInstance(x_test[0], list) + + def test_num_words(self): + # Only consider the top 1000 words + (x_train, _), _ = imdb.load_data(num_words=1000) + # Ensure that no word index exceeds 999 (0-based indexing) + max_index = max(max(sequence) for sequence in x_train if sequence) + self.assertLessEqual(max_index, 999) + + def test_skip_top(self): + # Skip the top 10 most frequent words + (x_train, _), _ = imdb.load_data(skip_top=10, num_words=1000) + # Check if top 10 words are skipped properly + self.assertNotIn(1, x_train[0]) # Assuming 1 is among top 10 + + def test_maxlen(self): + # Only consider sequences shorter than 100 + (x_train, _), _ = imdb.load_data(maxlen=100) + self.assertTrue(all(len(seq) <= 100 for seq in x_train)) + + def test_get_word_index(self): + word_index = imdb.get_word_index() + self.assertIsInstance(word_index, dict) + # Check if word_index contains specific known words + self.assertIn("the", word_index) + self.assertIn("and", word_index) diff --git a/integration_tests/dataset_tests/mnist_test.py b/integration_tests/dataset_tests/mnist_test.py new file mode 100644 index 000000000000..5aeaae4548bd --- /dev/null +++ b/integration_tests/dataset_tests/mnist_test.py @@ -0,0 +1,38 @@ +import numpy as np + +from keras.src import testing +from keras.src.datasets import mnist + + +class MnistLoadDataTest(testing.TestCase): + def test_x_train_shape(self): + (x_train, _), _ = mnist.load_data() + self.assertEqual(x_train.shape, (60000, 28, 28)) + + def test_y_train_shape(self): + (_, y_train), _ = mnist.load_data() + self.assertEqual(y_train.shape, (60000,)) + + def test_x_test_shape(self): + _, (x_test, _) = mnist.load_data() + self.assertEqual(x_test.shape, (10000, 28, 28)) + + def test_y_test_shape(self): + _, (_, y_test) = mnist.load_data() + self.assertEqual(y_test.shape, (10000,)) + + def test_x_train_dtype(self): + (x_train, _), _ = mnist.load_data() + self.assertEqual(x_train.dtype, np.uint8) + + def test_y_train_dtype(self): + (_, y_train), _ = mnist.load_data() + self.assertEqual(y_train.dtype, np.uint8) + + def test_x_test_dtype(self): + _, (x_test, _) = mnist.load_data() + self.assertEqual(x_test.dtype, np.uint8) + + def test_y_test_dtype(self): + _, (_, y_test) = mnist.load_data() + self.assertEqual(y_test.dtype, np.uint8) diff --git a/integration_tests/dataset_tests/reuters_test.py b/integration_tests/dataset_tests/reuters_test.py new file mode 100644 index 000000000000..3d83de560869 --- /dev/null +++ b/integration_tests/dataset_tests/reuters_test.py @@ -0,0 +1,51 @@ +import numpy as np + +from keras.src import testing +from keras.src.datasets import reuters + + +class ReutersLoadDataTest(testing.TestCase): + def test_load_data_default(self): + (x_train, y_train), (x_test, y_test) = reuters.load_data() + # Check types + self.assertIsInstance(x_train, np.ndarray) + self.assertIsInstance(y_train, np.ndarray) + self.assertIsInstance(x_test, np.ndarray) + self.assertIsInstance(y_test, np.ndarray) + + # Check shapes + self.assertGreater(len(x_train), 0) + self.assertEqual(len(x_train), len(y_train)) + self.assertGreater(len(x_test), 0) + self.assertEqual(len(x_test), len(y_test)) + + def test_num_words(self): + # Only consider the top 1000 words + (x_train, _), _ = reuters.load_data(num_words=1000) + # Ensure no word index exceeds 999 (0-based indexing) + max_index = max(max(sequence) for sequence in x_train if sequence) + self.assertLessEqual(max_index, 999) + + def test_skip_top(self): + # Skip the top 10 most frequent words + (x_train, _), _ = reuters.load_data(skip_top=10, num_words=1000) + # Assuming 1 is among top 10, check if it's skipped + self.assertNotIn(1, x_train[0]) + + def test_maxlen(self): + # Only consider sequences shorter than 50 + (x_train, _), _ = reuters.load_data(maxlen=50) + self.assertTrue(all(len(seq) <= 50 for seq in x_train)) + + def test_get_word_index(self): + word_index = reuters.get_word_index() + self.assertIsInstance(word_index, dict) + # Check if word_index contains specific known words + self.assertIn("the", word_index) + + def test_get_label_names(self): + label_names = reuters.get_label_names() + self.assertIsInstance(label_names, tuple) + # Check if the tuple contains specific known labels + self.assertIn("earn", label_names) + self.assertIn("acq", label_names) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py new file mode 100644 index 000000000000..f703797d5550 --- /dev/null +++ b/integration_tests/import_test.py @@ -0,0 +1,146 @@ +import os +import re +import subprocess + +from keras.src import backend +from keras.src.backend import config + +# For torch, use index url to avoid installing nvidia drivers for the test. +BACKEND_REQ = { + "tensorflow": ("tensorflow-cpu", ""), + "torch": ( + "torch", + "--extra-index-url https://download.pytorch.org/whl/cpu ", + ), + "jax": ("jax[cpu]", ""), + "openvino": ("openvino", ""), +} + + +def setup_package(): + subprocess.run("rm -rf tmp_build_dir", shell=True) + build_process = subprocess.run( + "python3 pip_build.py", + capture_output=True, + text=True, + shell=True, + ) + print(build_process.stdout) + whl_path = re.findall( + r"[^\s]*\.whl", + build_process.stdout, + ) + if not whl_path: + print(build_process.stdout) + print(build_process.stderr) + raise ValueError("Installing Keras package unsuccessful. ") + return whl_path[-1] + + +def create_virtualenv(): + env_setup = [ + # Create virtual environment + "python3 -m venv test_env", + ] + os.environ["PATH"] = os.pathsep.join( + ( + os.path.join(os.getcwd(), "test_env", "bin"), + os.environ.get("PATH", ""), + ) + ) + if os.name == "nt": + os.environ["PATH"] = os.pathsep.join( + ( + os.path.join(os.getcwd(), "test_env", "Scripts"), + os.environ["PATH"], + ) + ) + run_commands_local(env_setup) + + +def manage_venv_installs(whl_path): + other_backends = list(set(BACKEND_REQ.keys()) - {backend.backend()}) + backend_pkg, backend_extra_url = BACKEND_REQ[backend.backend()] + install_setup = [ + # Installs the backend's package and common requirements + f"pip install {backend_extra_url}{backend_pkg}", + "pip install -r requirements-common.txt", + "pip install pytest", + # Ensure other backends are uninstalled + "pip uninstall -y {0} {1} {2}".format( + BACKEND_REQ[other_backends[0]][0], + BACKEND_REQ[other_backends[1]][0], + BACKEND_REQ[other_backends[2]][0], + ), + # Install `.whl` package + f"pip install {whl_path}", + ] + # Install flax for JAX when NNX is enabled + if backend.backend() == "jax" and config.is_nnx_enabled(): + install_setup.append("pip install flax>=0.10.1") + run_commands_venv(install_setup) + + +def run_keras_flow(): + test_script = [ + # Runs the example script + "python -m pytest integration_tests/basic_full_flow.py", + ] + run_commands_venv(test_script) + + +def cleanup(): + cleanup_script = [ + # Exits virtual environment, deletes files, and any + # miscellaneous install logs + "exit", + "rm -rf test_env", + "rm -rf tmp_build_dir", + "rm -f *+cpu", + ] + run_commands_local(cleanup_script) + + +def run_commands_local(commands): + for command in commands: + print(f"Running command: {command}") + subprocess.run(command, shell=True) + + +def run_commands_venv(commands): + for command in commands: + print(f"Running command: {command}") + cmd_with_args = command.split(" ") + cmd_with_args[0] = os.path.join( + "test_env", + "Scripts" if os.name == "nt" else "bin", + cmd_with_args[0], + ) + p = subprocess.Popen(cmd_with_args) + assert p.wait() == 0 + + +def test_keras_imports(): + try: + # Ensures packages from all backends are installed. + # Builds Keras core package and returns package file path. + whl_path = setup_package() + + # Creates and activates a virtual environment. + create_virtualenv() + + # Ensures the backend's package is installed + # and the other backends are uninstalled. + manage_venv_installs(whl_path) + + # Runs test of basic flow in Keras Core. + # Tests for backend-specific imports and `model.fit()`. + run_keras_flow() + + # Removes virtual environment and associated files + finally: + cleanup() + + +if __name__ == "__main__": + test_keras_imports() diff --git a/integration_tests/jax_custom_fit_test.py b/integration_tests/jax_custom_fit_test.py new file mode 100644 index 000000000000..9c9eee59f114 --- /dev/null +++ b/integration_tests/jax_custom_fit_test.py @@ -0,0 +1,104 @@ +import jax +import numpy as np + +import keras + + +def test_custom_fit(): + class CustomModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae") + self.loss_fn = keras.losses.MeanSquaredError() + + def compute_loss_and_updates( + self, + trainable_variables, + non_trainable_variables, + x, + y, + training=False, + ): + y_pred, non_trainable_variables = self.stateless_call( + trainable_variables, + non_trainable_variables, + x, + training=training, + ) + loss = self.loss_fn(y, y_pred) + return loss, (y_pred, non_trainable_variables) + + def train_step(self, state, data): + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + x, y = data + grad_fn = jax.value_and_grad( + self.compute_loss_and_updates, has_aux=True + ) + (loss, (y_pred, non_trainable_variables)), grads = grad_fn( + trainable_variables, + non_trainable_variables, + x, + y, + training=True, + ) + ( + trainable_variables, + optimizer_variables, + ) = self.optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + loss_tracker_vars = metrics_variables[ + : len(self.loss_tracker.variables) + ] + mae_metric_vars = metrics_variables[ + len(self.loss_tracker.variables) : + ] + loss_tracker_vars = self.loss_tracker.stateless_update_state( + loss_tracker_vars, loss + ) + mae_metric_vars = self.mae_metric.stateless_update_state( + mae_metric_vars, y, y_pred + ) + logs = {} + logs[self.loss_tracker.name] = self.loss_tracker.stateless_result( + loss_tracker_vars + ) + logs[self.mae_metric.name] = self.mae_metric.stateless_result( + mae_metric_vars + ) + new_metrics_vars = loss_tracker_vars + mae_metric_vars + state = ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + new_metrics_vars, + ) + return logs, state + + @property + def metrics(self): + return [self.loss_tracker, self.mae_metric] + + inputs = keras.Input(shape=(32,)) + outputs = keras.layers.Dense(1)(inputs) + model = CustomModel(inputs, outputs) + model.compile(optimizer="adam") + x = np.random.random((64, 32)) + y = np.random.random((64, 1)) + history = model.fit(x, y, epochs=1) + + assert "loss" in history.history + assert "mae" in history.history + + print("History:") + print(history.history) + + +if __name__ == "__main__": + test_custom_fit() diff --git a/integration_tests/model_visualization_test.py b/integration_tests/model_visualization_test.py new file mode 100644 index 000000000000..965734958fe0 --- /dev/null +++ b/integration_tests/model_visualization_test.py @@ -0,0 +1,802 @@ +import re + +import keras +from keras.src import testing +from keras.src.utils import model_to_dot +from keras.src.utils import plot_model + + +class SubclassModel(keras.models.Model): + def __init__(self, name): + super().__init__(name=name) + + def call(self, x): + return x + + +def parse_text_from_html(html): + pattern = r"]*>(.*?)" + matches = re.findall(pattern, html) + + for match in matches: + clean_text = re.sub(r"<[^>]*>", "", match) + return clean_text + return "" + + +def get_node_text(node): + attributes = node.get_attributes() + + if "label" in attributes: + html = node.get_attributes()["label"] + return parse_text_from_html(html) + else: + return None + + +def get_edge_dict(dot): + def get_node_dict(graph, path=""): + nodes = { + node.get_name(): path + get_node_text(node) + for node in graph.get_nodes() + if node.get_name() != "node" # Dummy node inserted by pydot? + } + + for subgraph in graph.get_subgraphs(): + sub_nodes = get_node_dict( + subgraph, path=f"{path}{subgraph.get_label()} > " + ) + nodes.update(sub_nodes) + + return nodes + + node_dict = get_node_dict(dot) + + def get_edges(graph): + edges = list(graph.get_edges()) + for subgraph in graph.get_subgraphs(): + edges.extend(get_edges(subgraph)) + return edges + + edge_dict = dict() + dangling_edges = [] + for edge in get_edges(dot): + source_node = node_dict.get(edge.get_source(), None) + destination_node = node_dict.get(edge.get_destination(), None) + if source_node is None or destination_node is None: + dangling_edges.append( + f"from '{source_node}'/'{edge.get_source()}' " + f"to '{destination_node}'/'{edge.get_destination()}'" + ) + if source_node in edge_dict: + destination_nodes = edge_dict[source_node] + if not isinstance(destination_nodes, set): + destination_nodes = set([destination_nodes]) + edge_dict[source_node] = destination_nodes + destination_nodes.add(destination_node) + else: + edge_dict[source_node] = destination_node + + if dangling_edges: + raise ValueError(f"Dangling edges found: {dangling_edges}") + return edge_dict + + +class ModelVisualizationTest(testing.TestCase): + def multi_plot_model(self, model, name, expand_nested=False): + if expand_nested: + name = f"{name}-expand_nested" + + TEST_CASES = [ + {}, + { + "show_shapes": True, + }, + { + "show_shapes": True, + "show_dtype": True, + }, + { + "show_shapes": True, + "show_dtype": True, + "show_layer_names": True, + }, + { + "show_shapes": True, + "show_dtype": True, + "show_layer_names": True, + "show_layer_activations": True, + }, + { + "show_shapes": True, + "show_dtype": True, + "show_layer_names": True, + "show_layer_activations": True, + "show_trainable": True, + }, + { + "show_shapes": True, + "show_dtype": True, + "show_layer_names": True, + "show_layer_activations": True, + "show_trainable": True, + "rankdir": "LR", + }, + { + "show_layer_activations": True, + "show_trainable": True, + }, + ] + + for test_case in TEST_CASES: + tags = [v if k == "rankdir" else k for k, v in test_case.items()] + file_name = f"{'-'.join([name] + tags)}.png" + plot_model( + model, file_name, expand_nested=expand_nested, **test_case + ) + self.assertFileExists(file_name) + + def test_plot_sequential_model(self): + model = keras.Sequential( + [ + keras.Input((3,), name="input"), + keras.layers.Dense(4, activation="relu", name="dense"), + keras.layers.Dense(1, activation="sigmoid", name="dense_1"), + ] + ) + + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "dense (Dense)": "dense_1 (Dense)", + }, + ) + self.multi_plot_model(model, "sequential") + + def test_plot_functional_model(self): + inputs = keras.Input((3,), name="input") + x = keras.layers.Dense( + 4, activation="relu", trainable=False, name="dense" + )(inputs) + residual = x + x = keras.layers.Dense(4, activation="relu", name="dense_1")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_2")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_3")(x) + x += residual + residual = x + x = keras.layers.Dense(4, activation="relu", name="dense_4")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_5")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_6")(x) + x += residual + x = keras.layers.Dropout(0.5, name="dropout")(x) + outputs = keras.layers.Dense(1, activation="sigmoid", name="dense_7")(x) + + model = keras.Model(inputs, outputs) + + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "input (InputLayer)": "dense (Dense)", + "dense (Dense)": {"dense_1 (Dense)", "add (Add)"}, + "dense_1 (Dense)": "dense_2 (Dense)", + "dense_2 (Dense)": "dense_3 (Dense)", + "dense_3 (Dense)": "add (Add)", + "add (Add)": {"dense_4 (Dense)", "add_1 (Add)"}, + "dense_4 (Dense)": "dense_5 (Dense)", + "dense_5 (Dense)": "dense_6 (Dense)", + "dense_6 (Dense)": "add_1 (Add)", + "add_1 (Add)": "dropout (Dropout)", + "dropout (Dropout)": "dense_7 (Dense)", + }, + ) + self.multi_plot_model(model, "functional") + + def test_plot_subclassed_model(self): + model = SubclassModel(name="subclass") + model.build((None, 3)) + + self.multi_plot_model(model, "subclassed") + + def test_plot_nested_functional_model(self): + inputs = keras.Input((3,), name="input") + x = keras.layers.Dense(4, activation="relu", name="dense")(inputs) + x = keras.layers.Dense(4, activation="relu", name="dense_1")(x) + outputs = keras.layers.Dense(3, activation="relu", name="dense_2")(x) + inner_model = keras.Model(inputs, outputs, name="inner_model") + + inputs = keras.Input((3,), name="input_1") + x = keras.layers.Dense( + 3, activation="relu", trainable=False, name="dense_3" + )(inputs) + residual = x + x = inner_model(x) + x = keras.layers.Add(name="add")([x, residual]) + residual = x + x = keras.layers.Dense(4, activation="relu", name="dense_4")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_5")(x) + x = keras.layers.Dense(3, activation="relu", name="dense_6")(x) + x = keras.layers.Add(name="add_1")([x, residual]) + x = keras.layers.Dropout(0.5, name="dropout")(x) + outputs = keras.layers.Dense(1, activation="sigmoid", name="dense_7")(x) + model = keras.Model(inputs, outputs) + + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "input_1 (InputLayer)": "dense_3 (Dense)", + "dense_3 (Dense)": {"inner_model (Functional)", "add (Add)"}, + "inner_model (Functional)": "add (Add)", + "add (Add)": {"dense_4 (Dense)", "add_1 (Add)"}, + "dense_4 (Dense)": "dense_5 (Dense)", + "dense_5 (Dense)": "dense_6 (Dense)", + "dense_6 (Dense)": "add_1 (Add)", + "add_1 (Add)": "dropout (Dropout)", + "dropout (Dropout)": "dense_7 (Dense)", + }, + ) + self.multi_plot_model(model, "nested-functional") + + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + "input_1 (InputLayer)": "dense_3 (Dense)", + "dense_3 (Dense)": { + "inner_model > input (InputLayer)", + "add (Add)", + }, + "inner_model > input (InputLayer)": "inner_model > dense (Dense)", # noqa: E501 + "inner_model > dense (Dense)": "inner_model > dense_1 (Dense)", # noqa: E501 + "inner_model > dense_1 (Dense)": "inner_model > dense_2 (Dense)", # noqa: E501 + "inner_model > dense_2 (Dense)": "add (Add)", + "add (Add)": {"dense_4 (Dense)", "add_1 (Add)"}, + "dense_4 (Dense)": "dense_5 (Dense)", + "dense_5 (Dense)": "dense_6 (Dense)", + "dense_6 (Dense)": "add_1 (Add)", + "add_1 (Add)": "dropout (Dropout)", + "dropout (Dropout)": "dense_7 (Dense)", + }, + ) + self.multi_plot_model(model, "nested-functional", expand_nested=True) + + def test_plot_functional_model_with_splits_and_merges(self): + class SplitLayer(keras.Layer): + def call(self, x): + return list(keras.ops.split(x, 2, axis=1)) + + class ConcatLayer(keras.Layer): + def call(self, xs): + return keras.ops.concatenate(xs, axis=1) + + inputs = keras.Input((2,), name="input") + a, b = SplitLayer()(inputs) + + a = keras.layers.Dense(2, name="dense")(a) + b = keras.layers.Dense(2, name="dense_1")(b) + + outputs = ConcatLayer(name="concat_layer")([a, b]) + model = keras.Model(inputs, outputs) + + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "input (InputLayer)": "split_layer (SplitLayer)", + "split_layer (SplitLayer)": { + "dense (Dense)", + "dense_1 (Dense)", + }, + "dense (Dense)": "concat_layer (ConcatLayer)", + "dense_1 (Dense)": "concat_layer (ConcatLayer)", + }, + ) + self.multi_plot_model(model, "split-functional") + + def test_plot_sequential_in_sequential(self): + inner_model = keras.models.Sequential( + [ + keras.layers.Dense(10, name="dense2"), + keras.layers.Dense(10, name="dense3"), + ], + name="sub", + ) + model = keras.models.Sequential( + [ + keras.layers.Dense(10, name="dense1"), + inner_model, + ], + ) + model.build((1, 10)) + + # + # +-------------------------+ + # | dense1 (Dense) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | sub (Sequential) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "dense1 (Dense)": "sub (Sequential)", + }, + ) + self.multi_plot_model(model, "sequential_in_sequential") + + # + # +-------------------------+ + # | dense1 (Dense) | + # +-------------------------+ + # | + # +--------------|--------------+ + # | sub v | + # | +-------------------------+ | + # | | dense2 (Dense) | | + # | +-------------------------+ | + # | | | + # | v | + # | +-------------------------+ | + # | | dense3 (Dense) | | + # | +-------------------------+ | + # +-----------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + "dense1 (Dense)": "sub > dense2 (Dense)", + "sub > dense2 (Dense)": "sub > dense3 (Dense)", + }, + ) + self.multi_plot_model( + model, "sequential_in_sequential", expand_nested=True + ) + + def test_plot_functional_in_functional(self): + inner_input = keras.layers.Input((10,), name="inner_input") + x = keras.layers.Dense(10, name="dense1")(inner_input) + x = keras.layers.Dense(10, name="dense2")(x) + inner_model = keras.models.Model(inner_input, x, name="inner") + + outer_input = keras.layers.Input((10,), name="outer_input") + model = keras.models.Model(outer_input, inner_model(outer_input)) + + # + # +-------------------------+ + # |outer_input (InputLayer) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | inner (Functional) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "outer_input (InputLayer)": "inner (Functional)", + }, + ) + self.multi_plot_model(model, "functional_in_functional") + + # + # +-------------------------+ + # |outer_input (InputLayer) | + # +-------------------------+ + # | + # +--------------|--------------+ + # | inner v | + # | +-------------------------+ | + # | |inner_input (InputLayer) | | + # | +-------------------------+ | + # | | | + # | v | + # | +-------------------------+ | + # | | dense1 (Dense) | | + # | +-------------------------+ | + # | | | + # | v | + # | +-------------------------+ | + # | | dense2 (Dense) | | + # | +-------------------------+ | + # +-----------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + "outer_input (InputLayer)": "inner > inner_input (InputLayer)", + "inner > inner_input (InputLayer)": "inner > dense1 (Dense)", + "inner > dense1 (Dense)": "inner > dense2 (Dense)", + }, + ) + self.multi_plot_model( + model, "functional_in_functional", expand_nested=True + ) + + def test_plot_sequential_in_sequential_in_sequential(self): + inner_model = keras.models.Sequential( + [ + keras.layers.Dense(10, name="dense2"), + keras.layers.Dense(10, name="dense3"), + ], + name="inner", + ) + mid_model = keras.models.Sequential( + [ + inner_model, + ], + name="mid", + ) + model = keras.models.Sequential( + [ + keras.layers.Dense(10, name="dense1"), + mid_model, + keras.layers.Dense(10, name="dense4"), + ], + ) + model.build((1, 10)) + + # + # +-------------------------+ + # | dense1 (Dense) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | mid (Sequential) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | dense4 (Dense) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "dense1 (Dense)": "mid (Sequential)", + "mid (Sequential)": "dense4 (Dense)", + }, + ) + self.multi_plot_model(model, "sequential_in_sequential_in_sequential") + + # + # +-------------------------+ + # | dense1 (Dense) | + # +-------------------------+ + # | + # +----------------|----------------+ + # | mid | | + # | +--------------|--------------+ | + # | | inner v | | + # | | +-------------------------+ | | + # | | | dense2 (Dense) | | | + # | | +-------------------------+ | | + # | | | | | + # | | v | | + # | | +-------------------------+ | | + # | | | dense3 (Dense) | | | + # | | +-------------------------+ | | + # | +--------------|--------------+ | + # +----------------|----------------+ + # v + # +-------------------------+ + # | dense4 (Dense) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + "dense1 (Dense)": "mid > inner > dense2 (Dense)", + "mid > inner > dense2 (Dense)": "mid > inner > dense3 (Dense)", + "mid > inner > dense3 (Dense)": "dense4 (Dense)", + }, + ) + self.multi_plot_model( + model, "sequential_in_sequential_in_sequential", expand_nested=True + ) + + def test_plot_functional_in_sequential_in_sequential(self): + input1 = keras.layers.Input((10,), name="input1") + x = keras.layers.Dense(10, name="dense2")(input1) + inner_model = keras.models.Model(input1, x, name="inner") + + mid_model = keras.models.Sequential( + [ + inner_model, + ], + name="mid", + ) + model = keras.models.Sequential( + [ + keras.layers.Dense(10, name="dense1"), + mid_model, + keras.layers.Dense(10, name="dense3"), + ], + ) + model.build((1, 10)) + + # + # +-------------------------+ + # | dense1 (Dense) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | mid (Sequential) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | dense3 (Dense) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "dense1 (Dense)": "mid (Sequential)", + "mid (Sequential)": "dense3 (Dense)", + }, + ) + self.multi_plot_model(model, "functional_in_sequential_in_sequential") + + # + # +-------------------------+ + # | dense1 (Dense) | + # +-------------------------+ + # | + # +----------------|----------------+ + # | mid | | + # | +--------------|--------------+ | + # | | inner v | | + # | | +-------------------------+ | | + # | | | input1 (Inputlayer) | | | + # | | +-------------------------+ | | + # | | | | | + # | | v | | + # | | +-------------------------+ | | + # | | | dense2 (Dense) | | | + # | | +-------------------------+ | | + # | +--------------|--------------+ | + # +----------------|----------------+ + # v + # +-------------------------+ + # | dense3 (Dense) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + "dense1 (Dense)": "mid > inner > input1 (InputLayer)", + "mid > inner > input1 (InputLayer)": "mid > inner > dense2 (Dense)", # noqa: E501 + "mid > inner > dense2 (Dense)": "dense3 (Dense)", + }, + ) + self.multi_plot_model( + model, "functional_in_sequential_in_sequential", expand_nested=True + ) + + def test_plot_functional_in_functional_in_functional(self): + # From https://github.com/keras-team/keras/issues/21119 + inner_input = keras.layers.Input((10,), name="inner_input") + x = keras.layers.Dense(10, name="dense1")(inner_input) + inner_model = keras.models.Model(inner_input, x, name="inner") + + mid_input = keras.layers.Input((10,), name="mid_input") + mid_output = inner_model(mid_input) + mid_model = keras.models.Model(mid_input, mid_output, name="mid") + + outer_input = keras.layers.Input((10,), name="outer_input") + x = mid_model(outer_input) + x = keras.layers.Dense(10, name="dense2")(x) + model = keras.models.Model(outer_input, x) + + # + # +-------------------------+ + # |outer_input (InputLayer) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | mid (Functional) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | dense2 (Dense) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "outer_input (InputLayer)": "mid (Functional)", + "mid (Functional)": "dense2 (Dense)", + }, + ) + self.multi_plot_model(model, "functional_in_functional_in_functional") + + # + # +-------------------------+ + # |outer_input (InputLayer) | + # +-------------------------+ + # | + # +----------------|----------------+ + # | mid | | + # | +-------------------------+ | + # | | mid_input (Inputlayer) | | + # | +-------------------------+ | + # | +--------------|--------------+ | + # | | inner v | | + # | | +-------------------------+ | | + # | | |inner_input (Inputlayer) | | | + # | | +-------------------------+ | | + # | | | | | + # | | v | | + # | | +-------------------------+ | | + # | | | dense1 (Dense) | | | + # | | +-------------------------+ | | + # | +--------------|--------------+ | + # +----------------|----------------+ + # v + # +-------------------------+ + # | dense2 (Dense) | + # +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + "outer_input (InputLayer)": "mid > mid_input (InputLayer)", + "mid > mid_input (InputLayer)": "mid > inner > inner_input (InputLayer)", # noqa: E501 + "mid > inner > inner_input (InputLayer)": "mid > inner > dense1 (Dense)", # noqa: E501 + "mid > inner > dense1 (Dense)": "dense2 (Dense)", + }, + ) + self.multi_plot_model( + model, "functional_in_functional_in_functional", expand_nested=True + ) + + def test_plot_complex(self): + # Note: this test exercises the case when `output_index` is not 0 and + # changes when going deeply in nested models to resolve the destination + # of an edge. + inner_inpt1 = keras.layers.Input(shape=(10,), name="inner_inpt1") + inner_inpt2 = keras.layers.Input(shape=(10,), name="inner_inpt2") + inner_model = keras.models.Model( + [inner_inpt1, inner_inpt2], + [ + keras.layers.Dense(10, name="dense1")(inner_inpt1), + keras.layers.Dense(10, name="dense2")(inner_inpt2), + ], + name="inner", + ) + + input0 = keras.layers.Input(shape=(10,), name="input0") + input1 = keras.layers.Input(shape=(10,), name="input1") + input2 = keras.layers.Input(shape=(10,), name="input2") + input3 = keras.layers.Input(shape=(10,), name="input3") + + mid_sequential = keras.models.Sequential( + [ + keras.layers.Dense(10, name="dense0"), + SubclassModel(name="subclass0"), + ], + name="seq", + ) + mid_subclass = SubclassModel(name="subclass3") + mid_model = keras.models.Model( + [input0, input1, input2, input3], + [ + mid_sequential(input0), + *inner_model([input1, input2]), + mid_subclass(input3), + ], + name="mid", + ) + + outer_input = keras.layers.Input((10,), name="outer_input") + mid_outputs = mid_model( + [outer_input, outer_input, outer_input, outer_input] + ) + model = keras.models.Model( + outer_input, + [ + keras.layers.Add(name="add1")([mid_outputs[0], mid_outputs[1]]), + keras.layers.Add(name="add2")([mid_outputs[2], mid_outputs[3]]), + ], + ) + + # + # +-------------------------+ + # |outer_input (InputLayer) | + # +-------------------------+ + # | + # v + # +-------------------------+ + # | mid (Functional) | + # +-------------------------+ + # | | + # v v + # +-------------------------+ +-------------------------+ + # | add1 (Add) | | add2 (Add) | + # +-------------------------+ +-------------------------+ + # + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual( + edge_dict, + { + "outer_input (InputLayer)": "mid (Functional)", + "mid (Functional)": {"add1 (Add)", "add2 (Add)"}, + }, + ) + self.multi_plot_model(model, "complex") + + # + # +-----------+ + # +------------------|outer_input|-----------------+ + # | +-----------+ | + # | | | | + # +---------|-------------------|---------|------------------|-------+ + # | mid v v v v | + # | +-----------+ +-----------+ +-----------+ +-----------+ | + # | | input0 | | input1 | | input2 | | input3 | | + # | +-----------+ +-----------+ +-----------+ +-----------+ | + # | +-------|-------+ +-------|-------------|-------+ | | + # | | seq v | | inner v v | | | + # | | +-----------+ | | +-----------+ +-----------+ | +-----------+ | + # | | | dense0 | | | |inner_inp1t| |inner_inp2t| | | subclass3 | | + # | | +-----------+ | | +-----------+ +-----------+ | +-----------+ | + # | | | | | | | | | | + # | | v | | v v | | | + # | | +-----------+ | | +-----------+ +-----------+ | | | + # | | | subclass0 | | | | dense1 | | dense2 | | | | + # | | +-----------+ | | +-----------+ +-----------+ | | | + # | +-----------|---+ +---|---------------------|---+ | | + # +-------------|---------|---------------------|--------|-----------+ + # v v v v + # +-----------+ +-----------+ + # | add1 | | add2 | + # +-----------+ +-----------+ + # + edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True)) + self.assertEqual( + edge_dict, + { + # 1st row + "outer_input (InputLayer)": { + "mid > input0 (InputLayer)", + "mid > input1 (InputLayer)", + "mid > input2 (InputLayer)", + "mid > input3 (InputLayer)", + }, + # 2nd row + "mid > input0 (InputLayer)": "mid > seq > dense0 (Dense)", + "mid > input1 (InputLayer)": "mid > inner > inner_inpt1 (InputLayer)", # noqa: E501 + "mid > input2 (InputLayer)": "mid > inner > inner_inpt2 (InputLayer)", # noqa: E501 + "mid > input3 (InputLayer)": "mid > subclass3 (SubclassModel)", + # 3rd row + "mid > seq > dense0 (Dense)": "mid > seq > subclass0 (SubclassModel)", # noqa: E501 + "mid > inner > inner_inpt1 (InputLayer)": "mid > inner > dense1 (Dense)", # noqa: E501 + "mid > inner > inner_inpt2 (InputLayer)": "mid > inner > dense2 (Dense)", # noqa: E501 + # 4th row + "mid > seq > subclass0 (SubclassModel)": "add1 (Add)", + "mid > inner > dense1 (Dense)": "add1 (Add)", + "mid > inner > dense2 (Dense)": "add2 (Add)", + "mid > subclass3 (SubclassModel)": "add2 (Add)", + }, + ) + self.multi_plot_model(model, "complex", expand_nested=True) diff --git a/integration_tests/numerical_test.py b/integration_tests/numerical_test.py new file mode 100644 index 000000000000..39a077ff53c0 --- /dev/null +++ b/integration_tests/numerical_test.py @@ -0,0 +1,147 @@ +import keras # isort: skip, keep it on top for torch test + +import sys + +import numpy as np +import tf_keras + +keras.backend.set_image_data_format("channels_last") +tf_keras.backend.set_image_data_format("channels_last") + +NUM_CLASSES = 10 +BATCH_SIZE = 32 +EPOCHS = 1 + + +def build_mnist_data(num_classes): + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + + return x_train[:100], y_train[:100] + + +def build_keras_model(keras_module, num_classes): + input_shape = (28, 28, 1) + + model = keras_module.Sequential( + [ + keras_module.Input(shape=input_shape), + keras_module.layers.Conv2D( + 32, kernel_size=(3, 3), activation="relu" + ), + keras_module.layers.BatchNormalization(), + keras_module.layers.MaxPooling2D(pool_size=(2, 2)), + keras_module.layers.Conv2D( + 64, kernel_size=(3, 3), activation="relu" + ), + keras_module.layers.BatchNormalization(scale=False, center=True), + keras_module.layers.MaxPooling2D(pool_size=(2, 2)), + keras_module.layers.Flatten(), + keras_module.layers.Dense(num_classes, activation="softmax"), + ] + ) + return model + + +def compile_model(model): + model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["mae", "accuracy"], + jit_compile=False, + run_eagerly=True, + ) + + +def train_model(model, x, y): + return model.fit( + x, + y, + batch_size=BATCH_SIZE, + epochs=EPOCHS, + shuffle=False, + verbose=0, + ) + + +def eval_model(model, x, y): + score = model.evaluate(x, y, verbose=0, batch_size=BATCH_SIZE) + print(score) + return score + + +def check_history(h1, h2): + for key in h1.history.keys(): + print(f"{key}:") + print(h1.history[key]) + print(h2.history[key]) + np.testing.assert_allclose( + h1.history[key], + h2.history[key], + atol=1e-3, + ) + + +def predict_model(model, x): + return model.predict(x, batch_size=BATCH_SIZE, verbose=0) + + +def numerical_test(): + x_train, y_train = build_mnist_data(NUM_CLASSES) + keras_model = build_keras_model(keras, NUM_CLASSES) + tf_keras_model = build_keras_model(tf_keras, NUM_CLASSES) + + # Make sure both model have same weights before training + weights = [weight.numpy() for weight in keras_model.weights] + tf_keras_model.set_weights(weights) + + for kw, kcw in zip(keras_model.weights, tf_keras_model.weights): + np.testing.assert_allclose(kw.numpy(), kcw.numpy()) + + compile_model(keras_model) + compile_model(tf_keras_model) + + print("Checking training histories:") + keras_history = train_model(keras_model, x_train, y_train) + tf_keras_history = train_model(tf_keras_model, x_train, y_train) + check_history(keras_history, tf_keras_history) + print("Training histories match.") + print() + + print("Checking trained weights:") + for kw, kcw in zip(keras_model.weights, tf_keras_model.weights): + np.testing.assert_allclose(kw.numpy(), kcw.numpy(), atol=1e-3) + print("Trained weights match.") + print() + + print("Checking predict:") + outputs1 = predict_model(keras_model, x_train) + outputs2 = predict_model(tf_keras_model, x_train) + np.testing.assert_allclose(outputs1, outputs2, atol=1e-3) + print("Predict results match.") + print() + + print("Checking evaluate:") + score1 = eval_model(keras_model, x_train, y_train) + score2 = eval_model(tf_keras_model, x_train, y_train) + np.testing.assert_allclose(score1, score2, atol=1e-3) + print("Evaluate results match.") + + +if __name__ == "__main__": + if keras.backend.backend() == "openvino": + # this test requires trainable backend + sys.exit(0) + keras.utils.set_random_seed(1337) + tf_keras.utils.set_random_seed(1337) + numerical_test() diff --git a/integration_tests/tf_custom_fit_test.py b/integration_tests/tf_custom_fit_test.py new file mode 100644 index 000000000000..c409a7033b27 --- /dev/null +++ b/integration_tests/tf_custom_fit_test.py @@ -0,0 +1,50 @@ +import numpy as np +import tensorflow as tf + +import keras + + +def test_custom_fit(): + class CustomModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae") + self.loss_fn = keras.losses.MeanSquaredError() + + def train_step(self, data): + x, y = data + with tf.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.loss_fn(y, y_pred) + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + self.optimizer.apply(gradients, trainable_vars) + self.loss_tracker.update_state(loss) + self.mae_metric.update_state(y, y_pred) + return { + "loss": self.loss_tracker.result(), + "mae": self.mae_metric.result(), + } + + @property + def metrics(self): + return [self.loss_tracker, self.mae_metric] + + inputs = keras.Input(shape=(32,)) + outputs = keras.layers.Dense(1)(inputs) + model = CustomModel(inputs, outputs) + model.compile(optimizer="adam") + x = np.random.random((64, 32)) + y = np.random.random((64, 1)) + history = model.fit(x, y, epochs=1) + + assert "loss" in history.history + assert "mae" in history.history + + print("History:") + print(history.history) + + +if __name__ == "__main__": + test_custom_fit() diff --git a/integration_tests/tf_distribute_training_test.py b/integration_tests/tf_distribute_training_test.py new file mode 100644 index 000000000000..ec2a7d5bfb92 --- /dev/null +++ b/integration_tests/tf_distribute_training_test.py @@ -0,0 +1,76 @@ +import numpy as np +import tensorflow as tf + +import keras +from keras.src import layers +from keras.src import losses +from keras.src import metrics +from keras.src import models +from keras.src import optimizers +from keras.src.callbacks import LearningRateScheduler + + +def test_model_fit(): + cpus = tf.config.list_physical_devices("CPU") + tf.config.set_logical_device_configuration( + cpus[0], + [ + tf.config.LogicalDeviceConfiguration(), + tf.config.LogicalDeviceConfiguration(), + ], + ) + + keras.utils.set_random_seed(1337) + + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + with strategy.scope(): + inputs = layers.Input((100,), batch_size=32) + x = layers.Dense(256, activation="relu")(inputs) + x = layers.Dense(256, activation="relu")(x) + x = layers.Dense(256, activation="relu")(x) + x = layers.BatchNormalization()(x) + outputs = layers.Dense(16)(x) + model = models.Model(inputs, outputs) + + callbacks = [LearningRateScheduler(lambda _: 0.1)] + + model.summary() + + x = np.random.random((5000, 100)) + y = np.random.random((5000, 16)) + batch_size = 32 + epochs = 2 + + # Fit from numpy arrays: + with strategy.scope(): + model.compile( + optimizer=optimizers.LossScaleOptimizer( + optimizers.SGD(learning_rate=0.001, momentum=0.01) + ), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + history = model.fit( + x, + y, + batch_size=batch_size, + epochs=epochs, + validation_split=0.2, + callbacks=callbacks, + ) + + print("History:") + print(history.history) + + # Fit again from distributed dataset: + with strategy.scope(): + dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size) + dataset = strategy.experimental_distribute_dataset(dataset) + history = model.fit(dataset, epochs=epochs, callbacks=callbacks) + + print("History:") + print(history.history) + + +if __name__ == "__main__": + test_model_fit() diff --git a/integration_tests/torch_custom_fit_test.py b/integration_tests/torch_custom_fit_test.py new file mode 100644 index 000000000000..24201eab1e80 --- /dev/null +++ b/integration_tests/torch_custom_fit_test.py @@ -0,0 +1,52 @@ +import numpy as np +import torch + +import keras + + +def test_custom_fit(): + class CustomModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae") + self.loss_fn = keras.losses.MeanSquaredError() + + def train_step(self, data): + x, y = data + self.zero_grad() + y_pred = self(x, training=True) + loss = self.loss_fn(y, y_pred) + loss.backward() + trainable_weights = [v for v in self.trainable_weights] + gradients = [v.value.grad for v in trainable_weights] + with torch.no_grad(): + self.optimizer.apply(gradients, trainable_weights) + self.loss_tracker.update_state(loss) + self.mae_metric.update_state(y, y_pred) + return { + "loss": self.loss_tracker.result(), + "mae": self.mae_metric.result(), + } + + @property + def metrics(self): + return [self.loss_tracker, self.mae_metric] + + inputs = keras.Input(shape=(32,)) + outputs = keras.layers.Dense(1)(inputs) + model = CustomModel(inputs, outputs) + model.compile(optimizer="adam") + x = np.random.random((64, 32)) + y = np.random.random((64, 1)) + history = model.fit(x, y, epochs=1) + + assert "loss" in history.history + assert "mae" in history.history + + print("History:") + print(history.history) + + +if __name__ == "__main__": + test_custom_fit() diff --git a/integration_tests/torch_workflow_test.py b/integration_tests/torch_workflow_test.py new file mode 100644 index 000000000000..3737197b86e5 --- /dev/null +++ b/integration_tests/torch_workflow_test.py @@ -0,0 +1,34 @@ +import torch + +from keras.src import layers +from keras.src import testing +from keras.src.backend.common import KerasVariable + + +class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = layers.Dense(1) + + def forward(self, x): + x = self.fc1(x) + return x + + +class TorchWorkflowTest(testing.TestCase): + def test_keras_layer_in_nn_module(self): + net = Net() + + # Test using Keras layer in a nn.Module. + # Test forward pass + self.assertAllEqual(list(net(torch.empty(100, 10)).shape), [100, 1]) + # Test KerasVariables are added as nn.Parameter. + self.assertLen(list(net.parameters()), 2) + + # Test using KerasVariable as a torch tensor for torch ops. + kernel = net.fc1.kernel + transposed_kernel = torch.transpose(kernel, 0, 1) + self.assertIsInstance(kernel, KerasVariable) + self.assertIsInstance( + torch.mul(kernel, transposed_kernel), torch.Tensor + ) diff --git a/keras/__init__.py b/keras/__init__.py index e69de29bb2d1..0dc0f6aad102 100644 --- a/keras/__init__.py +++ b/keras/__init__.py @@ -0,0 +1,13 @@ +# This file should NEVER be packaged! This is a hack to make "import keras" from +# the base of the repo just import the source files. We'll keep it for compat. + +import os # isort: skip + +# Add everything in /api/ to the module search path. +__path__.append(os.path.join(os.path.dirname(__file__), "api")) # noqa: F405 + +from keras.api import * # noqa: F403, E402 +from keras.api import __version__ # noqa: E402 + +# Don't pollute namespace. +del os diff --git a/keras/activations.py b/keras/activations.py deleted file mode 100644 index 0e7f19ddd657..000000000000 --- a/keras/activations.py +++ /dev/null @@ -1,28 +0,0 @@ -import theano -import theano.tensor as T -import types - -def softmax(x): - return T.nnet.softmax(x) - -def softplus(x): - return T.nnet.softplus(x) - -def relu(x): - return (x + abs(x)) / 2.0 - -def tanh(x): - return T.tanh(x) - -def sigmoid(x): - return T.nnet.sigmoid(x) - -def hard_sigmoid(x): - return T.nnet.hard_sigmoid(x) - -def linear(x): - return x - -from utils.generic_utils import get_from_module -def get(identifier): - return get_from_module(identifier, globals(), 'activation function') \ No newline at end of file diff --git a/keras/api/__init__.py b/keras/api/__init__.py new file mode 100644 index 000000000000..dee6cea5bb19 --- /dev/null +++ b/keras/api/__init__.py @@ -0,0 +1,66 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras import _tf_keras as _tf_keras +from keras import activations as activations +from keras import applications as applications +from keras import backend as backend +from keras import callbacks as callbacks +from keras import config as config +from keras import constraints as constraints +from keras import datasets as datasets +from keras import distribution as distribution +from keras import dtype_policies as dtype_policies +from keras import export as export +from keras import initializers as initializers +from keras import layers as layers +from keras import legacy as legacy +from keras import losses as losses +from keras import metrics as metrics +from keras import mixed_precision as mixed_precision +from keras import models as models +from keras import ops as ops +from keras import optimizers as optimizers +from keras import preprocessing as preprocessing +from keras import quantizers as quantizers +from keras import random as random +from keras import regularizers as regularizers +from keras import saving as saving +from keras import tree as tree +from keras import utils as utils +from keras import visualization as visualization +from keras import wrappers as wrappers +from keras.src.backend import Variable as Variable +from keras.src.backend import device as device +from keras.src.backend import name_scope as name_scope +from keras.src.backend.common.keras_tensor import KerasTensor as KerasTensor +from keras.src.backend.common.remat import RematScope as RematScope +from keras.src.backend.common.remat import remat as remat +from keras.src.backend.common.stateless_scope import ( + StatelessScope as StatelessScope, +) +from keras.src.backend.common.symbolic_scope import ( + SymbolicScope as SymbolicScope, +) +from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy +from keras.src.dtype_policies.dtype_policy import ( + FloatDTypePolicy as FloatDTypePolicy, +) +from keras.src.initializers.initializer import Initializer as Initializer +from keras.src.layers.core.input_layer import Input as Input +from keras.src.layers.input_spec import InputSpec as InputSpec +from keras.src.layers.layer import Layer as Layer +from keras.src.losses.loss import Loss as Loss +from keras.src.metrics.metric import Metric as Metric +from keras.src.models.model import Model as Model +from keras.src.models.sequential import Sequential as Sequential +from keras.src.ops.function import Function as Function +from keras.src.ops.operation import Operation as Operation +from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.quantizers.quantizers import Quantizer as Quantizer +from keras.src.regularizers.regularizers import Regularizer as Regularizer +from keras.src.version import __version__ as __version__ +from keras.src.version import version as version diff --git a/keras/api/_tf_keras/__init__.py b/keras/api/_tf_keras/__init__.py new file mode 100644 index 000000000000..4c0e16d122e4 --- /dev/null +++ b/keras/api/_tf_keras/__init__.py @@ -0,0 +1 @@ +from keras._tf_keras import keras diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py new file mode 100644 index 000000000000..67d4738a0f3c --- /dev/null +++ b/keras/api/_tf_keras/keras/__init__.py @@ -0,0 +1,64 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras import activations as activations +from keras import applications as applications +from keras import callbacks as callbacks +from keras import config as config +from keras import constraints as constraints +from keras import datasets as datasets +from keras import distribution as distribution +from keras import dtype_policies as dtype_policies +from keras import export as export +from keras import initializers as initializers +from keras import legacy as legacy +from keras import mixed_precision as mixed_precision +from keras import models as models +from keras import ops as ops +from keras import optimizers as optimizers +from keras import quantizers as quantizers +from keras import random as random +from keras import regularizers as regularizers +from keras import tree as tree +from keras import utils as utils +from keras import visualization as visualization +from keras import wrappers as wrappers +from keras._tf_keras.keras import backend as backend +from keras._tf_keras.keras import layers as layers +from keras._tf_keras.keras import losses as losses +from keras._tf_keras.keras import metrics as metrics +from keras._tf_keras.keras import preprocessing as preprocessing +from keras.src.backend import Variable as Variable +from keras.src.backend import device as device +from keras.src.backend import name_scope as name_scope +from keras.src.backend.common.keras_tensor import KerasTensor as KerasTensor +from keras.src.backend.common.remat import RematScope as RematScope +from keras.src.backend.common.remat import remat as remat +from keras.src.backend.common.stateless_scope import ( + StatelessScope as StatelessScope, +) +from keras.src.backend.common.symbolic_scope import ( + SymbolicScope as SymbolicScope, +) +from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy +from keras.src.dtype_policies.dtype_policy import ( + FloatDTypePolicy as FloatDTypePolicy, +) +from keras.src.initializers.initializer import Initializer as Initializer +from keras.src.layers.core.input_layer import Input as Input +from keras.src.layers.input_spec import InputSpec as InputSpec +from keras.src.layers.layer import Layer as Layer +from keras.src.losses.loss import Loss as Loss +from keras.src.metrics.metric import Metric as Metric +from keras.src.models.model import Model as Model +from keras.src.models.sequential import Sequential as Sequential +from keras.src.ops.function import Function as Function +from keras.src.ops.operation import Operation as Operation +from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.quantizers.quantizers import Quantizer as Quantizer +from keras.src.regularizers.regularizers import Regularizer as Regularizer +from keras.src.version import __version__ as __version__ +from keras.src.version import version as version diff --git a/keras/api/_tf_keras/keras/activations/__init__.py b/keras/api/_tf_keras/keras/activations/__init__.py new file mode 100644 index 000000000000..85ae031a72dc --- /dev/null +++ b/keras/api/_tf_keras/keras/activations/__init__.py @@ -0,0 +1,41 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.activations import deserialize as deserialize +from keras.src.activations import get as get +from keras.src.activations import serialize as serialize +from keras.src.activations.activations import celu as celu +from keras.src.activations.activations import elu as elu +from keras.src.activations.activations import exponential as exponential +from keras.src.activations.activations import gelu as gelu +from keras.src.activations.activations import glu as glu +from keras.src.activations.activations import hard_shrink as hard_shrink +from keras.src.activations.activations import hard_sigmoid as hard_sigmoid +from keras.src.activations.activations import hard_silu as hard_silu +from keras.src.activations.activations import hard_silu as hard_swish +from keras.src.activations.activations import hard_tanh as hard_tanh +from keras.src.activations.activations import leaky_relu as leaky_relu +from keras.src.activations.activations import linear as linear +from keras.src.activations.activations import log_sigmoid as log_sigmoid +from keras.src.activations.activations import log_softmax as log_softmax +from keras.src.activations.activations import mish as mish +from keras.src.activations.activations import relu as relu +from keras.src.activations.activations import relu6 as relu6 +from keras.src.activations.activations import selu as selu +from keras.src.activations.activations import sigmoid as sigmoid +from keras.src.activations.activations import silu as silu +from keras.src.activations.activations import silu as swish +from keras.src.activations.activations import soft_shrink as soft_shrink +from keras.src.activations.activations import softmax as softmax +from keras.src.activations.activations import softplus as softplus +from keras.src.activations.activations import softsign as softsign +from keras.src.activations.activations import sparse_plus as sparse_plus +from keras.src.activations.activations import sparse_sigmoid as sparse_sigmoid +from keras.src.activations.activations import sparsemax as sparsemax +from keras.src.activations.activations import squareplus as squareplus +from keras.src.activations.activations import tanh as tanh +from keras.src.activations.activations import tanh_shrink as tanh_shrink +from keras.src.activations.activations import threshold as threshold diff --git a/keras/api/_tf_keras/keras/applications/__init__.py b/keras/api/_tf_keras/keras/applications/__init__.py new file mode 100644 index 000000000000..7c030b36bd4e --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/__init__.py @@ -0,0 +1,83 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.applications import convnext as convnext +from keras.applications import densenet as densenet +from keras.applications import efficientnet as efficientnet +from keras.applications import efficientnet_v2 as efficientnet_v2 +from keras.applications import imagenet_utils as imagenet_utils +from keras.applications import inception_resnet_v2 as inception_resnet_v2 +from keras.applications import inception_v3 as inception_v3 +from keras.applications import mobilenet as mobilenet +from keras.applications import mobilenet_v2 as mobilenet_v2 +from keras.applications import mobilenet_v3 as mobilenet_v3 +from keras.applications import nasnet as nasnet +from keras.applications import resnet as resnet +from keras.applications import resnet50 as resnet50 +from keras.applications import resnet_v2 as resnet_v2 +from keras.applications import vgg16 as vgg16 +from keras.applications import vgg19 as vgg19 +from keras.applications import xception as xception +from keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase +from keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge +from keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall +from keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny +from keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge +from keras.src.applications.densenet import DenseNet121 as DenseNet121 +from keras.src.applications.densenet import DenseNet169 as DenseNet169 +from keras.src.applications.densenet import DenseNet201 as DenseNet201 +from keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0 +from keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1 +from keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2 +from keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3 +from keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4 +from keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5 +from keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6 +from keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7 +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B0 as EfficientNetV2B0, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B1 as EfficientNetV2B1, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B2 as EfficientNetV2B2, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B3 as EfficientNetV2B3, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2L as EfficientNetV2L, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2M as EfficientNetV2M, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2S as EfficientNetV2S, +) +from keras.src.applications.inception_resnet_v2 import ( + InceptionResNetV2 as InceptionResNetV2, +) +from keras.src.applications.inception_v3 import InceptionV3 as InceptionV3 +from keras.src.applications.mobilenet import MobileNet as MobileNet +from keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2 +from keras.src.applications.mobilenet_v3 import ( + MobileNetV3Large as MobileNetV3Large, +) +from keras.src.applications.mobilenet_v3 import ( + MobileNetV3Small as MobileNetV3Small, +) +from keras.src.applications.nasnet import NASNetLarge as NASNetLarge +from keras.src.applications.nasnet import NASNetMobile as NASNetMobile +from keras.src.applications.resnet import ResNet50 as ResNet50 +from keras.src.applications.resnet import ResNet101 as ResNet101 +from keras.src.applications.resnet import ResNet152 as ResNet152 +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 +from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 +from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.vgg16 import VGG16 as VGG16 +from keras.src.applications.vgg19 import VGG19 as VGG19 +from keras.src.applications.xception import Xception as Xception diff --git a/keras/api/_tf_keras/keras/applications/convnext/__init__.py b/keras/api/_tf_keras/keras/applications/convnext/__init__.py new file mode 100644 index 000000000000..c6d7bb7117e8 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/convnext/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase +from keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge +from keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall +from keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny +from keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge +from keras.src.applications.convnext import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.convnext import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/densenet/__init__.py b/keras/api/_tf_keras/keras/applications/densenet/__init__.py new file mode 100644 index 000000000000..6d6a27101099 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/densenet/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.densenet import DenseNet121 as DenseNet121 +from keras.src.applications.densenet import DenseNet169 as DenseNet169 +from keras.src.applications.densenet import DenseNet201 as DenseNet201 +from keras.src.applications.densenet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.densenet import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/efficientnet/__init__.py b/keras/api/_tf_keras/keras/applications/efficientnet/__init__.py new file mode 100644 index 000000000000..16384b74e2b2 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/efficientnet/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0 +from keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1 +from keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2 +from keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3 +from keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4 +from keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5 +from keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6 +from keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7 +from keras.src.applications.efficientnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.efficientnet import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/efficientnet_v2/__init__.py b/keras/api/_tf_keras/keras/applications/efficientnet_v2/__init__.py new file mode 100644 index 000000000000..8d13352008b6 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/efficientnet_v2/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B0 as EfficientNetV2B0, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B1 as EfficientNetV2B1, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B2 as EfficientNetV2B2, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B3 as EfficientNetV2B3, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2L as EfficientNetV2L, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2M as EfficientNetV2M, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2S as EfficientNetV2S, +) +from keras.src.applications.efficientnet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.efficientnet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/imagenet_utils/__init__.py b/keras/api/_tf_keras/keras/applications/imagenet_utils/__init__.py new file mode 100644 index 000000000000..66804964efbe --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/imagenet_utils/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.imagenet_utils import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.imagenet_utils import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/inception_resnet_v2/__init__.py b/keras/api/_tf_keras/keras/applications/inception_resnet_v2/__init__.py new file mode 100644 index 000000000000..4cb545a39fe1 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/inception_resnet_v2/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.inception_resnet_v2 import ( + InceptionResNetV2 as InceptionResNetV2, +) +from keras.src.applications.inception_resnet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.inception_resnet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/inception_v3/__init__.py b/keras/api/_tf_keras/keras/applications/inception_v3/__init__.py new file mode 100644 index 000000000000..a7db7bd80ce8 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/inception_v3/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.inception_v3 import InceptionV3 as InceptionV3 +from keras.src.applications.inception_v3 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.inception_v3 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/mobilenet/__init__.py b/keras/api/_tf_keras/keras/applications/mobilenet/__init__.py new file mode 100644 index 000000000000..6e721019c42e --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/mobilenet/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.mobilenet import MobileNet as MobileNet +from keras.src.applications.mobilenet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.mobilenet import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/mobilenet_v2/__init__.py b/keras/api/_tf_keras/keras/applications/mobilenet_v2/__init__.py new file mode 100644 index 000000000000..15ebaa3155a6 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/mobilenet_v2/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2 +from keras.src.applications.mobilenet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.mobilenet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/mobilenet_v3/__init__.py b/keras/api/_tf_keras/keras/applications/mobilenet_v3/__init__.py new file mode 100644 index 000000000000..a5abb926247c --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/mobilenet_v3/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.mobilenet_v3 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.mobilenet_v3 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/nasnet/__init__.py b/keras/api/_tf_keras/keras/applications/nasnet/__init__.py new file mode 100644 index 000000000000..c831e135fbd6 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/nasnet/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.nasnet import NASNetLarge as NASNetLarge +from keras.src.applications.nasnet import NASNetMobile as NASNetMobile +from keras.src.applications.nasnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.nasnet import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/resnet/__init__.py b/keras/api/_tf_keras/keras/applications/resnet/__init__.py new file mode 100644 index 000000000000..b8a25644e1d9 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/resnet/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.resnet import ResNet50 as ResNet50 +from keras.src.applications.resnet import ResNet101 as ResNet101 +from keras.src.applications.resnet import ResNet152 as ResNet152 +from keras.src.applications.resnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.resnet import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/resnet50/__init__.py b/keras/api/_tf_keras/keras/applications/resnet50/__init__.py new file mode 100644 index 000000000000..6cff78c6749c --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/resnet50/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.resnet import ResNet50 as ResNet50 +from keras.src.applications.resnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.resnet import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py b/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py new file mode 100644 index 000000000000..7f92dd56f374 --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 +from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 +from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.resnet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.resnet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/_tf_keras/keras/applications/vgg16/__init__.py b/keras/api/_tf_keras/keras/applications/vgg16/__init__.py new file mode 100644 index 000000000000..17fb30585d9a --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/vgg16/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.vgg16 import VGG16 as VGG16 +from keras.src.applications.vgg16 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.vgg16 import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/vgg19/__init__.py b/keras/api/_tf_keras/keras/applications/vgg19/__init__.py new file mode 100644 index 000000000000..83f865b3876b --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/vgg19/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.vgg19 import VGG19 as VGG19 +from keras.src.applications.vgg19 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.vgg19 import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/applications/xception/__init__.py b/keras/api/_tf_keras/keras/applications/xception/__init__.py new file mode 100644 index 000000000000..09a5859aab4b --- /dev/null +++ b/keras/api/_tf_keras/keras/applications/xception/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.xception import Xception as Xception +from keras.src.applications.xception import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.xception import preprocess_input as preprocess_input diff --git a/keras/api/_tf_keras/keras/backend/__init__.py b/keras/api/_tf_keras/keras/backend/__init__.py new file mode 100644 index 000000000000..cd9037bcf4d6 --- /dev/null +++ b/keras/api/_tf_keras/keras/backend/__init__.py @@ -0,0 +1,165 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.backend.common.dtypes import result_type as result_type +from keras.src.backend.common.global_state import clear_session as clear_session +from keras.src.backend.common.keras_tensor import ( + is_keras_tensor as is_keras_tensor, +) +from keras.src.backend.common.variables import is_float_dtype as is_float_dtype +from keras.src.backend.common.variables import is_int_dtype as is_int_dtype +from keras.src.backend.common.variables import ( + standardize_dtype as standardize_dtype, +) +from keras.src.backend.config import backend as backend +from keras.src.backend.config import epsilon as epsilon +from keras.src.backend.config import floatx as floatx +from keras.src.backend.config import image_data_format as image_data_format +from keras.src.backend.config import set_epsilon as set_epsilon +from keras.src.backend.config import set_floatx as set_floatx +from keras.src.backend.config import ( + set_image_data_format as set_image_data_format, +) +from keras.src.legacy.backend import abs as abs +from keras.src.legacy.backend import all as all +from keras.src.legacy.backend import any as any +from keras.src.legacy.backend import arange as arange +from keras.src.legacy.backend import argmax as argmax +from keras.src.legacy.backend import argmin as argmin +from keras.src.legacy.backend import batch_dot as batch_dot +from keras.src.legacy.backend import batch_flatten as batch_flatten +from keras.src.legacy.backend import batch_get_value as batch_get_value +from keras.src.legacy.backend import batch_normalization as batch_normalization +from keras.src.legacy.backend import batch_set_value as batch_set_value +from keras.src.legacy.backend import bias_add as bias_add +from keras.src.legacy.backend import binary_crossentropy as binary_crossentropy +from keras.src.legacy.backend import ( + binary_focal_crossentropy as binary_focal_crossentropy, +) +from keras.src.legacy.backend import cast as cast +from keras.src.legacy.backend import cast_to_floatx as cast_to_floatx +from keras.src.legacy.backend import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.legacy.backend import ( + categorical_focal_crossentropy as categorical_focal_crossentropy, +) +from keras.src.legacy.backend import clip as clip +from keras.src.legacy.backend import concatenate as concatenate +from keras.src.legacy.backend import constant as constant +from keras.src.legacy.backend import conv1d as conv1d +from keras.src.legacy.backend import conv2d as conv2d +from keras.src.legacy.backend import conv2d_transpose as conv2d_transpose +from keras.src.legacy.backend import conv3d as conv3d +from keras.src.legacy.backend import cos as cos +from keras.src.legacy.backend import count_params as count_params +from keras.src.legacy.backend import ctc_batch_cost as ctc_batch_cost +from keras.src.legacy.backend import ctc_decode as ctc_decode +from keras.src.legacy.backend import ( + ctc_label_dense_to_sparse as ctc_label_dense_to_sparse, +) +from keras.src.legacy.backend import cumprod as cumprod +from keras.src.legacy.backend import cumsum as cumsum +from keras.src.legacy.backend import depthwise_conv2d as depthwise_conv2d +from keras.src.legacy.backend import dot as dot +from keras.src.legacy.backend import dropout as dropout +from keras.src.legacy.backend import dtype as dtype +from keras.src.legacy.backend import elu as elu +from keras.src.legacy.backend import equal as equal +from keras.src.legacy.backend import eval as eval +from keras.src.legacy.backend import exp as exp +from keras.src.legacy.backend import expand_dims as expand_dims +from keras.src.legacy.backend import eye as eye +from keras.src.legacy.backend import flatten as flatten +from keras.src.legacy.backend import foldl as foldl +from keras.src.legacy.backend import foldr as foldr +from keras.src.legacy.backend import gather as gather +from keras.src.legacy.backend import get_value as get_value +from keras.src.legacy.backend import gradients as gradients +from keras.src.legacy.backend import greater as greater +from keras.src.legacy.backend import greater_equal as greater_equal +from keras.src.legacy.backend import hard_sigmoid as hard_sigmoid +from keras.src.legacy.backend import in_top_k as in_top_k +from keras.src.legacy.backend import int_shape as int_shape +from keras.src.legacy.backend import is_sparse as is_sparse +from keras.src.legacy.backend import l2_normalize as l2_normalize +from keras.src.legacy.backend import less as less +from keras.src.legacy.backend import less_equal as less_equal +from keras.src.legacy.backend import log as log +from keras.src.legacy.backend import map_fn as map_fn +from keras.src.legacy.backend import max as max +from keras.src.legacy.backend import maximum as maximum +from keras.src.legacy.backend import mean as mean +from keras.src.legacy.backend import min as min +from keras.src.legacy.backend import minimum as minimum +from keras.src.legacy.backend import ( + moving_average_update as moving_average_update, +) +from keras.src.legacy.backend import name_scope as name_scope +from keras.src.legacy.backend import ndim as ndim +from keras.src.legacy.backend import not_equal as not_equal +from keras.src.legacy.backend import one_hot as one_hot +from keras.src.legacy.backend import ones as ones +from keras.src.legacy.backend import ones_like as ones_like +from keras.src.legacy.backend import permute_dimensions as permute_dimensions +from keras.src.legacy.backend import pool2d as pool2d +from keras.src.legacy.backend import pool3d as pool3d +from keras.src.legacy.backend import pow as pow +from keras.src.legacy.backend import prod as prod +from keras.src.legacy.backend import random_bernoulli as random_bernoulli +from keras.src.legacy.backend import random_normal as random_normal +from keras.src.legacy.backend import ( + random_normal_variable as random_normal_variable, +) +from keras.src.legacy.backend import random_uniform as random_uniform +from keras.src.legacy.backend import ( + random_uniform_variable as random_uniform_variable, +) +from keras.src.legacy.backend import relu as relu +from keras.src.legacy.backend import repeat as repeat +from keras.src.legacy.backend import repeat_elements as repeat_elements +from keras.src.legacy.backend import reshape as reshape +from keras.src.legacy.backend import resize_images as resize_images +from keras.src.legacy.backend import resize_volumes as resize_volumes +from keras.src.legacy.backend import reverse as reverse +from keras.src.legacy.backend import rnn as rnn +from keras.src.legacy.backend import round as round +from keras.src.legacy.backend import separable_conv2d as separable_conv2d +from keras.src.legacy.backend import set_value as set_value +from keras.src.legacy.backend import shape as shape +from keras.src.legacy.backend import sigmoid as sigmoid +from keras.src.legacy.backend import sign as sign +from keras.src.legacy.backend import sin as sin +from keras.src.legacy.backend import softmax as softmax +from keras.src.legacy.backend import softplus as softplus +from keras.src.legacy.backend import softsign as softsign +from keras.src.legacy.backend import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.legacy.backend import spatial_2d_padding as spatial_2d_padding +from keras.src.legacy.backend import spatial_3d_padding as spatial_3d_padding +from keras.src.legacy.backend import sqrt as sqrt +from keras.src.legacy.backend import square as square +from keras.src.legacy.backend import squeeze as squeeze +from keras.src.legacy.backend import stack as stack +from keras.src.legacy.backend import std as std +from keras.src.legacy.backend import stop_gradient as stop_gradient +from keras.src.legacy.backend import sum as sum +from keras.src.legacy.backend import switch as switch +from keras.src.legacy.backend import tanh as tanh +from keras.src.legacy.backend import temporal_padding as temporal_padding +from keras.src.legacy.backend import tile as tile +from keras.src.legacy.backend import to_dense as to_dense +from keras.src.legacy.backend import transpose as transpose +from keras.src.legacy.backend import truncated_normal as truncated_normal +from keras.src.legacy.backend import update as update +from keras.src.legacy.backend import update_add as update_add +from keras.src.legacy.backend import update_sub as update_sub +from keras.src.legacy.backend import var as var +from keras.src.legacy.backend import variable as variable +from keras.src.legacy.backend import zeros as zeros +from keras.src.legacy.backend import zeros_like as zeros_like +from keras.src.utils.naming import get_uid as get_uid diff --git a/keras/api/_tf_keras/keras/callbacks/__init__.py b/keras/api/_tf_keras/keras/callbacks/__init__.py new file mode 100644 index 000000000000..4e165cddb6a8 --- /dev/null +++ b/keras/api/_tf_keras/keras/callbacks/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.callbacks.backup_and_restore import ( + BackupAndRestore as BackupAndRestore, +) +from keras.src.callbacks.callback import Callback as Callback +from keras.src.callbacks.callback_list import CallbackList as CallbackList +from keras.src.callbacks.csv_logger import CSVLogger as CSVLogger +from keras.src.callbacks.early_stopping import EarlyStopping as EarlyStopping +from keras.src.callbacks.history import History as History +from keras.src.callbacks.lambda_callback import LambdaCallback as LambdaCallback +from keras.src.callbacks.learning_rate_scheduler import ( + LearningRateScheduler as LearningRateScheduler, +) +from keras.src.callbacks.model_checkpoint import ( + ModelCheckpoint as ModelCheckpoint, +) +from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger +from keras.src.callbacks.reduce_lr_on_plateau import ( + ReduceLROnPlateau as ReduceLROnPlateau, +) +from keras.src.callbacks.remote_monitor import RemoteMonitor as RemoteMonitor +from keras.src.callbacks.swap_ema_weights import ( + SwapEMAWeights as SwapEMAWeights, +) +from keras.src.callbacks.tensorboard import TensorBoard as TensorBoard +from keras.src.callbacks.terminate_on_nan import ( + TerminateOnNaN as TerminateOnNaN, +) diff --git a/keras/api/_tf_keras/keras/config/__init__.py b/keras/api/_tf_keras/keras/config/__init__.py new file mode 100644 index 000000000000..8cf3a1c30abd --- /dev/null +++ b/keras/api/_tf_keras/keras/config/__init__.py @@ -0,0 +1,57 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.backend.config import backend as backend +from keras.src.backend.config import ( + disable_flash_attention as disable_flash_attention, +) +from keras.src.backend.config import ( + enable_flash_attention as enable_flash_attention, +) +from keras.src.backend.config import epsilon as epsilon +from keras.src.backend.config import floatx as floatx +from keras.src.backend.config import image_data_format as image_data_format +from keras.src.backend.config import ( + is_flash_attention_enabled as is_flash_attention_enabled, +) +from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled +from keras.src.backend.config import max_epochs as max_epochs +from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch +from keras.src.backend.config import set_epsilon as set_epsilon +from keras.src.backend.config import set_floatx as set_floatx +from keras.src.backend.config import ( + set_image_data_format as set_image_data_format, +) +from keras.src.backend.config import set_max_epochs as set_max_epochs +from keras.src.backend.config import ( + set_max_steps_per_epoch as set_max_steps_per_epoch, +) +from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy +from keras.src.dtype_policies.dtype_policy import ( + set_dtype_policy as set_dtype_policy, +) +from keras.src.saving.serialization_lib import ( + enable_unsafe_deserialization as enable_unsafe_deserialization, +) +from keras.src.utils.backend_utils import set_backend as set_backend +from keras.src.utils.io_utils import ( + disable_interactive_logging as disable_interactive_logging, +) +from keras.src.utils.io_utils import ( + enable_interactive_logging as enable_interactive_logging, +) +from keras.src.utils.io_utils import ( + is_interactive_logging_enabled as is_interactive_logging_enabled, +) +from keras.src.utils.traceback_utils import ( + disable_traceback_filtering as disable_traceback_filtering, +) +from keras.src.utils.traceback_utils import ( + enable_traceback_filtering as enable_traceback_filtering, +) +from keras.src.utils.traceback_utils import ( + is_traceback_filtering_enabled as is_traceback_filtering_enabled, +) diff --git a/keras/api/_tf_keras/keras/constraints/__init__.py b/keras/api/_tf_keras/keras/constraints/__init__.py new file mode 100644 index 000000000000..47d73d44627f --- /dev/null +++ b/keras/api/_tf_keras/keras/constraints/__init__.py @@ -0,0 +1,18 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.constraints import deserialize as deserialize +from keras.src.constraints import get as get +from keras.src.constraints import serialize as serialize +from keras.src.constraints.constraints import Constraint as Constraint +from keras.src.constraints.constraints import MaxNorm as MaxNorm +from keras.src.constraints.constraints import MaxNorm as max_norm +from keras.src.constraints.constraints import MinMaxNorm as MinMaxNorm +from keras.src.constraints.constraints import MinMaxNorm as min_max_norm +from keras.src.constraints.constraints import NonNeg as NonNeg +from keras.src.constraints.constraints import NonNeg as non_neg +from keras.src.constraints.constraints import UnitNorm as UnitNorm +from keras.src.constraints.constraints import UnitNorm as unit_norm diff --git a/keras/api/_tf_keras/keras/datasets/__init__.py b/keras/api/_tf_keras/keras/datasets/__init__.py new file mode 100644 index 000000000000..f61e994a4bff --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/__init__.py @@ -0,0 +1,14 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.datasets import boston_housing as boston_housing +from keras.datasets import california_housing as california_housing +from keras.datasets import cifar10 as cifar10 +from keras.datasets import cifar100 as cifar100 +from keras.datasets import fashion_mnist as fashion_mnist +from keras.datasets import imdb as imdb +from keras.datasets import mnist as mnist +from keras.datasets import reuters as reuters diff --git a/keras/api/_tf_keras/keras/datasets/boston_housing/__init__.py b/keras/api/_tf_keras/keras/datasets/boston_housing/__init__.py new file mode 100644 index 000000000000..897f8516ca82 --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/boston_housing/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.boston_housing import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/california_housing/__init__.py b/keras/api/_tf_keras/keras/datasets/california_housing/__init__.py new file mode 100644 index 000000000000..602bf81ac2cd --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/california_housing/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.california_housing import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/cifar10/__init__.py b/keras/api/_tf_keras/keras/datasets/cifar10/__init__.py new file mode 100644 index 000000000000..f7aad7fd1a55 --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/cifar10/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.cifar10 import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/cifar100/__init__.py b/keras/api/_tf_keras/keras/datasets/cifar100/__init__.py new file mode 100644 index 000000000000..237fafab6fc6 --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/cifar100/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.cifar100 import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/fashion_mnist/__init__.py b/keras/api/_tf_keras/keras/datasets/fashion_mnist/__init__.py new file mode 100644 index 000000000000..317f0951a063 --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/fashion_mnist/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.fashion_mnist import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/imdb/__init__.py b/keras/api/_tf_keras/keras/datasets/imdb/__init__.py new file mode 100644 index 000000000000..66931a4a30eb --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/imdb/__init__.py @@ -0,0 +1,8 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.imdb import get_word_index as get_word_index +from keras.src.datasets.imdb import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/mnist/__init__.py b/keras/api/_tf_keras/keras/datasets/mnist/__init__.py new file mode 100644 index 000000000000..0fc59f334c50 --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/mnist/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.mnist import load_data as load_data diff --git a/keras/api/_tf_keras/keras/datasets/reuters/__init__.py b/keras/api/_tf_keras/keras/datasets/reuters/__init__.py new file mode 100644 index 000000000000..0b2af62d785b --- /dev/null +++ b/keras/api/_tf_keras/keras/datasets/reuters/__init__.py @@ -0,0 +1,9 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.reuters import get_label_names as get_label_names +from keras.src.datasets.reuters import get_word_index as get_word_index +from keras.src.datasets.reuters import load_data as load_data diff --git a/keras/api/_tf_keras/keras/distribution/__init__.py b/keras/api/_tf_keras/keras/distribution/__init__.py new file mode 100644 index 000000000000..66fed24c761d --- /dev/null +++ b/keras/api/_tf_keras/keras/distribution/__init__.py @@ -0,0 +1,22 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.distribution.distribution_lib import DataParallel as DataParallel +from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh +from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap +from keras.src.distribution.distribution_lib import ( + ModelParallel as ModelParallel, +) +from keras.src.distribution.distribution_lib import TensorLayout as TensorLayout +from keras.src.distribution.distribution_lib import ( + distribute_tensor as distribute_tensor, +) +from keras.src.distribution.distribution_lib import distribution as distribution +from keras.src.distribution.distribution_lib import initialize as initialize +from keras.src.distribution.distribution_lib import list_devices as list_devices +from keras.src.distribution.distribution_lib import ( + set_distribution as set_distribution, +) diff --git a/keras/api/_tf_keras/keras/dtype_policies/__init__.py b/keras/api/_tf_keras/keras/dtype_policies/__init__.py new file mode 100644 index 000000000000..04f947d157c3 --- /dev/null +++ b/keras/api/_tf_keras/keras/dtype_policies/__init__.py @@ -0,0 +1,25 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.dtype_policies import deserialize as deserialize +from keras.src.dtype_policies import get as get +from keras.src.dtype_policies import serialize as serialize +from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy +from keras.src.dtype_policies.dtype_policy import ( + FloatDTypePolicy as FloatDTypePolicy, +) +from keras.src.dtype_policies.dtype_policy import ( + GPTQDTypePolicy as GPTQDTypePolicy, +) +from keras.src.dtype_policies.dtype_policy import ( + QuantizedDTypePolicy as QuantizedDTypePolicy, +) +from keras.src.dtype_policies.dtype_policy import ( + QuantizedFloat8DTypePolicy as QuantizedFloat8DTypePolicy, +) +from keras.src.dtype_policies.dtype_policy_map import ( + DTypePolicyMap as DTypePolicyMap, +) diff --git a/keras/api/_tf_keras/keras/export/__init__.py b/keras/api/_tf_keras/keras/export/__init__.py new file mode 100644 index 000000000000..fc8e748defcc --- /dev/null +++ b/keras/api/_tf_keras/keras/export/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.export.saved_model import ExportArchive as ExportArchive diff --git a/keras/api/_tf_keras/keras/initializers/__init__.py b/keras/api/_tf_keras/keras/initializers/__init__.py new file mode 100644 index 000000000000..e88013d97315 --- /dev/null +++ b/keras/api/_tf_keras/keras/initializers/__init__.py @@ -0,0 +1,81 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.initializers import deserialize as deserialize +from keras.src.initializers import get as get +from keras.src.initializers import serialize as serialize +from keras.src.initializers.constant_initializers import STFT as STFT +from keras.src.initializers.constant_initializers import STFT as STFTInitializer +from keras.src.initializers.constant_initializers import STFT as stft +from keras.src.initializers.constant_initializers import Constant as Constant +from keras.src.initializers.constant_initializers import Constant as constant +from keras.src.initializers.constant_initializers import Identity as Identity +from keras.src.initializers.constant_initializers import ( + Identity as IdentityInitializer, +) +from keras.src.initializers.constant_initializers import Identity as identity +from keras.src.initializers.constant_initializers import Ones as Ones +from keras.src.initializers.constant_initializers import Ones as ones +from keras.src.initializers.constant_initializers import Zeros as Zeros +from keras.src.initializers.constant_initializers import Zeros as zeros +from keras.src.initializers.initializer import Initializer as Initializer +from keras.src.initializers.random_initializers import ( + GlorotNormal as GlorotNormal, +) +from keras.src.initializers.random_initializers import ( + GlorotNormal as glorot_normal, +) +from keras.src.initializers.random_initializers import ( + GlorotUniform as GlorotUniform, +) +from keras.src.initializers.random_initializers import ( + GlorotUniform as glorot_uniform, +) +from keras.src.initializers.random_initializers import HeNormal as HeNormal +from keras.src.initializers.random_initializers import HeNormal as he_normal +from keras.src.initializers.random_initializers import HeUniform as HeUniform +from keras.src.initializers.random_initializers import HeUniform as he_uniform +from keras.src.initializers.random_initializers import ( + LecunNormal as LecunNormal, +) +from keras.src.initializers.random_initializers import ( + LecunNormal as lecun_normal, +) +from keras.src.initializers.random_initializers import ( + LecunUniform as LecunUniform, +) +from keras.src.initializers.random_initializers import ( + LecunUniform as lecun_uniform, +) +from keras.src.initializers.random_initializers import Orthogonal as Orthogonal +from keras.src.initializers.random_initializers import ( + Orthogonal as OrthogonalInitializer, +) +from keras.src.initializers.random_initializers import Orthogonal as orthogonal +from keras.src.initializers.random_initializers import ( + RandomNormal as RandomNormal, +) +from keras.src.initializers.random_initializers import ( + RandomNormal as random_normal, +) +from keras.src.initializers.random_initializers import ( + RandomUniform as RandomUniform, +) +from keras.src.initializers.random_initializers import ( + RandomUniform as random_uniform, +) +from keras.src.initializers.random_initializers import ( + TruncatedNormal as TruncatedNormal, +) +from keras.src.initializers.random_initializers import ( + TruncatedNormal as truncated_normal, +) +from keras.src.initializers.random_initializers import ( + VarianceScaling as VarianceScaling, +) +from keras.src.initializers.random_initializers import ( + VarianceScaling as variance_scaling, +) diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py new file mode 100644 index 000000000000..c33886cc4716 --- /dev/null +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -0,0 +1,362 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.export.tfsm_layer import TFSMLayer as TFSMLayer +from keras.src.layers import deserialize as deserialize +from keras.src.layers import serialize as serialize +from keras.src.layers.activations.activation import Activation as Activation +from keras.src.layers.activations.elu import ELU as ELU +from keras.src.layers.activations.leaky_relu import LeakyReLU as LeakyReLU +from keras.src.layers.activations.prelu import PReLU as PReLU +from keras.src.layers.activations.relu import ReLU as ReLU +from keras.src.layers.activations.softmax import Softmax as Softmax +from keras.src.layers.attention.additive_attention import ( + AdditiveAttention as AdditiveAttention, +) +from keras.src.layers.attention.attention import Attention as Attention +from keras.src.layers.attention.grouped_query_attention import ( + GroupedQueryAttention as GroupQueryAttention, +) +from keras.src.layers.attention.multi_head_attention import ( + MultiHeadAttention as MultiHeadAttention, +) +from keras.src.layers.convolutional.conv1d import Conv1D as Conv1D +from keras.src.layers.convolutional.conv1d import Conv1D as Convolution1D +from keras.src.layers.convolutional.conv1d_transpose import ( + Conv1DTranspose as Conv1DTranspose, +) +from keras.src.layers.convolutional.conv1d_transpose import ( + Conv1DTranspose as Convolution1DTranspose, +) +from keras.src.layers.convolutional.conv2d import Conv2D as Conv2D +from keras.src.layers.convolutional.conv2d import Conv2D as Convolution2D +from keras.src.layers.convolutional.conv2d_transpose import ( + Conv2DTranspose as Conv2DTranspose, +) +from keras.src.layers.convolutional.conv2d_transpose import ( + Conv2DTranspose as Convolution2DTranspose, +) +from keras.src.layers.convolutional.conv3d import Conv3D as Conv3D +from keras.src.layers.convolutional.conv3d import Conv3D as Convolution3D +from keras.src.layers.convolutional.conv3d_transpose import ( + Conv3DTranspose as Conv3DTranspose, +) +from keras.src.layers.convolutional.conv3d_transpose import ( + Conv3DTranspose as Convolution3DTranspose, +) +from keras.src.layers.convolutional.depthwise_conv1d import ( + DepthwiseConv1D as DepthwiseConv1D, +) +from keras.src.layers.convolutional.depthwise_conv2d import ( + DepthwiseConv2D as DepthwiseConv2D, +) +from keras.src.layers.convolutional.separable_conv1d import ( + SeparableConv1D as SeparableConv1D, +) +from keras.src.layers.convolutional.separable_conv1d import ( + SeparableConv1D as SeparableConvolution1D, +) +from keras.src.layers.convolutional.separable_conv2d import ( + SeparableConv2D as SeparableConv2D, +) +from keras.src.layers.convolutional.separable_conv2d import ( + SeparableConv2D as SeparableConvolution2D, +) +from keras.src.layers.core.dense import Dense as Dense +from keras.src.layers.core.einsum_dense import EinsumDense as EinsumDense +from keras.src.layers.core.embedding import Embedding as Embedding +from keras.src.layers.core.identity import Identity as Identity +from keras.src.layers.core.input_layer import Input as Input +from keras.src.layers.core.input_layer import InputLayer as InputLayer +from keras.src.layers.core.lambda_layer import Lambda as Lambda +from keras.src.layers.core.masking import Masking as Masking +from keras.src.layers.core.wrapper import Wrapper as Wrapper +from keras.src.layers.input_spec import InputSpec as InputSpec +from keras.src.layers.layer import Layer as Layer +from keras.src.layers.merging.add import Add as Add +from keras.src.layers.merging.add import add as add +from keras.src.layers.merging.average import Average as Average +from keras.src.layers.merging.average import average as average +from keras.src.layers.merging.concatenate import Concatenate as Concatenate +from keras.src.layers.merging.concatenate import concatenate as concatenate +from keras.src.layers.merging.dot import Dot as Dot +from keras.src.layers.merging.dot import dot as dot +from keras.src.layers.merging.maximum import Maximum as Maximum +from keras.src.layers.merging.maximum import maximum as maximum +from keras.src.layers.merging.minimum import Minimum as Minimum +from keras.src.layers.merging.minimum import minimum as minimum +from keras.src.layers.merging.multiply import Multiply as Multiply +from keras.src.layers.merging.multiply import multiply as multiply +from keras.src.layers.merging.subtract import Subtract as Subtract +from keras.src.layers.merging.subtract import subtract as subtract +from keras.src.layers.normalization.batch_normalization import ( + BatchNormalization as BatchNormalization, +) +from keras.src.layers.normalization.group_normalization import ( + GroupNormalization as GroupNormalization, +) +from keras.src.layers.normalization.layer_normalization import ( + LayerNormalization as LayerNormalization, +) +from keras.src.layers.normalization.rms_normalization import ( + RMSNormalization as RMSNormalization, +) +from keras.src.layers.normalization.spectral_normalization import ( + SpectralNormalization as SpectralNormalization, +) +from keras.src.layers.normalization.unit_normalization import ( + UnitNormalization as UnitNormalization, +) +from keras.src.layers.pooling.average_pooling1d import ( + AveragePooling1D as AveragePooling1D, +) +from keras.src.layers.pooling.average_pooling1d import ( + AveragePooling1D as AvgPool1D, +) +from keras.src.layers.pooling.average_pooling2d import ( + AveragePooling2D as AveragePooling2D, +) +from keras.src.layers.pooling.average_pooling2d import ( + AveragePooling2D as AvgPool2D, +) +from keras.src.layers.pooling.average_pooling3d import ( + AveragePooling3D as AveragePooling3D, +) +from keras.src.layers.pooling.average_pooling3d import ( + AveragePooling3D as AvgPool3D, +) +from keras.src.layers.pooling.global_average_pooling1d import ( + GlobalAveragePooling1D as GlobalAveragePooling1D, +) +from keras.src.layers.pooling.global_average_pooling1d import ( + GlobalAveragePooling1D as GlobalAvgPool1D, +) +from keras.src.layers.pooling.global_average_pooling2d import ( + GlobalAveragePooling2D as GlobalAveragePooling2D, +) +from keras.src.layers.pooling.global_average_pooling2d import ( + GlobalAveragePooling2D as GlobalAvgPool2D, +) +from keras.src.layers.pooling.global_average_pooling3d import ( + GlobalAveragePooling3D as GlobalAveragePooling3D, +) +from keras.src.layers.pooling.global_average_pooling3d import ( + GlobalAveragePooling3D as GlobalAvgPool3D, +) +from keras.src.layers.pooling.global_max_pooling1d import ( + GlobalMaxPooling1D as GlobalMaxPool1D, +) +from keras.src.layers.pooling.global_max_pooling1d import ( + GlobalMaxPooling1D as GlobalMaxPooling1D, +) +from keras.src.layers.pooling.global_max_pooling2d import ( + GlobalMaxPooling2D as GlobalMaxPool2D, +) +from keras.src.layers.pooling.global_max_pooling2d import ( + GlobalMaxPooling2D as GlobalMaxPooling2D, +) +from keras.src.layers.pooling.global_max_pooling3d import ( + GlobalMaxPooling3D as GlobalMaxPool3D, +) +from keras.src.layers.pooling.global_max_pooling3d import ( + GlobalMaxPooling3D as GlobalMaxPooling3D, +) +from keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPool1D +from keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPooling1D +from keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPool2D +from keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPooling2D +from keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPool3D +from keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPooling3D +from keras.src.layers.preprocessing.category_encoding import ( + CategoryEncoding as CategoryEncoding, +) +from keras.src.layers.preprocessing.discretization import ( + Discretization as Discretization, +) +from keras.src.layers.preprocessing.hashed_crossing import ( + HashedCrossing as HashedCrossing, +) +from keras.src.layers.preprocessing.hashing import Hashing as Hashing +from keras.src.layers.preprocessing.image_preprocessing.aug_mix import ( + AugMix as AugMix, +) +from keras.src.layers.preprocessing.image_preprocessing.auto_contrast import ( + AutoContrast as AutoContrast, +) +from keras.src.layers.preprocessing.image_preprocessing.center_crop import ( + CenterCrop as CenterCrop, +) +from keras.src.layers.preprocessing.image_preprocessing.cut_mix import ( + CutMix as CutMix, +) +from keras.src.layers.preprocessing.image_preprocessing.equalization import ( + Equalization as Equalization, +) +from keras.src.layers.preprocessing.image_preprocessing.max_num_bounding_box import ( + MaxNumBoundingBoxes as MaxNumBoundingBoxes, +) +from keras.src.layers.preprocessing.image_preprocessing.mix_up import ( + MixUp as MixUp, +) +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment as RandAugment, +) +from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( + RandomBrightness as RandomBrightness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration as RandomColorDegeneration, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter as RandomColorJitter, +) +from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( + RandomContrast as RandomContrast, +) +from keras.src.layers.preprocessing.image_preprocessing.random_crop import ( + RandomCrop as RandomCrop, +) +from keras.src.layers.preprocessing.image_preprocessing.random_elastic_transform import ( + RandomElasticTransform as RandomElasticTransform, +) +from keras.src.layers.preprocessing.image_preprocessing.random_erasing import ( + RandomErasing as RandomErasing, +) +from keras.src.layers.preprocessing.image_preprocessing.random_flip import ( + RandomFlip as RandomFlip, +) +from keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import ( + RandomGaussianBlur as RandomGaussianBlur, +) +from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import ( + RandomGrayscale as RandomGrayscale, +) +from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( + RandomHue as RandomHue, +) +from keras.src.layers.preprocessing.image_preprocessing.random_invert import ( + RandomInvert as RandomInvert, +) +from keras.src.layers.preprocessing.image_preprocessing.random_perspective import ( + RandomPerspective as RandomPerspective, +) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization as RandomPosterization, +) +from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( + RandomRotation as RandomRotation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_saturation import ( + RandomSaturation as RandomSaturation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import ( + RandomSharpness as RandomSharpness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear as RandomShear, +) +from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( + RandomTranslation as RandomTranslation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_zoom import ( + RandomZoom as RandomZoom, +) +from keras.src.layers.preprocessing.image_preprocessing.resizing import ( + Resizing as Resizing, +) +from keras.src.layers.preprocessing.image_preprocessing.solarization import ( + Solarization as Solarization, +) +from keras.src.layers.preprocessing.integer_lookup import ( + IntegerLookup as IntegerLookup, +) +from keras.src.layers.preprocessing.mel_spectrogram import ( + MelSpectrogram as MelSpectrogram, +) +from keras.src.layers.preprocessing.normalization import ( + Normalization as Normalization, +) +from keras.src.layers.preprocessing.pipeline import Pipeline as Pipeline +from keras.src.layers.preprocessing.rescaling import Rescaling as Rescaling +from keras.src.layers.preprocessing.stft_spectrogram import ( + STFTSpectrogram as STFTSpectrogram, +) +from keras.src.layers.preprocessing.string_lookup import ( + StringLookup as StringLookup, +) +from keras.src.layers.preprocessing.text_vectorization import ( + TextVectorization as TextVectorization, +) +from keras.src.layers.regularization.activity_regularization import ( + ActivityRegularization as ActivityRegularization, +) +from keras.src.layers.regularization.dropout import Dropout as Dropout +from keras.src.layers.regularization.gaussian_dropout import ( + GaussianDropout as GaussianDropout, +) +from keras.src.layers.regularization.gaussian_noise import ( + GaussianNoise as GaussianNoise, +) +from keras.src.layers.regularization.spatial_dropout import ( + SpatialDropout1D as SpatialDropout1D, +) +from keras.src.layers.regularization.spatial_dropout import ( + SpatialDropout2D as SpatialDropout2D, +) +from keras.src.layers.regularization.spatial_dropout import ( + SpatialDropout3D as SpatialDropout3D, +) +from keras.src.layers.reshaping.cropping1d import Cropping1D as Cropping1D +from keras.src.layers.reshaping.cropping2d import Cropping2D as Cropping2D +from keras.src.layers.reshaping.cropping3d import Cropping3D as Cropping3D +from keras.src.layers.reshaping.flatten import Flatten as Flatten +from keras.src.layers.reshaping.permute import Permute as Permute +from keras.src.layers.reshaping.repeat_vector import ( + RepeatVector as RepeatVector, +) +from keras.src.layers.reshaping.reshape import Reshape as Reshape +from keras.src.layers.reshaping.up_sampling1d import ( + UpSampling1D as UpSampling1D, +) +from keras.src.layers.reshaping.up_sampling2d import ( + UpSampling2D as UpSampling2D, +) +from keras.src.layers.reshaping.up_sampling3d import ( + UpSampling3D as UpSampling3D, +) +from keras.src.layers.reshaping.zero_padding1d import ( + ZeroPadding1D as ZeroPadding1D, +) +from keras.src.layers.reshaping.zero_padding2d import ( + ZeroPadding2D as ZeroPadding2D, +) +from keras.src.layers.reshaping.zero_padding3d import ( + ZeroPadding3D as ZeroPadding3D, +) +from keras.src.layers.rnn.bidirectional import Bidirectional as Bidirectional +from keras.src.layers.rnn.conv_lstm1d import ConvLSTM1D as ConvLSTM1D +from keras.src.layers.rnn.conv_lstm2d import ConvLSTM2D as ConvLSTM2D +from keras.src.layers.rnn.conv_lstm3d import ConvLSTM3D as ConvLSTM3D +from keras.src.layers.rnn.gru import GRU as GRU +from keras.src.layers.rnn.gru import GRUCell as GRUCell +from keras.src.layers.rnn.lstm import LSTM as LSTM +from keras.src.layers.rnn.lstm import LSTMCell as LSTMCell +from keras.src.layers.rnn.rnn import RNN as RNN +from keras.src.layers.rnn.simple_rnn import SimpleRNN as SimpleRNN +from keras.src.layers.rnn.simple_rnn import SimpleRNNCell as SimpleRNNCell +from keras.src.layers.rnn.stacked_rnn_cells import ( + StackedRNNCells as StackedRNNCells, +) +from keras.src.layers.rnn.time_distributed import ( + TimeDistributed as TimeDistributed, +) +from keras.src.legacy.layers import AlphaDropout as AlphaDropout +from keras.src.legacy.layers import RandomHeight as RandomHeight +from keras.src.legacy.layers import RandomWidth as RandomWidth +from keras.src.legacy.layers import ThresholdedReLU as ThresholdedReLU +from keras.src.utils.jax_layer import FlaxLayer as FlaxLayer +from keras.src.utils.jax_layer import JaxLayer as JaxLayer +from keras.src.utils.torch_utils import TorchModuleWrapper as TorchModuleWrapper diff --git a/keras/api/_tf_keras/keras/legacy/__init__.py b/keras/api/_tf_keras/keras/legacy/__init__.py new file mode 100644 index 000000000000..e71ba4312ee0 --- /dev/null +++ b/keras/api/_tf_keras/keras/legacy/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.legacy import saving as saving diff --git a/keras/api/_tf_keras/keras/legacy/saving/__init__.py b/keras/api/_tf_keras/keras/legacy/saving/__init__.py new file mode 100644 index 000000000000..1e3aa0ee9d5c --- /dev/null +++ b/keras/api/_tf_keras/keras/legacy/saving/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.saving.serialization import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.legacy.saving.serialization import ( + serialize_keras_object as serialize_keras_object, +) diff --git a/keras/api/_tf_keras/keras/losses/__init__.py b/keras/api/_tf_keras/keras/losses/__init__.py new file mode 100644 index 000000000000..73cc8e82db82 --- /dev/null +++ b/keras/api/_tf_keras/keras/losses/__init__.py @@ -0,0 +1,85 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.losses import Reduction as Reduction +from keras.src.losses import deserialize as deserialize +from keras.src.losses import get as get +from keras.src.losses import serialize as serialize +from keras.src.losses.loss import Loss as Loss +from keras.src.losses.losses import CTC as CTC +from keras.src.losses.losses import BinaryCrossentropy as BinaryCrossentropy +from keras.src.losses.losses import ( + BinaryFocalCrossentropy as BinaryFocalCrossentropy, +) +from keras.src.losses.losses import ( + CategoricalCrossentropy as CategoricalCrossentropy, +) +from keras.src.losses.losses import ( + CategoricalFocalCrossentropy as CategoricalFocalCrossentropy, +) +from keras.src.losses.losses import ( + CategoricalGeneralizedCrossEntropy as CategoricalGeneralizedCrossEntropy, +) +from keras.src.losses.losses import CategoricalHinge as CategoricalHinge +from keras.src.losses.losses import Circle as Circle +from keras.src.losses.losses import CosineSimilarity as CosineSimilarity +from keras.src.losses.losses import Dice as Dice +from keras.src.losses.losses import Hinge as Hinge +from keras.src.losses.losses import Huber as Huber +from keras.src.losses.losses import KLDivergence as KLDivergence +from keras.src.losses.losses import LogCosh as LogCosh +from keras.src.losses.losses import MeanAbsoluteError as MeanAbsoluteError +from keras.src.losses.losses import ( + MeanAbsolutePercentageError as MeanAbsolutePercentageError, +) +from keras.src.losses.losses import MeanSquaredError as MeanSquaredError +from keras.src.losses.losses import ( + MeanSquaredLogarithmicError as MeanSquaredLogarithmicError, +) +from keras.src.losses.losses import Poisson as Poisson +from keras.src.losses.losses import ( + SparseCategoricalCrossentropy as SparseCategoricalCrossentropy, +) +from keras.src.losses.losses import SquaredHinge as SquaredHinge +from keras.src.losses.losses import Tversky as Tversky +from keras.src.losses.losses import binary_crossentropy as binary_crossentropy +from keras.src.losses.losses import ( + binary_focal_crossentropy as binary_focal_crossentropy, +) +from keras.src.losses.losses import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.losses.losses import ( + categorical_focal_crossentropy as categorical_focal_crossentropy, +) +from keras.src.losses.losses import ( + categorical_generalized_cross_entropy as categorical_generalized_cross_entropy, +) +from keras.src.losses.losses import categorical_hinge as categorical_hinge +from keras.src.losses.losses import circle as circle +from keras.src.losses.losses import cosine_similarity as cosine_similarity +from keras.src.losses.losses import ctc as ctc +from keras.src.losses.losses import dice as dice +from keras.src.losses.losses import hinge as hinge +from keras.src.losses.losses import huber as huber +from keras.src.losses.losses import kl_divergence as KLD +from keras.src.losses.losses import kl_divergence as kld +from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence +from keras.src.losses.losses import log_cosh as logcosh +from keras.src.losses.losses import mean_absolute_error as MAE +from keras.src.losses.losses import mean_absolute_error as mae +from keras.src.losses.losses import mean_absolute_percentage_error as MAPE +from keras.src.losses.losses import mean_absolute_percentage_error as mape +from keras.src.losses.losses import mean_squared_error as MSE +from keras.src.losses.losses import mean_squared_error as mse +from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE +from keras.src.losses.losses import mean_squared_logarithmic_error as msle +from keras.src.losses.losses import poisson as poisson +from keras.src.losses.losses import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.losses.losses import squared_hinge as squared_hinge +from keras.src.losses.losses import tversky as tversky diff --git a/keras/api/_tf_keras/keras/metrics/__init__.py b/keras/api/_tf_keras/keras/metrics/__init__.py new file mode 100644 index 000000000000..11fd5db493cd --- /dev/null +++ b/keras/api/_tf_keras/keras/metrics/__init__.py @@ -0,0 +1,146 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.losses.losses import binary_crossentropy as binary_crossentropy +from keras.src.losses.losses import ( + binary_focal_crossentropy as binary_focal_crossentropy, +) +from keras.src.losses.losses import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.losses.losses import ( + categorical_focal_crossentropy as categorical_focal_crossentropy, +) +from keras.src.losses.losses import categorical_hinge as categorical_hinge +from keras.src.losses.losses import hinge as hinge +from keras.src.losses.losses import huber as huber +from keras.src.losses.losses import kl_divergence as KLD +from keras.src.losses.losses import kl_divergence as kld +from keras.src.losses.losses import kl_divergence as kullback_leibler_divergence +from keras.src.losses.losses import log_cosh as logcosh +from keras.src.losses.losses import mean_absolute_error as MAE +from keras.src.losses.losses import mean_absolute_error as mae +from keras.src.losses.losses import mean_absolute_percentage_error as MAPE +from keras.src.losses.losses import mean_absolute_percentage_error as mape +from keras.src.losses.losses import mean_squared_error as MSE +from keras.src.losses.losses import mean_squared_error as mse +from keras.src.losses.losses import mean_squared_logarithmic_error as MSLE +from keras.src.losses.losses import mean_squared_logarithmic_error as msle +from keras.src.losses.losses import poisson as poisson +from keras.src.losses.losses import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.losses.losses import squared_hinge as squared_hinge +from keras.src.metrics import deserialize as deserialize +from keras.src.metrics import get as get +from keras.src.metrics import serialize as serialize +from keras.src.metrics.accuracy_metrics import Accuracy as Accuracy +from keras.src.metrics.accuracy_metrics import BinaryAccuracy as BinaryAccuracy +from keras.src.metrics.accuracy_metrics import ( + CategoricalAccuracy as CategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + SparseCategoricalAccuracy as SparseCategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + SparseTopKCategoricalAccuracy as SparseTopKCategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + TopKCategoricalAccuracy as TopKCategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + binary_accuracy as binary_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + categorical_accuracy as categorical_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + sparse_categorical_accuracy as sparse_categorical_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + sparse_top_k_categorical_accuracy as sparse_top_k_categorical_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + top_k_categorical_accuracy as top_k_categorical_accuracy, +) +from keras.src.metrics.confusion_metrics import AUC as AUC +from keras.src.metrics.confusion_metrics import FalseNegatives as FalseNegatives +from keras.src.metrics.confusion_metrics import FalsePositives as FalsePositives +from keras.src.metrics.confusion_metrics import Precision as Precision +from keras.src.metrics.confusion_metrics import ( + PrecisionAtRecall as PrecisionAtRecall, +) +from keras.src.metrics.confusion_metrics import Recall as Recall +from keras.src.metrics.confusion_metrics import ( + RecallAtPrecision as RecallAtPrecision, +) +from keras.src.metrics.confusion_metrics import ( + SensitivityAtSpecificity as SensitivityAtSpecificity, +) +from keras.src.metrics.confusion_metrics import ( + SpecificityAtSensitivity as SpecificityAtSensitivity, +) +from keras.src.metrics.confusion_metrics import TrueNegatives as TrueNegatives +from keras.src.metrics.confusion_metrics import TruePositives as TruePositives +from keras.src.metrics.correlation_metrics import ( + ConcordanceCorrelation as ConcordanceCorrelation, +) +from keras.src.metrics.correlation_metrics import ( + PearsonCorrelation as PearsonCorrelation, +) +from keras.src.metrics.correlation_metrics import ( + concordance_correlation as concordance_correlation, +) +from keras.src.metrics.correlation_metrics import ( + pearson_correlation as pearson_correlation, +) +from keras.src.metrics.f_score_metrics import F1Score as F1Score +from keras.src.metrics.f_score_metrics import FBetaScore as FBetaScore +from keras.src.metrics.hinge_metrics import CategoricalHinge as CategoricalHinge +from keras.src.metrics.hinge_metrics import Hinge as Hinge +from keras.src.metrics.hinge_metrics import SquaredHinge as SquaredHinge +from keras.src.metrics.iou_metrics import BinaryIoU as BinaryIoU +from keras.src.metrics.iou_metrics import IoU as IoU +from keras.src.metrics.iou_metrics import MeanIoU as MeanIoU +from keras.src.metrics.iou_metrics import OneHotIoU as OneHotIoU +from keras.src.metrics.iou_metrics import OneHotMeanIoU as OneHotMeanIoU +from keras.src.metrics.metric import Metric as Metric +from keras.src.metrics.probabilistic_metrics import ( + BinaryCrossentropy as BinaryCrossentropy, +) +from keras.src.metrics.probabilistic_metrics import ( + CategoricalCrossentropy as CategoricalCrossentropy, +) +from keras.src.metrics.probabilistic_metrics import KLDivergence as KLDivergence +from keras.src.metrics.probabilistic_metrics import Poisson as Poisson +from keras.src.metrics.probabilistic_metrics import ( + SparseCategoricalCrossentropy as SparseCategoricalCrossentropy, +) +from keras.src.metrics.reduction_metrics import Mean as Mean +from keras.src.metrics.reduction_metrics import ( + MeanMetricWrapper as MeanMetricWrapper, +) +from keras.src.metrics.reduction_metrics import Sum as Sum +from keras.src.metrics.regression_metrics import ( + CosineSimilarity as CosineSimilarity, +) +from keras.src.metrics.regression_metrics import LogCoshError as LogCoshError +from keras.src.metrics.regression_metrics import ( + MeanAbsoluteError as MeanAbsoluteError, +) +from keras.src.metrics.regression_metrics import ( + MeanAbsolutePercentageError as MeanAbsolutePercentageError, +) +from keras.src.metrics.regression_metrics import ( + MeanSquaredError as MeanSquaredError, +) +from keras.src.metrics.regression_metrics import ( + MeanSquaredLogarithmicError as MeanSquaredLogarithmicError, +) +from keras.src.metrics.regression_metrics import R2Score as R2Score +from keras.src.metrics.regression_metrics import ( + RootMeanSquaredError as RootMeanSquaredError, +) diff --git a/keras/api/_tf_keras/keras/mixed_precision/__init__.py b/keras/api/_tf_keras/keras/mixed_precision/__init__.py new file mode 100644 index 000000000000..9555b8639385 --- /dev/null +++ b/keras/api/_tf_keras/keras/mixed_precision/__init__.py @@ -0,0 +1,19 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy +from keras.src.dtype_policies.dtype_policy import DTypePolicy as Policy +from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy +from keras.src.dtype_policies.dtype_policy import dtype_policy as global_policy +from keras.src.dtype_policies.dtype_policy import ( + set_dtype_policy as set_dtype_policy, +) +from keras.src.dtype_policies.dtype_policy import ( + set_dtype_policy as set_global_policy, +) +from keras.src.optimizers.loss_scale_optimizer import ( + LossScaleOptimizer as LossScaleOptimizer, +) diff --git a/keras/api/_tf_keras/keras/models/__init__.py b/keras/api/_tf_keras/keras/models/__init__.py new file mode 100644 index 000000000000..f9dd57556d53 --- /dev/null +++ b/keras/api/_tf_keras/keras/models/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.models.cloning import clone_model as clone_model +from keras.src.models.model import Model as Model +from keras.src.models.model import model_from_json as model_from_json +from keras.src.models.sequential import Sequential as Sequential +from keras.src.saving.saving_api import load_model as load_model +from keras.src.saving.saving_api import save_model as save_model diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py new file mode 100644 index 000000000000..2194c975b89f --- /dev/null +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -0,0 +1,300 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.ops import image as image +from keras.ops import linalg as linalg +from keras.ops import nn as nn +from keras.ops import numpy as numpy +from keras.src.ops.core import associative_scan as associative_scan +from keras.src.ops.core import cast as cast +from keras.src.ops.core import cond as cond +from keras.src.ops.core import convert_to_numpy as convert_to_numpy +from keras.src.ops.core import convert_to_tensor as convert_to_tensor +from keras.src.ops.core import custom_gradient as custom_gradient +from keras.src.ops.core import dtype as dtype +from keras.src.ops.core import fori_loop as fori_loop +from keras.src.ops.core import is_tensor as is_tensor +from keras.src.ops.core import map as map +from keras.src.ops.core import saturate_cast as saturate_cast +from keras.src.ops.core import scan as scan +from keras.src.ops.core import scatter as scatter +from keras.src.ops.core import scatter_update as scatter_update +from keras.src.ops.core import shape as shape +from keras.src.ops.core import slice as slice +from keras.src.ops.core import slice_update as slice_update +from keras.src.ops.core import stop_gradient as stop_gradient +from keras.src.ops.core import switch as switch +from keras.src.ops.core import unstack as unstack +from keras.src.ops.core import vectorized_map as vectorized_map +from keras.src.ops.core import while_loop as while_loop +from keras.src.ops.einops import rearrange as rearrange +from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse +from keras.src.ops.linalg import det as det +from keras.src.ops.linalg import eig as eig +from keras.src.ops.linalg import eigh as eigh +from keras.src.ops.linalg import inv as inv +from keras.src.ops.linalg import jvp as jvp +from keras.src.ops.linalg import lstsq as lstsq +from keras.src.ops.linalg import lu_factor as lu_factor +from keras.src.ops.linalg import norm as norm +from keras.src.ops.linalg import qr as qr +from keras.src.ops.linalg import solve as solve +from keras.src.ops.linalg import solve_triangular as solve_triangular +from keras.src.ops.linalg import svd as svd +from keras.src.ops.math import erf as erf +from keras.src.ops.math import erfinv as erfinv +from keras.src.ops.math import extract_sequences as extract_sequences +from keras.src.ops.math import fft as fft +from keras.src.ops.math import fft2 as fft2 +from keras.src.ops.math import ifft2 as ifft2 +from keras.src.ops.math import in_top_k as in_top_k +from keras.src.ops.math import irfft as irfft +from keras.src.ops.math import istft as istft +from keras.src.ops.math import logdet as logdet +from keras.src.ops.math import logsumexp as logsumexp +from keras.src.ops.math import rfft as rfft +from keras.src.ops.math import rsqrt as rsqrt +from keras.src.ops.math import segment_max as segment_max +from keras.src.ops.math import segment_sum as segment_sum +from keras.src.ops.math import stft as stft +from keras.src.ops.math import top_k as top_k +from keras.src.ops.math import view_as_complex as view_as_complex +from keras.src.ops.math import view_as_real as view_as_real +from keras.src.ops.nn import average_pool as average_pool +from keras.src.ops.nn import batch_normalization as batch_normalization +from keras.src.ops.nn import binary_crossentropy as binary_crossentropy +from keras.src.ops.nn import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.ops.nn import celu as celu +from keras.src.ops.nn import conv as conv +from keras.src.ops.nn import conv_transpose as conv_transpose +from keras.src.ops.nn import ctc_decode as ctc_decode +from keras.src.ops.nn import ctc_loss as ctc_loss +from keras.src.ops.nn import depthwise_conv as depthwise_conv +from keras.src.ops.nn import dot_product_attention as dot_product_attention +from keras.src.ops.nn import elu as elu +from keras.src.ops.nn import gelu as gelu +from keras.src.ops.nn import glu as glu +from keras.src.ops.nn import hard_shrink as hard_shrink +from keras.src.ops.nn import hard_sigmoid as hard_sigmoid +from keras.src.ops.nn import hard_silu as hard_silu +from keras.src.ops.nn import hard_silu as hard_swish +from keras.src.ops.nn import hard_tanh as hard_tanh +from keras.src.ops.nn import layer_normalization as layer_normalization +from keras.src.ops.nn import leaky_relu as leaky_relu +from keras.src.ops.nn import log_sigmoid as log_sigmoid +from keras.src.ops.nn import log_softmax as log_softmax +from keras.src.ops.nn import max_pool as max_pool +from keras.src.ops.nn import moments as moments +from keras.src.ops.nn import multi_hot as multi_hot +from keras.src.ops.nn import normalize as normalize +from keras.src.ops.nn import one_hot as one_hot +from keras.src.ops.nn import polar as polar +from keras.src.ops.nn import psnr as psnr +from keras.src.ops.nn import relu as relu +from keras.src.ops.nn import relu6 as relu6 +from keras.src.ops.nn import rms_normalization as rms_normalization +from keras.src.ops.nn import selu as selu +from keras.src.ops.nn import separable_conv as separable_conv +from keras.src.ops.nn import sigmoid as sigmoid +from keras.src.ops.nn import silu as silu +from keras.src.ops.nn import silu as swish +from keras.src.ops.nn import soft_shrink as soft_shrink +from keras.src.ops.nn import softmax as softmax +from keras.src.ops.nn import softplus as softplus +from keras.src.ops.nn import softsign as softsign +from keras.src.ops.nn import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.ops.nn import sparse_plus as sparse_plus +from keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid +from keras.src.ops.nn import sparsemax as sparsemax +from keras.src.ops.nn import squareplus as squareplus +from keras.src.ops.nn import tanh_shrink as tanh_shrink +from keras.src.ops.nn import threshold as threshold +from keras.src.ops.nn import unfold as unfold +from keras.src.ops.numpy import abs as abs +from keras.src.ops.numpy import absolute as absolute +from keras.src.ops.numpy import add as add +from keras.src.ops.numpy import all as all +from keras.src.ops.numpy import amax as amax +from keras.src.ops.numpy import amin as amin +from keras.src.ops.numpy import angle as angle +from keras.src.ops.numpy import any as any +from keras.src.ops.numpy import append as append +from keras.src.ops.numpy import arange as arange +from keras.src.ops.numpy import arccos as arccos +from keras.src.ops.numpy import arccosh as arccosh +from keras.src.ops.numpy import arcsin as arcsin +from keras.src.ops.numpy import arcsinh as arcsinh +from keras.src.ops.numpy import arctan as arctan +from keras.src.ops.numpy import arctan2 as arctan2 +from keras.src.ops.numpy import arctanh as arctanh +from keras.src.ops.numpy import argmax as argmax +from keras.src.ops.numpy import argmin as argmin +from keras.src.ops.numpy import argpartition as argpartition +from keras.src.ops.numpy import argsort as argsort +from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import average as average +from keras.src.ops.numpy import bartlett as bartlett +from keras.src.ops.numpy import bincount as bincount +from keras.src.ops.numpy import bitwise_and as bitwise_and +from keras.src.ops.numpy import bitwise_invert as bitwise_invert +from keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift +from keras.src.ops.numpy import bitwise_not as bitwise_not +from keras.src.ops.numpy import bitwise_or as bitwise_or +from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift +from keras.src.ops.numpy import bitwise_xor as bitwise_xor +from keras.src.ops.numpy import blackman as blackman +from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt +from keras.src.ops.numpy import ceil as ceil +from keras.src.ops.numpy import clip as clip +from keras.src.ops.numpy import concatenate as concatenate +from keras.src.ops.numpy import conj as conj +from keras.src.ops.numpy import conjugate as conjugate +from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef +from keras.src.ops.numpy import correlate as correlate +from keras.src.ops.numpy import cos as cos +from keras.src.ops.numpy import cosh as cosh +from keras.src.ops.numpy import count_nonzero as count_nonzero +from keras.src.ops.numpy import cross as cross +from keras.src.ops.numpy import cumprod as cumprod +from keras.src.ops.numpy import cumsum as cumsum +from keras.src.ops.numpy import deg2rad as deg2rad +from keras.src.ops.numpy import diag as diag +from keras.src.ops.numpy import diagflat as diagflat +from keras.src.ops.numpy import diagonal as diagonal +from keras.src.ops.numpy import diff as diff +from keras.src.ops.numpy import digitize as digitize +from keras.src.ops.numpy import divide as divide +from keras.src.ops.numpy import divide_no_nan as divide_no_nan +from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import einsum as einsum +from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import equal as equal +from keras.src.ops.numpy import exp as exp +from keras.src.ops.numpy import exp2 as exp2 +from keras.src.ops.numpy import expand_dims as expand_dims +from keras.src.ops.numpy import expm1 as expm1 +from keras.src.ops.numpy import eye as eye +from keras.src.ops.numpy import flip as flip +from keras.src.ops.numpy import floor as floor +from keras.src.ops.numpy import floor_divide as floor_divide +from keras.src.ops.numpy import full as full +from keras.src.ops.numpy import full_like as full_like +from keras.src.ops.numpy import gcd as gcd +from keras.src.ops.numpy import get_item as get_item +from keras.src.ops.numpy import greater as greater +from keras.src.ops.numpy import greater_equal as greater_equal +from keras.src.ops.numpy import hamming as hamming +from keras.src.ops.numpy import hanning as hanning +from keras.src.ops.numpy import heaviside as heaviside +from keras.src.ops.numpy import histogram as histogram +from keras.src.ops.numpy import hstack as hstack +from keras.src.ops.numpy import hypot as hypot +from keras.src.ops.numpy import identity as identity +from keras.src.ops.numpy import imag as imag +from keras.src.ops.numpy import inner as inner +from keras.src.ops.numpy import isclose as isclose +from keras.src.ops.numpy import isfinite as isfinite +from keras.src.ops.numpy import isin as isin +from keras.src.ops.numpy import isinf as isinf +from keras.src.ops.numpy import isnan as isnan +from keras.src.ops.numpy import isneginf as isneginf +from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import kaiser as kaiser +from keras.src.ops.numpy import kron as kron +from keras.src.ops.numpy import lcm as lcm +from keras.src.ops.numpy import left_shift as left_shift +from keras.src.ops.numpy import less as less +from keras.src.ops.numpy import less_equal as less_equal +from keras.src.ops.numpy import linspace as linspace +from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log1p as log1p +from keras.src.ops.numpy import log2 as log2 +from keras.src.ops.numpy import log10 as log10 +from keras.src.ops.numpy import logaddexp as logaddexp +from keras.src.ops.numpy import logaddexp2 as logaddexp2 +from keras.src.ops.numpy import logical_and as logical_and +from keras.src.ops.numpy import logical_not as logical_not +from keras.src.ops.numpy import logical_or as logical_or +from keras.src.ops.numpy import logical_xor as logical_xor +from keras.src.ops.numpy import logspace as logspace +from keras.src.ops.numpy import matmul as matmul +from keras.src.ops.numpy import max as max +from keras.src.ops.numpy import maximum as maximum +from keras.src.ops.numpy import mean as mean +from keras.src.ops.numpy import median as median +from keras.src.ops.numpy import meshgrid as meshgrid +from keras.src.ops.numpy import min as min +from keras.src.ops.numpy import minimum as minimum +from keras.src.ops.numpy import mod as mod +from keras.src.ops.numpy import moveaxis as moveaxis +from keras.src.ops.numpy import multiply as multiply +from keras.src.ops.numpy import nan_to_num as nan_to_num +from keras.src.ops.numpy import ndim as ndim +from keras.src.ops.numpy import negative as negative +from keras.src.ops.numpy import nonzero as nonzero +from keras.src.ops.numpy import not_equal as not_equal +from keras.src.ops.numpy import ones as ones +from keras.src.ops.numpy import ones_like as ones_like +from keras.src.ops.numpy import outer as outer +from keras.src.ops.numpy import pad as pad +from keras.src.ops.numpy import power as power +from keras.src.ops.numpy import prod as prod +from keras.src.ops.numpy import quantile as quantile +from keras.src.ops.numpy import ravel as ravel +from keras.src.ops.numpy import real as real +from keras.src.ops.numpy import reciprocal as reciprocal +from keras.src.ops.numpy import repeat as repeat +from keras.src.ops.numpy import reshape as reshape +from keras.src.ops.numpy import right_shift as right_shift +from keras.src.ops.numpy import roll as roll +from keras.src.ops.numpy import rot90 as rot90 +from keras.src.ops.numpy import round as round +from keras.src.ops.numpy import searchsorted as searchsorted +from keras.src.ops.numpy import select as select +from keras.src.ops.numpy import sign as sign +from keras.src.ops.numpy import signbit as signbit +from keras.src.ops.numpy import sin as sin +from keras.src.ops.numpy import sinh as sinh +from keras.src.ops.numpy import size as size +from keras.src.ops.numpy import slogdet as slogdet +from keras.src.ops.numpy import sort as sort +from keras.src.ops.numpy import split as split +from keras.src.ops.numpy import sqrt as sqrt +from keras.src.ops.numpy import square as square +from keras.src.ops.numpy import squeeze as squeeze +from keras.src.ops.numpy import stack as stack +from keras.src.ops.numpy import std as std +from keras.src.ops.numpy import subtract as subtract +from keras.src.ops.numpy import sum as sum +from keras.src.ops.numpy import swapaxes as swapaxes +from keras.src.ops.numpy import take as take +from keras.src.ops.numpy import take_along_axis as take_along_axis +from keras.src.ops.numpy import tan as tan +from keras.src.ops.numpy import tanh as tanh +from keras.src.ops.numpy import tensordot as tensordot +from keras.src.ops.numpy import tile as tile +from keras.src.ops.numpy import trace as trace +from keras.src.ops.numpy import transpose as transpose +from keras.src.ops.numpy import tri as tri +from keras.src.ops.numpy import tril as tril +from keras.src.ops.numpy import triu as triu +from keras.src.ops.numpy import true_divide as true_divide +from keras.src.ops.numpy import trunc as trunc +from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import var as var +from keras.src.ops.numpy import vdot as vdot +from keras.src.ops.numpy import vectorize as vectorize +from keras.src.ops.numpy import vstack as vstack +from keras.src.ops.numpy import where as where +from keras.src.ops.numpy import zeros as zeros +from keras.src.ops.numpy import zeros_like as zeros_like diff --git a/keras/api/_tf_keras/keras/ops/image/__init__.py b/keras/api/_tf_keras/keras/ops/image/__init__.py new file mode 100644 index 000000000000..3a81f191259d --- /dev/null +++ b/keras/api/_tf_keras/keras/ops/image/__init__.py @@ -0,0 +1,19 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.image import affine_transform as affine_transform +from keras.src.ops.image import crop_images as crop_images +from keras.src.ops.image import elastic_transform as elastic_transform +from keras.src.ops.image import extract_patches as extract_patches +from keras.src.ops.image import gaussian_blur as gaussian_blur +from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb +from keras.src.ops.image import map_coordinates as map_coordinates +from keras.src.ops.image import pad_images as pad_images +from keras.src.ops.image import perspective_transform as perspective_transform +from keras.src.ops.image import resize as resize +from keras.src.ops.image import rgb_to_grayscale as rgb_to_grayscale +from keras.src.ops.image import rgb_to_hsv as rgb_to_hsv +from keras.src.ops.image import scale_and_translate as scale_and_translate diff --git a/keras/api/_tf_keras/keras/ops/linalg/__init__.py b/keras/api/_tf_keras/keras/ops/linalg/__init__.py new file mode 100644 index 000000000000..764fa8e74269 --- /dev/null +++ b/keras/api/_tf_keras/keras/ops/linalg/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse +from keras.src.ops.linalg import det as det +from keras.src.ops.linalg import eig as eig +from keras.src.ops.linalg import eigh as eigh +from keras.src.ops.linalg import inv as inv +from keras.src.ops.linalg import jvp as jvp +from keras.src.ops.linalg import lstsq as lstsq +from keras.src.ops.linalg import lu_factor as lu_factor +from keras.src.ops.linalg import norm as norm +from keras.src.ops.linalg import qr as qr +from keras.src.ops.linalg import solve as solve +from keras.src.ops.linalg import solve_triangular as solve_triangular +from keras.src.ops.linalg import svd as svd diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py new file mode 100644 index 000000000000..da08f380f227 --- /dev/null +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -0,0 +1,60 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.nn import average_pool as average_pool +from keras.src.ops.nn import batch_normalization as batch_normalization +from keras.src.ops.nn import binary_crossentropy as binary_crossentropy +from keras.src.ops.nn import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.ops.nn import celu as celu +from keras.src.ops.nn import conv as conv +from keras.src.ops.nn import conv_transpose as conv_transpose +from keras.src.ops.nn import ctc_decode as ctc_decode +from keras.src.ops.nn import ctc_loss as ctc_loss +from keras.src.ops.nn import depthwise_conv as depthwise_conv +from keras.src.ops.nn import dot_product_attention as dot_product_attention +from keras.src.ops.nn import elu as elu +from keras.src.ops.nn import gelu as gelu +from keras.src.ops.nn import glu as glu +from keras.src.ops.nn import hard_shrink as hard_shrink +from keras.src.ops.nn import hard_sigmoid as hard_sigmoid +from keras.src.ops.nn import hard_silu as hard_silu +from keras.src.ops.nn import hard_silu as hard_swish +from keras.src.ops.nn import hard_tanh as hard_tanh +from keras.src.ops.nn import layer_normalization as layer_normalization +from keras.src.ops.nn import leaky_relu as leaky_relu +from keras.src.ops.nn import log_sigmoid as log_sigmoid +from keras.src.ops.nn import log_softmax as log_softmax +from keras.src.ops.nn import max_pool as max_pool +from keras.src.ops.nn import moments as moments +from keras.src.ops.nn import multi_hot as multi_hot +from keras.src.ops.nn import normalize as normalize +from keras.src.ops.nn import one_hot as one_hot +from keras.src.ops.nn import polar as polar +from keras.src.ops.nn import psnr as psnr +from keras.src.ops.nn import relu as relu +from keras.src.ops.nn import relu6 as relu6 +from keras.src.ops.nn import rms_normalization as rms_normalization +from keras.src.ops.nn import selu as selu +from keras.src.ops.nn import separable_conv as separable_conv +from keras.src.ops.nn import sigmoid as sigmoid +from keras.src.ops.nn import silu as silu +from keras.src.ops.nn import silu as swish +from keras.src.ops.nn import soft_shrink as soft_shrink +from keras.src.ops.nn import softmax as softmax +from keras.src.ops.nn import softplus as softplus +from keras.src.ops.nn import softsign as softsign +from keras.src.ops.nn import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.ops.nn import sparse_plus as sparse_plus +from keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid +from keras.src.ops.nn import sparsemax as sparsemax +from keras.src.ops.nn import squareplus as squareplus +from keras.src.ops.nn import tanh_shrink as tanh_shrink +from keras.src.ops.nn import threshold as threshold +from keras.src.ops.nn import unfold as unfold diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py new file mode 100644 index 000000000000..ebeb384c181c --- /dev/null +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -0,0 +1,186 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.numpy import abs as abs +from keras.src.ops.numpy import absolute as absolute +from keras.src.ops.numpy import add as add +from keras.src.ops.numpy import all as all +from keras.src.ops.numpy import amax as amax +from keras.src.ops.numpy import amin as amin +from keras.src.ops.numpy import angle as angle +from keras.src.ops.numpy import any as any +from keras.src.ops.numpy import append as append +from keras.src.ops.numpy import arange as arange +from keras.src.ops.numpy import arccos as arccos +from keras.src.ops.numpy import arccosh as arccosh +from keras.src.ops.numpy import arcsin as arcsin +from keras.src.ops.numpy import arcsinh as arcsinh +from keras.src.ops.numpy import arctan as arctan +from keras.src.ops.numpy import arctan2 as arctan2 +from keras.src.ops.numpy import arctanh as arctanh +from keras.src.ops.numpy import argmax as argmax +from keras.src.ops.numpy import argmin as argmin +from keras.src.ops.numpy import argpartition as argpartition +from keras.src.ops.numpy import argsort as argsort +from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import average as average +from keras.src.ops.numpy import bartlett as bartlett +from keras.src.ops.numpy import bincount as bincount +from keras.src.ops.numpy import bitwise_and as bitwise_and +from keras.src.ops.numpy import bitwise_invert as bitwise_invert +from keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift +from keras.src.ops.numpy import bitwise_not as bitwise_not +from keras.src.ops.numpy import bitwise_or as bitwise_or +from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift +from keras.src.ops.numpy import bitwise_xor as bitwise_xor +from keras.src.ops.numpy import blackman as blackman +from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt +from keras.src.ops.numpy import ceil as ceil +from keras.src.ops.numpy import clip as clip +from keras.src.ops.numpy import concatenate as concatenate +from keras.src.ops.numpy import conj as conj +from keras.src.ops.numpy import conjugate as conjugate +from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef +from keras.src.ops.numpy import correlate as correlate +from keras.src.ops.numpy import cos as cos +from keras.src.ops.numpy import cosh as cosh +from keras.src.ops.numpy import count_nonzero as count_nonzero +from keras.src.ops.numpy import cross as cross +from keras.src.ops.numpy import cumprod as cumprod +from keras.src.ops.numpy import cumsum as cumsum +from keras.src.ops.numpy import deg2rad as deg2rad +from keras.src.ops.numpy import diag as diag +from keras.src.ops.numpy import diagflat as diagflat +from keras.src.ops.numpy import diagonal as diagonal +from keras.src.ops.numpy import diff as diff +from keras.src.ops.numpy import digitize as digitize +from keras.src.ops.numpy import divide as divide +from keras.src.ops.numpy import divide_no_nan as divide_no_nan +from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import einsum as einsum +from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import equal as equal +from keras.src.ops.numpy import exp as exp +from keras.src.ops.numpy import exp2 as exp2 +from keras.src.ops.numpy import expand_dims as expand_dims +from keras.src.ops.numpy import expm1 as expm1 +from keras.src.ops.numpy import eye as eye +from keras.src.ops.numpy import flip as flip +from keras.src.ops.numpy import floor as floor +from keras.src.ops.numpy import floor_divide as floor_divide +from keras.src.ops.numpy import full as full +from keras.src.ops.numpy import full_like as full_like +from keras.src.ops.numpy import gcd as gcd +from keras.src.ops.numpy import get_item as get_item +from keras.src.ops.numpy import greater as greater +from keras.src.ops.numpy import greater_equal as greater_equal +from keras.src.ops.numpy import hamming as hamming +from keras.src.ops.numpy import hanning as hanning +from keras.src.ops.numpy import heaviside as heaviside +from keras.src.ops.numpy import histogram as histogram +from keras.src.ops.numpy import hstack as hstack +from keras.src.ops.numpy import hypot as hypot +from keras.src.ops.numpy import identity as identity +from keras.src.ops.numpy import imag as imag +from keras.src.ops.numpy import inner as inner +from keras.src.ops.numpy import isclose as isclose +from keras.src.ops.numpy import isfinite as isfinite +from keras.src.ops.numpy import isin as isin +from keras.src.ops.numpy import isinf as isinf +from keras.src.ops.numpy import isnan as isnan +from keras.src.ops.numpy import isneginf as isneginf +from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import kaiser as kaiser +from keras.src.ops.numpy import kron as kron +from keras.src.ops.numpy import lcm as lcm +from keras.src.ops.numpy import left_shift as left_shift +from keras.src.ops.numpy import less as less +from keras.src.ops.numpy import less_equal as less_equal +from keras.src.ops.numpy import linspace as linspace +from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log1p as log1p +from keras.src.ops.numpy import log2 as log2 +from keras.src.ops.numpy import log10 as log10 +from keras.src.ops.numpy import logaddexp as logaddexp +from keras.src.ops.numpy import logaddexp2 as logaddexp2 +from keras.src.ops.numpy import logical_and as logical_and +from keras.src.ops.numpy import logical_not as logical_not +from keras.src.ops.numpy import logical_or as logical_or +from keras.src.ops.numpy import logical_xor as logical_xor +from keras.src.ops.numpy import logspace as logspace +from keras.src.ops.numpy import matmul as matmul +from keras.src.ops.numpy import max as max +from keras.src.ops.numpy import maximum as maximum +from keras.src.ops.numpy import mean as mean +from keras.src.ops.numpy import median as median +from keras.src.ops.numpy import meshgrid as meshgrid +from keras.src.ops.numpy import min as min +from keras.src.ops.numpy import minimum as minimum +from keras.src.ops.numpy import mod as mod +from keras.src.ops.numpy import moveaxis as moveaxis +from keras.src.ops.numpy import multiply as multiply +from keras.src.ops.numpy import nan_to_num as nan_to_num +from keras.src.ops.numpy import ndim as ndim +from keras.src.ops.numpy import negative as negative +from keras.src.ops.numpy import nonzero as nonzero +from keras.src.ops.numpy import not_equal as not_equal +from keras.src.ops.numpy import ones as ones +from keras.src.ops.numpy import ones_like as ones_like +from keras.src.ops.numpy import outer as outer +from keras.src.ops.numpy import pad as pad +from keras.src.ops.numpy import power as power +from keras.src.ops.numpy import prod as prod +from keras.src.ops.numpy import quantile as quantile +from keras.src.ops.numpy import ravel as ravel +from keras.src.ops.numpy import real as real +from keras.src.ops.numpy import reciprocal as reciprocal +from keras.src.ops.numpy import repeat as repeat +from keras.src.ops.numpy import reshape as reshape +from keras.src.ops.numpy import right_shift as right_shift +from keras.src.ops.numpy import roll as roll +from keras.src.ops.numpy import rot90 as rot90 +from keras.src.ops.numpy import round as round +from keras.src.ops.numpy import searchsorted as searchsorted +from keras.src.ops.numpy import select as select +from keras.src.ops.numpy import sign as sign +from keras.src.ops.numpy import signbit as signbit +from keras.src.ops.numpy import sin as sin +from keras.src.ops.numpy import sinh as sinh +from keras.src.ops.numpy import size as size +from keras.src.ops.numpy import slogdet as slogdet +from keras.src.ops.numpy import sort as sort +from keras.src.ops.numpy import split as split +from keras.src.ops.numpy import sqrt as sqrt +from keras.src.ops.numpy import square as square +from keras.src.ops.numpy import squeeze as squeeze +from keras.src.ops.numpy import stack as stack +from keras.src.ops.numpy import std as std +from keras.src.ops.numpy import subtract as subtract +from keras.src.ops.numpy import sum as sum +from keras.src.ops.numpy import swapaxes as swapaxes +from keras.src.ops.numpy import take as take +from keras.src.ops.numpy import take_along_axis as take_along_axis +from keras.src.ops.numpy import tan as tan +from keras.src.ops.numpy import tanh as tanh +from keras.src.ops.numpy import tensordot as tensordot +from keras.src.ops.numpy import tile as tile +from keras.src.ops.numpy import trace as trace +from keras.src.ops.numpy import transpose as transpose +from keras.src.ops.numpy import tri as tri +from keras.src.ops.numpy import tril as tril +from keras.src.ops.numpy import triu as triu +from keras.src.ops.numpy import true_divide as true_divide +from keras.src.ops.numpy import trunc as trunc +from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import var as var +from keras.src.ops.numpy import vdot as vdot +from keras.src.ops.numpy import vectorize as vectorize +from keras.src.ops.numpy import vstack as vstack +from keras.src.ops.numpy import where as where +from keras.src.ops.numpy import zeros as zeros +from keras.src.ops.numpy import zeros_like as zeros_like diff --git a/keras/api/_tf_keras/keras/optimizers/__init__.py b/keras/api/_tf_keras/keras/optimizers/__init__.py new file mode 100644 index 000000000000..40f6ab4018f5 --- /dev/null +++ b/keras/api/_tf_keras/keras/optimizers/__init__.py @@ -0,0 +1,28 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.optimizers import legacy as legacy +from keras.optimizers import schedules as schedules +from keras.src.optimizers import deserialize as deserialize +from keras.src.optimizers import get as get +from keras.src.optimizers import serialize as serialize +from keras.src.optimizers.adadelta import Adadelta as Adadelta +from keras.src.optimizers.adafactor import Adafactor as Adafactor +from keras.src.optimizers.adagrad import Adagrad as Adagrad +from keras.src.optimizers.adam import Adam as Adam +from keras.src.optimizers.adamax import Adamax as Adamax +from keras.src.optimizers.adamw import AdamW as AdamW +from keras.src.optimizers.ftrl import Ftrl as Ftrl +from keras.src.optimizers.lamb import Lamb as Lamb +from keras.src.optimizers.lion import Lion as Lion +from keras.src.optimizers.loss_scale_optimizer import ( + LossScaleOptimizer as LossScaleOptimizer, +) +from keras.src.optimizers.muon import Muon as Muon +from keras.src.optimizers.nadam import Nadam as Nadam +from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.optimizers.rmsprop import RMSprop as RMSprop +from keras.src.optimizers.sgd import SGD as SGD diff --git a/keras/api/_tf_keras/keras/optimizers/legacy/__init__.py b/keras/api/_tf_keras/keras/optimizers/legacy/__init__.py new file mode 100644 index 000000000000..bff1a0313630 --- /dev/null +++ b/keras/api/_tf_keras/keras/optimizers/legacy/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.optimizers import LegacyOptimizerWarning as Adagrad +from keras.src.optimizers import LegacyOptimizerWarning as Adam +from keras.src.optimizers import LegacyOptimizerWarning as Ftrl +from keras.src.optimizers import LegacyOptimizerWarning as Optimizer +from keras.src.optimizers import LegacyOptimizerWarning as RMSprop +from keras.src.optimizers import LegacyOptimizerWarning as SGD diff --git a/keras/api/_tf_keras/keras/optimizers/schedules/__init__.py b/keras/api/_tf_keras/keras/optimizers/schedules/__init__.py new file mode 100644 index 000000000000..da9621aa36b1 --- /dev/null +++ b/keras/api/_tf_keras/keras/optimizers/schedules/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.optimizers.schedules.learning_rate_schedule import ( + CosineDecay as CosineDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + CosineDecayRestarts as CosineDecayRestarts, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + ExponentialDecay as ExponentialDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + InverseTimeDecay as InverseTimeDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + LearningRateSchedule as LearningRateSchedule, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + PiecewiseConstantDecay as PiecewiseConstantDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + PolynomialDecay as PolynomialDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + deserialize as deserialize, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + serialize as serialize, +) diff --git a/keras/api/_tf_keras/keras/preprocessing/__init__.py b/keras/api/_tf_keras/keras/preprocessing/__init__.py new file mode 100644 index 000000000000..b11b4f3fd272 --- /dev/null +++ b/keras/api/_tf_keras/keras/preprocessing/__init__.py @@ -0,0 +1,18 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras._tf_keras.keras.preprocessing import image as image +from keras._tf_keras.keras.preprocessing import sequence as sequence +from keras._tf_keras.keras.preprocessing import text as text +from keras.src.utils.image_dataset_utils import ( + image_dataset_from_directory as image_dataset_from_directory, +) +from keras.src.utils.text_dataset_utils import ( + text_dataset_from_directory as text_dataset_from_directory, +) +from keras.src.utils.timeseries_dataset_utils import ( + timeseries_dataset_from_array as timeseries_dataset_from_array, +) diff --git a/keras/api/_tf_keras/keras/preprocessing/image/__init__.py b/keras/api/_tf_keras/keras/preprocessing/image/__init__.py new file mode 100644 index 000000000000..43986878eb40 --- /dev/null +++ b/keras/api/_tf_keras/keras/preprocessing/image/__init__.py @@ -0,0 +1,42 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.preprocessing.image import ( + DirectoryIterator as DirectoryIterator, +) +from keras.src.legacy.preprocessing.image import ( + ImageDataGenerator as ImageDataGenerator, +) +from keras.src.legacy.preprocessing.image import Iterator as Iterator +from keras.src.legacy.preprocessing.image import ( + NumpyArrayIterator as NumpyArrayIterator, +) +from keras.src.legacy.preprocessing.image import ( + apply_affine_transform as apply_affine_transform, +) +from keras.src.legacy.preprocessing.image import ( + apply_brightness_shift as apply_brightness_shift, +) +from keras.src.legacy.preprocessing.image import ( + apply_channel_shift as apply_channel_shift, +) +from keras.src.legacy.preprocessing.image import ( + random_brightness as random_brightness, +) +from keras.src.legacy.preprocessing.image import ( + random_channel_shift as random_channel_shift, +) +from keras.src.legacy.preprocessing.image import ( + random_rotation as random_rotation, +) +from keras.src.legacy.preprocessing.image import random_shear as random_shear +from keras.src.legacy.preprocessing.image import random_shift as random_shift +from keras.src.legacy.preprocessing.image import random_zoom as random_zoom +from keras.src.utils.image_utils import array_to_img as array_to_img +from keras.src.utils.image_utils import img_to_array as img_to_array +from keras.src.utils.image_utils import load_img as load_img +from keras.src.utils.image_utils import save_img as save_img +from keras.src.utils.image_utils import smart_resize as smart_resize diff --git a/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py b/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py new file mode 100644 index 000000000000..501c1f1123de --- /dev/null +++ b/keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py @@ -0,0 +1,14 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.preprocessing.sequence import ( + TimeseriesGenerator as TimeseriesGenerator, +) +from keras.src.legacy.preprocessing.sequence import ( + make_sampling_table as make_sampling_table, +) +from keras.src.legacy.preprocessing.sequence import skipgrams as skipgrams +from keras.src.utils.sequence_utils import pad_sequences as pad_sequences diff --git a/keras/api/_tf_keras/keras/preprocessing/text/__init__.py b/keras/api/_tf_keras/keras/preprocessing/text/__init__.py new file mode 100644 index 000000000000..01399ab15737 --- /dev/null +++ b/keras/api/_tf_keras/keras/preprocessing/text/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.preprocessing.text import Tokenizer as Tokenizer +from keras.src.legacy.preprocessing.text import hashing_trick as hashing_trick +from keras.src.legacy.preprocessing.text import one_hot as one_hot +from keras.src.legacy.preprocessing.text import ( + text_to_word_sequence as text_to_word_sequence, +) +from keras.src.legacy.preprocessing.text import ( + tokenizer_from_json as tokenizer_from_json, +) diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py new file mode 100644 index 000000000000..299e467ac1bb --- /dev/null +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -0,0 +1,27 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.quantizers import deserialize as deserialize +from keras.src.quantizers import get as get +from keras.src.quantizers import serialize as serialize +from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer +from keras.src.quantizers.quantizers import Quantizer as Quantizer +from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize +from keras.src.quantizers.quantizers import ( + compute_float8_amax_history as compute_float8_amax_history, +) +from keras.src.quantizers.quantizers import ( + compute_float8_scale as compute_float8_scale, +) +from keras.src.quantizers.quantizers import ( + fake_quant_with_min_max_vars as fake_quant_with_min_max_vars, +) +from keras.src.quantizers.quantizers import pack_int4 as pack_int4 +from keras.src.quantizers.quantizers import ( + quantize_and_dequantize as quantize_and_dequantize, +) +from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4 diff --git a/keras/api/_tf_keras/keras/random/__init__.py b/keras/api/_tf_keras/keras/random/__init__.py new file mode 100644 index 000000000000..d0ee60a77c92 --- /dev/null +++ b/keras/api/_tf_keras/keras/random/__init__.py @@ -0,0 +1,17 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.random.random import beta as beta +from keras.src.random.random import binomial as binomial +from keras.src.random.random import categorical as categorical +from keras.src.random.random import dropout as dropout +from keras.src.random.random import gamma as gamma +from keras.src.random.random import normal as normal +from keras.src.random.random import randint as randint +from keras.src.random.random import shuffle as shuffle +from keras.src.random.random import truncated_normal as truncated_normal +from keras.src.random.random import uniform as uniform +from keras.src.random.seed_generator import SeedGenerator as SeedGenerator diff --git a/keras/api/_tf_keras/keras/regularizers/__init__.py b/keras/api/_tf_keras/keras/regularizers/__init__.py new file mode 100644 index 000000000000..1e3609f71c75 --- /dev/null +++ b/keras/api/_tf_keras/keras/regularizers/__init__.py @@ -0,0 +1,22 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.regularizers import deserialize as deserialize +from keras.src.regularizers import get as get +from keras.src.regularizers import serialize as serialize +from keras.src.regularizers.regularizers import L1 as L1 +from keras.src.regularizers.regularizers import L1 as l1 +from keras.src.regularizers.regularizers import L1L2 as L1L2 +from keras.src.regularizers.regularizers import L1L2 as l1_l2 +from keras.src.regularizers.regularizers import L2 as L2 +from keras.src.regularizers.regularizers import L2 as l2 +from keras.src.regularizers.regularizers import ( + OrthogonalRegularizer as OrthogonalRegularizer, +) +from keras.src.regularizers.regularizers import ( + OrthogonalRegularizer as orthogonal_regularizer, +) +from keras.src.regularizers.regularizers import Regularizer as Regularizer diff --git a/keras/api/_tf_keras/keras/saving/__init__.py b/keras/api/_tf_keras/keras/saving/__init__.py new file mode 100644 index 000000000000..28edd8779337 --- /dev/null +++ b/keras/api/_tf_keras/keras/saving/__init__.py @@ -0,0 +1,35 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.saving.file_editor import KerasFileEditor as KerasFileEditor +from keras.src.saving.object_registration import ( + CustomObjectScope as CustomObjectScope, +) +from keras.src.saving.object_registration import ( + CustomObjectScope as custom_object_scope, +) +from keras.src.saving.object_registration import ( + get_custom_objects as get_custom_objects, +) +from keras.src.saving.object_registration import ( + get_registered_name as get_registered_name, +) +from keras.src.saving.object_registration import ( + get_registered_object as get_registered_object, +) +from keras.src.saving.object_registration import ( + register_keras_serializable as register_keras_serializable, +) +from keras.src.saving.saving_api import load_model as load_model +from keras.src.saving.saving_api import load_weights as load_weights +from keras.src.saving.saving_api import save_model as save_model +from keras.src.saving.saving_api import save_weights as save_weights +from keras.src.saving.serialization_lib import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.saving.serialization_lib import ( + serialize_keras_object as serialize_keras_object, +) diff --git a/keras/api/_tf_keras/keras/tree/__init__.py b/keras/api/_tf_keras/keras/tree/__init__.py new file mode 100644 index 000000000000..80d9f25244e8 --- /dev/null +++ b/keras/api/_tf_keras/keras/tree/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.tree.tree_api import MAP_TO_NONE as MAP_TO_NONE +from keras.src.tree.tree_api import assert_same_paths as assert_same_paths +from keras.src.tree.tree_api import ( + assert_same_structure as assert_same_structure, +) +from keras.src.tree.tree_api import flatten as flatten +from keras.src.tree.tree_api import flatten_with_path as flatten_with_path +from keras.src.tree.tree_api import is_nested as is_nested +from keras.src.tree.tree_api import lists_to_tuples as lists_to_tuples +from keras.src.tree.tree_api import map_shape_structure as map_shape_structure +from keras.src.tree.tree_api import map_structure as map_structure +from keras.src.tree.tree_api import map_structure_up_to as map_structure_up_to +from keras.src.tree.tree_api import pack_sequence_as as pack_sequence_as +from keras.src.tree.tree_api import traverse as traverse diff --git a/keras/api/_tf_keras/keras/utils/__init__.py b/keras/api/_tf_keras/keras/utils/__init__.py new file mode 100644 index 000000000000..8ddbda527609 --- /dev/null +++ b/keras/api/_tf_keras/keras/utils/__init__.py @@ -0,0 +1,90 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.backend.common.global_state import clear_session as clear_session +from keras.src.backend.common.keras_tensor import ( + is_keras_tensor as is_keras_tensor, +) +from keras.src.backend.common.variables import ( + standardize_dtype as standardize_dtype, +) +from keras.src.layers.preprocessing.feature_space import ( + FeatureSpace as FeatureSpace, +) +from keras.src.ops.operation_utils import get_source_inputs as get_source_inputs +from keras.src.saving.object_registration import ( + CustomObjectScope as CustomObjectScope, +) +from keras.src.saving.object_registration import ( + CustomObjectScope as custom_object_scope, +) +from keras.src.saving.object_registration import ( + get_custom_objects as get_custom_objects, +) +from keras.src.saving.object_registration import ( + get_registered_name as get_registered_name, +) +from keras.src.saving.object_registration import ( + get_registered_object as get_registered_object, +) +from keras.src.saving.object_registration import ( + register_keras_serializable as register_keras_serializable, +) +from keras.src.saving.serialization_lib import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.saving.serialization_lib import ( + serialize_keras_object as serialize_keras_object, +) +from keras.src.trainers.data_adapters.data_adapter_utils import ( + pack_x_y_sample_weight as pack_x_y_sample_weight, +) +from keras.src.trainers.data_adapters.data_adapter_utils import ( + unpack_x_y_sample_weight as unpack_x_y_sample_weight, +) +from keras.src.trainers.data_adapters.py_dataset_adapter import ( + PyDataset as PyDataset, +) +from keras.src.trainers.data_adapters.py_dataset_adapter import ( + PyDataset as Sequence, +) +from keras.src.utils.audio_dataset_utils import ( + audio_dataset_from_directory as audio_dataset_from_directory, +) +from keras.src.utils.config import Config as Config +from keras.src.utils.dataset_utils import split_dataset as split_dataset +from keras.src.utils.file_utils import get_file as get_file +from keras.src.utils.image_dataset_utils import ( + image_dataset_from_directory as image_dataset_from_directory, +) +from keras.src.utils.image_utils import array_to_img as array_to_img +from keras.src.utils.image_utils import img_to_array as img_to_array +from keras.src.utils.image_utils import load_img as load_img +from keras.src.utils.image_utils import save_img as save_img +from keras.src.utils.io_utils import ( + disable_interactive_logging as disable_interactive_logging, +) +from keras.src.utils.io_utils import ( + enable_interactive_logging as enable_interactive_logging, +) +from keras.src.utils.io_utils import ( + is_interactive_logging_enabled as is_interactive_logging_enabled, +) +from keras.src.utils.model_visualization import model_to_dot as model_to_dot +from keras.src.utils.model_visualization import plot_model as plot_model +from keras.src.utils.numerical_utils import normalize as normalize +from keras.src.utils.numerical_utils import to_categorical as to_categorical +from keras.src.utils.progbar import Progbar as Progbar +from keras.src.utils.rng_utils import set_random_seed as set_random_seed +from keras.src.utils.sequence_utils import pad_sequences as pad_sequences +from keras.src.utils.text_dataset_utils import ( + text_dataset_from_directory as text_dataset_from_directory, +) +from keras.src.utils.timeseries_dataset_utils import ( + timeseries_dataset_from_array as timeseries_dataset_from_array, +) +from keras.utils import bounding_boxes as bounding_boxes +from keras.utils import legacy as legacy diff --git a/keras/api/_tf_keras/keras/utils/bounding_boxes/__init__.py b/keras/api/_tf_keras/keras/utils/bounding_boxes/__init__.py new file mode 100644 index 000000000000..40221bd75c94 --- /dev/null +++ b/keras/api/_tf_keras/keras/utils/bounding_boxes/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + affine_transform as affine_transform, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + clip_to_image_size as clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + convert_format as convert_format, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + crop as crop, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + decode_deltas_to_boxes as decode_deltas_to_boxes, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + encode_box_to_deltas as encode_box_to_deltas, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + pad as pad, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import ( + compute_ciou as compute_ciou, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import ( + compute_iou as compute_iou, +) diff --git a/keras/api/_tf_keras/keras/utils/legacy/__init__.py b/keras/api/_tf_keras/keras/utils/legacy/__init__.py new file mode 100644 index 000000000000..1e3aa0ee9d5c --- /dev/null +++ b/keras/api/_tf_keras/keras/utils/legacy/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.saving.serialization import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.legacy.saving.serialization import ( + serialize_keras_object as serialize_keras_object, +) diff --git a/keras/api/_tf_keras/keras/visualization/__init__.py b/keras/api/_tf_keras/keras/visualization/__init__.py new file mode 100644 index 000000000000..6e3482a8d59a --- /dev/null +++ b/keras/api/_tf_keras/keras/visualization/__init__.py @@ -0,0 +1,21 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.visualization.draw_bounding_boxes import ( + draw_bounding_boxes as draw_bounding_boxes, +) +from keras.src.visualization.draw_segmentation_masks import ( + draw_segmentation_masks as draw_segmentation_masks, +) +from keras.src.visualization.plot_bounding_box_gallery import ( + plot_bounding_box_gallery as plot_bounding_box_gallery, +) +from keras.src.visualization.plot_image_gallery import ( + plot_image_gallery as plot_image_gallery, +) +from keras.src.visualization.plot_segmentation_mask_gallery import ( + plot_segmentation_mask_gallery as plot_segmentation_mask_gallery, +) diff --git a/keras/api/_tf_keras/keras/wrappers/__init__.py b/keras/api/_tf_keras/keras/wrappers/__init__.py new file mode 100644 index 000000000000..e3aa52524ca6 --- /dev/null +++ b/keras/api/_tf_keras/keras/wrappers/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.wrappers.sklearn_wrapper import ( + SKLearnClassifier as SKLearnClassifier, +) +from keras.src.wrappers.sklearn_wrapper import ( + SKLearnRegressor as SKLearnRegressor, +) +from keras.src.wrappers.sklearn_wrapper import ( + SKLearnTransformer as SKLearnTransformer, +) diff --git a/keras/api/activations/__init__.py b/keras/api/activations/__init__.py new file mode 100644 index 000000000000..85ae031a72dc --- /dev/null +++ b/keras/api/activations/__init__.py @@ -0,0 +1,41 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.activations import deserialize as deserialize +from keras.src.activations import get as get +from keras.src.activations import serialize as serialize +from keras.src.activations.activations import celu as celu +from keras.src.activations.activations import elu as elu +from keras.src.activations.activations import exponential as exponential +from keras.src.activations.activations import gelu as gelu +from keras.src.activations.activations import glu as glu +from keras.src.activations.activations import hard_shrink as hard_shrink +from keras.src.activations.activations import hard_sigmoid as hard_sigmoid +from keras.src.activations.activations import hard_silu as hard_silu +from keras.src.activations.activations import hard_silu as hard_swish +from keras.src.activations.activations import hard_tanh as hard_tanh +from keras.src.activations.activations import leaky_relu as leaky_relu +from keras.src.activations.activations import linear as linear +from keras.src.activations.activations import log_sigmoid as log_sigmoid +from keras.src.activations.activations import log_softmax as log_softmax +from keras.src.activations.activations import mish as mish +from keras.src.activations.activations import relu as relu +from keras.src.activations.activations import relu6 as relu6 +from keras.src.activations.activations import selu as selu +from keras.src.activations.activations import sigmoid as sigmoid +from keras.src.activations.activations import silu as silu +from keras.src.activations.activations import silu as swish +from keras.src.activations.activations import soft_shrink as soft_shrink +from keras.src.activations.activations import softmax as softmax +from keras.src.activations.activations import softplus as softplus +from keras.src.activations.activations import softsign as softsign +from keras.src.activations.activations import sparse_plus as sparse_plus +from keras.src.activations.activations import sparse_sigmoid as sparse_sigmoid +from keras.src.activations.activations import sparsemax as sparsemax +from keras.src.activations.activations import squareplus as squareplus +from keras.src.activations.activations import tanh as tanh +from keras.src.activations.activations import tanh_shrink as tanh_shrink +from keras.src.activations.activations import threshold as threshold diff --git a/keras/api/applications/__init__.py b/keras/api/applications/__init__.py new file mode 100644 index 000000000000..7c030b36bd4e --- /dev/null +++ b/keras/api/applications/__init__.py @@ -0,0 +1,83 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.applications import convnext as convnext +from keras.applications import densenet as densenet +from keras.applications import efficientnet as efficientnet +from keras.applications import efficientnet_v2 as efficientnet_v2 +from keras.applications import imagenet_utils as imagenet_utils +from keras.applications import inception_resnet_v2 as inception_resnet_v2 +from keras.applications import inception_v3 as inception_v3 +from keras.applications import mobilenet as mobilenet +from keras.applications import mobilenet_v2 as mobilenet_v2 +from keras.applications import mobilenet_v3 as mobilenet_v3 +from keras.applications import nasnet as nasnet +from keras.applications import resnet as resnet +from keras.applications import resnet50 as resnet50 +from keras.applications import resnet_v2 as resnet_v2 +from keras.applications import vgg16 as vgg16 +from keras.applications import vgg19 as vgg19 +from keras.applications import xception as xception +from keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase +from keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge +from keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall +from keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny +from keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge +from keras.src.applications.densenet import DenseNet121 as DenseNet121 +from keras.src.applications.densenet import DenseNet169 as DenseNet169 +from keras.src.applications.densenet import DenseNet201 as DenseNet201 +from keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0 +from keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1 +from keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2 +from keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3 +from keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4 +from keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5 +from keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6 +from keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7 +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B0 as EfficientNetV2B0, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B1 as EfficientNetV2B1, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B2 as EfficientNetV2B2, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B3 as EfficientNetV2B3, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2L as EfficientNetV2L, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2M as EfficientNetV2M, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2S as EfficientNetV2S, +) +from keras.src.applications.inception_resnet_v2 import ( + InceptionResNetV2 as InceptionResNetV2, +) +from keras.src.applications.inception_v3 import InceptionV3 as InceptionV3 +from keras.src.applications.mobilenet import MobileNet as MobileNet +from keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2 +from keras.src.applications.mobilenet_v3 import ( + MobileNetV3Large as MobileNetV3Large, +) +from keras.src.applications.mobilenet_v3 import ( + MobileNetV3Small as MobileNetV3Small, +) +from keras.src.applications.nasnet import NASNetLarge as NASNetLarge +from keras.src.applications.nasnet import NASNetMobile as NASNetMobile +from keras.src.applications.resnet import ResNet50 as ResNet50 +from keras.src.applications.resnet import ResNet101 as ResNet101 +from keras.src.applications.resnet import ResNet152 as ResNet152 +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 +from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 +from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.vgg16 import VGG16 as VGG16 +from keras.src.applications.vgg19 import VGG19 as VGG19 +from keras.src.applications.xception import Xception as Xception diff --git a/keras/api/applications/convnext/__init__.py b/keras/api/applications/convnext/__init__.py new file mode 100644 index 000000000000..c6d7bb7117e8 --- /dev/null +++ b/keras/api/applications/convnext/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase +from keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge +from keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall +from keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny +from keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge +from keras.src.applications.convnext import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.convnext import preprocess_input as preprocess_input diff --git a/keras/api/applications/densenet/__init__.py b/keras/api/applications/densenet/__init__.py new file mode 100644 index 000000000000..6d6a27101099 --- /dev/null +++ b/keras/api/applications/densenet/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.densenet import DenseNet121 as DenseNet121 +from keras.src.applications.densenet import DenseNet169 as DenseNet169 +from keras.src.applications.densenet import DenseNet201 as DenseNet201 +from keras.src.applications.densenet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.densenet import preprocess_input as preprocess_input diff --git a/keras/api/applications/efficientnet/__init__.py b/keras/api/applications/efficientnet/__init__.py new file mode 100644 index 000000000000..16384b74e2b2 --- /dev/null +++ b/keras/api/applications/efficientnet/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0 +from keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1 +from keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2 +from keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3 +from keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4 +from keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5 +from keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6 +from keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7 +from keras.src.applications.efficientnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.efficientnet import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/efficientnet_v2/__init__.py b/keras/api/applications/efficientnet_v2/__init__.py new file mode 100644 index 000000000000..8d13352008b6 --- /dev/null +++ b/keras/api/applications/efficientnet_v2/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B0 as EfficientNetV2B0, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B1 as EfficientNetV2B1, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B2 as EfficientNetV2B2, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2B3 as EfficientNetV2B3, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2L as EfficientNetV2L, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2M as EfficientNetV2M, +) +from keras.src.applications.efficientnet_v2 import ( + EfficientNetV2S as EfficientNetV2S, +) +from keras.src.applications.efficientnet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.efficientnet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/imagenet_utils/__init__.py b/keras/api/applications/imagenet_utils/__init__.py new file mode 100644 index 000000000000..66804964efbe --- /dev/null +++ b/keras/api/applications/imagenet_utils/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.imagenet_utils import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.imagenet_utils import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/inception_resnet_v2/__init__.py b/keras/api/applications/inception_resnet_v2/__init__.py new file mode 100644 index 000000000000..4cb545a39fe1 --- /dev/null +++ b/keras/api/applications/inception_resnet_v2/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.inception_resnet_v2 import ( + InceptionResNetV2 as InceptionResNetV2, +) +from keras.src.applications.inception_resnet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.inception_resnet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/inception_v3/__init__.py b/keras/api/applications/inception_v3/__init__.py new file mode 100644 index 000000000000..a7db7bd80ce8 --- /dev/null +++ b/keras/api/applications/inception_v3/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.inception_v3 import InceptionV3 as InceptionV3 +from keras.src.applications.inception_v3 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.inception_v3 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/mobilenet/__init__.py b/keras/api/applications/mobilenet/__init__.py new file mode 100644 index 000000000000..6e721019c42e --- /dev/null +++ b/keras/api/applications/mobilenet/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.mobilenet import MobileNet as MobileNet +from keras.src.applications.mobilenet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.mobilenet import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/mobilenet_v2/__init__.py b/keras/api/applications/mobilenet_v2/__init__.py new file mode 100644 index 000000000000..15ebaa3155a6 --- /dev/null +++ b/keras/api/applications/mobilenet_v2/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2 +from keras.src.applications.mobilenet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.mobilenet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/mobilenet_v3/__init__.py b/keras/api/applications/mobilenet_v3/__init__.py new file mode 100644 index 000000000000..a5abb926247c --- /dev/null +++ b/keras/api/applications/mobilenet_v3/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.mobilenet_v3 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.mobilenet_v3 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/nasnet/__init__.py b/keras/api/applications/nasnet/__init__.py new file mode 100644 index 000000000000..c831e135fbd6 --- /dev/null +++ b/keras/api/applications/nasnet/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.nasnet import NASNetLarge as NASNetLarge +from keras.src.applications.nasnet import NASNetMobile as NASNetMobile +from keras.src.applications.nasnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.nasnet import preprocess_input as preprocess_input diff --git a/keras/api/applications/resnet/__init__.py b/keras/api/applications/resnet/__init__.py new file mode 100644 index 000000000000..b8a25644e1d9 --- /dev/null +++ b/keras/api/applications/resnet/__init__.py @@ -0,0 +1,13 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.resnet import ResNet50 as ResNet50 +from keras.src.applications.resnet import ResNet101 as ResNet101 +from keras.src.applications.resnet import ResNet152 as ResNet152 +from keras.src.applications.resnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.resnet import preprocess_input as preprocess_input diff --git a/keras/api/applications/resnet50/__init__.py b/keras/api/applications/resnet50/__init__.py new file mode 100644 index 000000000000..6cff78c6749c --- /dev/null +++ b/keras/api/applications/resnet50/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.resnet import ResNet50 as ResNet50 +from keras.src.applications.resnet import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.resnet import preprocess_input as preprocess_input diff --git a/keras/api/applications/resnet_v2/__init__.py b/keras/api/applications/resnet_v2/__init__.py new file mode 100644 index 000000000000..7f92dd56f374 --- /dev/null +++ b/keras/api/applications/resnet_v2/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2 +from keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2 +from keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2 +from keras.src.applications.resnet_v2 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.resnet_v2 import ( + preprocess_input as preprocess_input, +) diff --git a/keras/api/applications/vgg16/__init__.py b/keras/api/applications/vgg16/__init__.py new file mode 100644 index 000000000000..17fb30585d9a --- /dev/null +++ b/keras/api/applications/vgg16/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.vgg16 import VGG16 as VGG16 +from keras.src.applications.vgg16 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.vgg16 import preprocess_input as preprocess_input diff --git a/keras/api/applications/vgg19/__init__.py b/keras/api/applications/vgg19/__init__.py new file mode 100644 index 000000000000..83f865b3876b --- /dev/null +++ b/keras/api/applications/vgg19/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.vgg19 import VGG19 as VGG19 +from keras.src.applications.vgg19 import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.vgg19 import preprocess_input as preprocess_input diff --git a/keras/api/applications/xception/__init__.py b/keras/api/applications/xception/__init__.py new file mode 100644 index 000000000000..09a5859aab4b --- /dev/null +++ b/keras/api/applications/xception/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.applications.xception import Xception as Xception +from keras.src.applications.xception import ( + decode_predictions as decode_predictions, +) +from keras.src.applications.xception import preprocess_input as preprocess_input diff --git a/keras/api/backend/__init__.py b/keras/api/backend/__init__.py new file mode 100644 index 000000000000..a2a50b9033a4 --- /dev/null +++ b/keras/api/backend/__init__.py @@ -0,0 +1,26 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.backend.common.dtypes import result_type as result_type +from keras.src.backend.common.global_state import clear_session as clear_session +from keras.src.backend.common.keras_tensor import ( + is_keras_tensor as is_keras_tensor, +) +from keras.src.backend.common.variables import is_float_dtype as is_float_dtype +from keras.src.backend.common.variables import is_int_dtype as is_int_dtype +from keras.src.backend.common.variables import ( + standardize_dtype as standardize_dtype, +) +from keras.src.backend.config import backend as backend +from keras.src.backend.config import epsilon as epsilon +from keras.src.backend.config import floatx as floatx +from keras.src.backend.config import image_data_format as image_data_format +from keras.src.backend.config import set_epsilon as set_epsilon +from keras.src.backend.config import set_floatx as set_floatx +from keras.src.backend.config import ( + set_image_data_format as set_image_data_format, +) +from keras.src.utils.naming import get_uid as get_uid diff --git a/keras/api/callbacks/__init__.py b/keras/api/callbacks/__init__.py new file mode 100644 index 000000000000..4e165cddb6a8 --- /dev/null +++ b/keras/api/callbacks/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.callbacks.backup_and_restore import ( + BackupAndRestore as BackupAndRestore, +) +from keras.src.callbacks.callback import Callback as Callback +from keras.src.callbacks.callback_list import CallbackList as CallbackList +from keras.src.callbacks.csv_logger import CSVLogger as CSVLogger +from keras.src.callbacks.early_stopping import EarlyStopping as EarlyStopping +from keras.src.callbacks.history import History as History +from keras.src.callbacks.lambda_callback import LambdaCallback as LambdaCallback +from keras.src.callbacks.learning_rate_scheduler import ( + LearningRateScheduler as LearningRateScheduler, +) +from keras.src.callbacks.model_checkpoint import ( + ModelCheckpoint as ModelCheckpoint, +) +from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger +from keras.src.callbacks.reduce_lr_on_plateau import ( + ReduceLROnPlateau as ReduceLROnPlateau, +) +from keras.src.callbacks.remote_monitor import RemoteMonitor as RemoteMonitor +from keras.src.callbacks.swap_ema_weights import ( + SwapEMAWeights as SwapEMAWeights, +) +from keras.src.callbacks.tensorboard import TensorBoard as TensorBoard +from keras.src.callbacks.terminate_on_nan import ( + TerminateOnNaN as TerminateOnNaN, +) diff --git a/keras/api/config/__init__.py b/keras/api/config/__init__.py new file mode 100644 index 000000000000..8cf3a1c30abd --- /dev/null +++ b/keras/api/config/__init__.py @@ -0,0 +1,57 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.backend.config import backend as backend +from keras.src.backend.config import ( + disable_flash_attention as disable_flash_attention, +) +from keras.src.backend.config import ( + enable_flash_attention as enable_flash_attention, +) +from keras.src.backend.config import epsilon as epsilon +from keras.src.backend.config import floatx as floatx +from keras.src.backend.config import image_data_format as image_data_format +from keras.src.backend.config import ( + is_flash_attention_enabled as is_flash_attention_enabled, +) +from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled +from keras.src.backend.config import max_epochs as max_epochs +from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch +from keras.src.backend.config import set_epsilon as set_epsilon +from keras.src.backend.config import set_floatx as set_floatx +from keras.src.backend.config import ( + set_image_data_format as set_image_data_format, +) +from keras.src.backend.config import set_max_epochs as set_max_epochs +from keras.src.backend.config import ( + set_max_steps_per_epoch as set_max_steps_per_epoch, +) +from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy +from keras.src.dtype_policies.dtype_policy import ( + set_dtype_policy as set_dtype_policy, +) +from keras.src.saving.serialization_lib import ( + enable_unsafe_deserialization as enable_unsafe_deserialization, +) +from keras.src.utils.backend_utils import set_backend as set_backend +from keras.src.utils.io_utils import ( + disable_interactive_logging as disable_interactive_logging, +) +from keras.src.utils.io_utils import ( + enable_interactive_logging as enable_interactive_logging, +) +from keras.src.utils.io_utils import ( + is_interactive_logging_enabled as is_interactive_logging_enabled, +) +from keras.src.utils.traceback_utils import ( + disable_traceback_filtering as disable_traceback_filtering, +) +from keras.src.utils.traceback_utils import ( + enable_traceback_filtering as enable_traceback_filtering, +) +from keras.src.utils.traceback_utils import ( + is_traceback_filtering_enabled as is_traceback_filtering_enabled, +) diff --git a/keras/api/constraints/__init__.py b/keras/api/constraints/__init__.py new file mode 100644 index 000000000000..47d73d44627f --- /dev/null +++ b/keras/api/constraints/__init__.py @@ -0,0 +1,18 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.constraints import deserialize as deserialize +from keras.src.constraints import get as get +from keras.src.constraints import serialize as serialize +from keras.src.constraints.constraints import Constraint as Constraint +from keras.src.constraints.constraints import MaxNorm as MaxNorm +from keras.src.constraints.constraints import MaxNorm as max_norm +from keras.src.constraints.constraints import MinMaxNorm as MinMaxNorm +from keras.src.constraints.constraints import MinMaxNorm as min_max_norm +from keras.src.constraints.constraints import NonNeg as NonNeg +from keras.src.constraints.constraints import NonNeg as non_neg +from keras.src.constraints.constraints import UnitNorm as UnitNorm +from keras.src.constraints.constraints import UnitNorm as unit_norm diff --git a/keras/api/datasets/__init__.py b/keras/api/datasets/__init__.py new file mode 100644 index 000000000000..f61e994a4bff --- /dev/null +++ b/keras/api/datasets/__init__.py @@ -0,0 +1,14 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.datasets import boston_housing as boston_housing +from keras.datasets import california_housing as california_housing +from keras.datasets import cifar10 as cifar10 +from keras.datasets import cifar100 as cifar100 +from keras.datasets import fashion_mnist as fashion_mnist +from keras.datasets import imdb as imdb +from keras.datasets import mnist as mnist +from keras.datasets import reuters as reuters diff --git a/keras/api/datasets/boston_housing/__init__.py b/keras/api/datasets/boston_housing/__init__.py new file mode 100644 index 000000000000..897f8516ca82 --- /dev/null +++ b/keras/api/datasets/boston_housing/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.boston_housing import load_data as load_data diff --git a/keras/api/datasets/california_housing/__init__.py b/keras/api/datasets/california_housing/__init__.py new file mode 100644 index 000000000000..602bf81ac2cd --- /dev/null +++ b/keras/api/datasets/california_housing/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.california_housing import load_data as load_data diff --git a/keras/api/datasets/cifar10/__init__.py b/keras/api/datasets/cifar10/__init__.py new file mode 100644 index 000000000000..f7aad7fd1a55 --- /dev/null +++ b/keras/api/datasets/cifar10/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.cifar10 import load_data as load_data diff --git a/keras/api/datasets/cifar100/__init__.py b/keras/api/datasets/cifar100/__init__.py new file mode 100644 index 000000000000..237fafab6fc6 --- /dev/null +++ b/keras/api/datasets/cifar100/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.cifar100 import load_data as load_data diff --git a/keras/api/datasets/fashion_mnist/__init__.py b/keras/api/datasets/fashion_mnist/__init__.py new file mode 100644 index 000000000000..317f0951a063 --- /dev/null +++ b/keras/api/datasets/fashion_mnist/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.fashion_mnist import load_data as load_data diff --git a/keras/api/datasets/imdb/__init__.py b/keras/api/datasets/imdb/__init__.py new file mode 100644 index 000000000000..66931a4a30eb --- /dev/null +++ b/keras/api/datasets/imdb/__init__.py @@ -0,0 +1,8 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.imdb import get_word_index as get_word_index +from keras.src.datasets.imdb import load_data as load_data diff --git a/keras/api/datasets/mnist/__init__.py b/keras/api/datasets/mnist/__init__.py new file mode 100644 index 000000000000..0fc59f334c50 --- /dev/null +++ b/keras/api/datasets/mnist/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.mnist import load_data as load_data diff --git a/keras/api/datasets/reuters/__init__.py b/keras/api/datasets/reuters/__init__.py new file mode 100644 index 000000000000..0b2af62d785b --- /dev/null +++ b/keras/api/datasets/reuters/__init__.py @@ -0,0 +1,9 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.datasets.reuters import get_label_names as get_label_names +from keras.src.datasets.reuters import get_word_index as get_word_index +from keras.src.datasets.reuters import load_data as load_data diff --git a/keras/api/distribution/__init__.py b/keras/api/distribution/__init__.py new file mode 100644 index 000000000000..66fed24c761d --- /dev/null +++ b/keras/api/distribution/__init__.py @@ -0,0 +1,22 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.distribution.distribution_lib import DataParallel as DataParallel +from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh +from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap +from keras.src.distribution.distribution_lib import ( + ModelParallel as ModelParallel, +) +from keras.src.distribution.distribution_lib import TensorLayout as TensorLayout +from keras.src.distribution.distribution_lib import ( + distribute_tensor as distribute_tensor, +) +from keras.src.distribution.distribution_lib import distribution as distribution +from keras.src.distribution.distribution_lib import initialize as initialize +from keras.src.distribution.distribution_lib import list_devices as list_devices +from keras.src.distribution.distribution_lib import ( + set_distribution as set_distribution, +) diff --git a/keras/api/dtype_policies/__init__.py b/keras/api/dtype_policies/__init__.py new file mode 100644 index 000000000000..04f947d157c3 --- /dev/null +++ b/keras/api/dtype_policies/__init__.py @@ -0,0 +1,25 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.dtype_policies import deserialize as deserialize +from keras.src.dtype_policies import get as get +from keras.src.dtype_policies import serialize as serialize +from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy +from keras.src.dtype_policies.dtype_policy import ( + FloatDTypePolicy as FloatDTypePolicy, +) +from keras.src.dtype_policies.dtype_policy import ( + GPTQDTypePolicy as GPTQDTypePolicy, +) +from keras.src.dtype_policies.dtype_policy import ( + QuantizedDTypePolicy as QuantizedDTypePolicy, +) +from keras.src.dtype_policies.dtype_policy import ( + QuantizedFloat8DTypePolicy as QuantizedFloat8DTypePolicy, +) +from keras.src.dtype_policies.dtype_policy_map import ( + DTypePolicyMap as DTypePolicyMap, +) diff --git a/keras/api/export/__init__.py b/keras/api/export/__init__.py new file mode 100644 index 000000000000..fc8e748defcc --- /dev/null +++ b/keras/api/export/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.export.saved_model import ExportArchive as ExportArchive diff --git a/keras/api/initializers/__init__.py b/keras/api/initializers/__init__.py new file mode 100644 index 000000000000..e88013d97315 --- /dev/null +++ b/keras/api/initializers/__init__.py @@ -0,0 +1,81 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.initializers import deserialize as deserialize +from keras.src.initializers import get as get +from keras.src.initializers import serialize as serialize +from keras.src.initializers.constant_initializers import STFT as STFT +from keras.src.initializers.constant_initializers import STFT as STFTInitializer +from keras.src.initializers.constant_initializers import STFT as stft +from keras.src.initializers.constant_initializers import Constant as Constant +from keras.src.initializers.constant_initializers import Constant as constant +from keras.src.initializers.constant_initializers import Identity as Identity +from keras.src.initializers.constant_initializers import ( + Identity as IdentityInitializer, +) +from keras.src.initializers.constant_initializers import Identity as identity +from keras.src.initializers.constant_initializers import Ones as Ones +from keras.src.initializers.constant_initializers import Ones as ones +from keras.src.initializers.constant_initializers import Zeros as Zeros +from keras.src.initializers.constant_initializers import Zeros as zeros +from keras.src.initializers.initializer import Initializer as Initializer +from keras.src.initializers.random_initializers import ( + GlorotNormal as GlorotNormal, +) +from keras.src.initializers.random_initializers import ( + GlorotNormal as glorot_normal, +) +from keras.src.initializers.random_initializers import ( + GlorotUniform as GlorotUniform, +) +from keras.src.initializers.random_initializers import ( + GlorotUniform as glorot_uniform, +) +from keras.src.initializers.random_initializers import HeNormal as HeNormal +from keras.src.initializers.random_initializers import HeNormal as he_normal +from keras.src.initializers.random_initializers import HeUniform as HeUniform +from keras.src.initializers.random_initializers import HeUniform as he_uniform +from keras.src.initializers.random_initializers import ( + LecunNormal as LecunNormal, +) +from keras.src.initializers.random_initializers import ( + LecunNormal as lecun_normal, +) +from keras.src.initializers.random_initializers import ( + LecunUniform as LecunUniform, +) +from keras.src.initializers.random_initializers import ( + LecunUniform as lecun_uniform, +) +from keras.src.initializers.random_initializers import Orthogonal as Orthogonal +from keras.src.initializers.random_initializers import ( + Orthogonal as OrthogonalInitializer, +) +from keras.src.initializers.random_initializers import Orthogonal as orthogonal +from keras.src.initializers.random_initializers import ( + RandomNormal as RandomNormal, +) +from keras.src.initializers.random_initializers import ( + RandomNormal as random_normal, +) +from keras.src.initializers.random_initializers import ( + RandomUniform as RandomUniform, +) +from keras.src.initializers.random_initializers import ( + RandomUniform as random_uniform, +) +from keras.src.initializers.random_initializers import ( + TruncatedNormal as TruncatedNormal, +) +from keras.src.initializers.random_initializers import ( + TruncatedNormal as truncated_normal, +) +from keras.src.initializers.random_initializers import ( + VarianceScaling as VarianceScaling, +) +from keras.src.initializers.random_initializers import ( + VarianceScaling as variance_scaling, +) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py new file mode 100644 index 000000000000..82550fd33462 --- /dev/null +++ b/keras/api/layers/__init__.py @@ -0,0 +1,361 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.export.tfsm_layer import TFSMLayer as TFSMLayer +from keras.src.layers import deserialize as deserialize +from keras.src.layers import serialize as serialize +from keras.src.layers.activations.activation import Activation as Activation +from keras.src.layers.activations.elu import ELU as ELU +from keras.src.layers.activations.leaky_relu import LeakyReLU as LeakyReLU +from keras.src.layers.activations.prelu import PReLU as PReLU +from keras.src.layers.activations.relu import ReLU as ReLU +from keras.src.layers.activations.softmax import Softmax as Softmax +from keras.src.layers.attention.additive_attention import ( + AdditiveAttention as AdditiveAttention, +) +from keras.src.layers.attention.attention import Attention as Attention +from keras.src.layers.attention.grouped_query_attention import ( + GroupedQueryAttention as GroupQueryAttention, +) +from keras.src.layers.attention.multi_head_attention import ( + MultiHeadAttention as MultiHeadAttention, +) +from keras.src.layers.convolutional.conv1d import Conv1D as Conv1D +from keras.src.layers.convolutional.conv1d import Conv1D as Convolution1D +from keras.src.layers.convolutional.conv1d_transpose import ( + Conv1DTranspose as Conv1DTranspose, +) +from keras.src.layers.convolutional.conv1d_transpose import ( + Conv1DTranspose as Convolution1DTranspose, +) +from keras.src.layers.convolutional.conv2d import Conv2D as Conv2D +from keras.src.layers.convolutional.conv2d import Conv2D as Convolution2D +from keras.src.layers.convolutional.conv2d_transpose import ( + Conv2DTranspose as Conv2DTranspose, +) +from keras.src.layers.convolutional.conv2d_transpose import ( + Conv2DTranspose as Convolution2DTranspose, +) +from keras.src.layers.convolutional.conv3d import Conv3D as Conv3D +from keras.src.layers.convolutional.conv3d import Conv3D as Convolution3D +from keras.src.layers.convolutional.conv3d_transpose import ( + Conv3DTranspose as Conv3DTranspose, +) +from keras.src.layers.convolutional.conv3d_transpose import ( + Conv3DTranspose as Convolution3DTranspose, +) +from keras.src.layers.convolutional.depthwise_conv1d import ( + DepthwiseConv1D as DepthwiseConv1D, +) +from keras.src.layers.convolutional.depthwise_conv2d import ( + DepthwiseConv2D as DepthwiseConv2D, +) +from keras.src.layers.convolutional.separable_conv1d import ( + SeparableConv1D as SeparableConv1D, +) +from keras.src.layers.convolutional.separable_conv1d import ( + SeparableConv1D as SeparableConvolution1D, +) +from keras.src.layers.convolutional.separable_conv2d import ( + SeparableConv2D as SeparableConv2D, +) +from keras.src.layers.convolutional.separable_conv2d import ( + SeparableConv2D as SeparableConvolution2D, +) +from keras.src.layers.core.dense import Dense as Dense +from keras.src.layers.core.einsum_dense import EinsumDense as EinsumDense +from keras.src.layers.core.embedding import Embedding as Embedding +from keras.src.layers.core.identity import Identity as Identity +from keras.src.layers.core.input_layer import Input as Input +from keras.src.layers.core.input_layer import InputLayer as InputLayer +from keras.src.layers.core.lambda_layer import Lambda as Lambda +from keras.src.layers.core.masking import Masking as Masking +from keras.src.layers.core.wrapper import Wrapper as Wrapper +from keras.src.layers.input_spec import InputSpec as InputSpec +from keras.src.layers.layer import Layer as Layer +from keras.src.layers.merging.add import Add as Add +from keras.src.layers.merging.add import add as add +from keras.src.layers.merging.average import Average as Average +from keras.src.layers.merging.average import average as average +from keras.src.layers.merging.concatenate import Concatenate as Concatenate +from keras.src.layers.merging.concatenate import concatenate as concatenate +from keras.src.layers.merging.dot import Dot as Dot +from keras.src.layers.merging.dot import dot as dot +from keras.src.layers.merging.maximum import Maximum as Maximum +from keras.src.layers.merging.maximum import maximum as maximum +from keras.src.layers.merging.minimum import Minimum as Minimum +from keras.src.layers.merging.minimum import minimum as minimum +from keras.src.layers.merging.multiply import Multiply as Multiply +from keras.src.layers.merging.multiply import multiply as multiply +from keras.src.layers.merging.subtract import Subtract as Subtract +from keras.src.layers.merging.subtract import subtract as subtract +from keras.src.layers.normalization.batch_normalization import ( + BatchNormalization as BatchNormalization, +) +from keras.src.layers.normalization.group_normalization import ( + GroupNormalization as GroupNormalization, +) +from keras.src.layers.normalization.layer_normalization import ( + LayerNormalization as LayerNormalization, +) +from keras.src.layers.normalization.rms_normalization import ( + RMSNormalization as RMSNormalization, +) +from keras.src.layers.normalization.spectral_normalization import ( + SpectralNormalization as SpectralNormalization, +) +from keras.src.layers.normalization.unit_normalization import ( + UnitNormalization as UnitNormalization, +) +from keras.src.layers.pooling.average_pooling1d import ( + AveragePooling1D as AveragePooling1D, +) +from keras.src.layers.pooling.average_pooling1d import ( + AveragePooling1D as AvgPool1D, +) +from keras.src.layers.pooling.average_pooling2d import ( + AveragePooling2D as AveragePooling2D, +) +from keras.src.layers.pooling.average_pooling2d import ( + AveragePooling2D as AvgPool2D, +) +from keras.src.layers.pooling.average_pooling3d import ( + AveragePooling3D as AveragePooling3D, +) +from keras.src.layers.pooling.average_pooling3d import ( + AveragePooling3D as AvgPool3D, +) +from keras.src.layers.pooling.global_average_pooling1d import ( + GlobalAveragePooling1D as GlobalAveragePooling1D, +) +from keras.src.layers.pooling.global_average_pooling1d import ( + GlobalAveragePooling1D as GlobalAvgPool1D, +) +from keras.src.layers.pooling.global_average_pooling2d import ( + GlobalAveragePooling2D as GlobalAveragePooling2D, +) +from keras.src.layers.pooling.global_average_pooling2d import ( + GlobalAveragePooling2D as GlobalAvgPool2D, +) +from keras.src.layers.pooling.global_average_pooling3d import ( + GlobalAveragePooling3D as GlobalAveragePooling3D, +) +from keras.src.layers.pooling.global_average_pooling3d import ( + GlobalAveragePooling3D as GlobalAvgPool3D, +) +from keras.src.layers.pooling.global_max_pooling1d import ( + GlobalMaxPooling1D as GlobalMaxPool1D, +) +from keras.src.layers.pooling.global_max_pooling1d import ( + GlobalMaxPooling1D as GlobalMaxPooling1D, +) +from keras.src.layers.pooling.global_max_pooling2d import ( + GlobalMaxPooling2D as GlobalMaxPool2D, +) +from keras.src.layers.pooling.global_max_pooling2d import ( + GlobalMaxPooling2D as GlobalMaxPooling2D, +) +from keras.src.layers.pooling.global_max_pooling3d import ( + GlobalMaxPooling3D as GlobalMaxPool3D, +) +from keras.src.layers.pooling.global_max_pooling3d import ( + GlobalMaxPooling3D as GlobalMaxPooling3D, +) +from keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPool1D +from keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPooling1D +from keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPool2D +from keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPooling2D +from keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPool3D +from keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPooling3D +from keras.src.layers.preprocessing.category_encoding import ( + CategoryEncoding as CategoryEncoding, +) +from keras.src.layers.preprocessing.discretization import ( + Discretization as Discretization, +) +from keras.src.layers.preprocessing.hashed_crossing import ( + HashedCrossing as HashedCrossing, +) +from keras.src.layers.preprocessing.hashing import Hashing as Hashing +from keras.src.layers.preprocessing.image_preprocessing.aug_mix import ( + AugMix as AugMix, +) +from keras.src.layers.preprocessing.image_preprocessing.auto_contrast import ( + AutoContrast as AutoContrast, +) +from keras.src.layers.preprocessing.image_preprocessing.center_crop import ( + CenterCrop as CenterCrop, +) +from keras.src.layers.preprocessing.image_preprocessing.cut_mix import ( + CutMix as CutMix, +) +from keras.src.layers.preprocessing.image_preprocessing.equalization import ( + Equalization as Equalization, +) +from keras.src.layers.preprocessing.image_preprocessing.max_num_bounding_box import ( + MaxNumBoundingBoxes as MaxNumBoundingBoxes, +) +from keras.src.layers.preprocessing.image_preprocessing.mix_up import ( + MixUp as MixUp, +) +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment as RandAugment, +) +from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( + RandomBrightness as RandomBrightness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration as RandomColorDegeneration, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter as RandomColorJitter, +) +from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( + RandomContrast as RandomContrast, +) +from keras.src.layers.preprocessing.image_preprocessing.random_crop import ( + RandomCrop as RandomCrop, +) +from keras.src.layers.preprocessing.image_preprocessing.random_elastic_transform import ( + RandomElasticTransform as RandomElasticTransform, +) +from keras.src.layers.preprocessing.image_preprocessing.random_erasing import ( + RandomErasing as RandomErasing, +) +from keras.src.layers.preprocessing.image_preprocessing.random_flip import ( + RandomFlip as RandomFlip, +) +from keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import ( + RandomGaussianBlur as RandomGaussianBlur, +) +from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import ( + RandomGrayscale as RandomGrayscale, +) +from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( + RandomHue as RandomHue, +) +from keras.src.layers.preprocessing.image_preprocessing.random_invert import ( + RandomInvert as RandomInvert, +) +from keras.src.layers.preprocessing.image_preprocessing.random_perspective import ( + RandomPerspective as RandomPerspective, +) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization as RandomPosterization, +) +from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( + RandomRotation as RandomRotation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_saturation import ( + RandomSaturation as RandomSaturation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import ( + RandomSharpness as RandomSharpness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear as RandomShear, +) +from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( + RandomTranslation as RandomTranslation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_zoom import ( + RandomZoom as RandomZoom, +) +from keras.src.layers.preprocessing.image_preprocessing.resizing import ( + Resizing as Resizing, +) +from keras.src.layers.preprocessing.image_preprocessing.solarization import ( + Solarization as Solarization, +) +from keras.src.layers.preprocessing.integer_lookup import ( + IntegerLookup as IntegerLookup, +) +from keras.src.layers.preprocessing.mel_spectrogram import ( + MelSpectrogram as MelSpectrogram, +) +from keras.src.layers.preprocessing.normalization import ( + Normalization as Normalization, +) +from keras.src.layers.preprocessing.pipeline import Pipeline as Pipeline +from keras.src.layers.preprocessing.rescaling import Rescaling as Rescaling +from keras.src.layers.preprocessing.stft_spectrogram import ( + STFTSpectrogram as STFTSpectrogram, +) +from keras.src.layers.preprocessing.string_lookup import ( + StringLookup as StringLookup, +) +from keras.src.layers.preprocessing.text_vectorization import ( + TextVectorization as TextVectorization, +) +from keras.src.layers.regularization.activity_regularization import ( + ActivityRegularization as ActivityRegularization, +) +from keras.src.layers.regularization.alpha_dropout import ( + AlphaDropout as AlphaDropout, +) +from keras.src.layers.regularization.dropout import Dropout as Dropout +from keras.src.layers.regularization.gaussian_dropout import ( + GaussianDropout as GaussianDropout, +) +from keras.src.layers.regularization.gaussian_noise import ( + GaussianNoise as GaussianNoise, +) +from keras.src.layers.regularization.spatial_dropout import ( + SpatialDropout1D as SpatialDropout1D, +) +from keras.src.layers.regularization.spatial_dropout import ( + SpatialDropout2D as SpatialDropout2D, +) +from keras.src.layers.regularization.spatial_dropout import ( + SpatialDropout3D as SpatialDropout3D, +) +from keras.src.layers.reshaping.cropping1d import Cropping1D as Cropping1D +from keras.src.layers.reshaping.cropping2d import Cropping2D as Cropping2D +from keras.src.layers.reshaping.cropping3d import Cropping3D as Cropping3D +from keras.src.layers.reshaping.flatten import Flatten as Flatten +from keras.src.layers.reshaping.permute import Permute as Permute +from keras.src.layers.reshaping.repeat_vector import ( + RepeatVector as RepeatVector, +) +from keras.src.layers.reshaping.reshape import Reshape as Reshape +from keras.src.layers.reshaping.up_sampling1d import ( + UpSampling1D as UpSampling1D, +) +from keras.src.layers.reshaping.up_sampling2d import ( + UpSampling2D as UpSampling2D, +) +from keras.src.layers.reshaping.up_sampling3d import ( + UpSampling3D as UpSampling3D, +) +from keras.src.layers.reshaping.zero_padding1d import ( + ZeroPadding1D as ZeroPadding1D, +) +from keras.src.layers.reshaping.zero_padding2d import ( + ZeroPadding2D as ZeroPadding2D, +) +from keras.src.layers.reshaping.zero_padding3d import ( + ZeroPadding3D as ZeroPadding3D, +) +from keras.src.layers.rnn.bidirectional import Bidirectional as Bidirectional +from keras.src.layers.rnn.conv_lstm1d import ConvLSTM1D as ConvLSTM1D +from keras.src.layers.rnn.conv_lstm2d import ConvLSTM2D as ConvLSTM2D +from keras.src.layers.rnn.conv_lstm3d import ConvLSTM3D as ConvLSTM3D +from keras.src.layers.rnn.gru import GRU as GRU +from keras.src.layers.rnn.gru import GRUCell as GRUCell +from keras.src.layers.rnn.lstm import LSTM as LSTM +from keras.src.layers.rnn.lstm import LSTMCell as LSTMCell +from keras.src.layers.rnn.rnn import RNN as RNN +from keras.src.layers.rnn.simple_rnn import SimpleRNN as SimpleRNN +from keras.src.layers.rnn.simple_rnn import SimpleRNNCell as SimpleRNNCell +from keras.src.layers.rnn.stacked_rnn_cells import ( + StackedRNNCells as StackedRNNCells, +) +from keras.src.layers.rnn.time_distributed import ( + TimeDistributed as TimeDistributed, +) +from keras.src.utils.jax_layer import FlaxLayer as FlaxLayer +from keras.src.utils.jax_layer import JaxLayer as JaxLayer +from keras.src.utils.torch_utils import TorchModuleWrapper as TorchModuleWrapper diff --git a/keras/api/legacy/__init__.py b/keras/api/legacy/__init__.py new file mode 100644 index 000000000000..e71ba4312ee0 --- /dev/null +++ b/keras/api/legacy/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.legacy import saving as saving diff --git a/keras/api/legacy/saving/__init__.py b/keras/api/legacy/saving/__init__.py new file mode 100644 index 000000000000..1e3aa0ee9d5c --- /dev/null +++ b/keras/api/legacy/saving/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.saving.serialization import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.legacy.saving.serialization import ( + serialize_keras_object as serialize_keras_object, +) diff --git a/keras/api/losses/__init__.py b/keras/api/losses/__init__.py new file mode 100644 index 000000000000..60414fe301d0 --- /dev/null +++ b/keras/api/losses/__init__.py @@ -0,0 +1,82 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.losses import deserialize as deserialize +from keras.src.losses import get as get +from keras.src.losses import serialize as serialize +from keras.src.losses.loss import Loss as Loss +from keras.src.losses.losses import CTC as CTC +from keras.src.losses.losses import BinaryCrossentropy as BinaryCrossentropy +from keras.src.losses.losses import ( + BinaryFocalCrossentropy as BinaryFocalCrossentropy, +) +from keras.src.losses.losses import ( + CategoricalCrossentropy as CategoricalCrossentropy, +) +from keras.src.losses.losses import ( + CategoricalFocalCrossentropy as CategoricalFocalCrossentropy, +) +from keras.src.losses.losses import ( + CategoricalGeneralizedCrossEntropy as CategoricalGeneralizedCrossEntropy, +) +from keras.src.losses.losses import CategoricalHinge as CategoricalHinge +from keras.src.losses.losses import Circle as Circle +from keras.src.losses.losses import CosineSimilarity as CosineSimilarity +from keras.src.losses.losses import Dice as Dice +from keras.src.losses.losses import Hinge as Hinge +from keras.src.losses.losses import Huber as Huber +from keras.src.losses.losses import KLDivergence as KLDivergence +from keras.src.losses.losses import LogCosh as LogCosh +from keras.src.losses.losses import MeanAbsoluteError as MeanAbsoluteError +from keras.src.losses.losses import ( + MeanAbsolutePercentageError as MeanAbsolutePercentageError, +) +from keras.src.losses.losses import MeanSquaredError as MeanSquaredError +from keras.src.losses.losses import ( + MeanSquaredLogarithmicError as MeanSquaredLogarithmicError, +) +from keras.src.losses.losses import Poisson as Poisson +from keras.src.losses.losses import ( + SparseCategoricalCrossentropy as SparseCategoricalCrossentropy, +) +from keras.src.losses.losses import SquaredHinge as SquaredHinge +from keras.src.losses.losses import Tversky as Tversky +from keras.src.losses.losses import binary_crossentropy as binary_crossentropy +from keras.src.losses.losses import ( + binary_focal_crossentropy as binary_focal_crossentropy, +) +from keras.src.losses.losses import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.losses.losses import ( + categorical_focal_crossentropy as categorical_focal_crossentropy, +) +from keras.src.losses.losses import ( + categorical_generalized_cross_entropy as categorical_generalized_cross_entropy, +) +from keras.src.losses.losses import categorical_hinge as categorical_hinge +from keras.src.losses.losses import circle as circle +from keras.src.losses.losses import cosine_similarity as cosine_similarity +from keras.src.losses.losses import ctc as ctc +from keras.src.losses.losses import dice as dice +from keras.src.losses.losses import hinge as hinge +from keras.src.losses.losses import huber as huber +from keras.src.losses.losses import kl_divergence as kl_divergence +from keras.src.losses.losses import log_cosh as log_cosh +from keras.src.losses.losses import mean_absolute_error as mean_absolute_error +from keras.src.losses.losses import ( + mean_absolute_percentage_error as mean_absolute_percentage_error, +) +from keras.src.losses.losses import mean_squared_error as mean_squared_error +from keras.src.losses.losses import ( + mean_squared_logarithmic_error as mean_squared_logarithmic_error, +) +from keras.src.losses.losses import poisson as poisson +from keras.src.losses.losses import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.losses.losses import squared_hinge as squared_hinge +from keras.src.losses.losses import tversky as tversky diff --git a/keras/api/metrics/__init__.py b/keras/api/metrics/__init__.py new file mode 100644 index 000000000000..e7ba55dbcb0c --- /dev/null +++ b/keras/api/metrics/__init__.py @@ -0,0 +1,144 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.losses.losses import binary_crossentropy as binary_crossentropy +from keras.src.losses.losses import ( + binary_focal_crossentropy as binary_focal_crossentropy, +) +from keras.src.losses.losses import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.losses.losses import ( + categorical_focal_crossentropy as categorical_focal_crossentropy, +) +from keras.src.losses.losses import categorical_hinge as categorical_hinge +from keras.src.losses.losses import hinge as hinge +from keras.src.losses.losses import huber as huber +from keras.src.losses.losses import kl_divergence as kl_divergence +from keras.src.losses.losses import log_cosh as log_cosh +from keras.src.losses.losses import mean_absolute_error as mean_absolute_error +from keras.src.losses.losses import ( + mean_absolute_percentage_error as mean_absolute_percentage_error, +) +from keras.src.losses.losses import mean_squared_error as mean_squared_error +from keras.src.losses.losses import ( + mean_squared_logarithmic_error as mean_squared_logarithmic_error, +) +from keras.src.losses.losses import poisson as poisson +from keras.src.losses.losses import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.losses.losses import squared_hinge as squared_hinge +from keras.src.metrics import deserialize as deserialize +from keras.src.metrics import get as get +from keras.src.metrics import serialize as serialize +from keras.src.metrics.accuracy_metrics import Accuracy as Accuracy +from keras.src.metrics.accuracy_metrics import BinaryAccuracy as BinaryAccuracy +from keras.src.metrics.accuracy_metrics import ( + CategoricalAccuracy as CategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + SparseCategoricalAccuracy as SparseCategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + SparseTopKCategoricalAccuracy as SparseTopKCategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + TopKCategoricalAccuracy as TopKCategoricalAccuracy, +) +from keras.src.metrics.accuracy_metrics import ( + binary_accuracy as binary_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + categorical_accuracy as categorical_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + sparse_categorical_accuracy as sparse_categorical_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + sparse_top_k_categorical_accuracy as sparse_top_k_categorical_accuracy, +) +from keras.src.metrics.accuracy_metrics import ( + top_k_categorical_accuracy as top_k_categorical_accuracy, +) +from keras.src.metrics.confusion_metrics import AUC as AUC +from keras.src.metrics.confusion_metrics import FalseNegatives as FalseNegatives +from keras.src.metrics.confusion_metrics import FalsePositives as FalsePositives +from keras.src.metrics.confusion_metrics import Precision as Precision +from keras.src.metrics.confusion_metrics import ( + PrecisionAtRecall as PrecisionAtRecall, +) +from keras.src.metrics.confusion_metrics import Recall as Recall +from keras.src.metrics.confusion_metrics import ( + RecallAtPrecision as RecallAtPrecision, +) +from keras.src.metrics.confusion_metrics import ( + SensitivityAtSpecificity as SensitivityAtSpecificity, +) +from keras.src.metrics.confusion_metrics import ( + SpecificityAtSensitivity as SpecificityAtSensitivity, +) +from keras.src.metrics.confusion_metrics import TrueNegatives as TrueNegatives +from keras.src.metrics.confusion_metrics import TruePositives as TruePositives +from keras.src.metrics.correlation_metrics import ( + ConcordanceCorrelation as ConcordanceCorrelation, +) +from keras.src.metrics.correlation_metrics import ( + PearsonCorrelation as PearsonCorrelation, +) +from keras.src.metrics.correlation_metrics import ( + concordance_correlation as concordance_correlation, +) +from keras.src.metrics.correlation_metrics import ( + pearson_correlation as pearson_correlation, +) +from keras.src.metrics.f_score_metrics import F1Score as F1Score +from keras.src.metrics.f_score_metrics import FBetaScore as FBetaScore +from keras.src.metrics.hinge_metrics import CategoricalHinge as CategoricalHinge +from keras.src.metrics.hinge_metrics import Hinge as Hinge +from keras.src.metrics.hinge_metrics import SquaredHinge as SquaredHinge +from keras.src.metrics.iou_metrics import BinaryIoU as BinaryIoU +from keras.src.metrics.iou_metrics import IoU as IoU +from keras.src.metrics.iou_metrics import MeanIoU as MeanIoU +from keras.src.metrics.iou_metrics import OneHotIoU as OneHotIoU +from keras.src.metrics.iou_metrics import OneHotMeanIoU as OneHotMeanIoU +from keras.src.metrics.metric import Metric as Metric +from keras.src.metrics.probabilistic_metrics import ( + BinaryCrossentropy as BinaryCrossentropy, +) +from keras.src.metrics.probabilistic_metrics import ( + CategoricalCrossentropy as CategoricalCrossentropy, +) +from keras.src.metrics.probabilistic_metrics import KLDivergence as KLDivergence +from keras.src.metrics.probabilistic_metrics import Poisson as Poisson +from keras.src.metrics.probabilistic_metrics import ( + SparseCategoricalCrossentropy as SparseCategoricalCrossentropy, +) +from keras.src.metrics.reduction_metrics import Mean as Mean +from keras.src.metrics.reduction_metrics import ( + MeanMetricWrapper as MeanMetricWrapper, +) +from keras.src.metrics.reduction_metrics import Sum as Sum +from keras.src.metrics.regression_metrics import ( + CosineSimilarity as CosineSimilarity, +) +from keras.src.metrics.regression_metrics import LogCoshError as LogCoshError +from keras.src.metrics.regression_metrics import ( + MeanAbsoluteError as MeanAbsoluteError, +) +from keras.src.metrics.regression_metrics import ( + MeanAbsolutePercentageError as MeanAbsolutePercentageError, +) +from keras.src.metrics.regression_metrics import ( + MeanSquaredError as MeanSquaredError, +) +from keras.src.metrics.regression_metrics import ( + MeanSquaredLogarithmicError as MeanSquaredLogarithmicError, +) +from keras.src.metrics.regression_metrics import R2Score as R2Score +from keras.src.metrics.regression_metrics import ( + RootMeanSquaredError as RootMeanSquaredError, +) diff --git a/keras/api/mixed_precision/__init__.py b/keras/api/mixed_precision/__init__.py new file mode 100644 index 000000000000..9555b8639385 --- /dev/null +++ b/keras/api/mixed_precision/__init__.py @@ -0,0 +1,19 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy +from keras.src.dtype_policies.dtype_policy import DTypePolicy as Policy +from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy +from keras.src.dtype_policies.dtype_policy import dtype_policy as global_policy +from keras.src.dtype_policies.dtype_policy import ( + set_dtype_policy as set_dtype_policy, +) +from keras.src.dtype_policies.dtype_policy import ( + set_dtype_policy as set_global_policy, +) +from keras.src.optimizers.loss_scale_optimizer import ( + LossScaleOptimizer as LossScaleOptimizer, +) diff --git a/keras/api/models/__init__.py b/keras/api/models/__init__.py new file mode 100644 index 000000000000..f9dd57556d53 --- /dev/null +++ b/keras/api/models/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.models.cloning import clone_model as clone_model +from keras.src.models.model import Model as Model +from keras.src.models.model import model_from_json as model_from_json +from keras.src.models.sequential import Sequential as Sequential +from keras.src.saving.saving_api import load_model as load_model +from keras.src.saving.saving_api import save_model as save_model diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py new file mode 100644 index 000000000000..2194c975b89f --- /dev/null +++ b/keras/api/ops/__init__.py @@ -0,0 +1,300 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.ops import image as image +from keras.ops import linalg as linalg +from keras.ops import nn as nn +from keras.ops import numpy as numpy +from keras.src.ops.core import associative_scan as associative_scan +from keras.src.ops.core import cast as cast +from keras.src.ops.core import cond as cond +from keras.src.ops.core import convert_to_numpy as convert_to_numpy +from keras.src.ops.core import convert_to_tensor as convert_to_tensor +from keras.src.ops.core import custom_gradient as custom_gradient +from keras.src.ops.core import dtype as dtype +from keras.src.ops.core import fori_loop as fori_loop +from keras.src.ops.core import is_tensor as is_tensor +from keras.src.ops.core import map as map +from keras.src.ops.core import saturate_cast as saturate_cast +from keras.src.ops.core import scan as scan +from keras.src.ops.core import scatter as scatter +from keras.src.ops.core import scatter_update as scatter_update +from keras.src.ops.core import shape as shape +from keras.src.ops.core import slice as slice +from keras.src.ops.core import slice_update as slice_update +from keras.src.ops.core import stop_gradient as stop_gradient +from keras.src.ops.core import switch as switch +from keras.src.ops.core import unstack as unstack +from keras.src.ops.core import vectorized_map as vectorized_map +from keras.src.ops.core import while_loop as while_loop +from keras.src.ops.einops import rearrange as rearrange +from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse +from keras.src.ops.linalg import det as det +from keras.src.ops.linalg import eig as eig +from keras.src.ops.linalg import eigh as eigh +from keras.src.ops.linalg import inv as inv +from keras.src.ops.linalg import jvp as jvp +from keras.src.ops.linalg import lstsq as lstsq +from keras.src.ops.linalg import lu_factor as lu_factor +from keras.src.ops.linalg import norm as norm +from keras.src.ops.linalg import qr as qr +from keras.src.ops.linalg import solve as solve +from keras.src.ops.linalg import solve_triangular as solve_triangular +from keras.src.ops.linalg import svd as svd +from keras.src.ops.math import erf as erf +from keras.src.ops.math import erfinv as erfinv +from keras.src.ops.math import extract_sequences as extract_sequences +from keras.src.ops.math import fft as fft +from keras.src.ops.math import fft2 as fft2 +from keras.src.ops.math import ifft2 as ifft2 +from keras.src.ops.math import in_top_k as in_top_k +from keras.src.ops.math import irfft as irfft +from keras.src.ops.math import istft as istft +from keras.src.ops.math import logdet as logdet +from keras.src.ops.math import logsumexp as logsumexp +from keras.src.ops.math import rfft as rfft +from keras.src.ops.math import rsqrt as rsqrt +from keras.src.ops.math import segment_max as segment_max +from keras.src.ops.math import segment_sum as segment_sum +from keras.src.ops.math import stft as stft +from keras.src.ops.math import top_k as top_k +from keras.src.ops.math import view_as_complex as view_as_complex +from keras.src.ops.math import view_as_real as view_as_real +from keras.src.ops.nn import average_pool as average_pool +from keras.src.ops.nn import batch_normalization as batch_normalization +from keras.src.ops.nn import binary_crossentropy as binary_crossentropy +from keras.src.ops.nn import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.ops.nn import celu as celu +from keras.src.ops.nn import conv as conv +from keras.src.ops.nn import conv_transpose as conv_transpose +from keras.src.ops.nn import ctc_decode as ctc_decode +from keras.src.ops.nn import ctc_loss as ctc_loss +from keras.src.ops.nn import depthwise_conv as depthwise_conv +from keras.src.ops.nn import dot_product_attention as dot_product_attention +from keras.src.ops.nn import elu as elu +from keras.src.ops.nn import gelu as gelu +from keras.src.ops.nn import glu as glu +from keras.src.ops.nn import hard_shrink as hard_shrink +from keras.src.ops.nn import hard_sigmoid as hard_sigmoid +from keras.src.ops.nn import hard_silu as hard_silu +from keras.src.ops.nn import hard_silu as hard_swish +from keras.src.ops.nn import hard_tanh as hard_tanh +from keras.src.ops.nn import layer_normalization as layer_normalization +from keras.src.ops.nn import leaky_relu as leaky_relu +from keras.src.ops.nn import log_sigmoid as log_sigmoid +from keras.src.ops.nn import log_softmax as log_softmax +from keras.src.ops.nn import max_pool as max_pool +from keras.src.ops.nn import moments as moments +from keras.src.ops.nn import multi_hot as multi_hot +from keras.src.ops.nn import normalize as normalize +from keras.src.ops.nn import one_hot as one_hot +from keras.src.ops.nn import polar as polar +from keras.src.ops.nn import psnr as psnr +from keras.src.ops.nn import relu as relu +from keras.src.ops.nn import relu6 as relu6 +from keras.src.ops.nn import rms_normalization as rms_normalization +from keras.src.ops.nn import selu as selu +from keras.src.ops.nn import separable_conv as separable_conv +from keras.src.ops.nn import sigmoid as sigmoid +from keras.src.ops.nn import silu as silu +from keras.src.ops.nn import silu as swish +from keras.src.ops.nn import soft_shrink as soft_shrink +from keras.src.ops.nn import softmax as softmax +from keras.src.ops.nn import softplus as softplus +from keras.src.ops.nn import softsign as softsign +from keras.src.ops.nn import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.ops.nn import sparse_plus as sparse_plus +from keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid +from keras.src.ops.nn import sparsemax as sparsemax +from keras.src.ops.nn import squareplus as squareplus +from keras.src.ops.nn import tanh_shrink as tanh_shrink +from keras.src.ops.nn import threshold as threshold +from keras.src.ops.nn import unfold as unfold +from keras.src.ops.numpy import abs as abs +from keras.src.ops.numpy import absolute as absolute +from keras.src.ops.numpy import add as add +from keras.src.ops.numpy import all as all +from keras.src.ops.numpy import amax as amax +from keras.src.ops.numpy import amin as amin +from keras.src.ops.numpy import angle as angle +from keras.src.ops.numpy import any as any +from keras.src.ops.numpy import append as append +from keras.src.ops.numpy import arange as arange +from keras.src.ops.numpy import arccos as arccos +from keras.src.ops.numpy import arccosh as arccosh +from keras.src.ops.numpy import arcsin as arcsin +from keras.src.ops.numpy import arcsinh as arcsinh +from keras.src.ops.numpy import arctan as arctan +from keras.src.ops.numpy import arctan2 as arctan2 +from keras.src.ops.numpy import arctanh as arctanh +from keras.src.ops.numpy import argmax as argmax +from keras.src.ops.numpy import argmin as argmin +from keras.src.ops.numpy import argpartition as argpartition +from keras.src.ops.numpy import argsort as argsort +from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import average as average +from keras.src.ops.numpy import bartlett as bartlett +from keras.src.ops.numpy import bincount as bincount +from keras.src.ops.numpy import bitwise_and as bitwise_and +from keras.src.ops.numpy import bitwise_invert as bitwise_invert +from keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift +from keras.src.ops.numpy import bitwise_not as bitwise_not +from keras.src.ops.numpy import bitwise_or as bitwise_or +from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift +from keras.src.ops.numpy import bitwise_xor as bitwise_xor +from keras.src.ops.numpy import blackman as blackman +from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt +from keras.src.ops.numpy import ceil as ceil +from keras.src.ops.numpy import clip as clip +from keras.src.ops.numpy import concatenate as concatenate +from keras.src.ops.numpy import conj as conj +from keras.src.ops.numpy import conjugate as conjugate +from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef +from keras.src.ops.numpy import correlate as correlate +from keras.src.ops.numpy import cos as cos +from keras.src.ops.numpy import cosh as cosh +from keras.src.ops.numpy import count_nonzero as count_nonzero +from keras.src.ops.numpy import cross as cross +from keras.src.ops.numpy import cumprod as cumprod +from keras.src.ops.numpy import cumsum as cumsum +from keras.src.ops.numpy import deg2rad as deg2rad +from keras.src.ops.numpy import diag as diag +from keras.src.ops.numpy import diagflat as diagflat +from keras.src.ops.numpy import diagonal as diagonal +from keras.src.ops.numpy import diff as diff +from keras.src.ops.numpy import digitize as digitize +from keras.src.ops.numpy import divide as divide +from keras.src.ops.numpy import divide_no_nan as divide_no_nan +from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import einsum as einsum +from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import equal as equal +from keras.src.ops.numpy import exp as exp +from keras.src.ops.numpy import exp2 as exp2 +from keras.src.ops.numpy import expand_dims as expand_dims +from keras.src.ops.numpy import expm1 as expm1 +from keras.src.ops.numpy import eye as eye +from keras.src.ops.numpy import flip as flip +from keras.src.ops.numpy import floor as floor +from keras.src.ops.numpy import floor_divide as floor_divide +from keras.src.ops.numpy import full as full +from keras.src.ops.numpy import full_like as full_like +from keras.src.ops.numpy import gcd as gcd +from keras.src.ops.numpy import get_item as get_item +from keras.src.ops.numpy import greater as greater +from keras.src.ops.numpy import greater_equal as greater_equal +from keras.src.ops.numpy import hamming as hamming +from keras.src.ops.numpy import hanning as hanning +from keras.src.ops.numpy import heaviside as heaviside +from keras.src.ops.numpy import histogram as histogram +from keras.src.ops.numpy import hstack as hstack +from keras.src.ops.numpy import hypot as hypot +from keras.src.ops.numpy import identity as identity +from keras.src.ops.numpy import imag as imag +from keras.src.ops.numpy import inner as inner +from keras.src.ops.numpy import isclose as isclose +from keras.src.ops.numpy import isfinite as isfinite +from keras.src.ops.numpy import isin as isin +from keras.src.ops.numpy import isinf as isinf +from keras.src.ops.numpy import isnan as isnan +from keras.src.ops.numpy import isneginf as isneginf +from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import kaiser as kaiser +from keras.src.ops.numpy import kron as kron +from keras.src.ops.numpy import lcm as lcm +from keras.src.ops.numpy import left_shift as left_shift +from keras.src.ops.numpy import less as less +from keras.src.ops.numpy import less_equal as less_equal +from keras.src.ops.numpy import linspace as linspace +from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log1p as log1p +from keras.src.ops.numpy import log2 as log2 +from keras.src.ops.numpy import log10 as log10 +from keras.src.ops.numpy import logaddexp as logaddexp +from keras.src.ops.numpy import logaddexp2 as logaddexp2 +from keras.src.ops.numpy import logical_and as logical_and +from keras.src.ops.numpy import logical_not as logical_not +from keras.src.ops.numpy import logical_or as logical_or +from keras.src.ops.numpy import logical_xor as logical_xor +from keras.src.ops.numpy import logspace as logspace +from keras.src.ops.numpy import matmul as matmul +from keras.src.ops.numpy import max as max +from keras.src.ops.numpy import maximum as maximum +from keras.src.ops.numpy import mean as mean +from keras.src.ops.numpy import median as median +from keras.src.ops.numpy import meshgrid as meshgrid +from keras.src.ops.numpy import min as min +from keras.src.ops.numpy import minimum as minimum +from keras.src.ops.numpy import mod as mod +from keras.src.ops.numpy import moveaxis as moveaxis +from keras.src.ops.numpy import multiply as multiply +from keras.src.ops.numpy import nan_to_num as nan_to_num +from keras.src.ops.numpy import ndim as ndim +from keras.src.ops.numpy import negative as negative +from keras.src.ops.numpy import nonzero as nonzero +from keras.src.ops.numpy import not_equal as not_equal +from keras.src.ops.numpy import ones as ones +from keras.src.ops.numpy import ones_like as ones_like +from keras.src.ops.numpy import outer as outer +from keras.src.ops.numpy import pad as pad +from keras.src.ops.numpy import power as power +from keras.src.ops.numpy import prod as prod +from keras.src.ops.numpy import quantile as quantile +from keras.src.ops.numpy import ravel as ravel +from keras.src.ops.numpy import real as real +from keras.src.ops.numpy import reciprocal as reciprocal +from keras.src.ops.numpy import repeat as repeat +from keras.src.ops.numpy import reshape as reshape +from keras.src.ops.numpy import right_shift as right_shift +from keras.src.ops.numpy import roll as roll +from keras.src.ops.numpy import rot90 as rot90 +from keras.src.ops.numpy import round as round +from keras.src.ops.numpy import searchsorted as searchsorted +from keras.src.ops.numpy import select as select +from keras.src.ops.numpy import sign as sign +from keras.src.ops.numpy import signbit as signbit +from keras.src.ops.numpy import sin as sin +from keras.src.ops.numpy import sinh as sinh +from keras.src.ops.numpy import size as size +from keras.src.ops.numpy import slogdet as slogdet +from keras.src.ops.numpy import sort as sort +from keras.src.ops.numpy import split as split +from keras.src.ops.numpy import sqrt as sqrt +from keras.src.ops.numpy import square as square +from keras.src.ops.numpy import squeeze as squeeze +from keras.src.ops.numpy import stack as stack +from keras.src.ops.numpy import std as std +from keras.src.ops.numpy import subtract as subtract +from keras.src.ops.numpy import sum as sum +from keras.src.ops.numpy import swapaxes as swapaxes +from keras.src.ops.numpy import take as take +from keras.src.ops.numpy import take_along_axis as take_along_axis +from keras.src.ops.numpy import tan as tan +from keras.src.ops.numpy import tanh as tanh +from keras.src.ops.numpy import tensordot as tensordot +from keras.src.ops.numpy import tile as tile +from keras.src.ops.numpy import trace as trace +from keras.src.ops.numpy import transpose as transpose +from keras.src.ops.numpy import tri as tri +from keras.src.ops.numpy import tril as tril +from keras.src.ops.numpy import triu as triu +from keras.src.ops.numpy import true_divide as true_divide +from keras.src.ops.numpy import trunc as trunc +from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import var as var +from keras.src.ops.numpy import vdot as vdot +from keras.src.ops.numpy import vectorize as vectorize +from keras.src.ops.numpy import vstack as vstack +from keras.src.ops.numpy import where as where +from keras.src.ops.numpy import zeros as zeros +from keras.src.ops.numpy import zeros_like as zeros_like diff --git a/keras/api/ops/image/__init__.py b/keras/api/ops/image/__init__.py new file mode 100644 index 000000000000..3a81f191259d --- /dev/null +++ b/keras/api/ops/image/__init__.py @@ -0,0 +1,19 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.image import affine_transform as affine_transform +from keras.src.ops.image import crop_images as crop_images +from keras.src.ops.image import elastic_transform as elastic_transform +from keras.src.ops.image import extract_patches as extract_patches +from keras.src.ops.image import gaussian_blur as gaussian_blur +from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb +from keras.src.ops.image import map_coordinates as map_coordinates +from keras.src.ops.image import pad_images as pad_images +from keras.src.ops.image import perspective_transform as perspective_transform +from keras.src.ops.image import resize as resize +from keras.src.ops.image import rgb_to_grayscale as rgb_to_grayscale +from keras.src.ops.image import rgb_to_hsv as rgb_to_hsv +from keras.src.ops.image import scale_and_translate as scale_and_translate diff --git a/keras/api/ops/linalg/__init__.py b/keras/api/ops/linalg/__init__.py new file mode 100644 index 000000000000..764fa8e74269 --- /dev/null +++ b/keras/api/ops/linalg/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse +from keras.src.ops.linalg import det as det +from keras.src.ops.linalg import eig as eig +from keras.src.ops.linalg import eigh as eigh +from keras.src.ops.linalg import inv as inv +from keras.src.ops.linalg import jvp as jvp +from keras.src.ops.linalg import lstsq as lstsq +from keras.src.ops.linalg import lu_factor as lu_factor +from keras.src.ops.linalg import norm as norm +from keras.src.ops.linalg import qr as qr +from keras.src.ops.linalg import solve as solve +from keras.src.ops.linalg import solve_triangular as solve_triangular +from keras.src.ops.linalg import svd as svd diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py new file mode 100644 index 000000000000..da08f380f227 --- /dev/null +++ b/keras/api/ops/nn/__init__.py @@ -0,0 +1,60 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.nn import average_pool as average_pool +from keras.src.ops.nn import batch_normalization as batch_normalization +from keras.src.ops.nn import binary_crossentropy as binary_crossentropy +from keras.src.ops.nn import ( + categorical_crossentropy as categorical_crossentropy, +) +from keras.src.ops.nn import celu as celu +from keras.src.ops.nn import conv as conv +from keras.src.ops.nn import conv_transpose as conv_transpose +from keras.src.ops.nn import ctc_decode as ctc_decode +from keras.src.ops.nn import ctc_loss as ctc_loss +from keras.src.ops.nn import depthwise_conv as depthwise_conv +from keras.src.ops.nn import dot_product_attention as dot_product_attention +from keras.src.ops.nn import elu as elu +from keras.src.ops.nn import gelu as gelu +from keras.src.ops.nn import glu as glu +from keras.src.ops.nn import hard_shrink as hard_shrink +from keras.src.ops.nn import hard_sigmoid as hard_sigmoid +from keras.src.ops.nn import hard_silu as hard_silu +from keras.src.ops.nn import hard_silu as hard_swish +from keras.src.ops.nn import hard_tanh as hard_tanh +from keras.src.ops.nn import layer_normalization as layer_normalization +from keras.src.ops.nn import leaky_relu as leaky_relu +from keras.src.ops.nn import log_sigmoid as log_sigmoid +from keras.src.ops.nn import log_softmax as log_softmax +from keras.src.ops.nn import max_pool as max_pool +from keras.src.ops.nn import moments as moments +from keras.src.ops.nn import multi_hot as multi_hot +from keras.src.ops.nn import normalize as normalize +from keras.src.ops.nn import one_hot as one_hot +from keras.src.ops.nn import polar as polar +from keras.src.ops.nn import psnr as psnr +from keras.src.ops.nn import relu as relu +from keras.src.ops.nn import relu6 as relu6 +from keras.src.ops.nn import rms_normalization as rms_normalization +from keras.src.ops.nn import selu as selu +from keras.src.ops.nn import separable_conv as separable_conv +from keras.src.ops.nn import sigmoid as sigmoid +from keras.src.ops.nn import silu as silu +from keras.src.ops.nn import silu as swish +from keras.src.ops.nn import soft_shrink as soft_shrink +from keras.src.ops.nn import softmax as softmax +from keras.src.ops.nn import softplus as softplus +from keras.src.ops.nn import softsign as softsign +from keras.src.ops.nn import ( + sparse_categorical_crossentropy as sparse_categorical_crossentropy, +) +from keras.src.ops.nn import sparse_plus as sparse_plus +from keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid +from keras.src.ops.nn import sparsemax as sparsemax +from keras.src.ops.nn import squareplus as squareplus +from keras.src.ops.nn import tanh_shrink as tanh_shrink +from keras.src.ops.nn import threshold as threshold +from keras.src.ops.nn import unfold as unfold diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py new file mode 100644 index 000000000000..ebeb384c181c --- /dev/null +++ b/keras/api/ops/numpy/__init__.py @@ -0,0 +1,186 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.ops.numpy import abs as abs +from keras.src.ops.numpy import absolute as absolute +from keras.src.ops.numpy import add as add +from keras.src.ops.numpy import all as all +from keras.src.ops.numpy import amax as amax +from keras.src.ops.numpy import amin as amin +from keras.src.ops.numpy import angle as angle +from keras.src.ops.numpy import any as any +from keras.src.ops.numpy import append as append +from keras.src.ops.numpy import arange as arange +from keras.src.ops.numpy import arccos as arccos +from keras.src.ops.numpy import arccosh as arccosh +from keras.src.ops.numpy import arcsin as arcsin +from keras.src.ops.numpy import arcsinh as arcsinh +from keras.src.ops.numpy import arctan as arctan +from keras.src.ops.numpy import arctan2 as arctan2 +from keras.src.ops.numpy import arctanh as arctanh +from keras.src.ops.numpy import argmax as argmax +from keras.src.ops.numpy import argmin as argmin +from keras.src.ops.numpy import argpartition as argpartition +from keras.src.ops.numpy import argsort as argsort +from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import average as average +from keras.src.ops.numpy import bartlett as bartlett +from keras.src.ops.numpy import bincount as bincount +from keras.src.ops.numpy import bitwise_and as bitwise_and +from keras.src.ops.numpy import bitwise_invert as bitwise_invert +from keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift +from keras.src.ops.numpy import bitwise_not as bitwise_not +from keras.src.ops.numpy import bitwise_or as bitwise_or +from keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift +from keras.src.ops.numpy import bitwise_xor as bitwise_xor +from keras.src.ops.numpy import blackman as blackman +from keras.src.ops.numpy import broadcast_to as broadcast_to +from keras.src.ops.numpy import cbrt as cbrt +from keras.src.ops.numpy import ceil as ceil +from keras.src.ops.numpy import clip as clip +from keras.src.ops.numpy import concatenate as concatenate +from keras.src.ops.numpy import conj as conj +from keras.src.ops.numpy import conjugate as conjugate +from keras.src.ops.numpy import copy as copy +from keras.src.ops.numpy import corrcoef as corrcoef +from keras.src.ops.numpy import correlate as correlate +from keras.src.ops.numpy import cos as cos +from keras.src.ops.numpy import cosh as cosh +from keras.src.ops.numpy import count_nonzero as count_nonzero +from keras.src.ops.numpy import cross as cross +from keras.src.ops.numpy import cumprod as cumprod +from keras.src.ops.numpy import cumsum as cumsum +from keras.src.ops.numpy import deg2rad as deg2rad +from keras.src.ops.numpy import diag as diag +from keras.src.ops.numpy import diagflat as diagflat +from keras.src.ops.numpy import diagonal as diagonal +from keras.src.ops.numpy import diff as diff +from keras.src.ops.numpy import digitize as digitize +from keras.src.ops.numpy import divide as divide +from keras.src.ops.numpy import divide_no_nan as divide_no_nan +from keras.src.ops.numpy import dot as dot +from keras.src.ops.numpy import einsum as einsum +from keras.src.ops.numpy import empty as empty +from keras.src.ops.numpy import equal as equal +from keras.src.ops.numpy import exp as exp +from keras.src.ops.numpy import exp2 as exp2 +from keras.src.ops.numpy import expand_dims as expand_dims +from keras.src.ops.numpy import expm1 as expm1 +from keras.src.ops.numpy import eye as eye +from keras.src.ops.numpy import flip as flip +from keras.src.ops.numpy import floor as floor +from keras.src.ops.numpy import floor_divide as floor_divide +from keras.src.ops.numpy import full as full +from keras.src.ops.numpy import full_like as full_like +from keras.src.ops.numpy import gcd as gcd +from keras.src.ops.numpy import get_item as get_item +from keras.src.ops.numpy import greater as greater +from keras.src.ops.numpy import greater_equal as greater_equal +from keras.src.ops.numpy import hamming as hamming +from keras.src.ops.numpy import hanning as hanning +from keras.src.ops.numpy import heaviside as heaviside +from keras.src.ops.numpy import histogram as histogram +from keras.src.ops.numpy import hstack as hstack +from keras.src.ops.numpy import hypot as hypot +from keras.src.ops.numpy import identity as identity +from keras.src.ops.numpy import imag as imag +from keras.src.ops.numpy import inner as inner +from keras.src.ops.numpy import isclose as isclose +from keras.src.ops.numpy import isfinite as isfinite +from keras.src.ops.numpy import isin as isin +from keras.src.ops.numpy import isinf as isinf +from keras.src.ops.numpy import isnan as isnan +from keras.src.ops.numpy import isneginf as isneginf +from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import kaiser as kaiser +from keras.src.ops.numpy import kron as kron +from keras.src.ops.numpy import lcm as lcm +from keras.src.ops.numpy import left_shift as left_shift +from keras.src.ops.numpy import less as less +from keras.src.ops.numpy import less_equal as less_equal +from keras.src.ops.numpy import linspace as linspace +from keras.src.ops.numpy import log as log +from keras.src.ops.numpy import log1p as log1p +from keras.src.ops.numpy import log2 as log2 +from keras.src.ops.numpy import log10 as log10 +from keras.src.ops.numpy import logaddexp as logaddexp +from keras.src.ops.numpy import logaddexp2 as logaddexp2 +from keras.src.ops.numpy import logical_and as logical_and +from keras.src.ops.numpy import logical_not as logical_not +from keras.src.ops.numpy import logical_or as logical_or +from keras.src.ops.numpy import logical_xor as logical_xor +from keras.src.ops.numpy import logspace as logspace +from keras.src.ops.numpy import matmul as matmul +from keras.src.ops.numpy import max as max +from keras.src.ops.numpy import maximum as maximum +from keras.src.ops.numpy import mean as mean +from keras.src.ops.numpy import median as median +from keras.src.ops.numpy import meshgrid as meshgrid +from keras.src.ops.numpy import min as min +from keras.src.ops.numpy import minimum as minimum +from keras.src.ops.numpy import mod as mod +from keras.src.ops.numpy import moveaxis as moveaxis +from keras.src.ops.numpy import multiply as multiply +from keras.src.ops.numpy import nan_to_num as nan_to_num +from keras.src.ops.numpy import ndim as ndim +from keras.src.ops.numpy import negative as negative +from keras.src.ops.numpy import nonzero as nonzero +from keras.src.ops.numpy import not_equal as not_equal +from keras.src.ops.numpy import ones as ones +from keras.src.ops.numpy import ones_like as ones_like +from keras.src.ops.numpy import outer as outer +from keras.src.ops.numpy import pad as pad +from keras.src.ops.numpy import power as power +from keras.src.ops.numpy import prod as prod +from keras.src.ops.numpy import quantile as quantile +from keras.src.ops.numpy import ravel as ravel +from keras.src.ops.numpy import real as real +from keras.src.ops.numpy import reciprocal as reciprocal +from keras.src.ops.numpy import repeat as repeat +from keras.src.ops.numpy import reshape as reshape +from keras.src.ops.numpy import right_shift as right_shift +from keras.src.ops.numpy import roll as roll +from keras.src.ops.numpy import rot90 as rot90 +from keras.src.ops.numpy import round as round +from keras.src.ops.numpy import searchsorted as searchsorted +from keras.src.ops.numpy import select as select +from keras.src.ops.numpy import sign as sign +from keras.src.ops.numpy import signbit as signbit +from keras.src.ops.numpy import sin as sin +from keras.src.ops.numpy import sinh as sinh +from keras.src.ops.numpy import size as size +from keras.src.ops.numpy import slogdet as slogdet +from keras.src.ops.numpy import sort as sort +from keras.src.ops.numpy import split as split +from keras.src.ops.numpy import sqrt as sqrt +from keras.src.ops.numpy import square as square +from keras.src.ops.numpy import squeeze as squeeze +from keras.src.ops.numpy import stack as stack +from keras.src.ops.numpy import std as std +from keras.src.ops.numpy import subtract as subtract +from keras.src.ops.numpy import sum as sum +from keras.src.ops.numpy import swapaxes as swapaxes +from keras.src.ops.numpy import take as take +from keras.src.ops.numpy import take_along_axis as take_along_axis +from keras.src.ops.numpy import tan as tan +from keras.src.ops.numpy import tanh as tanh +from keras.src.ops.numpy import tensordot as tensordot +from keras.src.ops.numpy import tile as tile +from keras.src.ops.numpy import trace as trace +from keras.src.ops.numpy import transpose as transpose +from keras.src.ops.numpy import tri as tri +from keras.src.ops.numpy import tril as tril +from keras.src.ops.numpy import triu as triu +from keras.src.ops.numpy import true_divide as true_divide +from keras.src.ops.numpy import trunc as trunc +from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import var as var +from keras.src.ops.numpy import vdot as vdot +from keras.src.ops.numpy import vectorize as vectorize +from keras.src.ops.numpy import vstack as vstack +from keras.src.ops.numpy import where as where +from keras.src.ops.numpy import zeros as zeros +from keras.src.ops.numpy import zeros_like as zeros_like diff --git a/keras/api/optimizers/__init__.py b/keras/api/optimizers/__init__.py new file mode 100644 index 000000000000..40f6ab4018f5 --- /dev/null +++ b/keras/api/optimizers/__init__.py @@ -0,0 +1,28 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.optimizers import legacy as legacy +from keras.optimizers import schedules as schedules +from keras.src.optimizers import deserialize as deserialize +from keras.src.optimizers import get as get +from keras.src.optimizers import serialize as serialize +from keras.src.optimizers.adadelta import Adadelta as Adadelta +from keras.src.optimizers.adafactor import Adafactor as Adafactor +from keras.src.optimizers.adagrad import Adagrad as Adagrad +from keras.src.optimizers.adam import Adam as Adam +from keras.src.optimizers.adamax import Adamax as Adamax +from keras.src.optimizers.adamw import AdamW as AdamW +from keras.src.optimizers.ftrl import Ftrl as Ftrl +from keras.src.optimizers.lamb import Lamb as Lamb +from keras.src.optimizers.lion import Lion as Lion +from keras.src.optimizers.loss_scale_optimizer import ( + LossScaleOptimizer as LossScaleOptimizer, +) +from keras.src.optimizers.muon import Muon as Muon +from keras.src.optimizers.nadam import Nadam as Nadam +from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.optimizers.rmsprop import RMSprop as RMSprop +from keras.src.optimizers.sgd import SGD as SGD diff --git a/keras/api/optimizers/legacy/__init__.py b/keras/api/optimizers/legacy/__init__.py new file mode 100644 index 000000000000..bff1a0313630 --- /dev/null +++ b/keras/api/optimizers/legacy/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.optimizers import LegacyOptimizerWarning as Adagrad +from keras.src.optimizers import LegacyOptimizerWarning as Adam +from keras.src.optimizers import LegacyOptimizerWarning as Ftrl +from keras.src.optimizers import LegacyOptimizerWarning as Optimizer +from keras.src.optimizers import LegacyOptimizerWarning as RMSprop +from keras.src.optimizers import LegacyOptimizerWarning as SGD diff --git a/keras/api/optimizers/schedules/__init__.py b/keras/api/optimizers/schedules/__init__.py new file mode 100644 index 000000000000..da9621aa36b1 --- /dev/null +++ b/keras/api/optimizers/schedules/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.optimizers.schedules.learning_rate_schedule import ( + CosineDecay as CosineDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + CosineDecayRestarts as CosineDecayRestarts, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + ExponentialDecay as ExponentialDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + InverseTimeDecay as InverseTimeDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + LearningRateSchedule as LearningRateSchedule, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + PiecewiseConstantDecay as PiecewiseConstantDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + PolynomialDecay as PolynomialDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + deserialize as deserialize, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + serialize as serialize, +) diff --git a/keras/api/preprocessing/__init__.py b/keras/api/preprocessing/__init__.py new file mode 100644 index 000000000000..49a47f66337e --- /dev/null +++ b/keras/api/preprocessing/__init__.py @@ -0,0 +1,17 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.preprocessing import image as image +from keras.preprocessing import sequence as sequence +from keras.src.utils.image_dataset_utils import ( + image_dataset_from_directory as image_dataset_from_directory, +) +from keras.src.utils.text_dataset_utils import ( + text_dataset_from_directory as text_dataset_from_directory, +) +from keras.src.utils.timeseries_dataset_utils import ( + timeseries_dataset_from_array as timeseries_dataset_from_array, +) diff --git a/keras/api/preprocessing/image/__init__.py b/keras/api/preprocessing/image/__init__.py new file mode 100644 index 000000000000..59f4e125116f --- /dev/null +++ b/keras/api/preprocessing/image/__init__.py @@ -0,0 +1,11 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.utils.image_utils import array_to_img as array_to_img +from keras.src.utils.image_utils import img_to_array as img_to_array +from keras.src.utils.image_utils import load_img as load_img +from keras.src.utils.image_utils import save_img as save_img +from keras.src.utils.image_utils import smart_resize as smart_resize diff --git a/keras/api/preprocessing/sequence/__init__.py b/keras/api/preprocessing/sequence/__init__.py new file mode 100644 index 000000000000..ed43e838795d --- /dev/null +++ b/keras/api/preprocessing/sequence/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.utils.sequence_utils import pad_sequences as pad_sequences diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py new file mode 100644 index 000000000000..299e467ac1bb --- /dev/null +++ b/keras/api/quantizers/__init__.py @@ -0,0 +1,27 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.quantizers import deserialize as deserialize +from keras.src.quantizers import get as get +from keras.src.quantizers import serialize as serialize +from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer +from keras.src.quantizers.quantizers import Quantizer as Quantizer +from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize +from keras.src.quantizers.quantizers import ( + compute_float8_amax_history as compute_float8_amax_history, +) +from keras.src.quantizers.quantizers import ( + compute_float8_scale as compute_float8_scale, +) +from keras.src.quantizers.quantizers import ( + fake_quant_with_min_max_vars as fake_quant_with_min_max_vars, +) +from keras.src.quantizers.quantizers import pack_int4 as pack_int4 +from keras.src.quantizers.quantizers import ( + quantize_and_dequantize as quantize_and_dequantize, +) +from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4 diff --git a/keras/api/random/__init__.py b/keras/api/random/__init__.py new file mode 100644 index 000000000000..d0ee60a77c92 --- /dev/null +++ b/keras/api/random/__init__.py @@ -0,0 +1,17 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.random.random import beta as beta +from keras.src.random.random import binomial as binomial +from keras.src.random.random import categorical as categorical +from keras.src.random.random import dropout as dropout +from keras.src.random.random import gamma as gamma +from keras.src.random.random import normal as normal +from keras.src.random.random import randint as randint +from keras.src.random.random import shuffle as shuffle +from keras.src.random.random import truncated_normal as truncated_normal +from keras.src.random.random import uniform as uniform +from keras.src.random.seed_generator import SeedGenerator as SeedGenerator diff --git a/keras/api/regularizers/__init__.py b/keras/api/regularizers/__init__.py new file mode 100644 index 000000000000..1e3609f71c75 --- /dev/null +++ b/keras/api/regularizers/__init__.py @@ -0,0 +1,22 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.regularizers import deserialize as deserialize +from keras.src.regularizers import get as get +from keras.src.regularizers import serialize as serialize +from keras.src.regularizers.regularizers import L1 as L1 +from keras.src.regularizers.regularizers import L1 as l1 +from keras.src.regularizers.regularizers import L1L2 as L1L2 +from keras.src.regularizers.regularizers import L1L2 as l1_l2 +from keras.src.regularizers.regularizers import L2 as L2 +from keras.src.regularizers.regularizers import L2 as l2 +from keras.src.regularizers.regularizers import ( + OrthogonalRegularizer as OrthogonalRegularizer, +) +from keras.src.regularizers.regularizers import ( + OrthogonalRegularizer as orthogonal_regularizer, +) +from keras.src.regularizers.regularizers import Regularizer as Regularizer diff --git a/keras/api/saving/__init__.py b/keras/api/saving/__init__.py new file mode 100644 index 000000000000..28edd8779337 --- /dev/null +++ b/keras/api/saving/__init__.py @@ -0,0 +1,35 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.saving.file_editor import KerasFileEditor as KerasFileEditor +from keras.src.saving.object_registration import ( + CustomObjectScope as CustomObjectScope, +) +from keras.src.saving.object_registration import ( + CustomObjectScope as custom_object_scope, +) +from keras.src.saving.object_registration import ( + get_custom_objects as get_custom_objects, +) +from keras.src.saving.object_registration import ( + get_registered_name as get_registered_name, +) +from keras.src.saving.object_registration import ( + get_registered_object as get_registered_object, +) +from keras.src.saving.object_registration import ( + register_keras_serializable as register_keras_serializable, +) +from keras.src.saving.saving_api import load_model as load_model +from keras.src.saving.saving_api import load_weights as load_weights +from keras.src.saving.saving_api import save_model as save_model +from keras.src.saving.saving_api import save_weights as save_weights +from keras.src.saving.serialization_lib import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.saving.serialization_lib import ( + serialize_keras_object as serialize_keras_object, +) diff --git a/keras/api/tree/__init__.py b/keras/api/tree/__init__.py new file mode 100644 index 000000000000..80d9f25244e8 --- /dev/null +++ b/keras/api/tree/__init__.py @@ -0,0 +1,20 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.tree.tree_api import MAP_TO_NONE as MAP_TO_NONE +from keras.src.tree.tree_api import assert_same_paths as assert_same_paths +from keras.src.tree.tree_api import ( + assert_same_structure as assert_same_structure, +) +from keras.src.tree.tree_api import flatten as flatten +from keras.src.tree.tree_api import flatten_with_path as flatten_with_path +from keras.src.tree.tree_api import is_nested as is_nested +from keras.src.tree.tree_api import lists_to_tuples as lists_to_tuples +from keras.src.tree.tree_api import map_shape_structure as map_shape_structure +from keras.src.tree.tree_api import map_structure as map_structure +from keras.src.tree.tree_api import map_structure_up_to as map_structure_up_to +from keras.src.tree.tree_api import pack_sequence_as as pack_sequence_as +from keras.src.tree.tree_api import traverse as traverse diff --git a/keras/api/utils/__init__.py b/keras/api/utils/__init__.py new file mode 100644 index 000000000000..8ddbda527609 --- /dev/null +++ b/keras/api/utils/__init__.py @@ -0,0 +1,90 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.backend.common.global_state import clear_session as clear_session +from keras.src.backend.common.keras_tensor import ( + is_keras_tensor as is_keras_tensor, +) +from keras.src.backend.common.variables import ( + standardize_dtype as standardize_dtype, +) +from keras.src.layers.preprocessing.feature_space import ( + FeatureSpace as FeatureSpace, +) +from keras.src.ops.operation_utils import get_source_inputs as get_source_inputs +from keras.src.saving.object_registration import ( + CustomObjectScope as CustomObjectScope, +) +from keras.src.saving.object_registration import ( + CustomObjectScope as custom_object_scope, +) +from keras.src.saving.object_registration import ( + get_custom_objects as get_custom_objects, +) +from keras.src.saving.object_registration import ( + get_registered_name as get_registered_name, +) +from keras.src.saving.object_registration import ( + get_registered_object as get_registered_object, +) +from keras.src.saving.object_registration import ( + register_keras_serializable as register_keras_serializable, +) +from keras.src.saving.serialization_lib import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.saving.serialization_lib import ( + serialize_keras_object as serialize_keras_object, +) +from keras.src.trainers.data_adapters.data_adapter_utils import ( + pack_x_y_sample_weight as pack_x_y_sample_weight, +) +from keras.src.trainers.data_adapters.data_adapter_utils import ( + unpack_x_y_sample_weight as unpack_x_y_sample_weight, +) +from keras.src.trainers.data_adapters.py_dataset_adapter import ( + PyDataset as PyDataset, +) +from keras.src.trainers.data_adapters.py_dataset_adapter import ( + PyDataset as Sequence, +) +from keras.src.utils.audio_dataset_utils import ( + audio_dataset_from_directory as audio_dataset_from_directory, +) +from keras.src.utils.config import Config as Config +from keras.src.utils.dataset_utils import split_dataset as split_dataset +from keras.src.utils.file_utils import get_file as get_file +from keras.src.utils.image_dataset_utils import ( + image_dataset_from_directory as image_dataset_from_directory, +) +from keras.src.utils.image_utils import array_to_img as array_to_img +from keras.src.utils.image_utils import img_to_array as img_to_array +from keras.src.utils.image_utils import load_img as load_img +from keras.src.utils.image_utils import save_img as save_img +from keras.src.utils.io_utils import ( + disable_interactive_logging as disable_interactive_logging, +) +from keras.src.utils.io_utils import ( + enable_interactive_logging as enable_interactive_logging, +) +from keras.src.utils.io_utils import ( + is_interactive_logging_enabled as is_interactive_logging_enabled, +) +from keras.src.utils.model_visualization import model_to_dot as model_to_dot +from keras.src.utils.model_visualization import plot_model as plot_model +from keras.src.utils.numerical_utils import normalize as normalize +from keras.src.utils.numerical_utils import to_categorical as to_categorical +from keras.src.utils.progbar import Progbar as Progbar +from keras.src.utils.rng_utils import set_random_seed as set_random_seed +from keras.src.utils.sequence_utils import pad_sequences as pad_sequences +from keras.src.utils.text_dataset_utils import ( + text_dataset_from_directory as text_dataset_from_directory, +) +from keras.src.utils.timeseries_dataset_utils import ( + timeseries_dataset_from_array as timeseries_dataset_from_array, +) +from keras.utils import bounding_boxes as bounding_boxes +from keras.utils import legacy as legacy diff --git a/keras/api/utils/bounding_boxes/__init__.py b/keras/api/utils/bounding_boxes/__init__.py new file mode 100644 index 000000000000..40221bd75c94 --- /dev/null +++ b/keras/api/utils/bounding_boxes/__init__.py @@ -0,0 +1,33 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + affine_transform as affine_transform, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + clip_to_image_size as clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + convert_format as convert_format, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + crop as crop, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + decode_deltas_to_boxes as decode_deltas_to_boxes, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + encode_box_to_deltas as encode_box_to_deltas, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( + pad as pad, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import ( + compute_ciou as compute_ciou, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import ( + compute_iou as compute_iou, +) diff --git a/keras/api/utils/legacy/__init__.py b/keras/api/utils/legacy/__init__.py new file mode 100644 index 000000000000..1e3aa0ee9d5c --- /dev/null +++ b/keras/api/utils/legacy/__init__.py @@ -0,0 +1,12 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.legacy.saving.serialization import ( + deserialize_keras_object as deserialize_keras_object, +) +from keras.src.legacy.saving.serialization import ( + serialize_keras_object as serialize_keras_object, +) diff --git a/keras/api/visualization/__init__.py b/keras/api/visualization/__init__.py new file mode 100644 index 000000000000..6e3482a8d59a --- /dev/null +++ b/keras/api/visualization/__init__.py @@ -0,0 +1,21 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.visualization.draw_bounding_boxes import ( + draw_bounding_boxes as draw_bounding_boxes, +) +from keras.src.visualization.draw_segmentation_masks import ( + draw_segmentation_masks as draw_segmentation_masks, +) +from keras.src.visualization.plot_bounding_box_gallery import ( + plot_bounding_box_gallery as plot_bounding_box_gallery, +) +from keras.src.visualization.plot_image_gallery import ( + plot_image_gallery as plot_image_gallery, +) +from keras.src.visualization.plot_segmentation_mask_gallery import ( + plot_segmentation_mask_gallery as plot_segmentation_mask_gallery, +) diff --git a/keras/api/wrappers/__init__.py b/keras/api/wrappers/__init__.py new file mode 100644 index 000000000000..e3aa52524ca6 --- /dev/null +++ b/keras/api/wrappers/__init__.py @@ -0,0 +1,15 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.wrappers.sklearn_wrapper import ( + SKLearnClassifier as SKLearnClassifier, +) +from keras.src.wrappers.sklearn_wrapper import ( + SKLearnRegressor as SKLearnRegressor, +) +from keras.src.wrappers.sklearn_wrapper import ( + SKLearnTransformer as SKLearnTransformer, +) diff --git a/keras/datasets/cifar10.py b/keras/datasets/cifar10.py deleted file mode 100644 index 98e8b17be6b3..000000000000 --- a/keras/datasets/cifar10.py +++ /dev/null @@ -1,41 +0,0 @@ -from data_utils import get_file -import random -import cPickle -import numpy as np -from PIL import Image - -def load_data(test_split=0.1, seed=113): - dirname = "cifar-10-batches-py" - origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" - path = get_file(dirname, origin=origin, untar=True) - - nb_samples = 50000 - X = np.zeros((nb_samples, 3, 32, 32), dtype="uint8") - y = np.zeros((nb_samples,)) - for i in range(1, 6): - fpath = path + '/data_batch_' + str(i) - f = open(fpath, 'rb') - d = cPickle.load(f) - f.close() - data = d["data"] - labels = d["labels"] - - data = data.reshape(data.shape[0], 3, 32, 32) - X[(i-1)*10000:i*10000, :, :, :] = data - y[(i-1)*10000:i*10000] = labels - - np.random.seed(seed) - np.random.shuffle(X) - np.random.seed(seed) - np.random.shuffle(y) - - y = np.reshape(y, (len(y), 1)) - - X_train = X[:int(len(X)*(1-test_split))] - y_train = y[:int(len(X)*(1-test_split))] - - X_test = X[int(len(X)*(1-test_split)):] - y_test = y[int(len(X)*(1-test_split)):] - - return (X_train, y_train), (X_test, y_test) - diff --git a/keras/datasets/data_utils.py b/keras/datasets/data_utils.py deleted file mode 100644 index 437618365406..000000000000 --- a/keras/datasets/data_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -import urllib, tarfile -import inspect, os -from ..utils.generic_utils import Progbar - -def get_file(fname, origin, untar=False): - datadir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) - datadir = os.path.join(datadir, 'data') - if not os.path.exists(datadir): - os.makedirs(datadir) - - if untar: - untar_fpath = os.path.join(datadir, fname) - fpath = untar_fpath + '.tar.gz' - else: - fpath = os.path.join(datadir, fname) - - try: - f = open(fpath) - except: - print 'Downloading data from', origin - - global progbar - progbar = None - def dl_progress(count, block_size, total_size): - global progbar - if progbar is None: - progbar = Progbar(total_size) - else: - progbar.update(count*block_size) - - urllib.urlretrieve(origin, fpath, dl_progress) - progbar = None - - if untar: - if not os.path.exists(untar_fpath): - print 'Untaring file...' - tfile = tarfile.open(fpath, 'r:gz') - tfile.extractall(path=datadir) - tfile.close() - return untar_fpath - - return fpath - - - diff --git a/keras/datasets/imdb.py b/keras/datasets/imdb.py deleted file mode 100644 index aa6ee4c7014f..000000000000 --- a/keras/datasets/imdb.py +++ /dev/null @@ -1,40 +0,0 @@ -import cPickle -import gzip -from data_utils import get_file -import random - -def load_data(path="imdb.pkl", nb_words=100000, maxlen=None, test_split=0.2, seed=113): - path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/imdb.pkl") - - if path.endswith(".gz"): - f = gzip.open(path, 'rb') - else: - f = open(path, 'rb') - - X, labels = cPickle.load(f) - f.close() - - random.seed(seed) - random.shuffle(X) - random.seed(seed) - random.shuffle(labels) - - if maxlen: - new_X = [] - new_labels = [] - for x, y in zip(X, labels): - if len(x) < maxlen: - new_X.append(x) - new_labels.append(y) - X = new_X - labels = new_labels - - X = [[1 if w >= nb_words else w for w in x] for x in X] - X_train = X[:int(len(X)*(1-test_split))] - y_train = labels[:int(len(X)*(1-test_split))] - - X_test = X[int(len(X)*(1-test_split)):] - y_test = labels[int(len(X)*(1-test_split)):] - - return (X_train, y_train), (X_test, y_test) - diff --git a/keras/datasets/reuters.py b/keras/datasets/reuters.py deleted file mode 100644 index ed6ae0c68299..000000000000 --- a/keras/datasets/reuters.py +++ /dev/null @@ -1,111 +0,0 @@ -# -*- coding: utf-8 -*- -from data_utils import get_file -import string -import random -import cPickle - -def make_reuters_dataset(path='datasets/temp/reuters21578/', min_samples_per_topic=15): - import os - import re - from preprocessing.text import Tokenizer - - wire_topics = [] - topic_counts = {} - wire_bodies = [] - - for fname in os.listdir(path): - if 'sgm' in fname: - s = open(path + fname).read() - tag = '' - while tag in s: - s = s[s.find(tag)+len(tag):] - topics = s[:s.find('' in topics: - topic = topics.replace('', '').replace('', '') - wire_topics.append(topic) - topic_counts[topic] = topic_counts.get(topic, 0) + 1 - else: - continue - - bodytag = '' - body = s[s.find(bodytag)+len(bodytag):] - body = body[:body.find('= min_samples_per_topic: - kept_topics.add(x[0]) - print '-' - print 'Kept topics:', len(kept_topics) - - # filter wires with rare topics - kept_wires = [] - labels = [] - topic_indexes = {} - for t, b in zip(wire_topics, wire_bodies): - if t in kept_topics: - if t not in topic_indexes: - topic_index = len(topic_indexes) - topic_indexes[t] = topic_index - else: - topic_index = topic_indexes[t] - - labels.append(topic_index) - kept_wires.append(b) - - # vectorize wires - tokenizer = Tokenizer() - tokenizer.fit(kept_wires) - X = tokenizer.transform(kept_wires) - - print 'Sanity check:' - for w in ["banana", "oil", "chocolate", "the", "dsft"]: - print '...index of', w, ':', tokenizer.word_index.get(w) - - dataset = (X, labels) - print '-' - print 'Saving...' - cPickle.dump(dataset, open('datasets/data/reuters.pkl', 'w')) - - - -def load_data(path="reuters.pkl", nb_words=100000, maxlen=None, test_split=0.2, seed=113): - path = get_file(path, origin="https://s3.amazonaws.com/text-datasets/reuters.pkl") - f = open(path, 'rb') - - X, labels = cPickle.load(f) - f.close() - random.seed(seed) - random.shuffle(X) - random.seed(seed) - random.shuffle(labels) - - if maxlen: - new_X = [] - new_labels = [] - for x, y in zip(X, labels): - if len(x) < maxlen: - new_X.append(x) - new_labels.append(y) - X = new_X - labels = new_labels - - X = [[1 if w >= nb_words else w for w in x] for x in X] - X_train = X[:int(len(X)*(1-test_split))] - y_train = labels[:int(len(X)*(1-test_split))] - - X_test = X[int(len(X)*(1-test_split)):] - y_test = labels[int(len(X)*(1-test_split)):] - - return (X_train, y_train), (X_test, y_test) - - -if __name__ == "__main__": - make_reuters_dataset() - (X_train, y_train), (X_test, y_test) = load_data() diff --git a/keras/initializations.py b/keras/initializations.py deleted file mode 100644 index 5753c771adbc..000000000000 --- a/keras/initializations.py +++ /dev/null @@ -1,36 +0,0 @@ -import theano -import theano.tensor as T -import numpy as np - -from utils.theano_utils import sharedX - -def uniform(shape, scale=0.05): - return sharedX(np.random.uniform(low=-scale, high=scale, size=shape)) - -def normal(shape, scale=0.05): - return sharedX(np.random.randn(*shape) * scale) - -def lecun_uniform(shape): - ''' Reference: LeCun 98, Efficient Backprop - http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf - ''' - m = 1 - for s in shape: - m *= s - scale = 1./np.sqrt(m) - return uniform(shape, scale) - -def orthogonal(shape, scale=1.1): - ''' From Lasagne - ''' - flat_shape = (shape[0], np.prod(shape[1:])) - a = np.random.normal(0.0, 1.0, flat_shape) - u, _, v = np.linalg.svd(a, full_matrices=False) - q = u if u.shape == flat_shape else v # pick the one with the correct shape - q = q.reshape(shape) - return sharedX(scale * q[:shape[0], :shape[1]]) - - -from utils.generic_utils import get_from_module -def get(identifier): - return get_from_module(identifier, globals(), 'initialization') \ No newline at end of file diff --git a/keras/layers/advanced_activations.py b/keras/layers/advanced_activations.py deleted file mode 100644 index d4210afacb1b..000000000000 --- a/keras/layers/advanced_activations.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..layers.core import Layer -from ..utils.theano_utils import shared_zeros - -class LeakyReLU(Layer): - def __init__(self, alpha=0.3): - self.alpha = alpha - self.params = [] - - def output(self, train): - X = self.get_input(train) - return ((X + abs(X)) / 2.0) + self.alpha * ((X - abs(X)) / 2.0) - - -class PReLU(Layer): - ''' - Reference: - Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification - http://arxiv.org/pdf/1502.01852v1.pdf - ''' - def __init__(self, input_shape): - self.alphas = shared_zeros(input_shape) - self.params = [self.alphas] - - def output(self, train): - X = self.get_input(train) - pos = ((X + abs(X)) / 2.0) - neg = self.alphas * ((X - abs(X)) / 2.0) - return pos + neg diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py deleted file mode 100644 index 3f3524a20abc..000000000000 --- a/keras/layers/convolutional.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -import theano -import theano.tensor as T -from theano.tensor.signal import downsample - -from .. import activations, initializations -from ..utils.theano_utils import shared_zeros -from ..layers.core import Layer - - -# class Convolution1D(Layer): TODO - -# class MaxPooling1D(Layer): TODO - - -class Convolution2D(Layer): - def __init__(self, nb_filter, stack_size, nb_row, nb_col, - init='uniform', activation='linear', weights=None, - image_shape=None, border_mode='valid', subsample=(1,1)): - - self.init = initializations.get(init) - self.activation = activations.get(activation) - self.subsample = subsample - self.border_mode = border_mode - self.image_shape = image_shape - - self.input = T.tensor4() - self.W_shape = (nb_filter, stack_size, nb_row, nb_col) - self.W = self.init(self.W_shape) - self.b = shared_zeros((nb_filter,)) - - self.params = [self.W, self.b] - - if weights is not None: - self.set_weights(weights) - - def output(self, train): - X = self.get_input(train) - - conv_out = theano.tensor.nnet.conv.conv2d(X, self.W, - border_mode=self.border_mode, subsample=self.subsample, image_shape=self.image_shape) - output = self.activation(conv_out + self.b.dimshuffle('x', 0, 'x', 'x')) - return output - - -class MaxPooling2D(Layer): - def __init__(self, poolsize=(2, 2), ignore_border=True): - self.input = T.tensor4() - self.poolsize = poolsize - self.ignore_border = ignore_border - self.params = [] - - def output(self, train): - X = self.get_input(train) - output = downsample.max_pool_2d(X, self.poolsize, ignore_border=self.ignore_border) - return output - - -# class ZeroPadding2D(Layer): TODO - diff --git a/keras/layers/core.py b/keras/layers/core.py deleted file mode 100644 index c525a8b14234..000000000000 --- a/keras/layers/core.py +++ /dev/null @@ -1,165 +0,0 @@ -# -*- coding: utf-8 -*- -import theano -import theano.tensor as T - -from .. import activations, initializations -from ..utils.theano_utils import shared_zeros, floatX -from ..utils.generic_utils import make_tuple - -from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams -srng = RandomStreams() - -class Layer(object): - def connect(self, previous_layer): - self.previous_layer = previous_layer - - def output(self, train): - raise NotImplementedError - - def get_input(self, train): - if hasattr(self, 'previous_layer'): - return self.previous_layer.output(train=train) - else: - return self.input - - def set_weights(self, weights): - for p, w in zip(self.params, weights): - p.set_value(floatX(w)) - - def get_weights(self): - weights = [] - for p in self.params: - weights.append(p.get_value()) - return weights - - -class Dropout(Layer): - ''' - Hinton's dropout. - ''' - def __init__(self, p): - self.p = p - self.params = [] - - def output(self, train): - X = self.get_input(train) - if self.p > 0.: - retain_prob = 1. - self.p - if train: - X *= srng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX) - else: - X *= retain_prob - return X - - -class Activation(Layer): - ''' - Apply an activation function to an output. - ''' - def __init__(self, activation): - self.activation = activations.get(activation) - self.params = [] - - def output(self, train): - X = self.get_input(train) - return self.activation(X) - - -class Reshape(Layer): - ''' - Reshape an output to a certain shape. - Can't be used as first layer in a model (no fixed input!) - First dimension is assumed to be nb_samples. - ''' - def __init__(self, *dims): - self.dims = dims - self.params = [] - - def output(self, train): - X = self.get_input(train) - nshape = make_tuple(X.shape[0], *self.dims) - return theano.tensor.reshape(X, nshape) - - -class Flatten(Layer): - ''' - Reshape input to flat shape. - First dimension is assumed to be nb_samples. - ''' - def __init__(self, size): - self.size = size - self.params = [] - - def output(self, train): - X = self.get_input(train) - nshape = (X.shape[0], self.size) - return theano.tensor.reshape(X, nshape) - - -class RepeatVector(Layer): - ''' - Repeat input n times. - - Dimensions of input are assumed to be (nb_samples, dim). - Return tensor of shape (nb_samples, n, dim). - ''' - def __init__(self, n): - self.n = n - self.params = [] - - def output(self, train): - X = self.get_input(train) - tensors = [X]*self.n - stacked = theano.tensor.stack(*tensors) - return stacked.dimshuffle((1,0,2)) - - -class Dense(Layer): - ''' - Just your regular fully connected NN layer. - ''' - def __init__(self, input_dim, output_dim, init='uniform', activation='linear', weights=None): - self.init = initializations.get(init) - self.activation = activations.get(activation) - self.input_dim = input_dim - self.output_dim = output_dim - - self.input = T.matrix() - self.W = self.init((self.input_dim, self.output_dim)) - self.b = shared_zeros((self.output_dim)) - - self.params = [self.W, self.b] - - if weights is not None: - self.set_weights(weights) - - def output(self, train): - X = self.get_input(train) - output = self.activation(T.dot(X, self.W) + self.b) - return output - - -class Embedding(Layer): - ''' - Turn a list of integers >=0 into a dense vector of fixed size. - eg. [4, 50, 123, 26] -> [0.25, 0.1] - - @input_dim: size of vocabulary (highest input integer + 1) - @out_dim: size of dense representation - ''' - def __init__(self, input_dim, output_dim, init='uniform', weights=None): - self.init = initializations.get(init) - self.input_dim = input_dim - self.output_dim = output_dim - - self.input = T.imatrix() - self.W = self.init((self.input_dim, self.output_dim)) - self.params = [self.W] - - if weights is not None: - self.set_weights(weights) - - def output(self, train): - X = self.get_input(train) - return self.W[X] - diff --git a/keras/layers/normalization.py b/keras/layers/normalization.py deleted file mode 100644 index abbaf45cbc3f..000000000000 --- a/keras/layers/normalization.py +++ /dev/null @@ -1,27 +0,0 @@ -from ..layers.core import Layer -from ..utils.theano_utils import shared_zeros -from .. import initializations - -class BatchNormalization(Layer): - ''' - Reference: - Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift - http://arxiv.org/pdf/1502.03167v3.pdf - ''' - def __init__(self, input_shape, epsilon=1e-6, weights=None): - self.init = initializations.get("uniform") - self.input_shape = input_shape - self.epsilon = epsilon - - self.gamma = self.init((self.input_shape)) - self.beta = shared_zeros(self.input_shape) - - self.params = [self.gamma, self.beta] - if weights is not None: - self.set_weights(weights) - - def output(self, train): - X = self.get_input(train) - X_normed = (X - X.mean(keepdims=True)) / (X.std(keepdims=True) + self.epsilon) - out = self.gamma * X_normed + self.beta - return out \ No newline at end of file diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py deleted file mode 100644 index 5eed96cad58f..000000000000 --- a/keras/layers/recurrent.py +++ /dev/null @@ -1,320 +0,0 @@ -# -*- coding: utf-8 -*- -import theano -import theano.tensor as T -import numpy as np - -from .. import activations, initializations -from ..utils.theano_utils import shared_zeros, alloc_zeros_matrix -from ..layers.core import Layer - -class SimpleRNN(Layer): - ''' - Fully connected RNN where output is to fed back to input. - - Not a particularly useful model, - included for demonstration purposes - (demonstrates how to use theano.scan to build a basic RNN). - ''' - def __init__(self, input_dim, output_dim, - init='uniform', inner_init='orthogonal', activation='sigmoid', weights=None, - truncate_gradient=-1, return_sequences=False): - self.init = initializations.get(init) - self.inner_init = initializations.get(inner_init) - self.input_dim = input_dim - self.output_dim = output_dim - self.truncate_gradient = truncate_gradient - self.activation = activations.get(activation) - self.return_sequences = return_sequences - self.input = T.tensor3() - - self.W = self.init((self.input_dim, self.output_dim)) - self.U = self.init((self.output_dim, self.output_dim)) - self.b = shared_zeros((self.output_dim)) - self.params = [self.W, self.U, self.b] - - if weights is not None: - self.set_weights(weights) - - def _step(self, x_t, h_tm1, u): - ''' - Variable names follow the conventions from: - http://deeplearning.net/software/theano/library/scan.html - - ''' - return self.activation(x_t + T.dot(h_tm1, u)) - - def output(self, train): - X = self.get_input(train) # shape: (nb_samples, time (padded with zeros at the end), input_dim) - # new shape: (time, nb_samples, input_dim) -> because theano.scan iterates over main dimension - X = X.dimshuffle((1,0,2)) - - x = T.dot(X, self.W) + self.b - - # scan = theano symbolic loop. - # See: http://deeplearning.net/software/theano/library/scan.html - # Iterate over the first dimension of the x array (=time). - outputs, updates = theano.scan( - self._step, # this will be called with arguments (sequences[i], outputs[i-1], non_sequences[i]) - sequences=x, # tensors to iterate over, inputs to _step - # initialization of the output. Input to _step with default tap=-1. - outputs_info=alloc_zeros_matrix(X.shape[1], self.output_dim), - non_sequences=self.U, # static inputs to _step - truncate_gradient=self.truncate_gradient - ) - if self.return_sequences: - return outputs.dimshuffle((1,0,2)) - return outputs[-1] - - -class SimpleDeepRNN(Layer): - ''' - Fully connected RNN where the output of multiple timesteps - (up to "depth" steps in the past) is fed back to the input: - - output = activation( W.x_t + b + inner_activation(U_1.h_tm1) + inner_activation(U_2.h_tm2) + ... ) - - This demonstrates how to build RNNs with arbitrary lookback. - Also (probably) not a super useful model. - ''' - def __init__(self, input_dim, output_dim, depth=3, - init='uniform', inner_init='orthogonal', - activation='sigmoid', inner_activation='hard_sigmoid', - weights=None, truncate_gradient=-1, return_sequences=False): - self.init = initializations.get(init) - self.inner_init = initializations.get(inner_init) - self.input_dim = input_dim - self.output_dim = output_dim - self.truncate_gradient = truncate_gradient - self.activation = activations.get(activation) - self.inner_activation = activations.get(inner_activation) - self.depth = depth - self.return_sequences = return_sequences - self.input = T.tensor3() - - self.W = self.init((self.input_dim, self.output_dim)) - self.Us = [self.init((self.output_dim, self.output_dim)) for _ in range(self.depth)] - self.b = shared_zeros((self.output_dim)) - self.params = [self.W] + self.Us + [self.b] - - if weights is not None: - self.set_weights(weights) - - def _step(self, *args): - o = args[0] - for i in range(1, self.depth+1): - o += self.inner_activation(T.dot(args[i], args[i+self.depth])) - return o - - def output(self, train): - X = self.get_input(train) - X = X.dimshuffle((1,0,2)) - - x = T.dot(X, self.W) + self.b - - outputs, updates = theano.scan( - self._step, - sequences=x, - outputs_info=[dict( - initial=T.alloc(np.cast[theano.config.floatX](0.), self.depth, X.shape[1], self.output_dim), - taps = [(-i-1) for i in range(self.depth)] - )], - non_sequences=self.Us, - truncate_gradient=self.truncate_gradient - ) - if self.return_sequences: - return outputs.dimshuffle((1,0,2)) - return outputs[-1] - - - -class GRU(Layer): - ''' - Gated Recurrent Unit - Cho et al. 2014 - - Acts as a spatiotemporal projection, - turning a sequence of vectors into a single vector. - - Eats inputs with shape: - (nb_samples, max_sample_length (samples shorter than this are padded with zeros at the end), input_dim) - - and returns outputs with shape: - if not return_sequences: - (nb_samples, output_dim) - if return_sequences: - (nb_samples, max_sample_length, output_dim) - - References: - On the Properties of Neural Machine Translation: Encoder–Decoder Approaches - http://www.aclweb.org/anthology/W14-4012 - Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling - http://arxiv.org/pdf/1412.3555v1.pdf - ''' - def __init__(self, input_dim, output_dim=128, - init='uniform', inner_init='orthogonal', - activation='sigmoid', inner_activation='hard_sigmoid', - truncate_gradient=-1, weights=None, return_sequences=False): - - self.input_dim = input_dim - self.output_dim = output_dim - self.truncate_gradient = truncate_gradient - self.return_sequences = return_sequences - - self.init = initializations.get(init) - self.inner_init = initializations.get(inner_init) - self.activation = activations.get(activation) - self.inner_activation = activations.get(inner_activation) - self.input = T.tensor3() - - self.W_z = self.init((self.input_dim, self.output_dim)) - self.U_z = self.inner_init((self.output_dim, self.output_dim)) - self.b_z = shared_zeros((self.output_dim)) - - self.W_r = self.init((self.input_dim, self.output_dim)) - self.U_r = self.inner_init((self.output_dim, self.output_dim)) - self.b_r = shared_zeros((self.output_dim)) - - self.W_h = self.init((self.input_dim, self.output_dim)) - self.U_h = self.inner_init((self.output_dim, self.output_dim)) - self.b_h = shared_zeros((self.output_dim)) - - self.params = [ - self.W_z, self.U_z, self.b_z, - self.W_r, self.U_r, self.b_r, - self.W_h, self.U_h, self.b_h, - ] - - if weights is not None: - self.set_weights(weights) - - def _step(self, - xz_t, xr_t, xh_t, - h_tm1, - u_z, u_r, u_h): - z = self.inner_activation(xz_t + T.dot(h_tm1, u_z)) - r = self.inner_activation(xr_t + T.dot(h_tm1, u_r)) - hh_t = self.activation(xh_t + T.dot(r * h_tm1, u_h)) - h_t = z * h_tm1 + (1 - z) * hh_t - return h_t - - def output(self, train): - X = self.get_input(train) - X = X.dimshuffle((1,0,2)) - - x_z = T.dot(X, self.W_z) + self.b_z - x_r = T.dot(X, self.W_r) + self.b_r - x_h = T.dot(X, self.W_h) + self.b_h - outputs, updates = theano.scan( - self._step, - sequences=[x_z, x_r, x_h], - outputs_info=alloc_zeros_matrix(X.shape[1], self.output_dim), - non_sequences=[self.U_z, self.U_r, self.U_h], - truncate_gradient=self.truncate_gradient - ) - if self.return_sequences: - return outputs.dimshuffle((1,0,2)) - return outputs[-1] - - - -class LSTM(Layer): - ''' - Acts as a spatiotemporal projection, - turning a sequence of vectors into a single vector. - - Eats inputs with shape: - (nb_samples, max_sample_length (samples shorter than this are padded with zeros at the end), input_dim) - - and returns outputs with shape: - if not return_sequences: - (nb_samples, output_dim) - if return_sequences: - (nb_samples, max_sample_length, output_dim) - - For a step-by-step description of the algorithm, see: - http://deeplearning.net/tutorial/lstm.html - - References: - Long short-term memory (original 97 paper) - http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf - Learning to forget: Continual prediction with LSTM - http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015 - Supervised sequence labelling with recurrent neural networks - http://www.cs.toronto.edu/~graves/preprint.pdf - ''' - def __init__(self, input_dim, output_dim=128, - init='uniform', inner_init='orthogonal', - activation='tanh', inner_activation='hard_sigmoid', - truncate_gradient=-1, weights=None, return_sequences=False): - - self.input_dim = input_dim - self.output_dim = output_dim - self.truncate_gradient = truncate_gradient - self.return_sequences = return_sequences - - self.init = initializations.get(init) - self.inner_init = initializations.get(inner_init) - self.activation = activations.get(activation) - self.inner_activation = activations.get(inner_activation) - self.input = T.tensor3() - - self.W_i = self.init((self.input_dim, self.output_dim)) - self.U_i = self.inner_init((self.output_dim, self.output_dim)) - self.b_i = shared_zeros((self.output_dim)) - - self.W_f = self.init((self.input_dim, self.output_dim)) - self.U_f = self.inner_init((self.output_dim, self.output_dim)) - self.b_f = shared_zeros((self.output_dim)) - - self.W_c = self.init((self.input_dim, self.output_dim)) - self.U_c = self.inner_init((self.output_dim, self.output_dim)) - self.b_c = shared_zeros((self.output_dim)) - - self.W_o = self.init((self.input_dim, self.output_dim)) - self.U_o = self.inner_init((self.output_dim, self.output_dim)) - self.b_o = shared_zeros((self.output_dim)) - - self.params = [ - self.W_i, self.U_i, self.b_i, - self.W_c, self.U_c, self.b_c, - self.W_f, self.U_f, self.b_f, - self.W_o, self.U_o, self.b_o, - ] - - if weights is not None: - self.set_weights(weights) - - def _step(self, - xi_t, xf_t, xo_t, xc_t, - h_tm1, c_tm1, - u_i, u_f, u_o, u_c): - i_t = self.inner_activation(xi_t + T.dot(h_tm1, u_i)) - f_t = self.inner_activation(xf_t + T.dot(h_tm1, u_f)) - c_t = f_t * c_tm1 + i_t * self.activation(xc_t + T.dot(h_tm1, u_c)) - o_t = self.inner_activation(xo_t + T.dot(h_tm1, u_o)) - h_t = o_t * self.activation(c_t) - return h_t, c_t - - def output(self, train): - X = self.get_input(train) - X = X.dimshuffle((1,0,2)) - - xi = T.dot(X, self.W_i) + self.b_i - xf = T.dot(X, self.W_f) + self.b_f - xc = T.dot(X, self.W_c) + self.b_c - xo = T.dot(X, self.W_o) + self.b_o - - [outputs, memories], updates = theano.scan( - self._step, - sequences=[xi, xf, xo, xc], - outputs_info=[ - alloc_zeros_matrix(X.shape[1], self.output_dim), - alloc_zeros_matrix(X.shape[1], self.output_dim) - ], - non_sequences=[self.U_i, self.U_f, self.U_o, self.U_c], - truncate_gradient=self.truncate_gradient - ) - if self.return_sequences: - return outputs.dimshuffle((1,0,2)) - return outputs[-1] - - diff --git a/keras/models.py b/keras/models.py deleted file mode 100644 index 727960df2215..000000000000 --- a/keras/models.py +++ /dev/null @@ -1,112 +0,0 @@ -import theano -import theano.tensor as T -import numpy as np - -import optimizers -import objectives -import time, copy -from utils.generic_utils import Progbar - -def standardize_y(y): - y = np.asarray(y) - if len(y.shape) == 1: - y = np.reshape(y, (len(y), 1)) - return y - -class Sequential(object): - def __init__(self): - self.layers = [] - self.params = [] - - def add(self, layer): - self.layers.append(layer) - if len(self.layers) > 1: - self.layers[-1].connect(self.layers[-2]) - self.params += [p for p in layer.params] - - def compile(self, optimizer, loss): - self.optimizer = optimizers.get(optimizer) - self.loss = objectives.get(loss) - - self.X = self.layers[0].input # input of model - # (first layer must have an "input" attribute!) - self.y_train = self.layers[-1].output(train=True) - self.y_test = self.layers[-1].output(train=False) - - # output of model - self.Y = T.matrix() # TODO: support for custom output shapes - - train_loss = self.loss(self.Y, self.y_train) - test_score = self.loss(self.Y, self.y_test) - updates = self.optimizer.get_updates(self.params, train_loss) - - self._train = theano.function([self.X, self.Y], train_loss, - updates=updates, allow_input_downcast=True) - self._predict = theano.function([self.X], self.y_test, - allow_input_downcast=True) - self._test = theano.function([self.X, self.Y], test_score, - allow_input_downcast=True) - - def train(self, X, y): - y = standardize_y(y) - loss = self._train(X, y) - return loss - - def test(self, X, y): - y = standardize_y(y) - score = self._test(X, y) - return score - - def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, shuffle=True): - y = standardize_y(y) - index_array = np.arange(len(X)) - for epoch in range(nb_epoch): - if verbose: - print 'Epoch', epoch - if shuffle: - np.random.shuffle(index_array) - nb_batch = len(X)/batch_size+1 - progbar = Progbar(target=len(X)) - for batch_index in range(0, nb_batch): - batch = range(batch_index*batch_size, min(len(X), (batch_index+1)*batch_size)) - if not batch: - break - prog = batch[-1]+1 - batch = index_array[batch] - loss = self._train(X[batch], y[batch]) - if verbose: - progbar.update(prog, [('loss', loss)]) - - def predict_proba(self, X, batch_size=128): - for batch_index in range(0, len(X)/batch_size+1): - batch = range(batch_index*batch_size, min(len(X), (batch_index+1)*batch_size)) - if not batch: - break - batch_preds = self._predict(X[batch]) - - if batch_index == 0: - preds = np.zeros((len(X), batch_preds.shape[1])) - preds[batch] = batch_preds - return preds - - def predict_classes(self, X, batch_size=128): - proba = self.predict_proba(X, batch_size=batch_size) - if proba.shape[1] > 1: - return proba.argmax(axis=1) - else: - return np.array([1 if p > 0.5 else 0 for p in proba]) - - def evaluate(self, X, y, batch_size=128): - y = standardize_y(y) - av_score = 0. - samples = 0 - for batch_index in range(0, len(X)/batch_size+1): - batch = range(batch_index*batch_size, min(len(X), (batch_index+1)*batch_size)) - if not batch: - break - score = self._test(X[batch], y[batch]) - av_score += len(batch)*score - samples += len(batch) - return av_score/samples - - diff --git a/keras/objectives.py b/keras/objectives.py deleted file mode 100644 index fadccef1ff7a..000000000000 --- a/keras/objectives.py +++ /dev/null @@ -1,47 +0,0 @@ -import theano -import theano.tensor as T -import numpy as np - -epsilon = 1.0e-15 - -def mean_squared_error(y_true, y_pred): - return T.sqr(y_pred - y_true).mean() - -def mean_absolute_error(y_true, y_pred): - return T.abs_(y_pred - y_true).mean() - -def squared_hinge(y_true, y_pred): - return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean() - -def hinge(y_true, y_pred): - return T.maximum(1. - y_true * y_pred, 0.).mean() - -def categorical_crossentropy(y_true, y_pred): - '''Expects a binary class matrix instead of a vector of scalar classes - ''' - y_pred = T.clip(y_pred, epsilon, 1.0 - epsilon) - # scale preds so that the class probas of each sample sum to 1 - y_pred /= y_pred.sum(axis=1, keepdims=True) - return T.nnet.categorical_crossentropy(y_pred, y_true).mean() - -def binary_crossentropy(y_true, y_pred): - y_pred = T.clip(y_pred, epsilon, 1.0 - epsilon) - return T.nnet.binary_crossentropy(y_pred, y_true).mean() - -# aliases -mse = MSE = mean_squared_error -mae = MAE = mean_absolute_error - -from utils.generic_utils import get_from_module -def get(identifier): - return get_from_module(identifier, globals(), 'objective') - -def to_categorical(y): - '''Convert class vector (integers from 0 to nb_classes) - to binary class matrix, for use with categorical_crossentropy - ''' - nb_classes = np.max(y)+1 - Y = np.zeros((len(y), nb_classes)) - for i in range(len(y)): - Y[i, y[i]] = 1. - return Y diff --git a/keras/optimizers.py b/keras/optimizers.py deleted file mode 100644 index 4b963050fd59..000000000000 --- a/keras/optimizers.py +++ /dev/null @@ -1,136 +0,0 @@ -import theano -import theano.tensor as T -import numpy as np - -from utils.theano_utils import shared_zeros, shared_scalar - -def clip_norm(g, c, n): - if c > 0: - g = T.switch(T.ge(n, c), g*c/n, g) - return g - -class Optimizer(object): - def get_updates(self, params, grads): - raise NotImplementedError - - def get_gradients(self, cost, params): - grads = T.grad(cost, params) - - if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = T.sqrt(sum([T.sum(g**2) for g in grads])) - grads = [clip_norm(g, c, norm) for g in grads] - - new_grads = [] - for p, g in zip(params, grads): - if hasattr(self, 'l1') and self.l1 > 0: - g += T.sgn(p) * self.l1 - - if hasattr(self, 'l2') and self.l2 > 0: - g += p * self.l2 - - if hasattr(self, 'maxnorm') and self.maxnorm > 0: - norms = T.sqrt(T.sum(T.sqr(p), axis=0)) - desired = T.clip(norms, 0, self.maxnorm) - p = p * (desired / (1e-7 + norms)) - - new_grads.append(g) - return new_grads - - -class SGD(Optimizer): - - def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, *args, **kwargs): - self.__dict__.update(locals()) - self.iterations = shared_scalar(0) - - def get_updates(self, params, cost): - grads = self.get_gradients(cost, params) - lr = self.lr - self.decay * self.iterations - updates = [(self.iterations, self.iterations+1.)] - - for p, g in zip(params, grads): - m = shared_zeros(p.get_value().shape) # momentum - v = self.momentum * m - lr * g # velocity - updates.append((m, v)) - - if self.nesterov: - new_p = p + self.momentum * v - lr * g - else: - new_p = p + v - updates.append((p, new_p)) - return updates - - -class RMSprop(Optimizer): - - def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs): - self.__dict__.update(locals()) - - def get_updates(self, params, cost): - grads = self.get_gradients(cost, params) - accumulators = [shared_zeros(p.get_value().shape) for p in params] - updates = [] - - for p, g, a in zip(params, grads, accumulators): - new_a = self.rho * a + (1 - self.rho) * g ** 2 # update accumulator - updates.append((a, new_a)) - - new_p = p - self.lr * g / T.sqrt(new_a + self.epsilon) - updates.append((p, new_p)) - return updates - - -class Adagrad(Optimizer): - - def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs): - self.__dict__.update(locals()) - - def get_updates(self, params, cost): - grads = self.get_gradients(cost, params) - accumulators = [shared_zeros(p.get_value().shape) for p in params] - updates = [] - - for p, g, a in zip(params, grads, accumulators): - new_a = a + g ** 2 # update accumulator - updates.append((a, new_a)) - - new_p = p - self.lr * g / T.sqrt(new_a + self.epsilon) - updates.append((p, new_p)) - return updates - - -class Adadelta(Optimizer): - - def __init__(self, lr=1.0, rho=0.95, epsilon=1e-6, *args, **kwargs): - self.__dict__.update(locals()) - - def get_updates(self, params, cost): - grads = self.get_gradients(cost, params) - accumulators = [shared_zeros(p.get_value().shape) for p in params] - delta_accumulators = [shared_zeros(p.get_value().shape) for p in params] - updates = [] - - for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): - new_a = self.rho * a + (1 - self.rho) * g ** 2 # update accumulator - updates.append((a, new_a)) - - # use the new accumulator and the *old* delta_accumulator - update = g * T.sqrt(d_a + self.epsilon) / T.sqrt(new_a + self.epsilon) - - new_p = p - self.lr * update - updates.append((p, new_p)) - - # update delta_accumulator - new_d_a = self.rho * d_a + (1 - self.rho) * update ** 2 - updates.append((d_a, new_d_a)) - return updates - -# aliases -sgd = SGD -rmsprop = RMSprop -adagrad = Adagrad -adadelta = Adadelta - -from utils.generic_utils import get_from_module -def get(identifier): - return get_from_module(identifier, globals(), 'optimizer', instantiate=True) diff --git a/keras/preprocessing/image.py b/keras/preprocessing/image.py deleted file mode 100644 index 6d3540e8880a..000000000000 --- a/keras/preprocessing/image.py +++ /dev/null @@ -1,238 +0,0 @@ -from PIL import Image -import numpy as np -from scipy import ndimage -from scipy import linalg - -from os import listdir -from os.path import isfile, join -import random, math - -''' - Fairly basic set of tools for realtime data augmentation on image data. - Can easily be extended to include new transforms, new preprocessing methods, etc... -''' - -def random_rotation(x, rg, fill_mode="nearest", cval=0.): - angle = random.uniform(-rg, rg) - x = ndimage.interpolation.rotate(x, angle, axes=(1,2), reshape=False, mode=fill_mode, cval=cval) - return x - -def random_shift(x, wrg, hrg, fill_mode="nearest", cval=0.): - crop_left_pixels = 0 - crop_right_pixels = 0 - crop_top_pixels = 0 - crop_bottom_pixels = 0 - - original_w = x.shape[1] - original_h = x.shape[2] - - if wrg: - crop = random.uniform(0., wrg) - split = random.uniform(0, 1) - crop_left_pixels = int(split*crop*x.shape[1]) - crop_right_pixels = int((1-split)*crop*x.shape[1]) - - if hrg: - crop = random.uniform(0., hrg) - split = random.uniform(0, 1) - crop_top_pixels = int(split*crop*x.shape[2]) - crop_bottom_pixels = int((1-split)*crop*x.shape[2]) - - x = ndimage.interpolation.shift(x, (0, crop_left_pixels, crop_top_pixels), mode=fill_mode, cval=cval) - return x - -def horizontal_flip(x): - for i in range(x.shape[0]): - x[i] = np.fliplr(x[i]) - return x - -def vertical_flip(x): - for i in range(x.shape[0]): - x[i] = np.flipud(x[i]) - return x - - -def random_barrel_transform(x, intensity): - # TODO - pass - -def random_shear(x, intensity): - # TODO - pass - -def random_channel_shift(x, rg): - # TODO - pass - -def random_zoom(x, rg, fill_mode="nearest", cval=0.): - zoom_w = random.uniform(1.-rg, 1.) - zoom_h = random.uniform(1.-rg, 1.) - x = ndimage.interpolation.zoom(x, zoom=(1., zoom_w, zoom_h), mode=fill_mode, cval=cval) - return x # shape of result will be different from shape of input! - - - - -def array_to_img(x, scale=True): - x = x.transpose(1, 2, 0) - if scale: - x += max(-np.min(x), 0) - x /= np.max(x) - x *= 255 - if x.shape[2] == 3: - # RGB - return Image.fromarray(x.astype("uint8"), "RGB") - else: - # grayscale - return Image.fromarray(x.astype("uint8"), "L") - - -def img_to_array(img): - x = np.asarray(img, dtype='float32') - return x.transpose(2, 0, 1) - - -def load_img(path, grayscale=False): - img = Image.open(open(path)) - if grayscale: - img = img.convert('L') - return img - - -def list_pictures(directory, ext='jpg|jpeg|bmp|png'): - return [join(directory,f) for f in listdir(directory) \ - if isfile(join(directory,f)) and re.match('([\w]+\.(?:' + ext + '))', f)] - - - -class ImageDataGenerator(object): - ''' - Generate minibatches with - realtime data augmentation. - ''' - def __init__(self, - featurewise_center=True, # set input mean to 0 over the dataset - samplewise_center=False, # set each sample mean to 0 - featurewise_std_normalization=True, # divide inputs by std of the dataset - samplewise_std_normalization=False, # divide each input by its std - - zca_whitening=False, # apply ZCA whitening - rotation_range=0., # degrees (0 to 180) - width_shift_range=0., # fraction of total width - height_shift_range=0., # fraction of total height - horizontal_flip=False, - vertical_flip=False, - ): - self.__dict__.update(locals()) - self.mean = None - self.std = None - self.principal_components = None - - - def flow(self, X, y, batch_size=32, shuffle=False, seed=None, save_to_dir=None, save_prefix="", save_format="jpeg"): - if seed: - random.seed(seed) - - if shuffle: - seed = random.randint(1, 10e6) - np.random.seed(seed) - np.random.shuffle(X) - np.random.seed(seed) - np.random.shuffle(y) - - nb_batch = int(math.ceil(float(X.shape[0])/batch_size)) - for b in range(nb_batch): - batch_end = (b+1)*batch_size - if batch_end > X.shape[0]: - nb_samples = X.shape[0] - b*batch_size - else: - nb_samples = batch_size - - bX = np.zeros(tuple([nb_samples]+list(X.shape)[1:])) - for i in range(nb_samples): - x = X[b*batch_size+i] - x = self.random_transform(x.astype("float32")) - x = self.standardize(x) - bX[i] = x - - if save_to_dir: - for i in range(nb_samples): - img = array_to_img(bX[i], scale=True) - img.save(save_to_dir + "/" + save_prefix + "_" + str(i) + "." + save_format) - - yield bX, y[b*batch_size:b*batch_size+nb_samples] - - - def standardize(self, x): - if self.featurewise_center: - x -= self.mean - if self.featurewise_std_normalization: - x /= self.std - - if self.zca_whitening: - flatx = np.reshape(x, (x.shape[0]*x.shape[1]*x.shape[2])) - whitex = np.dot(flatx, self.principal_components) - x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2])) - - if self.samplewise_center: - x -= np.mean(x) - if self.samplewise_std_normalization: - x /= np.std(x) - - return x - - - def random_transform(self, x): - if self.rotation_range: - x = random_rotation(x, self.rotation_range) - if self.width_shift_range or self.height_shift_range: - x = random_shift(x, self.width_shift_range, self.height_shift_range) - if self.horizontal_flip: - if random.random() < 0.5: - x = horizontal_flip(x) - if self.vertical_flip: - if random.random() < 0.5: - x = vertical_flip(x) - - # TODO: - # zoom - # barrel/fisheye - # shearing - # channel shifting - return x - - - def fit(self, X, - augment=False, # fit on randomly augmented samples - rounds=1, # if augment, how many augmentation passes over the data do we use - seed=None - ): - ''' - Required for featurewise_center, featurewise_std_normalization and zca_whitening. - ''' - X = np.copy(X) - - if augment: - aX = np.zeros(tuple([rounds*X.shape[0]]+list(X.shape)[1:])) - for r in range(rounds): - for i in range(X.shape[0]): - img = array_to_img(X[i]) - img = self.random_transform(img) - aX[i+r*X.shape[0]] = img_to_array(img) - X = aX - - if self.featurewise_center: - self.mean = np.mean(X, axis=0) - X -= self.mean - if self.featurewise_std_normalization: - self.std = np.std(X) - X /= self.std - - if self.zca_whitening: - flatX = np.reshape(X, (X.shape[0], X.shape[1]*X.shape[2]*X.shape[3])) - fudge = 10e-6 - sigma = np.dot(flatX.T, flatX) / flatX.shape[1] - U, S, V = linalg.svd(sigma) - self.principal_components = np.dot(np.dot(U, np.diag(1. / np.sqrt(S + fudge))), U.T) - - diff --git a/keras/preprocessing/sequence.py b/keras/preprocessing/sequence.py deleted file mode 100644 index f71d01f2b0c1..000000000000 --- a/keras/preprocessing/sequence.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np - -def pad_sequences(seqs, maxlen=None, dtype='int32'): - """ - Pad each sequence to the same lenght: - the lenght of the longuest sequence. - - If maxlen is provided, any sequence longer - than maxlen is truncated to maxlen. - """ - lengths = [len(s) for s in seqs] - - nb_samples = len(seqs) - if maxlen is None: - maxlen = np.max(lengths) - - x = np.zeros((nb_samples, maxlen)).astype(dtype) - for idx, s in enumerate(seqs): - x[idx, :lengths[idx]] = s[:maxlen] - - return x \ No newline at end of file diff --git a/keras/preprocessing/text.py b/keras/preprocessing/text.py deleted file mode 100644 index f56ae8e96a87..000000000000 --- a/keras/preprocessing/text.py +++ /dev/null @@ -1,158 +0,0 @@ -# -*- coding: utf-8 -*- -''' - These preprocessing utils would greatly benefit - from a fast Cython rewrite. -''' - -import string -import numpy as np - -def base_filter(): - f = string.punctuation - f += '\t\n' - return f - -def text_to_word_sequence(text, filters=base_filter(), lower=True, split=" "): - '''prune: sequence of characters to filter out - ''' - if lower: - text = text.lower() - text = text.translate(string.maketrans("",""), filters) - return text.split(split) - - -def one_hot(text, n): - seq = text_to_word_sequence(text) - return [abs(hash(w))%n for w in seq] - - -class Tokenizer(object): - def __init__(self, filters=base_filter(), lower=True, nb_words=None): - self.word_counts = {} - self.word_docs = {} - self.filters = filters - self.lower = lower - self.nb_words = nb_words - self.document_count = 0 - - def fit_on_texts(self, texts): - ''' - required before using texts_to_sequences or texts_to_matrix - ''' - for text in texts: - seq = text_to_word_sequence(text, self.filters, self.lower) - for w in seq: - if w in self.word_counts: - self.word_counts[w] += 1 - else: - self.word_counts[w] = 1 - for w in set(seq): - if w in self.word_docs: - self.word_docs[w] += 1 - else: - self.word_docs[w] = 1 - self.document_count = len(texts) - - wcounts = self.word_counts.items() - wcounts.sort(key = lambda x: x[1], reverse=True) - sorted_voc = [wc[0] for wc in wcounts] - self.word_index = dict(zip(sorted_voc, range(len(sorted_voc)))) - - self.index_docs = {} - for w, c in self.word_docs.items(): - self.index_docs[self.word_index[w]] = c - - - def fit_on_sequences(self, sequences): - ''' - required before using sequences_to_matrix - (if fit_on_texts was never called) - ''' - self.document_count = len(sequences) - self.index_docs = {} - for seq in sequences: - seq = set(seq) - for i in seq: - if i not in self.index_docs: - self.index_docs[i] = 1 - else: - self.index_docs[i] += 1 - - - def texts_to_sequences(self, texts): - ''' - Transform each text in texts in a sequence of integers. - Only top "nb_words" most frequent words will be taken into account. - Only words know by the tokenizer will be taken into account. - ''' - nb_words = self.nb_words - res = [] - for text in texts: - seq = text_to_word_sequence(text, self.filters, self.lower) - vect = [] - for w in seq: - i = self.word_index.get(w) - if i is not None: - if nb_words and i >= nb_words: - pass - else: - vect.append(i) - res.append(vect) - return res - - def texts_to_matrix(self, texts, mode="binary"): - ''' - modes: binary, count, tfidf, freq - ''' - sequences = self.texts_to_sequences(texts) - return self.sequences_to_matrix(sequences, mode=mode) - - def sequences_to_matrix(self, sequences, mode="binary"): - ''' - modes: binary, count, tfidf, freq - ''' - if not self.nb_words: - if self.word_index: - nb_words = len(self.word_index) - else: - raise Exception("Specify a dimension (nb_words argument), or fit on some text data first") - else: - nb_words = self.nb_words - - if mode == "tfidf" and not self.document_count: - raise Exception("Fit the Tokenizer on some data before using tfidf mode") - - X = np.zeros((len(sequences), nb_words)) - for i, seq in enumerate(sequences): - if not seq: - pass - counts = {} - for j in seq: - if j >= nb_words: - pass - if j not in counts: - counts[j] = 1. - else: - counts[j] += 1 - for j, c in counts.items(): - if mode == "count": - X[i][j] = c - elif mode == "freq": - X[i][j] = c/len(seq) - elif mode == "binary": - X[i][j] = 1 - elif mode == "tfidf": - tf = np.log(c/len(seq)) - df = (1 + np.log(1 + self.index_docs.get(j, 0)/(1 + self.document_count))) - X[i][j] = tf / df - else: - raise Exception("Unknown vectorization mode: " + str(mode)) - return X - - - - - - - - diff --git a/keras/src/__init__.py b/keras/src/__init__.py new file mode 100644 index 000000000000..9778bcd4d63a --- /dev/null +++ b/keras/src/__init__.py @@ -0,0 +1,20 @@ +from keras.src import activations +from keras.src import applications +from keras.src import backend +from keras.src import constraints +from keras.src import datasets +from keras.src import initializers +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import regularizers +from keras.src import utils +from keras.src import visualization +from keras.src.backend import KerasTensor +from keras.src.layers import Input +from keras.src.layers import Layer +from keras.src.models import Functional +from keras.src.models import Model +from keras.src.models import Sequential +from keras.src.version import __version__ diff --git a/keras/src/activations/__init__.py b/keras/src/activations/__init__.py new file mode 100644 index 000000000000..e1a4184afa7e --- /dev/null +++ b/keras/src/activations/__init__.py @@ -0,0 +1,130 @@ +import types + +from keras.src.activations.activations import celu +from keras.src.activations.activations import elu +from keras.src.activations.activations import exponential +from keras.src.activations.activations import gelu +from keras.src.activations.activations import glu +from keras.src.activations.activations import hard_shrink +from keras.src.activations.activations import hard_sigmoid +from keras.src.activations.activations import hard_silu +from keras.src.activations.activations import hard_tanh +from keras.src.activations.activations import leaky_relu +from keras.src.activations.activations import linear +from keras.src.activations.activations import log_sigmoid +from keras.src.activations.activations import log_softmax +from keras.src.activations.activations import mish +from keras.src.activations.activations import relu +from keras.src.activations.activations import relu6 +from keras.src.activations.activations import selu +from keras.src.activations.activations import sigmoid +from keras.src.activations.activations import silu +from keras.src.activations.activations import soft_shrink +from keras.src.activations.activations import softmax +from keras.src.activations.activations import softplus +from keras.src.activations.activations import softsign +from keras.src.activations.activations import sparse_plus +from keras.src.activations.activations import sparse_sigmoid +from keras.src.activations.activations import sparsemax +from keras.src.activations.activations import squareplus +from keras.src.activations.activations import tanh +from keras.src.activations.activations import tanh_shrink +from keras.src.activations.activations import threshold +from keras.src.api_export import keras_export +from keras.src.saving import object_registration +from keras.src.saving import serialization_lib + +ALL_OBJECTS = { + relu, + leaky_relu, + relu6, + softmax, + celu, + elu, + selu, + softplus, + softsign, + squareplus, + soft_shrink, + sparse_plus, + silu, + gelu, + glu, + tanh, + tanh_shrink, + threshold, + sigmoid, + sparse_sigmoid, + exponential, + hard_sigmoid, + hard_silu, + hard_tanh, + hard_shrink, + linear, + mish, + log_softmax, + log_sigmoid, + sparsemax, +} + +ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS} +# Additional aliases +ALL_OBJECTS_DICT["swish"] = silu +ALL_OBJECTS_DICT["hard_swish"] = hard_silu + + +@keras_export("keras.activations.serialize") +def serialize(activation): + fn_config = serialization_lib.serialize_keras_object(activation) + if "config" not in fn_config: + raise ValueError( + f"Unknown activation function '{activation}' cannot be " + "serialized due to invalid function name. Make sure to use " + "an activation name that matches the references defined in " + "activations.py or use " + "`@keras.saving.register_keras_serializable()`" + "to register any custom activations. " + f"config={fn_config}" + ) + if not isinstance(activation, types.FunctionType): + # Case for additional custom activations represented by objects + return fn_config + if ( + isinstance(fn_config["config"], str) + and fn_config["config"] not in globals() + ): + # Case for custom activation functions from external activations modules + fn_config["config"] = object_registration.get_registered_name( + activation + ) + return fn_config + # Case for keras.activations builtins (simply return name) + return fn_config["config"] + + +@keras_export("keras.activations.deserialize") +def deserialize(config, custom_objects=None): + """Return a Keras activation function via its config.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.activations.get") +def get(identifier): + """Retrieve a Keras activation function via an identifier.""" + if identifier is None: + return linear + if isinstance(identifier, dict): + obj = serialization_lib.deserialize_keras_object(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + if callable(obj): + return obj + raise ValueError( + f"Could not interpret activation function identifier: {identifier}" + ) diff --git a/keras/src/activations/activations.py b/keras/src/activations/activations.py new file mode 100644 index 000000000000..889ba3d9baae --- /dev/null +++ b/keras/src/activations/activations.py @@ -0,0 +1,684 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export + + +@keras_export("keras.activations.relu") +def relu(x, negative_slope=0.0, max_value=None, threshold=0.0): + """Applies the rectified linear unit activation function. + + With default values, this returns the standard ReLU activation: + `max(x, 0)`, the element-wise maximum of 0 and the input tensor. + + Modifying default parameters allows you to use non-zero thresholds, + change the max value of the activation, + and to use a non-zero multiple of the input for values below the threshold. + + Examples: + + >>> x = [-10, -5, 0.0, 5, 10] + >>> keras.activations.relu(x) + [ 0., 0., 0., 5., 10.] + >>> keras.activations.relu(x, negative_slope=0.5) + [-5. , -2.5, 0. , 5. , 10. ] + >>> keras.activations.relu(x, max_value=5.) + [0., 0., 0., 5., 5.] + >>> keras.activations.relu(x, threshold=5.) + [-0., -0., 0., 0., 10.] + + Args: + x: Input tensor. + negative_slope: A `float` that controls the slope + for values lower than the threshold. + max_value: A `float` that sets the saturation threshold (the largest + value the function will return). + threshold: A `float` giving the threshold value of the activation + function below which values will be damped or set to zero. + + Returns: + A tensor with the same shape and dtype as input `x`. + """ + if backend.any_symbolic_tensors((x,)): + return ReLU( + negative_slope=negative_slope, + max_value=max_value, + threshold=threshold, + )(x) + return ReLU.static_call( + x, + negative_slope=negative_slope, + max_value=max_value, + threshold=threshold, + ) + + +class ReLU(ops.Operation): + def __init__( + self, negative_slope=0.0, max_value=None, threshold=0.0, name=None + ): + super().__init__(name=name) + self.negative_slope = negative_slope + self.max_value = max_value + self.threshold = threshold + + def call(self, x): + return self.static_call( + x, + negative_slope=self.negative_slope, + max_value=self.max_value, + threshold=self.threshold, + ) + + def compute_output_spec(self, x): + return backend.KerasTensor(x.shape, x.dtype) + + @staticmethod + def static_call(x, negative_slope=0.0, max_value=None, threshold=0.0): + x = backend.convert_to_tensor(x) + if negative_slope != 0.0: + if max_value is None and threshold == 0: + return backend.nn.leaky_relu(x, negative_slope=negative_slope) + + if threshold != 0: + negative_part = backend.nn.relu(-x + threshold) + else: + negative_part = backend.nn.relu(-x) + else: + negative_part = 1 + + clip_max = max_value is not None + if threshold != 0: + # computes x for x > threshold else 0 + threshold = ops.cast(threshold, dtype=x.dtype) + x = x * backend.cast( + backend.numpy.greater(x, threshold), dtype=x.dtype + ) + elif max_value == 6: + # if no threshold, then can use nn.relu6 native op for performance + x = backend.nn.relu6(x) + clip_max = False + else: + x = backend.nn.relu(x) + + if clip_max: + min_value = ops.cast(0.0, dtype=x.dtype) + max_value = ops.cast(max_value, dtype=x.dtype) + x = backend.numpy.clip(x, min_value, max_value) + + if negative_slope != 0.0: + x -= negative_slope * negative_part + return x + + +@keras_export("keras.activations.leaky_relu") +def leaky_relu(x, negative_slope=0.2): + """Leaky relu activation function. + + Args: + x: Input tensor. + negative_slope: A `float` that controls the slope + for values lower than the threshold. + """ + return ops.leaky_relu(x, negative_slope=negative_slope) + + +@keras_export("keras.activations.relu6") +def relu6(x): + """Relu6 activation function. + + It's the ReLU function, but truncated to a maximum value of 6. + + Args: + x: Input tensor. + """ + return ops.relu6(x) + + +@keras_export("keras.activations.softmax") +def softmax(x, axis=-1): + """Softmax converts a vector of values to a probability distribution. + + The elements of the output vector are in range `[0, 1]` and sum to 1. + + Each input vector is handled independently. + The `axis` argument sets which axis of the input the function + is applied along. + + Softmax is often used as the activation for the last + layer of a classification network because the result could be interpreted as + a probability distribution. + + The softmax of each vector x is computed as + `exp(x) / sum(exp(x))`. + + The input values in are the log-odds of the resulting probability. + + Args: + x: Input tensor. + axis: Integer, axis along which the softmax is applied. + """ + output = ops.softmax(x, axis=axis) + # Cache the logits to use for crossentropy loss. + try: + output._keras_logits = x + except AttributeError: + # We're dealing with a C-type. + pass + return output + + +@keras_export("keras.activations.elu") +def elu(x, alpha=1.0): + """Exponential Linear Unit. + + The exponential linear unit (ELU) with `alpha > 0` is defined as: + + - `x` if `x > 0` + - alpha * `exp(x) - 1` if `x < 0` + + ELUs have negative values which pushes the mean of the activations + closer to zero. + + Mean activations that are closer to zero enable faster learning as they + bring the gradient closer to the natural gradient. + ELUs saturate to a negative value when the argument gets smaller. + Saturation means a small derivative which decreases the variation + and the information that is propagated to the next layer. + + Args: + x: Input tensor. + alpha: A scalar, slope of positive section. Defaults to `1.0`. + + Reference: + + - [Clevert et al., 2016](https://arxiv.org/abs/1511.07289) + """ + return ops.elu(x, alpha=alpha) + + +@keras_export("keras.activations.selu") +def selu(x): + """Scaled Exponential Linear Unit (SELU). + + The Scaled Exponential Linear Unit (SELU) activation function is defined as: + + - `scale * x` if `x > 0` + - `scale * alpha * (exp(x) - 1)` if `x < 0` + + where `alpha` and `scale` are pre-defined constants + (`alpha=1.67326324` and `scale=1.05070098`). + + Basically, the SELU activation function multiplies `scale` (> 1) with the + output of the `keras.activations.elu` function to ensure a slope larger + than one for positive inputs. + + The values of `alpha` and `scale` are + chosen so that the mean and variance of the inputs are preserved + between two consecutive layers as long as the weights are initialized + correctly (see `keras.initializers.LecunNormal` initializer) + and the number of input units is "large enough" + (see reference paper for more information). + + Args: + x: Input tensor. + + Notes: + + - To be used together with the + `keras.initializers.LecunNormal` initializer. + - To be used together with the dropout variant + `keras.layers.AlphaDropout` (rather than regular dropout). + + Reference: + + - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515) + """ + return ops.selu(x) + + +@keras_export("keras.activations.softplus") +def softplus(x): + """Softplus activation function. + + It is defined as: `softplus(x) = log(exp(x) + 1)`. + + Args: + x: Input tensor. + """ + return ops.softplus(x) + + +@keras_export("keras.activations.softsign") +def softsign(x): + """Softsign activation function. + + Softsign is defined as: `softsign(x) = x / (abs(x) + 1)`. + + Args: + x: Input tensor. + """ + return ops.softsign(x) + + +@keras_export("keras.activations.soft_shrink") +def soft_shrink(x, threshold=0.5): + """Soft Shrink activation function. + + It is defined as: + + `soft_shrink(x) = x - threshold` if `x > threshold`, + `soft_shrink(x) = x + threshold` if `x < -threshold`, + `soft_shrink(x) = 0` otherwise. + + Args: + x: Input tensor. + threshold: Threshold value. Defaults to 0.5. + + """ + return ops.soft_shrink(x, threshold=threshold) + + +@keras_export("keras.activations.sparse_plus") +def sparse_plus(x): + """SparsePlus activation function. + + SparsePlus is defined as: + + `sparse_plus(x) = 0` for `x <= -1`. + `sparse_plus(x) = (1/4) * (x + 1)^2` for `-1 < x < 1`. + `sparse_plus(x) = x` for `x >= 1`. + + Args: + x: Input tensor. + + """ + return ops.sparse_plus(x) + + +@keras_export(["keras.activations.silu", "keras.activations.swish"]) +def silu(x): + """Swish (or Silu) activation function. + + It is defined as: `swish(x) = x * sigmoid(x)`. + + The Swish (or Silu) activation function is a smooth, + non-monotonic function that is unbounded above and + bounded below. + + Args: + x: Input tensor. + + Reference: + + - [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941) + """ + return ops.silu(x) + + +@keras_export("keras.activations.squareplus") +def squareplus(x, b=4): + """Squareplus activation function. + + The Squareplus activation function is defined as: + + `f(x) = (x + sqrt(x^2 + b)) / 2` + + Where `b` is a smoothness parameter. + + Args: + x: Input tensor. + b: Smoothness parameter. Defaults to 4. + + Reference: + + - [Ramachandran et al., 2021](https://arxiv.org/abs/2112.11687) + """ + return ops.squareplus(x, b=b) + + +@keras_export("keras.activations.gelu") +def gelu(x, approximate=False): + """Gaussian error linear unit (GELU) activation function. + + The Gaussian error linear unit (GELU) is defined as: + + `gelu(x) = x * P(X <= x)` where `P(X) ~ N(0, 1)`, + i.e. `gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))`. + + GELU weights inputs by their value, rather than gating + inputs by their sign as in ReLU. + + Args: + x: Input tensor. + approximate: A `bool`, whether to enable approximation. + + Reference: + + - [Hendrycks et al., 2016](https://arxiv.org/abs/1606.08415) + """ + return ops.gelu(x, approximate=approximate) + + +@keras_export("keras.activations.celu") +def celu(x, alpha=1.0): + """Continuously Differentiable Exponential Linear Unit. + + The CeLU activation function is defined as: + + `celu(x) = alpha * (exp(x / alpha) - 1) for x < 0`,`celu(x) = x for x >= 0`. + + where `alpha` is a scaling parameter that controls the activation's shape. + + Args: + x: Input tensor. + alpha: The α value for the CeLU formulation. Defaults to `1.0`. + + Reference: + + - [Barron, J. T., 2017](https://arxiv.org/abs/1704.07483) + """ + return ops.celu(x, alpha=alpha) + + +@keras_export("keras.activations.glu") +def glu(x, axis=-1): + """Gated Linear Unit (GLU) activation function. + + The GLU activation function is defined as: + + `glu(x) = a * sigmoid(b)`, + + where `x` is split into two equal parts `a` and `b` along the given axis. + + Args: + x: Input tensor. + axis: The axis along which to split the input tensor. Defaults to `-1`. + + Reference: + + - [Dauphin et al., 2017](https://arxiv.org/abs/1612.08083) + """ + return ops.glu(x, axis=axis) + + +@keras_export("keras.activations.tanh") +def tanh(x): + """Hyperbolic tangent activation function. + + It is defined as: + `tanh(x) = sinh(x) / cosh(x)`, i.e. + `tanh(x) = ((exp(x) - exp(-x)) / (exp(x) + exp(-x)))`. + + Args: + x: Input tensor. + """ + return ops.tanh(x) + + +@keras_export("keras.activations.tanh_shrink") +def tanh_shrink(x): + """Tanh shrink activation function. + + It is defined as: + + `f(x) = x - tanh(x)`. + + Args: + x: Input tensor. + """ + return ops.tanh_shrink(x) + + +@keras_export("keras.activations.hard_tanh") +def hard_tanh(x): + """HardTanh activation function. + + It is defined as: + `hard_tanh(x) = -1 for x < -1`, + `hard_tanh(x) = x for -1 <= x <= 1`, + `hard_tanh(x) = 1 for x > 1`. + + Args: + x: Input tensor. + """ + return ops.hard_tanh(x) + + +@keras_export("keras.activations.hard_shrink") +def hard_shrink(x, threshold=0.5): + """Hard Shrink activation function. + + It is defined as: + + `hard_shrink(x) = x` if `|x| > threshold`, + `hard_shrink(x) = 0` otherwise. + + Args: + x: Input tensor. + threshold: Threshold value. Defaults to 0.5. + + """ + return ops.hard_shrink(x, threshold=threshold) + + +@keras_export("keras.activations.threshold") +def threshold(x, threshold, default_value): + """Threshold activation function. + + It is defined as: + + `threshold(x) = x` if `x > threshold`, + `threshold(x) = default_value` otherwise. + + Args: + x: Input tensor. + threshold: The value that decides when to retain or replace x. + default_value: Value to assign when `x <= threshold`. + + """ + return ops.threshold(x, threshold, default_value) + + +@keras_export("keras.activations.sigmoid") +def sigmoid(x): + """Sigmoid activation function. + + It is defined as: `sigmoid(x) = 1 / (1 + exp(-x))`. + + For small values (<-5), + `sigmoid` returns a value close to zero, and for large values (>5) + the result of the function gets close to 1. + + Sigmoid is equivalent to a 2-element softmax, where the second element is + assumed to be zero. The sigmoid function always returns a value between + 0 and 1. + + Args: + x: Input tensor. + """ + output = ops.sigmoid(x) + # Cache the logits to use for crossentropy loss. + try: + output._keras_logits = x + except AttributeError: + # We're dealing with a C-type. + pass + return output + + +@keras_export("keras.activations.exponential") +def exponential(x): + """Exponential activation function. + + Args: + x: Input tensor. + """ + return ops.exp(x) + + +@keras_export("keras.activations.hard_sigmoid") +def hard_sigmoid(x): + """Hard sigmoid activation function. + + The hard sigmoid activation is defined as: + + - `0` if `if x <= -3` + - `1` if `x >= 3` + - `(x/6) + 0.5` if `-3 < x < 3` + + It's a faster, piecewise linear approximation + of the sigmoid activation. + + Args: + x: Input tensor. + + Reference: + + - [Wikipedia "Hard sigmoid"](https://en.wikipedia.org/wiki/Hard_sigmoid) + """ + return ops.hard_sigmoid(x) + + +@keras_export("keras.activations.log_sigmoid") +def log_sigmoid(x): + """Logarithm of the sigmoid activation function. + + It is defined as `f(x) = log(1 / (1 + exp(-x)))`. + + Args: + x: Input tensor. + + """ + return ops.log_sigmoid(x) + + +@keras_export("keras.activations.sparse_sigmoid") +def sparse_sigmoid(x): + """Sparse sigmoid activation function. + + It is defined as + + `f(x) = 0` for `x <= -1`, + `f(x) = 0.5 * (x + 1)` for `-1 < x < 1`, + `f(x) = 1` for `x >= 1`. + + Args: + x: Input tensor. + + Reference: + + - [M. Blondel, A. F. T. Martins, V. Niculae, 2019](https://arxiv.org/pdf/1901.02324) + + """ + return ops.sparse_sigmoid(x) + + +@keras_export(["keras.activations.hard_silu", "keras.activations.hard_swish"]) +def hard_silu(x): + """Hard SiLU activation function, also known as Hard Swish. + + It is defined as: + + - `0` if `if x < -3` + - `x` if `x > 3` + - `x * (x + 3) / 6` if `-3 <= x <= 3` + + It's a faster, piecewise linear approximation of the silu activation. + + Args: + x: Input tensor. + + Reference: + + - [A Howard, 2019](https://arxiv.org/abs/1905.02244) + """ + x = backend.convert_to_tensor(x) + return ops.hard_silu(x) + + +@keras_export("keras.activations.linear") +def linear(x): + """Linear activation function (pass-through). + + A "linear" activation is an identity function: + it returns the input, unmodified. + + Args: + x: Input tensor. + """ + return x + + +class Mish(ops.Operation): + def call(self, x): + return self.static_call(x) + + def compute_output_spec(self, x): + return backend.KerasTensor(x.shape, x.dtype) + + @staticmethod + def static_call(x): + return x * backend.nn.tanh(backend.nn.softplus(x)) + + +@keras_export("keras.activations.mish") +def mish(x): + """Mish activation function. + + It is defined as: + + `mish(x) = x * tanh(softplus(x))` + + where `softplus` is defined as: + + `softplus(x) = log(exp(x) + 1)` + + Args: + x: Input tensor. + + Reference: + + - [Misra, 2019](https://arxiv.org/abs/1908.08681) + """ + x = backend.convert_to_tensor(x) + return Mish.static_call(x) + + +@keras_export("keras.activations.log_softmax") +def log_softmax(x, axis=-1): + """Log-Softmax activation function. + + Each input vector is handled independently. + The `axis` argument sets which axis of the input the function + is applied along. + + Args: + x: Input tensor. + axis: Integer, axis along which the softmax is applied. + """ + return ops.log_softmax(x, axis=axis) + + +@keras_export(["keras.activations.sparsemax"]) +def sparsemax(x, axis=-1): + """Sparsemax activation function. + + For each batch `i`, and class `j`, + sparsemax activation function is defined as: + + `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).` + + Args: + x: Input tensor. + axis: `int`, axis along which the sparsemax operation is applied. + + Returns: + A tensor, output of sparsemax transformation. Has the same type and + shape as `x`. + + Reference: + + - [Martins et.al., 2016](https://arxiv.org/abs/1602.02068) + """ + x = backend.convert_to_tensor(x) + return ops.sparsemax(x, axis) diff --git a/keras/src/activations/activations_test.py b/keras/src/activations/activations_test.py new file mode 100644 index 000000000000..b679f16803d2 --- /dev/null +++ b/keras/src/activations/activations_test.py @@ -0,0 +1,1010 @@ +import numpy as np + +from keras.src import activations +from keras.src import backend +from keras.src import testing + + +def _ref_softmax(values): + m = np.max(values) + e = np.exp(values - m) + return e / np.sum(e) + + +def _ref_softplus(x): + return np.log(np.ones_like(x) + np.exp(x)) + + +def _ref_log_softmax(values): + max_val = np.max(values) # for numerical stability + stabilized_values = values - max_val + log_sum_exp = np.log(np.sum(np.exp(stabilized_values))) + return stabilized_values - log_sum_exp + + +def _ref_leaky_relu(x, alpha=0.2): + return x if x > 0 else alpha * x + + +def _ref_relu6(x): + return min(max(0, x), 6) + + +def _ref_silu(x): + return x / (1 + np.exp(-x)) + + +def _ref_hard_sigmoid(x): + x = (x / 6.0) + 0.5 + z = 0.0 if x <= 0 else (1.0 if x >= 1 else x) + return z + + +def _ref_sparse_sigmoid(x): + return np.where(x <= -1, 0, np.where(x >= 1, 1, 0.5 * (x + 1))) + + +def _ref_log_sigmoid(x): + return -1 * _ref_softplus(-x) + + +def _ref_hard_silu(x): + return x * np.minimum(np.maximum(0.0, x + 3.0), 6.0) * (1.0 / 6.0) + + +def _ref_sigmoid(x): + if x >= 0: + return 1 / (1 + np.exp(-x)) + else: + z = np.exp(x) + return z / (1 + z) + + +def _ref_softsign(x): + return np.divide(x, np.ones_like(x) + np.absolute(x)) + + +class ActivationsTest(testing.TestCase): + def test_softmax(self): + x = np.random.random((2, 5)) + + result = activations.softmax(x[np.newaxis, :])[0] + expected = _ref_softmax(x[0]) + self.assertAllClose(result[0], expected, rtol=1e-05) + + def test_softmax_2d_axis_0(self): + x = np.random.random((2, 5)) + result = activations.softmax(x[np.newaxis, :], axis=1)[0] + expected = np.zeros((2, 5)) + for i in range(5): + expected[:, i] = _ref_softmax(x[:, i]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_softmax_3d_axis_tuple(self): + x = np.random.random((2, 3, 5)) + result = activations.softmax(x, axis=(1, 2)) + expected = np.zeros((2, 3, 5)) + for i in range(2): + expected[i, :, :] = _ref_softmax(x[i, :, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_softmax_1d(self): + x = np.random.random(5) + result = activations.softmax(x) + expected = _ref_softmax(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_softmax_higher_dim(self): + x = np.random.random((2, 3, 4, 5)) + result = activations.softmax(x, axis=(2, 3)) + expected = np.zeros((2, 3, 4, 5)) + for i in range(2): + for j in range(3): + expected[i, j, :, :] = _ref_softmax(x[i, j, :, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_softmax_higher_dim_multiple_axes(self): + x = np.random.random((2, 3, 4, 5, 6)) + result = activations.softmax(x, axis=(2, 3, 4)) + expected = np.zeros((2, 3, 4, 5, 6)) + for i in range(2): + for j in range(3): + expected[i, j, :, :, :] = _ref_softmax(x[i, j, :, :, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_softmax_negative_axis(self): + x = np.random.random((2, 5)) + result = activations.softmax(x, axis=-1) + expected = np.zeros((2, 5)) + for i in range(2): + expected[i, :] = _ref_softmax(x[i, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_temporal_softmax(self): + x = np.random.random((2, 2, 3)) * 10 + result = activations.softmax(x[np.newaxis, :])[0] + expected = _ref_softmax(x[0, 0]) + self.assertAllClose(result[0, 0], expected, rtol=1e-05) + + def test_log_softmax_2d_axis_0(self): + x = np.random.random((2, 5)) + result = activations.log_softmax(x[np.newaxis, :], axis=1)[0] + expected = np.zeros((2, 5)) + for i in range(5): + expected[:, i] = _ref_log_softmax(x[:, i]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_log_softmax_3d_axis_tuple(self): + x = np.random.random((2, 3, 5)) + result = activations.log_softmax(x, axis=(1, 2)) + expected = np.zeros((2, 3, 5)) + for i in range(2): + expected[i, :, :] = _ref_log_softmax(x[i, :, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_log_softmax_1d(self): + x = np.random.random(5) + result = activations.log_softmax(x) + expected = _ref_log_softmax(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_log_softmax_higher_dim(self): + x = np.random.random((2, 3, 4, 5)) + result = activations.log_softmax(x, axis=(2, 3)) + expected = np.zeros((2, 3, 4, 5)) + for i in range(2): + for j in range(3): + expected[i, j, :, :] = _ref_log_softmax(x[i, j, :, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_log_softmax_higher_dim_multiple_axes(self): + x = np.random.random((2, 3, 4, 5, 6)) + result = activations.log_softmax(x, axis=(2, 3, 4)) + expected = np.zeros((2, 3, 4, 5, 6)) + for i in range(2): + for j in range(3): + expected[i, j, :, :, :] = _ref_log_softmax(x[i, j, :, :, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_log_softmax_negative_axis(self): + x = np.random.random((2, 5)) + result = activations.log_softmax(x, axis=-1) + expected = np.zeros((2, 5)) + for i in range(2): + expected[i, :] = _ref_log_softmax(x[i, :]) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_temporal_log_softmax(self): + x = np.random.random((2, 2, 3)) * 10 + result = activations.log_softmax(x[np.newaxis, :])[0] + expected = _ref_log_softmax(x[0, 0]) + self.assertAllClose(result[0, 0], expected, rtol=1e-05) + + def test_selu(self): + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + positive_values = np.array([[1, 2]], dtype=backend.floatx()) + result = activations.selu(positive_values[np.newaxis, :])[0] + self.assertAllClose(result, positive_values * scale, rtol=1e-05) + + negative_values = np.array([[-1, -2]], dtype=backend.floatx()) + result = activations.selu(negative_values[np.newaxis, :])[0] + true_result = (np.exp(negative_values) - 1) * scale * alpha + self.assertAllClose(result, true_result) + + def test_softplus(self): + # Basic test for random values between 0 and 1 + x = np.random.uniform(0, 1, (2, 5)) + result = activations.softplus(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_softplus)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.softplus(x_1d) + expected_1d = np.vectorize(_ref_softplus)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.softplus(x_3d) + expected_3d = np.vectorize(_ref_softplus)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.softplus(x_zero) + expected_zero = np.vectorize(_ref_softplus)(x_zero) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large positive values + x_large_positive = np.random.uniform(10, 100, (2, 5)) + result_large_positive = activations.softplus(x_large_positive) + expected_large_positive = np.vectorize(_ref_softplus)(x_large_positive) + self.assertAllClose( + result_large_positive, expected_large_positive, rtol=1e-05 + ) + + # Test large negative values + x_large_negative = np.random.uniform(-100, -10, (2, 5)) + result_large_negative = activations.softplus(x_large_negative) + expected_large_negative = np.vectorize(_ref_softplus)(x_large_negative) + self.assertAllClose( + result_large_negative, expected_large_negative, rtol=1e-05 + ) + + def test_softsign(self): + # Basic test for random values between 0 and 1 + x = np.random.uniform(0, 1, (2, 5)) + result = activations.softsign(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_softsign)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.softsign(x_1d) + expected_1d = np.vectorize(_ref_softsign)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.softsign(x_3d) + expected_3d = np.vectorize(_ref_softsign)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.softsign(x_zero) + expected_zero = np.vectorize(_ref_softsign)(x_zero) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large positive values + x_large_positive = np.random.uniform(10, 100, (2, 5)) + result_large_positive = activations.softsign(x_large_positive) + expected_large_positive = np.vectorize(_ref_softsign)(x_large_positive) + self.assertAllClose( + result_large_positive, expected_large_positive, rtol=1e-05 + ) + + # Test large negative values + x_large_negative = np.random.uniform(-100, -10, (2, 5)) + result_large_negative = activations.softsign(x_large_negative) + expected_large_negative = np.vectorize(_ref_softsign)(x_large_negative) + self.assertAllClose( + result_large_negative, expected_large_negative, rtol=1e-05 + ) + + def test_sigmoid(self): + # Basic test for random values between 0 and 1 + x = np.random.uniform(0, 1, (2, 5)) + result = activations.sigmoid(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_sigmoid)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.sigmoid(x_1d) + expected_1d = np.vectorize(_ref_sigmoid)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.sigmoid(x_3d) + expected_3d = np.vectorize(_ref_sigmoid)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.sigmoid(x_zero) + expected_zero = np.vectorize(_ref_sigmoid)(x_zero) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large positive values + x_large_positive = np.random.uniform(10, 100, (2, 5)) + result_large_positive = activations.sigmoid(x_large_positive) + expected_large_positive = np.vectorize(_ref_sigmoid)(x_large_positive) + self.assertAllClose( + result_large_positive, expected_large_positive, rtol=1e-05 + ) + + # Test large negative values + x_large_negative = np.random.uniform(-100, -10, (2, 5)) + result_large_negative = activations.sigmoid(x_large_negative) + expected_large_negative = np.vectorize(_ref_sigmoid)(x_large_negative) + self.assertAllClose( + result_large_negative, expected_large_negative, rtol=1e-05 + ) + + def test_hard_sigmoid(self): + # Basic test for random values between 0 and 1 + x = np.random.uniform(0, 1, (2, 5)) + result = activations.hard_sigmoid(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_hard_sigmoid)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.hard_sigmoid(x_1d) + expected_1d = np.vectorize(_ref_hard_sigmoid)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.hard_sigmoid(x_3d) + expected_3d = np.vectorize(_ref_hard_sigmoid)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test with strictly positive values much larger than 1 + x_positive_above_1 = np.random.uniform( + 5, 10, (2, 5) + ) # Adjusted this range + result_positive_above_1 = activations.hard_sigmoid(x_positive_above_1) + expected_positive_above_1 = np.ones((2, 5)) + self.assertAllClose( + result_positive_above_1, expected_positive_above_1, rtol=1e-05 + ) + + def test_sparse_sigmoid(self): + # Basic test for random values between 0 and 1 + x = np.random.uniform(0, 1, (2, 5)) + result = activations.sparse_sigmoid(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_sparse_sigmoid)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.sparse_sigmoid(x_1d) + expected_1d = np.vectorize(_ref_sparse_sigmoid)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.sparse_sigmoid(x_3d) + expected_3d = np.vectorize(_ref_sparse_sigmoid)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test large positive values + x_large_positive = np.random.uniform(10, 100, (2, 5)) + result_large_positive = activations.sparse_sigmoid(x_large_positive) + expected_large_positive = np.vectorize(_ref_sparse_sigmoid)( + x_large_positive + ) + self.assertAllClose( + result_large_positive, expected_large_positive, rtol=1e-05 + ) + + # Test large negative values + x_large_negative = np.random.uniform(-100, -10, (2, 5)) + result_large_negative = activations.sparse_sigmoid(x_large_negative) + expected_large_negative = np.vectorize(_ref_sparse_sigmoid)( + x_large_negative + ) + self.assertAllClose( + result_large_negative, expected_large_negative, rtol=1e-05 + ) + + def test_log_sigmoid(self): + # Basic test for random values between 0 and 1 + x = np.random.uniform(0, 1, (2, 5)) + result = activations.log_sigmoid(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_log_sigmoid)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.log_sigmoid(x_1d) + expected_1d = np.vectorize(_ref_log_sigmoid)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.log_sigmoid(x_3d) + expected_3d = np.vectorize(_ref_log_sigmoid)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test large positive values + x_large_positive = np.random.uniform(10, 100, (2, 5)) + result_large_positive = activations.log_sigmoid(x_large_positive) + expected_large_positive = np.vectorize(_ref_log_sigmoid)( + x_large_positive + ) + self.assertAllClose( + result_large_positive, expected_large_positive, rtol=1e-05 + ) + + # Test large negative values + x_large_negative = np.random.uniform(-100, -10, (2, 5)) + result_large_negative = activations.log_sigmoid(x_large_negative) + expected_large_negative = np.vectorize(_ref_log_sigmoid)( + x_large_negative + ) + self.assertAllClose( + result_large_negative, expected_large_negative, rtol=1e-05 + ) + + def test_hard_silu(self): + # Basic test for random values between -3 and 3 + x = np.random.uniform(-3, 3, (2, 5)).astype("float32") + result = activations.hard_silu(x[np.newaxis, :])[0] + expected = np.vectorize(_ref_hard_silu)(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5).astype("float32") + result_1d = activations.hard_silu(x_1d) + expected_1d = np.vectorize(_ref_hard_silu)(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)).astype("float32") + result_3d = activations.hard_silu(x_3d) + expected_3d = np.vectorize(_ref_hard_silu)(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test with strictly positive values much larger than 3 + x_positive_above_3 = np.random.uniform(5, 10, (2, 5)).astype("float32") + result_positive_above_3 = activations.hard_silu(x_positive_above_3) + expected_positive_above_3 = x_positive_above_3 + self.assertAllClose( + result_positive_above_3, expected_positive_above_3, rtol=1e-05 + ) + + # Test with strictly negative values much smaller than -3 + x_negatives = np.random.uniform(-10, -5, (2, 5)).astype("float32") + result = activations.hard_silu(x_negatives) + expected_zeros = np.zeros_like(x_negatives) + self.assertAllClose(result, expected_zeros, rtol=1e-05) + + def test_relu_negative_slope(self): + # Define the input tensor + x = np.array([-10, -5, 0.0, 5, 10]) + + # Test with only negative_slope + result_negative_slope = activations.relu(x, negative_slope=0.5) + expected_negative_slope = np.array([-5.0, -2.5, 0.0, 5.0, 10.0]) + self.assertAllClose( + result_negative_slope, expected_negative_slope, rtol=1e-05 + ) + + def test_relu_max_value(self): + # Define the input tensor + x = np.array([-10, -5, 0.0, 5, 10]) + + # Test with only max_value + result_max_value = activations.relu(x, max_value=5.0) + expected_max_value = np.array([0.0, 0.0, 0.0, 5.0, 5.0]) + self.assertAllClose(result_max_value, expected_max_value, rtol=1e-05) + + def test_relu_threshold(self): + # Define the input tensor + x = np.array([-10, -5, 0.0, 5, 10]) + + # Test with only threshold + result_threshold = activations.relu(x, threshold=5.0) + expected_threshold = np.array([-0.0, -0.0, 0.0, 0.0, 10.0]) + self.assertAllClose(result_threshold, expected_threshold, rtol=1e-05) + + def test_relu_combined_threshold_and_max_value(self): + # Define the input tensor + x = np.array([-10, -5, 0.0, 5, 10]) + + # Test with threshold and max_value + result_combined = activations.relu(x, threshold=5.0, max_value=5.0) + expected_combined = np.array([0.0, 0.0, 0.0, 0.0, 5.0]) + self.assertAllClose(result_combined, expected_combined, rtol=1e-05) + + def test_relu_combined_all_parameters(self): + # Define the input tensor + x = np.array([-10, -5, 0.0, 5, 10]) + + # Test with negative_slope, max_value, and threshold + result_combined = activations.relu( + x, negative_slope=0.5, max_value=5.0, threshold=5.0 + ) + expected_combined = np.array([-7.5, -5.0, -2.5, 0.0, 5.0]) + self.assertAllClose(result_combined, expected_combined, rtol=1e-05) + + def test_relu_to_trigger_relu6(self): + x = np.array([-10, -5, 0.0, 5, 10, 12]) + result_relu6 = activations.relu(x, max_value=6.0) + expected_relu6 = np.array([0.0, 0.0, 0.0, 5.0, 6.0, 6.0]) + self.assertAllClose(result_relu6, expected_relu6, rtol=1e-05) + + def test_relu_to_trigger_leaky(self): + x = np.array([-10, -5, 0.0, 5, 10]) + result_leaky = activations.relu(x, negative_slope=0.5) + expected_leaky = np.array([-5.0, -2.5, 0.0, 5.0, 10.0]) + self.assertAllClose(result_leaky, expected_leaky, rtol=1e-05) + + def test_relu(self): + # Basic test for positive values + positive_values = np.random.uniform(0.1, 10, (2, 5)) + result = activations.relu(positive_values[np.newaxis, :])[0] + self.assertAllClose(result, positive_values, rtol=1e-05) + + # Basic test for negative values + negative_values = np.random.uniform(-10, -0.1, (2, 5)) + result = activations.relu(negative_values[np.newaxis, :])[0] + expected = np.zeros((2, 5)) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.relu(x_1d) + expected_1d = np.maximum(0, x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.relu(x_3d) + expected_3d = np.maximum(0, x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.relu(x_zero) + expected_zero = np.maximum(0, x_zero) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large positive values + x_large_positive = np.random.uniform(1e4, 1e5, (2, 5)) + result_large_positive = activations.relu(x_large_positive) + self.assertAllClose(result_large_positive, x_large_positive, rtol=1e-05) + + # Test large negative values + x_large_negative = np.random.uniform(-1e5, -1e4, (2, 5)) + result_large_negative = activations.relu(x_large_negative) + expected_large_negative = np.zeros((2, 5)) + self.assertAllClose( + result_large_negative, expected_large_negative, rtol=1e-05 + ) + + def test_leaky_relu(self): + leaky_relu_vectorized = np.vectorize(_ref_leaky_relu) + + # Test for negative_slope = 0.01 + # Test positive values + positive_values = np.random.random((2, 5)) + result = activations.leaky_relu( + positive_values[np.newaxis, :], negative_slope=0.01 + )[0] + expected = leaky_relu_vectorized(positive_values, alpha=0.01) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test negative values + negative_values = np.random.uniform(-1, 0, (2, 5)) + result = activations.leaky_relu( + negative_values[np.newaxis, :], negative_slope=0.01 + )[0] + expected = leaky_relu_vectorized(negative_values, alpha=0.01) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test for negative_slope = 0.3 + # Test positive values + positive_values = np.random.random((2, 5)) + result = activations.leaky_relu( + positive_values[np.newaxis, :], negative_slope=0.3 + )[0] + expected = leaky_relu_vectorized(positive_values, alpha=0.3) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test negative values + negative_values = np.random.uniform(-1, 0, (2, 5)) + result = activations.leaky_relu( + negative_values[np.newaxis, :], negative_slope=0.3 + )[0] + expected = leaky_relu_vectorized(negative_values, alpha=0.3) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_relu6(self): + relu6_vectorized = np.vectorize(_ref_relu6) + + # Test positive values less than 6 + positive_values = np.random.uniform(0, 5.9, (2, 5)) + result = activations.relu6(positive_values[np.newaxis, :])[0] + expected = relu6_vectorized(positive_values) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test positive values greater than 6 + positive_values_above_6 = np.random.uniform(6.1, 10, (2, 5)) + result = activations.relu6(positive_values_above_6[np.newaxis, :])[0] + expected = relu6_vectorized(positive_values_above_6) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test negative values + negative_values = np.random.uniform(-1, 0, (2, 5)) + result = activations.relu6(negative_values[np.newaxis, :])[0] + expected = relu6_vectorized(negative_values) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_silu(self): + silu_vectorized = np.vectorize(_ref_silu) + + # Test positive values + positive_values = np.random.uniform(0, 5.9, (2, 5)) + result = activations.silu(positive_values[np.newaxis, :])[0] + expected = silu_vectorized(positive_values) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test values around zero (to ensure sigmoid behaves correctly) + around_zero_values = np.random.uniform(-1, 1, (2, 5)) + result = activations.silu(around_zero_values[np.newaxis, :])[0] + expected = silu_vectorized(around_zero_values) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test negative values + negative_values = np.random.uniform(-5.9, 0, (2, 5)) + result = activations.silu(negative_values[np.newaxis, :])[0] + expected = silu_vectorized(negative_values) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_gelu(self): + def gelu(x, approximate=False): + if approximate: + return ( + 0.5 + * x + * ( + 1.0 + + np.tanh( + np.sqrt(2.0 / np.pi) + * (x + 0.044715 * np.power(x, 3)) + ) + ) + ) + else: + from scipy.stats import norm + + return x * norm.cdf(x) + + x = np.random.random((2, 5)) + result = activations.gelu(x[np.newaxis, :])[0] + expected = gelu(x) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.random((2, 5)) + result = activations.gelu(x[np.newaxis, :], approximate=True)[0] + expected = gelu(x, True) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_celu(self): + def celu(x, alpha=1.0): + return np.maximum(x, 0.0) + alpha * np.expm1( + np.minimum(x, 0.0) / alpha + ) + + x = np.random.random((2, 5)) + result = activations.celu(x[np.newaxis, :])[0] + expected = celu(x) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.random((2, 5)) + result = activations.celu(x[np.newaxis, :], alpha=0.5)[0] + expected = celu(x, alpha=0.5) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_glu(self): + def glu(x, axis=-1): + x1, x2 = np.split(x, 2, axis) + return x1 * (1 / (1 + np.exp(-x2))) + + x = np.random.random((2, 4)) + result = activations.glu(x[np.newaxis, :])[0] + expected = glu(x) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.random((2, 4)) + result = activations.glu(x[np.newaxis, :], axis=-2)[0] + expected = glu(x, axis=-2) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_tanh_shrink(self): + def tanh_shrink(x): + return x - np.tanh(x) + + x = np.random.random((2, 5)) + result = activations.tanh_shrink(x[np.newaxis, :])[0] + expected = tanh_shrink(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_hard_tanh(self): + def hard_tanh(x): + return np.clip(x, -1.0, 1.0) + + x = np.random.random((2, 5)) + result = activations.hard_tanh(x[np.newaxis, :])[0] + expected = hard_tanh(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_hard_shrink(self): + def hard_shrink(x): + return np.where(np.abs(x) > 0.5, x, 0.0) + + x = np.random.random((2, 5)) + result = activations.hard_shrink(x[np.newaxis, :])[0] + expected = hard_shrink(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_threshold(self): + def threshold(x, threshold_value, value): + return np.where( + x > threshold_value, x, np.array(value, dtype=x.dtype) + ) + + x = np.random.random((2, 5)) + result = activations.threshold(x[np.newaxis, :], 0, 0)[0] + expected = threshold(x, 0, 0) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_squareplus(self): + def squareplus(x, b=4): + y = x + np.sqrt(x**2 + b) + return y / 2 + + x = np.random.random((2, 5)) + result = activations.squareplus(x[np.newaxis, :])[0] + expected = squareplus(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_soft_shrink(self): + def soft_shrink(x, threshold=0.5): + return np.where( + x > threshold, + x - threshold, + np.where(x < -threshold, x + threshold, 0.0), + ) + + x = np.random.random((2, 5)) + result = activations.soft_shrink(x[np.newaxis, :])[0] + expected = soft_shrink(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_sparse_plus(self): + def sparse_plus(x): + return np.where( + x <= -1, + np.zeros_like(x), + np.where(x < 1, (1 / 4) * (x + 1) ** 2, x), + ) + + x = np.random.random((2, 5)) + result = activations.sparse_plus(x[np.newaxis, :])[0] + expected = sparse_plus(x) + self.assertAllClose(result, expected, rtol=1e-05) + + def test_elu(self): + x = np.random.random((2, 5)) + result = activations.elu(x[np.newaxis, :])[0] + self.assertAllClose(result, x, rtol=1e-05) + negative_values = np.array([[-1, -2]], dtype=backend.floatx()) + result = activations.elu(negative_values[np.newaxis, :])[0] + true_result = np.exp(negative_values) - 1 + self.assertAllClose(result, true_result) + + def test_tanh(self): + # Basic test for the tanh activation function + x = np.random.random((2, 5)) + result = activations.tanh(x[np.newaxis, :])[0] + expected = np.tanh(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Basic test for the tanh activation function + x = np.random.uniform(-10, 10, (2, 5)) + result = activations.tanh(x[np.newaxis, :])[0] + expected = np.tanh(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.tanh(x_1d) + expected_1d = np.tanh(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.tanh(x_3d) + expected_3d = np.tanh(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test with strictly positive values + x_positive = np.random.uniform(0, 10, (2, 5)) + result_positive = activations.tanh(x_positive) + expected_positive = np.tanh(x_positive) + self.assertAllClose(result_positive, expected_positive, rtol=1e-05) + + # Test with strictly negative values + x_negative = np.random.uniform(-10, 0, (2, 5)) + result_negative = activations.tanh(x_negative) + expected_negative = np.tanh(x_negative) + self.assertAllClose(result_negative, expected_negative, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.tanh(x_zero) + expected_zero = np.tanh(x_zero) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large values to check stability + x_large = np.random.uniform(1e4, 1e5, (2, 5)) + result_large = activations.tanh(x_large) + expected_large = np.tanh(x_large) + self.assertAllClose(result_large, expected_large, rtol=1e-05) + + def test_exponential(self): + # Basic test for the exponential activation function + x = np.random.random((2, 5)) + result = activations.exponential(x[np.newaxis, :])[0] + expected = np.exp(x) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.uniform(-10, 10, (2, 5)) + result = activations.exponential(x[np.newaxis, :])[0] + expected = np.exp(x) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.exponential(x_1d) + expected_1d = np.exp(x_1d) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.exponential(x_3d) + expected_3d = np.exp(x_3d) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test with strictly positive values + x_positive = np.random.uniform(0, 10, (2, 5)) + result_positive = activations.exponential(x_positive) + expected_positive = np.exp(x_positive) + self.assertAllClose(result_positive, expected_positive, rtol=1e-05) + + # Test with strictly negative values + x_negative = np.random.uniform(-10, 0, (2, 5)) + result_negative = activations.exponential(x_negative) + expected_negative = np.exp(x_negative) + self.assertAllClose(result_negative, expected_negative, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.exponential(x_zero) + expected_zero = np.exp(x_zero) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large values to check stability + x_large = np.random.uniform(1e4, 1e5, (2, 5)) + result_large = activations.exponential(x_large) + expected_large = np.exp(x_large) + self.assertAllClose(result_large, expected_large, rtol=1e-05) + + def test_mish(self): + # Basic test for the mish activation function + x = np.random.random((2, 5)) + result = activations.mish(x[np.newaxis, :])[0] + expected = x * np.tanh(_ref_softplus(x)) + self.assertAllClose(result, expected, rtol=1e-05) + + x = np.random.uniform(-10, 10, (2, 5)) + result = activations.mish(x[np.newaxis, :])[0] + expected = x * np.tanh(_ref_softplus(x)) + self.assertAllClose(result, expected, rtol=1e-05) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + result_1d = activations.mish(x_1d) + expected_1d = x_1d * np.tanh(_ref_softplus(x_1d)) + self.assertAllClose(result_1d, expected_1d, rtol=1e-05) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (3, 3, 3)) + result_3d = activations.mish(x_3d) + expected_3d = x_3d * np.tanh(_ref_softplus(x_3d)) + self.assertAllClose(result_3d, expected_3d, rtol=1e-05) + + # Test with strictly positive values + x_positive = np.random.uniform(0, 10, (2, 5)) + result_positive = activations.mish(x_positive) + expected_positive = x_positive * np.tanh(_ref_softplus(x_positive)) + self.assertAllClose(result_positive, expected_positive, rtol=1e-05) + + # Test with strictly negative values + x_negative = np.random.uniform(-10, 0, (2, 5)) + result_negative = activations.mish(x_negative) + expected_negative = x_negative * np.tanh(_ref_softplus(x_negative)) + self.assertAllClose(result_negative, expected_negative, rtol=1e-05) + + # Test near zero values + x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5)) + result_zero = activations.mish(x_zero) + expected_zero = x_zero * np.tanh(_ref_softplus(x_zero)) + self.assertAllClose(result_zero, expected_zero, rtol=1e-05) + + # Test large values to check stability + x_large = np.random.uniform(1e4, 1e5, (2, 5)) + result_large = activations.mish(x_large) + expected_large = x_large * np.tanh(_ref_softplus(x_large)) + self.assertAllClose(result_large, expected_large, rtol=1e-05) + + def test_linear(self): + x = np.random.random((10, 5)) + self.assertAllClose(x, activations.linear(x)) + + # Test with 1D array + x_1d = np.random.uniform(-10, 10, 5) + self.assertAllClose(x_1d, activations.linear(x_1d)) + + # Test with 2D array + x = np.random.uniform(-10, 10, (10, 5)) + self.assertAllClose(x, activations.linear(x)) + + # Test with 3D array + x_3d = np.random.uniform(-10, 10, (5, 5, 5)) + self.assertAllClose(x_3d, activations.linear(x_3d)) + + # Test with float32 data type + x_float32 = np.random.uniform(-10, 10, (10, 5)).astype(np.float32) + self.assertAllClose(x_float32, activations.linear(x_float32)) + # Test with int32 data type + x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32) + self.assertAllClose(x_int32, activations.linear(x_int32)) + + def test_sparsemax(self): + # result check with 1d + x_1d = np.linspace(1, 12, num=12) + expected_result = np.zeros_like(x_1d) + expected_result[-1] = 1.0 + self.assertAllClose(expected_result, activations.sparsemax(x_1d)) + + # result check with 2d + x_2d = np.linspace(1, 12, num=12).reshape(-1, 2) + expected_result = np.zeros_like(x_2d) + expected_result[:, -1] = 1.0 + self.assertAllClose(expected_result, activations.sparsemax(x_2d)) + + # result check with 3d + x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3) + expected_result = np.zeros_like(x_3d) + expected_result[:, :, -1] = 1.0 + self.assertAllClose(expected_result, activations.sparsemax(x_3d)) + + # result check with axis=-2 with 2d input + x_2d = np.linspace(1, 12, num=12).reshape(-1, 2) + expected_result = np.zeros_like(x_2d) + expected_result[-1, :] = 1.0 + self.assertAllClose( + expected_result, activations.sparsemax(x_2d, axis=-2) + ) + + # result check with axis=-2 with 3d input + x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3) + expected_result = np.ones_like(x_3d) + self.assertAllClose( + expected_result, activations.sparsemax(x_3d, axis=-2) + ) + + # result check with axis=-3 with 3d input + x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3) + expected_result = np.zeros_like(x_3d) + expected_result[-1, :, :] = 1.0 + self.assertAllClose( + expected_result, activations.sparsemax(x_3d, axis=-3) + ) + + # result check with axis=-3 with 4d input + x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2) + expected_result = np.ones_like(x_4d) + self.assertAllClose( + expected_result, activations.sparsemax(x_4d, axis=-3) + ) + + def test_get_method(self): + obj = activations.get("relu") + self.assertEqual(obj, activations.relu) + + obj = activations.get(None) + self.assertEqual(obj, activations.linear) + + with self.assertRaises(ValueError): + activations.get("typo") diff --git a/keras/src/api_export.py b/keras/src/api_export.py new file mode 100644 index 000000000000..76d007dc2af0 --- /dev/null +++ b/keras/src/api_export.py @@ -0,0 +1,49 @@ +try: + import namex +except ImportError: + namex = None + + +# These dicts reference "canonical names" only +# (i.e. the first name an object was registered with). +REGISTERED_NAMES_TO_OBJS = {} +REGISTERED_OBJS_TO_NAMES = {} + + +def register_internal_serializable(path, symbol): + global REGISTERED_NAMES_TO_OBJS + if isinstance(path, (list, tuple)): + name = path[0] + else: + name = path + REGISTERED_NAMES_TO_OBJS[name] = symbol + REGISTERED_OBJS_TO_NAMES[symbol] = name + + +def get_symbol_from_name(name): + return REGISTERED_NAMES_TO_OBJS.get(name, None) + + +def get_name_from_symbol(symbol): + return REGISTERED_OBJS_TO_NAMES.get(symbol, None) + + +if namex: + + class keras_export(namex.export): + def __init__(self, path): + super().__init__(package="keras", path=path) + + def __call__(self, symbol): + register_internal_serializable(self.path, symbol) + return super().__call__(symbol) + +else: + + class keras_export: + def __init__(self, path): + self.path = path + + def __call__(self, symbol): + register_internal_serializable(self.path, symbol) + return symbol diff --git a/keras/src/applications/__init__.py b/keras/src/applications/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/applications/applications_test.py b/keras/src/applications/applications_test.py new file mode 100644 index 000000000000..c43627e261e2 --- /dev/null +++ b/keras/src/applications/applications_test.py @@ -0,0 +1,294 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.applications import convnext +from keras.src.applications import densenet +from keras.src.applications import efficientnet +from keras.src.applications import efficientnet_v2 +from keras.src.applications import inception_resnet_v2 +from keras.src.applications import inception_v3 +from keras.src.applications import mobilenet +from keras.src.applications import mobilenet_v2 +from keras.src.applications import mobilenet_v3 +from keras.src.applications import nasnet +from keras.src.applications import resnet +from keras.src.applications import resnet_v2 +from keras.src.applications import vgg16 +from keras.src.applications import vgg19 +from keras.src.applications import xception +from keras.src.layers import Conv2D +from keras.src.layers import Input +from keras.src.saving import serialization_lib +from keras.src.utils import file_utils +from keras.src.utils import image_utils + +try: + import PIL +except ImportError: + PIL = None + +MODEL_LIST = [ + # vgg + (vgg16.VGG16, 512, vgg16), + (vgg19.VGG19, 512, vgg19), + # xception + (xception.Xception, 2048, xception), + # inception + (inception_v3.InceptionV3, 2048, inception_v3), + (inception_resnet_v2.InceptionResNetV2, 1536, inception_resnet_v2), + # mobilenet + (mobilenet.MobileNet, 1024, mobilenet), + (mobilenet_v2.MobileNetV2, 1280, mobilenet_v2), + (mobilenet_v3.MobileNetV3Small, 576, mobilenet_v3), + (mobilenet_v3.MobileNetV3Large, 960, mobilenet_v3), + # efficientnet + (efficientnet.EfficientNetB0, 1280, efficientnet), + (efficientnet.EfficientNetB1, 1280, efficientnet), + (efficientnet.EfficientNetB2, 1408, efficientnet), + (efficientnet.EfficientNetB3, 1536, efficientnet), + (efficientnet.EfficientNetB4, 1792, efficientnet), + (efficientnet.EfficientNetB5, 2048, efficientnet), + (efficientnet.EfficientNetB6, 2304, efficientnet), + (efficientnet.EfficientNetB7, 2560, efficientnet), + (efficientnet_v2.EfficientNetV2B0, 1280, efficientnet_v2), + (efficientnet_v2.EfficientNetV2B1, 1280, efficientnet_v2), + (efficientnet_v2.EfficientNetV2B2, 1408, efficientnet_v2), + (efficientnet_v2.EfficientNetV2B3, 1536, efficientnet_v2), + (efficientnet_v2.EfficientNetV2S, 1280, efficientnet_v2), + (efficientnet_v2.EfficientNetV2M, 1280, efficientnet_v2), + (efficientnet_v2.EfficientNetV2L, 1280, efficientnet_v2), + # densenet + (densenet.DenseNet121, 1024, densenet), + (densenet.DenseNet169, 1664, densenet), + (densenet.DenseNet201, 1920, densenet), + # convnext + (convnext.ConvNeXtTiny, 768, convnext), + (convnext.ConvNeXtSmall, 768, convnext), + (convnext.ConvNeXtBase, 1024, convnext), + (convnext.ConvNeXtLarge, 1536, convnext), + (convnext.ConvNeXtXLarge, 2048, convnext), + # nasnet + (nasnet.NASNetMobile, 1056, nasnet), + (nasnet.NASNetLarge, 4032, nasnet), + # resnet + (resnet.ResNet50, 2048, resnet), + (resnet.ResNet101, 2048, resnet), + (resnet.ResNet152, 2048, resnet), + (resnet_v2.ResNet50V2, 2048, resnet_v2), + (resnet_v2.ResNet101V2, 2048, resnet_v2), + (resnet_v2.ResNet152V2, 2048, resnet_v2), +] +MODELS_UNSUPPORTED_CHANNELS_FIRST = ["ConvNeXt", "DenseNet", "NASNet"] + +# Add names for `named_parameters`, and add each data format for each model +test_parameters = [ + ( + "{}_{}".format(model[0].__name__, image_data_format), + *model, + image_data_format, + ) + for image_data_format in ["channels_first", "channels_last"] + for model in MODEL_LIST +] + + +def _get_elephant(target_size): + # For models that don't include a Flatten step, + # the default is to accept variable-size inputs + # even when loading ImageNet weights (since it is possible). + # In this case, default to 299x299. + TEST_IMAGE_PATH = ( + "https://storage.googleapis.com/tensorflow/" + "keras-applications/tests/elephant.jpg" + ) + + if target_size[0] is None: + target_size = (299, 299) + test_image = file_utils.get_file("elephant.jpg", TEST_IMAGE_PATH) + img = image_utils.load_img(test_image, target_size=tuple(target_size)) + x = image_utils.img_to_array(img) + return np.expand_dims(x, axis=0) + + +@pytest.mark.skipif( + os.environ.get("SKIP_APPLICATIONS_TESTS"), + reason="Env variable set to skip.", +) +@pytest.mark.requires_trainable_backend +class ApplicationsTest(testing.TestCase): + @classmethod + def setUpClass(cls): + cls.original_image_data_format = backend.image_data_format() + + @classmethod + def tearDownClass(cls): + backend.set_image_data_format(cls.original_image_data_format) + + def skip_if_invalid_image_data_format_for_model( + self, app, image_data_format + ): + does_not_support_channels_first = any( + [ + unsupported_name.lower() in app.__name__.lower() + for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST + ] + ) + if ( + image_data_format == "channels_first" + and does_not_support_channels_first + ): + self.skipTest( + "{} does not support channels first".format(app.__name__) + ) + + @parameterized.named_parameters(test_parameters) + def test_application_notop_variable_input_channels( + self, app, last_dim, _, image_data_format + ): + if app == nasnet.NASNetMobile and backend.backend() == "torch": + self.skipTest( + "NASNetMobile pretrained incorrect with torch backend." + ) + self.skip_if_invalid_image_data_format_for_model(app, image_data_format) + backend.set_image_data_format(image_data_format) + + # Test compatibility with 1 channel + if image_data_format == "channels_first": + input_shape = (1, None, None) + correct_output_shape = [None, last_dim, None, None] + else: + input_shape = (None, None, 1) + correct_output_shape = [None, None, None, last_dim] + + model = app(weights=None, include_top=False, input_shape=input_shape) + output_shape = list(model.outputs[0].shape) + self.assertEqual(output_shape, correct_output_shape) + + # Test compatibility with 4 channels + if image_data_format == "channels_first": + input_shape = (4, None, None) + else: + input_shape = (None, None, 4) + model = app(weights=None, include_top=False, input_shape=input_shape) + output_shape = list(model.outputs[0].shape) + self.assertEqual(output_shape, correct_output_shape) + + @parameterized.named_parameters(test_parameters) + @pytest.mark.skipif(PIL is None, reason="Requires PIL.") + def test_application_base(self, app, _, app_module, image_data_format): + import tensorflow as tf + + if app == nasnet.NASNetMobile and backend.backend() == "torch": + self.skipTest( + "NASNetMobile pretrained incorrect with torch backend." + ) + if ( + image_data_format == "channels_first" + and len(tf.config.list_physical_devices("GPU")) == 0 + and backend.backend() == "tensorflow" + ): + self.skipTest( + "Conv2D doesn't support channels_first using CPU with " + "tensorflow backend" + ) + self.skip_if_invalid_image_data_format_for_model(app, image_data_format) + backend.set_image_data_format(image_data_format) + + # Can be instantiated with default arguments + model = app(weights="imagenet") + + # Can run a correct inference on a test image + if image_data_format == "channels_first": + shape = model.input_shape[2:4] + else: + shape = model.input_shape[1:3] + x = _get_elephant(shape) + + x = app_module.preprocess_input(x) + preds = model.predict(x) + names = [p[1] for p in app_module.decode_predictions(preds)[0]] + # Test correct label is in top 3 (weak correctness test). + self.assertIn("African_elephant", names[:3]) + + # Can be serialized and deserialized + config = serialization_lib.serialize_keras_object(model) + reconstructed_model = serialization_lib.deserialize_keras_object(config) + self.assertEqual(len(model.weights), len(reconstructed_model.weights)) + + @parameterized.named_parameters(test_parameters) + def test_application_notop_custom_input_shape( + self, app, last_dim, _, image_data_format + ): + if app == nasnet.NASNetMobile and backend.backend() == "torch": + self.skipTest( + "NASNetMobile pretrained incorrect with torch backend." + ) + self.skip_if_invalid_image_data_format_for_model(app, image_data_format) + backend.set_image_data_format(image_data_format) + + if image_data_format == "channels_first": + input_shape = (3, 123, 123) + last_dim_axis = 1 + else: + input_shape = (123, 123, 3) + last_dim_axis = -1 + model = app(weights=None, include_top=False, input_shape=input_shape) + output_shape = list(model.outputs[0].shape) + self.assertEqual(output_shape[last_dim_axis], last_dim) + + @parameterized.named_parameters(test_parameters) + def test_application_notop_custom_input_tensor( + self, app, last_dim, _, image_data_format + ): + if app == nasnet.NASNetMobile and backend.backend() == "torch": + self.skipTest( + "NASNetMobile pretrained incorrect with torch backend." + ) + self.skip_if_invalid_image_data_format_for_model(app, image_data_format) + backend.set_image_data_format(image_data_format) + + if image_data_format == "channels_first": + input_shape = (4, 123, 123) + last_dim_axis = 1 + else: + input_shape = (123, 123, 4) + last_dim_axis = -1 + + inputs_custom = Input(shape=input_shape, name="custom_input") + inputs_custom = Conv2D(3, (2, 2), padding="valid", strides=(2, 2))( + inputs_custom + ) + model = app(weights=None, include_top=False, input_tensor=inputs_custom) + output_shape = list(model.outputs[0].shape) + self.assertEqual(output_shape[last_dim_axis], last_dim) + + @parameterized.named_parameters(test_parameters) + def test_application_pooling(self, app, last_dim, _, image_data_format): + if app == nasnet.NASNetMobile and backend.backend() == "torch": + self.skipTest( + "NASNetMobile pretrained incorrect with torch backend." + ) + self.skip_if_invalid_image_data_format_for_model(app, image_data_format) + backend.set_image_data_format(image_data_format) + + model = app(weights=None, include_top=False, pooling="max") + output_shape = list(model.outputs[0].shape) + self.assertEqual(output_shape, [None, last_dim]) + + @parameterized.named_parameters(test_parameters) + def test_application_classifier_activation(self, app, *_): + if app == nasnet.NASNetMobile and backend.backend() == "torch": + self.skipTest( + "NASNetMobile pretrained incorrect with torch backend." + ) + + model = app( + weights=None, include_top=True, classifier_activation="softmax" + ) + last_layer_act = model.layers[-1].activation.__name__ + self.assertEqual(last_layer_act, "softmax") diff --git a/keras/src/applications/convnext.py b/keras/src/applications/convnext.py new file mode 100644 index 000000000000..39e9b52fa75d --- /dev/null +++ b/keras/src/applications/convnext.py @@ -0,0 +1,789 @@ +import numpy as np + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import ops +from keras.src import random +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.layers.layer import Layer +from keras.src.models import Functional +from keras.src.models import Sequential +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/convnext/" +) + +WEIGHTS_HASHES = { + "convnext_tiny": ( + "8ae6e78ce2933352b1ef4008e6dd2f17bc40771563877d156bc6426c7cf503ff", + "d547c096cabd03329d7be5562c5e14798aa39ed24b474157cef5e85ab9e49ef1", + ), + "convnext_small": ( + "ce1277d8f1ee5a0ef0e171469089c18f5233860ceaf9b168049cb9263fd7483c", + "6fc8009faa2f00c1c1dfce59feea9b0745eb260a7dd11bee65c8e20843da6eab", + ), + "convnext_base": ( + "52cbb006d3dadd03f6e095a8ca1aca47aecdd75acb4bc74bce1f5c695d0086e6", + "40a20c5548a5e9202f69735ecc06c990e6b7c9d2de39f0361e27baeb24cb7c45", + ), + "convnext_large": ( + "070c5ed9ed289581e477741d3b34beffa920db8cf590899d6d2c67fba2a198a6", + "96f02b6f0753d4f543261bc9d09bed650f24dd6bc02ddde3066135b63d23a1cd", + ), + "convnext_xlarge": ( + "c1f5ccab661354fc3a79a10fa99af82f0fbf10ec65cb894a3ae0815f17a889ee", + "de3f8a54174130e0cecdc71583354753d557fcf1f4487331558e2a16ba0cfe05", + ), +} + + +MODEL_CONFIGS = { + "tiny": { + "depths": [3, 3, 9, 3], + "projection_dims": [96, 192, 384, 768], + "default_size": 224, + }, + "small": { + "depths": [3, 3, 27, 3], + "projection_dims": [96, 192, 384, 768], + "default_size": 224, + }, + "base": { + "depths": [3, 3, 27, 3], + "projection_dims": [128, 256, 512, 1024], + "default_size": 224, + }, + "large": { + "depths": [3, 3, 27, 3], + "projection_dims": [192, 384, 768, 1536], + "default_size": 224, + }, + "xlarge": { + "depths": [3, 3, 27, 3], + "projection_dims": [256, 512, 1024, 2048], + "default_size": 224, + }, +} + +BASE_DOCSTRING = """Instantiates the {name} architecture. + +References: +- [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) +(CVPR 2022) + +For image classification use cases, see +[this page for detailed examples]( +https://keras.io/api/applications/#usage-examples-for-image-classification-models). +For transfer learning use cases, make sure to read the +[guide to transfer learning & fine-tuning]( +https://keras.io/guides/transfer_learning/). + +The `base`, `large`, and `xlarge` models were first pre-trained on the +ImageNet-21k dataset and then fine-tuned on the ImageNet-1k dataset. The +pre-trained parameters of the models were assembled from the +[official repository](https://github.com/facebookresearch/ConvNeXt). To get a +sense of how these parameters were converted to Keras compatible parameters, +please refer to +[this repository](https://github.com/sayakpaul/keras-convnext-conversion). + +Note: Each Keras Application expects a specific kind of input preprocessing. +For ConvNeXt, preprocessing is included in the model using a `Normalization` +layer. ConvNeXt models expect their inputs to be float or uint8 tensors of +pixels with values in the [0-255] range. + +When calling the `summary()` method after instantiating a ConvNeXt model, +prefer setting the `expand_nested` argument `summary()` to `True` to better +investigate the instantiated model. + +Args: + include_top: Whether to include the fully-connected + layer at the top of the network. Defaults to `True`. + weights: One of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet-1k), or the path to the weights + file to be loaded. Defaults to `"imagenet"`. + input_tensor: Optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: Optional shape tuple, only to be specified + if `include_top` is `False`. + It should have exactly 3 inputs channels. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. Defaults to None. + - `None` means that the output of the model will be + the 4D tensor output of the last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. Defaults to 1000 (number of + ImageNet classes). + classifier_activation: A `str` or callable. The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + Defaults to `"softmax"`. + When loading pretrained weights, `classifier_activation` can only + be `None` or `"softmax"`. + name: The name of the model (string). + +Returns: + A model instance. +""" + + +class StochasticDepth(Layer): + """Stochastic Depth module. + + It performs batch-wise dropping rather than sample-wise. In libraries like + `timm`, it's similar to `DropPath` layers that drops residual paths + sample-wise. + + References: + - https://github.com/rwightman/pytorch-image-models + + Args: + drop_path_rate (float): Probability of dropping paths. Should be within + [0, 1]. + + Returns: + Tensor either with the residual path dropped or kept. + """ + + def __init__(self, drop_path_rate, **kwargs): + super().__init__(**kwargs) + self.drop_path_rate = drop_path_rate + + def call(self, x, training=None): + if training: + keep_prob = 1 - self.drop_path_rate + shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1) + random_tensor = keep_prob + random.uniform(shape, 0, 1) + random_tensor = ops.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + def get_config(self): + config = super().get_config() + config.update({"drop_path_rate": self.drop_path_rate}) + return config + + +class LayerScale(Layer): + """Layer scale module. + + References: + + - https://arxiv.org/abs/2103.17239 + + Args: + init_values (float): Initial value for layer scale. Should be within + [0, 1]. + projection_dim (int): Projection dimensionality. + + Returns: + Tensor multiplied to the scale. + """ + + def __init__(self, init_values, projection_dim, **kwargs): + super().__init__(**kwargs) + self.init_values = init_values + self.projection_dim = projection_dim + + def build(self, _): + self.gamma = self.add_weight( + shape=(self.projection_dim,), + initializer=initializers.Constant(self.init_values), + trainable=True, + ) + + def call(self, x): + return x * self.gamma + + def get_config(self): + config = super().get_config() + config.update( + { + "init_values": self.init_values, + "projection_dim": self.projection_dim, + } + ) + return config + + +def ConvNeXtBlock( + projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name=None +): + """ConvNeXt block. + + References: + - https://arxiv.org/abs/2201.03545 + - https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py + + Notes: + In the original ConvNeXt implementation (linked above), the authors use + `Dense` layers for pointwise convolutions for increased efficiency. + Following that, this implementation also uses the same. + + Args: + projection_dim (int): Number of filters for convolution layers. In the + ConvNeXt paper, this is referred to as projection dimension. + drop_path_rate (float): Probability of dropping paths. Should be within + [0, 1]. + layer_scale_init_value (float): Layer scale value. + Should be a small float number. + name: name to path to the keras layer. + + Returns: + A function representing a ConvNeXtBlock block. + """ + if name is None: + name = f"prestem{str(backend.get_uid('prestem'))}" + + def apply(inputs): + x = inputs + + x = layers.Conv2D( + filters=projection_dim, + kernel_size=7, + padding="same", + groups=projection_dim, + name=f"{name}_depthwise_conv", + )(x) + x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_layernorm")(x) + x = layers.Dense(4 * projection_dim, name=f"{name}_pointwise_conv_1")(x) + x = layers.Activation("gelu", name=f"{name}_gelu")(x) + x = layers.Dense(projection_dim, name=f"{name}_pointwise_conv_2")(x) + + if layer_scale_init_value is not None: + x = LayerScale( + layer_scale_init_value, + projection_dim, + name=f"{name}_layer_scale", + )(x) + if drop_path_rate: + layer = StochasticDepth( + drop_path_rate, name=f"{name}_stochastic_depth" + ) + else: + layer = layers.Activation("linear", name=f"{name}_identity") + + return inputs + layer(x) + + return apply + + +def PreStem(name=None): + """Normalizes inputs with ImageNet-1k mean and std.""" + if name is None: + name = "prestem{0}".format(str(backend.get_uid("prestem"))) + + def apply(x): + x = layers.Normalization( + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + variance=[ + (0.229 * 255) ** 2, + (0.224 * 255) ** 2, + (0.225 * 255) ** 2, + ], + name=f"{name}_prestem_normalization", + )(x) + return x + + return apply + + +def Head(num_classes=1000, classifier_activation=None, name=None): + """Implementation of classification head of ConvNeXt. + + Args: + num_classes: number of classes for Dense layer + classifier_activation: activation function for the Dense layer + name: name prefix + + Returns: + Classification head function. + """ + if name is None: + name = str(backend.get_uid("head")) + + def apply(x): + x = layers.GlobalAveragePooling2D(name=f"{name}_head_gap")(x) + x = layers.LayerNormalization( + epsilon=1e-6, name=f"{name}_head_layernorm" + )(x) + x = layers.Dense( + num_classes, + activation=classifier_activation, + name=f"{name}_head_dense", + )(x) + return x + + return apply + + +def ConvNeXt( + depths, + projection_dims, + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + default_size=224, + name="convnext", + include_preprocessing=True, + include_top=True, + weights=None, + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + weights_name=None, +): + """Instantiates ConvNeXt architecture given specific configuration. + + Args: + depths: An iterable containing depths for each individual stages. + projection_dims: An iterable containing output number of channels of + each individual stages. + drop_path_rate: Stochastic depth probability. If 0.0, then stochastic + depth won't be used. + layer_scale_init_value: Layer scale coefficient. If 0.0, layer scaling + won't be used. + default_size: Default input image size. + name: An optional name for the model. + include_preprocessing: boolean denoting whether to + include preprocessing in the model. + When `weights="imagenet"` this should always be `True`. + But for other models (e.g., randomly initialized) you should set it + to `False` and apply preprocessing to data accordingly. + include_top: Boolean denoting whether to include classification + head to the model. + weights: one of `None` (random initialization), `"imagenet"` + (pre-training on ImageNet-1k), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to + use as image input for the model. + input_shape: optional shape tuple, only to be specified if `include_top` + is `False`. It should have exactly 3 inputs channels. + pooling: optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the last convolutional layer. + - `avg` means that global average pooling will be applied + to the output of the last convolutional layer, + and thus the output of the model will be a 2D tensor. + - `max` means that global max pooling will be applied. + classes: optional number of classes to classify images into, + only to be specified if `include_top` is `True`, + and if no `weights` argument is specified. + classifier_activation: A `str` or callable. + The activation function to use + on the "top" layer. Ignored unless `include_top=True`. + Set `classifier_activation=None` to return the logits + of the "top" layer. + + Returns: + A model instance. + """ + if backend.image_data_format() == "channels_first": + raise ValueError( + "ConvNeXt does not support the `channels_first` image data " + "format. Switch to `channels_last` by editing your local " + "config file at ~/.keras/keras.json" + ) + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top=True`, ' + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape. + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor)[0] + x = input_tensor + else: + inputs = img_input + x = inputs + + if include_preprocessing: + channel_axis = ( + 3 if backend.image_data_format() == "channels_last" else 1 + ) + num_channels = input_shape[channel_axis - 1] + if num_channels == 3: + x = PreStem(name=name)(x) + + # Stem block. + stem = Sequential( + [ + layers.Conv2D( + projection_dims[0], + kernel_size=4, + strides=4, + name=f"{name}_stem_conv", + ), + layers.LayerNormalization( + epsilon=1e-6, name=f"{name}_stem_layernorm" + ), + ], + name=f"{name}_stem", + ) + + # Downsampling blocks. + downsample_layers = [] + downsample_layers.append(stem) + + num_downsample_layers = 3 + for i in range(num_downsample_layers): + downsample_layer = Sequential( + [ + layers.LayerNormalization( + epsilon=1e-6, + name=f"{name}_downsampling_layernorm_{i}", + ), + layers.Conv2D( + projection_dims[i + 1], + kernel_size=2, + strides=2, + name=f"{name}_downsampling_conv_{i}", + ), + ], + name=f"{name}_downsampling_block_{i}", + ) + downsample_layers.append(downsample_layer) + + # Stochastic depth schedule. + # This is referred from the original ConvNeXt codebase: + # https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L86 + depth_drop_rates = [ + float(x) for x in np.linspace(0.0, drop_path_rate, sum(depths)) + ] + + # First apply downsampling blocks and then apply ConvNeXt stages. + cur = 0 + + num_convnext_blocks = 4 + for i in range(num_convnext_blocks): + x = downsample_layers[i](x) + for j in range(depths[i]): + x = ConvNeXtBlock( + projection_dim=projection_dims[i], + drop_path_rate=depth_drop_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + name=name + f"_stage_{i}_block_{j}", + )(x) + cur += depths[i] + + if include_top: + imagenet_utils.validate_activation(classifier_activation, weights) + x = Head( + num_classes=classes, + classifier_activation=classifier_activation, + name=name, + )(x) + + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + x = layers.LayerNormalization(epsilon=1e-6)(x) + + model = Functional(inputs=inputs, outputs=x, name=name) + + # Validate weights before requesting them from the API + if weights == "imagenet": + expected_config = MODEL_CONFIGS[weights_name.split("convnext_")[-1]] + if ( + depths != expected_config["depths"] + or projection_dims != expected_config["projection_dims"] + ): + raise ValueError( + f"Architecture configuration does not match {weights_name} " + f"variant. When using pre-trained weights, the model " + f"architecture must match the pre-trained configuration " + f"exactly. Expected depths: {expected_config['depths']}, " + f"got: {depths}. Expected projection_dims: " + f"{expected_config['projection_dims']}, got: {projection_dims}." + ) + + if weights_name not in name: + raise ValueError( + f'Model name "{name}" does not match weights variant ' + f'"{weights_name}". When using imagenet weights, model name ' + f'must contain the weights variant (e.g., "convnext_' + f'{weights_name.split("convnext_")[-1]}").' + ) + + # Load weights. + if weights == "imagenet": + if include_top: + file_suffix = ".h5" + file_hash = WEIGHTS_HASHES[weights_name][0] + else: + file_suffix = "_notop.h5" + file_hash = WEIGHTS_HASHES[weights_name][1] + file_name = name + file_suffix + weights_path = file_utils.get_file( + file_name, + BASE_WEIGHTS_PATH + file_name, + cache_subdir="models", + file_hash=file_hash, + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +## Instantiating variants ## + + +@keras_export( + [ + "keras.applications.convnext.ConvNeXtTiny", + "keras.applications.ConvNeXtTiny", + ] +) +def ConvNeXtTiny( + include_top=True, + include_preprocessing=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="convnext_tiny", +): + return ConvNeXt( + weights_name="convnext_tiny", + depths=MODEL_CONFIGS["tiny"]["depths"], + projection_dims=MODEL_CONFIGS["tiny"]["projection_dims"], + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + default_size=MODEL_CONFIGS["tiny"]["default_size"], + name=name, + include_top=include_top, + include_preprocessing=include_preprocessing, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.convnext.ConvNeXtSmall", + "keras.applications.ConvNeXtSmall", + ] +) +def ConvNeXtSmall( + include_top=True, + include_preprocessing=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="convnext_small", +): + return ConvNeXt( + weights_name="convnext_small", + depths=MODEL_CONFIGS["small"]["depths"], + projection_dims=MODEL_CONFIGS["small"]["projection_dims"], + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + default_size=MODEL_CONFIGS["small"]["default_size"], + name=name, + include_top=include_top, + include_preprocessing=include_preprocessing, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.convnext.ConvNeXtBase", + "keras.applications.ConvNeXtBase", + ] +) +def ConvNeXtBase( + include_top=True, + include_preprocessing=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="convnext_base", +): + return ConvNeXt( + weights_name="convnext_base", + depths=MODEL_CONFIGS["base"]["depths"], + projection_dims=MODEL_CONFIGS["base"]["projection_dims"], + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + default_size=MODEL_CONFIGS["base"]["default_size"], + name=name, + include_top=include_top, + include_preprocessing=include_preprocessing, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.convnext.ConvNeXtLarge", + "keras.applications.ConvNeXtLarge", + ] +) +def ConvNeXtLarge( + include_top=True, + include_preprocessing=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="convnext_large", +): + return ConvNeXt( + weights_name="convnext_large", + depths=MODEL_CONFIGS["large"]["depths"], + projection_dims=MODEL_CONFIGS["large"]["projection_dims"], + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + default_size=MODEL_CONFIGS["large"]["default_size"], + name=name, + include_top=include_top, + include_preprocessing=include_preprocessing, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.convnext.ConvNeXtXLarge", + "keras.applications.ConvNeXtXLarge", + ] +) +def ConvNeXtXLarge( + include_top=True, + include_preprocessing=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="convnext_xlarge", +): + return ConvNeXt( + weights_name="convnext_xlarge", + depths=MODEL_CONFIGS["xlarge"]["depths"], + projection_dims=MODEL_CONFIGS["xlarge"]["projection_dims"], + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + default_size=MODEL_CONFIGS["xlarge"]["default_size"], + name=name, + include_top=include_top, + include_preprocessing=include_preprocessing, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +ConvNeXtTiny.__doc__ = BASE_DOCSTRING.format(name="ConvNeXtTiny") +ConvNeXtSmall.__doc__ = BASE_DOCSTRING.format(name="ConvNeXtSmall") +ConvNeXtBase.__doc__ = BASE_DOCSTRING.format(name="ConvNeXtBase") +ConvNeXtLarge.__doc__ = BASE_DOCSTRING.format(name="ConvNeXtLarge") +ConvNeXtXLarge.__doc__ = BASE_DOCSTRING.format(name="ConvNeXtXLarge") + + +@keras_export("keras.applications.convnext.preprocess_input") +def preprocess_input(x, data_format=None): + """A placeholder method for backward compatibility. + + The preprocessing logic has been included in the convnext model + implementation. Users are no longer required to call this method to + normalize the input data. This method does nothing and only kept as a + placeholder to align the API surface between old and new version of model. + + Args: + x: A floating point `numpy.array` or a tensor. + data_format: Optional data format of the image tensor/array. Defaults to + None, in which case the global setting + `keras.backend.image_data_format()` is used + (unless you changed it, it defaults to `"channels_last"`).{mode} + + Returns: + Unchanged `numpy.array` or tensor. + """ + return x + + +@keras_export("keras.applications.convnext.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/densenet.py b/keras/src/applications/densenet.py new file mode 100644 index 000000000000..9021f2ba0093 --- /dev/null +++ b/keras/src/applications/densenet.py @@ -0,0 +1,492 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/densenet/" +) +DENSENET121_WEIGHT_PATH = ( + f"{BASE_WEIGHTS_PATH}densenet121_weights_tf_dim_ordering_tf_kernels.h5" +) +DENSENET121_WEIGHT_PATH_NO_TOP = ( + f"{BASE_WEIGHTS_PATH}" + "densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5" +) +DENSENET169_WEIGHT_PATH = ( + f"{BASE_WEIGHTS_PATH}densenet169_weights_tf_dim_ordering_tf_kernels.h5" +) +DENSENET169_WEIGHT_PATH_NO_TOP = ( + f"{BASE_WEIGHTS_PATH}" + "densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5" +) +DENSENET201_WEIGHT_PATH = ( + f"{BASE_WEIGHTS_PATH}densenet201_weights_tf_dim_ordering_tf_kernels.h5" +) +DENSENET201_WEIGHT_PATH_NO_TOP = ( + f"{BASE_WEIGHTS_PATH}" + "densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5" +) + + +def dense_block(x, blocks, name): + """A dense block. + + Args: + x: input tensor. + blocks: integer, the number of building blocks. + name: string, block label. + + Returns: + Output tensor for the block. + """ + for i in range(blocks): + x = conv_block(x, 32, name=f"{name}_block{i + 1}") + return x + + +def transition_block(x, reduction, name): + """A transition block. + + Args: + x: input tensor. + reduction: float, compression rate at transition layers. + name: string, block label. + + Returns: + Output tensor for the block. + """ + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_bn" + )(x) + x = layers.Activation("relu", name=f"{name}_relu")(x) + x = layers.Conv2D( + int(x.shape[bn_axis] * reduction), + 1, + use_bias=False, + name=f"{name}_conv", + )(x) + x = layers.AveragePooling2D(2, strides=2, name=f"{name}_pool")(x) + return x + + +def conv_block(x, growth_rate, name): + """A building block for a dense block. + + Args: + x: input tensor. + growth_rate: float, growth rate at dense layers. + name: string, block label. + + Returns: + Output tensor for the block. + """ + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + x1 = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_0_bn" + )(x) + x1 = layers.Activation("relu", name=f"{name}_0_relu")(x1) + x1 = layers.Conv2D( + 4 * growth_rate, 1, use_bias=False, name=f"{name}_1_conv" + )(x1) + x1 = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_1_bn" + )(x1) + x1 = layers.Activation("relu", name=f"{name}_1_relu")(x1) + x1 = layers.Conv2D( + growth_rate, 3, padding="same", use_bias=False, name=f"{name}_2_conv" + )(x1) + x = layers.Concatenate(axis=bn_axis, name=f"{name}_concat")([x, x1]) + return x + + +def DenseNet( + blocks, + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="densenet", +): + """Instantiates the DenseNet architecture. + + Reference: + - [Densely Connected Convolutional Networks]( + https://arxiv.org/abs/1608.06993) (CVPR 2017) + + This function returns a Keras image classification model, + optionally loaded with weights pre-trained on ImageNet. + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + Note: each Keras Application expects a specific kind of input preprocessing. + For DenseNet, call `keras.applications.densenet.preprocess_input` + on your inputs before passing them to the model. + `densenet.preprocess_input` will scale pixels between 0 and 1 and then + will normalize each channel with respect to the ImageNet + dataset statistics. + + Args: + blocks: numbers of building blocks for the four dense layers. + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(224, 224, 3)` + (with `'channels_last'` data format) + or `(3, 224, 224)` (with `'channels_first'` data format). + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(200, 200, 3)` would be one valid value. + pooling: optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. Defaults to `1000`. + classifier_activation: A `str` or callable. + The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` + can only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A model instance. + """ + if backend.image_data_format() == "channels_first": + raise ValueError( + "DenseNet does not support the `channels_first` image data " + "format. Switch to `channels_last` by editing your local " + "config file at ~/.keras/keras.json" + ) + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights` as `"imagenet"` with `include_top`' + " as true, `classes` should be 1000" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=224, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + + x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input) + x = layers.Conv2D(64, 7, strides=2, use_bias=False, name="conv1_conv")(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name="conv1_bn" + )(x) + x = layers.Activation("relu", name="conv1_relu")(x) + x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x) + x = layers.MaxPooling2D(3, strides=2, name="pool1")(x) + + x = dense_block(x, blocks[0], name="conv2") + x = transition_block(x, 0.5, name="pool2") + x = dense_block(x, blocks[1], name="conv3") + x = transition_block(x, 0.5, name="pool3") + x = dense_block(x, blocks[2], name="conv4") + x = transition_block(x, 0.5, name="pool4") + x = dense_block(x, blocks[3], name="conv5") + + x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name="bn")(x) + x = layers.Activation("relu", name="relu")(x) + + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + if blocks == [6, 12, 24, 16]: + weights_path = file_utils.get_file( + "densenet121_weights_tf_dim_ordering_tf_kernels.h5", + DENSENET121_WEIGHT_PATH, + cache_subdir="models", + file_hash="9d60b8095a5708f2dcce2bca79d332c7", + ) + elif blocks == [6, 12, 32, 32]: + weights_path = file_utils.get_file( + "densenet169_weights_tf_dim_ordering_tf_kernels.h5", + DENSENET169_WEIGHT_PATH, + cache_subdir="models", + file_hash="d699b8f76981ab1b30698df4c175e90b", + ) + elif blocks == [6, 12, 48, 32]: + weights_path = file_utils.get_file( + "densenet201_weights_tf_dim_ordering_tf_kernels.h5", + DENSENET201_WEIGHT_PATH, + cache_subdir="models", + file_hash="1ceb130c1ea1b78c3bf6114dbdfd8807", + ) + else: + raise ValueError("weights_path undefined") + else: + if blocks == [6, 12, 24, 16]: + weights_path = file_utils.get_file( + "densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5", + DENSENET121_WEIGHT_PATH_NO_TOP, + cache_subdir="models", + file_hash="30ee3e1110167f948a6b9946edeeb738", + ) + elif blocks == [6, 12, 32, 32]: + weights_path = file_utils.get_file( + "densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5", + DENSENET169_WEIGHT_PATH_NO_TOP, + cache_subdir="models", + file_hash="b8c4d4c20dd625c148057b9ff1c1176b", + ) + elif blocks == [6, 12, 48, 32]: + weights_path = file_utils.get_file( + "densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5", + DENSENET201_WEIGHT_PATH_NO_TOP, + cache_subdir="models", + file_hash="c13680b51ded0fb44dff2d8f86ac8bb1", + ) + else: + raise ValueError("weights_path undefined") + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export( + [ + "keras.applications.densenet.DenseNet121", + "keras.applications.DenseNet121", + ] +) +def DenseNet121( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="densenet121", +): + """Instantiates the Densenet121 architecture.""" + return DenseNet( + [6, 12, 24, 16], + include_top, + weights, + input_tensor, + input_shape, + pooling, + classes, + classifier_activation, + name=name, + ) + + +@keras_export( + [ + "keras.applications.densenet.DenseNet169", + "keras.applications.DenseNet169", + ] +) +def DenseNet169( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="densenet169", +): + """Instantiates the Densenet169 architecture.""" + return DenseNet( + [6, 12, 32, 32], + include_top, + weights, + input_tensor, + input_shape, + pooling, + classes, + classifier_activation, + name=name, + ) + + +@keras_export( + [ + "keras.applications.densenet.DenseNet201", + "keras.applications.DenseNet201", + ] +) +def DenseNet201( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="densenet201", +): + """Instantiates the Densenet201 architecture.""" + return DenseNet( + [6, 12, 48, 32], + include_top, + weights, + input_tensor, + input_shape, + pooling, + classes, + classifier_activation, + name=name, + ) + + +@keras_export("keras.applications.densenet.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="torch" + ) + + +@keras_export("keras.applications.densenet.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ + +DOC = """ + +Reference: +- [Densely Connected Convolutional Networks]( + https://arxiv.org/abs/1608.06993) (CVPR 2017) + +Optionally loads weights pre-trained on ImageNet. +Note that the data format convention used by the model is +the one specified in your Keras config at `~/.keras/keras.json`. + +Note: each Keras Application expects a specific kind of input preprocessing. +For DenseNet, call `keras.applications.densenet.preprocess_input` +on your inputs before passing them to the model. + +Args: + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(224, 224, 3)` (with `'channels_last'` data format) + or `(3, 224, 224)` (with `'channels_first'` data format). + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(200, 200, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. Defaults to 1000. + classifier_activation: A `str` or callable. + The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits + of the "top" layer. When loading pretrained weights, + `classifier_activation` can only be `None` or `"softmax"`. + name: The name of the model (string). + +Returns: + A Keras model instance. +""" + +setattr(DenseNet121, "__doc__", DenseNet121.__doc__ + DOC) +setattr(DenseNet169, "__doc__", DenseNet169.__doc__ + DOC) +setattr(DenseNet201, "__doc__", DenseNet201.__doc__ + DOC) diff --git a/keras/src/applications/efficientnet.py b/keras/src/applications/efficientnet.py new file mode 100644 index 000000000000..44dcad9bc8c2 --- /dev/null +++ b/keras/src/applications/efficientnet.py @@ -0,0 +1,856 @@ +import copy +import math + +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHTS_PATH = "https://storage.googleapis.com/keras-applications/" + +WEIGHTS_HASHES = { + "b0": ( + "902e53a9f72be733fc0bcb005b3ebbac", + "50bc09e76180e00e4465e1a485ddc09d", + ), + "b1": ( + "1d254153d4ab51201f1646940f018540", + "74c4e6b3e1f6a1eea24c589628592432", + ), + "b2": ( + "b15cce36ff4dcbd00b6dd88e7857a6ad", + "111f8e2ac8aa800a7a99e3239f7bfb39", + ), + "b3": ( + "ffd1fdc53d0ce67064dc6a9c7960ede0", + "af6d107764bb5b1abb91932881670226", + ), + "b4": ( + "18c95ad55216b8f92d7e70b3a046e2fc", + "ebc24e6d6c33eaebbd558eafbeedf1ba", + ), + "b5": ( + "ace28f2a6363774853a83a0b21b9421a", + "38879255a25d3c92d5e44e04ae6cec6f", + ), + "b6": ( + "165f6e37dce68623721b423839de8be5", + "9ecce42647a20130c1f39a5d4cb75743", + ), + "b7": ( + "8c03f828fec3ef71311cd463b6759d99", + "cbcfe4450ddf6f3ad90b1b398090fe4a", + ), +} + +DEFAULT_BLOCKS_ARGS = [ + { + "kernel_size": 3, + "repeats": 1, + "filters_in": 32, + "filters_out": 16, + "expand_ratio": 1, + "id_skip": True, + "strides": 1, + "se_ratio": 0.25, + }, + { + "kernel_size": 3, + "repeats": 2, + "filters_in": 16, + "filters_out": 24, + "expand_ratio": 6, + "id_skip": True, + "strides": 2, + "se_ratio": 0.25, + }, + { + "kernel_size": 5, + "repeats": 2, + "filters_in": 24, + "filters_out": 40, + "expand_ratio": 6, + "id_skip": True, + "strides": 2, + "se_ratio": 0.25, + }, + { + "kernel_size": 3, + "repeats": 3, + "filters_in": 40, + "filters_out": 80, + "expand_ratio": 6, + "id_skip": True, + "strides": 2, + "se_ratio": 0.25, + }, + { + "kernel_size": 5, + "repeats": 3, + "filters_in": 80, + "filters_out": 112, + "expand_ratio": 6, + "id_skip": True, + "strides": 1, + "se_ratio": 0.25, + }, + { + "kernel_size": 5, + "repeats": 4, + "filters_in": 112, + "filters_out": 192, + "expand_ratio": 6, + "id_skip": True, + "strides": 2, + "se_ratio": 0.25, + }, + { + "kernel_size": 3, + "repeats": 1, + "filters_in": 192, + "filters_out": 320, + "expand_ratio": 6, + "id_skip": True, + "strides": 1, + "se_ratio": 0.25, + }, +] + +CONV_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 2.0, + "mode": "fan_out", + "distribution": "truncated_normal", + }, +} + +DENSE_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 1.0 / 3.0, + "mode": "fan_out", + "distribution": "uniform", + }, +} + +BASE_DOCSTRING = """Instantiates the {name} architecture. + +Reference: +- [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks]( + https://arxiv.org/abs/1905.11946) (ICML 2019) + +This function returns a Keras image classification model, +optionally loaded with weights pre-trained on ImageNet. + +For image classification use cases, see +[this page for detailed examples]( +https://keras.io/api/applications/#usage-examples-for-image-classification-models). + +For transfer learning use cases, make sure to read the +[guide to transfer learning & fine-tuning]( +https://keras.io/guides/transfer_learning/). + +Note: each Keras Application expects a specific kind of input preprocessing. +For EfficientNet, input preprocessing is included as part of the model +(as a `Rescaling` layer), and thus +`keras.applications.efficientnet.preprocess_input` is actually a +pass-through function. EfficientNet models expect their inputs to be float +tensors of pixels with values in the `[0-255]` range. + +Args: + include_top: Whether to include the fully-connected + layer at the top of the network. Defaults to `True`. + weights: One of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + Defaults to `"imagenet"`. + input_tensor: Optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: Optional shape tuple, only to be specified + if `include_top` is False. + It should have exactly 3 inputs channels. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. Defaults to `None`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. 1000 is how many + ImageNet classes there are. Defaults to `1000`. + classifier_activation: A `str` or callable. The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + Defaults to `'softmax'`. + When loading pretrained weights, `classifier_activation` can only + be `None` or `"softmax"`. + name: The name of the model (string). + +Returns: + A model instance. +""" + + +IMAGENET_STDDEV_RGB = [0.229, 0.224, 0.225] + + +def EfficientNet( + width_coefficient, + depth_coefficient, + default_size, + dropout_rate=0.2, + drop_connect_rate=0.2, + depth_divisor=8, + activation="swish", + blocks_args="default", + name="efficientnet", + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + weights_name=None, +): + """Instantiates the EfficientNet architecture. + + Args: + width_coefficient: float, scaling coefficient for network width. + depth_coefficient: float, scaling coefficient for network depth. + default_size: integer, default input image size. + dropout_rate: float, dropout rate before final classifier layer. + drop_connect_rate: float, dropout rate at skip connections. + depth_divisor: integer, a unit of network width. + activation: activation function. + blocks_args: list of dicts, parameters to construct block modules. + name: string, model name. + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + 'imagenet' (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is False. + It should have exactly 3 inputs channels. + pooling: optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is True, and + if no `weights` argument is specified. + classifier_activation: A `str` or callable. The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + + Returns: + A model instance. + """ + if blocks_args == "default": + blocks_args = DEFAULT_BLOCKS_ARGS + + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top`' + " as true, `classes` should be 1000" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + + def round_filters(filters, divisor=depth_divisor): + """Round number of filters based on depth multiplier.""" + filters *= width_coefficient + new_filters = max( + divisor, int(filters + divisor / 2) // divisor * divisor + ) + # Make sure that round down does not go down by more than 10%. + if new_filters < 0.9 * filters: + new_filters += divisor + return int(new_filters) + + def round_repeats(repeats): + """Round number of repeats based on depth multiplier.""" + return int(math.ceil(depth_coefficient * repeats)) + + # Build stem + x = img_input + x = layers.Rescaling(1.0 / 255.0)(x) + x = layers.Normalization(axis=bn_axis)(x) + + if weights == "imagenet": + # Note that the normalization layer uses square value of STDDEV as the + # variance for the layer: result = (input - mean) / sqrt(var) + # However, the original implementation uses (input - mean) / var to + # normalize the input, we need to divide another sqrt(var) to match the + # original implementation. + # See https://github.com/tensorflow/tensorflow/issues/49930 for more + # details + x = layers.Rescaling( + [1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB] + )(x) + + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, 3), name="stem_conv_pad" + )(x) + x = layers.Conv2D( + round_filters(32), + 3, + strides=2, + padding="valid", + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name="stem_conv", + )(x) + x = layers.BatchNormalization(axis=bn_axis, name="stem_bn")(x) + x = layers.Activation(activation, name="stem_activation")(x) + + # Build blocks + blocks_args = copy.deepcopy(blocks_args) + + b = 0 + blocks = float(sum(round_repeats(args["repeats"]) for args in blocks_args)) + for i, args in enumerate(blocks_args): + assert args["repeats"] > 0 + # Update block input and output filters based on depth multiplier. + args["filters_in"] = round_filters(args["filters_in"]) + args["filters_out"] = round_filters(args["filters_out"]) + + for j in range(round_repeats(args.pop("repeats"))): + # The first block needs to take care of stride and filter size + # increase. + if j > 0: + args["strides"] = 1 + args["filters_in"] = args["filters_out"] + x = block( + x, + activation, + drop_connect_rate * b / blocks, + name=f"block{i + 1}{chr(j + 97)}_", + **args, + ) + b += 1 + + # Build top + x = layers.Conv2D( + round_filters(1280), + 1, + padding="same", + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name="top_conv", + )(x) + x = layers.BatchNormalization(axis=bn_axis, name="top_bn")(x) + x = layers.Activation(activation, name="top_activation")(x) + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + if dropout_rate > 0: + x = layers.Dropout(dropout_rate, name="top_dropout")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, + activation=classifier_activation, + kernel_initializer=DENSE_KERNEL_INITIALIZER, + name="predictions", + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + file_suffix = ".h5" + file_hash = WEIGHTS_HASHES[weights_name][0] + else: + file_suffix = "_notop.h5" + file_hash = WEIGHTS_HASHES[weights_name][1] + file_name = name + file_suffix + weights_path = file_utils.get_file( + file_name, + BASE_WEIGHTS_PATH + file_name, + cache_subdir="models", + file_hash=file_hash, + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + return model + + +def block( + inputs, + activation="swish", + drop_rate=0.0, + name="", + filters_in=32, + filters_out=16, + kernel_size=3, + strides=1, + expand_ratio=1, + se_ratio=0.0, + id_skip=True, +): + """An inverted residual block. + + Args: + inputs: input tensor. + activation: activation function. + drop_rate: float between 0 and 1, fraction of the input units to drop. + name: string, block label. + filters_in: integer, the number of input filters. + filters_out: integer, the number of output filters. + kernel_size: integer, the dimension of the convolution window. + strides: integer, the stride of the convolution. + expand_ratio: integer, scaling coefficient for the input filters. + se_ratio: float between 0 and 1, fraction to squeeze the input filters. + id_skip: boolean. + + Returns: + output tensor for the block. + """ + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + + # Expansion phase + filters = filters_in * expand_ratio + if expand_ratio != 1: + x = layers.Conv2D( + filters, + 1, + padding="same", + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}expand_conv", + )(inputs) + x = layers.BatchNormalization(axis=bn_axis, name=f"{name}expand_bn")(x) + x = layers.Activation(activation, name=f"{name}expand_activation")(x) + else: + x = inputs + + # Depthwise Convolution + if strides == 2: + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, kernel_size), + name=f"{name}dwconv_pad", + )(x) + conv_pad = "valid" + else: + conv_pad = "same" + x = layers.DepthwiseConv2D( + kernel_size, + strides=strides, + padding=conv_pad, + use_bias=False, + depthwise_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}dwconv", + )(x) + x = layers.BatchNormalization(axis=bn_axis, name=f"{name}bn")(x) + x = layers.Activation(activation, name=f"{name}activation")(x) + + # Squeeze and Excitation phase + if 0 < se_ratio <= 1: + filters_se = max(1, int(filters_in * se_ratio)) + se = layers.GlobalAveragePooling2D(name=f"{name}se_squeeze")(x) + if bn_axis == 1: + se_shape = (filters, 1, 1) + else: + se_shape = (1, 1, filters) + se = layers.Reshape(se_shape, name=f"{name}se_reshape")(se) + se = layers.Conv2D( + filters_se, + 1, + padding="same", + activation=activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}se_reduce", + )(se) + se = layers.Conv2D( + filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}se_expand", + )(se) + x = layers.multiply([x, se], name=f"{name}se_excite") + + # Output phase + x = layers.Conv2D( + filters_out, + 1, + padding="same", + use_bias=False, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}project_conv", + )(x) + x = layers.BatchNormalization(axis=bn_axis, name=f"{name}project_bn")(x) + if id_skip and strides == 1 and filters_in == filters_out: + if drop_rate > 0: + x = layers.Dropout( + drop_rate, noise_shape=(None, 1, 1, 1), name=f"{name}drop" + )(x) + x = layers.add([x, inputs], name=f"{name}add") + return x + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB0", + "keras.applications.EfficientNetB0", + ] +) +def EfficientNetB0( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb0", +): + return EfficientNet( + 1.0, + 1.0, + 224, + 0.2, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b0", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB1", + "keras.applications.EfficientNetB1", + ] +) +def EfficientNetB1( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb1", +): + return EfficientNet( + 1.0, + 1.1, + 240, + 0.2, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b1", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB2", + "keras.applications.EfficientNetB2", + ] +) +def EfficientNetB2( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb2", +): + return EfficientNet( + 1.1, + 1.2, + 260, + 0.3, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b2", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB3", + "keras.applications.EfficientNetB3", + ] +) +def EfficientNetB3( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb3", +): + return EfficientNet( + 1.2, + 1.4, + 300, + 0.3, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b3", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB4", + "keras.applications.EfficientNetB4", + ] +) +def EfficientNetB4( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb4", +): + return EfficientNet( + 1.4, + 1.8, + 380, + 0.4, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b4", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB5", + "keras.applications.EfficientNetB5", + ] +) +def EfficientNetB5( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb5", +): + return EfficientNet( + 1.6, + 2.2, + 456, + 0.4, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b5", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB6", + "keras.applications.EfficientNetB6", + ] +) +def EfficientNetB6( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb6", +): + return EfficientNet( + 1.8, + 2.6, + 528, + 0.5, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b6", + ) + + +@keras_export( + [ + "keras.applications.efficientnet.EfficientNetB7", + "keras.applications.EfficientNetB7", + ] +) +def EfficientNetB7( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="efficientnetb7", +): + return EfficientNet( + 2.0, + 3.1, + 600, + 0.5, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + weights_name="b7", + ) + + +EfficientNetB0.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB0") +EfficientNetB1.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB1") +EfficientNetB2.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB2") +EfficientNetB3.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB3") +EfficientNetB4.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB4") +EfficientNetB5.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB5") +EfficientNetB6.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB6") +EfficientNetB7.__doc__ = BASE_DOCSTRING.format(name="EfficientNetB7") + + +@keras_export("keras.applications.efficientnet.preprocess_input") +def preprocess_input(x, data_format=None): + """A placeholder method for backward compatibility. + + The preprocessing logic has been included in the efficientnet model + implementation. Users are no longer required to call this method to + normalize the input data. This method does nothing and only kept as a + placeholder to align the API surface between old and new version of model. + + Args: + x: A floating point `numpy.array` or a tensor. + data_format: Optional data format of the image tensor/array. `None` + means the global setting `keras.backend.image_data_format()` + is used (unless you changed it, it uses `"channels_last"`). + Defaults to `None`. + + Returns: + Unchanged `numpy.array` or tensor. + """ + return x + + +@keras_export("keras.applications.efficientnet.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/efficientnet_v2.py b/keras/src/applications/efficientnet_v2.py new file mode 100644 index 000000000000..86e8e2827844 --- /dev/null +++ b/keras/src/applications/efficientnet_v2.py @@ -0,0 +1,1368 @@ +import copy +import math + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHTS_PATH = "https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/" # noqa: E501 + +WEIGHTS_HASHES = { + "b0": ( + "21ecbf6da12460d5c40bb2f29ceb2188", + "893217f2bb855e2983157299931e43ff", + ), + "b1": ( + "069f0534ff22adf035c89e2d9547a9dc", + "0e80663031ca32d657f9caa404b6ec37", + ), + "b2": ( + "424e49f28180edbde1e94797771950a7", + "1dfe2e7a5d45b6632553a8961ea609eb", + ), + "b3": ( + "1f1fc43bd98a6e4fd8fdfd551e02c7a0", + "f6abf7b5849ac99a89b50dd3fd532856", + ), + "-s": ( + "e1d88a8495beba45748fedd0cecbe016", + "af0682fb74e8c54910f2d4393339c070", + ), + "-m": ( + "a3bf6aa3276309f4fc6a34aa114c95cd", + "1b8dc055df72dde80d614482840fe342", + ), + "-l": ( + "27e6d408b53c7ebc868fefa357689935", + "b0b66b5c863aef5b46e8608fe1711615", + ), +} + +DEFAULT_BLOCKS_ARGS = { + "efficientnetv2-s": [ + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 24, + "output_filters": 24, + "expand_ratio": 1, + "se_ratio": 0.0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 4, + "input_filters": 24, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0.0, + "strides": 2, + "conv_type": 1, + }, + { + "conv_type": 1, + "expand_ratio": 4, + "input_filters": 48, + "kernel_size": 3, + "num_repeat": 4, + "output_filters": 64, + "se_ratio": 0, + "strides": 2, + }, + { + "conv_type": 0, + "expand_ratio": 4, + "input_filters": 64, + "kernel_size": 3, + "num_repeat": 6, + "output_filters": 128, + "se_ratio": 0.25, + "strides": 2, + }, + { + "conv_type": 0, + "expand_ratio": 6, + "input_filters": 128, + "kernel_size": 3, + "num_repeat": 9, + "output_filters": 160, + "se_ratio": 0.25, + "strides": 1, + }, + { + "conv_type": 0, + "expand_ratio": 6, + "input_filters": 160, + "kernel_size": 3, + "num_repeat": 15, + "output_filters": 256, + "se_ratio": 0.25, + "strides": 2, + }, + ], + "efficientnetv2-m": [ + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 24, + "output_filters": 24, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 24, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 48, + "output_filters": 80, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 80, + "output_filters": 160, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 14, + "input_filters": 160, + "output_filters": 176, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 18, + "input_filters": 176, + "output_filters": 304, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 304, + "output_filters": 512, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + ], + "efficientnetv2-l": [ + { + "kernel_size": 3, + "num_repeat": 4, + "input_filters": 32, + "output_filters": 32, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 32, + "output_filters": 64, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 64, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 10, + "input_filters": 96, + "output_filters": 192, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 19, + "input_filters": 192, + "output_filters": 224, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 25, + "input_filters": 224, + "output_filters": 384, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 384, + "output_filters": 640, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + ], + "efficientnetv2-b0": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + ], + "efficientnetv2-b1": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + ], + "efficientnetv2-b2": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + ], + "efficientnetv2-b3": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1, + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0, + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0, + }, + ], +} + +CONV_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 2.0, + "mode": "fan_out", + "distribution": "truncated_normal", + }, +} + +DENSE_KERNEL_INITIALIZER = { + "class_name": "VarianceScaling", + "config": { + "scale": 1.0 / 3.0, + "mode": "fan_out", + "distribution": "uniform", + }, +} + +BASE_DOCSTRING = """Instantiates the {name} architecture. + +Reference: +- [EfficientNetV2: Smaller Models and Faster Training]( + https://arxiv.org/abs/2104.00298) (ICML 2021) + +This function returns a Keras image classification model, +optionally loaded with weights pre-trained on ImageNet. + +For image classification use cases, see +[this page for detailed examples]( +https://keras.io/api/applications/#usage-examples-for-image-classification-models). + +For transfer learning use cases, make sure to read the +[guide to transfer learning & fine-tuning]( +https://keras.io/guides/transfer_learning/). + +Note: each Keras Application expects a specific kind of input preprocessing. +For EfficientNetV2, by default input preprocessing is included as a part of +the model (as a `Rescaling` layer), and thus +`keras.applications.efficientnet_v2.preprocess_input` is actually a +pass-through function. In this use case, EfficientNetV2 models expect their +inputs to be float tensors of pixels with values in the `[0, 255]` range. +At the same time, preprocessing as a part of the model (i.e. `Rescaling` +layer) can be disabled by setting `include_preprocessing` argument to `False`. +With preprocessing disabled EfficientNetV2 models expect their inputs to be +float tensors of pixels with values in the `[-1, 1]` range. + +Args: + include_top: Boolean, whether to include the fully-connected + layer at the top of the network. Defaults to `True`. + weights: One of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. Defaults to `"imagenet"`. + input_tensor: Optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: Optional shape tuple, only to be specified + if `include_top` is `False`. + It should have exactly 3 inputs channels. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. Defaults to None. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `"avg"` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `"max"` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. Defaults to 1000 (number of + ImageNet classes). + classifier_activation: A string or callable. The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + Defaults to `"softmax"`. + When loading pretrained weights, `classifier_activation` can only + be `None` or `"softmax"`. + name: The name of the model (string). + +Returns: + A model instance. +""" + + +def round_filters(filters, width_coefficient, min_depth, depth_divisor): + """Round number of filters based on depth multiplier.""" + filters *= width_coefficient + minimum_depth = min_depth or depth_divisor + new_filters = max( + minimum_depth, + int(filters + depth_divisor / 2) // depth_divisor * depth_divisor, + ) + return int(new_filters) + + +def round_repeats(repeats, depth_coefficient): + """Round number of repeats based on depth multiplier.""" + return int(math.ceil(depth_coefficient * repeats)) + + +def MBConvBlock( + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + se_ratio=0.0, + bn_momentum=0.9, + activation="swish", + survival_probability=0.8, + name=None, +): + """MBConv block: Mobile Inverted Residual Bottleneck.""" + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + + if name is None: + name = backend.get_uid("block0") + + def apply(inputs): + # Expansion phase + filters = input_filters * expand_ratio + if expand_ratio != 1: + x = layers.Conv2D( + filters=filters, + kernel_size=1, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format=backend.image_data_format(), + use_bias=False, + name=f"{name}expand_conv", + )(inputs) + x = layers.BatchNormalization( + axis=bn_axis, + momentum=bn_momentum, + name=f"{name}expand_bn", + )(x) + x = layers.Activation(activation, name=f"{name}expand_activation")( + x + ) + else: + x = inputs + + # Depthwise conv + x = layers.DepthwiseConv2D( + kernel_size=kernel_size, + strides=strides, + depthwise_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format=backend.image_data_format(), + use_bias=False, + name=f"{name}dwconv2", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, momentum=bn_momentum, name=f"{name}bn" + )(x) + x = layers.Activation(activation, name=f"{name}activation")(x) + + # Squeeze and excite + if 0 < se_ratio <= 1: + filters_se = max(1, int(input_filters * se_ratio)) + se = layers.GlobalAveragePooling2D(name=f"{name}se_squeeze")(x) + if bn_axis == 1: + se_shape = (filters, 1, 1) + else: + se_shape = (1, 1, filters) + se = layers.Reshape(se_shape, name=f"{name}se_reshape")(se) + + se = layers.Conv2D( + filters_se, + 1, + padding="same", + activation=activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}se_reduce", + )(se) + se = layers.Conv2D( + filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}se_expand", + )(se) + + x = layers.multiply([x, se], name=f"{name}se_excite") + + # Output phase + x = layers.Conv2D( + filters=output_filters, + kernel_size=1, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format=backend.image_data_format(), + use_bias=False, + name=f"{name}project_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, momentum=bn_momentum, name=f"{name}project_bn" + )(x) + + if strides == 1 and input_filters == output_filters: + if survival_probability: + x = layers.Dropout( + survival_probability, + noise_shape=(None, 1, 1, 1), + name=f"{name}drop", + )(x) + x = layers.add([x, inputs], name=f"{name}add") + + return x + + return apply + + +def FusedMBConvBlock( + input_filters, + output_filters, + expand_ratio=1, + kernel_size=3, + strides=1, + se_ratio=0.0, + bn_momentum=0.9, + activation="swish", + survival_probability=0.8, + name=None, +): + """Fuses the proj conv1x1 and depthwise_conv into a conv2d.""" + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + + if name is None: + name = backend.get_uid("block0") + + def apply(inputs): + filters = input_filters * expand_ratio + if expand_ratio != 1: + x = layers.Conv2D( + filters, + kernel_size=kernel_size, + strides=strides, + kernel_initializer=CONV_KERNEL_INITIALIZER, + data_format=backend.image_data_format(), + padding="same", + use_bias=False, + name=f"{name}expand_conv", + )(inputs) + x = layers.BatchNormalization( + axis=bn_axis, momentum=bn_momentum, name=f"{name}expand_bn" + )(x) + x = layers.Activation( + activation=activation, name=f"{name}expand_activation" + )(x) + else: + x = inputs + + # Squeeze and excite + if 0 < se_ratio <= 1: + filters_se = max(1, int(input_filters * se_ratio)) + se = layers.GlobalAveragePooling2D(name=f"{name}se_squeeze")(x) + if bn_axis == 1: + se_shape = (filters, 1, 1) + else: + se_shape = (1, 1, filters) + + se = layers.Reshape(se_shape, name=f"{name}se_reshape")(se) + + se = layers.Conv2D( + filters_se, + 1, + padding="same", + activation=activation, + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}se_reduce", + )(se) + se = layers.Conv2D( + filters, + 1, + padding="same", + activation="sigmoid", + kernel_initializer=CONV_KERNEL_INITIALIZER, + name=f"{name}se_expand", + )(se) + + x = layers.multiply([x, se], name=f"{name}se_excite") + + # Output phase: + x = layers.Conv2D( + output_filters, + kernel_size=1 if expand_ratio != 1 else kernel_size, + strides=1 if expand_ratio != 1 else strides, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + use_bias=False, + name=f"{name}project_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, momentum=bn_momentum, name=f"{name}project_bn" + )(x) + if expand_ratio == 1: + x = layers.Activation( + activation=activation, name=f"{name}project_activation" + )(x) + + # Residual: + if strides == 1 and input_filters == output_filters: + if survival_probability: + x = layers.Dropout( + survival_probability, + noise_shape=(None, 1, 1, 1), + name=f"{name}drop", + )(x) + x = layers.add([x, inputs], name=f"{name}add") + return x + + return apply + + +def EfficientNetV2( + width_coefficient, + depth_coefficient, + default_size, + dropout_rate=0.2, + drop_connect_rate=0.2, + depth_divisor=8, + min_depth=8, + bn_momentum=0.9, + activation="swish", + blocks_args="default", + name="efficientnetv2", + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + weights_name=None, +): + """Instantiates the EfficientNetV2 architecture using given scaling + coefficients. + + Args: + width_coefficient: float, scaling coefficient for network width. + depth_coefficient: float, scaling coefficient for network depth. + default_size: integer, default input image size. + dropout_rate: float, dropout rate before final classifier layer. + drop_connect_rate: float, dropout rate at skip connections. + depth_divisor: integer, a unit of network width. + min_depth: integer, minimum number of filters. + bn_momentum: float. Momentum parameter for Batch Normalization layers. + activation: activation function. + blocks_args: list of dicts, parameters to construct block modules. + name: string, model name. + include_top: whether to include the fully-connected layer at the top of + the network. + weights: one of `None` (random initialization), `"imagenet"` + (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) or + numpy array to use as image input for the model. + input_shape: optional shape tuple, only to be specified if `include_top` + is `False`. It should have exactly 3 inputs channels. + pooling: optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` means that the output of the model will be the + 4D tensor output of the last convolutional layer. + - "avg" means that global average pooling will be applied to + the output of the last convolutional layer, + and thus the output of the model will be a 2D tensor. + - `"max"` means that global max pooling will be applied. + classes: optional number of classes to classify images into, + only to be specified if `include_top` is `True`, and if no `weights` + argument is specified. + classifier_activation: A string or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. + include_preprocessing: Boolean, whether to include the preprocessing + layer (`Rescaling`) at the bottom of the network. + Defaults to `True`. + + Returns: + A model instance. + """ + + if blocks_args == "default": + blocks_args = DEFAULT_BLOCKS_ARGS[name] + + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + f"Received: weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top`' + " as true, `classes` should be 1000" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + bn_axis = 3 if backend.image_data_format() == "channels_last" else 1 + + x = img_input + + if include_preprocessing: + # Apply original V1 preprocessing for Bx variants + # if number of channels allows it + num_channels = input_shape[bn_axis - 1] + if name.split("-")[-1].startswith("b") and num_channels == 3: + x = layers.Rescaling(scale=1.0 / 255)(x) + if backend.image_data_format() == "channels_first": + mean = [[[[0.485]], [[0.456]], [[0.406]]]] # shape [1,3,1,1] + variance = [ + [[[0.229**2]], [[0.224**2]], [[0.225**2]]] + ] # shape [1,3,1,1] + else: + mean = [0.485, 0.456, 0.406] + variance = [0.229**2, 0.224**2, 0.225**2] + x = layers.Normalization( + mean=mean, + variance=variance, + axis=bn_axis, + )(x) + else: + x = layers.Rescaling(scale=1.0 / 128.0, offset=-1)(x) + + # Build stem + stem_filters = round_filters( + filters=blocks_args[0]["input_filters"], + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + ) + x = layers.Conv2D( + filters=stem_filters, + kernel_size=3, + strides=2, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + use_bias=False, + name="stem_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + momentum=bn_momentum, + name="stem_bn", + )(x) + x = layers.Activation(activation, name="stem_activation")(x) + + # Build blocks + blocks_args = copy.deepcopy(blocks_args) + b = 0 + blocks = float(sum(args["num_repeat"] for args in blocks_args)) + + for i, args in enumerate(blocks_args): + assert args["num_repeat"] > 0 + + # Update block input and output filters based on depth multiplier. + args["input_filters"] = round_filters( + filters=args["input_filters"], + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + ) + args["output_filters"] = round_filters( + filters=args["output_filters"], + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + ) + + # Determine which conv type to use: + block = {0: MBConvBlock, 1: FusedMBConvBlock}[args.pop("conv_type")] + repeats = round_repeats( + repeats=args.pop("num_repeat"), depth_coefficient=depth_coefficient + ) + for j in range(repeats): + # The first block needs to take care of stride and filter size + # increase. + if j > 0: + args["strides"] = 1 + args["input_filters"] = args["output_filters"] + + x = block( + activation=activation, + bn_momentum=bn_momentum, + survival_probability=drop_connect_rate * b / blocks, + name=f"block{i + 1}{chr(j + 97)}_", + **args, + )(x) + b += 1 + + # Build top + top_filters = round_filters( + filters=1280, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + ) + x = layers.Conv2D( + filters=top_filters, + kernel_size=1, + strides=1, + kernel_initializer=CONV_KERNEL_INITIALIZER, + padding="same", + data_format=backend.image_data_format(), + use_bias=False, + name="top_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, + momentum=bn_momentum, + name="top_bn", + )(x) + x = layers.Activation(activation=activation, name="top_activation")(x) + + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + if dropout_rate > 0: + x = layers.Dropout(dropout_rate, name="top_dropout")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, + activation=classifier_activation, + kernel_initializer=DENSE_KERNEL_INITIALIZER, + bias_initializer=initializers.Constant(0.0), + name="predictions", + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + file_suffix = ".h5" + file_hash = WEIGHTS_HASHES[weights_name][0] + else: + file_suffix = "_notop.h5" + file_hash = WEIGHTS_HASHES[weights_name][1] + file_name = name + file_suffix + weights_path = file_utils.get_file( + file_name, + BASE_WEIGHTS_PATH + file_name, + cache_subdir="models", + file_hash=file_hash, + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2B0", + "keras.applications.EfficientNetV2B0", + ] +) +def EfficientNetV2B0( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-b0", +): + return EfficientNetV2( + width_coefficient=1.0, + depth_coefficient=1.0, + default_size=224, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="b0", + ) + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2B1", + "keras.applications.EfficientNetV2B1", + ] +) +def EfficientNetV2B1( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-b1", +): + return EfficientNetV2( + width_coefficient=1.0, + depth_coefficient=1.1, + default_size=240, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="b1", + ) + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2B2", + "keras.applications.EfficientNetV2B2", + ] +) +def EfficientNetV2B2( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-b2", +): + return EfficientNetV2( + width_coefficient=1.1, + depth_coefficient=1.2, + default_size=260, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="b2", + ) + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2B3", + "keras.applications.EfficientNetV2B3", + ] +) +def EfficientNetV2B3( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-b3", +): + return EfficientNetV2( + width_coefficient=1.2, + depth_coefficient=1.4, + default_size=300, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="b3", + ) + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2S", + "keras.applications.EfficientNetV2S", + ] +) +def EfficientNetV2S( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-s", +): + return EfficientNetV2( + width_coefficient=1.0, + depth_coefficient=1.0, + default_size=384, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="-s", + ) + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2M", + "keras.applications.EfficientNetV2M", + ] +) +def EfficientNetV2M( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-m", +): + return EfficientNetV2( + width_coefficient=1.0, + depth_coefficient=1.0, + default_size=480, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="-m", + ) + + +@keras_export( + [ + "keras.applications.efficientnet_v2.EfficientNetV2L", + "keras.applications.EfficientNetV2L", + ] +) +def EfficientNetV2L( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + include_preprocessing=True, + name="efficientnetv2-l", +): + return EfficientNetV2( + width_coefficient=1.0, + depth_coefficient=1.0, + default_size=480, + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + include_preprocessing=include_preprocessing, + weights_name="-l", + ) + + +EfficientNetV2B0.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B0") +EfficientNetV2B1.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B1") +EfficientNetV2B2.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B2") +EfficientNetV2B3.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B3") +EfficientNetV2S.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2S") +EfficientNetV2M.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2M") +EfficientNetV2L.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2L") + + +@keras_export("keras.applications.efficientnet_v2.preprocess_input") +def preprocess_input(x, data_format=None): + """A placeholder method for backward compatibility. + + The preprocessing logic has been included in the EfficientNetV2 model + implementation. Users are no longer required to call this method to + normalize the input data. This method does nothing and only kept as a + placeholder to align the API surface between old and new version of model. + + Args: + x: A floating point `numpy.array` or a tensor. + data_format: Optional data format of the image tensor/array. Defaults to + None, in which case the global setting + `keras.backend.image_data_format()` is used + (unless you changed it, it defaults to "channels_last").{mode} + + Returns: + Unchanged `numpy.array` or tensor. + """ + return x + + +@keras_export("keras.applications.efficientnet_v2.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/imagenet_utils.py b/keras/src/applications/imagenet_utils.py new file mode 100644 index 000000000000..f88c0af64d88 --- /dev/null +++ b/keras/src/applications/imagenet_utils.py @@ -0,0 +1,462 @@ +import json +import warnings + +import numpy as np + +from keras.src import activations +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.utils import file_utils + +CLASS_INDEX = None +CLASS_INDEX_PATH = ( + "https://storage.googleapis.com/download.tensorflow.org/" + "data/imagenet_class_index.json" +) + + +PREPROCESS_INPUT_DOC = """ + Preprocesses a tensor or Numpy array encoding a batch of images. + + Usage example with `applications.MobileNet`: + + ```python + i = keras.layers.Input([None, None, 3], dtype="uint8") + x = ops.cast(i, "float32") + x = keras.applications.mobilenet.preprocess_input(x) + core = keras.applications.MobileNet() + x = core(x) + model = keras.Model(inputs=[i], outputs=[x]) + result = model(image) + ``` + + Args: + x: A floating point `numpy.array` or a backend-native tensor, + 3D or 4D with 3 color + channels, with values in the range [0, 255]. + The preprocessed data are written over the input data + if the data types are compatible. To avoid this + behaviour, `numpy.copy(x)` can be used. + data_format: Optional data format of the image tensor/array. None, means + the global setting `keras.backend.image_data_format()` is used + (unless you changed it, it uses "channels_last").{mode} + Defaults to `None`. + + Returns: + Preprocessed array with type `float32`. + {ret} + + Raises: + {error} + """ + +PREPROCESS_INPUT_MODE_DOC = """ + mode: One of "caffe", "tf" or "torch". + - caffe: will convert the images from RGB to BGR, + then will zero-center each color channel with + respect to the ImageNet dataset, + without scaling. + - tf: will scale pixels between -1 and 1, + sample-wise. + - torch: will scale pixels between 0 and 1 and then + will normalize each channel with respect to the + ImageNet dataset. + Defaults to `"caffe"`. + """ + +PREPROCESS_INPUT_DEFAULT_ERROR_DOC = """ + ValueError: In case of unknown `mode` or `data_format` argument.""" + +PREPROCESS_INPUT_ERROR_DOC = """ + ValueError: In case of unknown `data_format` argument.""" + +PREPROCESS_INPUT_RET_DOC_TF = """ + The inputs pixel values are scaled between -1 and 1, sample-wise.""" + +PREPROCESS_INPUT_RET_DOC_TORCH = """ + The input pixels values are scaled between 0 and 1 and each channel is + normalized with respect to the ImageNet dataset.""" + +PREPROCESS_INPUT_RET_DOC_CAFFE = """ + The images are converted from RGB to BGR, then each color channel is + zero-centered with respect to the ImageNet dataset, without scaling.""" + + +@keras_export("keras.applications.imagenet_utils.preprocess_input") +def preprocess_input(x, data_format=None, mode="caffe"): + """Preprocesses a tensor or Numpy array encoding a batch of images.""" + if mode not in {"caffe", "tf", "torch"}: + raise ValueError( + "Expected mode to be one of `caffe`, `tf` or `torch`. " + f"Received: mode={mode}" + ) + + if data_format is None: + data_format = backend.image_data_format() + elif data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "Expected data_format to be one of `channels_first` or " + f"`channels_last`. Received: data_format={data_format}" + ) + + if isinstance(x, np.ndarray): + return _preprocess_numpy_input(x, data_format=data_format, mode=mode) + else: + return _preprocess_tensor_input(x, data_format=data_format, mode=mode) + + +preprocess_input.__doc__ = PREPROCESS_INPUT_DOC.format( + mode=PREPROCESS_INPUT_MODE_DOC, + ret="", + error=PREPROCESS_INPUT_DEFAULT_ERROR_DOC, +) + + +@keras_export("keras.applications.imagenet_utils.decode_predictions") +def decode_predictions(preds, top=5): + """Decodes the prediction of an ImageNet model. + + Args: + preds: NumPy array encoding a batch of predictions. + top: Integer, how many top-guesses to return. Defaults to `5`. + + Returns: + A list of lists of top class prediction tuples + `(class_name, class_description, score)`. + One list of tuples per sample in batch input. + + Raises: + ValueError: In case of invalid shape of the `pred` array + (must be 2D). + """ + global CLASS_INDEX + + if len(preds.shape) != 2 or preds.shape[1] != 1000: + raise ValueError( + "`decode_predictions` expects " + "a batch of predictions " + "(i.e. a 2D array of shape (samples, 1000)). " + f"Received array with shape: {preds.shape}" + ) + if CLASS_INDEX is None: + fpath = file_utils.get_file( + "imagenet_class_index.json", + CLASS_INDEX_PATH, + cache_subdir="models", + file_hash="c2c37ea517e94d9795004a39431a14cb", + ) + with open(fpath) as f: + CLASS_INDEX = json.load(f) + results = [] + preds = ops.convert_to_numpy(preds) + for pred in preds: + top_indices = pred.argsort()[-top:][::-1] + result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] + result.sort(key=lambda x: x[2], reverse=True) + results.append(result) + return results + + +def _preprocess_numpy_input(x, data_format, mode): + """Preprocesses a NumPy array encoding a batch of images. + + Args: + x: Input array, 3D or 4D. + data_format: Data format of the image array. + mode: One of "caffe", "tf" or "torch". + - caffe: will convert the images from RGB to BGR, + then will zero-center each color channel with + respect to the ImageNet dataset, + without scaling. + - tf: will scale pixels between -1 and 1, + sample-wise. + - torch: will scale pixels between 0 and 1 and then + will normalize each channel with respect to the + ImageNet dataset. + + Returns: + Preprocessed Numpy array. + """ + if not issubclass(x.dtype.type, np.floating): + x = x.astype(backend.floatx(), copy=False) + + if mode == "tf": + x /= 127.5 + x -= 1.0 + return x + elif mode == "torch": + x /= 255.0 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + else: + if data_format == "channels_first": + # 'RGB'->'BGR' + if len(x.shape) == 3: + x = x[::-1, ...] + else: + x = x[:, ::-1, ...] + else: + # 'RGB'->'BGR' + x = x[..., ::-1] + mean = [103.939, 116.779, 123.68] + std = None + + # Zero-center by mean pixel + if data_format == "channels_first": + if len(x.shape) == 3: + x[0, :, :] -= mean[0] + x[1, :, :] -= mean[1] + x[2, :, :] -= mean[2] + if std is not None: + x[0, :, :] /= std[0] + x[1, :, :] /= std[1] + x[2, :, :] /= std[2] + else: + x[:, 0, :, :] -= mean[0] + x[:, 1, :, :] -= mean[1] + x[:, 2, :, :] -= mean[2] + if std is not None: + x[:, 0, :, :] /= std[0] + x[:, 1, :, :] /= std[1] + x[:, 2, :, :] /= std[2] + else: + x[..., 0] -= mean[0] + x[..., 1] -= mean[1] + x[..., 2] -= mean[2] + if std is not None: + x[..., 0] /= std[0] + x[..., 1] /= std[1] + x[..., 2] /= std[2] + return x + + +def _preprocess_tensor_input(x, data_format, mode): + """Preprocesses a tensor encoding a batch of images. + + Args: + x: Input tensor, 3D or 4D. + data_format: Data format of the image tensor. + mode: One of "caffe", "tf" or "torch". + - caffe: will convert the images from RGB to BGR, + then will zero-center each color channel with + respect to the ImageNet dataset, + without scaling. + - tf: will scale pixels between -1 and 1, + sample-wise. + - torch: will scale pixels between 0 and 1 and then + will normalize each channel with respect to the + ImageNet dataset. + + Returns: + Preprocessed tensor. + """ + ndim = len(x.shape) + + if mode == "tf": + x /= 127.5 + x -= 1.0 + return x + elif mode == "torch": + x /= 255.0 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + else: + if data_format == "channels_first": + # 'RGB'->'BGR' + if len(x.shape) == 3: + x = ops.stack([x[i, ...] for i in (2, 1, 0)], axis=0) + else: + x = ops.stack([x[:, i, :] for i in (2, 1, 0)], axis=1) + else: + # 'RGB'->'BGR' + x = ops.stack([x[..., i] for i in (2, 1, 0)], axis=-1) + mean = [103.939, 116.779, 123.68] + std = None + + mean_tensor = ops.convert_to_tensor(-np.array(mean), dtype=x.dtype) + + # Zero-center by mean pixel + if data_format == "channels_first": + mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2)) + else: + mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,)) + x += mean_tensor + if std is not None: + std_tensor = ops.convert_to_tensor(np.array(std), dtype=x.dtype) + if data_format == "channels_first": + std_tensor = ops.reshape(std_tensor, (-1, 1, 1)) + x /= std_tensor + return x + + +def obtain_input_shape( + input_shape, + default_size, + min_size, + data_format, + require_flatten, + weights=None, +): + """Internal utility to compute/validate a model's input shape. + + Args: + input_shape: Either None (will return the default network input shape), + or a user-provided shape to be validated. + default_size: Default input width/height for the model. + min_size: Minimum input width/height accepted by the model. + data_format: Image data format to use. + require_flatten: Whether the model is expected to + be linked to a classifier via a Flatten layer. + weights: One of `None` (random initialization) + or 'imagenet' (pre-training on ImageNet). + If weights='imagenet' input channels must be equal to 3. + + Returns: + An integer shape tuple (may include None entries). + + Raises: + ValueError: In case of invalid argument values. + """ + if weights != "imagenet" and input_shape and len(input_shape) == 3: + if data_format == "channels_first": + correct_channel_axis = 1 if len(input_shape) == 4 else 0 + if input_shape[correct_channel_axis] not in {1, 3}: + warnings.warn( + "This model usually expects 1 or 3 input channels. " + "However, it was passed an input_shape " + f"with {input_shape[0]} input channels.", + stacklevel=2, + ) + default_shape = (input_shape[0], default_size, default_size) + else: + if input_shape[-1] not in {1, 3}: + warnings.warn( + "This model usually expects 1 or 3 input channels. " + "However, it was passed an input_shape " + f"with {input_shape[-1]} input channels.", + stacklevel=2, + ) + default_shape = (default_size, default_size, input_shape[-1]) + else: + if data_format == "channels_first": + default_shape = (3, default_size, default_size) + else: + default_shape = (default_size, default_size, 3) + if weights == "imagenet" and require_flatten: + if input_shape is not None: + if input_shape != default_shape: + raise ValueError( + "When setting `include_top=True` " + "and loading `imagenet` weights, " + f"`input_shape` should be {default_shape}. " + f"Received: input_shape={input_shape}" + ) + return default_shape + if input_shape: + if data_format == "channels_first": + if input_shape is not None: + if len(input_shape) != 3: + raise ValueError( + "`input_shape` must be a tuple of three integers." + ) + if input_shape[0] != 3 and weights == "imagenet": + raise ValueError( + "The input must have 3 channels; Received " + f"`input_shape={input_shape}`" + ) + if ( + input_shape[1] is not None and input_shape[1] < min_size + ) or (input_shape[2] is not None and input_shape[2] < min_size): + raise ValueError( + f"Input size must be at least {min_size}" + f"x{min_size}; Received: " + f"input_shape={input_shape}" + ) + else: + if input_shape is not None: + if len(input_shape) != 3: + raise ValueError( + "`input_shape` must be a tuple of three integers." + ) + if input_shape[-1] != 3 and weights == "imagenet": + raise ValueError( + "The input must have 3 channels; Received " + f"`input_shape={input_shape}`" + ) + if ( + input_shape[0] is not None and input_shape[0] < min_size + ) or (input_shape[1] is not None and input_shape[1] < min_size): + raise ValueError( + "Input size must be at least " + f"{min_size}x{min_size}; Received: " + f"input_shape={input_shape}" + ) + else: + if require_flatten: + input_shape = default_shape + else: + if data_format == "channels_first": + input_shape = (3, None, None) + else: + input_shape = (None, None, 3) + if require_flatten: + if None in input_shape: + raise ValueError( + "If `include_top` is True, " + "you should specify a static `input_shape`. " + f"Received: input_shape={input_shape}" + ) + return input_shape + + +def correct_pad(inputs, kernel_size): + """Returns a tuple for zero-padding for 2D convolution with downsampling. + + Args: + inputs: Input tensor. + kernel_size: An integer or tuple/list of 2 integers. + + Returns: + A tuple. + """ + img_dim = 2 if backend.image_data_format() == "channels_first" else 1 + input_size = inputs.shape[img_dim : (img_dim + 2)] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if input_size[0] is None: + adjust = (1, 1) + else: + adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + return ( + (correct[0] - adjust[0], correct[0]), + (correct[1] - adjust[1], correct[1]), + ) + + +def validate_activation(classifier_activation, weights): + """validates that the classifer_activation is compatible with the weights. + + Args: + classifier_activation: str or callable activation function + weights: The pretrained weights to load. + + Raises: + ValueError: if an activation other than `None` or `softmax` are used with + pretrained weights. + """ + if weights is None: + return + + classifier_activation = activations.get(classifier_activation) + if classifier_activation not in { + activations.get("softmax"), + activations.get(None), + }: + raise ValueError( + "Only `None` and `softmax` activations are allowed " + "for the `classifier_activation` argument when using " + "pretrained weights, with `include_top=True`; Received: " + f"classifier_activation={classifier_activation}" + ) diff --git a/keras/src/applications/imagenet_utils_test.py b/keras/src/applications/imagenet_utils_test.py new file mode 100644 index 000000000000..9eb254a56c6a --- /dev/null +++ b/keras/src/applications/imagenet_utils_test.py @@ -0,0 +1,314 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.applications import imagenet_utils as utils +from keras.src.dtype_policies.dtype_policy import set_dtype_policy + + +class TestImageNetUtils(testing.TestCase): + def test_preprocess_input(self): + # Test invalid mode check + x = np.random.uniform(0, 255, (10, 10, 3)) + with self.assertRaises(ValueError): + utils.preprocess_input(x, mode="some_unknown_mode") + + # Test image batch with float and int image input + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + xint = x.astype("int32") + self.assertEqual(utils.preprocess_input(x).shape, x.shape) + self.assertEqual(utils.preprocess_input(xint).shape, xint.shape) + + out1 = utils.preprocess_input(x, "channels_last") + out1int = utils.preprocess_input(xint, "channels_last") + out2 = utils.preprocess_input( + np.transpose(x, (0, 3, 1, 2)), "channels_first" + ) + out2int = utils.preprocess_input( + np.transpose(xint, (0, 3, 1, 2)), "channels_first" + ) + self.assertAllClose(out1, out2.transpose(0, 2, 3, 1)) + self.assertAllClose(out1int, out2int.transpose(0, 2, 3, 1)) + + # Test single image + x = np.random.uniform(0, 255, (10, 10, 3)) + xint = x.astype("int32") + self.assertEqual(utils.preprocess_input(x).shape, x.shape) + self.assertEqual(utils.preprocess_input(xint).shape, xint.shape) + + out1 = utils.preprocess_input(x, "channels_last") + out1int = utils.preprocess_input(xint, "channels_last") + out2 = utils.preprocess_input( + np.transpose(x, (2, 0, 1)), "channels_first" + ) + out2int = utils.preprocess_input( + np.transpose(xint, (2, 0, 1)), "channels_first" + ) + self.assertAllClose(out1, out2.transpose(1, 2, 0)) + self.assertAllClose(out1int, out2int.transpose(1, 2, 0)) + + # Test that writing over the input data works predictably + for mode in ["torch", "tf"]: + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + xint = x.astype("int") + x2 = utils.preprocess_input(x, "channels_last", mode=mode) + xint2 = utils.preprocess_input(xint, "channels_last") + self.assertAllClose(x, x2) + self.assertNotEqual(xint.astype("float").max(), xint2.max()) + + # Caffe mode works differently from the others + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + xint = x.astype("int") + x2 = utils.preprocess_input( + x, data_format="channels_last", mode="caffe" + ) + xint2 = utils.preprocess_input(xint, data_format="channels_last") + self.assertAllClose(x, x2[..., ::-1]) + self.assertNotEqual(xint.astype("float").max(), xint2.max()) + + @parameterized.named_parameters( + [ + {"testcase_name": "mode_torch", "mode": "torch"}, + {"testcase_name": "mode_tf", "mode": "tf"}, + {"testcase_name": "mode_caffe", "mode": "caffe"}, + ] + ) + @pytest.mark.requires_trainable_backend + def test_preprocess_input_symbolic(self, mode): + backend_data_format = backend.image_data_format() + # Test image batch + if backend_data_format == "channels_last": + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + elif backend_data_format == "channels_first": + x = np.random.uniform(0, 255, (2, 3, 10, 10)) + inputs = keras.layers.Input(shape=x.shape[1:]) + outputs = keras.layers.Lambda( + lambda x: utils.preprocess_input(x, mode=mode), + output_shape=x.shape[1:], + )(inputs) + model = keras.Model(inputs, outputs) + self.assertEqual(model.predict(x).shape, x.shape) + + x = np.random.uniform(0, 255, (2, 10, 10, 3)) + inputs = keras.layers.Input(shape=x.shape[1:]) + outputs1 = keras.layers.Lambda( + lambda x: utils.preprocess_input(x, "channels_last", mode=mode), + output_shape=x.shape[1:], + )(inputs) + model1 = keras.Model(inputs, outputs1) + out1 = model1.predict(x) + x2 = np.transpose(x, (0, 3, 1, 2)) + inputs2 = keras.layers.Input(shape=x2.shape[1:]) + outputs2 = keras.layers.Lambda( + lambda x: utils.preprocess_input(x, "channels_first", mode=mode), + output_shape=x2.shape[1:], + )(inputs2) + model2 = keras.Model(inputs2, outputs2) + out2 = model2.predict(x2) + self.assertAllClose(out1, out2.transpose(0, 2, 3, 1)) + + # Test single image + if backend_data_format == "channels_last": + x = np.random.uniform(0, 255, (10, 10, 3)) + elif backend_data_format == "channels_first": + x = np.random.uniform(0, 255, (3, 10, 10)) + inputs = keras.layers.Input(shape=x.shape) + outputs = keras.layers.Lambda( + lambda x: utils.preprocess_input(x, mode=mode), output_shape=x.shape + )(inputs) + model = keras.Model(inputs, outputs) + self.assertEqual(model.predict(x[np.newaxis])[0].shape, x.shape) + + x = np.random.uniform(0, 255, (10, 10, 3)) + inputs = keras.layers.Input(shape=x.shape) + outputs1 = keras.layers.Lambda( + lambda x: utils.preprocess_input(x, "channels_last", mode=mode), + output_shape=x.shape, + )(inputs) + model1 = keras.Model(inputs, outputs1) + out1 = model1.predict(x[np.newaxis])[0] + x2 = np.transpose(x, (2, 0, 1)) + inputs2 = keras.layers.Input(shape=x2.shape) + outputs2 = keras.layers.Lambda( + lambda x: utils.preprocess_input(x, "channels_first", mode=mode), + output_shape=x2.shape, + )(inputs2) + model2 = keras.Model(inputs2, outputs2) + out2 = model2.predict(x2[np.newaxis])[0] + self.assertAllClose(out1, out2.transpose(1, 2, 0)) + + @parameterized.named_parameters( + [ + {"testcase_name": "mode_torch", "mode": "torch"}, + {"testcase_name": "mode_tf", "mode": "tf"}, + {"testcase_name": "mode_caffe", "mode": "caffe"}, + ] + ) + def test_preprocess_input_symbolic_mixed_precision(self, mode): + set_dtype_policy("mixed_float16") + shape = (20, 20, 3) + inputs = keras.layers.Input(shape=shape) + try: + keras.layers.Lambda( + lambda x: utils.preprocess_input(x, mode=mode), + output_shape=shape, + )(inputs) + finally: + set_dtype_policy("float32") + + @parameterized.named_parameters( + [ + { + "testcase_name": "channels_last_format", + "data_format": "channels_last", + }, + { + "testcase_name": "channels_first_format", + "data_format": "channels_first", + }, + ] + ) + def test_obtain_input_shape(self, data_format): + # input_shape and default_size are not identical. + with self.assertRaises(ValueError): + utils.obtain_input_shape( + input_shape=(224, 224, 3), + default_size=299, + min_size=139, + data_format="channels_last", + require_flatten=True, + weights="imagenet", + ) + + # Test invalid use cases + + shape = (139, 139) + if data_format == "channels_last": + input_shape = shape + (99,) + else: + input_shape = (99,) + shape + + # input_shape is smaller than min_size. + shape = (100, 100) + if data_format == "channels_last": + input_shape = shape + (3,) + else: + input_shape = (3,) + shape + with self.assertRaises(ValueError): + utils.obtain_input_shape( + input_shape=input_shape, + default_size=None, + min_size=139, + data_format=data_format, + require_flatten=False, + ) + + # shape is 1D. + shape = (100,) + if data_format == "channels_last": + input_shape = shape + (3,) + else: + input_shape = (3,) + shape + with self.assertRaises(ValueError): + utils.obtain_input_shape( + input_shape=input_shape, + default_size=None, + min_size=139, + data_format=data_format, + require_flatten=False, + ) + + # the number of channels is 5 not 3. + shape = (100, 100) + if data_format == "channels_last": + input_shape = shape + (5,) + else: + input_shape = (5,) + shape + with self.assertRaises(ValueError): + utils.obtain_input_shape( + input_shape=input_shape, + default_size=None, + min_size=139, + data_format=data_format, + require_flatten=False, + ) + + # require_flatten=True with dynamic input shape. + with self.assertRaises(ValueError): + utils.obtain_input_shape( + input_shape=None, + default_size=None, + min_size=139, + data_format="channels_first", + require_flatten=True, + ) + + # test include top + self.assertEqual( + utils.obtain_input_shape( + input_shape=(3, 200, 200), + default_size=None, + min_size=139, + data_format="channels_first", + require_flatten=True, + ), + (3, 200, 200), + ) + + self.assertEqual( + utils.obtain_input_shape( + input_shape=None, + default_size=None, + min_size=139, + data_format="channels_last", + require_flatten=False, + ), + (None, None, 3), + ) + + self.assertEqual( + utils.obtain_input_shape( + input_shape=None, + default_size=None, + min_size=139, + data_format="channels_first", + require_flatten=False, + ), + (3, None, None), + ) + + self.assertEqual( + utils.obtain_input_shape( + input_shape=None, + default_size=None, + min_size=139, + data_format="channels_last", + require_flatten=False, + ), + (None, None, 3), + ) + + self.assertEqual( + utils.obtain_input_shape( + input_shape=(150, 150, 3), + default_size=None, + min_size=139, + data_format="channels_last", + require_flatten=False, + ), + (150, 150, 3), + ) + + self.assertEqual( + utils.obtain_input_shape( + input_shape=(3, None, None), + default_size=None, + min_size=139, + data_format="channels_first", + require_flatten=False, + ), + (3, None, None), + ) diff --git a/keras/src/applications/inception_resnet_v2.py b/keras/src/applications/inception_resnet_v2.py new file mode 100644 index 000000000000..5289c14f2f87 --- /dev/null +++ b/keras/src/applications/inception_resnet_v2.py @@ -0,0 +1,396 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.layers.layer import Layer +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHT_URL = ( + "https://storage.googleapis.com/tensorflow/" + "keras-applications/inception_resnet_v2/" +) + + +@keras_export( + [ + "keras.applications.inception_resnet_v2.InceptionResNetV2", + "keras.applications.InceptionResNetV2", + ] +) +def InceptionResNetV2( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="inception_resnet_v2", +): + """Instantiates the Inception-ResNet v2 architecture. + + Reference: + - [Inception-v4, Inception-ResNet and the Impact of + Residual Connections on Learning](https://arxiv.org/abs/1602.07261) + (AAAI 2017) + + This function returns a Keras image classification model, + optionally loaded with weights pre-trained on ImageNet. + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + Note: each Keras Application expects a specific kind of + input preprocessing. For InceptionResNetV2, call + `keras.applications.inception_resnet_v2.preprocess_input` + on your inputs before passing them to the model. + `inception_resnet_v2.preprocess_input` + will scale input pixels between -1 and 1. + + Args: + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is `False` (otherwise the input shape + has to be `(299, 299, 3)` + (with `'channels_last'` data format) + or `(3, 299, 299)` (with `'channels_first'` data format). + It should have exactly 3 inputs channels, + and width and height should be no smaller than 75. + E.g. `(150, 150, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the last convolutional block. + - `'avg'` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `'max'` means that global max pooling will be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, + and if no `weights` argument is specified. + classifier_activation: A `str` or callable. + The activation function to use on the "top" layer. + Ignored unless `include_top=True`. + Set `classifier_activation=None` to return the logits + of the "top" layer. When loading pretrained weights, + `classifier_activation` can only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A model instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top=True`, ' + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=299, + min_size=75, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + # Stem block: 35 x 35 x 192 + x = conv2d_bn(img_input, 32, 3, strides=2, padding="valid") + x = conv2d_bn(x, 32, 3, padding="valid") + x = conv2d_bn(x, 64, 3) + x = layers.MaxPooling2D(3, strides=2)(x) + x = conv2d_bn(x, 80, 1, padding="valid") + x = conv2d_bn(x, 192, 3, padding="valid") + x = layers.MaxPooling2D(3, strides=2)(x) + + # Mixed 5b (Inception-A block): 35 x 35 x 320 + branch_0 = conv2d_bn(x, 96, 1) + branch_1 = conv2d_bn(x, 48, 1) + branch_1 = conv2d_bn(branch_1, 64, 5) + branch_2 = conv2d_bn(x, 64, 1) + branch_2 = conv2d_bn(branch_2, 96, 3) + branch_2 = conv2d_bn(branch_2, 96, 3) + branch_pool = layers.AveragePooling2D(3, strides=1, padding="same")(x) + branch_pool = conv2d_bn(branch_pool, 64, 1) + branches = [branch_0, branch_1, branch_2, branch_pool] + channel_axis = 1 if backend.image_data_format() == "channels_first" else 3 + x = layers.Concatenate(axis=channel_axis, name="mixed_5b")(branches) + + # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320 + for block_idx in range(1, 11): + x = inception_resnet_block( + x, scale=0.17, block_type="block35", block_idx=block_idx + ) + + # Mixed 6a (Reduction-A block): 17 x 17 x 1088 + branch_0 = conv2d_bn(x, 384, 3, strides=2, padding="valid") + branch_1 = conv2d_bn(x, 256, 1) + branch_1 = conv2d_bn(branch_1, 256, 3) + branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding="valid") + branch_pool = layers.MaxPooling2D(3, strides=2, padding="valid")(x) + branches = [branch_0, branch_1, branch_pool] + x = layers.Concatenate(axis=channel_axis, name="mixed_6a")(branches) + + # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088 + for block_idx in range(1, 21): + x = inception_resnet_block( + x, scale=0.1, block_type="block17", block_idx=block_idx + ) + + # Mixed 7a (Reduction-B block): 8 x 8 x 2080 + branch_0 = conv2d_bn(x, 256, 1) + branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding="valid") + branch_1 = conv2d_bn(x, 256, 1) + branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding="valid") + branch_2 = conv2d_bn(x, 256, 1) + branch_2 = conv2d_bn(branch_2, 288, 3) + branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding="valid") + branch_pool = layers.MaxPooling2D(3, strides=2, padding="valid")(x) + branches = [branch_0, branch_1, branch_2, branch_pool] + x = layers.Concatenate(axis=channel_axis, name="mixed_7a")(branches) + + # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080 + for block_idx in range(1, 10): + x = inception_resnet_block( + x, scale=0.2, block_type="block8", block_idx=block_idx + ) + x = inception_resnet_block( + x, scale=1.0, activation=None, block_type="block8", block_idx=10 + ) + + # Final convolution block: 8 x 8 x 1536 + x = conv2d_bn(x, 1536, 1, name="conv_7b") + + if include_top: + # Classification block + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + fname = "inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5" + weights_path = file_utils.get_file( + fname, + BASE_WEIGHT_URL + fname, + cache_subdir="models", + file_hash="e693bd0210a403b3192acc6073ad2e96", + ) + else: + fname = ( + "inception_resnet_v2_weights_" + "tf_dim_ordering_tf_kernels_notop.h5" + ) + weights_path = file_utils.get_file( + fname, + BASE_WEIGHT_URL + fname, + cache_subdir="models", + file_hash="d19885ff4a710c122648d3b5c3b684e4", + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +def conv2d_bn( + x, + filters, + kernel_size, + strides=1, + padding="same", + activation="relu", + use_bias=False, + name=None, +): + """Utility function to apply conv + BN. + + Args: + x: input tensor. + filters: filters in `Conv2D`. + kernel_size: kernel size as in `Conv2D`. + strides: strides in `Conv2D`. + padding: padding mode in `Conv2D`. + activation: activation in `Conv2D`. + use_bias: whether to use a bias in `Conv2D`. + name: name of the ops; will become `name + '_ac'` + for the activation and `name + '_bn'` for the batch norm layer. + + Returns: + Output tensor after applying `Conv2D` and `BatchNormalization`. + """ + x = layers.Conv2D( + filters, + kernel_size, + strides=strides, + padding=padding, + use_bias=use_bias, + name=name, + )(x) + if not use_bias: + bn_axis = 1 if backend.image_data_format() == "channels_first" else 3 + bn_name = None if name is None else f"{name}_bn" + x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)( + x + ) + if activation is not None: + ac_name = None if name is None else f"{name}_ac" + x = layers.Activation(activation, name=ac_name)(x) + return x + + +class CustomScaleLayer(Layer): + def __init__(self, scale, **kwargs): + super().__init__(**kwargs) + self.scale = scale + + def get_config(self): + config = super().get_config() + config.update({"scale": self.scale}) + return config + + def call(self, inputs): + return inputs[0] + inputs[1] * self.scale + + +def inception_resnet_block(x, scale, block_type, block_idx, activation="relu"): + """Adds an Inception-ResNet block. + + Args: + x: input tensor. + scale: scaling factor to scale the residuals + (i.e., the output of passing `x` through an inception module) + before adding them to the shortcut + branch. Let `r` be the output from the residual branch, + the output of this block will be `x + scale * r`. + block_type: `'block35'`, `'block17'` or `'block8'`, + determines the network structure in the residual branch. + block_idx: an `int` used for generating layer names. + The Inception-ResNet blocks are repeated many times + in this network. We use `block_idx` to identify each + of the repetitions. For example, the first + Inception-ResNet-A block will have + `block_type='block35', block_idx=0`, and the layer names + will have a common prefix `'block35_0'`. + activation: activation function to use at the end of the block. + + Returns: + Output tensor for the block. + """ + if block_type == "block35": + branch_0 = conv2d_bn(x, 32, 1) + branch_1 = conv2d_bn(x, 32, 1) + branch_1 = conv2d_bn(branch_1, 32, 3) + branch_2 = conv2d_bn(x, 32, 1) + branch_2 = conv2d_bn(branch_2, 48, 3) + branch_2 = conv2d_bn(branch_2, 64, 3) + branches = [branch_0, branch_1, branch_2] + elif block_type == "block17": + branch_0 = conv2d_bn(x, 192, 1) + branch_1 = conv2d_bn(x, 128, 1) + branch_1 = conv2d_bn(branch_1, 160, [1, 7]) + branch_1 = conv2d_bn(branch_1, 192, [7, 1]) + branches = [branch_0, branch_1] + elif block_type == "block8": + branch_0 = conv2d_bn(x, 192, 1) + branch_1 = conv2d_bn(x, 192, 1) + branch_1 = conv2d_bn(branch_1, 224, [1, 3]) + branch_1 = conv2d_bn(branch_1, 256, [3, 1]) + branches = [branch_0, branch_1] + else: + raise ValueError( + "Unknown Inception-ResNet block type. " + 'Expects "block35", "block17" or "block8", ' + f"but got: {block_type}" + ) + + block_name = f"{block_type}_{block_idx}" + channel_axis = 1 if backend.image_data_format() == "channels_first" else 3 + mixed = layers.Concatenate(axis=channel_axis, name=f"{block_name}_mixed")( + branches + ) + up = conv2d_bn( + mixed, + x.shape[channel_axis], + 1, + activation=None, + use_bias=True, + name=f"{block_name}_conv", + ) + + x = CustomScaleLayer(scale)([x, up]) + if activation is not None: + x = layers.Activation(activation, name=f"{block_name}_ac")(x) + return x + + +@keras_export("keras.applications.inception_resnet_v2.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.inception_resnet_v2.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/inception_v3.py b/keras/src/applications/inception_v3.py new file mode 100644 index 000000000000..50d3e0bf0bda --- /dev/null +++ b/keras/src/applications/inception_v3.py @@ -0,0 +1,442 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels.h5" +) +WEIGHTS_PATH_NO_TOP = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5" +) + + +@keras_export( + [ + "keras.applications.inception_v3.InceptionV3", + "keras.applications.InceptionV3", + ] +) +def InceptionV3( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="inception_v3", +): + """Instantiates the Inception v3 architecture. + + Reference: + - [Rethinking the Inception Architecture for Computer Vision]( + http://arxiv.org/abs/1512.00567) (CVPR 2016) + + This function returns a Keras image classification model, + optionally loaded with weights pre-trained on ImageNet. + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + Note: each Keras Application expects a specific kind of input preprocessing. + For `InceptionV3`, call + `keras.applications.inception_v3.preprocess_input` on your inputs + before passing them to the model. + `inception_v3.preprocess_input` will scale input pixels between -1 and 1. + + Args: + include_top: Boolean, whether to include the fully-connected + layer at the top, as the last layer of the network. + Defaults to `True`. + weights: One of `None` (random initialization), + `imagenet` (pre-training on ImageNet), + or the path to the weights file to be loaded. + Defaults to `"imagenet"`. + input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. `input_tensor` is useful for + sharing inputs between multiple different networks. + Defaults to `None`. + input_shape: Optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(299, 299, 3)` (with `channels_last` data format) + or `(3, 299, 299)` (with `channels_first` data format). + It should have exactly 3 inputs channels, + and width and height should be no smaller than 75. + E.g. `(150, 150, 3)` would be one valid value. + `input_shape` will be ignored if the `input_tensor` is provided. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` (default) means that the output of the model will be + the 4D tensor output of the last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. Defaults to 1000. + classifier_activation: A `str` or callable. The activation function + to use on the "top" layer. Ignored unless `include_top=True`. + Set `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` + can only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A model instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded; " + f"Received: weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top=True`, ' + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=299, + min_size=75, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + if backend.image_data_format() == "channels_first": + channel_axis = 1 + else: + channel_axis = 3 + + x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding="valid") + x = conv2d_bn(x, 32, 3, 3, padding="valid") + x = conv2d_bn(x, 64, 3, 3) + x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) + + x = conv2d_bn(x, 80, 1, 1, padding="valid") + x = conv2d_bn(x, 192, 3, 3, padding="valid") + x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) + + # mixed 0: 35 x 35 x 256 + branch1x1 = conv2d_bn(x, 64, 1, 1) + + branch5x5 = conv2d_bn(x, 48, 1, 1) + branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) + + branch3x3dbl = conv2d_bn(x, 64, 1, 1) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 32, 1, 1) + x = layers.concatenate( + [branch1x1, branch5x5, branch3x3dbl, branch_pool], + axis=channel_axis, + name="mixed0", + ) + + # mixed 1: 35 x 35 x 288 + branch1x1 = conv2d_bn(x, 64, 1, 1) + + branch5x5 = conv2d_bn(x, 48, 1, 1) + branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) + + branch3x3dbl = conv2d_bn(x, 64, 1, 1) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 64, 1, 1) + x = layers.concatenate( + [branch1x1, branch5x5, branch3x3dbl, branch_pool], + axis=channel_axis, + name="mixed1", + ) + + # mixed 2: 35 x 35 x 288 + branch1x1 = conv2d_bn(x, 64, 1, 1) + + branch5x5 = conv2d_bn(x, 48, 1, 1) + branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) + + branch3x3dbl = conv2d_bn(x, 64, 1, 1) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 64, 1, 1) + x = layers.concatenate( + [branch1x1, branch5x5, branch3x3dbl, branch_pool], + axis=channel_axis, + name="mixed2", + ) + + # mixed 3: 17 x 17 x 768 + branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding="valid") + + branch3x3dbl = conv2d_bn(x, 64, 1, 1) + branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) + branch3x3dbl = conv2d_bn( + branch3x3dbl, 96, 3, 3, strides=(2, 2), padding="valid" + ) + + branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) + x = layers.concatenate( + [branch3x3, branch3x3dbl, branch_pool], axis=channel_axis, name="mixed3" + ) + + # mixed 4: 17 x 17 x 768 + branch1x1 = conv2d_bn(x, 192, 1, 1) + + branch7x7 = conv2d_bn(x, 128, 1, 1) + branch7x7 = conv2d_bn(branch7x7, 128, 1, 7) + branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) + + branch7x7dbl = conv2d_bn(x, 128, 1, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7) + branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 192, 1, 1) + x = layers.concatenate( + [branch1x1, branch7x7, branch7x7dbl, branch_pool], + axis=channel_axis, + name="mixed4", + ) + + # mixed 5, 6: 17 x 17 x 768 + for i in range(2): + branch1x1 = conv2d_bn(x, 192, 1, 1) + + branch7x7 = conv2d_bn(x, 160, 1, 1) + branch7x7 = conv2d_bn(branch7x7, 160, 1, 7) + branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) + + branch7x7dbl = conv2d_bn(x, 160, 1, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7) + branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 192, 1, 1) + x = layers.concatenate( + [branch1x1, branch7x7, branch7x7dbl, branch_pool], + axis=channel_axis, + name="mixed{0}".format(5 + i), + ) + + # mixed 7: 17 x 17 x 768 + branch1x1 = conv2d_bn(x, 192, 1, 1) + + branch7x7 = conv2d_bn(x, 192, 1, 1) + branch7x7 = conv2d_bn(branch7x7, 192, 1, 7) + branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) + + branch7x7dbl = conv2d_bn(x, 192, 1, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) + branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) + branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 192, 1, 1) + x = layers.concatenate( + [branch1x1, branch7x7, branch7x7dbl, branch_pool], + axis=channel_axis, + name="mixed7", + ) + + # mixed 8: 8 x 8 x 1280 + branch3x3 = conv2d_bn(x, 192, 1, 1) + branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, strides=(2, 2), padding="valid") + + branch7x7x3 = conv2d_bn(x, 192, 1, 1) + branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7) + branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1) + branch7x7x3 = conv2d_bn( + branch7x7x3, 192, 3, 3, strides=(2, 2), padding="valid" + ) + + branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) + x = layers.concatenate( + [branch3x3, branch7x7x3, branch_pool], axis=channel_axis, name="mixed8" + ) + + # mixed 9: 8 x 8 x 2048 + for i in range(2): + branch1x1 = conv2d_bn(x, 320, 1, 1) + + branch3x3 = conv2d_bn(x, 384, 1, 1) + branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3) + branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1) + branch3x3 = layers.concatenate( + [branch3x3_1, branch3x3_2], + axis=channel_axis, + name=f"mixed9_{i}", + ) + + branch3x3dbl = conv2d_bn(x, 448, 1, 1) + branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3) + branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3) + branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1) + branch3x3dbl = layers.concatenate( + [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis + ) + + branch_pool = layers.AveragePooling2D( + (3, 3), strides=(1, 1), padding="same" + )(x) + branch_pool = conv2d_bn(branch_pool, 192, 1, 1) + x = layers.concatenate( + [branch1x1, branch3x3, branch3x3dbl, branch_pool], + axis=channel_axis, + name=f"mixed{9 + i}", + ) + if include_top: + # Classification block + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + weights_path = file_utils.get_file( + "inception_v3_weights_tf_dim_ordering_tf_kernels.h5", + WEIGHTS_PATH, + cache_subdir="models", + file_hash="9a0d58056eeedaa3f26cb7ebd46da564", + ) + else: + weights_path = file_utils.get_file( + "inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5", + WEIGHTS_PATH_NO_TOP, + cache_subdir="models", + file_hash="bcbd6486424b2319ff4ef7d526e38f63", + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +def conv2d_bn( + x, filters, num_row, num_col, padding="same", strides=(1, 1), name=None +): + """Utility function to apply conv + BN. + + Args: + x: input tensor. + filters: filters in `Conv2D`. + num_row: height of the convolution kernel. + num_col: width of the convolution kernel. + padding: padding mode in `Conv2D`. + strides: strides in `Conv2D`. + name: name of the ops; will become `name + '_conv'` + for the convolution and `name + '_bn'` for the + batch norm layer. + + Returns: + Output tensor after applying `Conv2D` and `BatchNormalization`. + """ + if name is not None: + bn_name = f"{name}_bn" + conv_name = f"{name}_conv" + else: + bn_name = None + conv_name = None + if backend.image_data_format() == "channels_first": + bn_axis = 1 + else: + bn_axis = 3 + x = layers.Conv2D( + filters, + (num_row, num_col), + strides=strides, + padding=padding, + use_bias=False, + name=conv_name, + )(x) + x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) + x = layers.Activation("relu", name=name)(x) + return x + + +@keras_export("keras.applications.inception_v3.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.inception_v3.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/mobilenet.py b/keras/src/applications/mobilenet.py new file mode 100644 index 000000000000..ea1b5d581374 --- /dev/null +++ b/keras/src/applications/mobilenet.py @@ -0,0 +1,435 @@ +import warnings + +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHT_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/" +) + + +@keras_export( + [ + "keras.applications.mobilenet.MobileNet", + "keras.applications.MobileNet", + ] +) +def MobileNet( + input_shape=None, + alpha=1.0, + depth_multiplier=1, + dropout=1e-3, + include_top=True, + weights="imagenet", + input_tensor=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name=None, +): + """Instantiates the MobileNet architecture. + + Reference: + - [MobileNets: Efficient Convolutional Neural Networks + for Mobile Vision Applications]( + https://arxiv.org/abs/1704.04861) + + This function returns a Keras image classification model, + optionally loaded with weights pre-trained on ImageNet. + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + Note: each Keras Application expects a specific kind of input preprocessing. + For MobileNet, call `keras.applications.mobilenet.preprocess_input` + on your inputs before passing them to the model. + `mobilenet.preprocess_input` will scale input pixels between -1 and 1. + + Args: + input_shape: Optional shape tuple, only to be specified if `include_top` + is `False` (otherwise the input shape has to be `(224, 224, 3)` + (with `"channels_last"` data format) or `(3, 224, 224)` + (with `"channels_first"` data format). + It should have exactly 3 inputs channels, and width and + height should be no smaller than 32. E.g. `(200, 200, 3)` would + be one valid value. Defaults to `None`. + `input_shape` will be ignored if the `input_tensor` is provided. + alpha: Controls the width of the network. This is known as the width + multiplier in the MobileNet paper. + - If `alpha < 1.0`, proportionally decreases the number + of filters in each layer. + - If `alpha > 1.0`, proportionally increases the number + of filters in each layer. + - If `alpha == 1`, default number of filters from the paper + are used at each layer. Defaults to `1.0`. + depth_multiplier: Depth multiplier for depthwise convolution. + This is called the resolution multiplier in the MobileNet paper. + Defaults to `1.0`. + dropout: Dropout rate. Defaults to `0.001`. + include_top: Boolean, whether to include the fully-connected layer + at the top of the network. Defaults to `True`. + weights: One of `None` (random initialization), `"imagenet"` + (pre-training on ImageNet), or the path to the weights file + to be loaded. Defaults to `"imagenet"`. + input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. `input_tensor` is useful + for sharing inputs between multiple different networks. + Defaults to `None`. + pooling: Optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` (default) means that the output of the model will be + the 4D tensor output of the last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will be applied. + classes: Optional number of classes to classify images into, + only to be specified if `include_top` is `True`, and if + no `weights` argument is specified. Defaults to `1000`. + classifier_activation: A `str` or callable. The activation function + to use on the "top" layer. Ignored unless `include_top=True`. + Set `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` + can only be `None` or `"softmax"`. + name: String, the name of the model. + + Returns: + A model instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), 'imagenet' " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded. " + f"Received weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + "If using `weights='imagenet'` with `include_top=True`, " + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape and default size. + if input_shape is None: + default_size = 224 + else: + if backend.image_data_format() == "channels_first": + rows = input_shape[1] + cols = input_shape[2] + else: + rows = input_shape[0] + cols = input_shape[1] + + if rows == cols and rows in [128, 160, 192, 224]: + default_size = rows + else: + default_size = 224 + + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if backend.image_data_format() == "channels_last": + row_axis, col_axis = (0, 1) + else: + row_axis, col_axis = (1, 2) + rows = input_shape[row_axis] + cols = input_shape[col_axis] + + if weights == "imagenet": + if depth_multiplier != 1: + raise ValueError( + "If imagenet weights are being loaded, " + "depth multiplier must be 1. " + f"Received depth_multiplier={depth_multiplier}" + ) + + if alpha not in [0.25, 0.50, 0.75, 1.0]: + raise ValueError( + "If imagenet weights are being loaded, " + "alpha can be one of" + "`0.25`, `0.50`, `0.75` or `1.0` only. " + f"Received alpha={alpha}" + ) + + if rows != cols or rows not in [128, 160, 192, 224]: + rows = 224 + warnings.warn( + "`input_shape` is undefined or non-square, " + "or `rows` is not in [128, 160, 192, 224]. " + "Weights for input shape (224, 224) will be " + "loaded as the default.", + stacklevel=2, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + x = _conv_block(img_input, 32, alpha, strides=(2, 2)) + x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1) + + x = _depthwise_conv_block( + x, 128, alpha, depth_multiplier, strides=(2, 2), block_id=2 + ) + x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3) + + x = _depthwise_conv_block( + x, 256, alpha, depth_multiplier, strides=(2, 2), block_id=4 + ) + x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5) + + x = _depthwise_conv_block( + x, 512, alpha, depth_multiplier, strides=(2, 2), block_id=6 + ) + x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7) + x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8) + x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9) + x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10) + x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11) + + x = _depthwise_conv_block( + x, 1024, alpha, depth_multiplier, strides=(2, 2), block_id=12 + ) + x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13) + + if include_top: + x = layers.GlobalAveragePooling2D(keepdims=True)(x) + x = layers.Dropout(dropout, name="dropout")(x) + x = layers.Conv2D(classes, (1, 1), padding="same", name="conv_preds")(x) + x = layers.Reshape((classes,), name="reshape_2")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Activation( + activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + if name is None: + name = f"mobilenet_{alpha:0.2f}_{rows}" + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if alpha == 1.0: + alpha_text = "1_0" + elif alpha == 0.75: + alpha_text = "7_5" + elif alpha == 0.50: + alpha_text = "5_0" + else: + alpha_text = "2_5" + + if include_top: + model_name = "mobilenet_%s_%d_tf.h5" % (alpha_text, rows) + weight_path = BASE_WEIGHT_PATH + model_name + weights_path = file_utils.get_file( + model_name, weight_path, cache_subdir="models" + ) + else: + model_name = "mobilenet_%s_%d_tf_no_top.h5" % (alpha_text, rows) + weight_path = BASE_WEIGHT_PATH + model_name + weights_path = file_utils.get_file( + model_name, weight_path, cache_subdir="models" + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): + """Adds an initial convolution layer (with batch normalization and relu6). + + Args: + inputs: Input tensor of shape `(rows, cols, 3)` (with `channels_last` + data format) or (3, rows, cols) (with `channels_first` data format). + It should have exactly 3 inputs channels, and width and height + should be no smaller than 32. E.g. `(224, 224, 3)` would be + one valid value. + filters: Integer, the dimensionality of the output space (i.e. the + number of output filters in the convolution). + alpha: controls the width of the network. - If `alpha` < 1.0, + proportionally decreases the number of filters in each layer. + - If `alpha` > 1.0, proportionally increases the number of filters + in each layer. + - If `alpha` = 1, default number of filters from the paper are + used at each layer. + kernel: An integer or tuple/list of 2 integers, specifying the width + and height of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for all + spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` + value != 1. + + Input shape: + 4D tensor with shape: `(samples, channels, rows, cols)` if + data_format='channels_first' + or 4D tensor with shape: `(samples, rows, cols, channels)` if + data_format='channels_last'. # Output shape + 4D tensor with shape: `(samples, filters, new_rows, new_cols)` + if data_format='channels_first' + or 4D tensor with shape: `(samples, new_rows, new_cols, filters)` + if data_format='channels_last'. `rows` and `cols` values + might have changed due to stride. + + Returns: + Output tensor of block. + """ + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + filters = int(filters * alpha) + x = layers.Conv2D( + filters, + kernel, + padding="same", + use_bias=False, + strides=strides, + name="conv1", + )(inputs) + x = layers.BatchNormalization(axis=channel_axis, name="conv1_bn")(x) + return layers.ReLU(6.0, name="conv1_relu")(x) + + +def _depthwise_conv_block( + inputs, + pointwise_conv_filters, + alpha, + depth_multiplier=1, + strides=(1, 1), + block_id=1, +): + """Adds a depthwise convolution block. + + A depthwise convolution block consists of a depthwise conv, + batch normalization, relu6, pointwise convolution, + batch normalization and relu6 activation. + + Args: + inputs: Input tensor of shape `(rows, cols, channels)` (with + `channels_last` data format) or (channels, rows, cols) (with + `channels_first` data format). + pointwise_conv_filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the pointwise convolution). + alpha: controls the width of the network. - If `alpha` < 1.0, + proportionally decreases the number of filters in each layer. + - If `alpha` > 1.0, proportionally increases the number of filters + in each layer. + - If `alpha` = 1, default number of filters from the paper are + used at each layer. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `filters_in * depth_multiplier`. + strides: An integer or tuple/list of 2 integers, specifying the strides + of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. Specifying any stride value != 1 is + incompatible with specifying any `dilation_rate` value != 1. + block_id: Integer, a unique identification designating the block number. + + Input shape: + 4D tensor with shape: `(batch, channels, rows, cols)` if + data_format='channels_first' + or 4D tensor with shape: `(batch, rows, cols, channels)` if + data_format='channels_last'. # Output shape + 4D tensor with shape: `(batch, filters, new_rows, new_cols)` if + data_format='channels_first' + or 4D tensor with shape: `(batch, new_rows, new_cols, filters)` if + data_format='channels_last'. `rows` and `cols` values might have + changed due to stride. + + Returns: + Output tensor of block. + """ + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + pointwise_conv_filters = int(pointwise_conv_filters * alpha) + + if strides == (1, 1): + x = inputs + else: + x = layers.ZeroPadding2D( + ((0, 1), (0, 1)), name="conv_pad_%d" % block_id + )(inputs) + x = layers.DepthwiseConv2D( + (3, 3), + padding="same" if strides == (1, 1) else "valid", + depth_multiplier=depth_multiplier, + strides=strides, + use_bias=False, + name="conv_dw_%d" % block_id, + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name="conv_dw_%d_bn" % block_id + )(x) + x = layers.ReLU(6.0, name="conv_dw_%d_relu" % block_id)(x) + + x = layers.Conv2D( + pointwise_conv_filters, + (1, 1), + padding="same", + use_bias=False, + strides=(1, 1), + name="conv_pw_%d" % block_id, + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name="conv_pw_%d_bn" % block_id + )(x) + return layers.ReLU(6.0, name="conv_pw_%d_relu" % block_id)(x) + + +@keras_export("keras.applications.mobilenet.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.mobilenet.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/mobilenet_v2.py b/keras/src/applications/mobilenet_v2.py new file mode 100644 index 000000000000..50e475329e63 --- /dev/null +++ b/keras/src/applications/mobilenet_v2.py @@ -0,0 +1,497 @@ +import warnings + +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHT_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/" +) + + +@keras_export( + [ + "keras.applications.mobilenet_v2.MobileNetV2", + "keras.applications.MobileNetV2", + ] +) +def MobileNetV2( + input_shape=None, + alpha=1.0, + include_top=True, + weights="imagenet", + input_tensor=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name=None, +): + """Instantiates the MobileNetV2 architecture. + + MobileNetV2 is very similar to the original MobileNet, + except that it uses inverted residual blocks with + bottlenecking features. It has a drastically lower + parameter count than the original MobileNet. + MobileNets support any input size greater + than 32 x 32, with larger image sizes + offering better performance. + + Reference: + - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( + https://arxiv.org/abs/1801.04381) (CVPR 2018) + + This function returns a Keras image classification model, + optionally loaded with weights pre-trained on ImageNet. + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + Note: each Keras Application expects a specific kind of input preprocessing. + For MobileNetV2, call + `keras.applications.mobilenet_v2.preprocess_input` + on your inputs before passing them to the model. + `mobilenet_v2.preprocess_input` will scale input pixels between -1 and 1. + + Args: + input_shape: Optional shape tuple, only to be specified if `include_top` + is `False` (otherwise the input shape has to be `(224, 224, 3)` + (with `"channels_last"` data format) or `(3, 224, 224)` + (with `"channels_first"` data format). + It should have exactly 3 inputs channels, and width and + height should be no smaller than 32. E.g. `(200, 200, 3)` would + be one valid value. Defaults to `None`. + `input_shape` will be ignored if the `input_tensor` is provided. + alpha: Controls the width of the network. This is known as the width + multiplier in the MobileNet paper. + - If `alpha < 1.0`, proportionally decreases the number + of filters in each layer. + - If `alpha > 1.0`, proportionally increases the number + of filters in each layer. + - If `alpha == 1`, default number of filters from the paper + are used at each layer. Defaults to `1.0`. + include_top: Boolean, whether to include the fully-connected layer + at the top of the network. Defaults to `True`. + weights: One of `None` (random initialization), `"imagenet"` + (pre-training on ImageNet), or the path to the weights file + to be loaded. Defaults to `"imagenet"`. + input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. `input_tensor` is useful + for sharing inputs between multiple different networks. + Defaults to `None`. + pooling: Optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` (default) means that the output of the model will be + the 4D tensor output of the last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will be applied. + classes: Optional number of classes to classify images into, + only to be specified if `include_top` is `True`, and if + no `weights` argument is specified. Defaults to `1000`. + classifier_activation: A `str` or callable. The activation function + to use on the "top" layer. Ignored unless `include_top=True`. + Set `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` + can only be `None` or `"softmax"`. + name: String, the name of the model. + + Returns: + A model instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded. " + f"Received `weights={weights}`" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top` ' + f"as true, `classes` should be 1000. Received `classes={classes}`" + ) + + # Determine proper input shape and default size. + # If both input_shape and input_tensor are used, they should match + if input_shape is not None and input_tensor is not None: + try: + is_input_t_tensor = backend.is_keras_tensor(input_tensor) + except ValueError: + try: + is_input_t_tensor = backend.is_keras_tensor( + operation_utils.get_source_inputs(input_tensor) + ) + except ValueError: + raise ValueError( + f"input_tensor: {input_tensor}" + "is not type input_tensor. " + f"Received `type(input_tensor)={type(input_tensor)}`" + ) + if is_input_t_tensor: + if backend.image_data_format() == "channels_first": + if input_tensor.shape[1] != input_shape[1]: + raise ValueError( + "input_shape[1] must equal shape(input_tensor)[1] " + "when `image_data_format` is `channels_first`; " + "Received `input_tensor.shape=" + f"{input_tensor.shape}`" + f", `input_shape={input_shape}`" + ) + else: + if input_tensor.shape[2] != input_shape[1]: + raise ValueError( + "input_tensor.shape[2] must equal input_shape[1]; " + "Received `input_tensor.shape=" + f"{input_tensor.shape}`, " + f"`input_shape={input_shape}`" + ) + else: + raise ValueError( + "input_tensor is not a Keras tensor; " + f"Received `input_tensor={input_tensor}`" + ) + + # If input_shape is None, infer shape from input_tensor. + if input_shape is None and input_tensor is not None: + try: + backend.is_keras_tensor(input_tensor) + except ValueError: + raise ValueError( + "input_tensor must be a valid Keras tensor type; " + f"Received {input_tensor} of type {type(input_tensor)}" + ) + + if input_shape is None and not backend.is_keras_tensor(input_tensor): + default_size = 224 + elif input_shape is None and backend.is_keras_tensor(input_tensor): + if backend.image_data_format() == "channels_first": + rows = input_tensor.shape[2] + cols = input_tensor.shape[3] + else: + rows = input_tensor.shape[1] + cols = input_tensor.shape[2] + + if rows == cols and rows in [96, 128, 160, 192, 224]: + default_size = rows + else: + default_size = 224 + + # If input_shape is None and no input_tensor + elif input_shape is None: + default_size = 224 + + # If input_shape is not None, assume default size. + else: + if backend.image_data_format() == "channels_first": + rows = input_shape[1] + cols = input_shape[2] + else: + rows = input_shape[0] + cols = input_shape[1] + + if rows == cols and rows in [96, 128, 160, 192, 224]: + default_size = rows + else: + default_size = 224 + + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if backend.image_data_format() == "channels_last": + row_axis, col_axis = (0, 1) + else: + row_axis, col_axis = (1, 2) + rows = input_shape[row_axis] + cols = input_shape[col_axis] + + if weights == "imagenet": + if alpha not in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]: + raise ValueError( + "If imagenet weights are being loaded, " + "alpha must be one of `0.35`, `0.50`, `0.75`, " + "`1.0`, `1.3` or `1.4` only;" + f" Received `alpha={alpha}`" + ) + + if rows != cols or rows not in [96, 128, 160, 192, 224]: + rows = 224 + warnings.warn( + "`input_shape` is undefined or non-square, " + "or `rows` is not in [96, 128, 160, 192, 224]. " + "Weights for input shape (224, 224) will be " + "loaded as the default.", + stacklevel=2, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + + first_block_filters = _make_divisible(32 * alpha, 8) + x = layers.Conv2D( + first_block_filters, + kernel_size=3, + strides=(2, 2), + padding="same", + use_bias=False, + name="Conv1", + )(img_input) + x = layers.BatchNormalization( + axis=channel_axis, epsilon=1e-3, momentum=0.999, name="bn_Conv1" + )(x) + x = layers.ReLU(6.0, name="Conv1_relu")(x) + + x = _inverted_res_block( + x, filters=16, alpha=alpha, stride=1, expansion=1, block_id=0 + ) + + x = _inverted_res_block( + x, filters=24, alpha=alpha, stride=2, expansion=6, block_id=1 + ) + x = _inverted_res_block( + x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=2 + ) + + x = _inverted_res_block( + x, filters=32, alpha=alpha, stride=2, expansion=6, block_id=3 + ) + x = _inverted_res_block( + x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=4 + ) + x = _inverted_res_block( + x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=5 + ) + + x = _inverted_res_block( + x, filters=64, alpha=alpha, stride=2, expansion=6, block_id=6 + ) + x = _inverted_res_block( + x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=7 + ) + x = _inverted_res_block( + x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=8 + ) + x = _inverted_res_block( + x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=9 + ) + + x = _inverted_res_block( + x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=10 + ) + x = _inverted_res_block( + x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=11 + ) + x = _inverted_res_block( + x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=12 + ) + + x = _inverted_res_block( + x, filters=160, alpha=alpha, stride=2, expansion=6, block_id=13 + ) + x = _inverted_res_block( + x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=14 + ) + x = _inverted_res_block( + x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=15 + ) + + x = _inverted_res_block( + x, filters=320, alpha=alpha, stride=1, expansion=6, block_id=16 + ) + + # no alpha applied to last conv as stated in the paper: + # if the width multiplier is greater than 1 we increase the number of output + # channels. + if alpha > 1.0: + last_block_filters = _make_divisible(1280 * alpha, 8) + else: + last_block_filters = 1280 + + x = layers.Conv2D( + last_block_filters, kernel_size=1, use_bias=False, name="Conv_1" + )(x) + x = layers.BatchNormalization( + axis=channel_axis, epsilon=1e-3, momentum=0.999, name="Conv_1_bn" + )(x) + x = layers.ReLU(6.0, name="out_relu")(x) + + if include_top: + x = layers.GlobalAveragePooling2D()(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account any potential predecessors of + # `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + if name is None: + name = f"mobilenetv2_{alpha:0.2f}_{rows}" + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + model_name = ( + "mobilenet_v2_weights_tf_dim_ordering_tf_kernels" + f"_{float(alpha)}_{rows}.h5" + ) + weight_path = BASE_WEIGHT_PATH + model_name + weights_path = file_utils.get_file( + model_name, weight_path, cache_subdir="models" + ) + else: + model_name = ( + "mobilenet_v2_weights_tf_dim_ordering_tf_kernels_" + f"{float(alpha)}_{rows}_no_top.h5" + ) + weight_path = BASE_WEIGHT_PATH + model_name + weights_path = file_utils.get_file( + model_name, weight_path, cache_subdir="models" + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id): + """Inverted ResNet block.""" + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + + in_channels = inputs.shape[channel_axis] + pointwise_conv_filters = int(filters * alpha) + # Ensure the number of filters on the last 1x1 convolution is divisible by + # 8. + pointwise_filters = _make_divisible(pointwise_conv_filters, 8) + x = inputs + prefix = f"block_{block_id}_" + + if block_id: + # Expand with a pointwise 1x1 convolution. + x = layers.Conv2D( + expansion * in_channels, + kernel_size=1, + padding="same", + use_bias=False, + activation=None, + name=f"{prefix}expand", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=f"{prefix}expand_BN", + )(x) + x = layers.ReLU(6.0, name=f"{prefix}expand_relu")(x) + else: + prefix = "expanded_conv_" + + # Depthwise 3x3 convolution. + if stride == 2: + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, 3), name=f"{prefix}pad" + )(x) + x = layers.DepthwiseConv2D( + kernel_size=3, + strides=stride, + activation=None, + use_bias=False, + padding="same" if stride == 1 else "valid", + name=f"{prefix}depthwise", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=f"{prefix}depthwise_BN", + )(x) + + x = layers.ReLU(6.0, name=f"{prefix}depthwise_relu")(x) + + # Project with a pointwise 1x1 convolution. + x = layers.Conv2D( + pointwise_filters, + kernel_size=1, + padding="same", + use_bias=False, + activation=None, + name=f"{prefix}project", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=f"{prefix}project_BN", + )(x) + + if in_channels == pointwise_filters and stride == 1: + return layers.Add(name=f"{prefix}add")([inputs, x]) + return x + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +@keras_export("keras.applications.mobilenet_v2.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.mobilenet_v2.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/mobilenet_v3.py b/keras/src/applications/mobilenet_v3.py new file mode 100644 index 000000000000..8496e9b257f3 --- /dev/null +++ b/keras/src/applications/mobilenet_v3.py @@ -0,0 +1,688 @@ +import warnings + +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHT_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/" +) +WEIGHTS_HASHES = { + "large_224_0.75_float": ( + "765b44a33ad4005b3ac83185abf1d0eb", + "40af19a13ebea4e2ee0c676887f69a2e", + ), + "large_224_1.0_float": ( + "59e551e166be033d707958cf9e29a6a7", + "07fb09a5933dd0c8eaafa16978110389", + ), + "large_minimalistic_224_1.0_float": ( + "675e7b876c45c57e9e63e6d90a36599c", + "ec5221f64a2f6d1ef965a614bdae7973", + ), + "small_224_0.75_float": ( + "cb65d4e5be93758266aa0a7f2c6708b7", + "ebdb5cc8e0b497cd13a7c275d475c819", + ), + "small_224_1.0_float": ( + "8768d4c2e7dee89b9d02b2d03d65d862", + "d3e8ec802a04aa4fc771ee12a9a9b836", + ), + "small_minimalistic_224_1.0_float": ( + "99cd97fb2fcdad2bf028eb838de69e37", + "cde8136e733e811080d9fcd8a252f7e4", + ), +} + + +BASE_DOCSTRING = """Instantiates the {name} architecture. + +Reference: +- [Searching for MobileNetV3]( + https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019) + +The following table describes the performance of MobileNets v3: +------------------------------------------------------------------------ +MACs stands for Multiply Adds + +|Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1 CPU(ms)| +|---|---|---|---|---| +| mobilenet_v3_large_1.0_224 | 217 | 5.4 | 75.6 | 51.2 | +| mobilenet_v3_large_0.75_224 | 155 | 4.0 | 73.3 | 39.8 | +| mobilenet_v3_large_minimalistic_1.0_224 | 209 | 3.9 | 72.3 | 44.1 | +| mobilenet_v3_small_1.0_224 | 66 | 2.9 | 68.1 | 15.8 | +| mobilenet_v3_small_0.75_224 | 44 | 2.4 | 65.4 | 12.8 | +| mobilenet_v3_small_minimalistic_1.0_224 | 65 | 2.0 | 61.9 | 12.2 | + +For image classification use cases, see +[this page for detailed examples]( +https://keras.io/api/applications/#usage-examples-for-image-classification-models). + +For transfer learning use cases, make sure to read the +[guide to transfer learning & fine-tuning]( +https://keras.io/guides/transfer_learning/). + +Note: each Keras Application expects a specific kind of input preprocessing. +For MobileNetV3, by default input preprocessing is included as a part of the +model (as a `Rescaling` layer), and thus +`keras.applications.mobilenet_v3.preprocess_input` is actually a +pass-through function. In this use case, MobileNetV3 models expect their +inputs to be float tensors of pixels with values in the `[0-255]` range. +At the same time, preprocessing as a part of the model (i.e. `Rescaling` +layer) can be disabled by setting `include_preprocessing` argument to `False`. +With preprocessing disabled MobileNetV3 models expect their inputs to be float +tensors of pixels with values in the `[-1, 1]` range. + +Args: + input_shape: Optional shape tuple, to be specified if you would + like to use a model with an input image resolution that is not + `(224, 224, 3)`. + It should have exactly 3 inputs channels. + You can also omit this option if you would like + to infer input_shape from an input_tensor. + If you choose to include both input_tensor and input_shape then + input_shape will be used if they match, if the shapes + do not match then we will throw an error. + E.g. `(160, 160, 3)` would be one valid value. + alpha: controls the width of the network. This is known as the + depth multiplier in the MobileNetV3 paper, but the name is kept for + consistency with MobileNetV1 in Keras. + When `weights` is `imagenet`, `alpha` can be one of `0.75` or `1.0` + for non-minimalistic models, and `1.0` for minimalistic models. + - If `alpha < 1.0`, proportionally decreases the number + of filters in each layer. + - If `alpha > 1.0`, proportionally increases the number + of filters in each layer. + - If `alpha == 1`, default number of filters from the paper + are used at each layer. + minimalistic: In addition to large and small models this module also + contains so-called minimalistic models, these models have the same + per-layer dimensions characteristic as MobilenetV3 however, they don't + utilize any of the advanced blocks (squeeze-and-excite units, + hard-swish, and 5x5 convolutions). + While these models are less efficient on CPU, they + are much more performant on GPU/DSP. + include_top: Boolean, whether to include the fully-connected + layer at the top of the network. Defaults to `True`. + weights: String, one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: String, optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Integer, optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. + dropout_rate: fraction of the input units to drop on the last layer. + classifier_activation: A `str` or callable. The activation function to use + on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + When loading pretrained weights, `classifier_activation` can only + be `None` or `"softmax"`. + include_preprocessing: Boolean, whether to include the preprocessing + layer (`Rescaling`) at the bottom of the network. Defaults to `True`. + name: String, the name of the model. + +Call arguments: + inputs: A floating point `numpy.array` or backend-native tensor, + 4D with 3 color channels, with values in the range `[0, 255]` + if `include_preprocessing` is `True` and in the range `[-1, 1]` + otherwise. + +Returns: + A model instance. +""" + + +def MobileNetV3( + stack_fn, + last_point_ch, + input_shape=None, + alpha=1.0, + model_type="large", + minimalistic=False, + include_top=True, + weights="imagenet", + input_tensor=None, + classes=1000, + pooling=None, + dropout_rate=0.2, + classifier_activation="softmax", + include_preprocessing=True, + name=None, +): + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded. " + f"Received weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights="imagenet"` with `include_top` ' + "as true, `classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape and default size. + # If both input_shape and input_tensor are used, they should match + if input_shape is not None and input_tensor is not None: + try: + is_input_t_tensor = backend.is_keras_tensor(input_tensor) + except ValueError: + try: + is_input_t_tensor = backend.is_keras_tensor( + operation_utils.get_source_inputs(input_tensor) + ) + except ValueError: + raise ValueError( + "input_tensor: ", + input_tensor, + "is not type input_tensor. " + f"Received type(input_tensor)={type(input_tensor)}", + ) + if is_input_t_tensor: + if backend.image_data_format() == "channels_first": + if input_tensor.shape[1] != input_shape[1]: + raise ValueError( + "When backend.image_data_format()=channels_first, " + "input_shape[1] must equal " + "input_tensor.shape[1]. Received " + f"input_shape={input_shape}, " + "input_tensor.shape=" + f"{input_tensor.shape}" + ) + else: + if input_tensor.shape[2] != input_shape[1]: + raise ValueError( + "input_shape[1] must equal " + "input_tensor.shape[2]. Received " + f"input_shape={input_shape}, " + "input_tensor.shape=" + f"{input_tensor.shape}" + ) + else: + raise ValueError( + "input_tensor specified: ", + input_tensor, + "is not a keras tensor", + ) + + # If input_shape is None, infer shape from input_tensor + if input_shape is None and input_tensor is not None: + try: + backend.is_keras_tensor(input_tensor) + except ValueError: + raise ValueError( + "input_tensor: ", + input_tensor, + "is type: ", + type(input_tensor), + "which is not a valid type", + ) + + if backend.is_keras_tensor(input_tensor): + if backend.image_data_format() == "channels_first": + rows = input_tensor.shape[2] + cols = input_tensor.shape[3] + input_shape = (3, cols, rows) + else: + rows = input_tensor.shape[1] + cols = input_tensor.shape[2] + input_shape = (cols, rows, 3) + # If input_shape is None and input_tensor is None using standard shape + if input_shape is None and input_tensor is None: + if backend.image_data_format() == "channels_last": + input_shape = (None, None, 3) + else: + input_shape = (3, None, None) + + if backend.image_data_format() == "channels_last": + row_axis, col_axis = (0, 1) + else: + row_axis, col_axis = (1, 2) + rows = input_shape[row_axis] + cols = input_shape[col_axis] + if rows and cols and (rows < 32 or cols < 32): + raise ValueError( + "Input size must be at least 32x32; Received `input_shape=" + f"{input_shape}`" + ) + if weights == "imagenet": + if ( + not minimalistic + and alpha not in [0.75, 1.0] + or minimalistic + and alpha != 1.0 + ): + raise ValueError( + "If imagenet weights are being loaded, " + "alpha can be one of `0.75`, `1.0` for non minimalistic " + "or `1.0` for minimalistic only." + ) + + if rows != cols or rows != 224: + warnings.warn( + "`input_shape` is undefined or non-square, " + "or `rows` is not 224. " + "Weights for input shape (224, 224) will be " + "loaded as the default.", + stacklevel=2, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + + if minimalistic: + kernel = 3 + activation = relu + se_ratio = None + else: + kernel = 5 + activation = hard_swish + se_ratio = 0.25 + + x = img_input + if include_preprocessing: + x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(x) + x = layers.Conv2D( + 16, + kernel_size=3, + strides=(2, 2), + padding="same", + use_bias=False, + name="conv", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, epsilon=1e-3, momentum=0.999, name="conv_bn" + )(x) + x = activation(x) + + x = stack_fn(x, kernel, activation, se_ratio) + + last_conv_ch = _depth(x.shape[channel_axis] * 6) + + # if the width multiplier is greater than 1 we + # increase the number of output channels + if alpha > 1.0: + last_point_ch = _depth(last_point_ch * alpha) + x = layers.Conv2D( + last_conv_ch, + kernel_size=1, + padding="same", + use_bias=False, + name="conv_1", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, epsilon=1e-3, momentum=0.999, name="conv_1_bn" + )(x) + x = activation(x) + if include_top: + x = layers.GlobalAveragePooling2D(keepdims=True)(x) + x = layers.Conv2D( + last_point_ch, + kernel_size=1, + padding="same", + use_bias=True, + name="conv_2", + )(x) + x = activation(x) + + if dropout_rate > 0: + x = layers.Dropout(dropout_rate)(x) + x = layers.Conv2D( + classes, kernel_size=1, padding="same", name="logits" + )(x) + x = layers.Flatten()(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Activation( + activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + model_name = "{}{}_224_{}_float".format( + model_type, "_minimalistic" if minimalistic else "", str(alpha) + ) + if include_top: + file_name = f"weights_mobilenet_v3_{model_name}.h5" + file_hash = WEIGHTS_HASHES[model_name][0] + else: + file_name = f"weights_mobilenet_v3_{model_name}_no_top_v2.h5" + file_hash = WEIGHTS_HASHES[model_name][1] + weights_path = file_utils.get_file( + file_name, + BASE_WEIGHT_PATH + file_name, + cache_subdir="models", + file_hash=file_hash, + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export("keras.applications.MobileNetV3Small") +def MobileNetV3Small( + input_shape=None, + alpha=1.0, + minimalistic=False, + include_top=True, + weights="imagenet", + input_tensor=None, + classes=1000, + pooling=None, + dropout_rate=0.2, + classifier_activation="softmax", + include_preprocessing=True, + name="MobileNetV3Small", +): + def stack_fn(x, kernel, activation, se_ratio): + def depth(d): + return _depth(d * alpha) + + x = _inverted_res_block(x, 1, depth(16), 3, 2, se_ratio, relu, 0) + x = _inverted_res_block(x, 72.0 / 16, depth(24), 3, 2, None, relu, 1) + x = _inverted_res_block(x, 88.0 / 24, depth(24), 3, 1, None, relu, 2) + x = _inverted_res_block( + x, 4, depth(40), kernel, 2, se_ratio, activation, 3 + ) + x = _inverted_res_block( + x, 6, depth(40), kernel, 1, se_ratio, activation, 4 + ) + x = _inverted_res_block( + x, 6, depth(40), kernel, 1, se_ratio, activation, 5 + ) + x = _inverted_res_block( + x, 3, depth(48), kernel, 1, se_ratio, activation, 6 + ) + x = _inverted_res_block( + x, 3, depth(48), kernel, 1, se_ratio, activation, 7 + ) + x = _inverted_res_block( + x, 6, depth(96), kernel, 2, se_ratio, activation, 8 + ) + x = _inverted_res_block( + x, 6, depth(96), kernel, 1, se_ratio, activation, 9 + ) + x = _inverted_res_block( + x, 6, depth(96), kernel, 1, se_ratio, activation, 10 + ) + return x + + return MobileNetV3( + stack_fn, + 1024, + input_shape, + alpha, + "small", + minimalistic, + include_top, + weights, + input_tensor, + classes, + pooling, + dropout_rate, + classifier_activation, + include_preprocessing, + name=name, + ) + + +@keras_export("keras.applications.MobileNetV3Large") +def MobileNetV3Large( + input_shape=None, + alpha=1.0, + minimalistic=False, + include_top=True, + weights="imagenet", + input_tensor=None, + classes=1000, + pooling=None, + dropout_rate=0.2, + classifier_activation="softmax", + include_preprocessing=True, + name="MobileNetV3Large", +): + def stack_fn(x, kernel, activation, se_ratio): + def depth(d): + return _depth(d * alpha) + + x = _inverted_res_block(x, 1, depth(16), 3, 1, None, relu, 0) + x = _inverted_res_block(x, 4, depth(24), 3, 2, None, relu, 1) + x = _inverted_res_block(x, 3, depth(24), 3, 1, None, relu, 2) + x = _inverted_res_block(x, 3, depth(40), kernel, 2, se_ratio, relu, 3) + x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 4) + x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 5) + x = _inverted_res_block(x, 6, depth(80), 3, 2, None, activation, 6) + x = _inverted_res_block(x, 2.5, depth(80), 3, 1, None, activation, 7) + x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 8) + x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 9) + x = _inverted_res_block( + x, 6, depth(112), 3, 1, se_ratio, activation, 10 + ) + x = _inverted_res_block( + x, 6, depth(112), 3, 1, se_ratio, activation, 11 + ) + x = _inverted_res_block( + x, 6, depth(160), kernel, 2, se_ratio, activation, 12 + ) + x = _inverted_res_block( + x, 6, depth(160), kernel, 1, se_ratio, activation, 13 + ) + x = _inverted_res_block( + x, 6, depth(160), kernel, 1, se_ratio, activation, 14 + ) + return x + + return MobileNetV3( + stack_fn, + 1280, + input_shape, + alpha, + "large", + minimalistic, + include_top, + weights, + input_tensor, + classes, + pooling, + dropout_rate, + classifier_activation, + include_preprocessing, + name=name, + ) + + +MobileNetV3Small.__doc__ = BASE_DOCSTRING.format(name="MobileNetV3Small") +MobileNetV3Large.__doc__ = BASE_DOCSTRING.format(name="MobileNetV3Large") + + +def relu(x): + return layers.ReLU()(x) + + +def hard_sigmoid(x): + return layers.ReLU(6.0)(x + 3.0) * (1.0 / 6.0) + + +def hard_swish(x): + return layers.Activation("hard_swish")(x) + + +# This function is taken from the original tf repo. +# It ensures that all layers have a channel number that is divisible by 8 +# It can be seen here: +# https://github.com/tensorflow/models/blob/master/research/ +# slim/nets/mobilenet/mobilenet.py + + +def _depth(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def _se_block(inputs, filters, se_ratio, prefix): + x = layers.GlobalAveragePooling2D( + keepdims=True, name=f"{prefix}squeeze_excite_avg_pool" + )(inputs) + x = layers.Conv2D( + _depth(filters * se_ratio), + kernel_size=1, + padding="same", + name=f"{prefix}squeeze_excite_conv", + )(x) + x = layers.ReLU(name=f"{prefix}squeeze_excite_relu")(x) + x = layers.Conv2D( + filters, + kernel_size=1, + padding="same", + name=f"{prefix}squeeze_excite_conv_1", + )(x) + x = hard_sigmoid(x) + x = layers.Multiply(name=f"{prefix}squeeze_excite_mul")([inputs, x]) + return x + + +def _inverted_res_block( + x, expansion, filters, kernel_size, stride, se_ratio, activation, block_id +): + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + shortcut = x + prefix = "expanded_conv_" + infilters = x.shape[channel_axis] + if block_id: + # Expand + prefix = f"expanded_conv_{block_id}_" + x = layers.Conv2D( + _depth(infilters * expansion), + kernel_size=1, + padding="same", + use_bias=False, + name=f"{prefix}expand", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=f"{prefix}expand_bn", + )(x) + x = activation(x) + + if stride == 2: + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, kernel_size), + name=f"{prefix}depthwise_pad", + )(x) + x = layers.DepthwiseConv2D( + kernel_size, + strides=stride, + padding="same" if stride == 1 else "valid", + use_bias=False, + name=f"{prefix}depthwise", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=f"{prefix}depthwise_bn", + )(x) + x = activation(x) + + if se_ratio: + x = _se_block(x, _depth(infilters * expansion), se_ratio, prefix) + + x = layers.Conv2D( + filters, + kernel_size=1, + padding="same", + use_bias=False, + name=f"{prefix}project", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, + epsilon=1e-3, + momentum=0.999, + name=f"{prefix}project_bn", + )(x) + + if stride == 1 and infilters == filters: + x = layers.Add(name=f"{prefix}add")([shortcut, x]) + return x + + +@keras_export("keras.applications.mobilenet_v3.preprocess_input") +def preprocess_input(x, data_format=None): + """A placeholder method for backward compatibility. + + The preprocessing logic has been included in the mobilenet_v3 model + implementation. Users are no longer required to call this method to + normalize the input data. This method does nothing and only kept as a + placeholder to align the API surface between old and new version of model. + + Args: + x: A floating point `numpy.array` or a tensor. + data_format: Optional data format of the image tensor/array. + `None` means the global setting + `keras.config.image_data_format()` is used + (unless you changed it, it uses `"channels_last"`). + Defaults to `None`. + + Returns: + Unchanged `numpy.array` or tensor. + """ + return x + + +@keras_export("keras.applications.mobilenet_v3.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/nasnet.py b/keras/src/applications/nasnet.py new file mode 100644 index 000000000000..e0f55da4f467 --- /dev/null +++ b/keras/src/applications/nasnet.py @@ -0,0 +1,869 @@ +import warnings + +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/nasnet/" +) +NASNET_MOBILE_WEIGHT_PATH = f"{BASE_WEIGHTS_PATH}NASNet-mobile.h5" +NASNET_MOBILE_WEIGHT_PATH_NO_TOP = f"{BASE_WEIGHTS_PATH}NASNet-mobile-no-top.h5" +NASNET_LARGE_WEIGHT_PATH = f"{BASE_WEIGHTS_PATH}NASNet-large.h5" +NASNET_LARGE_WEIGHT_PATH_NO_TOP = f"{BASE_WEIGHTS_PATH}NASNet-large-no-top.h5" + + +def NASNet( + input_shape=None, + penultimate_filters=4032, + num_blocks=6, + stem_block_filters=96, + skip_reduction=True, + filter_multiplier=2, + include_top=True, + weights="imagenet", + input_tensor=None, + pooling=None, + classes=1000, + default_size=None, + classifier_activation="softmax", + name="NASNet", +): + """Instantiates a NASNet model. + + Reference: + - [Learning Transferable Architectures for Scalable Image Recognition]( + https://arxiv.org/abs/1707.07012) (CVPR 2018) + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + Note: each Keras Application expects a specific kind of input preprocessing. + For NasNet, call `keras.applications.nasnet.preprocess_input` + on your inputs before passing them to the model. + `nasnet.preprocess_input` will scale input pixels between -1 and 1. + + Args: + input_shape: Optional shape tuple, the input shape + is by default `(331, 331, 3)` for NASNetLarge and + `(224, 224, 3)` for NASNetMobile. + It should have exactly 3 input channels, + and width and height should be no smaller than 32. + E.g. `(224, 224, 3)` would be one valid value. + penultimate_filters: Number of filters in the penultimate layer. + NASNet models use the notation `NASNet (N @ P)`, where: + - N is the number of blocks + - P is the number of penultimate filters + num_blocks: Number of repeated blocks of the NASNet model. + NASNet models use the notation `NASNet (N @ P)`, where: + - N is the number of blocks + - P is the number of penultimate filters + stem_block_filters: Number of filters in the initial stem block + skip_reduction: Whether to skip the reduction step at the tail + end of the network. + filter_multiplier: Controls the width of the network. + - If `filter_multiplier` < 1.0, proportionally decreases the number + of filters in each layer. + - If `filter_multiplier` > 1.0, proportionally increases the number + of filters in each layer. + - If `filter_multiplier` = 1, default number of filters from the + paper are used at each layer. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: `None` (random initialization) or + `imagenet` (ImageNet weights) + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. + default_size: Specifies the default image size of the model + classifier_activation: A `str` or callable. + The activation function to use on the "top" layer. + Ignored unless `include_top=True`. + Set `classifier_activation=None` to return the logits + of the "top" layer. When loading pretrained weights, + `classifier_activation` can only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A model instance. + """ + if backend.image_data_format() == "channels_first": + raise ValueError( + "NASNet does not support the `channels_first` image data " + "format. Switch to `channels_last` by editing your local " + "config file at ~/.keras/keras.json" + ) + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), `imagenet` " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + 'If using `weights` as `"imagenet"` with `include_top` ' + "as true, `classes` should be 1000" + ) + + if ( + isinstance(input_shape, tuple) + and None in input_shape + and weights == "imagenet" + ): + raise ValueError( + "When specifying the input shape of a NASNet and loading " + "`ImageNet` weights, the input_shape argument must be static" + f" (no None entries). Got: `input_shape={input_shape}`." + ) + + if default_size is None: + default_size = 331 + + # Determine proper input shape and default size. + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if backend.image_data_format() != "channels_last": + warnings.warn( + "The NASNet family of models is only available " + 'for the input data format "channels_last" ' + "(width, height, channels). " + "However your settings specify the default " + 'data format "channels_first" (channels, width, height).' + ' You should set `image_data_format="channels_last"` ' + "in your Keras config located at ~/.keras/keras.json. " + "The model being returned right now will expect inputs " + 'to follow the "channels_last" data format.', + stacklevel=2, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + if penultimate_filters % (24 * (filter_multiplier**2)) != 0: + raise ValueError( + "For NASNet-A models, the `penultimate_filters` must be a multiple " + "of 24 * (`filter_multiplier` ** 2). " + f"Current value: {penultimate_filters}" + ) + + channel_dim = 1 if backend.image_data_format() == "channels_first" else -1 + filters = penultimate_filters // 24 + + x = layers.Conv2D( + stem_block_filters, + (3, 3), + strides=(2, 2), + padding="valid", + use_bias=False, + name="stem_conv1", + kernel_initializer="he_normal", + )(img_input) + + x = layers.BatchNormalization( + axis=channel_dim, momentum=0.9997, epsilon=1e-3, name="stem_bn1" + )(x) + + p = None + x, p = _reduction_a_cell( + x, p, filters // (filter_multiplier**2), block_id="stem_1" + ) + x, p = _reduction_a_cell( + x, p, filters // filter_multiplier, block_id="stem_2" + ) + + for i in range(num_blocks): + x, p = _normal_a_cell(x, p, filters, block_id=f"{i}") + + x, p0 = _reduction_a_cell( + x, p, filters * filter_multiplier, block_id=f"reduce_{num_blocks}" + ) + + p = p0 if not skip_reduction else p + + for i in range(num_blocks): + x, p = _normal_a_cell( + x, + p, + filters * filter_multiplier, + block_id=f"{num_blocks + i + 1}", + ) + + x, p0 = _reduction_a_cell( + x, + p, + filters * filter_multiplier**2, + block_id=f"reduce_{2 * num_blocks}", + ) + + p = p0 if not skip_reduction else p + + for i in range(num_blocks): + x, p = _normal_a_cell( + x, + p, + filters * filter_multiplier**2, + block_id=f"{2 * num_blocks + i + 1}", + ) + + x = layers.Activation("relu")(x) + + if include_top: + x = layers.GlobalAveragePooling2D()(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if default_size == 224: # mobile version + if include_top: + weights_path = file_utils.get_file( + "nasnet_mobile.h5", + NASNET_MOBILE_WEIGHT_PATH, + cache_subdir="models", + file_hash="020fb642bf7360b370c678b08e0adf61", + ) + else: + weights_path = file_utils.get_file( + "nasnet_mobile_no_top.h5", + NASNET_MOBILE_WEIGHT_PATH_NO_TOP, + cache_subdir="models", + file_hash="1ed92395b5b598bdda52abe5c0dbfd63", + ) + model.load_weights(weights_path) + elif default_size == 331: # large version + if include_top: + weights_path = file_utils.get_file( + "nasnet_large.h5", + NASNET_LARGE_WEIGHT_PATH, + cache_subdir="models", + file_hash="11577c9a518f0070763c2b964a382f17", + ) + else: + weights_path = file_utils.get_file( + "nasnet_large_no_top.h5", + NASNET_LARGE_WEIGHT_PATH_NO_TOP, + cache_subdir="models", + file_hash="d81d89dc07e6e56530c4e77faddd61b5", + ) + model.load_weights(weights_path) + else: + raise ValueError( + "ImageNet weights can only be loaded with NASNetLarge" + " or NASNetMobile" + ) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export( + [ + "keras.applications.nasnet.NASNetMobile", + "keras.applications.NASNetMobile", + ] +) +def NASNetMobile( + input_shape=None, + include_top=True, + weights="imagenet", + input_tensor=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="nasnet_mobile", +): + """Instantiates a Mobile NASNet model in ImageNet mode. + + Reference: + - [Learning Transferable Architectures for Scalable Image Recognition]( + https://arxiv.org/abs/1707.07012) (CVPR 2018) + + Optionally loads weights pre-trained on ImageNet. + Note that the data format convention used by the model is + the one specified in your Keras config at `~/.keras/keras.json`. + + Note: each Keras Application expects a specific kind of input preprocessing. + For NASNet, call `keras.applications.nasnet.preprocess_input` on your + inputs before passing them to the model. + + Args: + input_shape: Optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(224, 224, 3)` for NASNetMobile + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(224, 224, 3)` would be one valid value. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: `None` (random initialization) or + `imagenet` (ImageNet weights). For loading `imagenet` weights, + `input_shape` should be (224, 224, 3) + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` can + only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A Keras model instance. + """ + if backend.backend() == "torch": + raise ValueError( + "NASNetMobile is not available with the torch backend " + "at this time due to an outstanding bug. " + "If interested, please open a PR." + ) + if not include_top and input_shape is None: + input_shape = (224, 224, 3) + return NASNet( + input_shape, + penultimate_filters=1056, + num_blocks=4, + stem_block_filters=32, + skip_reduction=False, + filter_multiplier=2, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + pooling=pooling, + classes=classes, + default_size=224, + classifier_activation=classifier_activation, + name=name, + ) + + +@keras_export( + [ + "keras.applications.nasnet.NASNetLarge", + "keras.applications.NASNetLarge", + ] +) +def NASNetLarge( + input_shape=None, + include_top=True, + weights="imagenet", + input_tensor=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="nasnet_large", +): + """Instantiates a NASNet model in ImageNet mode. + + Reference: + - [Learning Transferable Architectures for Scalable Image Recognition]( + https://arxiv.org/abs/1707.07012) (CVPR 2018) + + Optionally loads weights pre-trained on ImageNet. + Note that the data format convention used by the model is + the one specified in your Keras config at `~/.keras/keras.json`. + + Note: each Keras Application expects a specific kind of input preprocessing. + For NASNet, call `keras.applications.nasnet.preprocess_input` on your + inputs before passing them to the model. + + Args: + input_shape: Optional shape tuple, only to be specified + if `include_top` is False (otherwise the input shape + has to be `(331, 331, 3)` for NASNetLarge. + It should have exactly 3 inputs channels, + and width and height should be no smaller than 32. + E.g. `(224, 224, 3)` would be one valid value. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: `None` (random initialization) or + `imagenet` (ImageNet weights). For loading `imagenet` weights, + `input_shape` should be (331, 331, 3) + input_tensor: Optional Keras tensor (i.e. output of + `layers.Input()`) + to use as image input for the model. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model + will be the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a + 2D tensor. + - `max` means that global max pooling will + be applied. + classes: Optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` + can only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A Keras model instance. + """ + return NASNet( + input_shape, + penultimate_filters=4032, + num_blocks=6, + stem_block_filters=96, + skip_reduction=True, + filter_multiplier=2, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + pooling=pooling, + classes=classes, + default_size=331, + classifier_activation=classifier_activation, + name=name, + ) + + +def _separable_conv_block( + ip, filters, kernel_size=(3, 3), strides=(1, 1), block_id=None +): + """Adds 2 blocks of [relu-separable conv-batchnorm]. + + Args: + ip: Input tensor + filters: Number of output filters per layer + kernel_size: Kernel size of separable convolutions + strides: Strided convolution for downsampling + block_id: String block_id + + Returns: + A Keras tensor + """ + channel_dim = 1 if backend.image_data_format() == "channels_first" else -1 + + with backend.name_scope(f"separable_conv_block_{block_id}"): + x = layers.Activation("relu")(ip) + if strides == (2, 2): + x = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(x, kernel_size), + name=f"separable_conv_1_pad_{block_id}", + )(x) + conv_pad = "valid" + else: + conv_pad = "same" + x = layers.SeparableConv2D( + filters, + kernel_size, + strides=strides, + name=f"separable_conv_1_{block_id}", + padding=conv_pad, + use_bias=False, + )(x) + x = layers.BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name=f"separable_conv_1_bn_{block_id}", + )(x) + x = layers.Activation("relu")(x) + x = layers.SeparableConv2D( + filters, + kernel_size, + name=f"separable_conv_2_{block_id}", + padding="same", + use_bias=False, + )(x) + x = layers.BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name=f"separable_conv_2_bn_{block_id}", + )(x) + return x + + +def _adjust_block(p, ip, filters, block_id=None): + """Adjusts the input `previous path` to match the shape of the `input`. + + Used in situations where the output number of filters needs to be changed. + + Args: + p: Input tensor which needs to be modified + ip: Input tensor whose shape needs to be matched + filters: Number of output filters to be matched + block_id: String block_id + + Returns: + Adjusted Keras tensor + """ + channel_dim = 1 if backend.image_data_format() == "channels_first" else -1 + img_dim = 2 if backend.image_data_format() == "channels_first" else -2 + + with backend.name_scope("adjust_block"): + if p is None: + p = ip + + elif p.shape[img_dim] != ip.shape[img_dim]: + with backend.name_scope(f"adjust_reduction_block_{block_id}"): + p = layers.Activation("relu", name=f"adjust_relu_1_{block_id}")( + p + ) + p1 = layers.AveragePooling2D( + (1, 1), + strides=(2, 2), + padding="valid", + name=f"adjust_avg_pool_1_{block_id}", + )(p) + p1 = layers.Conv2D( + filters // 2, + (1, 1), + padding="same", + use_bias=False, + name=f"adjust_conv_1_{block_id}", + kernel_initializer="he_normal", + )(p1) + + p2 = layers.ZeroPadding2D(padding=((0, 1), (0, 1)))(p) + p2 = layers.Cropping2D(cropping=((1, 0), (1, 0)))(p2) + p2 = layers.AveragePooling2D( + (1, 1), + strides=(2, 2), + padding="valid", + name=f"adjust_avg_pool_2_{block_id}", + )(p2) + p2 = layers.Conv2D( + filters // 2, + (1, 1), + padding="same", + use_bias=False, + name=f"adjust_conv_2_{block_id}", + kernel_initializer="he_normal", + )(p2) + + p = layers.concatenate([p1, p2], axis=channel_dim) + p = layers.BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name=f"adjust_bn_{block_id}", + )(p) + + elif p.shape[channel_dim] != filters: + with backend.name_scope(f"adjust_projection_block_{block_id}"): + p = layers.Activation("relu")(p) + p = layers.Conv2D( + filters, + (1, 1), + strides=(1, 1), + padding="same", + name=f"adjust_conv_projection_{block_id}", + use_bias=False, + kernel_initializer="he_normal", + )(p) + p = layers.BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name=f"adjust_bn_{block_id}", + )(p) + return p + + +def _normal_a_cell(ip, p, filters, block_id=None): + """Adds a Normal cell for NASNet-A (Fig. 4 in the paper). + + Args: + ip: Input tensor `x` + p: Input tensor `p` + filters: Number of output filters + block_id: String block_id + + Returns: + A Keras tensor + """ + channel_dim = 1 if backend.image_data_format() == "channels_first" else -1 + + with backend.name_scope(f"normal_A_block_{block_id}"): + p = _adjust_block(p, ip, filters, block_id) + + h = layers.Activation("relu")(ip) + h = layers.Conv2D( + filters, + (1, 1), + strides=(1, 1), + padding="same", + name=f"normal_conv_1_{block_id}", + use_bias=False, + kernel_initializer="he_normal", + )(h) + h = layers.BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name=f"normal_bn_1_{block_id}", + )(h) + + with backend.name_scope("block_1"): + x1_1 = _separable_conv_block( + h, + filters, + kernel_size=(5, 5), + block_id=f"normal_left1_{block_id}", + ) + x1_2 = _separable_conv_block( + p, filters, block_id=f"normal_right1_{block_id}" + ) + x1 = layers.add([x1_1, x1_2], name=f"normal_add_1_{block_id}") + + with backend.name_scope("block_2"): + x2_1 = _separable_conv_block( + p, filters, (5, 5), block_id=f"normal_left2_{block_id}" + ) + x2_2 = _separable_conv_block( + p, filters, (3, 3), block_id=f"normal_right2_{block_id}" + ) + x2 = layers.add([x2_1, x2_2], name=f"normal_add_2_{block_id}") + + with backend.name_scope("block_3"): + x3 = layers.AveragePooling2D( + (3, 3), + strides=(1, 1), + padding="same", + name=f"normal_left3_{block_id}", + )(h) + x3 = layers.add([x3, p], name=f"normal_add_3_{block_id}") + + with backend.name_scope("block_4"): + x4_1 = layers.AveragePooling2D( + (3, 3), + strides=(1, 1), + padding="same", + name=f"normal_left4_{block_id}", + )(p) + x4_2 = layers.AveragePooling2D( + (3, 3), + strides=(1, 1), + padding="same", + name=f"normal_right4_{block_id}", + )(p) + x4 = layers.add([x4_1, x4_2], name=f"normal_add_4_{block_id}") + + with backend.name_scope("block_5"): + x5 = _separable_conv_block( + h, filters, block_id=f"normal_left5_{block_id}" + ) + x5 = layers.add([x5, h], name=f"normal_add_5_{block_id}") + + x = layers.concatenate( + [p, x1, x2, x3, x4, x5], + axis=channel_dim, + name=f"normal_concat_{block_id}", + ) + return x, ip + + +def _reduction_a_cell(ip, p, filters, block_id=None): + """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper). + + Args: + ip: Input tensor `x` + p: Input tensor `p` + filters: Number of output filters + block_id: String block_id + + Returns: + A Keras tensor + """ + channel_dim = 1 if backend.image_data_format() == "channels_first" else -1 + + with backend.name_scope(f"reduction_A_block_{block_id}"): + p = _adjust_block(p, ip, filters, block_id) + + h = layers.Activation("relu")(ip) + h = layers.Conv2D( + filters, + (1, 1), + strides=(1, 1), + padding="same", + name=f"reduction_conv_1_{block_id}", + use_bias=False, + kernel_initializer="he_normal", + )(h) + h = layers.BatchNormalization( + axis=channel_dim, + momentum=0.9997, + epsilon=1e-3, + name=f"reduction_bn_1_{block_id}", + )(h) + h3 = layers.ZeroPadding2D( + padding=imagenet_utils.correct_pad(h, 3), + name=f"reduction_pad_1_{block_id}", + )(h) + + with backend.name_scope("block_1"): + x1_1 = _separable_conv_block( + h, + filters, + (5, 5), + strides=(2, 2), + block_id=f"reduction_left1_{block_id}", + ) + x1_2 = _separable_conv_block( + p, + filters, + (7, 7), + strides=(2, 2), + block_id=f"reduction_right1_{block_id}", + ) + x1 = layers.add([x1_1, x1_2], name=f"reduction_add_1_{block_id}") + + with backend.name_scope("block_2"): + x2_1 = layers.MaxPooling2D( + (3, 3), + strides=(2, 2), + padding="valid", + name=f"reduction_left2_{block_id}", + )(h3) + x2_2 = _separable_conv_block( + p, + filters, + (7, 7), + strides=(2, 2), + block_id=f"reduction_right2_{block_id}", + ) + x2 = layers.add([x2_1, x2_2], name=f"reduction_add_2_{block_id}") + + with backend.name_scope("block_3"): + x3_1 = layers.AveragePooling2D( + (3, 3), + strides=(2, 2), + padding="valid", + name=f"reduction_left3_{block_id}", + )(h3) + x3_2 = _separable_conv_block( + p, + filters, + (5, 5), + strides=(2, 2), + block_id=f"reduction_right3_{block_id}", + ) + x3 = layers.add([x3_1, x3_2], name=f"reduction_add3_{block_id}") + + with backend.name_scope("block_4"): + x4 = layers.AveragePooling2D( + (3, 3), + strides=(1, 1), + padding="same", + name=f"reduction_left4_{block_id}", + )(x1) + x4 = layers.add([x2, x4]) + + with backend.name_scope("block_5"): + x5_1 = _separable_conv_block( + x1, filters, (3, 3), block_id=f"reduction_left4_{block_id}" + ) + x5_2 = layers.MaxPooling2D( + (3, 3), + strides=(2, 2), + padding="valid", + name=f"reduction_right5_{block_id}", + )(h3) + x5 = layers.add([x5_1, x5_2], name=f"reduction_add4_{block_id}") + + x = layers.concatenate( + [x2, x3, x4, x5], + axis=channel_dim, + name=f"reduction_concat_{block_id}", + ) + return x, ip + + +@keras_export("keras.applications.nasnet.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.nasnet.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/resnet.py b/keras/src/applications/resnet.py new file mode 100644 index 000000000000..95c805cffc9a --- /dev/null +++ b/keras/src/applications/resnet.py @@ -0,0 +1,591 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +BASE_WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/resnet/" +) +WEIGHTS_HASHES = { + "resnet50": ( + "2cb95161c43110f7111970584f804107", + "4d473c1dd8becc155b73f8504c6f6626", + ), + "resnet101": ( + "f1aeb4b969a6efcfb50fad2f0c20cfc5", + "88cf7a10940856eca736dc7b7e228a21", + ), + "resnet152": ( + "100835be76be38e30d865e96f2aaae62", + "ee4c566cf9a93f14d82f913c2dc6dd0c", + ), + "resnet50v2": ( + "3ef43a0b657b3be2300d5770ece849e0", + "fac2f116257151a9d068a22e544a4917", + ), + "resnet101v2": ( + "6343647c601c52e1368623803854d971", + "c0ed64b8031c3730f411d2eb4eea35b5", + ), + "resnet152v2": ( + "a49b44d1979771252814e80f8ec446f9", + "ed17cf2e0169df9d443503ef94b23b33", + ), + "resnext50": ( + "67a5b30d522ed92f75a1f16eef299d1a", + "62527c363bdd9ec598bed41947b379fc", + ), + "resnext101": ( + "34fb605428fcc7aa4d62f44404c11509", + "0f678c91647380debd923963594981b3", + ), +} + + +def ResNet( + stack_fn, + preact, + use_bias, + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet", + weights_name=None, +): + """Instantiates the ResNet, ResNetV2, and ResNeXt architecture. + + Args: + stack_fn: A function that returns output tensor for the + stacked residual blocks. + preact: Whether to use pre-activation or not. `True` for ResNetV2, + `False` for ResNet and ResNeXt. + use_bias: Whether to use biases for convolutional layers or not. + `True` for ResNet and ResNetV2, `False` for ResNeXt. + name: Name of the model. + include_top: Whether to include the fully-connected + layer at the top of the network. + weights: One of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: Optional shape tuple, only to be specified + if `include_top` is `False` (otherwise the input shape + has to be `(224, 224, 3)` (with `channels_last` data format) + or `(3, 224, 224)` (with `"channels_first"` data format). It + should have exactly 3 inputs channels. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional layer. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional layer, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, + and if no `weights` argument is specified. + classifier_activation: A `str` or callable. The activation + function to use on the "top" layer. Ignored unless + `include_top=True`. Set `classifier_activation=None` to + return the logits of the "top" layer. When loading + pretrained weights, `classifier_activation` can only be + `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A Model instance. + """ + + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), 'imagenet' " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded. Received: " + f"weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + "If using `weights='imagenet'` with `include_top=True`, " + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=224, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + if backend.image_data_format() == "channels_last": + bn_axis = 3 + else: + bn_axis = 1 + + x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name="conv1_pad")( + img_input + ) + x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name="conv1_conv")(x) + + if not preact: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name="conv1_bn" + )(x) + x = layers.Activation("relu", name="conv1_relu")(x) + + x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name="pool1_pad")(x) + x = layers.MaxPooling2D(3, strides=2, name="pool1_pool")(x) + + x = stack_fn(x) + + if preact: + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name="post_bn" + )(x) + x = layers.Activation("relu", name="post_relu")(x) + + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + + # Validate activation for the classifier layer + imagenet_utils.validate_activation(classifier_activation, weights) + + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if (weights == "imagenet") and (weights_name in WEIGHTS_HASHES): + if include_top: + file_name = f"{weights_name}_weights_tf_dim_ordering_tf_kernels.h5" + file_hash = WEIGHTS_HASHES[weights_name][0] + else: + file_name = ( + f"{weights_name}_weights_tf_dim_ordering_tf_kernels_notop.h5" + ) + file_hash = WEIGHTS_HASHES[weights_name][1] + weights_path = file_utils.get_file( + file_name, + f"{BASE_WEIGHTS_PATH}{file_name}", + cache_subdir="models", + file_hash=file_hash, + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +def residual_block_v1( + x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None +): + """A residual block for ResNet*_v1. + + Args: + x: Input tensor. + filters: No of filters in the bottleneck layer. + kernel_size: Kernel size of the bottleneck layer. Defaults to `3`. + stride: Stride of the first layer. Defaults to `1`. + conv_shortcut: Use convolution shortcut if `True`, otherwise + use identity shortcut. Defaults to `True` + name(optional): Name of the block + + Returns: + Output tensor for the residual block. + """ + + if backend.image_data_format() == "channels_last": + bn_axis = 3 + else: + bn_axis = 1 + + if conv_shortcut: + shortcut = layers.Conv2D( + 4 * filters, 1, strides=stride, name=f"{name}_0_conv" + )(x) + shortcut = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_0_bn" + )(shortcut) + else: + shortcut = x + + x = layers.Conv2D(filters, 1, strides=stride, name=f"{name}_1_conv")(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", name=f"{name}_1_relu")(x) + + x = layers.Conv2D( + filters, kernel_size, padding="SAME", name=f"{name}_2_conv" + )(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_2_bn" + )(x) + x = layers.Activation("relu", name=f"{name}_2_relu")(x) + + x = layers.Conv2D(4 * filters, 1, name=f"{name}_3_conv")(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_3_bn" + )(x) + + x = layers.Add(name=f"{name}_add")([shortcut, x]) + x = layers.Activation("relu", name=f"{name}_out")(x) + return x + + +def stack_residual_blocks_v1(x, filters, blocks, stride1=2, name=None): + """A set of stacked residual blocks. + + Args: + x: Input tensor. + filters: Number of filters in the bottleneck layer in a block. + blocks: Number of blocks in the stacked blocks. + stride1: Stride of the first layer in the first block. Defaults to `2`. + name: Stack label. + + Returns: + Output tensor for the stacked blocks. + """ + + x = residual_block_v1(x, filters, stride=stride1, name=f"{name}_block1") + for i in range(2, blocks + 1): + x = residual_block_v1( + x, filters, conv_shortcut=False, name=f"{name}_block{i}" + ) + return x + + +def residual_block_v2( + x, filters, kernel_size=3, stride=1, conv_shortcut=False, name=None +): + """A residual block for ResNet*_v2. + + Args: + x: Input tensor. + filters: No of filters in the bottleneck layer. + kernel_size: Kernel size of the bottleneck layer. Defaults to `3`. + stride: Stride of the first layer. Defaults to `1`. + conv_shortcut: Use convolution shortcut if `True`, otherwise + use identity shortcut. Defaults to `True` + name(optional): Name of the block + + Returns: + Output tensor for the residual block. + """ + + if backend.image_data_format() == "channels_last": + bn_axis = 3 + else: + bn_axis = 1 + + preact = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_preact_bn" + )(x) + preact = layers.Activation("relu", name=f"{name}_preact_relu")(preact) + + if conv_shortcut: + shortcut = layers.Conv2D( + 4 * filters, 1, strides=stride, name=f"{name}_0_conv" + )(preact) + else: + shortcut = ( + layers.MaxPooling2D(1, strides=stride)(x) if stride > 1 else x + ) + + x = layers.Conv2D( + filters, 1, strides=1, use_bias=False, name=f"{name}_1_conv" + )(preact) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_1_bn" + )(x) + x = layers.Activation("relu", name=f"{name}_1_relu")(x) + + x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=f"{name}_2_pad")(x) + x = layers.Conv2D( + filters, + kernel_size, + strides=stride, + use_bias=False, + name=f"{name}_2_conv", + )(x) + x = layers.BatchNormalization( + axis=bn_axis, epsilon=1.001e-5, name=f"{name}_2_bn" + )(x) + x = layers.Activation("relu", name=f"{name}_2_relu")(x) + + x = layers.Conv2D(4 * filters, 1, name=f"{name}_3_conv")(x) + x = layers.Add(name=f"{name}_out")([shortcut, x]) + return x + + +def stack_residual_blocks_v2(x, filters, blocks, stride1=2, name=None): + """A set of stacked residual blocks. + + Args: + x: Input tensor. + filters: Number of filters in the bottleneck layer in a block. + blocks: Number of blocks in the stacked blocks. + stride1: Stride of the first layer in the first block. Defaults to `2`. + name: Stack label. + + Returns: + Output tensor for the stacked blocks. + """ + + x = residual_block_v2(x, filters, conv_shortcut=True, name=f"{name}_block1") + for i in range(2, blocks): + x = residual_block_v2(x, filters, name=f"{name}_block{i}") + x = residual_block_v2( + x, filters, stride=stride1, name=f"{name}_block{str(blocks)}" + ) + return x + + +@keras_export( + [ + "keras.applications.resnet50.ResNet50", + "keras.applications.resnet.ResNet50", + "keras.applications.ResNet50", + ] +) +def ResNet50( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet50", +): + """Instantiates the ResNet50 architecture.""" + + def stack_fn(x): + x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name="conv2") + x = stack_residual_blocks_v1(x, 128, 4, name="conv3") + x = stack_residual_blocks_v1(x, 256, 6, name="conv4") + return stack_residual_blocks_v1(x, 512, 3, name="conv5") + + return ResNet( + stack_fn, + preact=False, + use_bias=True, + weights_name="resnet50", + name=name, + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.resnet.ResNet101", + "keras.applications.ResNet101", + ] +) +def ResNet101( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet101", +): + """Instantiates the ResNet101 architecture.""" + + def stack_fn(x): + x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name="conv2") + x = stack_residual_blocks_v1(x, 128, 4, name="conv3") + x = stack_residual_blocks_v1(x, 256, 23, name="conv4") + return stack_residual_blocks_v1(x, 512, 3, name="conv5") + + return ResNet( + stack_fn, + preact=False, + use_bias=True, + name=name, + weights_name="resnet101", + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.resnet.ResNet152", + "keras.applications.ResNet152", + ] +) +def ResNet152( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet152", +): + """Instantiates the ResNet152 architecture.""" + + def stack_fn(x): + x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name="conv2") + x = stack_residual_blocks_v1(x, 128, 8, name="conv3") + x = stack_residual_blocks_v1(x, 256, 36, name="conv4") + return stack_residual_blocks_v1(x, 512, 3, name="conv5") + + return ResNet( + stack_fn, + preact=False, + use_bias=True, + name=name, + weights_name="resnet152", + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.resnet50.preprocess_input", + "keras.applications.resnet.preprocess_input", + ] +) +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="caffe" + ) + + +@keras_export( + [ + "keras.applications.resnet50.decode_predictions", + "keras.applications.resnet.decode_predictions", + ] +) +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ + +DOC = """ + +Reference: +- [Deep Residual Learning for Image Recognition]( + https://arxiv.org/abs/1512.03385) (CVPR 2015) + +For image classification use cases, see [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + +For transfer learning use cases, make sure to read the +[guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + +Note: each Keras Application expects a specific kind of input preprocessing. +For ResNet, call `keras.applications.resnet.preprocess_input` on your +inputs before passing them to the model. `resnet.preprocess_input` will convert +the input images from RGB to BGR, then will zero-center each color channel with +respect to the ImageNet dataset, without scaling. + +Args: + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), or the path to the weights + file to be loaded. + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified if `include_top` + is `False` (otherwise the input shape has to be `(224, 224, 3)` + (with `"channels_last"` data format) or `(3, 224, 224)` + (with `"channels_first"` data format). It should have exactly 3 + inputs channels, and width and height should be no smaller than 32. + E.g. `(200, 200, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` means that the output of the model will be the 4D tensor + output of the last convolutional block. + - `avg` means that global average pooling will be applied to the output + of the last convolutional block, and thus the output of the + model will be a 2D tensor. + - `max` means that global max pooling will be applied. + classes: optional number of classes to classify images into, only to be + specified if `include_top` is `True`, and if no `weights` argument is + specified. Defaults to `1000`. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + When loading pretrained weights, `classifier_activation` can only + be `None` or `"softmax"`. + name: The name of the model (string). + +Returns: + A Model instance. +""" + +setattr(ResNet50, "__doc__", ResNet50.__doc__ + DOC) +setattr(ResNet101, "__doc__", ResNet101.__doc__ + DOC) +setattr(ResNet152, "__doc__", ResNet152.__doc__ + DOC) diff --git a/keras/src/applications/resnet_v2.py b/keras/src/applications/resnet_v2.py new file mode 100644 index 000000000000..590efa0bbbda --- /dev/null +++ b/keras/src/applications/resnet_v2.py @@ -0,0 +1,208 @@ +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.applications import resnet + + +@keras_export( + [ + "keras.applications.ResNet50V2", + "keras.applications.resnet_v2.ResNet50V2", + ] +) +def ResNet50V2( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet50v2", +): + """Instantiates the ResNet50V2 architecture.""" + + def stack_fn(x): + x = resnet.stack_residual_blocks_v2(x, 64, 3, name="conv2") + x = resnet.stack_residual_blocks_v2(x, 128, 4, name="conv3") + x = resnet.stack_residual_blocks_v2(x, 256, 6, name="conv4") + return resnet.stack_residual_blocks_v2( + x, 512, 3, stride1=1, name="conv5" + ) + + return resnet.ResNet( + stack_fn, + True, + True, + name=name, + weights_name="resnet50v2", + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.ResNet101V2", + "keras.applications.resnet_v2.ResNet101V2", + ] +) +def ResNet101V2( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet101v2", +): + """Instantiates the ResNet101V2 architecture.""" + + def stack_fn(x): + x = resnet.stack_residual_blocks_v2(x, 64, 3, name="conv2") + x = resnet.stack_residual_blocks_v2(x, 128, 4, name="conv3") + x = resnet.stack_residual_blocks_v2(x, 256, 23, name="conv4") + return resnet.stack_residual_blocks_v2( + x, 512, 3, stride1=1, name="conv5" + ) + + return resnet.ResNet( + stack_fn, + True, + True, + name=name, + weights_name="resnet101v2", + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export( + [ + "keras.applications.ResNet152V2", + "keras.applications.resnet_v2.ResNet152V2", + ] +) +def ResNet152V2( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="resnet152v2", +): + """Instantiates the ResNet152V2 architecture.""" + + def stack_fn(x): + x = resnet.stack_residual_blocks_v2(x, 64, 3, name="conv2") + x = resnet.stack_residual_blocks_v2(x, 128, 8, name="conv3") + x = resnet.stack_residual_blocks_v2(x, 256, 36, name="conv4") + return resnet.stack_residual_blocks_v2( + x, 512, 3, stride1=1, name="conv5" + ) + + return resnet.ResNet( + stack_fn, + True, + True, + name=name, + weights_name="resnet152v2", + include_top=include_top, + weights=weights, + input_tensor=input_tensor, + input_shape=input_shape, + pooling=pooling, + classes=classes, + classifier_activation=classifier_activation, + ) + + +@keras_export("keras.applications.resnet_v2.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.resnet_v2.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ + + +DOC = """ + +Reference: +- [Identity Mappings in Deep Residual Networks]( + https://arxiv.org/abs/1603.05027) (CVPR 2016) + +For image classification use cases, see [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + +For transfer learning use cases, make sure to read the +[guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + +Note: each Keras Application expects a specific kind of input preprocessing. +For ResNet, call `keras.applications.resnet_v2.preprocess_input` on your +inputs before passing them to the model. `resnet_v2.preprocess_input` will +scale input pixels between -1 and 1. + +Args: + include_top: whether to include the fully-connected + layer at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), or the path to the weights + file to be loaded. + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified if `include_top` + is `False` (otherwise the input shape has to be `(224, 224, 3)` + (with `"channels_last"` data format) or `(3, 224, 224)` + (with `"channels_first"` data format). It should have exactly 3 + inputs channels, and width and height should be no smaller than 32. + E.g. `(200, 200, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` means that the output of the model will be the 4D tensor + output of the last convolutional block. + - `avg` means that global average pooling will be applied to the output + of the last convolutional block, and thus the output of the + model will be a 2D tensor. + - `max` means that global max pooling will be applied. + classes: optional number of classes to classify images into, only to be + specified if `include_top` is `True`, and if no `weights` argument is + specified. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" layer. + When loading pretrained weights, `classifier_activation` can only + be `None` or `"softmax"`. + name: The name of the model (string). + +Returns: + A Model instance. +""" + +setattr(ResNet50V2, "__doc__", ResNet50V2.__doc__ + DOC) +setattr(ResNet101V2, "__doc__", ResNet101V2.__doc__ + DOC) +setattr(ResNet152V2, "__doc__", ResNet152V2.__doc__ + DOC) diff --git a/keras/src/applications/vgg16.py b/keras/src/applications/vgg16.py new file mode 100644 index 000000000000..21163a4efd49 --- /dev/null +++ b/keras/src/applications/vgg16.py @@ -0,0 +1,248 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5" +) +WEIGHTS_PATH_NO_TOP = ( + "https://storage.googleapis.com/tensorflow/" + "keras-applications/vgg16/" + "vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5" +) + + +@keras_export(["keras.applications.vgg16.VGG16", "keras.applications.VGG16"]) +def VGG16( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="vgg16", +): + """Instantiates the VGG16 model. + + Reference: + - [Very Deep Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556) (ICLR 2015) + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + The default input size for this model is 224x224. + + Note: each Keras Application expects a specific kind of input preprocessing. + For VGG16, call `keras.applications.vgg16.preprocess_input` on your + inputs before passing them to the model. + `vgg16.preprocess_input` will convert the input images from RGB to BGR, + then will zero-center each color channel with respect to the ImageNet + dataset, without scaling. + + Args: + include_top: whether to include the 3 fully-connected + layers at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is `False` (otherwise the input shape + has to be `(224, 224, 3)` + (with `channels_last` data format) or + `(3, 224, 224)` (with `"channels_first"` data format). + It should have exactly 3 input channels, + and width and height should be no smaller than 32. + E.g. `(200, 200, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` + can only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A `Model` instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), 'imagenet' " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded. Received: " + f"weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + "If using `weights='imagenet'` with `include_top=True`, " + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=224, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + # Block 1 + x = layers.Conv2D( + 64, (3, 3), activation="relu", padding="same", name="block1_conv1" + )(img_input) + x = layers.Conv2D( + 64, (3, 3), activation="relu", padding="same", name="block1_conv2" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block1_pool")(x) + + # Block 2 + x = layers.Conv2D( + 128, (3, 3), activation="relu", padding="same", name="block2_conv1" + )(x) + x = layers.Conv2D( + 128, (3, 3), activation="relu", padding="same", name="block2_conv2" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block2_pool")(x) + + # Block 3 + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv1" + )(x) + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv2" + )(x) + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv3" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block3_pool")(x) + + # Block 4 + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv1" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv2" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv3" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block4_pool")(x) + + # Block 5 + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv1" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv2" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv3" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block5_pool")(x) + + if include_top: + # Classification block + x = layers.Flatten(name="flatten")(x) + x = layers.Dense(4096, activation="relu", name="fc1")(x) + x = layers.Dense(4096, activation="relu", name="fc2")(x) + + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + weights_path = file_utils.get_file( + "vgg16_weights_tf_dim_ordering_tf_kernels.h5", + WEIGHTS_PATH, + cache_subdir="models", + file_hash="64373286793e3c8b2b4e3219cbf3544b", + ) + else: + weights_path = file_utils.get_file( + "vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5", + WEIGHTS_PATH_NO_TOP, + cache_subdir="models", + file_hash="6d6bbae143d832006294945121d1f1fc", + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export("keras.applications.vgg16.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="caffe" + ) + + +@keras_export("keras.applications.vgg16.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/vgg19.py b/keras/src/applications/vgg19.py new file mode 100644 index 000000000000..d7ea1fce2d9c --- /dev/null +++ b/keras/src/applications/vgg19.py @@ -0,0 +1,256 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5" +) +WEIGHTS_PATH_NO_TOP = ( + "https://storage.googleapis.com/tensorflow/" + "keras-applications/vgg19/" + "vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5" +) + + +@keras_export(["keras.applications.vgg19.VGG19", "keras.applications.VGG19"]) +def VGG19( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="vgg19", +): + """Instantiates the VGG19 model. + + Reference: + - [Very Deep Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556) (ICLR 2015) + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + The default input size for this model is 224x224. + + Note: each Keras Application expects a specific kind of input preprocessing. + For VGG19, call `keras.applications.vgg19.preprocess_input` on your + inputs before passing them to the model. + `vgg19.preprocess_input` will convert the input images from RGB to BGR, + then will zero-center each color channel with respect to the ImageNet + dataset, without scaling. + + Args: + include_top: whether to include the 3 fully-connected + layers at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is `False` (otherwise the input shape + has to be `(224, 224, 3)` + (with `channels_last` data format) or + `(3, 224, 224)` (with `"channels_first"` data format). + It should have exactly 3 input channels, + and width and height should be no smaller than 32. + E.g. `(200, 200, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` can + only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A model instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), 'imagenet' " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded. Received: " + f"weights={weights}" + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + "If using `weights='imagenet'` with `include_top=True`, " + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=224, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + # Block 1 + x = layers.Conv2D( + 64, (3, 3), activation="relu", padding="same", name="block1_conv1" + )(img_input) + x = layers.Conv2D( + 64, (3, 3), activation="relu", padding="same", name="block1_conv2" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block1_pool")(x) + + # Block 2 + x = layers.Conv2D( + 128, (3, 3), activation="relu", padding="same", name="block2_conv1" + )(x) + x = layers.Conv2D( + 128, (3, 3), activation="relu", padding="same", name="block2_conv2" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block2_pool")(x) + + # Block 3 + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv1" + )(x) + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv2" + )(x) + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv3" + )(x) + x = layers.Conv2D( + 256, (3, 3), activation="relu", padding="same", name="block3_conv4" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block3_pool")(x) + + # Block 4 + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv1" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv2" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv3" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block4_conv4" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block4_pool")(x) + + # Block 5 + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv1" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv2" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv3" + )(x) + x = layers.Conv2D( + 512, (3, 3), activation="relu", padding="same", name="block5_conv4" + )(x) + x = layers.MaxPooling2D((2, 2), strides=(2, 2), name="block5_pool")(x) + + if include_top: + # Classification block + x = layers.Flatten(name="flatten")(x) + x = layers.Dense(4096, activation="relu", name="fc1")(x) + x = layers.Dense(4096, activation="relu", name="fc2")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + weights_path = file_utils.get_file( + "vgg19_weights_tf_dim_ordering_tf_kernels.h5", + WEIGHTS_PATH, + cache_subdir="models", + file_hash="cbe5617147190e668d6c5d5026f83318", + ) + else: + weights_path = file_utils.get_file( + "vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5", + WEIGHTS_PATH_NO_TOP, + cache_subdir="models", + file_hash="253f8cb515780f3b799900260a226db6", + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export("keras.applications.vgg19.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="caffe" + ) + + +@keras_export("keras.applications.vgg19.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/applications/xception.py b/keras/src/applications/xception.py new file mode 100644 index 000000000000..45d0f8179031 --- /dev/null +++ b/keras/src/applications/xception.py @@ -0,0 +1,355 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.applications import imagenet_utils +from keras.src.models import Functional +from keras.src.ops import operation_utils +from keras.src.utils import file_utils + +WEIGHTS_PATH = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "xception/xception_weights_tf_dim_ordering_tf_kernels.h5" +) +WEIGHTS_PATH_NO_TOP = ( + "https://storage.googleapis.com/tensorflow/keras-applications/" + "xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5" +) + + +@keras_export( + [ + "keras.applications.xception.Xception", + "keras.applications.Xception", + ] +) +def Xception( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + name="xception", +): + """Instantiates the Xception architecture. + + Reference: + - [Xception: Deep Learning with Depthwise Separable Convolutions]( + https://arxiv.org/abs/1610.02357) (CVPR 2017) + + For image classification use cases, see + [this page for detailed examples]( + https://keras.io/api/applications/#usage-examples-for-image-classification-models). + + For transfer learning use cases, make sure to read the + [guide to transfer learning & fine-tuning]( + https://keras.io/guides/transfer_learning/). + + The default input image size for this model is 299x299. + + Note: each Keras Application expects a specific kind of input preprocessing. + For Xception, call `keras.applications.xception.preprocess_input` + on your inputs before passing them to the model. + `xception.preprocess_input` will scale input pixels between -1 and 1. + + Args: + include_top: whether to include the 3 fully-connected + layers at the top of the network. + weights: one of `None` (random initialization), + `"imagenet"` (pre-training on ImageNet), + or the path to the weights file to be loaded. + input_tensor: optional Keras tensor + (i.e. output of `layers.Input()`) + to use as image input for the model. + input_shape: optional shape tuple, only to be specified + if `include_top` is `False` (otherwise the input shape + has to be `(299, 299, 3)`. + It should have exactly 3 inputs channels, + and width and height should be no smaller than 71. + E.g. `(150, 150, 3)` would be one valid value. + pooling: Optional pooling mode for feature extraction + when `include_top` is `False`. + - `None` means that the output of the model will be + the 4D tensor output of the + last convolutional block. + - `avg` means that global average pooling + will be applied to the output of the + last convolutional block, and thus + the output of the model will be a 2D tensor. + - `max` means that global max pooling will + be applied. + classes: optional number of classes to classify images + into, only to be specified if `include_top` is `True`, and + if no `weights` argument is specified. Defaults to `1000`. + classifier_activation: A `str` or callable. The activation function to + use on the "top" layer. Ignored unless `include_top=True`. Set + `classifier_activation=None` to return the logits of the "top" + layer. When loading pretrained weights, `classifier_activation` can + only be `None` or `"softmax"`. + name: The name of the model (string). + + Returns: + A model instance. + """ + if not (weights in {"imagenet", None} or file_utils.exists(weights)): + raise ValueError( + "The `weights` argument should be either " + "`None` (random initialization), 'imagenet' " + "(pre-training on ImageNet), " + "or the path to the weights file to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + "If using `weights='imagenet'` with `include_top=True`, " + "`classes` should be 1000. " + f"Received classes={classes}" + ) + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=299, + min_size=71, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + channel_axis = 1 if backend.image_data_format() == "channels_first" else -1 + + x = layers.Conv2D( + 32, (3, 3), strides=(2, 2), use_bias=False, name="block1_conv1" + )(img_input) + x = layers.BatchNormalization(axis=channel_axis, name="block1_conv1_bn")(x) + x = layers.Activation("relu", name="block1_conv1_act")(x) + x = layers.Conv2D(64, (3, 3), use_bias=False, name="block1_conv2")(x) + x = layers.BatchNormalization(axis=channel_axis, name="block1_conv2_bn")(x) + x = layers.Activation("relu", name="block1_conv2_act")(x) + + residual = layers.Conv2D( + 128, (1, 1), strides=(2, 2), padding="same", use_bias=False + )(x) + residual = layers.BatchNormalization(axis=channel_axis)(residual) + + x = layers.SeparableConv2D( + 128, (3, 3), padding="same", use_bias=False, name="block2_sepconv1" + )(x) + x = layers.BatchNormalization(axis=channel_axis, name="block2_sepconv1_bn")( + x + ) + x = layers.Activation("relu", name="block2_sepconv2_act")(x) + x = layers.SeparableConv2D( + 128, (3, 3), padding="same", use_bias=False, name="block2_sepconv2" + )(x) + x = layers.BatchNormalization(axis=channel_axis, name="block2_sepconv2_bn")( + x + ) + + x = layers.MaxPooling2D( + (3, 3), strides=(2, 2), padding="same", name="block2_pool" + )(x) + x = layers.add([x, residual]) + + residual = layers.Conv2D( + 256, (1, 1), strides=(2, 2), padding="same", use_bias=False + )(x) + residual = layers.BatchNormalization(axis=channel_axis)(residual) + + x = layers.Activation("relu", name="block3_sepconv1_act")(x) + x = layers.SeparableConv2D( + 256, (3, 3), padding="same", use_bias=False, name="block3_sepconv1" + )(x) + x = layers.BatchNormalization(axis=channel_axis, name="block3_sepconv1_bn")( + x + ) + x = layers.Activation("relu", name="block3_sepconv2_act")(x) + x = layers.SeparableConv2D( + 256, (3, 3), padding="same", use_bias=False, name="block3_sepconv2" + )(x) + x = layers.BatchNormalization(axis=channel_axis, name="block3_sepconv2_bn")( + x + ) + + x = layers.MaxPooling2D( + (3, 3), strides=(2, 2), padding="same", name="block3_pool" + )(x) + x = layers.add([x, residual]) + + residual = layers.Conv2D( + 728, (1, 1), strides=(2, 2), padding="same", use_bias=False + )(x) + residual = layers.BatchNormalization(axis=channel_axis)(residual) + + x = layers.Activation("relu", name="block4_sepconv1_act")(x) + x = layers.SeparableConv2D( + 728, (3, 3), padding="same", use_bias=False, name="block4_sepconv1" + )(x) + x = layers.BatchNormalization(axis=channel_axis, name="block4_sepconv1_bn")( + x + ) + x = layers.Activation("relu", name="block4_sepconv2_act")(x) + x = layers.SeparableConv2D( + 728, (3, 3), padding="same", use_bias=False, name="block4_sepconv2" + )(x) + x = layers.BatchNormalization(axis=channel_axis, name="block4_sepconv2_bn")( + x + ) + + x = layers.MaxPooling2D( + (3, 3), strides=(2, 2), padding="same", name="block4_pool" + )(x) + x = layers.add([x, residual]) + + for i in range(8): + residual = x + prefix = f"block{i + 5}" + + x = layers.Activation("relu", name=f"{prefix}_sepconv1_act")(x) + x = layers.SeparableConv2D( + 728, + (3, 3), + padding="same", + use_bias=False, + name=f"{prefix}_sepconv1", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name=f"{prefix}_sepconv1_bn" + )(x) + x = layers.Activation("relu", name=f"{prefix}_sepconv2_act")(x) + x = layers.SeparableConv2D( + 728, + (3, 3), + padding="same", + use_bias=False, + name=f"{prefix}_sepconv2", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name=f"{prefix}_sepconv2_bn" + )(x) + x = layers.Activation("relu", name=f"{prefix}_sepconv3_act")(x) + x = layers.SeparableConv2D( + 728, + (3, 3), + padding="same", + use_bias=False, + name=f"{prefix}_sepconv3", + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name=f"{prefix}_sepconv3_bn" + )(x) + + x = layers.add([x, residual]) + + residual = layers.Conv2D( + 1024, (1, 1), strides=(2, 2), padding="same", use_bias=False + )(x) + residual = layers.BatchNormalization(axis=channel_axis)(residual) + + x = layers.Activation("relu", name="block13_sepconv1_act")(x) + x = layers.SeparableConv2D( + 728, (3, 3), padding="same", use_bias=False, name="block13_sepconv1" + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name="block13_sepconv1_bn" + )(x) + x = layers.Activation("relu", name="block13_sepconv2_act")(x) + x = layers.SeparableConv2D( + 1024, (3, 3), padding="same", use_bias=False, name="block13_sepconv2" + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name="block13_sepconv2_bn" + )(x) + + x = layers.MaxPooling2D( + (3, 3), strides=(2, 2), padding="same", name="block13_pool" + )(x) + x = layers.add([x, residual]) + + x = layers.SeparableConv2D( + 1536, (3, 3), padding="same", use_bias=False, name="block14_sepconv1" + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name="block14_sepconv1_bn" + )(x) + x = layers.Activation("relu", name="block14_sepconv1_act")(x) + + x = layers.SeparableConv2D( + 2048, (3, 3), padding="same", use_bias=False, name="block14_sepconv2" + )(x) + x = layers.BatchNormalization( + axis=channel_axis, name="block14_sepconv2_bn" + )(x) + x = layers.Activation("relu", name="block14_sepconv2_act")(x) + + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + imagenet_utils.validate_activation(classifier_activation, weights) + x = layers.Dense( + classes, activation=classifier_activation, name="predictions" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D()(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D()(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = operation_utils.get_source_inputs(input_tensor) + else: + inputs = img_input + # Create model. + model = Functional(inputs, x, name=name) + + # Load weights. + if weights == "imagenet": + if include_top: + weights_path = file_utils.get_file( + "xception_weights_tf_dim_ordering_tf_kernels.h5", + WEIGHTS_PATH, + cache_subdir="models", + file_hash="0a58e3b7378bc2990ea3b43d5981f1f6", + ) + else: + weights_path = file_utils.get_file( + "xception_weights_tf_dim_ordering_tf_kernels_notop.h5", + WEIGHTS_PATH_NO_TOP, + cache_subdir="models", + file_hash="b0042744bf5b25fce3cb969f33bebb97", + ) + model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + + return model + + +@keras_export("keras.applications.xception.preprocess_input") +def preprocess_input(x, data_format=None): + return imagenet_utils.preprocess_input( + x, data_format=data_format, mode="tf" + ) + + +@keras_export("keras.applications.xception.decode_predictions") +def decode_predictions(preds, top=5): + return imagenet_utils.decode_predictions(preds, top=top) + + +preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( + mode="", + ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, + error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC, +) +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py new file mode 100644 index 000000000000..15f1af2145d5 --- /dev/null +++ b/keras/src/backend/__init__.py @@ -0,0 +1,77 @@ +from keras.src.backend.config import backend + +if backend() == "torch": + # When using the torch backend, + # torch needs to be imported first, otherwise it will segfault + # upon import. + import torch + +from keras.src.api_export import keras_export +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.keras_tensor import any_symbolic_tensors +from keras.src.backend.common.keras_tensor import is_keras_tensor +from keras.src.backend.common.masking import get_keras_mask +from keras.src.backend.common.masking import set_keras_mask +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.stateless_scope import get_stateless_scope +from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope +from keras.src.backend.common.variables import AutocastScope +from keras.src.backend.common.variables import Variable +from keras.src.backend.common.variables import get_autocast_scope +from keras.src.backend.common.variables import is_float_dtype +from keras.src.backend.common.variables import is_int_dtype +from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.common.variables import standardize_shape +from keras.src.backend.config import epsilon +from keras.src.backend.config import floatx +from keras.src.backend.config import image_data_format +from keras.src.backend.config import set_epsilon +from keras.src.backend.config import set_floatx +from keras.src.backend.config import set_image_data_format +from keras.src.backend.config import standardize_data_format + +# Import backend functions. +if backend() == "tensorflow": + from keras.src.backend.tensorflow import * # noqa: F403 + from keras.src.backend.tensorflow.core import Variable as BackendVariable +elif backend() == "jax": + from keras.src.backend.jax import * # noqa: F403 + from keras.src.backend.jax.core import Variable as BackendVariable +elif backend() == "torch": + from keras.src.backend.torch import * # noqa: F403 + from keras.src.backend.torch.core import Variable as BackendVariable + + distribution_lib = None +elif backend() == "numpy": + from keras.src.backend.numpy import * # noqa: F403 + from keras.src.backend.numpy.core import Variable as BackendVariable + + distribution_lib = None +elif backend() == "openvino": + from keras.src.backend.openvino import * # noqa: F403 + from keras.src.backend.openvino.core import Variable as BackendVariable + + distribution_lib = None +else: + raise ValueError(f"Unable to import backend : {backend()}") + + +@keras_export("keras.Variable") +class Variable(BackendVariable): # noqa: F811 + pass + + +backend_name_scope = name_scope # noqa: F405 + + +@keras_export("keras.name_scope") +class name_scope(backend_name_scope): + pass + + +@keras_export("keras.device") +def device(device_name): + return device_scope(device_name) # noqa: F405 diff --git a/keras/src/backend/common/__init__.py b/keras/src/backend/common/__init__.py new file mode 100644 index 000000000000..27ab20a03aec --- /dev/null +++ b/keras/src/backend/common/__init__.py @@ -0,0 +1,10 @@ +from keras.src.backend.common import backend_utils +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.variables import AutocastScope +from keras.src.backend.common.variables import Variable as KerasVariable +from keras.src.backend.common.variables import get_autocast_scope +from keras.src.backend.common.variables import is_float_dtype +from keras.src.backend.common.variables import is_int_dtype +from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.common.variables import standardize_shape +from keras.src.random import random diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py new file mode 100644 index 000000000000..fb809c2cc7b2 --- /dev/null +++ b/keras/src/backend/common/backend_utils.py @@ -0,0 +1,525 @@ +import functools +import operator +import re +import warnings + + +def _convert_conv_transpose_padding_args_from_keras_to_jax( + kernel_size, stride, dilation_rate, padding, output_padding +): + """Convert the padding arguments from Keras to the ones used by JAX. + JAX starts with an shape of size `(input-1) * stride - kernel_size + 2`, + then adds `left_pad` on the left, and `right_pad` on the right. + In Keras, the `padding` argument determines a base shape, to which + `output_padding` is added on the right. If `output_padding` is None, it will + be given a default value. + """ + + assert padding.lower() in {"valid", "same"} + kernel_size = (kernel_size - 1) * dilation_rate + 1 + + if padding.lower() == "valid": + # If output_padding is None, we fill it so that the shape of the output + # is `(input-1)*s + max(kernel_size, stride)` + output_padding = ( + max(kernel_size, stride) - kernel_size + if output_padding is None + else output_padding + ) + left_pad = kernel_size - 1 + right_pad = kernel_size - 1 + output_padding + + else: + if output_padding is None: + # When output_padding is None, we want the shape of the output to + # be `input * s`, therefore a total padding of + # `stride + kernel_size - 2` + pad_len = stride + kernel_size - 2 + else: + # When output_padding is filled, we want the shape of the output to + # be `(input-1)*stride + kernel_size%2 + output_padding` + pad_len = kernel_size + kernel_size % 2 - 2 + output_padding + left_pad = min(pad_len // 2 + pad_len % 2, kernel_size - 1) + right_pad = pad_len - left_pad + + return left_pad, right_pad + + +def _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size, stride, dilation_rate, padding, output_padding +): + """Convert the padding arguments from Keras to the ones used by Torch. + Torch starts with an output shape of `(input-1) * stride + kernel_size`, + then removes `torch_padding` from both sides, and adds + `torch_output_padding` on the right. + Because in Torch the output_padding can only be added to the right, + consistency with Tensorflow is not always possible. In particular this is + the case when both the Torch padding and output_padding values are + strictly positive. + """ + assert padding.lower() in {"valid", "same"} + original_kernel_size = kernel_size + kernel_size = (kernel_size - 1) * dilation_rate + 1 + + if padding.lower() == "valid": + # If output_padding is None, we fill it so that the shape of the output + # is `(i-1)*s + max(k, s)` + output_padding = ( + max(kernel_size, stride) - kernel_size + if output_padding is None + else output_padding + ) + torch_padding = 0 + torch_output_padding = output_padding + + else: + # When output_padding is None, we want the shape of the output to be + # `input * s`, otherwise we use the value provided. + output_padding = ( + stride - kernel_size % 2 + if output_padding is None + else output_padding + ) + torch_padding = max( + -((kernel_size % 2 - kernel_size + output_padding) // 2), 0 + ) + torch_output_padding = ( + 2 * torch_padding + kernel_size % 2 - kernel_size + output_padding + ) + + if torch_padding > 0 and torch_output_padding > 0: + warnings.warn( + f"You might experience inconsistencies across backends when " + f"calling conv transpose with kernel_size={original_kernel_size}, " + f"stride={stride}, dilation_rate={dilation_rate}, " + f"padding={padding}, output_padding={output_padding}." + ) + + if torch_output_padding >= stride: + raise ValueError( + f"The padding arguments (padding={padding}) and " + f"output_padding={output_padding}) lead to a Torch " + f"output_padding ({torch_output_padding}) that is greater than " + f"strides ({stride}). This is not supported. You can change the " + f"padding arguments, kernel or stride, or run on another backend. " + ) + + return torch_padding, torch_output_padding + + +def compute_conv_transpose_padding_args_for_jax( + input_shape, + kernel_shape, + strides, + padding, + output_padding, + dilation_rate, +): + num_spatial_dims = len(input_shape) - 2 + kernel_spatial_shape = kernel_shape[:-2] + + jax_padding = [] + for i in range(num_spatial_dims): + output_padding_i = ( + output_padding + if output_padding is None or isinstance(output_padding, int) + else output_padding[i] + ) + strides_i = strides if isinstance(strides, int) else strides[i] + dilation_rate_i = ( + dilation_rate + if isinstance(dilation_rate, int) + else dilation_rate[i] + ) + ( + pad_left, + pad_right, + ) = _convert_conv_transpose_padding_args_from_keras_to_jax( + kernel_size=kernel_spatial_shape[i], + stride=strides_i, + dilation_rate=dilation_rate_i, + padding=padding, + output_padding=output_padding_i, + ) + jax_padding.append((pad_left, pad_right)) + + return jax_padding + + +def compute_conv_transpose_padding_args_for_torch( + input_shape, + kernel_shape, + strides, + padding, + output_padding, + dilation_rate, +): + num_spatial_dims = len(input_shape) - 2 + kernel_spatial_shape = kernel_shape[:-2] + + torch_paddings = [] + torch_output_paddings = [] + for i in range(num_spatial_dims): + output_padding_i = ( + output_padding + if output_padding is None or isinstance(output_padding, int) + else output_padding[i] + ) + strides_i = strides if isinstance(strides, int) else strides[i] + dilation_rate_i = ( + dilation_rate + if isinstance(dilation_rate, int) + else dilation_rate[i] + ) + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=kernel_spatial_shape[i], + stride=strides_i, + dilation_rate=dilation_rate_i, + padding=padding, + output_padding=output_padding_i, + ) + torch_paddings.append(torch_padding) + torch_output_paddings.append(torch_output_padding) + + return torch_paddings, torch_output_paddings + + +def _get_output_shape_given_tf_padding( + input_size, kernel_size, strides, padding, output_padding, dilation_rate +): + if input_size is None: + return None + + assert padding.lower() in {"valid", "same"} + + kernel_size = (kernel_size - 1) * dilation_rate + 1 + + if padding.lower() == "valid": + output_padding = ( + max(kernel_size, strides) - kernel_size + if output_padding is None + else output_padding + ) + return (input_size - 1) * strides + kernel_size + output_padding + + else: + if output_padding is None: + return input_size * strides + else: + return (input_size - 1) * strides + kernel_size % 2 + output_padding + + +def compute_conv_transpose_output_shape( + input_shape, + kernel_size, + filters, + strides, + padding, + output_padding=None, + data_format="channels_last", + dilation_rate=1, +): + num_spatial_dims = len(input_shape) - 2 + kernel_spatial_shape = kernel_size + + if isinstance(output_padding, int): + output_padding = (output_padding,) * len(kernel_spatial_shape) + if isinstance(strides, int): + strides = (strides,) * num_spatial_dims + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) * num_spatial_dims + + if data_format == "channels_last": + input_spatial_shape = input_shape[1:-1] + else: + input_spatial_shape = input_shape[2:] + + output_shape = [] + for i in range(num_spatial_dims): + current_output_padding = ( + None if output_padding is None else output_padding[i] + ) + + shape_i = _get_output_shape_given_tf_padding( + input_size=input_spatial_shape[i], + kernel_size=kernel_spatial_shape[i], + strides=strides[i], + padding=padding, + output_padding=current_output_padding, + dilation_rate=dilation_rate[i], + ) + output_shape.append(shape_i) + + if data_format == "channels_last": + output_shape = [input_shape[0]] + output_shape + [filters] + else: + output_shape = [input_shape[0], filters] + output_shape + return output_shape + + +def canonicalize_axis(axis, num_dims): + """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" + axis = operator.index(axis) + if not -num_dims <= axis < num_dims: + raise ValueError( + f"axis {axis} is out of bounds for an array with dimension " + f"{num_dims}." + ) + if axis < 0: + axis = axis + num_dims + return axis + + +def standardize_axis_for_numpy(axis): + """Standardize an axis to a tuple if it is a list in the numpy backend.""" + return tuple(axis) if isinstance(axis, list) else axis + + +def to_tuple_or_list(value): + """Convert the non-`None` value to either a tuple or a list.""" + if value is None: + return value + if not isinstance(value, (int, tuple, list)): + raise ValueError( + "`value` must be an integer, tuple or list. " + f"Received: value={value}" + ) + if isinstance(value, int): + return (value,) + return value + + +### Code for ops.vectorize() used for TF and torch backends. + +# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html +_DIMENSION_NAME = r"\w+" +_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME) +_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" +_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT) +_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST) + + +def _vectorize_parse_gufunc_signature( + signature, +): + if not re.match(_SIGNATURE, signature): + raise ValueError(f"not a valid gufunc signature: {signature}") + args, retvals = ( + [ + tuple(re.findall(_DIMENSION_NAME, arg)) + for arg in re.findall(_ARGUMENT, arg_list) + ] + for arg_list in signature.split("->") + ) + return args, retvals + + +def _vectorize_update_dim_sizes(dim_sizes, shape, core_dims, is_input=True): + num_core_dims = len(core_dims) + if is_input: + if len(shape) < num_core_dims: + raise ValueError( + f"input with shape {shape} does not " + "have enough dimensions for all core " + f"dimensions {core_dims}" + ) + else: + if len(shape) != num_core_dims: + raise ValueError( + f"output shape {shape} does not " + f"match core dimensions {core_dims}" + ) + + core_shape = shape[-num_core_dims:] if core_dims else () + for dim, size in zip(core_dims, core_shape): + if dim not in dim_sizes: + dim_sizes[dim] = size + elif size != dim_sizes[dim]: + raise ValueError( + f"inconsistent size for core dimension {dim}: " + f"{size} vs {dim_sizes[dim]}" + ) + + +def _vectorize_parse_input_dimensions( + args, + input_core_dims, +): + from keras.src import ops + + if len(args) != len(input_core_dims): + raise TypeError( + "wrong number of positional arguments: " + f"expected {len(input_core_dims)}, got {len(args)}" + ) + shapes = [] + dim_sizes = {} + for arg, core_dims in zip(args, input_core_dims): + _vectorize_update_dim_sizes( + dim_sizes, arg.shape, core_dims, is_input=True + ) + ndim = arg.ndim - len(core_dims) + shapes.append(arg.shape[:ndim]) + broadcast_shape = shapes[0] + for s in shapes: + broadcast_shape = ops.broadcast_shapes(broadcast_shape, s) + return broadcast_shape, dim_sizes + + +def _vectorize_check_output_dims( + func, + dim_sizes, + expected_output_core_dims, +): + from keras.src import ops + + def wrapped(*args): + out = func(*args) + if isinstance(out, (list, tuple)): + out_shapes = [ops.shape(x) for x in out] + else: + out_shapes = [out.shape] + + if expected_output_core_dims is None: + output_core_dims = [()] * len(out_shapes) + else: + output_core_dims = expected_output_core_dims + if len(output_core_dims) > 1 and not isinstance(out, tuple): + raise TypeError( + "output must be a tuple when multiple outputs " + f"are expected, got: {out}" + ) + if len(out_shapes) != len(output_core_dims): + raise TypeError( + "wrong number of output arguments: " + f"expected {len(output_core_dims)}, got {len(out_shapes)}" + ) + + sizes = dict(dim_sizes) + for shape, core_dims in zip(out_shapes, output_core_dims): + _vectorize_update_dim_sizes(sizes, shape, core_dims, is_input=False) + + return out + + return wrapped + + +def _vectorize_apply_excluded(func, excluded, args, kwargs): + if not excluded: + return func, args, kwargs + + dynamic_args = [arg for i, arg in enumerate(args) if i not in excluded] + dynamic_kwargs = { + key: val for key, val in kwargs.items() if key not in excluded + } + static_args = [ + (i, args[i]) + for i in sorted(e for e in excluded if isinstance(e, int)) + if i < len(args) + ] + static_kwargs = {key: val for key, val in kwargs.items() if key in excluded} + + def new_func(*args, **kwargs): + args = list(args) + for i, arg in static_args: + args.insert(i, arg) + return func(*args, **kwargs, **static_kwargs) + + return new_func, dynamic_args, dynamic_kwargs + + +def vectorize_impl(pyfunc, vmap_fn, *, excluded=None, signature=None): + """Implementation adapted from JAX and NumPy.""" + + from keras.src import ops + + excluded = None or set() + + @functools.wraps(pyfunc) + def wrapped(*args, **kwargs): + excluded_func, args, kwargs = _vectorize_apply_excluded( + pyfunc, excluded, args, kwargs + ) + + if signature is not None: + input_core_dims, output_core_dims = ( + _vectorize_parse_gufunc_signature(signature) + ) + else: + input_core_dims = [()] * len(args) + output_core_dims = None + + none_args = {i for i, arg in enumerate(args) if arg is None} + if any(none_args): + if any(input_core_dims[i] != () for i in none_args): + raise ValueError( + f"Cannot pass None at locations {none_args} " + f"with signature={signature}" + ) + excluded_func, args, _ = _vectorize_apply_excluded( + excluded_func, none_args, args, {} + ) + input_core_dims = [ + dim + for i, dim in enumerate(input_core_dims) + if i not in none_args + ] + + args = tuple(map(ops.convert_to_tensor, args)) + + broadcast_shape, dim_sizes = _vectorize_parse_input_dimensions( + args, input_core_dims + ) + checked_func = _vectorize_check_output_dims( + excluded_func, dim_sizes, output_core_dims + ) + squeezed_args = [] + rev_filled_shapes = [] + for arg, core_dims in zip(args, input_core_dims): + noncore_shape = arg.shape[: arg.ndim - len(core_dims)] + + pad_ndim = len(broadcast_shape) - len(noncore_shape) + filled_shape = pad_ndim * (1,) + noncore_shape + rev_filled_shapes.append(filled_shape[::-1]) + + squeeze_indices = tuple( + i for i, size in enumerate(noncore_shape) if size == 1 + ) + squeezed_arg = ops.squeeze(arg, axis=squeeze_indices) + squeezed_args.append(squeezed_arg) + + vectorized_func = checked_func + dims_to_expand = [] + for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)): + in_axes = tuple(None if size == 1 else 0 for size in axis_sizes) + if all(axis is None for axis in in_axes): + dims_to_expand.append(len(broadcast_shape) - 1 - negdim) + else: + vectorized_func = vmap_fn(vectorized_func, in_axes) + result = vectorized_func(*squeezed_args) + + if not dims_to_expand: + return result + elif isinstance(result, tuple): + return tuple( + ops.expand_dims(r, axis=dims_to_expand) for r in result + ) + else: + return ops.expand_dims(result, axis=dims_to_expand) + + return wrapped + + +def slice_along_axis(x, start=0, stop=None, step=1, axis=0): + """Slice a Tensor along the given axis.""" + # Ref: same util function defined in tfp.math.scan_associative + if axis >= 0: + slices = [slice(None)] * axis + [slice(start, stop, step)] + else: + slices = [Ellipsis, slice(start, stop, step)] + [slice(None)] * ( + -1 - axis + ) + return x[tuple(slices)] diff --git a/keras/src/backend/common/backend_utils_test.py b/keras/src/backend/common/backend_utils_test.py new file mode 100644 index 000000000000..deea5fc17267 --- /dev/null +++ b/keras/src/backend/common/backend_utils_test.py @@ -0,0 +1,235 @@ +from keras.src.backend.common.backend_utils import ( + _convert_conv_transpose_padding_args_from_keras_to_jax, +) +from keras.src.backend.common.backend_utils import ( + _convert_conv_transpose_padding_args_from_keras_to_torch, +) +from keras.src.backend.common.backend_utils import ( + _get_output_shape_given_tf_padding, +) +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_jax, +) +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_torch, +) +from keras.src.testing import test_case + + +class ConvertConvTransposePaddingArgsJAXTest(test_case.TestCase): + def test_valid_padding_without_output_padding(self): + """Test conversion with 'valid' padding and no output padding""" + ( + left_pad, + right_pad, + ) = _convert_conv_transpose_padding_args_from_keras_to_jax( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="valid", + output_padding=None, + ) + self.assertEqual(left_pad, 2) + self.assertEqual(right_pad, 2) + + def test_same_padding_without_output_padding(self): + """Test conversion with 'same' padding and no output padding.""" + ( + left_pad, + right_pad, + ) = _convert_conv_transpose_padding_args_from_keras_to_jax( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="same", + output_padding=None, + ) + self.assertEqual(left_pad, 2) + self.assertEqual(right_pad, 1) + + +class ConvertConvTransposePaddingArgsTorchTest(test_case.TestCase): + def test_valid_padding_without_output_padding(self): + """Test conversion with 'valid' padding and no output padding""" + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="valid", + output_padding=None, + ) + self.assertEqual(torch_padding, 0) + self.assertEqual(torch_output_padding, 0) + + def test_same_padding_without_output_padding(self): + """Test conversion with 'same' padding and no output padding""" + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="same", + output_padding=None, + ) + self.assertEqual(torch_padding, 1) + self.assertEqual(torch_output_padding, 1) + + +class ComputeConvTransposePaddingArgsForJAXTest(test_case.TestCase): + def test_valid_padding_without_output_padding(self): + """Test computation with 'valid' padding and no output padding""" + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=(1, 5, 5, 3), + kernel_shape=(3, 3, 3, 3), + strides=2, + padding="valid", + output_padding=None, + dilation_rate=1, + ) + self.assertEqual(jax_padding, [(2, 2), (2, 2)]) + + def test_same_padding_without_output_padding(self): + """Test computation with 'same' padding and no output padding""" + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=(1, 5, 5, 3), + kernel_shape=(3, 3, 3, 3), + strides=2, + padding="same", + output_padding=None, + dilation_rate=1, + ) + + self.assertEqual(jax_padding, [(2, 1), (2, 1)]) + + +class ComputeConvTransposePaddingArgsForTorchTest(test_case.TestCase): + def test_valid_padding_without_output_padding(self): + """Test computation with 'valid' padding and no output padding""" + ( + torch_paddings, + torch_output_paddings, + ) = compute_conv_transpose_padding_args_for_torch( + input_shape=(1, 5, 5, 3), + kernel_shape=(3, 3, 3, 3), + strides=2, + padding="valid", + output_padding=None, + dilation_rate=1, + ) + self.assertEqual(torch_paddings, [0, 0]) + self.assertEqual(torch_output_paddings, [0, 0]) + + def test_same_padding_without_output_padding(self): + """Test computation with 'same' padding and no output padding""" + ( + torch_paddings, + torch_output_paddings, + ) = compute_conv_transpose_padding_args_for_torch( + input_shape=(1, 5, 5, 3), + kernel_shape=(3, 3, 3, 3), + strides=2, + padding="same", + output_padding=None, + dilation_rate=1, + ) + self.assertEqual(torch_paddings, [1, 1]) + self.assertEqual(torch_output_paddings, [1, 1]) + + def test_valid_padding_with_none_output_padding(self): + """Test conversion with 'valid' padding and no output padding""" + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="valid", + output_padding=None, + ) + self.assertEqual(torch_padding, 0) + self.assertEqual(torch_output_padding, 0) + + def test_valid_padding_with_output_padding(self): + """Test conversion with 'valid' padding and output padding for Torch.""" + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="valid", + output_padding=1, + ) + self.assertEqual(torch_padding, 0) + self.assertEqual(torch_output_padding, 1) + + +class GetOutputShapeGivenTFPaddingTest(test_case.TestCase): + def test_valid_padding_without_output_padding(self): + """Test computation with 'valid' padding and no output padding.""" + output_shape = _get_output_shape_given_tf_padding( + input_size=5, + kernel_size=3, + strides=2, + padding="valid", + output_padding=None, + dilation_rate=1, + ) + self.assertEqual(output_shape, 11) + + def test_same_padding_without_output_padding(self): + """Test computation with 'same' padding and no output padding.""" + output_shape = _get_output_shape_given_tf_padding( + input_size=5, + kernel_size=3, + strides=2, + padding="same", + output_padding=None, + dilation_rate=1, + ) + self.assertEqual(output_shape, 10) + + def test_valid_padding_with_output_padding(self): + """Test computation with 'valid' padding and output padding.""" + output_shape = _get_output_shape_given_tf_padding( + input_size=5, + kernel_size=3, + strides=2, + padding="valid", + output_padding=1, + dilation_rate=1, + ) + self.assertEqual(output_shape, 12) + + def test_warning_for_inconsistencies(self): + """Test that a warning is raised for potential inconsistencies""" + with self.assertWarns(Warning): + _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="same", + output_padding=1, + ) + + def test_same_padding_without_output_padding_for_torch_(self): + """Test conversion with 'same' padding and no output padding.""" + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=3, + stride=2, + dilation_rate=1, + padding="same", + output_padding=None, + ) + self.assertEqual(torch_padding, max(-((3 % 2 - 3) // 2), 0)) + self.assertEqual(torch_output_padding, 1) diff --git a/keras/src/backend/common/compute_output_spec_test.py b/keras/src/backend/common/compute_output_spec_test.py new file mode 100644 index 000000000000..8ee856f4d31b --- /dev/null +++ b/keras/src/backend/common/compute_output_spec_test.py @@ -0,0 +1,64 @@ +import pytest + +from keras.src import backend +from keras.src import testing + + +def example_fn(x): + x = (x + 2) * backend.numpy.ones_like(x) + x = backend.numpy.stack([x, x], axis=-1) + return x + + +class ComputeOutputSpecTest(testing.TestCase): + def test_basics(self): + out = backend.compute_output_spec( + example_fn, backend.KerasTensor((2, 3)) + ) + self.assertIsInstance(out, backend.KerasTensor) + self.assertEqual(out.shape, (2, 3, 2)) + + out = backend.compute_output_spec( + example_fn, backend.KerasTensor((None, 3)) + ) + self.assertIsInstance(out, backend.KerasTensor) + self.assertEqual(out.shape, (None, 3, 2)) + + out = backend.compute_output_spec( + example_fn, backend.KerasTensor((2, None)) + ) + self.assertIsInstance(out, backend.KerasTensor) + self.assertEqual(out.shape, (2, None, 2)) + + @pytest.mark.skipif( + backend.backend() != "torch", reason="Only applicable for torch" + ) + def test_torch_meta_device_incompatible_ops(self): + class Container: + def __init__(self): + self.canary = False + + def example_meta_fn(self, x): + y = backend.numpy.ones(x.shape) + if str(y.device) == "meta": + self.canary = True + raise ValueError("Erroring out on meta device") + x = (x + 2) * y + x = backend.numpy.stack([x, x], axis=-1) + return x + + instance = Container() + out = backend.compute_output_spec( + instance.example_meta_fn, backend.KerasTensor((2, 3)) + ) + self.assertIsInstance(out, backend.KerasTensor) + self.assertTrue(instance.canary) + self.assertEqual(out.shape, (2, 3, 2)) + + instance = Container() + out = backend.compute_output_spec( + instance.example_meta_fn, backend.KerasTensor((2, None)) + ) + self.assertIsInstance(out, backend.KerasTensor) + self.assertTrue(instance.canary) + self.assertEqual(out.shape, (2, None, 2)) diff --git a/keras/src/backend/common/dtypes.py b/keras/src/backend/common/dtypes.py new file mode 100644 index 000000000000..9fcb7b15357a --- /dev/null +++ b/keras/src/backend/common/dtypes.py @@ -0,0 +1,324 @@ +import functools + +from keras.src.api_export import keras_export +from keras.src.backend import config +from keras.src.backend.common.variables import standardize_dtype + +BOOL_TYPES = ("bool",) +INT_TYPES = ( + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", +) +FLOAT_TYPES = ("bfloat16", "float16", "float32", "float64") +WEAK_TYPES = ("int", "float") +COMPLEX_TYPES = ("complex64", "complex128") +# We need to separate float8 from float because there are no implicit +# conversions from float8 dtypes to other dtypes. +# Ref: https://github.com/google/jax/issues/16705 +FLOAT8_TYPES = ("float8_e4m3fn", "float8_e5m2") + +# All supported dtypes in Keras +ALLOWED_DTYPES = ( + "float16", + "float32", + "float64", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "bfloat16", + "bool", + "string", + "float8_e4m3fn", + "float8_e5m2", + "complex64", + "complex128", +) +PYTHON_DTYPES_MAP = { + bool: "bool", + int: "int64" if config.backend() == "tensorflow" else "int32", + float: "float32", + str: "string", + # special case for string value + "int": "int64" if config.backend() == "tensorflow" else "int32", + complex: "complex128" if config.backend() == "tensorflow" else "complex64", +} + +# We adapted the type promotion lattice from JAX. Ref: +# https://github.com/google/jax/blob/main/jax/_src/dtypes.py + + +def _type_promotion_lattice(): + """ + Return the type promotion lattice in the form of a DAG. + This DAG maps each type to its immediately higher type on the lattice. + """ + (b1,) = BOOL_TYPES + (u1, u2, u4, u8, i1, i2, i4, i8) = INT_TYPES + bf, f2, f4, f8 = FLOAT_TYPES + i_, f_ = WEAK_TYPES + c64, c128 = COMPLEX_TYPES + out = { + b1: [i_], + u1: [i2, u2], + u2: [i4, u4], + u4: [i8, u8], + u8: [f_], + i_: [u1, i1, c64], + i1: [i2], + i2: [i4], + i4: [i8], + i8: [f_], + f_: [bf, f2], + bf: [f4], + f2: [f4], + f4: [f8, c64], + f8: [c128], + c64: [c128], + c128: [], + } + return out + + +def _make_lattice_upper_bounds(): + lattice = _type_promotion_lattice() + upper_bounds = {node: {node} for node in lattice} + for n in lattice: + while True: + new_upper_bounds = set().union( + *(lattice[b] for b in upper_bounds[n]) + ) + if n in new_upper_bounds: + raise ValueError( + f"cycle detected in type promotion lattice for node {n}" + ) + if new_upper_bounds.issubset(upper_bounds[n]): + break + upper_bounds[n] |= new_upper_bounds + return upper_bounds + + +LATTICE_UPPER_BOUNDS = _make_lattice_upper_bounds() + + +@functools.lru_cache(512) +def _least_upper_bound(*nodes): + """Compute the least upper bound of a set of nodes. + + Args: + nodes: sequence of entries from dtypes + weak_types + + Returns: + The type representing the least upper bound of the input nodes on the + promotion lattice. + """ + # This function computes the least upper bound of a set of nodes N within a + # partially ordered set defined by the lattice generated above. + # Given a partially ordered set S, let the set of upper bounds of n ∈ S be + # UB(n) ≡ {m ∈ S | n ≤ m} + # Further, for a set of nodes N ⊆ S, let the set of common upper bounds be + # given by + # CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)} + # Then the least upper bound of N is defined as + # LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d} + # The definition of an upper bound implies that + # c ≤ d if and only if d ∈ UB(c), + # so the LUB can be expressed: + # LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)} + # or, equivalently: + # LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)} + # By definition, LUB(N) has a cardinality of 1 for a partially ordered set. + # Note a potential algorithmic shortcut: from the definition of CUB(N), + # we have + # ∀ c ∈ N: CUB(N) ⊆ UB(c) + # So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N). + N = set(nodes) + UB = LATTICE_UPPER_BOUNDS + try: + bounds = [UB[n] for n in N] + except KeyError: + dtype = next(n for n in N if n not in UB) + raise ValueError( + f"{dtype=} is not a valid dtype for Keras type promotion." + ) + CUB = set.intersection(*bounds) + LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])} + if len(LUB) == 1: + return LUB.pop() + elif len(LUB) == 0: + msg = ( + f"Input dtypes {tuple(str(n) for n in nodes)} have no available " + "implicit dtype promotion path. Try explicitly casting inputs to " + "the desired output type." + ) + raise ValueError(msg) + else: + # If we get here, it means the lattice is ill-formed. + raise ValueError( + f"Internal Type Promotion error: {nodes} do not have a unique " + f"least upper bound on the specified lattice; options are {LUB}. " + "This is an unexpected error in Keras's internal logic; " + "please report it to the maintainers." + ) + + +def _dtype_and_weaktype(value): + """Return a (dtype, weak_type) tuple for the given input.""" + is_weak_type = False + if value is int or value is float: + # Note that we can't use `value in [int, float]` because the dtype + # might be equal to python scalar types. + # e.g, tf.float32 == float is True + is_weak_type = True + return standardize_dtype(value), is_weak_type + + +@functools.lru_cache(maxsize=None) +def _respect_weak_type(dtype, weak_type): + """Return the weak dtype of `dtype` if `weak_type==True`.""" + if weak_type: + if dtype == "bool": + return dtype + elif "float" in dtype: + return "float" + elif "int" in dtype: + return "int" + elif "complex" in dtype: + return "complex" + else: + raise ValueError( + "Invalid value for argument `dtype`. Expected one of " + f"{ALLOWED_DTYPES}. Received: dtype={dtype}" + ) + return dtype + + +@functools.lru_cache(maxsize=None) +def _resolve_weak_type(dtype, precision="32"): + """Resolve weak type by the precision of `backend.floatx()`.""" + extended_allowed_dtypes = set(ALLOWED_DTYPES).union(WEAK_TYPES) + if dtype not in extended_allowed_dtypes: + raise ValueError( + "Invalid value for argument `dtype`. Expected one of " + f"{extended_allowed_dtypes}. Received: dtype={dtype}" + ) + if precision not in ["16", "32", "64"]: + raise ValueError( + f"Invalid value for argument `precision`. Expected one of " + f"('16', '32', '64'). Received: precision={precision}" + ) + if dtype == "bfloat16": # special case for bfloat16 + dtype_indicator = "f" + else: + dtype_indicator = dtype[:1] + + if dtype_indicator == "b": + return "bool" + elif dtype_indicator == "i": + return f"int{precision}" + elif dtype_indicator == "u": + return f"uint{precision}" + else: + return f"float{precision}" + + +BIT64_TO_BIT32_DTYPE = { + # Since TF variables require int64 to be placed on the GPU, we exclusively + # enable the int64 dtype for TF. + "int64": "int64" if config.backend() == "tensorflow" else "int32", + "uint64": "uint32", + "float64": "float64" if config.backend() == "tensorflow" else "float32", + "complex128": "complex64", +} + + +def _lattice_result_type(*args): + dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args)) + if len(dtypes) == 1: + out_dtype = dtypes[0] + out_weak_type = weak_types[0] + elif len(set(dtypes)) == 1 and not all(weak_types): + # Trivial promotion case. This allows extended dtypes through. + out_dtype = dtypes[0] + out_weak_type = False + elif all(weak_types): + # If all inputs are weakly typed, we compute the bound of the + # strongly-typed counterparts and apply the weak type at the end. This + # avoids returning the incorrect result with non-canonical weak types + # (e.g. weak int16). + out_dtype = _least_upper_bound( + *{_respect_weak_type(d, False) for d in dtypes} + ) + out_weak_type = True + else: + out_dtype = _least_upper_bound( + *{_respect_weak_type(d, w) for d, w in zip(dtypes, weak_types)} + ) + out_weak_type = any(out_dtype is t for t in WEAK_TYPES) + + out_weak_type = (out_dtype != "bool") and out_weak_type + precision = config.floatx()[-2:] + if out_weak_type: + out_dtype = _resolve_weak_type(out_dtype, precision=precision) + + # Force to be 32-bit dtype when encountering 64-bit dtype. This is to + # be aligned with JAX's default behavior. + out_dtype = BIT64_TO_BIT32_DTYPE.get(out_dtype, out_dtype) + return out_dtype + + +@keras_export("keras.backend.result_type") +def result_type(*dtypes): + """Returns the type from applying the Keras type promotion rules. + + In general, each argument is first parsed by `backend.standardize_dtype`, + and the resulting dtype is determined by the least upper bound of the type + promotion lattice. + + Note: This function attempts to match the result of `jnp.result_type`. + + Args: + dtypes: Input dtypes. + + Returns: + The result dtype. + + Examples: + + >>> x = keras.ops.ones((1,), dtype="bfloat16") + >>> keras.backend.result_type(x.dtype, int) + "bfloat16" + + >>> x = keras.ops.ones((1,), dtype="int32") + >>> y = keras.ops.ones((1,), dtype="float32") + >>> keras.backend.result_type(x.dtype, y.dtype) + "float32" + + >>> z= keras.ops.ones((1,), dtype='complex64') + >>> keras.backend.result_type(z.dtype, int) + "float64" + + """ + if len(dtypes) == 0: + # If no dtypes provided, default to floatx, this matches + # `ops.convert_to_tensor([])` + return config.floatx() + for dtype in dtypes: + if dtype in FLOAT8_TYPES: + raise ValueError( + "There is no implicit conversions from float8 dtypes to others." + f" You must cast it internally. Received: {dtypes}" + ) + return _lattice_result_type( + *(config.floatx() if arg is None else arg for arg in dtypes), + ) diff --git a/keras/src/backend/common/dtypes_test.py b/keras/src/backend/common/dtypes_test.py new file mode 100644 index 000000000000..a113992b9458 --- /dev/null +++ b/keras/src/backend/common/dtypes_test.py @@ -0,0 +1,289 @@ +from unittest.mock import patch + +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src.backend.common import dtypes +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product + + +class DtypesTest(test_case.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x + not in ( + "string", + "complex128", + "float64", + "uint64", + "int64", + ) + + dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests + ] + [None] + if backend.backend() == "torch": + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")] + elif backend.backend() == "tensorflow": + # TODO(hongyu): Re-enable uint32 tests once we determine how to handle + # dtypes.result_type(uint32, int*) -> int64 promotion. + # Since TF variables require int64 to be placed on the GPU, we + # exclusively enable the int64 dtype for TF. However, JAX does not + # natively support int64, which prevents us from comparing the dtypes. + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint32",)] + elif backend.backend() == "openvino": + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("complex64",)] + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float]) + ) + def test_result_type_with_python_scalar_types(self, dtype1, dtype2): + import jax.numpy as jnp + + out = backend.result_type(dtype1, dtype2) + expected = jnp.result_type(dtype1, dtype2).name + self.assertEqual(out, expected) + + @parameterized.named_parameters( + named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES) + ) + def test_result_type_with_tensor(self, dtype1, dtype2): + import jax.numpy as jnp + + x1 = ops.ones((1,), dtype=dtype1) + x2 = ops.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + + out = backend.result_type(x1.dtype, x2.dtype) + expected = jnp.result_type(x1_jax, x2_jax).name + self.assertEqual(out, expected) + + @parameterized.named_parameters( + named_product( + dtype=[ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + ] + ) + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="TensorFlow only" + ) + def test_result_type_with_int64(self, dtype): + # https://github.com/keras-team/keras/issues/21677 + x1 = ops.ones((1,), dtype="int64") + x2 = ops.ones((1,), dtype=dtype) + out = backend.result_type(x1.dtype, x2.dtype) + self.assertEqual(out, "int64") + + @parameterized.named_parameters( + named_product( + dtype=[ + "float16", + "bfloat16", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + ] + ) + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="TensorFlow only" + ) + def test_result_type_with_float64(self, dtype): + # Float types have a similar issue as int64 in TF.: + # https://github.com/keras-team/keras/issues/21677 + x1 = ops.ones((1,), dtype="float64") + x2 = ops.ones((1,), dtype=dtype) + out = backend.result_type(x1.dtype, x2.dtype) + self.assertEqual(out, "float64") + + def test_result_type_with_none(self): + import jax.numpy as jnp + + self.assertEqual(backend.result_type(None), jnp.result_type(None).name) + + def test_result_type_empty_list(self): + self.assertEqual(backend.result_type(), "float32") + + def test_respect_weak_type_for_bool(self): + self.assertEqual(dtypes._respect_weak_type("bool", True), "bool") + + def test_respect_weak_type_for_int(self): + self.assertEqual(dtypes._respect_weak_type("int32", True), "int") + + def test_respect_weak_type_for_float(self): + self.assertEqual(dtypes._respect_weak_type("float32", True), "float") + + def test_resolve_weak_type_for_bfloat16(self): + self.assertEqual(dtypes._resolve_weak_type("bfloat16"), "float32") + + def test_resolve_weak_type_for_bfloat16_with_precision(self): + self.assertEqual( + dtypes._resolve_weak_type("bfloat16", precision="64"), "float64" + ) + + def test_respect_weak_type_for_complex64(self): + self.assertAllEqual( + dtypes._respect_weak_type("complex64", True), "complex" + ) + + def test_respect_weak_type_for_complex128(self): + self.assertAllEqual( + dtypes._respect_weak_type("complex128", True), "complex" + ) + + def test_invalid_dtype_for_keras_promotion(self): + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion." + ): + dtypes._least_upper_bound("invalid_dtype") + + def test_resolve_weak_type_for_invalid_dtype(self): + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `dtype`. Expected one of" + ): + dtypes._resolve_weak_type("invalid_dtype") + + def test_resolve_weak_type_for_invalid_precision(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `precision`. Expected one of", + ): + dtypes._resolve_weak_type("int32", precision="invalid_precision") + + def test_cycle_detection_in_make_lattice_upper_bounds(self): + original_lattice_function = dtypes._type_promotion_lattice + + def mock_lattice(): + lattice = original_lattice_function() + lattice["int32"].append("float32") + lattice["float32"].append("int32") + return lattice + + dtypes._type_promotion_lattice = mock_lattice + + with self.assertRaisesRegex( + ValueError, "cycle detected in type promotion lattice for node" + ): + dtypes._make_lattice_upper_bounds() + + dtypes._type_promotion_lattice = original_lattice_function + + def test_respect_weak_type_for_invalid_dtype(self): + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `dtype`. Expected one of" + ): + dtypes._respect_weak_type("invalid_dtype", True) + + def test_invalid_dtype_in_least_upper_bound(self): + invalid_dtype = "non_existent_dtype" + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion" + ): + dtypes._least_upper_bound(invalid_dtype) + + def test_empty_lub_in_least_upper_bound(self): + dtype1 = "float32" + dtype2 = "int32" + with patch.dict( + dtypes.LATTICE_UPPER_BOUNDS, + {"float32": set(), "int32": set()}, + clear=True, + ): + with self.assertRaisesRegex( + ValueError, "no available implicit dtype promotion path" + ): + dtypes._least_upper_bound(dtype1, dtype2) + + def test_valid_dtype_leading_to_single_lub_element(self): + self.assertEqual( + dtypes._least_upper_bound("float32", "int32"), "float32" + ) + + def test_valid_dtype_leading_to_keyerror_and_valueerror(self): + invalid_dtype = "non_existent_dtype" + with self.assertRaisesRegex( + ValueError, "is not a valid dtype for Keras type promotion" + ): + dtypes._least_upper_bound(invalid_dtype) + + def test_resolve_weak_type_bool(self): + self.assertEqual(dtypes._resolve_weak_type("bool"), "bool") + + def test_resolve_weak_type_int(self): + self.assertEqual( + dtypes._resolve_weak_type("int32", precision="32"), "int32" + ) + self.assertEqual( + dtypes._resolve_weak_type("int64", precision="64"), "int64" + ) + + def test_resolve_weak_type_uint(self): + self.assertEqual( + dtypes._resolve_weak_type("uint32", precision="32"), "uint32" + ) + self.assertEqual( + dtypes._resolve_weak_type("uint64", precision="64"), "uint64" + ) + + def test_resolve_weak_type_float(self): + self.assertEqual( + dtypes._resolve_weak_type("float32", precision="32"), "float32" + ) + self.assertEqual( + dtypes._resolve_weak_type("float64", precision="64"), "float64" + ) + + def test_least_upper_bound_ensure_order_independence(self): + # Test to ensure _least_upper_bound is order-independent. + result1 = dtypes._least_upper_bound("float32", "int32") + result2 = dtypes._least_upper_bound("int32", "float32") + self.assertEqual(result1, result2) + + def test_least_upper_bound_single_element(self): + dtypes.LATTICE_UPPER_BOUNDS["test_dtype"] = {"test_dtype"} + self.assertEqual(dtypes._least_upper_bound("test_dtype"), "test_dtype") + + def test_least_upper_bound_no_element(self): + dtypes.LATTICE_UPPER_BOUNDS["test_dtype"] = set() + with self.assertRaisesRegex( + ValueError, "no available implicit dtype promotion path" + ): + dtypes._least_upper_bound("test_dtype") + + def test_least_upper_bound_with_no_common_upper_bound(self): + with patch.dict( + dtypes.LATTICE_UPPER_BOUNDS, + {"test_dtype1": set(), "test_dtype2": set()}, + clear=True, + ): + with self.assertRaisesRegex( + ValueError, "no available implicit dtype promotion path" + ): + dtypes._least_upper_bound("test_dtype1", "test_dtype2") + + def test_invalid_float8_dtype(self): + with self.assertRaisesRegex( + ValueError, "There is no implicit conversions from float8 dtypes" + ): + dtypes.result_type("float8_e4m3fn", "bfloat16") + with self.assertRaisesRegex( + ValueError, "There is no implicit conversions from float8 dtypes" + ): + dtypes.result_type("float8_e5m2", "bfloat16") diff --git a/keras/src/backend/common/global_state.py b/keras/src/backend/common/global_state.py new file mode 100644 index 000000000000..8ecf11b95056 --- /dev/null +++ b/keras/src/backend/common/global_state.py @@ -0,0 +1,98 @@ +import gc +import threading + +from keras.src import backend +from keras.src.api_export import keras_export + +GLOBAL_STATE_TRACKER = threading.local() +GLOBAL_SETTINGS_TRACKER = threading.local() + + +def set_global_attribute(name, value): + setattr(GLOBAL_STATE_TRACKER, name, value) + + +def get_global_attribute(name, default=None, set_to_default=False): + attr = getattr(GLOBAL_STATE_TRACKER, name, None) + if attr is None and default is not None: + attr = default + if set_to_default: + set_global_attribute(name, attr) + return attr + + +@keras_export(["keras.utils.clear_session", "keras.backend.clear_session"]) +def clear_session(free_memory=True): + """Resets all state generated by Keras. + + Keras manages a global state, which it uses to implement the Functional + model-building API and to uniquify autogenerated layer names. + + If you are creating many models in a loop, this global state will consume + an increasing amount of memory over time, and you may want to clear it. + Calling `clear_session()` releases the global state: this helps avoid + clutter from old models and layers, especially when memory is limited. + + Args: + free_memory: Whether to call Python garbage collection. + It's usually a good practice to call it to make sure + memory used by deleted objects is immediately freed. + However, it may take a few seconds to execute, so + when using `clear_session()` in a short loop, + you may want to skip it. + + Example 1: calling `clear_session()` when creating models in a loop + + ```python + for _ in range(100): + # Without `clear_session()`, each iteration of this loop will + # slightly increase the size of the global state managed by Keras + model = keras.Sequential([ + keras.layers.Dense(10) for _ in range(10)]) + + for _ in range(100): + # With `clear_session()` called at the beginning, + # Keras starts with a blank state at each iteration + # and memory consumption is constant over time. + keras.backend.clear_session() + model = keras.Sequential([ + keras.layers.Dense(10) for _ in range(10)]) + ``` + + Example 2: resetting the layer name generation counter + + >>> layers = [keras.layers.Dense(10) for _ in range(10)] + >>> new_layer = keras.layers.Dense(10) + >>> print(new_layer.name) + dense_10 + >>> keras.backend.clear_session() + >>> new_layer = keras.layers.Dense(10) + >>> print(new_layer.name) + dense + """ + global GLOBAL_STATE_TRACKER + global GLOBAL_SETTINGS_TRACKER + + GLOBAL_STATE_TRACKER = threading.local() + GLOBAL_SETTINGS_TRACKER = threading.local() + + if backend.backend() == "tensorflow": + from keras.src.utils.module_utils import tensorflow as tf + + tf.compat.v1.reset_default_graph() + if tf.executing_eagerly(): + # Clear pending nodes in eager executors, kernel caches and + # step_containers. + from tensorflow.python.eager import context + + context.context().clear_kernel_cache() + elif backend.backend() == "torch": + import torch._dynamo as dynamo + + # reset's torchdynamo's cache so that cached guards, compiled fn, etc + # do not persist between clear_session() calls + dynamo.reset() + + if free_memory: + # Manually trigger garbage collection. + gc.collect() diff --git a/keras/src/backend/common/global_state_test.py b/keras/src/backend/common/global_state_test.py new file mode 100644 index 000000000000..5f2a05ba15a4 --- /dev/null +++ b/keras/src/backend/common/global_state_test.py @@ -0,0 +1,14 @@ +from keras.src.backend.common import global_state +from keras.src.testing import test_case +from keras.src.utils.naming import auto_name + + +class GlobalStateTest(test_case.TestCase): + def test_clear_session(self): + name0 = auto_name("somename") + self.assertEqual(name0, "somename") + name1 = auto_name("somename") + self.assertEqual(name1, "somename_1") + global_state.clear_session() + name0 = auto_name("somename") + self.assertEqual(name0, "somename") diff --git a/keras/src/backend/common/keras_tensor.py b/keras/src/backend/common/keras_tensor.py new file mode 100644 index 000000000000..c03d6afe53e1 --- /dev/null +++ b/keras/src/backend/common/keras_tensor.py @@ -0,0 +1,422 @@ +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.utils.naming import auto_name + + +@keras_export("keras.KerasTensor") +class KerasTensor: + """Symbolic tensor -- encapsulates a shape and a dtype. + + You can use `KerasTensor` instances to build computation + graphs of Keras operations, such as `keras.Function` + objects or Functional `keras.models.Model` objects. + + Example: + + >>> x = keras.KerasTensor(shape=(3, 4), dtype="float32") + >>> x.shape + (3, 4) + >>> x.dtype + float32 + + Calling a Keras operation (including a layer or a model) + on a `KerasTensor` instance will return another `KerasTensor` + instance with the appropriate shape and dtype. This is + called a "symbolic call" (since there is no actual data + involved). The computation of the correct output shape and + dtype is called "static shape inference". + """ + + def __init__( + self, + shape, + dtype="float32", + sparse=False, + ragged=False, + record_history=True, + name=None, + **kwargs, + ): + from keras.src import backend + + ragged_rank = kwargs.pop("ragged_rank", None) + row_splits_dtype = kwargs.pop("row_splits_dtype", None) + if kwargs: + raise TypeError( + f"Unexpected keyword arguments: {', '.join(kwargs.keys())}" + ) + + self._shape = backend.standardize_shape(shape) + self._dtype = backend.standardize_dtype(dtype) + self._sparse = bool(sparse) + self._ragged = bool(ragged) + if self._sparse and self._ragged: + raise ValueError( + "KerasTensor cannot have `sparse=True` and `ragged=True` at " + "the same time." + ) + self._ragged_rank = ( + int(ragged_rank) if ragged_rank is not None else None + ) + self._row_splits_dtype = ( + backend.standardize_dtype(row_splits_dtype) + if row_splits_dtype is not None + else None + ) + self.name = name or auto_name(self.__class__.__name__) + self.record_history = record_history + + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, value): + raise AttributeError( + "The `shape` attribute of KerasTensor is immutable. One should " + "create a new instance of KerasTensor for this." + ) + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + raise AttributeError( + "The `dtype` attribute of KerasTensor is immutable. One should " + "create a new instance of KerasTensor for this." + ) + + @property + def sparse(self): + return self._sparse + + @sparse.setter + def sparse(self, value): + raise AttributeError( + "The `sparse` attribute of KerasTensor is immutable. One should " + "create a new instance of KerasTensor for this." + ) + + @property + def ragged_rank(self): + return self._ragged_rank + + @ragged_rank.setter + def ragged_rank(self, value): + raise AttributeError( + "The `ragged_rank` attribute of KerasTensor is immutable. One " + "should create a new instance of KerasTensor for this." + ) + + @property + def row_splits_dtype(self): + return self._row_splits_dtype + + @row_splits_dtype.setter + def row_splits_dtype(self, value): + raise AttributeError( + "The `row_splits_dtype` attribute of KerasTensor is immutable. One " + "should create a new instance of KerasTensor for this." + ) + + @property + def ragged(self): + return self._ragged + + @ragged.setter + def ragged(self, value): + raise AttributeError( + "The `ragged` attribute of KerasTensor is immutable. One should " + "create a new instance of KerasTensor for this." + ) + + @property + def ndim(self): + return len(self.shape) + + def reshape(self, newshape): + from keras.src import ops + + return ops.Reshape(newshape)(self) + + def squeeze(self, axis=None): + from keras.src import ops + + return ops.Squeeze(axis)(self) + + def __int__(self): + raise ValueError( + "A KerasTensor is symbolic: it's a placeholder for a shape " + "an a dtype. It doesn't have any actual numerical value. " + "You cannot convert it to an int." + ) + + def __float__(self): + raise ValueError( + "A KerasTensor is symbolic: it's a placeholder for a shape " + "an a dtype. It doesn't have any actual numerical value. " + "You cannot convert it to a float." + ) + + def __array__(self): + raise ValueError( + "A KerasTensor is symbolic: it's a placeholder for a shape " + "an a dtype. It doesn't have any actual numerical value. " + "You cannot convert it to a NumPy array." + ) + + def __jax_array__(self): + raise ValueError( + "A KerasTensor cannot be used as input to a JAX function. " + "A KerasTensor is a symbolic placeholder for a shape and dtype, " + "used when constructing Keras Functional models " + "or Keras Functions. You can only use it as input to a Keras layer " + "or a Keras operation (from the namespaces `keras.layers` " + "and `keras.ops`). " + "You are likely doing something like:\n\n" + "```\n" + "x = Input(...)\n" + "...\n" + "jax_fn(x) # Invalid.\n" + "```\n\n" + "What you should do instead is wrap `jax_fn` in a layer:\n\n" + "```\n" + "class MyLayer(Layer):\n" + " def call(self, x):\n" + " return jax_fn(x)\n\n" + "x = MyLayer()(x)\n" + "```\n" + ) + + def __tf_tensor__(self, dtype=None, name=None): + raise ValueError( + "A KerasTensor cannot be used as input to a TensorFlow function. " + "A KerasTensor is a symbolic placeholder for a shape and dtype, " + "used when constructing Keras Functional models " + "or Keras Functions. You can only use it as input to a Keras layer " + "or a Keras operation (from the namespaces `keras.layers` " + "and `keras.ops`). " + "You are likely doing something like:\n\n" + "```\n" + "x = Input(...)\n" + "...\n" + "tf_fn(x) # Invalid.\n" + "```\n\n" + "What you should do instead is wrap `tf_fn` in a layer:\n\n" + "```\n" + "class MyLayer(Layer):\n" + " def call(self, x):\n" + " return tf_fn(x)\n\n" + "x = MyLayer()(x)\n" + "```\n" + ) + + def __repr__(self): + return ( + f"" + ) + + def __iter__(self): + raise NotImplementedError( + "Iterating over a symbolic KerasTensor is not supported." + ) + + def __bool__(self): + raise TypeError("A symbolic KerasTensor cannot be used as a boolean.") + + def __add__(self, other): + from keras.src import ops + + return ops.Add().symbolic_call(self, other) + + def __radd__(self, other): + from keras.src import ops + + return ops.Add().symbolic_call(other, self) + + def __sub__(self, other): + from keras.src import ops + + return ops.Subtract().symbolic_call(self, other) + + def __rsub__(self, other): + from keras.src import ops + + return ops.Subtract().symbolic_call(other, self) + + def __mul__(self, other): + from keras.src import ops + + return ops.Multiply().symbolic_call(self, other) + + def __rmul__(self, other): + from keras.src import ops + + return ops.Multiply().symbolic_call(other, self) + + def __matmul__(self, other): + from keras.src import ops + + return ops.Matmul().symbolic_call(self, other) + + def __rmatmul__(self, other): + from keras.src import ops + + return ops.Matmul().symbolic_call(other, self) + + def __div__(self, other): + from keras.src import ops + + return ops.Divide().symbolic_call(self, other) + + def __rdiv__(self, other): + from keras.src import ops + + return ops.Divide().symbolic_call(other, self) + + def __truediv__(self, other): + from keras.src import ops + + return ops.TrueDivide().symbolic_call(self, other) + + def __rtruediv__(self, other): + from keras.src import ops + + return ops.TrueDivide().symbolic_call(other, self) + + def __neg__(self): + from keras.src import ops + + return ops.Negative().symbolic_call(self) + + def __abs__(self): + from keras.src import ops + + return ops.Absolute().symbolic_call(self) + + def __pow__(self, other): + from keras.src import ops + + return ops.Power().symbolic_call(self, other) + + def __rpow__(self, other): + from keras.src import ops + + return ops.Power().symbolic_call(other, self) + + def __floordiv__(self, other): + from keras.src import ops + + return ops.FloorDivide().symbolic_call(self, other) + + def __rfloordiv__(self, other): + from keras.src import ops + + return ops.FloorDivide().symbolic_call(other, self) + + def __mod__(self, other): + from keras.src import ops + + return ops.Mod().symbolic_call(self, other) + + def __rmod__(self, other): + from keras.src import ops + + return ops.Mod().symbolic_call(other, self) + + def __lt__(self, other): + from keras.src import ops + + return ops.Less().symbolic_call(self, other) + + def __le__(self, other): + from keras.src import ops + + return ops.LessEqual().symbolic_call(self, other) + + def __gt__(self, other): + from keras.src import ops + + return ops.Greater().symbolic_call(self, other) + + def __ge__(self, other): + from keras.src import ops + + return ops.GreaterEqual().symbolic_call(self, other) + + def __ne__(self, other): + from keras.src import ops + + return ops.NotEqual().symbolic_call(self, other) + + def __and__(self, other): + from keras.src import ops + + return ops.LogicalAnd().symbolic_call(self, other) + + def __rand__(self, other): + from keras.src import ops + + return ops.LogicalAnd().symbolic_call(other, self) + + def __or__(self, other): + from keras.src import ops + + return ops.LogicalOr().symbolic_call(self, other) + + def __ror__(self, other): + from keras.src import ops + + return ops.LogicalOr().symbolic_call(other, self) + + def __invert__(self): + from keras.src import ops + + return ops.LogicalNot().symbolic_call(self) + + def __xor__(self, other): + from keras.src import ops + + return ops.LogicalXor().symbolic_call(self, other) + + def __rxor__(self, other): + from keras.src import ops + + return ops.LogicalXor().symbolic_call(other, self) + + def __getitem__(self, key): + from keras.src import ops + + return ops.GetItem().symbolic_call(self, key) + + def __round__(self, ndigits=None): + from keras.src import ops + + decimals = ndigits or 0 + return ops.Round(decimals=decimals).symbolic_call(self) + + +def any_symbolic_tensors(args=None, kwargs=None): + args = args or () + kwargs = kwargs or {} + for x in tree.flatten((args, kwargs)): + if isinstance(x, KerasTensor): + return True + return False + + +@keras_export(["keras.utils.is_keras_tensor", "keras.backend.is_keras_tensor"]) +def is_keras_tensor(x): + """Returns whether `x` is a Keras tensor. + + A "Keras tensor" is a *symbolic tensor*, such as a tensor + that was created via `Input()`. A "symbolic tensor" + can be understood as a placeholder -- it does not + contain any actual numerical data, only a shape and dtype. + It can be used for building Functional models, but it + cannot be used in actual computations. + """ + return isinstance(x, KerasTensor) diff --git a/keras/src/backend/common/keras_tensor_test.py b/keras/src/backend/common/keras_tensor_test.py new file mode 100644 index 000000000000..c2e84417c92d --- /dev/null +++ b/keras/src/backend/common/keras_tensor_test.py @@ -0,0 +1,425 @@ +from unittest.mock import Mock +from unittest.mock import patch + +import numpy as np +import tensorflow as tf + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common import keras_tensor + + +class KerasTensorTest(testing.TestCase): + def test_attributes(self): + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True) + self.assertEqual(x.dtype, "float32") + self.assertEqual(x.shape, (3,)) + self.assertEqual(x.sparse, True) + + # Raise error if trying to set attributes + with self.assertRaisesRegex( + AttributeError, "The `shape` attribute of KerasTensor is immutable." + ): + x.shape = [3, 2] + with self.assertRaisesRegex( + AttributeError, "The `dtype` attribute of KerasTensor is immutable." + ): + x.dtype = "int32" + + def test_attributes_sparse(self): + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True) + self.assertEqual(x.sparse, True) + + # Raise error if trying to set attributes + with self.assertRaisesRegex( + AttributeError, + "The `sparse` attribute of KerasTensor is immutable.", + ): + x.sparse = False + + def test_attributes_ragged(self): + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", ragged=True) + self.assertEqual(x.ragged, True) + + # Raise error if trying to set attributes + with self.assertRaisesRegex( + AttributeError, + "The `ragged` attribute of KerasTensor is immutable.", + ): + x.ragged = False + + def test_init_sparse_ragged_raises(self): + with self.assertRaisesRegex( + ValueError, "cannot have `sparse=True` and `ragged=True`" + ): + keras_tensor.KerasTensor(shape=(3,), sparse=True, ragged=True) + + def test_numpy_methods(self): + x = keras_tensor.KerasTensor(shape=(3, 2), dtype="float32") + + # reshape + x = x.reshape((6,)) + self.assertEqual(x.shape, (6,)) + + # expand_dims, squeeze + x = ops.expand_dims(x, -1) + self.assertEqual(x.shape, (6, 1)) + x = x.squeeze() + self.assertEqual(x.shape, (6,)) + x = ops.expand_dims(x, axis=0) + self.assertEqual(x.shape, (1, 6)) + x = x.squeeze(axis=0) + self.assertEqual(x.shape, (6,)) + + def test_invalid_usage(self): + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32") + with self.assertRaisesRegex( + ValueError, "doesn't have any actual numerical value" + ): + np.array(x) + + if backend.backend() == "jax": + from jax import numpy as jnp + + with self.assertRaisesRegex( + ValueError, "cannot be used as input to a JAX function" + ): + jnp.array(x) + + with self.assertRaisesRegex( + ValueError, "cannot be used as input to a TensorFlow function" + ): + tf.convert_to_tensor(x) + + def test_bool(self): + tensor = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + with self.assertRaisesRegex(TypeError, "cannot be used as a boolean."): + bool(tensor) + + def test_representation(self): + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + self.assertIn(" y + mock_symbolic_call.assert_called_once_with(x, y) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.GreaterEqual.symbolic_call") + def test_ge_method(self, mock_symbolic_call): + """Test __ge__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + y = Mock() + result = x >= y + mock_symbolic_call.assert_called_once_with(x, y) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.NotEqual.symbolic_call") + def test_ne_method(self, mock_symbolic_call): + """Test __ne__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + y = Mock() + result = x != y + mock_symbolic_call.assert_called_once_with(x, y) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.LogicalAnd.symbolic_call") + def test_rand_method(self, mock_symbolic_call): + """Test __rand__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="bool") + y = Mock() + result = y & x + mock_symbolic_call.assert_called_once_with(y, x) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.LogicalOr.symbolic_call") + def test_ror_method(self, mock_symbolic_call): + """Test __ror__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="bool") + y = Mock() + result = y | x + mock_symbolic_call.assert_called_once_with(y, x) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.LogicalNot.symbolic_call") + def test_invert_method(self, mock_symbolic_call): + """Test __invert__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="bool") + result = ~x + mock_symbolic_call.assert_called_once_with(x) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.LogicalXor.symbolic_call") + def test_xor_method(self, mock_symbolic_call): + """Test __xor__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="bool") + y = Mock() + result = x ^ y + mock_symbolic_call.assert_called_once_with(x, y) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.LogicalXor.symbolic_call") + def test_rxor_method(self, mock_symbolic_call): + """Test __rxor__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="bool") + y = Mock() + result = y ^ x + mock_symbolic_call.assert_called_once_with(y, x) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.TrueDivide.symbolic_call") + def test_truediv_method(self, mock_symbolic_call): + """Test __truediv__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + y = Mock() + result = x / y + mock_symbolic_call.assert_called_once_with(x, y) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.TrueDivide.symbolic_call") + def test_rtruediv_method(self, mock_symbolic_call): + """Test __rtruediv__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + y = Mock() + result = y / x + mock_symbolic_call.assert_called_once_with(y, x) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.Divide.symbolic_call") + def test_div_method(self, mock_symbolic_call): + """Test __div__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + y = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + # to ensure compatibility across Python versions + result = x.__div__(y) + mock_symbolic_call.assert_called_once_with(x, y) + self.assertEqual(result, mock_tensor) + + @patch("keras.src.ops.Divide.symbolic_call") + def test_rdiv_method(self, mock_symbolic_call): + """Test __rdiv__ method""" + mock_tensor = Mock() + mock_symbolic_call.return_value = mock_tensor + x = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + y = keras_tensor.KerasTensor(shape=(3, 4), dtype="float32") + # to ensure compatibility across Python versions + result = x.__rdiv__(y) + mock_symbolic_call.assert_called_once_with(y, x) + self.assertEqual(result, mock_tensor) diff --git a/keras/src/backend/common/masking.py b/keras/src/backend/common/masking.py new file mode 100644 index 000000000000..afd0c2b64733 --- /dev/null +++ b/keras/src/backend/common/masking.py @@ -0,0 +1,26 @@ +from keras.src.backend.common.tensor_attributes import get_tensor_attr +from keras.src.backend.common.tensor_attributes import set_tensor_attr + + +def set_keras_mask(x, mask): + """Sets the Keras mask attribute for the given tensor in-place. + + Args: + x: Input tensor. + mask: The mask tensor to be set. If `None`, the `_keras_mask` attribute + will be cleared. + """ + set_tensor_attr(x, "_keras_mask", mask) + + +def get_keras_mask(x): + """Gets the Keras mask attribute from the given tensor. + + Args: + x: Input tensor. + + Returns: + The mask tensor associated with the input tensor, or `None` if no mask + has been set. + """ + return get_tensor_attr(x, "_keras_mask") diff --git a/keras/src/backend/common/masking_test.py b/keras/src/backend/common/masking_test.py new file mode 100644 index 000000000000..f1ac8a5c26d5 --- /dev/null +++ b/keras/src/backend/common/masking_test.py @@ -0,0 +1,43 @@ +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.masking import get_keras_mask +from keras.src.backend.common.masking import set_keras_mask + + +class MaskingTest(testing.TestCase): + def test_mask_on_eager_tensor(self): + x = ops.zeros((2, 3)) + self.assertIsNone(get_keras_mask(x)) + + set_keras_mask(x, None) + self.assertIsNone(get_keras_mask(x)) + + mask = ops.ones((2, 3)) + set_keras_mask(x, mask) + self.assertIs(get_keras_mask(x), mask) + + set_keras_mask(x, None) + self.assertIsNone(get_keras_mask(x)) + + set_keras_mask(x, None) + self.assertIsNone(get_keras_mask(x)) + + def test_mask_on_tracer_tensor(self): + def fn(x): + self.assertIsNone(get_keras_mask(x)) + + set_keras_mask(x, None) + self.assertIsNone(get_keras_mask(x)) + + mask = ops.ones((2, 3)) + set_keras_mask(x, mask) + self.assertIs(get_keras_mask(x), mask) + + set_keras_mask(x, None) + self.assertIsNone(get_keras_mask(x)) + + set_keras_mask(x, None) # key is now deleted, should be a no-op + self.assertIsNone(get_keras_mask(x)) + + backend.compute_output_spec(fn, backend.KerasTensor((2, 3))) diff --git a/keras/src/backend/common/name_scope.py b/keras/src/backend/common/name_scope.py new file mode 100644 index 000000000000..71a8408767b6 --- /dev/null +++ b/keras/src/backend/common/name_scope.py @@ -0,0 +1,73 @@ +from keras.src.backend.common import global_state + + +class name_scope: + """Creates a sub-namespace for variable paths. + + Args: + name: Name of the current scope (string). + caller: Optional ID of a caller object (e.g. class instance). + deduplicate: If `True`, if `caller` was passed, + and the previous caller matches the current caller, + and the previous name matches the current name, + do not reenter a new namespace. + override_parent: Can be used to provide an absolute path + which would override any previously opened name scopes. + """ + + def __init__( + self, name, caller=None, deduplicate=True, override_parent=None + ): + if not isinstance(name, str) or "/" in name: + raise ValueError( + "Argument `name` must be a string and " + "cannot contain character `/`. " + f"Received: name={name}" + ) + self.name = name + self.caller = caller + self.deduplicate = deduplicate + self.override_parent = override_parent + if ( + override_parent is None + and deduplicate + and getattr(caller, "_parent_path", None) is not None + ): + self.override_parent = caller._parent_path + self._pop_on_exit = False + + def __enter__(self): + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack", default=[], set_to_default=True + ) + if self.deduplicate and name_scope_stack: + parent_caller = name_scope_stack[-1].caller + parent_name = name_scope_stack[-1].name + if ( + self.caller is not None + and self.caller is parent_caller + and self.name == parent_name + ): + return self + name_scope_stack.append(self) + self._pop_on_exit = True + return self + + def __exit__(self, *args, **kwargs): + if self._pop_on_exit: + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack" + ) + name_scope_stack.pop() + + +def current_path(): + name_scope_stack = global_state.get_global_attribute("name_scope_stack") + if name_scope_stack is None: + return "" + parts = [] + for entry in name_scope_stack: + if entry.override_parent is not None: + parts = [p for p in entry.override_parent.split("/") if p] + parts.append(entry.name) + return "/".join(parts) diff --git a/keras/src/backend/common/name_scope_test.py b/keras/src/backend/common/name_scope_test.py new file mode 100644 index 000000000000..2e79f2146958 --- /dev/null +++ b/keras/src/backend/common/name_scope_test.py @@ -0,0 +1,48 @@ +from keras.src import testing +from keras.src.backend.common.name_scope import current_path +from keras.src.backend.common.name_scope import name_scope + + +class NameScopeTest(testing.TestCase): + def test_stacking(self): + self.assertEqual(current_path(), "") + with name_scope("outer") as outer: + self.assertEqual(outer.name, "outer") + self.assertEqual(current_path(), "outer") + with name_scope("middle") as middle: + self.assertEqual(middle.name, "middle") + self.assertEqual(current_path(), "outer/middle") + with name_scope("inner") as inner: + self.assertEqual(inner.name, "inner") + self.assertEqual(current_path(), "outer/middle/inner") + self.assertEqual(current_path(), "outer/middle") + self.assertEqual(current_path(), "outer") + self.assertEqual(current_path(), "") + + def test_deduplication(self): + self.assertEqual(current_path(), "") + with name_scope("name", caller=1): + with name_scope("name", caller=1): + self.assertEqual(current_path(), "name") + self.assertEqual(current_path(), "") + with name_scope("name"): + with name_scope("name"): + self.assertEqual(current_path(), "name/name") + + def test_errors(self): + with self.assertRaisesRegex(ValueError, "must be a string"): + name_scope("foo/bar") + with self.assertRaisesRegex(ValueError, "must be a string"): + name_scope(4) + + def test_override_parent(self): + self.assertEqual(current_path(), "") + with name_scope("outer"): + self.assertEqual(current_path(), "outer") + with name_scope("middle", override_parent="/absolute/path"): + self.assertEqual(current_path(), "absolute/path/middle") + with name_scope("inner"): + self.assertEqual( + current_path(), "absolute/path/middle/inner" + ) + self.assertEqual(current_path(), "outer") diff --git a/keras/src/backend/common/remat.py b/keras/src/backend/common/remat.py new file mode 100644 index 000000000000..8465bda25d0b --- /dev/null +++ b/keras/src/backend/common/remat.py @@ -0,0 +1,186 @@ +from collections import namedtuple + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +@keras_export("keras.RematScope") +class RematScope: + """A context manager for enabling rematerialization in Keras. + + Rematerialization (gradient checkpointing) trades memory for computation by + recomputing intermediate activations during the backward pass. This is + particularly useful for training large models or large batch sizes within + limited memory constraints. + + This should be used when initializing the layer (e.g., `layer(input)`). + Rematerialization applies at execution time, not at creation time. + + Args: + mode: Rematerialization mode to apply. + Options: + - `"full"`: Apply rematerialization globally to all supported + operations. + - `"activations"`: Apply rematerialization to activations on any + layers that contain `keras.activations` (e.g., `Dense(..., + activation=relu)`). + - `"larger_than"`: Apply rematerialization to layers with output + sizes larger than `output_size_threshold`. + - `"list_of_layers"`: Apply rematerialization to a specific list of + layer names. + - `None`: Disable rematerialization. + output_size_threshold: Output size threshold for the + `"larger_than"` mode. Layers producing outputs larger than this + threshold will be rematerialized. Default is `1024`. + layer_names: List of layer names for the + `"list_of_layers"` mode. Default is an empty list. + + Examples: + Using "list_of_layers" mode: + + ```python + from keras import RematScope + input_tensor = tf.random.normal((1, 32, 32, 3)) + with RematScope(mode="list_of_layers", layer_names=["dense_1", + "conv2d_1"]): + layer1 = keras.layers.Dense(128, name="dense_1") + layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1") + layer3 = keras.layers.Dense(64, name="dense_2") + # Only layer1 and layer2 will apply rematerialization + output1 = layer1(input_tensor) + output2 = layer2(output1) + output3 = layer3(output2) + ``` + + Using "larger_than" mode with a specific output size threshold: + + ```python + with RematScope(mode="larger_than", output_size_threshold=2048): + layer = keras.layers.Conv2D(64, (3, 3)) + output = layer(input_tensor) # Conv2D outputs larger than 2048 + ``` + + Nested scopes for fine-grained control: + + ```python + with RematScope(mode="full"): + # Create layers + layer1 = keras.layers.Dense(128, activation='relu') + output1 = layer1(input_tensor) # layer1 is fully rematerialized + with RematScope(mode="larger_than", output_size_threshold=512): + layer2 = keras.layers.Conv2D(32, (3, 3)) + output2 = layer2(output1) # layer2 is conditionally rematerialized + # if output > 512 + ``` + """ + + def __init__( + self, mode="full", output_size_threshold=1024, layer_names=None + ): + if mode not in { + "full", + "activations", + "larger_than", + "list_of_layers", + None, + }: + raise ValueError( + f"Invalid mode '{mode}'. Supported modes are: " + "'full', 'activations', 'larger_than', 'list_of_layers', or " + " None." + ) + self.mode = mode + self.output_size_threshold = output_size_threshold + self.layer_names = layer_names or [] + self._pop_on_exit = False + + def __enter__(self): + remat_scope_stack = global_state.get_global_attribute( + "remat_scope_stack", default=[], set_to_default=True + ) + remat_scope_stack.append(self) + self._pop_on_exit = True + return self + + def __exit__(self, *args, **kwargs): + if self._pop_on_exit: + remat_scope_stack = global_state.get_global_attribute( + "remat_scope_stack" + ) + remat_scope_stack.pop() + + +RematMode = namedtuple( + "RematMode", ["mode", "output_size_threshold", "layer_names"] +) + + +def get_current_remat_mode(): + """Get the current rematerialization mode and associated settings. + + Returns: + RematMode or None: The current rematerialization mode, or None if not + set. + """ + remat_scope_stack = global_state.get_global_attribute("remat_scope_stack") + if not remat_scope_stack: + return None + active_scope = remat_scope_stack[-1] + return RematMode( + active_scope.mode, + active_scope.output_size_threshold, + active_scope.layer_names, + ) + + +@keras_export("keras.remat") +def remat(f): + """Applies rematerialization to a function or layer for memory optimization. + + Rematerialization is a memory optimization technique that trades off + computation for memory. Instead of storing intermediate results + (e.g. activations) for backpropagation, they are recomputed during the + backward pass. This reduces peak memory usage at the cost of increased + computation time, allowing the training of larger models or using larger + batch sizes within the same memory constraints. + + Args: + f: A callable function, to which rematerialization is + applied. This is typically a computationally expensive operation + where intermediate states can be recomputed instead of stored. + + Returns: + A wrapped function that applies rematerialization. The returned + function defines a custom gradient, ensuring that during the backward + pass, the forward computation is recomputed as needed. + + Example: + + ```python + from keras import Model + class CustomRematLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.remat_function = remat(self.intermediate_function) + + def intermediate_function(self, x): + for _ in range(2): + x = x + x * 0.1 # Simple scaled transformation + return x + + def call(self, inputs): + return self.remat_function(inputs) + + # Define a simple model using the custom layer + inputs = layers.Input(shape=(4,)) + x = layers.Dense(4, activation="relu")(inputs) + x = CustomRematLayer()(x) # Custom layer with rematerialization + outputs = layers.Dense(1)(x) + + # Create and compile the model + model = Model(inputs=inputs, outputs=outputs) + model.compile(optimizer="sgd", loss="mse") + ``` + """ + return backend.core.remat(f) diff --git a/keras/src/backend/common/remat_test.py b/keras/src/backend/common/remat_test.py new file mode 100644 index 000000000000..2732f5da964a --- /dev/null +++ b/keras/src/backend/common/remat_test.py @@ -0,0 +1,118 @@ +import numpy as np + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.backend.common import global_state +from keras.src.backend.common.remat import RematScope +from keras.src.backend.common.remat import get_current_remat_mode +from keras.src.layers import activations + + +class TestRematScope(testing.TestCase): + def setUp(self): + """Reset global state before each test.""" + global_state.clear_session() + + def test_remat_scope_activation(self): + self.assertIsNone( + get_current_remat_mode() + ) # Initially, no mode is active + + with RematScope(mode="full"): + self.assertEqual( + get_current_remat_mode().mode, "full" + ) # Mode is set to "full" + + self.assertIsNone( + get_current_remat_mode() + ) # Mode is restored to None after scope ends + + def test_remat_scope_nested(self): + """Test nested scopes with different rematerialization modes.""" + with RematScope(mode="full"): + self.assertEqual( + get_current_remat_mode().mode, "full" + ) # Outer scope is "full" + + with RematScope(mode="activations"): + self.assertEqual( + get_current_remat_mode().mode, "activations" + ) # Inner scope is "activations" + + self.assertEqual( + get_current_remat_mode().mode, "full" + ) # Back to outer scope + + self.assertIsNone( + get_current_remat_mode() + ) # Mode is restored to None after all scopes + + def test_remat_scope_stack_management(self): + """Test that the remat_scope_stack is managed correctly.""" + self.assertIsNone( + global_state.get_global_attribute("remat_scope_stack") + ) # No stack initially + + with RematScope(mode="full"): + remat_stack = global_state.get_global_attribute("remat_scope_stack") + self.assertIsNotNone(remat_stack) # Stack is initialized + self.assertEqual(len(remat_stack), 1) # Stack contains one entry + + with RematScope(mode="activations"): + remat_stack = global_state.get_global_attribute( + "remat_scope_stack" + ) + self.assertEqual( + len(remat_stack), 2 + ) # Stack contains two entries + + remat_stack = global_state.get_global_attribute("remat_scope_stack") + self.assertEqual(len(remat_stack), 1) # Back to one entry + + self.assertEqual( + global_state.get_global_attribute("remat_scope_stack"), [] + ) # Stack is cleared + + def test_invalid_mode(self): + """Test that invalid rematerialization modes raise an error.""" + with self.assertRaises(ValueError): + RematScope(mode="invalid") # Invalid mode should raise ValueError + + +class RematTest(testing.TestCase): + def test_remat_basic_call(self): + if backend.backend() in ("openvino", "numpy"): + self.skipTest( + "remat is not supported in openvino and numpy backends." + ) + # Generate dummy data + data_size = 10**5 + x_train = np.random.normal(size=(data_size, 4)) + y_train = np.random.normal(size=(data_size, 1)) + + epochs = 5 + batch_size = 512 + # test applying remat + output_with_remat = backend.core.remat(activations.ReLU())(x_train) + output_without_remat = activations.ReLU()(x_train) + self.assertAllClose(output_with_remat, output_without_remat) + # test remat in a model + intermediate_function = backend.core.remat(activations.ReLU()) + inputs = layers.Input(shape=(4,)) + x = layers.Dense(4)(inputs) + x = layers.Lambda(intermediate_function)(x) + outputs = layers.Dense(1)(x) + model = models.Model(inputs=inputs, outputs=outputs) + model.predict(x_train) + model.compile(optimizer="sgd", loss="mse") + + # Train model + model.fit( + x_train, + y_train, + epochs=epochs, + batch_size=batch_size, + verbose=0, + ) diff --git a/keras/src/backend/common/stateless_scope.py b/keras/src/backend/common/stateless_scope.py new file mode 100644 index 000000000000..cbefd64a7551 --- /dev/null +++ b/keras/src/backend/common/stateless_scope.py @@ -0,0 +1,105 @@ +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +@keras_export("keras.StatelessScope") +class StatelessScope: + """Scope to prevent any update to Keras Variables. + + The values of variables to be used inside the scope + should be passed via the `state_mapping` argument, a + list of tuples `(k, v)` where `k` is a `Variable` + and `v` is the intended value for this variable + (a backend tensor). + + Updated values can be collected on scope exit via + `value = scope.get_current_value(variable)`. No updates + will be applied in-place to any variables for the duration + of the scope. + + Example: + + ```python + state_mapping = [(k, ops.ones(k.shape, k.dtype)) for k in model.weights] + with keras.StatelessScope(state_mapping) as scope: + outputs = model.some_function(inputs) + + # All model variables remain unchanged. Their new values can be + # collected via: + for k in model.weights: + new_value = scope.get_current_value(k) + print(f"New value for {k}: {new_value}) + ``` + """ + + def __init__( + self, + state_mapping=None, + collect_losses=False, + initialize_variables=True, + ): + from keras.src import backend + from keras.src.backend.common.variables import Variable + + self.collect_losses = collect_losses + self.initialize_variables = initialize_variables + self.losses = [] + self.state_mapping = {} + state_mapping = state_mapping or {} + for k, v in state_mapping: + if not isinstance(k, Variable): + raise ValueError( + "Invalid reference variable in StatelessScope: " + "all keys in argument `mapping` must be Variable " + f"instances. Received instead: {k}" + ) + if isinstance(v, Variable): + v = backend.cast(v.value, dtype=k.dtype) + else: + v = backend.convert_to_tensor(v, dtype=k.dtype) + if k.shape != v.shape: + raise ValueError( + "Invalid variable value in StatelessScope: " + "all values in argument `mapping` must be tensors with " + "a shape that matches the corresponding variable shape. " + f"For variable {k}, received invalid value {v} with shape " + f"{v.shape}." + ) + self.state_mapping[id(k)] = v + + def __enter__(self): + self.original_scope = get_stateless_scope() + global_state.set_global_attribute("stateless_scope", self) + return self + + def add_loss(self, loss): + self.losses.append(loss) + + def add_update(self, update): + variable, value = update + self.state_mapping[id(variable)] = value + + def get_current_value(self, variable): + return self.state_mapping.get(id(variable), None) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute( + "stateless_scope", self.original_scope + ) + if self.original_scope is None and self.initialize_variables: + # We're back in eager scope; + # if any variables were created within the stateless + # scope, we initialize them here. + from keras.src.backend.common.variables import ( + initialize_all_variables, + ) + + initialize_all_variables() + + +def in_stateless_scope(): + return global_state.get_global_attribute("stateless_scope") is not None + + +def get_stateless_scope(): + return global_state.get_global_attribute("stateless_scope") diff --git a/keras/src/backend/common/stateless_scope_test.py b/keras/src/backend/common/stateless_scope_test.py new file mode 100644 index 000000000000..68aaa397ff8c --- /dev/null +++ b/keras/src/backend/common/stateless_scope_test.py @@ -0,0 +1,55 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.stateless_scope import StatelessScope + + +class TestStatelessScope(testing.TestCase): + def test_basic_flow(self): + var1 = backend.Variable(np.zeros((2,))) + var2 = backend.Variable(np.zeros((2,))) + var_out = backend.Variable(np.zeros((2,))) + + value1 = ops.ones(shape=(2,)) + value2 = ops.ones(shape=(2,)) + with StatelessScope( + state_mapping=[(var1, value1), (var2, value2)] + ) as scope: + out = var1 + var2 + var_out.assign(out) + var_out_value = var_out + 0.0 + # Inside scope: new value is used. + self.assertAllClose(var_out_value, 2 * np.ones((2,))) + + # Out of scope: old value is used. + var_out_value = var_out + 0.0 + self.assertAllClose(var_out_value, np.zeros((2,))) + + # Updates are tracked. + var_out_value = scope.get_current_value(var_out) + self.assertAllClose(var_out_value, 2 * np.ones((2,))) + + # Updates can be reapplied. + var_out.assign(scope.get_current_value(var_out)) + self.assertAllClose(var_out_value, 2 * np.ones((2,))) + + def test_invalid_key_in_state_mapping(self): + # var1 = backend.Variable(np.zeros((2,))) + invalid_key = "not_a_keras_variable" + value1 = ops.ones(shape=(2,)) + + with self.assertRaisesRegex( + ValueError, "all keys in argument `mapping` must be Variable" + ): + StatelessScope(state_mapping=[(invalid_key, value1)]) + + def test_invalid_value_shape_in_state_mapping(self): + var1 = backend.Variable(np.zeros((2,))) + invalid_value = ops.ones(shape=(3,)) # Incorrect shape + + with self.assertRaisesRegex( + ValueError, "all values in argument `mapping` must be tensors with" + ): + StatelessScope(state_mapping=[(var1, invalid_value)]) diff --git a/keras/src/backend/common/symbolic_scope.py b/keras/src/backend/common/symbolic_scope.py new file mode 100644 index 000000000000..15cd7a5ee059 --- /dev/null +++ b/keras/src/backend/common/symbolic_scope.py @@ -0,0 +1,23 @@ +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +@keras_export("keras.SymbolicScope") +class SymbolicScope: + """Scope to indicate the symbolic stage.""" + + def __enter__(self): + self.original_scope = get_symbolic_scope() + global_state.set_global_attribute("symbolic_scope", self) + return self + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute("symbolic_scope", self.original_scope) + + +def in_symbolic_scope(): + return global_state.get_global_attribute("symbolic_scope") is not None + + +def get_symbolic_scope(): + return global_state.get_global_attribute("symbolic_scope") diff --git a/keras/src/backend/common/symbolic_scope_test.py b/keras/src/backend/common/symbolic_scope_test.py new file mode 100644 index 000000000000..72b8746cb96e --- /dev/null +++ b/keras/src/backend/common/symbolic_scope_test.py @@ -0,0 +1,25 @@ +import numpy as np + +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope + + +class TestSymbolicScope(testing.TestCase): + def test_basic_flow(self): + # Define a function that behaves differently according to + # `in_symbolic_scope`. + def compute_loss(y, y_pred): + if in_symbolic_scope(): + return ops.zeros_like(y) + return ops.add(y, y_pred) + + y = ops.ones(shape=(2,)) + y_pred = ops.ones(shape=(2,)) + with SymbolicScope(): + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, np.zeros((2,))) + + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, 2 * np.ones((2,))) diff --git a/keras/src/backend/common/tensor_attributes.py b/keras/src/backend/common/tensor_attributes.py new file mode 100644 index 000000000000..8d3496198e1d --- /dev/null +++ b/keras/src/backend/common/tensor_attributes.py @@ -0,0 +1,36 @@ +import weakref + +from keras.src.backend.common import global_state + + +def _clear_tensor_attr(tensor_id, attr): + attr_dict = global_state.get_global_attribute(f"{attr}_dict") + if attr_dict is not None and tensor_id in attr_dict: + del attr_dict[tensor_id] + + +def set_tensor_attr(tensor, attr, value): + try: + setattr(tensor, attr, value) + except AttributeError: + attr_dict = global_state.get_global_attribute(f"{attr}_dict") + if attr_dict is None: + if value is None: + return + attr_dict = {} + global_state.set_global_attribute(f"{attr}_dict", attr_dict) + if value is not None: + attr_dict[id(tensor)] = value + weakref.finalize(tensor, _clear_tensor_attr, id(tensor), attr) + elif id(tensor) in attr_dict: + del attr_dict[id(tensor)] + + +def get_tensor_attr(tensor, attr): + if not hasattr(tensor, attr): + attr_dict = global_state.get_global_attribute(f"{attr}_dict") + if attr_dict is not None: + return attr_dict.get(id(tensor), None) + else: + return None + return getattr(tensor, attr, None) diff --git a/keras/src/backend/common/thread_safe_test.py b/keras/src/backend/common/thread_safe_test.py new file mode 100644 index 000000000000..b5775cca3586 --- /dev/null +++ b/keras/src/backend/common/thread_safe_test.py @@ -0,0 +1,29 @@ +import concurrent + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing + + +class TestThreadSafe(testing.TestCase): + def test_is_thread_safe(self): + if backend.IS_THREAD_SAFE: + executor = concurrent.futures.ThreadPoolExecutor() + + def sum(x, axis): + return ops.sum(x, axis=axis) + + futures = [] + + for i in range(10000): + futures.clear() + x = ops.convert_to_tensor(np.random.rand(100, 100)) + futures.append(executor.submit(sum, x, 1)) + x = ops.convert_to_tensor(np.random.rand(100)) + futures.append(executor.submit(sum, x, 0)) + concurrent.futures.wait( + futures, return_when=concurrent.futures.ALL_COMPLETED + ) + [future.result() for future in futures] diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py new file mode 100644 index 000000000000..84289a35f64c --- /dev/null +++ b/keras/src/backend/common/variables.py @@ -0,0 +1,686 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend import config +from keras.src.backend.common import dtypes +from keras.src.backend.common import global_state +from keras.src.backend.common.name_scope import current_path +from keras.src.backend.common.stateless_scope import get_stateless_scope +from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.utils.module_utils import tensorflow as tf +from keras.src.utils.naming import auto_name + + +class Variable: + """Represents a backend-agnostic variable in Keras. + + A `Variable` acts as a container for state. It holds a tensor value and can + be updated. With the JAX backend, variables are used to implement + "functionalization", the pattern of lifting stateful operations out of + a piece of computation to turn it into a stateless function. + + Args: + initializer: Initial value or callable for initialization. + If a callable is used, it should take the arguments + `shape` and `dtype`. + shape: Optional. Tuple for the variable's shape. + Required if `initializer` is a callable. + dtype: Optional. Data type of the variable. Defaults to the global float + dtype type (`"float32"` if never configured). + trainable: Optional. Boolean indicating if variable is trainable. + Defaults to `True`. + autocast: Optional. Boolean indicating whether the variable supports + autocasting. If `True`, the layer may first convert the variable + to the compute data type when accessed. Defaults to `True`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"` specifying how a distributed + variable will be aggregated. This serves as a semantic annotation, + to be taken into account by downstream backends or users. Defaults + to `"none"`. + name: Optional. A unique name for the variable. Automatically generated + if not set. + + Attributes: + shape: The shape of the variable (tuple of integers). + ndim: The number of dimensions of the variable (integer). + dtype: The data type of the variable (string). + trainable: Whether the variable is trainable (boolean). + autocast: Whether the variable supports autocasting (boolean). + aggregation: How a distributed variable will be aggregated (string). + value: The current value of the variable (NumPy array or tensor). + name: The name of the variable (string). + path: The path of the variable within the Keras model or layer (string). + kwargs: Additional backend-specific keyword arguments. + + Examples: + + **Initializing a `Variable` with a NumPy array:** + + ```python + import numpy as np + import keras + initial_array = np.ones((3, 3)) + variable_from_array = keras.Variable(initializer=initial_array) + ``` + + **Using a Keras initializer to create a `Variable`:** + + ```python + from keras.src.initializers import Ones + variable_from_initializer = keras.Variable( + initializer=Ones(), shape=(3, 3), dtype="float32" + ) + ``` + + **Updating the value of a `Variable`:** + + ```python + new_value = np.zeros((3, 3), dtype="float32") + variable_from_array.assign(new_value) + ``` + + **Marking a `Variable` as non-trainable:** + + ```python + non_trainable_variable = keras.Variable( + initializer=np.ones((3, 3), dtype="float32"), trainable=False + ) + ``` + """ + + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + **kwargs, + ): + del kwargs + name = name or auto_name(self.__class__.__name__) + if not isinstance(name, str) or "/" in name: + raise ValueError( + "Argument `name` must be a string and " + "cannot contain character `/`. " + f"Received: name={name}" + ) + if aggregation not in ( + None, + "none", + "mean", + "sum", + "only_first_replica", + ): + raise ValueError( + "Invalid value for argument `aggregation`. Expected " + "one of `None`, `'none'`, `'mean'`, `'sum'`, " + "`'only_first_replica'`. " + f"Received: aggregation={aggregation}" + ) + if aggregation is None: + aggregation = "none" + if synchronization not in ( + None, + "none", + "on_read", + "on_write", + "auto", + ): + raise ValueError( + "Invalid value for argument `synchronization`. Expected " + "one of `None`, `'none'`, `'on_read'`, `'on_write'`, " + "`'auto'`. " + f"Received: synchronization={synchronization}" + ) + if synchronization is None: + synchronization = "none" + self._name = name + parent_path = current_path() + if parent_path: + self._path = f"{parent_path}/{name}" + else: + self._path = name + self._shape = None + self._initializer = None + self._regularizer = None + self._constraint = None + self._trainable = bool(trainable) + self._autocast = bool(autocast) + self._aggregation = aggregation + self._synchronization = synchronization + # `self._overwrite_with_gradient` is an internal property to determine + # whether this variable should be overwritten by the computed gradient. + # Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py + self._overwrite_with_gradient = False + if isinstance(initializer, str): + from keras.src import initializers + + initializer = initializers.get(initializer) + if callable(initializer): + if shape is None: + raise ValueError( + "When creating a Variable from an initializer, " + "the `shape` argument should be specified. " + f"Received: initializer={initializer} " + f"and shape={shape}" + ) + else: + initializer = self._convert_to_tensor(initializer, dtype=dtype) + # If dtype is None and `initializer` is an array, use its dtype. + if dtype is None: + dtype = initializer.dtype + self._dtype = standardize_dtype(dtype) + + if in_stateless_scope(): + if callable(initializer): + self._value = None + self._initializer = initializer + self._shape = self._validate_shape(shape) + register_uninitialized_variable(self) + else: + raise ValueError( + "You are attempting to create a variable " + "while in a stateless scope. This is disallowed. " + "Make sure that all variables are created " + "before you start using your layer/model objects.\n\n" + "In some cases, you might be seeing this error " + "because you need to " + "implement a `def build(self, input_shape)` method " + "on your layer/model, which will " + "create its variables.\n\n" + "In some other cases, you might be seeing this error " + "because you are instantiating a `Variable` and " + "assigning it to a layer without going through " + "self.add_variable()/self.add_weight(). Always prefer " + "using these methods " + "(with a `shape` and `initializer` argument)." + ) + else: + if callable(initializer): + self._shape = self._validate_shape(shape) + self._initialize_with_initializer(initializer) + else: + self._initialize(initializer) + self._shape = self._validate_shape(self._value.shape) + self._ndim = len(self._shape) + + def _deferred_initialize(self): + if self._value is not None: + # If NNX is enabled, it's possible the variable was already + # initialized by a concrete call. In this case, _deferred_initialize + # returns early and does not raise an error. + if config.is_nnx_enabled(): + return + raise ValueError(f"Variable {self.path} is already initialized.") + + if in_stateless_scope(): + raise ValueError( + "You are attempting to initialize a variable " + "while in a stateless scope. This is disallowed. " + "Make sure that all variables are initialized " + "before you start using your layer/model objects." + ) + self._initialize_with_initializer(self._initializer) + self._initializer = None + + def _validate_shape(self, shape): + shape = standardize_shape(shape) + if None in shape: + raise ValueError( + "Shapes used to initialize variables must be " + "fully-defined (no `None` dimensions). Received: " + f"shape={shape} for variable path='{self.path}'" + ) + return shape + + def _maybe_autocast(self, value): + autocast_scope = get_autocast_scope() + if self._autocast and autocast_scope is not None: + return autocast_scope.maybe_cast(value) + return value + + def numpy(self): + return np.array(self) + + @property + def aggregation(self): + """The strategy for aggregating this variable.""" + return self._aggregation + + @property + def synchronization(self): + """The strategy for synchronizing this variable.""" + return self._synchronization + + @property + def value(self): + """The current value of the variable (numpy array or backend tensor).""" + if in_stateless_scope(): + scope = get_stateless_scope() + value = scope.get_current_value(self) + if value is not None: + return self._maybe_autocast(value) + if self._value is None: + # Uninitialized variable. Return a placeholder. + # This is fine because it's only ever used + # in during shape inference / graph tracing + # (anything else would be a bug, to be fixed.) + return self._maybe_autocast( + self._initializer(self._shape, dtype=self._dtype) + ) + return self._maybe_autocast(self._value) + + def assign(self, value): + value = self._convert_to_tensor(value, dtype=self.dtype) + if not shape_equal(value.shape, self.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self.value.shape}, " + f"Received: value.shape={value.shape}. " + f"Target variable: {self}" + ) + if in_stateless_scope(): + scope = get_stateless_scope() + scope.add_update((self, value)) + else: + self._direct_assign(value) + return value + + def assign_add(self, value): + return self.assign(self + value) + + def assign_sub(self, value): + return self.assign(self - value) + + @property + def dtype(self): + """The data type of the variable.""" + autocast_scope = get_autocast_scope() + if ( + self._autocast + and autocast_scope is not None + and is_float_dtype(self._dtype) + ): + dtype = autocast_scope.dtype + else: + dtype = self._dtype + return backend.standardize_dtype(dtype) + + @property + def shape(self): + """The shape of the variable.""" + return self._shape + + @property + def ndim(self): + """The number of dimensions of the variable.""" + return self._ndim + + @property + def trainable(self): + """Whether the variable is trainable.""" + return self._trainable + + @trainable.setter + def trainable(self, value): + self._trainable = bool(value) + + @property + def name(self): + """The name of the variable.""" + return self._name + + @property + def path(self): + """The path of the variable within the Keras model or layer.""" + return self._path + + @property + def overwrite_with_gradient(self): + """Whether this variable should be overwritten by the gradient. + + This property is designed for a special case where we want to overwrite + the variable directly with its computed gradient. For example, in float8 + training, new `scale` and `amax_history` are computed as gradients, and + we want to overwrite them directly instead of following the typical + procedure such as gradient descent with a learning rate, gradient + clipping and weight decaying. + """ + return self._overwrite_with_gradient + + @overwrite_with_gradient.setter + def overwrite_with_gradient(self, value): + if not isinstance(value, bool): + raise TypeError( + "`overwrite_with_gradient` must be a boolean. " + f"Received: {value}" + ) + self._overwrite_with_gradient = value + + @property + def regularizer(self): + return self._regularizer + + @regularizer.setter + def regularizer(self, value): + from keras.src.regularizers import Regularizer + + if value is not None and not isinstance(value, Regularizer): + raise ValueError( + "Invalid value for attribute `regularizer`. Expected an " + "instance of `keras.regularizers.Regularizer`, or `None`. " + f"Received: regularizer={value}" + ) + self._regularizer = value + + @property + def constraint(self): + return self._constraint + + @constraint.setter + def constraint(self, value): + from keras.src.constraints import Constraint + + if value is not None and not isinstance(value, Constraint): + raise ValueError( + "Invalid value for attribute `constraint`. Expected an " + "instance of `keras.constraints.Constraint`, or `None`. " + f"Received: constraint={value}" + ) + self._constraint = value + + def __repr__(self): + value = None + if hasattr(self, "_value") and self._value is not None: + value = backend.core.convert_to_numpy(self._value) + value_str = f", value={value}" if value is not None else "" + return ( + f"" + ) + + def _initialize(self, value): + raise NotImplementedError + + def _initialize_with_initializer(self, initializer): + value = self._convert_to_tensor( + initializer(self._shape, dtype=self._dtype) + ) + self._initialize(value) + + def _convert_to_tensor(self, value, dtype=None): + raise NotImplementedError + + def __getitem__(self, idx): + return self.value.__getitem__(idx) + + def __int__(self): + if self.ndim > 0: + raise TypeError( + "Only scalar arrays can be converted to Python scalars. " + f"Got: shape={self.shape}" + ) + return int(self.value) + + def __float__(self): + if self.ndim > 0: + raise TypeError( + "Only scalar arrays can be converted to Python scalars. " + f"Got: shape={self.shape}" + ) + return float(self.value) + + def __array__(self, dtype=None): + # We can't directly use self.value.__array__ here because of scalar. + # Numpy require this method to return as array like object. In the case + # of scalar, it will fail the type checking from numpy. We need to + # return a 0d array via numpy. + return np.asarray(self.value.__array__(dtype)) + + def __bool__(self): + raise TypeError("A Keras Variable cannot be used as a boolean.") + + def __neg__(self): + return self.value.__neg__() + + def __pos__(self): + return self.value + + def __abs__(self): + return self.value.__abs__() + + def __invert__(self): + return self.value.__invert__() + + def __eq__(self, other): + return backend.numpy.equal(self.value, other) + + def __ne__(self, other): + return backend.numpy.not_equal(self.value, other) + + def __lt__(self, other): + return backend.numpy.less(self.value, other) + + def __le__(self, other): + return backend.numpy.less_equal(self.value, other) + + def __gt__(self, other): + return backend.numpy.greater(self.value, other) + + def __ge__(self, other): + return backend.numpy.greater_equal(self.value, other) + + def __add__(self, other): + return backend.numpy.add(self.value, other) + + def __radd__(self, other): + return backend.numpy.add(other, self.value) + + def __sub__(self, other): + return backend.numpy.subtract(self.value, other) + + def __rsub__(self, other): + return backend.numpy.subtract(other, self.value) + + def __mul__(self, other): + return backend.numpy.multiply(self.value, other) + + def __rmul__(self, other): + return backend.numpy.multiply(other, self.value) + + def __truediv__(self, other): + return backend.numpy.true_divide(self.value, other) + + def __rtruediv__(self, other): + return backend.numpy.true_divide(other, self.value) + + def __floordiv__(self, other): + return backend.numpy.floor_divide(self.value, other) + + def __rfloordiv__(self, other): + return backend.numpy.floor_divide(other, self.value) + + def __mod__(self, other): + return backend.numpy.mod(self.value, other) + + def __rmod__(self, other): + return backend.numpy.mod(other, self.value) + + def __pow__(self, other): + return backend.numpy.power(self.value, other) + + def __rpow__(self, other): + return backend.numpy.power(other, self.value) + + def __matmul__(self, other): + return backend.numpy.matmul(self.value, other) + + def __rmatmul__(self, other): + return backend.numpy.matmul(other, self.value) + + def __and__(self, other): + return backend.numpy.logical_and(self.value, other) + + def __rand__(self, other): + return backend.numpy.logical_and(other, self.value) + + def __or__(self, other): + return backend.numpy.logical_or(self.value, other) + + def __ror__(self, other): + return backend.numpy.logical_or(other, self.value) + + def __xor__(self, other): + return backend.numpy.logical_xor(self.value, other) + + def __rxor__(self, other): + return backend.numpy.logical_xor(other, self.value) + + def __round__(self, ndigits=None): + decimals = ndigits or 0 + return backend.numpy.round(self.value, decimals=decimals) + + +def register_uninitialized_variable(variable): + uninitialized_variables = global_state.get_global_attribute( + "uninitialized_variables", [], set_to_default=True + ) + uninitialized_variables.append(variable) + + +def initialize_all_variables(): + collection = global_state.get_global_attribute("uninitialized_variables") + if collection: + for v in collection: + v._deferred_initialize() + global_state.set_global_attribute("uninitialized_variables", []) + + +@keras_export( + ["keras.utils.standardize_dtype", "keras.backend.standardize_dtype"] +) +def standardize_dtype(dtype): + if dtype is None: + return config.floatx() + dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype) + if hasattr(dtype, "name"): + dtype = dtype.name + elif hasattr(dtype, "__name__"): + dtype = dtype.__name__ + elif hasattr(dtype, "__str__") and ( + "torch" in str(dtype) or "jax.numpy" in str(dtype) + ): + dtype = str(dtype).split(".")[-1] + + if dtype not in dtypes.ALLOWED_DTYPES: + raise ValueError(f"Invalid dtype: {dtype}") + return dtype + + +def standardize_shape(shape): + if not isinstance(shape, tuple): + if shape is None: + raise ValueError("Undefined shapes are not supported.") + if not hasattr(shape, "__iter__"): + raise ValueError(f"Cannot convert '{shape}' to a shape.") + if config.backend() == "tensorflow": + if isinstance(shape, tf.TensorShape): + # `tf.TensorShape` may contain `Dimension` objects. + # We need to convert the items in it to either int or `None` + shape = shape.as_list() + shape = tuple(shape) + + if config.backend() == "jax": + # Replace `_DimExpr` (dimension expression) with None + from jax import export as jax_export + + shape = tuple( + None if jax_export.is_symbolic_dim(d) else d for d in shape + ) + + if config.backend() == "torch": + # `shape` might be `torch.Size`. We need to convert the items in it to + # either int or `None` + shape = tuple(map(lambda x: int(x) if x is not None else None, shape)) + + for e in shape: + if e is None: + continue + if not is_int_dtype(type(e)): + raise ValueError( + f"Cannot convert '{shape}' to a shape. " + f"Found invalid entry '{e}' of type '{type(e)}'. " + ) + if e < 0: + raise ValueError( + f"Cannot convert '{shape}' to a shape. " + "Negative dimensions are not allowed." + ) + return shape + + +def shape_equal(a_shape, b_shape): + """Return whether a_shape == b_shape (allows None entries).""" + if len(a_shape) != len(b_shape): + return False + for e1, e2 in zip(a_shape, b_shape): + if e1 is not None and e2 is not None and e1 != e2: + return False + return True + + +@keras_export("keras.backend.is_float_dtype") +def is_float_dtype(dtype): + dtype = standardize_dtype(dtype) + return dtype.startswith("float") or dtype.startswith("bfloat") + + +@keras_export("keras.backend.is_int_dtype") +def is_int_dtype(dtype): + dtype = standardize_dtype(dtype) + return dtype.startswith("int") or dtype.startswith("uint") + + +def get_autocast_scope(): + return global_state.get_global_attribute("autocast_scope") + + +class AutocastScope: + """Context manager that enables the autocasting of float variables. + + Under this context manager, float `Variables`s will be cast to `dtype` + (note that `dtype` must also be float). + """ + + def __init__(self, dtype): + if dtype is not None: + dtype = standardize_dtype(dtype) + if not is_float_dtype(dtype): + raise ValueError( + "`AutocastScope` can only be used with " + "a floating-point target dtype, such as 'float16'. " + f"Received: dtype={dtype}" + ) + self.dtype = dtype + self.original_scope = None + + def maybe_cast(self, value): + from keras.src import backend + + if self.dtype is not None and is_float_dtype(value.dtype): + return backend.cast(value, dtype=self.dtype) + return value + + def __enter__(self): + self.original_scope = get_autocast_scope() + global_state.set_global_attribute("autocast_scope", self) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute("autocast_scope", self.original_scope) diff --git a/keras/src/backend/common/variables_test.py b/keras/src/backend/common/variables_test.py new file mode 100644 index 000000000000..519a69bb3ade --- /dev/null +++ b/keras/src/backend/common/variables_test.py @@ -0,0 +1,1266 @@ +import itertools + +import numpy as np +import pytest +from absl.testing import parameterized + +from conftest import skip_if_backend +from keras.src import backend +from keras.src import initializers +from keras.src import ops +from keras.src.backend.common import dtypes +from keras.src.backend.common.variables import AutocastScope +from keras.src.backend.common.variables import shape_equal +from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.common.variables import standardize_shape +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product + + +class VariableInitializationTest(test_case.TestCase): + """Tests for Variable.__init__()""" + + def test_deferred_initialization(self): + """Tests deferred initialization of variables.""" + with backend.StatelessScope(): + v = backend.Variable( + initializer=initializers.RandomNormal(), shape=(2, 2) + ) + self.assertEqual(v._value, None) + # Variables can nevertheless be accessed + _ = v + 1 + self.assertEqual(v._value.shape, (2, 2)) + + with self.assertRaisesRegex(ValueError, "while in a stateless scope"): + with backend.StatelessScope(): + v = backend.Variable(initializer=0) + + def test_variable_initialization_with_numpy_array(self): + """Test variable init with numpy array initializer.""" + v = backend.Variable( + initializer=np.ones((2, 2), dtype=np.int32), trainable=False + ) + self.assertAllClose(v.value, np.ones((2, 2))) + self.assertEqual(v.dtype, "int32") + + def test_variable_initialization_with_native_array(self): + """Test variable init with native array initializer.""" + v = backend.Variable( + initializer=ops.ones((2, 2), dtype="int32"), trainable=False + ) + self.assertAllClose(v.value, np.ones((2, 2))) + self.assertEqual(v.dtype, "int32") + + def test_variable_initialization_with_python_array(self): + """Test variable init with python array initializer.""" + v = backend.Variable(initializer=[[1, 1], [1, 1]], trainable=False) + self.assertAllClose(v.value, np.ones((2, 2))) + self.assertEqual(v.dtype, "int32") + v = backend.Variable( + initializer=[[1.0, 1.0], [1.0, 1.0]], trainable=False + ) + self.assertAllClose(v.value, np.ones((2, 2))) + self.assertEqual(v.dtype, "float32") + + def test_variable_initialization_with_lambda_expression(self): + # Test Python number + v = backend.Variable( + initializer=lambda *a, **kw: 1.0, + shape=(), + dtype="float32", + ) + self.assertAllClose(v.value, 1.0) + self.assertEqual(v.dtype, "float32") + + # Test Python array + v = backend.Variable( + initializer=lambda *a, **kw: [1.0], + shape=(1,), + dtype="float32", + ) + self.assertAllClose(v.value, np.ones((1,))) + self.assertEqual(v.dtype, "float32") + + # Test numpy array + v = backend.Variable( + initializer=lambda *a, **kw: np.ones((1,)), + shape=(1,), + dtype="float32", + ) + self.assertAllClose(v.value, np.ones((1,))) + self.assertEqual(v.dtype, "float32") + + # Test backend array + v = backend.Variable( + initializer=lambda *a, **kw: ops.ones((1,)), + shape=(1,), + dtype="float32", + ) + self.assertAllClose(v.value, np.ones((1,))) + self.assertEqual(v.dtype, "float32") + + def test_variable_initialization_with_strings(self): + """Test variable init with non-callable initializer.""" + v = backend.Variable(initializer="ones", shape=(2, 2)) + self.assertAllClose(v.value, np.ones((2, 2))) + + def test_variable_initialization_with_non_trainable(self): + """Test variable initialization with non-trainable flag.""" + v = backend.Variable(initializer=np.ones((2, 2)), trainable=False) + self.assertFalse(v.trainable) + + def test_variable_initialization_without_shape(self): + """Test variable init without a shape.""" + with self.assertRaisesRegex( + ValueError, + "When creating a Variable from an initializer, the `shape` ", + ): + backend.Variable(initializer=initializers.RandomNormal()) + + def test_deferred_initialize_already_initialized(self): + """Test deferred init on an already initialized variable.""" + v = backend.Variable(initializer=np.ones((2, 2))) + with self.assertRaisesRegex( + ValueError, f"Variable {v.path} is already initialized." + ): + v._deferred_initialize() + + def test_variable_initialize(self): + """Test initializing a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + init_value = np.array([4.0, 5.0, 6.0]) + v._initialize(value=init_value) + self.assertAllClose(v.value, init_value) + + def test_variable_without_shape_from_callable_initializer(self): + """Test that Variable raises error + if shape is not provided for callable initializer.""" + with self.assertRaisesRegex( + ValueError, "When creating a Variable from an initializer" + ): + backend.Variable(initializer=lambda: np.ones((2, 2))) + + +class VariablePropertiesTest(test_case.TestCase): + """Tests for Variable._deferred_initialize Variable._maybe_autocast""" + + @skip_if_backend( + "openvino", "Can not constant fold eltwise node by CPU plugin" + ) + def test_deferred_assignment(self): + """Tests deferred assignment to variables.""" + with backend.StatelessScope() as scope: + v = backend.Variable( + initializer=initializers.RandomNormal(), shape=(2, 2) + ) + self.assertEqual(v._value, None) + v.assign(np.zeros((2, 2))) + v.assign_add(2 * np.ones((2, 2))) + v.assign_sub(np.ones((2, 2))) + out = scope.get_current_value(v) + self.assertAllClose(out, np.ones((2, 2))) + + def test_trainable_setter(self): + """Tests the trainable setter.""" + v = backend.Variable( + initializer=initializers.RandomNormal(), + shape=(2, 2), + ) + self.assertTrue(v.trainable) + v.trainable = False + self.assertFalse(v.trainable) + + if backend.backend() == "torch": + v.trainable = True + self.assertTrue(v._value.requires_grad) + v.trainable = False + self.assertFalse(v._value.requires_grad) + + def test_autocasting(self): + """Tests autocasting of float variables.""" + v = backend.Variable( + initializer=initializers.RandomNormal(), + shape=(2, 2), + dtype="float32", + ) + self.assertEqual(v.dtype, "float32") + self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32") + with AutocastScope("float16"): + self.assertEqual( + backend.standardize_dtype(v.value.dtype), "float16" + ) + self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32") + + # Test non-float variables are not affected + v = backend.Variable( + initializer=initializers.Ones(), + shape=(2, 2), + dtype="int32", + trainable=False, + ) + self.assertEqual(v.dtype, "int32") + self.assertEqual(backend.standardize_dtype(v.value.dtype), "int32") + + with AutocastScope("float16"): + self.assertEqual(backend.standardize_dtype(v.value.dtype), "int32") + + # Test autocast argument + v = backend.Variable( + initializer=initializers.RandomNormal(), + shape=(2, 2), + dtype="float32", + autocast=False, + ) + self.assertEqual(v.dtype, "float32") + self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32") + with AutocastScope("float16"): + self.assertEqual( + backend.standardize_dtype(v.value.dtype), + "float32", # ignore AutocastScope + ) + self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32") + + @parameterized.parameters( + *( + ( + dtype + for dtype in dtypes.ALLOWED_DTYPES + if dtype not in ["string", "complex64", "complex28"] + ) + ) + ) + def test_standardize_dtype(self, dtype): + """Tests standardize_dtype for all ALLOWED_DTYPES except string.""" + if backend.backend() == "torch" and dtype in ( + "uint16", + "uint32", + "uint64", + "complex64", + "complex128", + ): + self.skipTest(f"torch backend does not support dtype {dtype}") + + if backend.backend() == "jax": + if dtype in ("complex128",): + self.skipTest(f"jax backend does not support dtype {dtype}") + import jax + + if not jax.config.x64_enabled and "64" in dtype: + self.skipTest( + f"jax backend does not support {dtype} without x64 enabled" + ) + + if backend.backend() == "openvino" and dtype in ( + "complex64", + "complex128", + ): + self.skipTest(f"openvino backend does not support dtype {dtype}") + + x = backend.convert_to_tensor(np.zeros(()), dtype) + actual = standardize_dtype(x.dtype) + self.assertEqual(actual, dtype) + + def test_standardize_dtype_with_torch_dtype(self): + """Tests dtype standardization with PyTorch dtypes.""" + import torch + + x = torch.randn(4, 4) + backend.standardize_dtype(x.dtype) + + def test_name_validation(self): + """Tests validation of variable names.""" + with self.assertRaisesRegex( + ValueError, "Argument `name` must be a string" + ): + backend.Variable( + initializer=initializers.RandomNormal(), name=12345 + ) + + with self.assertRaisesRegex(ValueError, "cannot contain character `/`"): + backend.Variable( + initializer=initializers.RandomNormal(), name="invalid/name" + ) + + def test_standardize_shape_with_none(self): + """Tests standardizing shape with None.""" + with self.assertRaisesRegex( + ValueError, "Undefined shapes are not supported." + ): + standardize_shape(None) + + def test_standardize_shape_with_non_iterable(self): + """Tests shape standardization with non-iterables.""" + with self.assertRaisesRegex( + ValueError, "Cannot convert '42' to a shape." + ): + standardize_shape(42) + + def test_standardize_shape_with_valid_input(self): + """Tests standardizing shape with valid input.""" + shape = [3, 4, 5] + standardized_shape = standardize_shape(shape) + self.assertEqual(standardized_shape, (3, 4, 5)) + + def test_standardize_shape_with_negative_entry(self): + """Tests standardizing shape with negative entries.""" + with self.assertRaisesRegex( + ValueError, + "Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions", + ): + standardize_shape([3, 4, -5]) + + def test_shape_equal_length_mismatch(self): + """Test mismatch in lengths of shapes.""" + self.assertFalse(shape_equal((3, 2), (3, 2, 4))) + self.assertFalse(shape_equal((), (3,))) + self.assertFalse(shape_equal((3, 2, 4, 5), (3, 2, 4))) + + def test_autocast_scope_with_non_float_dtype(self): + """Tests autocast scope with non-float dtype.""" + with self.assertRaisesRegex( + ValueError, + "`AutocastScope` can only be used with a floating-point", + ): + _ = AutocastScope("int32") + + def test_variable_path_creation(self): + """Test path creation for a variable.""" + v = backend.Variable(initializer=np.ones((2, 2)), name="test_var") + self.assertEqual(v.path, "test_var") + + with backend.name_scope("test_scope"): + v = backend.Variable(initializer=np.ones((2, 2)), name="test_var") + self.assertEqual(v.path, "test_scope/test_var") + + def test_overwrite_with_gradient_setter(self): + v = backend.Variable( + initializer=initializers.RandomNormal(), + shape=(2, 2), + ) + self.assertFalse(v.overwrite_with_gradient) + v.overwrite_with_gradient = True + self.assertTrue(v.overwrite_with_gradient) + + with self.assertRaisesRegex(TypeError, "must be a boolean."): + v.overwrite_with_gradient = "true" + + +class VariableNumpyValueAndAssignmentTest(test_case.TestCase): + """tests for Variable.numpy(), Variable.value() and Variable.assign()""" + + def test_variable_numpy(self): + """Test retrieving the value of a variable as a numpy array.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertIsInstance(v.numpy(), np.ndarray) + self.assertAllClose(v.numpy(), np.array([1.0, 2.0, 3.0])) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Tests for MirroredVariable under tf backend", + ) + def test_variable_numpy_scalar(self): + from keras.src.utils.module_utils import tensorflow as tf + + strategy = tf.distribute.MirroredStrategy(["cpu:0", "cpu:1"]) + with strategy.scope(): + v = backend.Variable(initializer=0.0) + + np_value = backend.convert_to_numpy(v) + self.assertIsInstance(np_value, np.ndarray) + self.assertAllClose(np_value, 0.0) + + def test_variable_value(self): + """Test retrieving the value of a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v.value, np.array([1.0, 2.0, 3.0])) + + def test_variable_assign(self): + """Test assigning a new value to a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v.assign(np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v.value, np.array([4.0, 5.0, 6.0])) + + def test_variable_assign_return(self): + """Test assigning a new value and returning.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + r = v.assign(np.array([4.0, 5.0, 6.0])) + self.assertAllClose(r, np.array([4.0, 5.0, 6.0])) + + def test_variable_assign_add(self): + """Test the assign_add method on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v.assign_add(np.array([1.0, 1.0, 1.0])) + self.assertAllClose(v.value, np.array([2.0, 3.0, 4.0])) + + def test_variable_assign_add_return(self): + """Test assign_add a new value and returning.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + r = v.assign_add(np.array([1.0, 1.0, 1.0])) + self.assertAllClose(r, np.array([2.0, 3.0, 4.0])) + + def test_variable_assign_sub(self): + """Test the assign_sub method on a variable.""" + v = backend.Variable(initializer=np.array([2.0, 3.0, 4.0])) + v.assign_sub(np.array([1.0, 1.0, 1.0])) + self.assertAllClose(v.value, np.array([1.0, 2.0, 3.0])) + + def test_variable_assign_sub_return(self): + """Test assign_sub a new value and returning.""" + v = backend.Variable(initializer=np.array([2.0, 3.0, 4.0])) + r = v.assign_sub(np.array([1.0, 1.0, 1.0])) + self.assertAllClose(r, np.array([1.0, 2.0, 3.0])) + + def test_deferred_initialize_within_stateless_scope(self): + """Test deferred init within a stateless scope.""" + with backend.StatelessScope(): + v = backend.Variable( + initializer=initializers.RandomNormal(), shape=(2, 2) + ) + with self.assertRaisesRegex( + ValueError, + "You are attempting to initialize a variable " + "while in a stateless scope. This is disallowed.", + ): + v._deferred_initialize() + + +class VariableDtypeShapeNdimRepr(test_case.TestCase): + """tests for dtype, shape, ndim, __repr__""" + + def test_variable_dtype(self): + """Test retrieving the dtype of a variable.""" + v = backend.Variable( + initializer=np.array([1.0, 2.0, 3.0], dtype=np.float32) + ) + self.assertEqual(v.dtype, "float32") + + def test_variable_shape(self): + """Test retrieving the shape of a variable.""" + v = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]])) + self.assertEqual(v.shape, (2, 2)) + + def test_variable_ndim(self): + """Test retrieving the number of dimensions of a variable.""" + v = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]])) + self.assertEqual(v.ndim, 2) + + def test_variable_repr(self): + """Test the string representation of a variable.""" + v = backend.Variable( + initializer=np.array([1.0, 2.0, 3.0], dtype=np.float32), + name="test_var", + ) + expected_repr = ( + "" + ) + self.assertEqual(repr(v), expected_repr) + + # Test with `backend.StatelessScope()` + with backend.StatelessScope(): + v = backend.Variable( + initializer="zeros", shape=(3,), name="test_var" + ) + expected_repr = ( + "" + ) + self.assertEqual(repr(v), expected_repr) + + def test_variable_getitem(self): + """Test getting an item from a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertEqual(v[0], 1) + + def test_variable_initialize(self): + """Test initializing a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + init_value = np.array([4.0, 5.0, 6.0]) + v._initialize(value=init_value) + self.assertAllClose(v.value, init_value) + + def test_variable_convert_to_tensor(self): + """Test converting a variable to a tensor.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose( + v._convert_to_tensor(v.value), np.array([1.0, 2.0, 3.0]) + ) + + def test_variable_convert_to_tensor_with_dtype(self): + """Test converting a variable to a tensor with a dtype.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose( + v._convert_to_tensor(v.value, dtype="float32"), + np.array([1.0, 2.0, 3.0]), + ) + + def test_variable_array(self): + """Test converting a variable to an array.""" + v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v.__array__(), np.array([1.0, 2.0, 3.0])) + + +class VariableOpsCorrectnessTest(test_case.TestCase): + """Tests for operations on Variable.""" + + def test_int(self): + v = backend.Variable(initializer=np.array(-1.1)) + self.assertAllClose(int(v), np.array(-1)) + + def test_float(self): + v = backend.Variable(initializer=np.array(-1.1)) + self.assertAllClose(float(v), np.array(-1.1)) + + def test__neg__(self): + """Test negating a variable.""" + v = backend.Variable(initializer=np.array([-1.0, 2.0]), trainable=False) + self.assertAllClose(v.__neg__(), np.array([1.0, -2.0])) + + def test__abs__(self): + """Test absolute value on a variable.""" + v = backend.Variable(initializer=np.array([-1.0, 2.0]), trainable=False) + self.assertAllClose(v.__abs__(), np.array([1.0, 2.0])) + + def test__invert__(self): + """Test bitwise not on a variable.""" + v = backend.Variable( + initializer=np.array([True, False]), trainable=False, dtype="bool" + ) + self.assertAllClose(v.__invert__(), np.array([False, True])) + + def test__eq__(self): + """Test equality comparison on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False) + self.assertAllClose( + v.__eq__(np.array([1.0, 2.0])), np.array([True, True]) + ) + + def test__ne__(self): + """Test inequality comparison on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False) + self.assertAllClose( + v.__ne__(np.array([1.0, 2.0])), np.array([False, False]) + ) + + def test__lt__(self): + """Test less than comparison on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False) + self.assertAllClose( + v.__lt__(np.array([1.0, 2.0])), np.array([False, False]) + ) + + def test__le__(self): + """Test less than or equal to comparison on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False) + self.assertAllClose( + v.__le__(np.array([1.0, 2.0])), np.array([True, True]) + ) + + def test__gt__(self): + """Test greater than comparison on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False) + self.assertAllClose( + v.__gt__(np.array([1.0, 2.0])), np.array([False, False]) + ) + + def test__ge__(self): + """Test greater than or equal to comparison on a variable.""" + v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False) + self.assertAllClose( + v.__ge__(np.array([1.0, 2.0])), np.array([True, True]) + ) + + def test__add__(self): + """Test addition operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v1.__add__(v2), np.array([5.0, 7.0, 9.0])) + + def test__radd__(self): + """Test reverse addition operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v1.__radd__(v2), np.array([5.0, 7.0, 9.0])) + + def test__sub__(self): + """Test subtraction operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v1.__sub__(v2), np.array([-3.0, -3.0, -3.0])) + + def test__rsub__(self): + """Test reverse subtraction operation on a variable.""" + v1 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v1.__rsub__(v2), np.array([-3.0, -3.0, -3.0])) + + def test__mul__(self): + """Test multiplication operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v1.__mul__(v2), np.array([4.0, 10.0, 18.0])) + + def test__rmul__(self): + """Test reverse multiplication operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v1.__rmul__(v2), np.array([4.0, 10.0, 18.0])) + + def test__truediv__(self): + """Test true division operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + self.assertAllClose(v1.__truediv__(v2), np.array([0.25, 0.4, 0.5])) + + def test__rtruediv__(self): + """Test reverse true division operation on a variable.""" + v1 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v1.__rtruediv__(v2), np.array([0.25, 0.4, 0.5])) + + @skip_if_backend( + "openvino", "`floor_divide` is not supported with openvino backend" + ) + def test__floordiv__(self): + """Test floordiv operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0])) + self.assertAllClose(v1.__floordiv__(v2), np.array([-1.0, 0.0, 0.0])) + + @skip_if_backend( + "openvino", "`floor_divide` is not supported with openvino backend" + ) + def test__rfloordiv__(self): + """Test reverse floordiv operation on a variable.""" + v1 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0])) + v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v1.__rfloordiv__(v2), np.array([-1.0, 0.0, 0.0])) + + def test__mod__(self): + """Test mod operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0])) + self.assertAllClose(v1.__mod__(v2), np.array([-3.0, 2.0, 3.0])) + + def test__rmod__(self): + """Test reverse mod operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v1.__rmod__(v2), np.array([0.0, 0.0, 0.0])) + + def test__pow__(self): + """Test pow operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0])) + self.assertAllClose(v1.__pow__(v2), np.array([1.0, 32.0, 729.0])) + + def test__rpow__(self): + """Test reverse power operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + self.assertAllClose(v1.__rpow__(v2), np.array([1.0, 4.0, 27.0])) + + def test__matmul__(self): + """Test matmul operation on a variable.""" + v1 = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]])) + v2 = backend.Variable(initializer=np.array([[5.0, 6.0], [7.0, 8.0]])) + self.assertAllClose( + v1.__matmul__(v2), np.array([[19.0, 22.0], [43.0, 50.0]]) + ) + + def test__rmatmul__(self): + """Test reverse matmul operation on a variable.""" + v1 = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]])) + v2 = backend.Variable(initializer=np.array([[5.0, 6.0], [7.0, 8.0]])) + self.assertAllClose( + v1.__rmatmul__(v2), np.array([[23.0, 34.0], [31.0, 46.0]]) + ) + + def test__and__(self): + """Test bitwise and operation on a variable.""" + v1 = backend.Variable( + initializer=np.array([True, False]), dtype="bool", trainable=False + ) + v2 = backend.Variable( + initializer=np.array([True, True]), dtype="bool", trainable=False + ) + self.assertAllClose(v1.__and__(v2), np.array([True, False])) + + def test__rand__(self): + """Test reverse bitwise and operation on a variable.""" + v1 = backend.Variable( + initializer=np.array([True, False]), dtype="bool", trainable=False + ) + v2 = backend.Variable( + initializer=np.array([True, True]), dtype="bool", trainable=False + ) + self.assertAllClose(v1.__rand__(v2), np.array([True, False])) + + def test__or__(self): + """Test bitwise or operation on a variable.""" + v1 = backend.Variable( + initializer=np.array([True, False]), dtype="bool", trainable=False + ) + v2 = backend.Variable( + initializer=np.array([True, True]), dtype="bool", trainable=False + ) + self.assertAllClose(v1.__or__(v2), np.array([True, True])) + + def test__ror__(self): + """Test reverse bitwise or operation on a variable.""" + v1 = backend.Variable( + initializer=np.array([True, False]), dtype="bool", trainable=False + ) + v2 = backend.Variable( + initializer=np.array([True, True]), dtype="bool", trainable=False + ) + self.assertAllClose(v1.__ror__(v2), np.array([True, True])) + + def test__xor__(self): + """Test bitwise xor operation on a variable.""" + v1 = backend.Variable( + initializer=np.array([True, False]), dtype="bool", trainable=False + ) + v2 = backend.Variable( + initializer=np.array([True, True]), dtype="bool", trainable=False + ) + self.assertAllClose(v1.__xor__(v2), np.array([False, True])) + + def test__rxor__(self): + """Test reverse bitwise xor operation on a variable.""" + v1 = backend.Variable( + initializer=np.array([True, False]), dtype="bool", trainable=False + ) + v2 = backend.Variable( + initializer=np.array([True, True]), dtype="bool", trainable=False + ) + self.assertAllClose(v1.__rxor__(v2), np.array([False, True])) + + def test__pos__(self): + """Test unary plus on a variable.""" + v = backend.Variable(initializer=np.array([-1.0, 2.0]), trainable=False) + self.assertAllClose(v.__pos__(), np.array([-1.0, 2.0])) + + def test_variable_pow(self): + """Test pow operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + result = v1**v2 + self.assertAllClose(result, np.array([1.0, 32.0, 729.0])) + + def test_variable_rpow(self): + """Test reverse power operation on a variable.""" + v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0])) + v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0])) + result = v2**v1 + self.assertAllClose(result, np.array([4.0, 25.0, 216.0])) + + @skip_if_backend( + "openvino", "`round` is not supported with openvino backend" + ) + def test_round(self): + v = backend.Variable(initializer=np.array([1.1, 2.2, 3.3])) + self.assertAllClose(round(v), np.array([1.0, 2.0, 3.0])) + + +class VariableOpsBehaviorTest(test_case.TestCase): + def test_invalid_bool(self): + """Test converting a variable to boolean.""" + v = backend.Variable(initializer=np.ones((2, 2))) + with self.assertRaisesRegex( + TypeError, "A Keras Variable cannot be used as a boolean." + ): + bool(v) + + def test_invalid_int(self): + v = backend.Variable(initializer=np.ones((2, 2))) + with self.assertRaisesRegex( + TypeError, "Only scalar arrays can be converted to Python scalars." + ): + int(v) + + def test_invalid_float(self): + v = backend.Variable(initializer=np.ones((2, 2))) + with self.assertRaisesRegex( + TypeError, "Only scalar arrays can be converted to Python scalars." + ): + float(v) + + +class VariableOpsDTypeTest(test_case.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x + not in ( + "string", + "complex128", + # Remove 64-bit dtypes. + "float64", + "uint64", + "int64", + ) + + dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests + ] + [None] + INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")] + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + COMPLEX_DTYPES = ["complex32", "complex64"] + if backend.backend() == "torch": + ALL_DTYPES = [ + x for x in ALL_DTYPES if x not in ("uint16", "uint32", "complex64") + ] + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")] + elif backend.backend() == "tensorflow": + # TODO(hongyu): Re-enable uint32 tests once we determine how to handle + # dtypes.result_type(uint32, int*) -> int64 promotion. + # Since TF variables require int64 to be placed on the GPU, we + # exclusively enable the int64 dtype for TF. However, JAX does not + # natively support int64, which prevents us from comparing the dtypes. + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint32",)] + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint32",)] + elif backend.backend() == "openvino": + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("complex64",)] + NON_COMPLEX_DTYPES = [ + x for x in ALL_DTYPES if x and x not in ["complex32", "complex64"] + ] + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_eq(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.equal(x1_jax, x2_jax).dtype) + + self.assertDType(x1 == x2, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_ne(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.not_equal(x1_jax, x2_jax).dtype) + + self.assertDType(x1 != x2, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + def test_lt(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.less(x1_jax, x2_jax).dtype) + + self.assertDType(x1 < x2, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + def test_le(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.less_equal(x1_jax, x2_jax).dtype) + + self.assertDType(x1 <= x2, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + def test_gt(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.greater(x1_jax, x2_jax).dtype) + + self.assertDType(x1 > x2, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + def test_ge(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.greater_equal(x1_jax, x2_jax).dtype + ) + + self.assertDType(x1 >= x2, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_add(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype) + + self.assertDType(x1 + x2, expected_dtype) + self.assertDType(x1.__radd__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_sub(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype) + + self.assertDType(x1 - x2, expected_dtype) + self.assertDType(x1.__rsub__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_mul(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype) + + self.assertDType(x1 * x2, expected_dtype) + self.assertDType(x1.__rmul__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + def test_truediv(self, dtypes): + import jax.numpy as jnp + + try: + # JAX v0.8.0 and newer + from jax import enable_x64 + except ImportError: + # JAX v0.7.2 and older + from jax.experimental import enable_x64 + + # We have to disable x64 for jax since jnp.true_divide doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with enable_x64(False): + dtype1, dtype2 = dtypes + x1 = backend.Variable( + "ones", shape=(1,), dtype=dtype1, trainable=False + ) + x2 = backend.Variable( + "ones", shape=(1,), dtype=dtype2, trainable=False + ) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.true_divide(x1_jax, x2_jax).dtype + ) + if "float64" in (dtype1, dtype2): + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertDType(x1 / x2, expected_dtype) + self.assertDType(x1.__rtruediv__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + @skip_if_backend( + "openvino", "`floor_divide` is not supported with openvino backend" + ) + def test_floordiv(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.floor_divide(x1_jax, x2_jax).dtype + ) + + self.assertDType(x1 // x2, expected_dtype) + self.assertDType(x1.__rfloordiv__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2)) + ) + def test_mod(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.mod(x1_jax, x2_jax).dtype) + + self.assertDType(x1 % x2, expected_dtype) + self.assertDType(x1.__rmod__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_pow(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.power(x1_jax, x2_jax).dtype) + + self.assertDType(x1**x2, expected_dtype) + self.assertDType(x1.__rpow__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_matmul(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.matmul(x1_jax, x2_jax).dtype) + + self.assertDType(x1 @ x2, expected_dtype) + self.assertDType(x1.__rmatmul__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_and(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.logical_and(x1_jax, x2_jax).dtype + ) + + self.assertDType(x1 & x2, expected_dtype) + self.assertDType(x1.__rand__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_or(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.logical_or(x1_jax, x2_jax).dtype) + + self.assertDType(x1 | x2, expected_dtype) + self.assertDType(x1.__ror__(x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_xor(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = backend.Variable("ones", shape=(1,), dtype=dtype1, trainable=False) + x2 = backend.Variable("ones", shape=(1,), dtype=dtype2, trainable=False) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.logical_xor(x1_jax, x2_jax).dtype + ) + + self.assertDType(x1 ^ x2, expected_dtype) + self.assertDType(x1.__rxor__(x2), expected_dtype) + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Tests for standardize_shape with Torch backend", +) +class TestStandardizeShapeWithTorch(test_case.TestCase): + """Tests for standardize_shape with Torch backend.""" + + def test_standardize_shape_with_torch_size_containing_negative_value(self): + """Tests shape with a negative value.""" + shape_with_negative_value = (3, 4, -5) + with self.assertRaisesRegex( + ValueError, + "Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions", + ): + _ = standardize_shape(shape_with_negative_value) + + def test_standardize_shape_with_torch_size_valid(self): + """Tests a valid shape.""" + shape_valid = (3, 4, 5) + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3, 4, 5)) + + def test_standardize_shape_with_torch_size_multidimensional(self): + """Tests shape of a multi-dimensional tensor.""" + import torch + + tensor = torch.randn(3, 4, 5) + shape = tensor.size() + standardized_shape = standardize_shape(shape) + self.assertEqual(standardized_shape, (3, 4, 5)) + + def test_standardize_shape_with_torch_size_single_dimension(self): + """Tests shape of a single-dimensional tensor.""" + import torch + + tensor = torch.randn(10) + shape = tensor.size() + standardized_shape = standardize_shape(shape) + self.assertEqual(standardized_shape, (10,)) + + def test_standardize_shape_with_torch_size_with_valid_1_dimension(self): + """Tests a valid shape.""" + shape_valid = [3] + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3,)) + + def test_standardize_shape_with_torch_size_with_valid_2_dimension(self): + """Tests a valid shape.""" + shape_valid = [3, 4] + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3, 4)) + + def test_standardize_shape_with_torch_size_with_valid_3_dimension(self): + """Tests a valid shape.""" + shape_valid = [3, 4, 5] + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3, 4, 5)) + + def test_standardize_shape_with_torch_size_with_negative_value(self): + """Tests shape with a negative value appended.""" + import torch + + tensor = torch.randn(3, 4, 5) + shape = tuple(tensor.size()) + shape_with_negative = shape + (-1,) + with self.assertRaisesRegex( + ValueError, + "Cannot convert .* to a shape. Negative dimensions are not", + ): + _ = standardize_shape(shape_with_negative) + + def test_standardize_shape_with_non_integer_entry(self): + """Tests shape with a non-integer value.""" + with self.assertRaisesRegex( + # different error message for torch + ValueError, + r"invalid literal for int\(\) with base 10: 'a'", + ): + standardize_shape([3, 4, "a"]) + + def test_standardize_shape_with_negative_entry(self): + """Tests shape with a negative value.""" + with self.assertRaisesRegex( + ValueError, + "Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions", + ): + standardize_shape([3, 4, -5]) + + def test_standardize_shape_with_valid_not_tuple(self): + """Tests a valid shape.""" + shape_valid = [3, 4, 5] + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3, 4, 5)) + + +@pytest.mark.skipif( + backend.backend() == "torch", + reason="Tests for standardize_shape with others backend", +) +class TestStandardizeShapeWithOutTorch(test_case.TestCase): + """Tests for standardize_shape with others backend.""" + + def test_standardize_shape_with_out_torch_negative_value(self): + """Tests shape with a negative value.""" + shape_with_negative_value = (3, 4, -5) + with self.assertRaisesRegex( + ValueError, + "Cannot convert '\\(3, 4, -5\\)' to a shape. Negative dimensions", + ): + _ = standardize_shape(shape_with_negative_value) + + def test_standardize_shape_with_out_torch_string(self): + """Tests shape with a string value.""" + shape_with_string = (3, 4, "5") + with self.assertRaisesRegex( + ValueError, + "Cannot convert .* to a shape. Found invalid entry '5'.", + ): + _ = standardize_shape(shape_with_string) + + def test_standardize_shape_with_out_torch_float(self): + """Tests shape with a float value.""" + shape_with_float = (3, 4, 5.0) + with self.assertRaisesRegex( + ValueError, + "Cannot convert .* to a shape. Found invalid entry '5.0'.", + ): + _ = standardize_shape(shape_with_float) + + def test_standardize_shape_with_out_torch_valid(self): + """Tests a valid shape.""" + shape_valid = (3, 4, 5) + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3, 4, 5)) + + def test_standardize_shape_with_out_torch_valid_not_tuple(self): + """Tests a valid shape.""" + shape_valid = [3, 4, 5] + standardized_shape = standardize_shape(shape_valid) + self.assertEqual(standardized_shape, (3, 4, 5)) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py new file mode 100644 index 000000000000..3986a467de92 --- /dev/null +++ b/keras/src/backend/config.py @@ -0,0 +1,448 @@ +import json +import os + +from keras.src.api_export import keras_export + +# The type of float to use throughout a session. +_FLOATX = "float32" + +# Epsilon fuzz factor used throughout the codebase. +_EPSILON = 1e-7 + +# Default image data format, one of "channels_last", "channels_first". +_IMAGE_DATA_FORMAT = "channels_last" + +# Default backend: TensorFlow. +_BACKEND = "tensorflow" + +# Whether NNX is enabled. +_NNX_ENABLED = False + +# Cap run duration for debugging. +_MAX_EPOCHS = None +_MAX_STEPS_PER_EPOCH = None + + +@keras_export(["keras.config.floatx", "keras.backend.floatx"]) +def floatx(): + """Return the default float type, as a string. + + E.g. `'bfloat16'`, `'float16'`, `'float32'`, `'float64'`. + + Returns: + String, the current default float type. + + Example: + + >>> keras.config.floatx() + 'float32' + + """ + return _FLOATX + + +@keras_export(["keras.config.set_floatx", "keras.backend.set_floatx"]) +def set_floatx(value): + """Set the default float dtype. + + Note: It is not recommended to set this to `"float16"` for training, + as this will likely cause numeric stability issues. + Instead, mixed precision, which leverages + a mix of `float16` and `float32`. It can be configured by calling + `keras.mixed_precision.set_dtype_policy('mixed_float16')`. + + Args: + value: String; `'bfloat16'`, `'float16'`, `'float32'`, or `'float64'`. + + Examples: + >>> keras.config.floatx() + 'float32' + + >>> keras.config.set_floatx('float64') + >>> keras.config.floatx() + 'float64' + + >>> # Set it back to float32 + >>> keras.config.set_floatx('float32') + + Raises: + ValueError: In case of invalid value. + """ + global _FLOATX + accepted_dtypes = {"bfloat16", "float16", "float32", "float64"} + if value not in accepted_dtypes: + raise ValueError( + f"Unknown `floatx` value: {value}. " + f"Expected one of {accepted_dtypes}" + ) + _FLOATX = str(value) + + +@keras_export(["keras.config.epsilon", "keras.backend.epsilon"]) +def epsilon(): + """Return the value of the fuzz factor used in numeric expressions. + + Returns: + A float. + + Example: + + >>> keras.config.epsilon() + 1e-07 + + """ + return _EPSILON + + +@keras_export(["keras.config.set_epsilon", "keras.backend.set_epsilon"]) +def set_epsilon(value): + """Set the value of the fuzz factor used in numeric expressions. + + Args: + value: float. New value of epsilon. + + Examples: + >>> keras.config.epsilon() + 1e-07 + + >>> keras.config.set_epsilon(1e-5) + >>> keras.config.epsilon() + 1e-05 + + >>> # Set it back to the default value. + >>> keras.config.set_epsilon(1e-7) + + """ + global _EPSILON + _EPSILON = value + + +@keras_export( + [ + "keras.config.image_data_format", + "keras.backend.image_data_format", + ] +) +def image_data_format(): + """Return the default image data format convention. + + Returns: + A string, either `'channels_first'` or `'channels_last'`. + + Example: + + >>> keras.config.image_data_format() + 'channels_last' + + """ + return _IMAGE_DATA_FORMAT + + +@keras_export( + [ + "keras.config.set_image_data_format", + "keras.backend.set_image_data_format", + ] +) +def set_image_data_format(data_format): + """Set the value of the image data format convention. + + Args: + data_format: string. `'channels_first'` or `'channels_last'`. + + Examples: + + >>> keras.config.image_data_format() + 'channels_last' + + >>> keras.config.set_image_data_format('channels_first') + >>> keras.config.image_data_format() + 'channels_first' + + >>> # Set it back to `'channels_last'` + >>> keras.config.set_image_data_format('channels_last') + + """ + global _IMAGE_DATA_FORMAT + data_format = str(data_format).lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "The `data_format` argument must be one of " + "{'channels_first', 'channels_last'}. " + f"Received: data_format={data_format}" + ) + _IMAGE_DATA_FORMAT = data_format + + +@keras_export("keras.config.enable_flash_attention") +def enable_flash_attention(): + """Enable flash attention. + + Flash attention offers performance optimization for attention layers, + making it especially useful for large language models (LLMs) that + benefit from faster and more memory-efficient attention computations. + + Once enabled, supported layers like `MultiHeadAttention` will **attempt** to + use flash attention for faster computations. By default, this feature is + enabled. + + Note that enabling flash attention does not guarantee it will always be + used. Typically, the inputs must be in `float16` or `bfloat16` dtype, and + input layout requirements may vary depending on the backend. + """ + from keras.src.backend.common import global_state + + global_state.set_global_attribute("flash_attention", None) + + +@keras_export("keras.config.disable_flash_attention") +def disable_flash_attention(): + """Disable flash attention. + + Flash attention offers performance optimization for attention layers, + making it especially useful for large language models (LLMs) that + benefit from faster and more memory-efficient attention computations. + + Once disabled, supported layers like `MultiHeadAttention` will not + use flash attention for faster computations. + """ + from keras.src.backend.common import global_state + + global_state.set_global_attribute("flash_attention", False) + + +@keras_export("keras.config.is_flash_attention_enabled") +def is_flash_attention_enabled(): + """Checks whether flash attention is globally enabled in Keras. + + Flash attention is a performance-optimized method for computing attention + in large models, such as transformers, allowing for faster and more + memory-efficient operations. This function checks the global Keras + configuration to determine if flash attention is enabled for compatible + layers (e.g., `MultiHeadAttention`). + + Note that enabling flash attention does not guarantee it will always be + used. Typically, the inputs must be in `float16` or `bfloat16` dtype, and + input layout requirements may vary depending on the backend. + + Returns: + `False` if disabled; otherwise, it indicates that it is enabled. + """ + from keras.src.backend.common import global_state + + return global_state.get_global_attribute("flash_attention", default=None) + + +@keras_export("keras.config.is_nnx_enabled") +def is_nnx_enabled(): + """Checks whether NNX specific features are enabled for the JAX backend. + + Returns: + bool: `True` if NNX backend features are enabled, `False` otherwise. + Defaults to `False`. + """ + return _NNX_ENABLED + + +def set_nnx_enabled(value): + global _NNX_ENABLED + from keras.src.backend.common import global_state + + _NNX_ENABLED = bool(value) + if _NNX_ENABLED: + try: + from flax import nnx # noqa F401 + except ImportError: + raise ImportError( + "To use NNX with the JAX backend, you must install `flax`." + ) + global_state.set_global_attribute("nnx_enabled", bool(value)) + + +def standardize_data_format(data_format): + if data_format is None: + return image_data_format() + data_format = str(data_format).lower() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "The `data_format` argument must be one of " + "{'channels_first', 'channels_last'}. " + f"Received: data_format={data_format}" + ) + return data_format + + +# Set Keras base dir path given KERAS_HOME env variable, if applicable. +# Otherwise either ~/.keras or /tmp. +if "KERAS_HOME" in os.environ: + _KERAS_DIR = os.environ.get("KERAS_HOME") +else: + _keras_base_dir = os.path.expanduser("~") + if not os.access(_keras_base_dir, os.W_OK): + _keras_base_dir = "/tmp" + _KERAS_DIR = os.path.join(_keras_base_dir, ".keras") + + +def keras_home(): + # Private accessor for the keras home location. + return _KERAS_DIR + + +# Attempt to read Keras config file. +_config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) +if os.path.exists(_config_path): + try: + with open(_config_path) as f: + _config = json.load(f) + except ValueError: + _config = {} + _floatx = _config.get("floatx", floatx()) + assert _floatx in {"float16", "float32", "float64"} + _epsilon = _config.get("epsilon", epsilon()) + assert isinstance(_epsilon, float) + _backend = _config.get("backend", _BACKEND) + _image_data_format = _config.get("image_data_format", image_data_format()) + assert _image_data_format in {"channels_last", "channels_first"} + _nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED) + + # Apply basic configs that don't cause circular import + set_floatx(_floatx) + _NNX_ENABLED = _nnx_enabled_config + set_epsilon(_epsilon) + set_image_data_format(_image_data_format) + _BACKEND = _backend + +# Save config file, if possible. +if not os.path.exists(_KERAS_DIR): + try: + os.makedirs(_KERAS_DIR) + except OSError: + # Except permission denied and potential race conditions + # in multi-threaded environments. + pass + +if not os.path.exists(_config_path): + _config = { + "floatx": floatx(), + "epsilon": epsilon(), + "backend": _BACKEND, + "image_data_format": image_data_format(), + } + try: + with open(_config_path, "w") as f: + f.write(json.dumps(_config, indent=4)) + except IOError: + # Except permission denied. + pass + +# Set backend based on KERAS_BACKEND flag, if applicable. +if "KERAS_BACKEND" in os.environ: + _backend = os.environ["KERAS_BACKEND"] + if _backend: + _BACKEND = _backend +if "KERAS_MAX_EPOCHS" in os.environ: + _MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"]) +if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: + _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) + + +if _BACKEND != "tensorflow": + # If we are not running on the tensorflow backend, we should stop tensorflow + # from using all available GPU memory. See + # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth + os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" + + +@keras_export( + [ + "keras.config.backend", + "keras.backend.backend", + ] +) +def backend(): + """Publicly accessible method for determining the current backend. + + Returns: + String, the name of the backend Keras is currently using. One of + `"tensorflow"`, `"torch"`, or `"jax"`. + + Example: + + >>> keras.config.backend() + 'tensorflow' + + """ + return _BACKEND + + +@keras_export(["keras.config.set_max_epochs"]) +def set_max_epochs(max_epochs): + """Limit the maximum number of epochs for any call to fit. + + This will cap the number of epochs for any training run using `model.fit()`. + This is purely for debugging, and can also be set via the `KERAS_MAX_EPOCHS` + environment variable to quickly run a script without modifying its source. + + Args: + max_epochs: The integer limit on the number of epochs or `None`. If + `None`, no limit is applied. + """ + global _MAX_EPOCHS + _MAX_EPOCHS = max_epochs + + +@keras_export(["keras.config.set_max_steps_per_epoch"]) +def set_max_steps_per_epoch(max_steps_per_epoch): + """Limit the maximum number of steps for any call to fit/evaluate/predict. + + This will cap the number of steps for single epoch of a call to `fit()`, + `evaluate()`, or `predict()`. This is purely for debugging, and can also be + set via the `KERAS_MAX_STEPS_PER_EPOCH` environment variable to quickly run + a scrip without modifying its source. + + Args: + max_epochs: The integer limit on the number of epochs or `None`. If + `None`, no limit is applied. + """ + global _MAX_STEPS_PER_EPOCH + _MAX_STEPS_PER_EPOCH = max_steps_per_epoch + + +@keras_export(["keras.config.max_epochs"]) +def max_epochs(): + """Get the maximum number of epochs for any call to fit. + + Retrieves the limit on the number of epochs set by + `keras.config.set_max_epochs` or the `KERAS_MAX_EPOCHS` environment + variable. + + Returns: + The integer limit on the number of epochs or `None`, if no limit has + been set. + """ + return _MAX_EPOCHS + + +@keras_export(["keras.config.max_steps_per_epoch"]) +def max_steps_per_epoch(): + """Get the maximum number of steps for any call to fit/evaluate/predict. + + Retrieves the limit on the number of epochs set by + `keras.config.set_max_steps_per_epoch` or the `KERAS_MAX_STEPS_PER_EPOCH` + environment variable. + + Args: + max_epochs: The integer limit on the number of epochs or `None`. If + `None`, no limit is applied. + """ + return _MAX_STEPS_PER_EPOCH + + +if "KERAS_NNX_ENABLED" in os.environ: + env_val = os.environ["KERAS_NNX_ENABLED"].lower() + if env_val == "true" or env_val == "1": + _NNX_ENABLED = True + else: + _NNX_ENABLED = False + +set_nnx_enabled(_NNX_ENABLED) diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py new file mode 100644 index 000000000000..89ac0fa71c8c --- /dev/null +++ b/keras/src/backend/jax/__init__.py @@ -0,0 +1,31 @@ +from keras.src.backend.config import is_nnx_enabled +from keras.src.backend.jax import core +from keras.src.backend.jax import distribution_lib +from keras.src.backend.jax import image +from keras.src.backend.jax import linalg +from keras.src.backend.jax import math +from keras.src.backend.jax import nn +from keras.src.backend.jax import numpy +from keras.src.backend.jax import random +from keras.src.backend.jax import tensorboard +from keras.src.backend.jax.core import IS_THREAD_SAFE +from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS +from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.jax.core import Variable +from keras.src.backend.jax.core import cast +from keras.src.backend.jax.core import compute_output_spec +from keras.src.backend.jax.core import cond +from keras.src.backend.jax.core import convert_to_numpy +from keras.src.backend.jax.core import convert_to_tensor +from keras.src.backend.jax.core import device_scope +from keras.src.backend.jax.core import is_tensor +from keras.src.backend.jax.core import name_scope +from keras.src.backend.jax.core import random_seed_dtype +from keras.src.backend.jax.core import scatter +from keras.src.backend.jax.core import shape +from keras.src.backend.jax.core import stop_gradient +from keras.src.backend.jax.core import vectorized_map +from keras.src.backend.jax.rnn import cudnn_ok +from keras.src.backend.jax.rnn import gru +from keras.src.backend.jax.rnn import lstm +from keras.src.backend.jax.rnn import rnn diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py new file mode 100644 index 000000000000..7dc5a98fb8d5 --- /dev/null +++ b/keras/src/backend/jax/core.py @@ -0,0 +1,574 @@ +import jax +import jax.experimental.sparse as jax_sparse +import jax.numpy as jnp +import ml_dtypes +import numpy as np +from jax import export as jax_export + +from keras.src import tree +from keras.src.backend import config +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import global_state +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.name_scope import name_scope as base_name_scope +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.stateless_scope import get_stateless_scope +from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.jax import distribution_lib + +SUPPORTS_SPARSE_TENSORS = True +SUPPORTS_RAGGED_TENSORS = False +IS_THREAD_SAFE = True + + +class JaxVariable(KerasVariable): + def __init__(self, *args, layout=None, **kwargs): + # Intercept layout parameter so that it is available + # during initialization. + self._layout = layout + super().__init__(*args, **kwargs) + + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + # We can't import the keras/distribution/distribution_lib + # due to circular dependency. + distribution = global_state.get_global_attribute("distribution") + if self._layout is None and distribution is not None: + tensor_layout = distribution.get_variable_layout(self) + from keras.src.distribution import TensorLayout + + if isinstance(tensor_layout, TensorLayout): + self._layout = tensor_layout.backend_layout + else: + self._layout = tensor_layout + self._direct_assign(value) + + def _direct_assign(self, value): + if self._layout is not None: + value = distribution_lib.distribute_variable(value, self._layout) + self._value = value + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype, sparse=False) + + # Overload native accessor. + def __jax_array__(self): + return self.value + + +Variable = JaxVariable +if config.is_nnx_enabled(): + from flax import nnx + + class NnxVariable(JaxVariable, nnx.Variable): + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + layout=None, + mutable=None, + **nnx_metadata, + ): + # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' + # param takes precedence. + nnx_metadata["mutable"] = trainable if mutable is None else mutable + + # First, initialize a basic nnx.Variable with a dummy value + # This sets up the NNX variable structure + if shape is None: + dummy_value = jnp.array(0.0) + else: + dummy_value = jnp.zeros(shape, dtype=standardize_dtype(dtype)) + + # Initialize nnx.Variable first + nnx.Variable.__init__(self, value=dummy_value, **nnx_metadata) + + # Now we can safely set layout + self._layout = layout + + # Initialize JaxVariable (which will call KerasVariable.__init__ + # and set up the real value). + JaxVariable.__init__( + self, + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + ) + + # The real value is now set in self._value, sync it to raw_value + object.__setattr__(self, "raw_value", self._value) + + @property + def _value(self): + if hasattr(self, "raw_value"): + return self.raw_value + return None + + @_value.setter + def _value(self, new_keras_value): + self._direct_assign(new_keras_value) + + def __getstate__(self): + # Get the state from KerasVariable (attributes in __dict__) + # KerasVariable does not have a custom __getstate__, so we mimic + # default behavior. + try: + keras_state = KerasVariable.__getstate__(self) + except AttributeError: + keras_state = object.__getstate__(self) + + # Get the state from nnx.Variable + nnx_specific_state = nnx.Variable.__getstate__(self) + + # Merge them. Keras state is primary. NNX specific state adds + # to it. + if "raw_value" in nnx_specific_state: + keras_state["_value"] = nnx_specific_state["raw_value"] + + # Add NNX attributes that are not in Keras's __dict__ + if "_trace_state" in nnx_specific_state: + keras_state["_trace_state"] = nnx_specific_state["_trace_state"] + if "_var_metadata" in nnx_specific_state: + keras_state["_var_metadata"] = nnx_specific_state[ + "_var_metadata" + ] + + # Remove elements that might be problematic or redundant if + # nnx.Variable's __getstate__ + keras_state.pop("raw_value", None) + + return keras_state + + def __setstate__(self, state): + # Separate nnx specific keys that we added if they are not part + # of Keras __dict__ this __getstate__ puts them into the main + # state dictionary. + nnx_raw_value = state["_value"] # This was raw_value + nnx_trace_state = state.pop("_trace_state", None) + nnx_var_metadata = state.pop("_var_metadata", None) + + # Populate the instance's __dict__ with the Keras attributes. + self.__dict__.update(state) + + # restore the nnx.Variable specific slotted attributes. + object.__setattr__(self, "raw_value", nnx_raw_value) + + if nnx_trace_state is not None: + object.__setattr__(self, "_trace_state", nnx_trace_state) + else: + pass + + if nnx_var_metadata is not None: + object.__setattr__(self, "_var_metadata", nnx_var_metadata) + else: + pass + + # Ensure Keras's self._value is also consistent with the + # restored raw_value + self._value = nnx_raw_value + + if hasattr(self, "_shape") and self._shape is not None: + self._ndim = len(self._shape) + else: + # Fallback if shape isn't immediately available. + self._ndim = len(self.raw_value.shape) + + def _direct_assign(self, value): + # Apply JAX-specific distribution if layout is present + if self._layout is not None: + value = distribution_lib.distribute_variable( + value, self._layout + ) + + # Apply on_set_value hook if it exists + if ( + hasattr(self, "_var_metadata") + and "on_set_value" in self._var_metadata + ): + value = self._var_metadata["on_set_value"](self, value) + + # Set the value for both Keras and NNX parts + # This ensures both systems see the same value + object.__setattr__(self, "raw_value", value) + + @property + def value(self): + if in_stateless_scope(): + scope = get_stateless_scope() + stateless_value = scope.get_current_value(self) + if stateless_value is not None: + return self._maybe_autocast(stateless_value) + if not hasattr(self, "raw_value"): + if self._initializer is not None: + self._initialize( + self._initializer(self.shape, dtype=self.dtype) + ) + else: + raise AttributeError( + "Variable is not properly initialized (raw_value " + "missing) and has no initializer." + ) + current_value = self.raw_value + if ( + hasattr(self, "_var_metadata") + and "on_get_value" in self._var_metadata + ): + current_value = self._var_metadata["on_get_value"]( + self, current_value + ) + return self._maybe_autocast(current_value) + + Variable = NnxVariable + + +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + if ragged: + raise ValueError("`ragged=True` is not supported with jax backend") + if dtype is not None: + dtype = standardize_dtype(dtype) + if isinstance(x, (jnp.ndarray, jax.Array)) and ( + dtype is None or x.dtype == dtype + ): + # Skip the conversion early if the instance is already a JAX array. + # This is important in the multi-process context since jax.array(x) for + # an existing distributed jax array will raise error. + return x + + if isinstance(x, Variable): + if dtype is not None and x.dtype != dtype: + return x.value.astype(dtype) + return x.value + + if isinstance(x, jax_sparse.JAXSparse): + if sparse is not None and not sparse: + x = x.todense() + elif dtype is not None and x.dtype != dtype: + return x.astype(dtype) + else: + return x + + if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16": + # Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset). + # Instead we convert "as is" (to stored dtype) and cast. + return jnp.asarray(x).astype(dtype) + return jnp.asarray(x, dtype=dtype) + + +def convert_to_numpy(x): + if isinstance(x, jax_sparse.JAXSparse): + x = x.todense() + if is_tensor(x) and x.dtype == "bfloat16": + return np.array(x, dtype=ml_dtypes.bfloat16) + return np.array(x) + + +def is_tensor(x): + if isinstance(x, (jnp.ndarray, jax_sparse.JAXSparse)): + return True + return False + + +def shape(x): + return x.shape + + +def cast(x, dtype): + return convert_to_tensor(x, dtype=dtype) + + +# Shape / dtype / sparseness inference util +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(), SymbolicScope(): + built_in_types = (type(None), int, float, str, bool, complex, bytes) + + # First, separate symbolic args from other args + static_args_idx = [] + static_args = [] + maybe_symbolic_args = [] + static_kwargs = {} + maybe_symbolic_kwargs = {} + for idx, arg in enumerate(args): + if isinstance(arg, built_in_types): + static_args_idx.append(idx) + static_args.append(arg) + else: + maybe_symbolic_args.append(arg) + maybe_symbolic_args = tuple(maybe_symbolic_args) + for k, v in kwargs.items(): + if isinstance(v, built_in_types): + static_kwargs[k] = v + else: + maybe_symbolic_kwargs[k] = v + + # Create a _DimExpr instance for one dimension by creating a symbolic + # shape with one dimension and extracting it. + # + # We create a single dynamic dimension and reuse it instead of creating + # N dynamic dimensions. This is for backwards compatibility. Previously + # we would fill all dynamic dimensions with the same concrete value. + # This can handle the case where there is an implicit assumption that + # two dimensions are the same (e.g. square images). + # + # We add the constraint "dynamic_dimension>=2" to prevent JAX from + # assuming that the dimension can be broadcastable or squeezable. It + # removes this ambiguity. + dynamic_dimension = jax_export.symbolic_shape( + "(dynamic_dimension)", + constraints=["dynamic_dimension>=2"], + )[0] + + def convert_keras_tensor_to_jax(x): + if isinstance(x, KerasTensor): + shape = tuple( + [d if d is not None else dynamic_dimension for d in x.shape] + ) + return jax.ShapeDtypeStruct(shape, dtype=x.dtype) + return x + + def wrapped_fn(*args, **kwargs): + # Turn inputs that are sparse to BCOO tensors + def to_bcoo_if_sparse(x, maybe_symbolic_x): + if ( + isinstance(maybe_symbolic_x, KerasTensor) + and maybe_symbolic_x.sparse + ): + return jax_sparse.BCOO.fromdense(x, nse=1) + return x + + args, kwargs = tree.map_structure( + to_bcoo_if_sparse, + (args, kwargs), + (maybe_symbolic_args, maybe_symbolic_kwargs), + ) + + rec_args = [] + idx_static = 0 + idx_sym = 0 + i = 0 + while idx_static < len(static_args) or idx_sym < len(args): + if i in static_args_idx: + rec_args.append(static_args[idx_static]) + idx_static += 1 + else: + rec_args.append(args[idx_sym]) + idx_sym += 1 + + i += 1 + with StatelessScope(): + return fn(*rec_args, **kwargs, **static_kwargs) + + maybe_symbolic_args_jax, maybe_symbolic_kwargs_jax = tree.map_structure( + convert_keras_tensor_to_jax, + (maybe_symbolic_args, maybe_symbolic_kwargs), + ) + jax_out = jax.eval_shape( + wrapped_fn, *maybe_symbolic_args_jax, **maybe_symbolic_kwargs_jax + ) + + def convert_jax_spec_to_keras_tensor(x): + if isinstance(x, jax.ShapeDtypeStruct): + shape = tuple( + d if isinstance(d, int) else None for d in x.shape + ) + return KerasTensor(shape, x.dtype) + elif isinstance(x, jax_sparse.BCOO): + shape = tuple( + d if isinstance(d, int) else None for d in x.shape + ) + return KerasTensor(shape, x.dtype, sparse=True) + return x + + return tree.map_structure(convert_jax_spec_to_keras_tensor, jax_out) + + +def cond(pred, true_fn, false_fn): + return jax.lax.cond(pred, true_fun=true_fn, false_fun=false_fn) + + +def vectorized_map(function, elements): + return jax.vmap(function)(elements) + + +def map(f, xs): + return jax.lax.map(f, xs) + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + return jax.lax.scan( + f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll + ) + + +def associative_scan(f, elems, reverse=False, axis=0): + return jax.lax.associative_scan(f, elems, reverse, axis) + + +def scatter(indices, values, shape): + zeros = jnp.zeros(shape, values.dtype) + key = tuple(jnp.moveaxis(indices, -1, 0)) + return zeros.at[key].add(values) + + +def scatter_update(inputs, indices, updates): + inputs = convert_to_tensor(inputs) + indices = jnp.array(indices) + indices = jnp.transpose(indices) + inputs = inputs.at[tuple(indices)].set(updates) + return inputs + + +def slice(inputs, start_indices, shape): + # If shape[i] is -1, all remaining elements in dimension i are included in + # the slice. + final_shape = tuple( + inputs.shape[i] - start_indices[i] if s == -1 else s + for i, s in enumerate(shape) + ) + return jax.lax.dynamic_slice(inputs, start_indices, final_shape) + + +def slice_update(inputs, start_indices, updates): + return jax.lax.dynamic_update_slice(inputs, updates, start_indices) + + +def switch(index, branches, *operands): + return jax.lax.switch(index, branches, *operands) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + is_tuple = isinstance(loop_vars, (tuple, list)) + loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,) + if maximum_iterations is not None: + current_iter = 0 + loop_vars = loop_vars + (current_iter,) + + # Unpack list/tuple args. The last argument is `current_iter`. + def _cond(args): + return cond(*args[:-1]) & (args[-1] < maximum_iterations) + + def _body(args): + outputs = body(*args[:-1]) + outputs = tuple(outputs) if is_tuple else (outputs,) + return outputs + (args[-1] + 1,) + + else: + + def _cond(args): + return cond(*args) + + def _body(args): + outputs = body(*args) + return tuple(outputs) if is_tuple else (outputs,) + + outputs = jax.lax.while_loop(_cond, _body, loop_vars) + if maximum_iterations is not None: + outputs = outputs[:-1] + return outputs if is_tuple else outputs[0] + + +def fori_loop(lower, upper, body_fun, init_val): + return jax.lax.fori_loop(lower, upper, body_fun, init_val) + + +def stop_gradient(variable): + if isinstance(variable, Variable): + variable = variable.value + return jax.lax.stop_gradient(variable) + + +def unstack(x, num=None, axis=0): + return [ + jax.lax.index_in_dim(x, i, axis, keepdims=False) + for i in range(x.shape[axis]) + ] + + +def random_seed_dtype(): + # jax random seed uses uint32. + return "uint32" + + +def custom_gradient(fun): + return jax.custom_gradient(fun=fun) + + +def remat(f): + """Implementation of rematerialization. + + Args: + f: The function or operation to rematerialize. + Returns: + A function wrapping f that defines a custom gradient, which + recomputes f on the backwards pass of a gradient call. + """ + return jax.checkpoint(f) + + +class name_scope(base_name_scope): + def __init__(self, name, **kwargs): + super().__init__(name, **kwargs) + self._jax_name_scope = jax.named_scope(name) + + def __enter__(self): + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack", default=[], set_to_default=True + ) + if self.deduplicate and name_scope_stack: + parent_caller = name_scope_stack[-1].caller + parent_name = name_scope_stack[-1].name + if ( + self.caller is not None + and self.caller is parent_caller + and self.name == parent_name + ): + return self + name_scope_stack.append(self) + self._pop_on_exit = True + self._jax_name_scope.__enter__() + return self + + def __exit__(self, *args, **kwargs): + super().__exit__(*args, **kwargs) + if self._pop_on_exit: + self._jax_name_scope.__exit__(*args, **kwargs) + + +def device_scope(device_name): + if isinstance(device_name, str): + # We support string value like "cpu:0", "gpu:1", etc. + device_name = device_name.lower() + jax_device = distribution_lib._to_backend_device(device_name) + elif not isinstance(device_name, jax.Device): + raise ValueError( + "Invalid value for argument `device_name`. " + "Expected a string like 'gpu:0' or a `jax.Device` instance. " + f"Received: device_name='{device_name}'" + ) + else: + jax_device = device_name + return jax.default_device(jax_device) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py new file mode 100644 index 000000000000..792cf25e67f0 --- /dev/null +++ b/keras/src/backend/jax/core_test.py @@ -0,0 +1,68 @@ +import os + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.backend.config import is_nnx_enabled + +if is_nnx_enabled(): + from flax import nnx + + from keras.src.backend.jax.core import NnxVariable + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JAX backend specific test for core Variable integration with NNX.", +) +@pytest.mark.skipif( + not is_nnx_enabled(), + reason="Test requires NNX backend to be enabled by default for setup.", +) +class NnxVariableTest(testing.TestCase): + def setup(self): + super().setup() + + class NNXModel(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + # Use NnxVariable directly as KerasJaxVariable + # might be JaxVariable if NNX is disabled globally. + self.custom_variable = NnxVariable(jnp.ones((1, 3))) + + def __call__(self, x): + return self.linear(x) + self.custom_variable + + self.nnx_model = NNXModel(rngs=nnx.Rngs(0)) + self.keras_nnx_model = keras.Sequential( + [keras.layers.Dense(units=1, input_shape=(10,))] + ) + self.single_dummy_input = np.random.rand(1, 10) + + def test_variable_in_nnx_module(self): + self.assertTrue(hasattr(self.nnx_model.custom_variable, "_trace_state")) + self.assertIsNotNone(self.nnx_model.custom_variable._trace_state) + self.assertAllEqual(self.nnx_model.custom_variable.value, [[1, 1, 1]]) + self.assertTrue( + isinstance(self.nnx_model.custom_variable, nnx.Variable) + ) + + def test_model_saving(self): + path = os.path.join(self.get_temp_dir(), "model.keras") + original_outputs = self.keras_nnx_model(self.single_dummy_input) + self.keras_nnx_model.save(path, save_format="keras_v3") + restored_model = keras.models.load_model(path) + restored_outputs = restored_model(self.single_dummy_input) + self.assertAllEqual(original_outputs, restored_outputs) + + def test_keras_variable_nnx_split_merge_sync(self): + variable1 = keras.Variable(jnp.array(1.0)) + graphdef, state = nnx.split(variable1) + state = jax.tree.map(lambda x: x + 1, state) + variable2 = nnx.merge(graphdef, state) + self.assertEqual(variable2._value, variable2.value) diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py new file mode 100644 index 000000000000..6b5bf37314c0 --- /dev/null +++ b/keras/src/backend/jax/distribution_lib.py @@ -0,0 +1,248 @@ +"""Utilities for distribution strategy with JAX backend.""" + +import jax +import numpy as np + +from keras.src.backend.common import global_state +from keras.src.random import seed_generator +from keras.src.utils import jax_utils +from keras.src.utils import rng_utils + + +def list_devices(device_type=None): + """Return all the available devices based on the device type. + + Note that this should return the global devices in a distributed setting. + + Args: + device_type: string of `"cpu"`, `"gpu"` or `"tpu"`. Defaults to `"gpu"` + or `"tpu"` if available when device_type is not provided. Otherwise + will return the `"cpu"` devices. + + Return: + List of devices that are available for distribute computation. + """ + device_type = device_type.lower() if device_type else None + jax_devices = jax.devices(backend=device_type) + return [f"{device.platform}:{device.id}" for device in jax_devices] + + +def distribute_variable(value, layout): + """Create a distributed variable for JAX. + + Since JAX doesn't have a variable class, this will just return a `jax.Array` + with the corresponding layout/sharding specified. + + Note that this function should be used in eager context, not in jitted + function. + + Args: + value: the initial value of the variable. + layout: `TensorLayout` for the created variable, or a + JAX-supported layout instance (e.g. `jax.sharding.Sharding`). + + Returns: + jax.Array which is the distributed variable. + """ + return distribute_tensor(value, layout) + + +def distribute_tensor(tensor, layout): + """Distribute the tensor based on the layout. + + Note that this function can be used both in eager context, or within a + jitted function. + + Args: + tensor: `jax.Array` that need to be distributed. + layout: `TensorLayout` for the created variable, or a + JAX-supported layout instance (e.g. `jax.sharding.Sharding`). + + Returns: + Distributed value. + """ + # Avoid circular imports. + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + layout = layout.backend_layout + + # TODO(scottzhu): This might not be a cheap check, we should consider + # have some proper JAX API for doing this check. + if jax_utils.is_in_jax_tracing_scope(): + return jax.lax.with_sharding_constraint(tensor, layout) + + # Skip relayout if unnecessary. + if isinstance(tensor, jax.Array): + if isinstance( + layout, jax.sharding.Sharding + ) and tensor.sharding.is_equivalent_to(layout, ndim=len(tensor.shape)): + return tensor + # JAX explicit "layout" support. + elif hasattr(layout, "layout"): + current_layout = getattr(tensor, "layout", None) + if current_layout == layout: + return tensor + # JAX explicit "format" support. + elif hasattr(layout, "format"): + current_layout = getattr(tensor, "format", None) + if current_layout == layout: + return tensor + + return jax.device_put(tensor, layout) + + +def distribute_data_input(per_process_batch, layout, batch_dim_name): + """Distribute the input data with the corresponding layout. + + Note that the inputs here is a local worker batch. Within the local worker, + the data need to be further partitioned to map to each of the devices. + + Args: + inputs: `jax.Array` that is already sharded to a local process size. + layout: `TensorLayout` for the distribution information, or a + `jax.sharding.Sharding` instance. + + Returns: + A global batch distributed according to `layout`. + """ + # Avoid circular imports. + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + layout = layout.backend_layout + + return jax.make_array_from_process_local_data(layout, per_process_batch) + + +def initialize_rng(): + """Initializes the global random number generator across processes. + + This is required for consistent initialization in multi-host settings. + """ + global_seed = rng_utils.get_random_seed() + # Only set a random seed if not already set + # via keras.config.set_random_seed() + if global_seed is None: + # Generate a random seed on each CPU host and psum them to get a single + # consistent seed across all processes. + cpu_devices = jax.devices("cpu") + num_local_cpu_devices = jax.local_device_count("cpu") + # Seed must be in range [0, 2^32 - 1], so to ensure proper range and + # avoid signed integer overflow, we use uint32. + local_seed = jax.numpy.asarray( + [seed_generator.make_default_seed()] * num_local_cpu_devices, + dtype=jax.numpy.uint32, + ) + # Sum across processes and pull out the first item. + global_seed = jax.pmap( + lambda x: jax.lax.psum(x, "all"), + axis_name="all", + devices=cpu_devices, + )(local_seed).item(0) + # Set the global seed. + rng_utils.set_random_seed(global_seed) + + # Check if the global seed generator is set and ensure it has an initialized + # seed. Otherwise, reset the seed to the global seed. + global_seed_generator = global_state.get_global_attribute( + "global_seed_generator" + ) + if global_seed_generator is not None: + seed = global_seed_generator.get_config()["seed"] + if seed is None: + global_state.set_global_attribute( + "global_seed_generator", + seed_generator.SeedGenerator( + seed=global_seed, + name=global_seed_generator.name, + backend=global_seed_generator.backend, + ), + ) + + +def initialize(job_addresses, num_processes, process_id): + if job_addresses and "," in job_addresses: + # When user provide all the job addresses, we will split and get the + # first one, which is the coordinator. + job_addresses = job_addresses.split(",") + # Do a sanity check to make sure the number of addresses also match + # the num_processes. + if num_processes is not None and num_processes != len(job_addresses): + raise ValueError( + f"The provided job_addresses {job_addresses} has " + f"{len(job_addresses)} jobs, but num_processes is " + f"{num_processes}" + ) + coordinator_address = job_addresses[0] + else: + coordinator_address = job_addresses + + jax.distributed.initialize( + coordinator_address=coordinator_address, + num_processes=num_processes, + process_id=process_id, + ) + + # Ensure the random number generator is initialized across processes. + initialize_rng() + + +def num_processes(): + """Return the number of processes for the current distribution setting.""" + return jax.process_count() + + +def process_id(): + """Return the current process ID for the distribution setting.""" + return jax.process_index() + + +def _to_backend_device(device_name): + if isinstance(device_name, jax.Device): + return device_name + device_name = str(device_name) + if ":" not in device_name: + device_type, device_id = device_name, 0 + else: + device_type, device_id = device_name.split(":") + + devices = jax.devices(backend=device_type) + for device in devices: + if device.platform == device_type and device.id == int(device_id): + return device + raise ValueError(f"Device not found: {device_name}") + + +def _to_backend_mesh(device_mesh): + """Convert the DeviceMesh to JAX backend specific Mesh. + + Args: + device_mesh: DeviceMesh instance to convert. + + Returns: + A `jax.sharding.Mesh` instance. + """ + shape = device_mesh.devices.shape + devices = [_to_backend_device(d) for d in device_mesh.devices.flatten()] + devices = np.array(devices).reshape(shape) + return jax.sharding.Mesh(devices, device_mesh.axis_names) + + +def _to_backend_layout(tensor_layout): + """Convert the TensorLayout to JAX backend specific Sharding. + + Args: + tensor_layout: TensorLayout instance to convert. + + Returns: + A `jax.sharding.NamedSharding` instance. + """ + if tensor_layout.device_mesh is None: + raise ValueError( + "Cannot create sharding when device mesh is not set " + "for TensorLayout." + ) + partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) + jax_mesh = tensor_layout.device_mesh.backend_mesh + return jax.sharding.NamedSharding(jax_mesh, partition_spec) diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py new file mode 100644 index 000000000000..8938c14fc50a --- /dev/null +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -0,0 +1,454 @@ +"""Test for distribution_lib.py.""" + +import functools +import os +from unittest import mock + +import jax +import numpy as np +import pytest +from jax.experimental import layout as jax_layout + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.backend import distribution_lib as backend_dlib +from keras.src.distribution import distribution_lib + +if backend.backend() == "jax": + # Due to https://github.com/google/jax/issues/17188, we can't + # override the XLA flag after the JAX back init. We have to + # run this at top level to let JAX pick the flag value. + xla_flags = os.getenv("XLA_FLAGS") or "" + # Don't override user-specified device count, or other XLA flags. + if "xla_force_host_platform_device_count" not in xla_flags: + os.environ["XLA_FLAGS"] = ( + f"{xla_flags} --xla_force_host_platform_device_count=8" + ) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Backend specific test", +) +class JaxDistributionLibTest(testing.TestCase): + def _create_jax_layout(self, sharding): + # Use jax_layout.Format or jax_layout.Layout if available. + if hasattr(jax_layout, "Format"): + return jax_layout.Format(sharding=sharding) + elif hasattr(jax_layout, "Layout"): + return jax_layout.Layout(sharding=sharding) + + return sharding + + def test_list_devices(self): + self.assertEqual(len(distribution_lib.list_devices()), 8) + self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) + self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) + + def test_device_conversion(self): + devices = distribution_lib.list_devices("cpu") + jax_devices = jax.devices("cpu") + + for d, jax_d in zip(devices, jax_devices): + converted_jax_device = backend_dlib._to_backend_device(d) + self.assertIsInstance(converted_jax_device, jax.Device) + self.assertEqual(jax_d, converted_jax_device) + + @mock.patch.object(jax.distributed, "initialize", return_value=None) + def test_initialize_with_all_job_addresses(self, mock_jax_initialize): + backend_dlib.initialize("10.0.0.1:1234,10.0.0.2:2345", 2, 0) + mock_jax_initialize.assert_called_once_with( + coordinator_address="10.0.0.1:1234", num_processes=2, process_id=0 + ) + + def test_initialize_validate_job_and_process(self): + with self.assertRaisesRegex( + ValueError, "has 2 jobs, but num_processes is 3" + ): + backend_dlib.initialize("10.0.0.1:1234,10.0.0.2:2345", 3, 0) + + @mock.patch.object(jax.distributed, "initialize", return_value=None) + def test_initialize_with_coordinator_address(self, mock_jax_initialize): + backend_dlib.initialize("10.0.0.1:1234", 2, 0) + mock_jax_initialize.assert_called_once_with( + coordinator_address="10.0.0.1:1234", num_processes=2, process_id=0 + ) + + def test_distribute_tensor(self): + jax_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(2, 4), ("batch", "model") + ) + + inputs = jax.numpy.array(np.random.normal(size=(16, 8))) + target_layout = jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("batch", None) + ) + + @functools.partial(jax.jit, static_argnames="target_layout") + def test_function(inputs, target_layout): + return distribution_lib.distribute_tensor(inputs, target_layout) + + result = test_function(inputs, target_layout) + # Note that the returned tensor has a different sharding implementation + # which is GSPMDSharding, but it should be equivalent as the target + # layout specified. + self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) + + # Test without jit + result = distribution_lib.distribute_tensor(inputs, target_layout) + self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) + + def test_distribute_variable(self): + # This test only verify the single worker/process behavior. + jax_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(2, 4), ("batch", "model") + ) + + variable = jax.numpy.array(np.random.normal(size=(16, 8))) + target_layout = jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("model", None) + ) + + result = backend_dlib.distribute_variable(variable, target_layout) + # Note that the returned tensor has a different sharding implementation + # which is GSPMDSharding, but it should be equivalent as the target + # layout specified. + self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) + + def test_distribute_input_data(self): + # This test only verify the single worker/process behavior. + # The multi-process test lives in g3. + jax_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(2, 4), ("batch", "model") + ) + + input_data = jax.numpy.array(np.random.normal(size=(16, 8))) + target_layout = jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("batch", None) + ) + + result = backend_dlib.distribute_variable(input_data, target_layout) + # Note that the returned tensor has a different sharding implementation + # which is GSPMDSharding, but it should be equivalent as the target + # layout specified. + self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2)) + + def test_distribute_tensor_with_jax_layout(self): + jax_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(2, 4), ("batch", "model") + ) + + inputs = jax.numpy.array(np.random.normal(size=(16, 8))) + target_layout = self._create_jax_layout( + sharding=jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("batch", None) + ) + ) + + @functools.partial(jax.jit, static_argnames="target_layout") + def test_function(inputs, target_layout): + return distribution_lib.distribute_tensor(inputs, target_layout) + + result = test_function(inputs, target_layout) + # Note that the returned tensor has a different sharding implementation + # which is GSPMDSharding, but it should be equivalent as the target + # layout specified. + self.assertTrue( + result.sharding.is_equivalent_to(target_layout.sharding, ndim=2) + ) + + # Test without jit. + result = distribution_lib.distribute_tensor(inputs, target_layout) + self.assertTrue( + result.sharding.is_equivalent_to(target_layout.sharding, ndim=2) + ) + + def test_distribute_variable_with_jax_layout(self): + # This test only verify the single worker/process behavior. + jax_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(2, 4), ("batch", "model") + ) + + variable = jax.numpy.array(np.random.normal(size=(16, 8))) + target_layout = self._create_jax_layout( + sharding=jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("model", None) + ) + ) + + result = backend_dlib.distribute_variable(variable, target_layout) + # Note that the returned tensor has a different sharding implementation + # which is GSPMDSharding, but it should be equivalent as the target + # layout specified. + self.assertTrue( + result.sharding.is_equivalent_to(target_layout.sharding, ndim=2) + ) + + def test_distribute_input_data_with_jax_layout(self): + # This test only verify the single worker/process behavior. + jax_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape(2, 4), ("batch", "model") + ) + + input_data = jax.numpy.array(np.random.normal(size=(16, 8))) + target_layout = self._create_jax_layout( + sharding=jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("batch", None) + ) + ) + + result = backend_dlib.distribute_variable(input_data, target_layout) + # Note that the returned tensor has a different sharding implementation + # which is GSPMDSharding, but it should be equivalent as the target + # layout specified. + self.assertTrue( + result.sharding.is_equivalent_to(target_layout.sharding, ndim=2) + ) + + def test_processes(self): + self.assertEqual(backend_dlib.process_id(), 0) + self.assertEqual(backend_dlib.num_processes(), 1) + + def test_to_backend_mesh(self): + devices = [f"cpu:{i}" for i in range(8)] + shape = (4, 2) + axis_names = ["batch", "model"] + + mesh = distribution_lib.DeviceMesh(shape, axis_names, devices) + jax_mesh = backend_dlib._to_backend_mesh(mesh) + + self.assertIsInstance(jax_mesh, jax.sharding.Mesh) + self.assertEqual(jax_mesh.devices.shape, shape) + self.assertEqual(jax_mesh.axis_names, ("batch", "model")) + + def test_to_backend_layout(self): + axes = ["data", None] + mesh = distribution_lib.DeviceMesh( + (4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)] + ) + layout = distribution_lib.TensorLayout(axes, mesh) + jax_sharding = backend_dlib._to_backend_layout(layout) + jax_mesh = backend_dlib._to_backend_mesh(mesh) + self.assertEqual( + jax_sharding, + jax.sharding.NamedSharding( + jax_mesh, jax.sharding.PartitionSpec("data", None) + ), + ) + + def test_validation_for_device_mesh(self): + axes = ["data", None] + layout = distribution_lib.TensorLayout(axes, device_mesh=None) + + with self.assertRaisesRegex( + ValueError, "Cannot create sharding when device mesh is not set" + ): + backend_dlib._to_backend_layout(layout) + + def test_variable_assignment_reuse_layout(self): + shape = (4, 2) + axis_names = ["batch", "model"] + device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, backend_dlib.list_devices() + ) + layout_map = distribution_lib.LayoutMap(device_mesh) + layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout( + [None, "model"] + ) + layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"]) + + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="batch" + ) + + with distribution.scope(): + dense_layer = layers.Dense(8) + dense_layer.build((16, 16)) + + self.assertEqual( + dense_layer.kernel._value.sharding.spec, (None, "model") + ) + self.assertEqual(dense_layer.bias._value.sharding.spec, ("model",)) + + # Assign a numpy value to dense layer to mimic the model weight loading + new_kernel = np.random.normal(size=(16, 8)) + new_bias = np.random.normal(size=(8)) + dense_layer.kernel.assign(new_kernel) + dense_layer.bias.assign(new_bias) + + # Make sure the loaded value still use the layout when it is + # initialized, even outside of the distribution scope. + self.assertEqual( + dense_layer.kernel._value.sharding.spec, (None, "model") + ) + self.assertEqual(dense_layer.bias._value.sharding.spec, ("model",)) + + def test_e2e_data_parallel_model(self): + distribution = distribution_lib.DataParallel( + devices=backend_dlib.list_devices() + ) + + with distribution.scope(): + inputs = layers.Input(shape=[28, 28, 1]) + y = layers.Flatten()(inputs) + y = layers.Dense(units=200, use_bias=False, activation="relu")(y) + y = layers.Dropout(0.4)(y) + y = layers.Dense(units=10, activation="softmax")(y) + model = models.Model(inputs=inputs, outputs=y) + + # Make sure all the weights are properly sharded. + for weight in model.weights: + self.assertTrue(weight._value.sharding.is_fully_replicated) + + inputs = np.random.normal(size=(32, 28, 28, 1)) + labels = np.random.normal(size=(32, 10)) + + with distribution.scope(): + model.compile(loss="mse") + model.fit(inputs, labels) + + def test_e2e_model_parallel_model(self): + shape = (4, 2) + axis_names = ["batch", "model"] + device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, backend_dlib.list_devices() + ) + + layout_map = distribution_lib.LayoutMap(device_mesh) + layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout( + [None, "model"] + ) + layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"]) + + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="batch" + ) + with distribution.scope(): + inputs = layers.Input(shape=[28, 28, 1]) + y = layers.Flatten()(inputs) + y = layers.Dense(units=200, use_bias=False, activation="relu")(y) + y = layers.Dropout(0.4)(y) + y = layers.Dense(units=10, activation="softmax")(y) + model = models.Model(inputs=inputs, outputs=y) + + for weight in model.weights: + if "kernel" in weight.name: + self.assertEqual(weight._value.sharding.spec, (None, "model")) + elif "bias" in weight.name: + self.assertEqual(weight._value.sharding.spec, ("model",)) + else: + self.assertTrue(weight._value.sharding.is_fully_replicated) + + inputs = np.random.normal(size=(32, 28, 28, 1)) + labels = np.random.normal(size=(32, 10)) + + with distribution.scope(): + model.compile(loss="mse") + model.fit(inputs, labels) + + def test_e2e_model_parallel_with_output_sharding(self): + shape = (4, 2) + axis_names = ["batch", "model"] + device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, backend_dlib.list_devices() + ) + + layout_map = distribution_lib.LayoutMap(device_mesh) + layout_map[".*dense.*kernel"] = distribution_lib.TensorLayout( + [None, "model"] + ) + layout_map[".*dense.*bias"] = distribution_lib.TensorLayout(["model"]) + # Force the dense layer output to be batch parallel only, and not + # sharded on model dimension. + layout_map[".*dense.*output"] = ("batch", None) + + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="batch" + ) + sharding_capture = ShardingCaptureLayer() + with distribution.scope(): + inputs = layers.Input(shape=[28, 28, 1]) + y = layers.Flatten()(inputs) + y = layers.Dense(units=200, use_bias=False, activation="relu")(y) + y = sharding_capture(y) + y = layers.Dropout(0.4)(y) + y = layers.Dense(units=10, activation="softmax")(y) + model = models.Model(inputs=inputs, outputs=y) + + for weight in model.weights: + if "kernel" in weight.name: + self.assertEqual(weight._value.sharding.spec, (None, "model")) + elif "bias" in weight.name: + self.assertEqual(weight._value.sharding.spec, ("model",)) + else: + self.assertTrue(weight._value.sharding.is_fully_replicated) + + inputs = np.random.normal(size=(32, 28, 28, 1)) + labels = np.random.normal(size=(32, 10)) + + with distribution.scope(): + model.compile(loss="mse") + model.fit(inputs, labels) + + # Note that the intermediate_tensor_layout is only captured during the + # actual training, and not at the model building time. + intermediate_tensor_layout = jax.sharding.NamedSharding( + backend_dlib._to_backend_mesh(distribution.device_mesh), + jax.sharding.PartitionSpec("batch", None), + ) + self.assertTrue( + sharding_capture.captured_input_sharding.is_equivalent_to( + intermediate_tensor_layout, ndim=2 + ) + ) + + def test_distribute_data_input(self): + per_process_batch = jax.numpy.arange(24).reshape( + 6, 4 + ) # Example input array + devices = jax.devices()[:4] # Simulate 4 devices + batch_dim_size, model_dim_size = 2, 2 + mesh = jax.sharding.Mesh( + np.array(devices).reshape(batch_dim_size, model_dim_size), + axis_names=["batch", "model"], + ) + layout = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("batch", None) + ) + + result = backend_dlib.distribute_data_input( + per_process_batch, layout, "batch" + ) + + # Check the shape of the global batch array + self.assertEqual( + result.shape, (6, 4) + ) # (per_replica_batch_size * num_model_replicas_total, 4) + + # Check the sharding of the global batch array + self.assertEqual(len(result.addressable_shards), len(devices)) + # Since batch_dim_size=2, there are 2 model replicas so there is one + # replication of data for model replica #1 and another replication of + # data for model replica #2. Within each model replica, the data is + # sharded to two shards. Therefore, each shard has 1/2 of + # per_process_batch. + for shard in result.addressable_shards: + self.assertEqual(shard.data.shape, (3, 4)) + + +class ShardingCaptureLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.captured_input_sharding = None + self.supports_masking = True + + def call(self, inputs): + jax.debug.inspect_array_sharding( + inputs, callback=lambda x: self.capture_input_sharding(x) + ) + return inputs + + def capture_input_sharding(self, sharding): + self.captured_input_sharding = sharding diff --git a/keras/src/backend/jax/export.py b/keras/src/backend/jax/export.py new file mode 100644 index 000000000000..71f0d88a5768 --- /dev/null +++ b/keras/src/backend/jax/export.py @@ -0,0 +1,184 @@ +import copy +import inspect +import itertools +import string +import warnings + +from keras.src import tree +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.utils.module_utils import tensorflow as tf + + +class JaxExportArchive: + def __init__(self): + self._backend_variables = [] + self._backend_trainable_variables = [] + self._backend_non_trainable_variables = [] + + def _track_layer(self, layer): + # Variables in the lists below are actually part of the trackables + # that get saved, because the lists are created in __init__. + trainable_variables = layer.trainable_variables + non_trainable_variables = layer.non_trainable_variables + + self._tf_trackable.trainable_variables += tree.map_structure( + self._convert_to_tf_variable, trainable_variables + ) + self._tf_trackable.non_trainable_variables += tree.map_structure( + self._convert_to_tf_variable, non_trainable_variables + ) + self._tf_trackable.variables = ( + self._tf_trackable.trainable_variables + + self._tf_trackable.non_trainable_variables + ) + + self._backend_trainable_variables += trainable_variables + self._backend_non_trainable_variables += non_trainable_variables + self._backend_variables = ( + self._backend_trainable_variables + + self._backend_non_trainable_variables + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None) + # Use `copy.copy()` to avoid modification issues. + jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {} + is_static = bool(kwargs.pop("is_static", False)) + + # Configure `jax2tf_kwargs` + if "native_serialization" not in jax2tf_kwargs: + jax2tf_kwargs["native_serialization"] = ( + self._check_device_compatible() + ) + if "polymorphic_shapes" not in jax2tf_kwargs: + jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( + input_signature + ) + + # Note: we truncate the number of parameters to what is specified by + # `input_signature`. + fn_signature = inspect.signature(fn) + fn_parameters = list(fn_signature.parameters.values()) + + if is_static: + from jax.experimental import jax2tf + + jax_fn = jax2tf.convert(fn, **jax2tf_kwargs) + jax_fn.__signature__ = inspect.Signature( + parameters=fn_parameters[0 : len(input_signature)], + return_annotation=fn_signature.return_annotation, + ) + + decorated_fn = tf.function( + jax_fn, + input_signature=input_signature, + autograph=False, + ) + else: + # 1. Create a stateless wrapper for `fn` + # 2. jax2tf the stateless wrapper + # 3. Create a stateful function that binds the variables with + # the jax2tf converted stateless wrapper + # 4. Make the signature of the stateful function the same as the + # original function + # 5. Wrap in a `tf.function` + def stateless_fn(variables, *args, **kwargs): + state_mapping = zip(self._backend_variables, variables) + with StatelessScope(state_mapping=state_mapping) as scope: + output = fn(*args, **kwargs) + + # Gather updated non-trainable variables + non_trainable_variables = [] + for var in self._backend_non_trainable_variables: + new_value = scope.get_current_value(var) + non_trainable_variables.append(new_value) + return output, non_trainable_variables + + jax2tf_stateless_fn = self._convert_jax2tf_function( + stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs + ) + + def stateful_fn(*args, **kwargs): + output, non_trainable_variables = jax2tf_stateless_fn( + # Change the trackable `ListWrapper` to a plain `list` + list(self._tf_trackable.variables), + *args, + **kwargs, + ) + for var, new_value in zip( + self._tf_trackable.non_trainable_variables, + non_trainable_variables, + ): + var.assign(tf.cast(new_value, var.dtype)) + return output + + stateful_fn.__signature__ = inspect.Signature( + parameters=fn_parameters[0 : len(input_signature)], + return_annotation=fn_signature.return_annotation, + ) + + decorated_fn = tf.function( + stateful_fn, + input_signature=input_signature, + autograph=False, + ) + return decorated_fn + + def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None): + from jax.experimental import jax2tf + + variables_shapes = self._to_polymorphic_shape( + self._backend_variables, allow_none=False + ) + input_shapes = list(jax2tf_kwargs["polymorphic_shapes"]) + jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes + return jax2tf.convert(fn, **jax2tf_kwargs) + + def _to_polymorphic_shape(self, struct, allow_none=True): + if allow_none: + # Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz + # for unknown non-batch dims. Defined here to be scope per endpoint. + dim_names = itertools.chain( + string.ascii_lowercase, + itertools.starmap( + lambda a, b: a + b, + itertools.product(string.ascii_lowercase, repeat=2), + ), + ) + + def convert_shape(x): + poly_shape = [] + for index, dim in enumerate(list(x.shape)): + if dim is not None: + poly_shape.append(str(dim)) + elif not allow_none: + raise ValueError( + f"Illegal None dimension in {x} with shape {x.shape}" + ) + elif index == 0: + poly_shape.append("batch") + else: + poly_shape.append(next(dim_names)) + return f"({', '.join(poly_shape)})" + + return tree.map_structure(convert_shape, struct) + + def _check_device_compatible(self): + from jax import default_backend as jax_device + + if ( + jax_device() == "gpu" + and len(tf.config.list_physical_devices("GPU")) == 0 + ): + warnings.warn( + "JAX backend is using GPU for export, but installed " + "TF package cannot access GPU, so reloading the model with " + "the TF runtime in the same environment will not work. " + "To use JAX-native serialization for high-performance export " + "and serving, please install `tensorflow-gpu` and ensure " + "CUDA version compatibility between your JAX and TF " + "installations." + ) + return False + else: + return True diff --git a/keras/src/backend/jax/image.py b/keras/src/backend/jax/image.py new file mode 100644 index 000000000000..52e37eed6c45 --- /dev/null +++ b/keras/src/backend/jax/image.py @@ -0,0 +1,897 @@ +import functools + +import jax +import jax.numpy as jnp + +from keras.src import backend +from keras.src.backend.jax.core import convert_to_tensor +from keras.src.random.seed_generator import draw_seed + +RESIZE_INTERPOLATIONS = ( + "bilinear", + "nearest", + "lanczos3", + "lanczos5", + "bicubic", +) +AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order + "nearest": 0, + "bilinear": 1, +} +AFFINE_TRANSFORM_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} +MAP_COORDINATES_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} +SCALE_AND_TRANSLATE_METHODS = { + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", +} + + +def rgb_to_grayscale(images, data_format=None): + images = convert_to_tensor(images) + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + # Convert to floats + original_dtype = images.dtype + compute_dtype = backend.result_type(images.dtype, float) + images = images.astype(compute_dtype) + + # Ref: tf.image.rgb_to_grayscale + rgb_weights = convert_to_tensor( + [0.2989, 0.5870, 0.1140], dtype=images.dtype + ) + images = jnp.tensordot(images, rgb_weights, axes=(channels_axis, -1)) + images = jnp.expand_dims(images, axis=channels_axis) + return images.astype(original_dtype) + + +def rgb_to_hsv(images, data_format=None): + # Ref: dm_pix + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + eps = jnp.finfo(dtype).eps + images = jnp.where(jnp.abs(images) < eps, 0.0, images) + red, green, blue = jnp.split(images, 3, channels_axis) + red = jnp.squeeze(red, channels_axis) + green = jnp.squeeze(green, channels_axis) + blue = jnp.squeeze(blue, channels_axis) + + def rgb_planes_to_hsv_planes(r, g, b): + value = jnp.maximum(jnp.maximum(r, g), b) + minimum = jnp.minimum(jnp.minimum(r, g), b) + range_ = value - minimum + + safe_value = jnp.where(value > 0, value, 1.0) + safe_range = jnp.where(range_ > 0, range_, 1.0) + + saturation = jnp.where(value > 0, range_ / safe_value, 0.0) + norm = 1.0 / (6.0 * safe_range) + + hue = jnp.where( + value == g, + norm * (b - r) + 2.0 / 6.0, + norm * (r - g) + 4.0 / 6.0, + ) + hue = jnp.where(value == r, norm * (g - b), hue) + hue = jnp.where(range_ > 0, hue, 0.0) + (hue < 0.0).astype(hue.dtype) + return hue, saturation, value + + images = jnp.stack( + rgb_planes_to_hsv_planes(red, green, blue), axis=channels_axis + ) + return images + + +def hsv_to_rgb(images, data_format=None): + # Ref: dm_pix + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + hue, saturation, value = jnp.split(images, 3, channels_axis) + hue = jnp.squeeze(hue, channels_axis) + saturation = jnp.squeeze(saturation, channels_axis) + value = jnp.squeeze(value, channels_axis) + + def hsv_planes_to_rgb_planes(hue, saturation, value): + dh = jnp.mod(hue, 1.0) * 6.0 + dr = jnp.clip(jnp.abs(dh - 3.0) - 1.0, 0.0, 1.0) + dg = jnp.clip(2.0 - jnp.abs(dh - 2.0), 0.0, 1.0) + db = jnp.clip(2.0 - jnp.abs(dh - 4.0), 0.0, 1.0) + one_minus_s = 1.0 - saturation + + red = value * (one_minus_s + saturation * dr) + green = value * (one_minus_s + saturation * dg) + blue = value * (one_minus_s + saturation * db) + return red, green, blue + + images = jnp.stack( + hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis + ) + return images + + +def resize( + images, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in RESIZE_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}" + ) + if fill_mode != "constant": + raise ValueError( + "Invalid value for argument `fill_mode`. Only `'constant'` " + f"is supported. Received: fill_mode={fill_mode}" + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) + if not len(size) == 2: + raise ValueError( + "Argument `size` must be a tuple of two elements " + f"(height, width). Received: size={size}" + ) + size = tuple(size) + target_height, target_width = size + if len(images.shape) == 4: + if data_format == "channels_last": + size = (images.shape[0],) + size + (images.shape[-1],) + else: + size = (images.shape[0], images.shape[1]) + size + batch_size = images.shape[0] + elif len(images.shape) == 3: + if data_format == "channels_last": + size = size + (images.shape[-1],) + else: + size = (images.shape[0],) + size + else: + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + if crop_to_aspect_ratio: + shape = images.shape + if data_format == "channels_last": + height, width = shape[-3], shape[-2] + else: + height, width = shape[-2], shape[-1] + crop_height = int(float(width * target_height) / target_width) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / target_height) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + if data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + elif pad_to_aspect_ratio: + shape = images.shape + if data_format == "channels_last": + height, width, channels = shape[-3], shape[-2], shape[-1] + else: + height, width, channels = shape[-2], shape[-1], shape[-3] + + pad_height = int(float(width * target_height) / target_width) + pad_height = max(height, pad_height) + pad_width = int(float(height * target_width) / target_height) + pad_width = max(width, pad_width) + img_box_hstart = int(float(pad_height - height) / 2) + img_box_wstart = int(float(pad_width - width) / 2) + if data_format == "channels_last": + if img_box_hstart > 0: + if len(images.shape) == 4: + padded_img = jnp.concatenate( + [ + jnp.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + jnp.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, + ) + else: + padded_img = jnp.concatenate( + [ + jnp.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + jnp.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=0, + ) + elif img_box_wstart > 0: + if len(images.shape) == 4: + padded_img = jnp.concatenate( + [ + jnp.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + images, + jnp.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=2, + ) + else: + padded_img = jnp.concatenate( + [ + jnp.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + images, + jnp.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, + ) + else: + padded_img = images + else: + if img_box_hstart > 0: + if len(images.shape) == 4: + padded_img = jnp.concatenate( + [ + jnp.ones( + (batch_size, channels, img_box_hstart, width) + ) + * fill_value, + images, + jnp.ones( + (batch_size, channels, img_box_hstart, width) + ) + * fill_value, + ], + axis=2, + ) + else: + padded_img = jnp.concatenate( + [ + jnp.ones((channels, img_box_hstart, width)) + * fill_value, + images, + jnp.ones((channels, img_box_hstart, width)) + * fill_value, + ], + axis=1, + ) + elif img_box_wstart > 0: + if len(images.shape) == 4: + padded_img = jnp.concatenate( + [ + jnp.ones( + (batch_size, channels, height, img_box_wstart) + ) + * fill_value, + images, + jnp.ones( + (batch_size, channels, height, img_box_wstart) + ) + * fill_value, + ], + axis=3, + ) + else: + padded_img = jnp.concatenate( + [ + jnp.ones((channels, height, img_box_wstart)) + * fill_value, + images, + jnp.ones((channels, height, img_box_wstart)) + * fill_value, + ], + axis=2, + ) + else: + padded_img = images + images = padded_img + + return jax.image.resize( + images, size, method=interpolation, antialias=antialias + ) + + +def affine_transform( + images, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + + transform = convert_to_tensor(transform) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if len(transform.shape) not in (1, 2): + raise ValueError( + "Invalid transform rank: expected rank 1 (single transform) " + "or rank 2 (batch of transforms). Received input with shape: " + f"transform.shape={transform.shape}" + ) + + # unbatched case + need_squeeze = False + if len(images.shape) == 3: + images = jnp.expand_dims(images, axis=0) + need_squeeze = True + if len(transform.shape) == 1: + transform = jnp.expand_dims(transform, axis=0) + + if data_format == "channels_first": + images = jnp.transpose(images, (0, 2, 3, 1)) + + batch_size = images.shape[0] + + # get indices + meshgrid = jnp.meshgrid( + *[jnp.arange(size) for size in images.shape[1:]], indexing="ij" + ) + indices = jnp.concatenate( + [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 + ) + indices = jnp.tile(indices, (batch_size, 1, 1, 1, 1)) + + # swap the values + a0 = transform[:, 0] + a2 = transform[:, 2] + b1 = transform[:, 4] + b2 = transform[:, 5] + transform = transform.at[:, 0].set(b1) + transform = transform.at[:, 2].set(b2) + transform = transform.at[:, 4].set(a0) + transform = transform.at[:, 5].set(a2) + + # deal with transform + transform = jnp.pad( + transform, pad_width=[[0, 0], [0, 1]], constant_values=1 + ) + transform = jnp.reshape(transform, (batch_size, 3, 3)) + offset = transform[:, 0:2, 2] + offset = jnp.pad(offset, pad_width=[[0, 0], [0, 1]]) + transform = transform.at[:, 0:2, 2].set(0) + + # transform the indices + coordinates = jnp.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = jnp.moveaxis(coordinates, source=-1, destination=1) + coordinates += jnp.reshape(offset, shape=(*offset.shape, 1, 1, 1)) + + # apply affine transformation + _map_coordinates = functools.partial( + jax.scipy.ndimage.map_coordinates, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + mode=fill_mode, + cval=fill_value, + ) + affined = jax.vmap(_map_coordinates)(images, coordinates) + + if data_format == "channels_first": + affined = jnp.transpose(affined, (0, 3, 1, 2)) + if need_squeeze: + affined = jnp.squeeze(affined, axis=0) + return affined + + +def perspective_transform( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + if start_points.shape[-2:] != (4, 2) or start_points.ndim not in (2, 3): + raise ValueError( + "Invalid start_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {start_points.shape}" + ) + if end_points.shape[-2:] != (4, 2) or end_points.ndim not in (2, 3): + raise ValueError( + "Invalid end_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {end_points.shape}" + ) + if start_points.shape != end_points.shape: + raise ValueError( + "start_points and end_points must have the same shape." + f" Received start_points.shape={start_points.shape}, " + f"end_points.shape={end_points.shape}" + ) + + need_squeeze = False + if len(images.shape) == 3: + images = jnp.expand_dims(images, axis=0) + need_squeeze = True + + if len(start_points.shape) == 2: + start_points = jnp.expand_dims(start_points, axis=0) + if len(end_points.shape) == 2: + end_points = jnp.expand_dims(end_points, axis=0) + + if data_format == "channels_first": + images = jnp.transpose(images, (0, 2, 3, 1)) + + _, height, width, _ = images.shape + transforms = compute_homography_matrix( + jnp.asarray(start_points, dtype="float32"), + jnp.asarray(end_points, dtype="float32"), + ) + + x, y = jnp.meshgrid(jnp.arange(width), jnp.arange(height), indexing="xy") + grid = jnp.stack([x.ravel(), y.ravel(), jnp.ones_like(x).ravel()], axis=0) + + def transform_coordinates(transform): + denom = transform[6] * grid[0] + transform[7] * grid[1] + 1.0 + x_in = ( + transform[0] * grid[0] + transform[1] * grid[1] + transform[2] + ) / denom + y_in = ( + transform[3] * grid[0] + transform[4] * grid[1] + transform[5] + ) / denom + return jnp.stack([y_in, x_in], axis=0) + + transformed_coords = jax.vmap(transform_coordinates)(transforms) + + def interpolate_image(image, coords): + def interpolate_channel(channel_img): + return jax.scipy.ndimage.map_coordinates( + channel_img, + coords, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + mode="constant", + cval=fill_value, + ).reshape(height, width) + + return jax.vmap(interpolate_channel, in_axes=0)( + jnp.moveaxis(image, -1, 0) + ) + + output = jax.vmap(interpolate_image, in_axes=(0, 0))( + images, transformed_coords + ) + output = jnp.moveaxis(output, 1, -1) + + if data_format == "channels_first": + output = jnp.transpose(output, (0, 3, 1, 2)) + if need_squeeze: + output = jnp.squeeze(output, axis=0) + + return output + + +def compute_homography_matrix(start_points, end_points): + start_x, start_y = start_points[..., 0], start_points[..., 1] + end_x, end_y = end_points[..., 0], end_points[..., 1] + + zeros = jnp.zeros_like(end_x) + ones = jnp.ones_like(end_x) + + x_rows = jnp.stack( + [ + end_x, + end_y, + ones, + zeros, + zeros, + zeros, + -start_x * end_x, + -start_x * end_y, + ], + axis=-1, + ) + y_rows = jnp.stack( + [ + zeros, + zeros, + zeros, + end_x, + end_y, + ones, + -start_y * end_x, + -start_y * end_y, + ], + axis=-1, + ) + + coefficient_matrix = jnp.concatenate([x_rows, y_rows], axis=1) + + target_vector = jnp.expand_dims( + jnp.concatenate([start_x, start_y], axis=-1), axis=-1 + ) + + homography_matrix = jnp.linalg.solve(coefficient_matrix, target_vector) + + return homography_matrix.squeeze(-1) + + +def map_coordinates( + inputs, coordinates, order, fill_mode="constant", fill_value=0.0 +): + inputs = convert_to_tensor(inputs) + coordinates = convert_to_tensor(coordinates) + if coordinates.shape[0] != len(inputs.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`inputs`. " + f"Received inputs with shape: {inputs.shape} and coordinate " + f"leading dim of {coordinates.shape[0]}" + ) + if len(coordinates.shape) < 2: + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {coordinates.shape}" + ) + if fill_mode not in MAP_COORDINATES_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES)}. Received: " + f"fill_mode={fill_mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: order={order}" + ) + return jax.scipy.ndimage.map_coordinates( + inputs, coordinates, order, fill_mode, fill_value + ) + + +def gaussian_blur( + images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None +): + def _create_gaussian_kernel(kernel_size, sigma, dtype): + def _get_gaussian_kernel1d(size, sigma): + x = jnp.arange(size, dtype=dtype) - jnp.array( + (size - 1) / 2, dtype=dtype + ) + kernel1d = jnp.exp(-0.5 * (x / sigma) ** 2) + return kernel1d / jnp.sum(kernel1d) + + def _get_gaussian_kernel2d(size, sigma): + kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0]) + kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1]) + return jnp.outer(kernel1d_y, kernel1d_x) + + kernel = _get_gaussian_kernel2d(kernel_size, sigma)[ + jnp.newaxis, jnp.newaxis, :, : + ] + return kernel + + images = convert_to_tensor(images) + dtype = backend.standardize_dtype(images.dtype) + sigma = convert_to_tensor(sigma, dtype=dtype) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + need_squeeze = False + if images.ndim == 3: + images = images[jnp.newaxis, ...] + need_squeeze = True + + if data_format == "channels_last": + images = jnp.transpose(images, (0, 3, 1, 2)) + + num_channels = images.shape[1] + kernel = _create_gaussian_kernel(kernel_size, sigma, dtype) + + kernel = jnp.tile(kernel, (num_channels, 1, 1, 1)) + + blurred_images = jax.lax.conv_general_dilated( + images, + kernel, + window_strides=(1, 1), + padding="SAME", + dimension_numbers=("NCHW", "OIHW", "NCHW"), + feature_group_count=num_channels, + ) + + if data_format == "channels_last": + blurred_images = jnp.transpose(blurred_images, (0, 2, 3, 1)) + + if need_squeeze: + blurred_images = blurred_images.squeeze(axis=0) + + return blurred_images + + +def elastic_transform( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + images = convert_to_tensor(images) + alpha = convert_to_tensor(alpha) + sigma = convert_to_tensor(sigma) + input_dtype = images.dtype + kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1) + + need_squeeze = False + if len(images.shape) == 3: + images = jnp.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_last": + batch_size, height, width, channels = images.shape + channel_axis = -1 + else: + batch_size, channels, height, width = images.shape + channel_axis = 1 + + seed = draw_seed(seed) + dx = ( + jax.random.normal( + seed, shape=(batch_size, height, width), dtype=input_dtype + ) + * sigma + ) + dy = ( + jax.random.normal( + seed, shape=(batch_size, height, width), dtype=input_dtype + ) + * sigma + ) + + dx = gaussian_blur( + jnp.expand_dims(dx, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + dy = gaussian_blur( + jnp.expand_dims(dy, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + + dx = jnp.squeeze(dx) + dy = jnp.squeeze(dy) + + x, y = jnp.meshgrid(jnp.arange(width), jnp.arange(height)) + x, y = x[None, :, :], y[None, :, :] + + distorted_x = x + alpha * dx + distorted_y = y + alpha * dy + + transformed_images = jnp.zeros_like(images) + + if data_format == "channels_last": + for i in range(channels): + transformed_images = transformed_images.at[..., i].set( + jnp.stack( + [ + map_coordinates( + images[b, ..., i], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[ + interpolation + ], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + ) + else: + for i in range(channels): + transformed_images = transformed_images.at[:, i, :, :].set( + jnp.stack( + [ + map_coordinates( + images[b, i, ...], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[ + interpolation + ], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + ) + + if need_squeeze: + transformed_images = jnp.squeeze(transformed_images, axis=0) + transformed_images = transformed_images.astype(input_dtype) + + return transformed_images + + +def scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias=True, +): + if method not in SCALE_AND_TRANSLATE_METHODS: + raise ValueError( + "Invalid value for argument `method`. Expected of one " + f"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}" + ) + images = convert_to_tensor(images) + scale = convert_to_tensor(scale) + translation = convert_to_tensor(translation) + return jax.image.scale_and_translate( + images, + output_shape, + spatial_dims, + scale, + translation, + method, + antialias, + ) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py new file mode 100644 index 000000000000..9810ec7d8ed6 --- /dev/null +++ b/keras/src/backend/jax/layer.py @@ -0,0 +1,14 @@ +from keras.src.backend.config import is_nnx_enabled + +if is_nnx_enabled(): + from flax import nnx + + class BaseLayer(nnx.Module): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(pytree=False, **kwargs) +else: + BaseLayer = object + + +class JaxLayer(BaseLayer): + pass diff --git a/keras/src/backend/jax/linalg.py b/keras/src/backend/jax/linalg.py new file mode 100644 index 000000000000..2b0ff9b1fcf0 --- /dev/null +++ b/keras/src/backend/jax/linalg.py @@ -0,0 +1,103 @@ +import jax +import jax.numpy as jnp +import jax.scipy as jsp + +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.jax.core import cast +from keras.src.backend.jax.core import convert_to_tensor + + +def cholesky(a, upper=False): + out = jnp.linalg.cholesky(a, upper=upper) + try: + # In eager mode, raise for nan to + # achieve behavior consistency with numpy + if jnp.any(jnp.isnan(out)): + raise ValueError( + "Cholesky decomposition failed. " + "The input might not be a valid " + "positive definite matrix." + ) + except jax.errors.TracerBoolConversionError: + # Cannot raise for nan in tracing mode + pass + return out + + +def cholesky_inverse(a, upper=False): + identity = jnp.eye(a.shape[-1], dtype=a.dtype) + inv_chol = solve_triangular(a, identity, lower=not upper) + if upper: + a_inv = jnp.matmul(inv_chol, jnp.transpose(inv_chol)) + else: + a_inv = jnp.matmul(jnp.transpose(inv_chol), inv_chol) + return a_inv + + +def det(a): + return jnp.linalg.det(a) + + +def eig(x): + return jnp.linalg.eig(x) + + +def eigh(x): + return jnp.linalg.eigh(x) + + +def inv(a): + return jnp.linalg.inv(a) + + +def lu_factor(x): + lu_factor_fn = jsp.linalg.lu_factor + if x.ndim > 2: + for i in range(x.ndim - 2): + lu_factor_fn = jax.vmap(lu_factor_fn) + + return lu_factor_fn(x) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return jnp.linalg.qr(x, mode=mode) + + +def solve(a, b): + return jnp.linalg.solve(a, b) + + +def solve_triangular(a, b, lower=False): + return jsp.linalg.solve_triangular(a, b, lower=lower) + + +def svd(x, full_matrices=True, compute_uv=True): + return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + + +def lstsq(a, b, rcond=None): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return jnp.linalg.lstsq(a, b, rcond=rcond)[0] + + +def jvp(fun, primals, tangents, has_aux=False): + return jax.jvp(fun, primals, tangents, has_aux=has_aux) diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py new file mode 100644 index 000000000000..6b04f58a4303 --- /dev/null +++ b/keras/src/backend/jax/math.py @@ -0,0 +1,298 @@ +import math + +import jax +import jax.numpy as jnp + +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.jax.core import cast +from keras.src.backend.jax.core import convert_to_tensor +from keras.src.utils.module_utils import scipy + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + if num_segments is None: + raise ValueError( + "Argument `num_segments` must be set when using the JAX backend. " + "Received: num_segments=None" + ) + return jax.ops.segment_sum( + data, segment_ids, num_segments, indices_are_sorted=sorted + ) + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + if num_segments is None: + raise ValueError( + "Argument `num_segments` must be set when using the JAX backend. " + "Received: num_segments=None" + ) + return jax.ops.segment_max( + data, segment_ids, num_segments, indices_are_sorted=sorted + ) + + +def top_k(x, k, sorted=True): + # Jax does not supported `sorted`, but in the case where `sorted=False`, + # order is not guaranteed, so OK to return sorted output. + return jax.lax.top_k(x, k) + + +def in_top_k(targets, predictions, k): + preds_at_label = jnp.take_along_axis( + predictions, jnp.expand_dims(targets, axis=-1), axis=-1 + ) + # `nan` shouldn't be considered as large probability. + preds_at_label = jnp.where( + jnp.isnan(preds_at_label), -jnp.inf, preds_at_label + ) + rank = 1 + jnp.sum(jnp.greater(predictions, preds_at_label), axis=-1) + return jnp.less_equal(rank, k) + + +def logsumexp(x, axis=None, keepdims=False): + return jax.scipy.special.logsumexp(x, axis=axis, keepdims=keepdims) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return jnp.linalg.qr(x, mode=mode) + + +def extract_sequences(x, sequence_length, sequence_stride): + *batch_shape, signal_length = x.shape + batch_shape = list(batch_shape) + x = jnp.reshape(x, (math.prod(batch_shape), signal_length, 1)) + x = jax.lax.conv_general_dilated_patches( + x, + (sequence_length,), + (sequence_stride,), + "VALID", + dimension_numbers=("NTC", "OIT", "NTC"), + ) + return jnp.reshape(x, (*batch_shape, *x.shape[-2:])) + + +def _get_complex_tensor_from_tuple(x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + f"Received: x={x}" + ) + # `convert_to_tensor` does not support passing complex tensors. We separate + # the input out into real and imaginary and convert them separately. + real, imag = x + # Check shapes. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + "Both the real and imaginary parts should have the same shape. " + f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" + ) + # Ensure dtype is float. + if not jnp.issubdtype(real.dtype, jnp.floating) or not jnp.issubdtype( + imag.dtype, jnp.floating + ): + raise ValueError( + "At least one tensor in input `x` is not of type float." + f"Received: x={x}." + ) + complex_input = jax.lax.complex(real, imag) + return complex_input + + +def fft(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = jnp.fft.fft(complex_input) + return jnp.real(complex_output), jnp.imag(complex_output) + + +def fft2(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = jnp.fft.fft2(complex_input) + return jnp.real(complex_output), jnp.imag(complex_output) + + +def ifft2(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = jnp.fft.ifft2(complex_input) + return jnp.real(complex_output), jnp.imag(complex_output) + + +def rfft(x, fft_length=None): + complex_output = jnp.fft.rfft(x, n=fft_length, axis=-1, norm="backward") + return jnp.real(complex_output), jnp.imag(complex_output) + + +def irfft(x, fft_length=None): + complex_input = _get_complex_tensor_from_tuple(x) + return jnp.fft.irfft(complex_input, n=fft_length, axis=-1, norm="backward") + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + if standardize_dtype(x.dtype) not in {"float32", "float64"}: + raise TypeError( + "Invalid input type. Expected `float32` or `float64`. " + f"Received: input type={x.dtype}" + ) + if fft_length < sequence_length: + raise ValueError( + "`fft_length` must equal or larger than `sequence_length`. " + f"Received: sequence_length={sequence_length}, " + f"fft_length={fft_length}" + ) + if isinstance(window, str): + if window not in {"hann", "hamming"}: + raise ValueError( + "If a string is passed to `window`, it must be one of " + f'`"hann"`, `"hamming"`. Received: window={window}' + ) + x = convert_to_tensor(x) + + if center: + pad_width = [(0, 0) for _ in range(len(x.shape))] + pad_width[-1] = (fft_length // 2, fft_length // 2) + x = jnp.pad(x, pad_width, mode="reflect") + + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + win = convert_to_tensor( + scipy.signal.get_window(window, sequence_length), dtype=x.dtype + ) + else: + win = convert_to_tensor(window, dtype=x.dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + win = jnp.pad(win, [[l_pad, r_pad]]) + else: + win = jnp.ones((sequence_length + l_pad + r_pad), dtype=x.dtype) + + result = jax.scipy.signal.stft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=None, + padded=False, + )[-1] + # scale and swap to (..., num_sequences, fft_bins) + scale = jnp.sqrt(1.0 / win.sum() ** 2) + result = result / scale + result = jnp.swapaxes(result, -2, -1) + return jnp.real(result), jnp.imag(result) + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + x = _get_complex_tensor_from_tuple(x) + dtype = jnp.real(x).dtype + + if len(x.shape) < 2: + raise ValueError( + f"Input `x` must have at least 2 dimensions. " + f"Received shape: {x.shape}" + ) + + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + win = convert_to_tensor( + scipy.signal.get_window(window, sequence_length), dtype=dtype + ) + else: + win = convert_to_tensor(window, dtype=dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + win = jnp.pad(win, [[l_pad, r_pad]]) + else: + win = jnp.ones((sequence_length + l_pad + r_pad), dtype=dtype) + + x = jax.scipy.signal.istft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=False, + time_axis=-2, + freq_axis=-1, + )[-1] + + # scale + x = x / win.sum() if window is not None else x / sequence_stride + + start = 0 if center is False else fft_length // 2 + if length is not None: + end = start + length + elif center is True: + end = -(fft_length // 2) + else: + end = expected_output_len + return x[..., start:end] + + +def rsqrt(x): + return jax.lax.rsqrt(x) + + +def erf(x): + return jax.lax.erf(x) + + +def erfinv(x): + return jax.lax.erf_inv(x) + + +def solve(a, b): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return jnp.linalg.solve(a, b) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + + +def logdet(x): + from keras.src.backend.jax.numpy import slogdet + + # In JAX (like in NumPy) slogdet is more stable than + # `np.log(np.linalg.det(x))`. See + # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html + return slogdet(x)[1] diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py new file mode 100644 index 000000000000..3e8c08e860df --- /dev/null +++ b/keras/src/backend/jax/nn.py @@ -0,0 +1,1458 @@ +import builtins +import inspect +import math + +import jax +import jax.experimental.sparse as jax_sparse +import jax.numpy as jnp +from absl import logging +from jax import lax +from jax import nn as jnn +from jax.experimental.pallas.ops.tpu.splash_attention import ( + splash_attention_kernel, +) +from jax.experimental.pallas.ops.tpu.splash_attention import ( + splash_attention_mask, +) + +from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_jax, +) +from keras.src.backend.jax.core import cast +from keras.src.backend.jax.core import convert_to_tensor + + +def relu(x): + x = convert_to_tensor(x) + return jnn.relu(x) + + +def relu6(x): + x = convert_to_tensor(x) + return jnn.relu6(x) + + +def sigmoid(x): + x = convert_to_tensor(x) + return jnn.sigmoid(x) + + +def sparse_sigmoid(x): + x = convert_to_tensor(x) + return jnn.sparse_sigmoid(x) + + +def tanh(x): + x = convert_to_tensor(x) + return jnn.tanh(x) + + +def tanh_shrink(x): + x = convert_to_tensor(x) + return x - jnp.tanh(x) + + +def softplus(x): + x = convert_to_tensor(x) + return jnn.softplus(x) + + +def softsign(x): + x = convert_to_tensor(x) + return jnn.soft_sign(x) + + +def soft_shrink(x, threshold=0.5): + x = convert_to_tensor(x) + return jnp.where( + x > threshold, + x - threshold, + jnp.where(x < -threshold, x + threshold, 0.0), + ) + + +def sparse_plus(x): + x = convert_to_tensor(x) + return jnn.sparse_plus(x) + + +def silu(x): + x = convert_to_tensor(x) + return jnn.silu(x) + + +def squareplus(x, b=4): + x = convert_to_tensor(x) + return jnn.squareplus(x, b=b) + + +def log_sigmoid(x): + x = convert_to_tensor(x) + return jnn.log_sigmoid(x) + + +def leaky_relu(x, negative_slope=0.2): + x = convert_to_tensor(x) + return jnn.leaky_relu(x, negative_slope=negative_slope) + + +def hard_sigmoid(x): + x = convert_to_tensor(x) + return jnn.hard_sigmoid(x) + + +def hard_silu(x): + x = convert_to_tensor(x) + return jnn.hard_silu(x) + + +def elu(x, alpha=1.0): + x = convert_to_tensor(x) + return jnn.elu(x, alpha=alpha) + + +def selu(x): + x = convert_to_tensor(x) + return jnn.selu(x) + + +def gelu(x, approximate=True): + x = convert_to_tensor(x) + return jnn.gelu(x, approximate) + + +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + return jnn.celu(x, alpha=alpha) + + +def glu(x, axis=-1): + x = convert_to_tensor(x) + return jnn.glu(x, axis=axis) + + +def hard_tanh(x): + x = convert_to_tensor(x) + return jnn.hard_tanh(x) + + +def hard_shrink(x, threshold=0.5): + x = convert_to_tensor(x) + return jnp.where(jnp.abs(x) > threshold, x, 0.0) + + +def threshold(x, threshold, default_value): + x = convert_to_tensor(x) + return jnp.where(x > threshold, x, default_value) + + +def softmax(x, axis=-1): + x = convert_to_tensor(x) + return jnn.softmax(x, axis=axis) + + +def log_softmax(x, axis=-1): + x = convert_to_tensor(x) + return jnn.log_softmax(x, axis=axis) + + +def sparsemax(x, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(x) + logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis) + logits_cumsum = jnp.cumsum(logits_sorted, axis=axis) # find cumulative sum + r = jnp.arange(1, logits.shape[axis] + 1) # Determine the sparsity + r_shape = [1] * logits.ndim + r_shape[axis] = -1 # Broadcast to match the target axis + r = r.reshape(r_shape) + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + k = jnp.sum(support, axis=axis, keepdims=True) + logits_cumsum_safe = jnp.where(support, logits_cumsum, 0.0) + tau = (jnp.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k + output = jnp.maximum(logits - tau, 0.0) + return output + + +def _convert_to_spatial_operand( + x, + num_spatial_dims, + data_format="channels_last", + include_batch_and_channels=True, +): + # Helper function that converts an operand to a spatial operand. + x = (x,) * num_spatial_dims if isinstance(x, int) else x + if not include_batch_and_channels: + return x + if data_format == "channels_last": + x = (1,) + x + (1,) + else: + x = (1,) + (1,) + x + return x + + +def _pool( + inputs, + initial_value, + reduce_fn, + pool_size, + strides=None, + padding="valid", +): + """Helper function to define pooling functions. + + Args: + inputs: input data of shape `N+2`. + initial_value: the initial value for the reduction. + reduce_fn: a reduce function of the form `(T, T) -> T`. + pool_size: a sequence of `N` integers, representing the window size to + reduce over. + strides: a sequence of `N` integers, representing the inter-window + strides (default: `(1, ..., 1)`). + padding: either the string `same` or `valid`. + + Returns: + The output of the reduction for each window slice. + """ + if padding not in ("same", "valid"): + raise ValueError( + f"Invalid padding '{padding}', must be 'same' or 'valid'." + ) + padding = padding.upper() + return lax.reduce_window( + inputs, + initial_value, + reduce_fn, + pool_size, + strides, + padding, + ) + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding) + + +def average_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + + pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) + if padding == "valid": + # Avoid the extra reduce_window. + return pooled / math.prod(pool_size) + else: + # Count the number of valid entries at each input point, then use that + # for computing average. Assumes that any two arrays of same shape will + # be padded the same. Avoid broadcasting on axis where pooling is + # skipped. + shape = [ + (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) + ] + window_counts = _pool( + jnp.ones(shape, inputs.dtype), + 0.0, + lax.add, + pool_size, + strides, + padding, + ) + return pooled / window_counts + + +def _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format="channels_last", + transpose=False, +): + """Create a `lax.ConvDimensionNumbers` for the given inputs.""" + num_dims = num_spatial_dims + 2 + + if data_format == "channels_last": + spatial_dims = tuple(range(1, num_dims - 1)) + inputs_dn = (0, num_dims - 1) + spatial_dims + else: + spatial_dims = tuple(range(2, num_dims)) + inputs_dn = (0, 1) + spatial_dims + + if transpose: + kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) + else: + kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) + + return lax.ConvDimensionNumbers( + lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn + ) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + if data_format == "channels_last": + channels = inputs.shape[-1] + else: + channels = inputs.shape[1] + kernel_in_channels = kernel.shape[-2] + if channels % kernel_in_channels > 0: + raise ValueError( + "The number of input channels must be evenly divisible by " + f"kernel's in_channels. Received input channels {channels} and " + f"kernel in_channels {kernel_in_channels}. " + ) + feature_group_count = channels // kernel_in_channels + kernel = convert_to_tensor(kernel) + inputs = convert_to_tensor(inputs, dtype=kernel.dtype) + return jax.lax.conv_general_dilated( + inputs, + kernel, + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + feature_group_count = ( + inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] + ) + kernel = convert_to_tensor(kernel) + inputs = convert_to_tensor(inputs) + kernel = jnp.reshape( + kernel, + kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), + ) + return jax.lax.conv_general_dilated( + inputs, + kernel, + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + depthwise_conv_output = depthwise_conv( + inputs, + depthwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + return conv( + depthwise_conv_output, + pointwise_kernel, + strides=1, + padding="valid", + data_format=data_format, + dilation_rate=dilation_rate, + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + padding_values = compute_conv_transpose_padding_args_for_jax( + input_shape=inputs.shape, + kernel_shape=kernel.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + + return jax.lax.conv_transpose( + inputs, + kernel, + strides, + padding=padding_values, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + transpose_kernel=True, + ) + + +def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + x = convert_to_tensor(x) + if sparse: + if axis < 0: + axis = axis + len(x.shape) + 1 + if dtype is None: + dtype = "float32" + # We deal with negative inputs by having zeros in the output although + # it's useless. It makes shapes static. + values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype) + values_count = values.shape[0] + indices = [jnp.arange(dim) for dim in x.shape] + indices = jnp.meshgrid(*indices, indexing="ij") + indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices + indices = [a.reshape(values_count, 1).astype("int32") for a in indices] + indices = jnp.concatenate(indices, axis=1) + shape = list(x.shape) + shape.insert(axis, num_classes) + shape = tuple(shape) + return jax_sparse.BCOO( + (values, indices), + shape=shape, + indices_sorted=True, + unique_indices=True, + ) + return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype) + + +def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + x = convert_to_tensor(x) + reduction_axis = 1 if len(x.shape) > 1 else 0 + if sparse: + result = one_hot( + x, num_classes, axis=axis, dtype="int32", sparse=sparse + ) + # JAX's BCOO does not support max reduction, use sum and compare with 0. + result = jax_sparse.bcoo_reduce_sum(result, axes=(reduction_axis,)) + result = jax_sparse.bcoo_sum_duplicates(result) + values = jnp.greater_equal(result.data, 0).astype(dtype) + return jax_sparse.BCOO( + (values, result.indices), + shape=result.shape, + indices_sorted=True, + unique_indices=True, + ) + return jnp.max( + one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), + axis=reduction_axis, + ) + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = jnp.array(target) + output = jnp.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + log_prob = jax.nn.log_softmax(output, axis=axis) + else: + output = output / jnp.sum(output, axis, keepdims=True) + output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = jnp.log(output) + return -jnp.sum(target * log_prob, axis=axis) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = jnp.array(target, dtype="int32") + output = jnp.array(output) + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: + target = jnp.squeeze(target, axis=-1) + + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + if target.shape != output.shape[:-1]: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if from_logits: + log_prob = jax.nn.log_softmax(output, axis=axis) + else: + output = output / jnp.sum(output, axis, keepdims=True) + output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = jnp.log(output) + target = jnn.one_hot(target, output.shape[axis], axis=axis) + return -jnp.sum(target * log_prob, axis=axis) + + +def binary_crossentropy(target, output, from_logits=False): + target = jnp.array(target) + output = jnp.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + log_logits = jax.nn.log_sigmoid(output) + log_neg_logits = jax.nn.log_sigmoid(-output) + return -1.0 * target * log_logits - (1.0 - target) * log_neg_logits + + output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + bce = target * jnp.log(output) + bce += (1.0 - target) * jnp.log(1.0 - output) + return -bce + + +def moments(x, axes, keepdims=False, synchronized=False): + if synchronized: + raise NotImplementedError( + "Argument synchronized=True is not supported with JAX." + ) + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = backend.standardize_dtype(x.dtype) + if ori_dtype in ("float16", "bfloat16"): + need_cast = True + x = cast(x, "float32") + + mean = jnp.mean(x, axes, keepdims=True) + variance = jnp.var(x, axis=axes, keepdims=True) + + if not keepdims: + mean = jnp.squeeze(mean, axes) + variance = jnp.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = jnp.clip( + mean, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max + ) + variance = jnp.clip( + variance, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + shape = [1] * len(x.shape) + shape[axis] = mean.shape[0] + mean = jnp.reshape(mean, shape) + variance = jnp.reshape(variance, shape) + + inv = jax.lax.rsqrt(variance + epsilon) + if scale is not None: + scale = jnp.reshape(scale, shape) + inv = inv * scale + + res = -mean * inv + if offset is not None: + offset = jnp.reshape(offset, shape) + res = res + offset + + return jnp.add(x * inv, res) + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + # Ref: https://github.com/google-deepmind/optax + # optax.ctc_loss_with_forward_probs + target = convert_to_tensor(target, dtype="int32") + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length, "int32") + output_length = convert_to_tensor(output_length, "int32") + batch_size, max_input_length, num_classes = output.shape + batch_size, max_label_length = target.shape + log_epsilon = -1e5 + + # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = cast(output, dtype) + + def _lengths_to_paddings(lengths, max_length): + indices = jnp.arange(max_length).reshape( + (1,) * lengths.ndim + (max_length,) + ) + lengths = jnp.expand_dims(lengths, axis=-1) + elem_valid = indices < lengths + return jnp.logical_not(elem_valid) + + target_paddings = _lengths_to_paddings(target_length, max_label_length) + output_paddings = _lengths_to_paddings(output_length, max_input_length) + target_paddings = target_paddings.astype(output.dtype) + output_paddings = output_paddings.astype(output.dtype) + + logprobs = jnn.log_softmax(output) + label_lengths = max_label_length - jnp.sum(target_paddings, axis=1).astype( + jnp.int32 + ) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (target[:, :-1] == target[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + _one_hot = jax.nn.one_hot( + target, num_classes=num_classes, dtype=logprobs.dtype + ) # [B, N, K] + logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, _one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + # [B, N] + logalpha_phi_init = ( + jnp.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon + ) + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = ( + jnp.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon + ) + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return jnp.concatenate( + [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1 + ) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp( + prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit + ) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) + ) + + pad = pad.reshape((batch_size, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = jax.lax.scan( + loop_body, (logalpha_phi_init, logalpha_emit_init), xs + ) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + # [B, N+1] + _one_hot = jax.nn.one_hot( + label_lengths, + num_classes=max_label_length + 1, + dtype=logalpha_phi_last.dtype, + ) + per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, _one_hot) + return per_seq_loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = jnp.argmax(inputs, axis=-1) + scores = jnp.max(inputs, axis=-1) + + seqlen_mask = jnp.arange(max_length)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = jnp.where(seqlen_mask, mask_index, indices) + scores = jnp.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat_mask = indices[:, 1:] == indices[:, :-1] + repeat_mask = jnp.pad(repeat_mask, ((0, 0), (1, 0))) + indices = jnp.where(repeat_mask, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = jnp.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = jnp.expand_dims(jnp.arange(max_length), axis=0) # [1, N] + order = jnp.tile(order, (batch_size, 1)) # [B, N] + order = jnp.where(invalid_mask, max_length, order) + order = jnp.argsort(order, axis=-1) + indices = jnp.take_along_axis(indices, order, axis=-1) + + scores = -jnp.sum(scores, axis=1)[:, None] + indices = jnp.expand_dims(indices, axis=0) + return indices, scores + + +def _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths) + + batch_size, max_seq_len, num_classes = inputs.shape + inputs = jnn.log_softmax(inputs) + seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] + + if mask_index is None: + mask_index = num_classes - 1 + + # This is a workaround for the fact that jnp.argsort does not support + # the order parameter which is used to break ties when scores are equal. + # For compatibility with the tensorflow implementation, we flip the inputs + # and the mask_index, and then flip the classes back to the correct indices + inputs = jnp.flip(inputs, axis=2) + mask_index = num_classes - mask_index - 1 + + _pad = -1 + + init_paths = jnp.full( + (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=jnp.int32 + ) + + num_init_paths = builtins.min(num_classes, beam_width) + max_classes = jnp.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] + init_classes = jnp.where(max_classes == mask_index, _pad, max_classes) + init_paths = init_paths.at[:, :num_init_paths, 0].set(init_classes) + + init_scores = ( + jnp.full((batch_size, 2 * beam_width), -jnp.inf, dtype=inputs.dtype) + .at[:, :num_init_paths] + .set(jnp.take_along_axis(inputs[:, 0], max_classes, axis=1)) + ) + init_masked = init_paths[:, :, 0] == _pad + + def _extend_paths(paths, scores, masked, x): + paths = jnp.repeat(paths, num_classes, axis=0) + scores = jnp.repeat(scores, num_classes) + masked = jnp.repeat(masked, num_classes) + + path_tail_index = jnp.argmax(paths == _pad, axis=1) + paths_arange = jnp.arange(2 * beam_width * num_classes) + path_tails = paths[paths_arange, path_tail_index - 1] + path_tails = jnp.where(path_tail_index == 0, _pad, path_tails) + + classes = jnp.arange(num_classes).at[mask_index].set(_pad) + classes = jnp.tile(classes, 2 * beam_width) + + prev_masked = masked + masked = classes == _pad + + masked_repeat = ~prev_masked & (path_tails == classes) + classes = jnp.where(masked_repeat, _pad, classes) + paths = paths.at[paths_arange, path_tail_index].set(classes) + + x = jnp.tile(x, 2 * beam_width) + scores = scores + x + + return paths, scores, masked + + def _merge_scores(unique_inverse, scores): + scores_max = jnp.max(scores) + scores_exp = jnp.exp(scores - scores_max) + scores = jnp.zeros_like(scores).at[unique_inverse].add(scores_exp) + scores = jnp.log(scores) + scores_max + return scores + + def _prune_paths(paths, scores, masked): + paths, unique_inverse = jnp.unique( + paths, + return_inverse=True, + size=2 * num_classes * beam_width, + axis=0, + fill_value=_pad, + ) + if len(unique_inverse.shape) >= 2: + unique_inverse = jnp.squeeze(unique_inverse, axis=1) + + emit_scores = jnp.where(masked, -jnp.inf, scores) + mask_scores = jnp.where(masked, scores, -jnp.inf) + + emit_scores = _merge_scores(unique_inverse, emit_scores) + mask_scores = _merge_scores(unique_inverse, mask_scores) + + total_scores = jnp.logaddexp(emit_scores, mask_scores) + top_indices = jnp.argsort(total_scores)[-beam_width:] + + paths = paths[top_indices] + emit_scores = emit_scores[top_indices] + mask_scores = mask_scores[top_indices] + + paths = jnp.tile(paths, (2, 1)) + scores = jnp.concatenate([emit_scores, mask_scores]) + masked = jnp.concatenate( + [jnp.zeros(beam_width, bool), jnp.ones(beam_width, bool)] + ) + + return paths, scores, masked + + def _decode_step(paths, scores, masked, x): + paths, scores, masked = _extend_paths(paths, scores, masked, x) + paths, scores, masked = _prune_paths(paths, scores, masked) + return paths, scores, masked + + def _step(prev, x): + paths, scores, masked = prev + x, seqlen_mask = x + + paths, scores, masked = lax.cond( + seqlen_mask, + lambda paths, scores, masked, x: (paths, scores, masked), + _decode_step, + paths, + scores, + masked, + x, + ) + + return (paths, scores, masked), None + + def _decode_batch( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ): + (paths, scores, masked), _ = lax.scan( + _step, + (init_paths, init_scores, init_masked), + (inputs[1:], seqlen_mask[1:]), + ) + + paths, unique_inverse = jnp.unique( + paths, + return_inverse=True, + size=2 * num_classes * beam_width, + axis=0, + fill_value=_pad, + ) + if len(unique_inverse.shape) >= 2: + unique_inverse = jnp.squeeze(unique_inverse, axis=1) + scores = _merge_scores(unique_inverse, scores) + + top_indices = jnp.argsort(scores)[-top_paths:][::-1] + paths = paths[top_indices] + scores = scores[top_indices] + + return paths, scores + + paths, scores = jax.vmap(_decode_batch)( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ) + + # convert classes back to the correct indices + paths = jnp.where(paths == _pad, _pad, num_classes - paths - 1) + paths = jnp.transpose(paths, [1, 0, 2]) + return paths, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + return _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + mask_index=mask_index, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = jnp.mean(jnp.square(x1 - x2)) + psnr = 20 * jnp.log10(max_val) - 10 * jnp.log10(mse) + return psnr + + +def _can_use_flash_attention(query, key, value, bias, raise_error=False): + """Verify the availability of flash attention.""" + try: + from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout + from jax._src.cudnn.fused_attention_stablehlo import ( + check_compute_capability, + ) + from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version + from jax._src.cudnn.fused_attention_stablehlo import ( + check_is_flash_attention, + ) + from jax._src.cudnn.fused_attention_stablehlo import check_layout + from jax.nn import dot_product_attention as dot_product_attention + except ImportError: + if raise_error: + raise ImportError( + "Flash attention is not supported in your current JAX version. " + "Please update it by following the official guide: " + "https://jax.readthedocs.io/en/latest/installation.html" + ) + return False + + if jax.devices()[0].platform == "tpu": + return True + try: + # Check if cuDNN is installed and raise RuntimeError if cuDNN is not + # detected + cudnn_version = check_cudnn_version() + # Only support at least Ampere + if not check_compute_capability("8.0"): + raise RuntimeError("Require at least Ampere arch to run") + # Check inputs layout + check_layout_params = list( + inspect.signature(check_layout).parameters.keys() + ) + for known_param in ("query", "key", "value", "bias", "layout"): + check_layout_params.remove(known_param) + # Defaults to `None` when not specified. + kwargs = {key: None for key in check_layout_params} + check_layout( + query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs + ) + check_is_flash_attention( + query, + key, + _normalize_layout("BTNH"), + cudnn_version, + bias is not None, + is_training=False, + ) + return True + except: + if raise_error: + raise + return False + + +def _apply_masks(logits, mask, is_causal): + if mask is None and not is_causal: + return logits + + combined_mask = jnp.ones_like(logits, dtype="bool") + if mask is not None: + combined_mask = jnp.logical_and(combined_mask, mask) + + if is_causal: + T, S = logits.shape[2], logits.shape[3] + mask = jnp.tril(jnp.ones((T, S), dtype="bool")) + mask = mask[None, None, :, :] + combined_mask = jnp.logical_and(combined_mask, mask) + + large_negative_number = jnp.asarray( + -0.7 * jnp.finfo(logits.dtype).max, dtype=logits.dtype + ) + padded_logits = jnp.where(combined_mask, logits, large_negative_number) + return padded_logits + + +def _dot_product_attention_core( + query, key, value, bias, mask, is_causal, scale +): + logits_dtype = jnp.promote_types(query.dtype, jnp.float32) + logits = jnp.einsum( + "BTNH,BSNH->BNTS", query, key, preferred_element_type=logits_dtype + ) + logits *= jnp.array(scale, dtype=logits.dtype) + + if bias is not None: + logits = (logits + bias).astype(logits.dtype) + + padded_logits = _apply_masks(logits, mask, is_causal) + + # Softmax and it is always carried out in fp32. + padded_logits = padded_logits.astype(jnp.float32) + probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype) + return jnp.einsum("BNTS,BSNH->BTNH", probs, value) + + +def wrap_flash_attention( + query, + key, + value, + decoder_segment_ids, + custom_mask=None, + attn_logits_soft_cap=None, + head_shards=1, + q_seq_shards=1, +): + """Applies a wrapped flash attention mechanism using the Splash kernel. + This function prepares the appropriate attention mask (causal or custom), + constructs a multi-head mask, and applies the Splash multi-head attention + kernel to the provided query, key, and value tensors. It supports optional + sharding and soft capping of attention logits. + Args: + query: jax.Array. The query tensor of shape + (batch, num_heads, seq_len, head_dim). + key: jax.Array. The key tensor of shape + (batch, num_heads, seq_len, head_dim). + value: jax.Array. The value tensor of shape + (batch, num_heads, seq_len, head_dim). + decoder_segment_ids: Optional. Segment IDs for the decoder, used for + sharding or masking. + custom_mask: Optional[jax.Array]. A custom attention mask to apply. If + None, a causal mask is used. + attn_logits_soft_cap: Optional[float]. If provided, applies a soft cap + to the attention logits. + head_shards: int, default=1. Number of shards for the attention heads. + q_seq_shards: int, default=1. Number of shards for the query sequence + dimension. + Returns: + jax.Array: The result of applying the Splash multi-head attention + kernel to the inputs. + Raises: + AssertionError: If sharding along the sequence dimension is attempted + with decoder_segment_ids. + """ + if decoder_segment_ids is not None: + assert query.shape[2] == decoder_segment_ids.q.shape[1], ( + "Sharding along sequence dimension not allowed" + " in TPU kernel attention" + ) + + if custom_mask is not None: + mask = splash_attention_mask.NumpyMask(array=custom_mask) + else: + mask = splash_attention_mask.CausalMask( + shape=(query.shape[2], query.shape[2]) + ) + + # Create multi-head mask + multi_head_mask = splash_attention_mask.MultiHeadMask( + masks=(mask,) * query.shape[1] + ) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=head_shards, + q_seq_shards=q_seq_shards, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + return jax.vmap(splash_kernel)( + query, key, value, segment_ids=decoder_segment_ids + ) + + +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, +): + """Computes dot-product attention given query, key, and value. + + This is the core computation of attention that is used in transformers. + For TPU platforms, flash attention optimizations are automatically applied + when possible, and sharding parameters are inferred from the layout map + in the current distribution context. + + Args: + query: Queries with shape `[batch, time, heads, + depth_k]`. + key: Keys with shape `[batch, time, heads, + depth_k]`. + value: Values with shape `[batch, time, heads, + depth_v]`. + bias: Optional bias with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + mask: Optional mask with shape broadcastable to + `[batch, heads, dest_time, source_time]`. + scale: Float. Optional scale that is applied to the attention + computation. + is_causal: Boolean. Specifying whether causal masking is applied. + flash_attention: Boolean. Whether to use flash attention optimization + for increased performance. Default to None, which means it will + be auto-determined based on the platform, input shapes and + compatibility. + attn_logits_soft_cap: Float. Optional float to softly cap attention + logits to avoid numerical stability issues. Applied as: + `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`. + + Returns: + JAX Array of shape `[batch, time, heads, depth_v]`. + """ + query = convert_to_tensor(query) + key = convert_to_tensor(key) + value = convert_to_tensor(value) + if len(query.shape) != 4 or len(key.shape) != 4 or len(value.shape) != 4: + raise ValueError( + "`dot_product_attention` only supports 4D inputs. " + f"Received: query.shape={query.shape}, key.shape={key.shape}, " + f"value.shape={value.shape}." + ) + compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype) + query = cast(query, compute_dtype) + key = cast(key, compute_dtype) + value = cast(value, compute_dtype) + if bias is not None: + bias = convert_to_tensor(bias, dtype=compute_dtype) + + # Check platform + platform = jax.devices()[0].platform + is_tpu = platform == "tpu" + + # Determine flash attention compatibility + if flash_attention is None: + flash_attention = _can_use_flash_attention(query, key, value, bias) + elif flash_attention is True: + # Use `raise_error=True` to provide more details if the inputs failed to + # use flash attention + _can_use_flash_attention(query, key, value, bias, raise_error=True) + + # TPU-specific flash attention path + if is_tpu and flash_attention: + # Get sharding parameters from distribution context + head_shards = 1 + # Typically keep q_seq_shards=1 for best performance + q_seq_shards = 1 + try: + from keras.src.distribution.distribution_lib import ModelParallel + from keras.src.distribution.distribution_lib import ( + distribution as get_dist, + ) + + # Get current distribution if available + dist = get_dist() + if dist and isinstance(dist, ModelParallel): + mesh = dist.device_mesh + if "model" in mesh.axis_names: + model_dim_index = mesh.axis_names.index("model") + # Set head_shards based on the model dimension of the mesh + head_shards = mesh.shape[model_dim_index] + except (ImportError, ValueError, AttributeError): + # Use default values if detection fails + logging.exception( + "Failed to determine distribution context for sharding. " + "Using default head_shards=1 and q_seq_shards=1." + ) + # Transpose to ('batch', 'heads', 'length', 'head_dim') + query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3)) + key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3)) + value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3)) + + bs, num_heads, q_len, head_dim = query_tpu_layout.shape + + # Apply scale to query if provided + if scale is not None: + # TPU kernel applies 1/sqrt(head_dim) internally, to achieve + # overall QK^T * scale, scale query by (scale * sqrt(head_dim)) + query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim)) + + # Create segment IDs for Splash Attention (for packing/batching) + segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32) + decoder_segment_ids = splash_attention_kernel.SegmentIds( + q=segment_ids, kv=segment_ids + ) + + # Process mask for Splash Attention + custom_mask = None + if mask is not None: + mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask + + if mask_bool.ndim == 3 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0] + elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs: + custom_mask = mask_bool[0, 0] + + if is_causal and custom_mask is not None: + causal_mask = jnp.tril( + jnp.ones((q_len, q_len), dtype=jnp.bool_) + ) + custom_mask = jnp.logical_and(custom_mask, causal_mask) + + if custom_mask is None and is_causal: + custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_)) + + try: + output = wrap_flash_attention( + query_tpu_layout, + key_tpu_layout, + value_tpu_layout, + decoder_segment_ids=decoder_segment_ids, + custom_mask=custom_mask, + attn_logits_soft_cap=attn_logits_soft_cap, + head_shards=head_shards, + q_seq_shards=q_seq_shards, + ) + # Transpose output back to Keras layout + return jnp.transpose(output, axes=(0, 2, 1, 3)) + except Exception: + logging.exception( + "Failed to apply Splash kernel for flash attention. " + "Falling back to JAX native dot_product_attention." + ) + flash_attention = False + + # JAX native dot_product_attention for GPU or fallback for TPU + if hasattr(jax.nn, "dot_product_attention"): + impls = ["cudnn", "xla"] if flash_attention else ["xla"] + for impl in impls: + try: + return jax.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + implementation=impl, + ) + except Exception: + logging.exception( + f"Failed to apply {impl} implementation of " + "jax.nn.dot_product_attention." + ) + + if flash_attention: + raise RuntimeError( + "Flash attention is not supported in your current JAX version. " + "Please update it by following the official guide: " + "https://jax.readthedocs.io/en/latest/installation.html" + ) + # Ref: jax.nn.dot_product_attention + # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886 + # Not support `query_seq_lengths` and `key_value_seq_lengths` args + + # Fallback to custom XLA implementation + # This is the reference implementation from jax.nn.dot_product_attention + output_shape = query.shape + _, _, K, H = key.shape + scale = (1.0 / jnp.sqrt(H)) if scale is None else scale + + # _dot_product_attention_xla + B, T, N, H = query.shape + G = N // K + query = jnp.reshape(query, (B, T, K, G, H)) + + def _reshape_to_grouped(t): + if t is not None: + tB, tN, tT, tS = t.shape + if tN == 1: + t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS)) + else: + assert tN == N + t = jnp.reshape(t, (tB, K, G, tT, tS)) + return t + + bias = _reshape_to_grouped(bias) + mask = _reshape_to_grouped(mask) + vmapped_fn = jax.vmap( + _dot_product_attention_core, + in_axes=(3, None, None, 2, 2, None, None), + out_axes=3, + ) + encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale) + return jnp.reshape(encoded, output_shape) + + +def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + """JAX implementation of Unfold. + Extract sliding local blocks from a **NCHW** batched image tensor. + + Args: + input: 4-D tensor, shape (N, C, H, W) **required**. + kernel_size: int or (kH, kW) + dilation: int or (dH, dW), default 1 + padding: int or (pH, pW), default 0 + stride: int or (sH, sW), default 1 + + Returns: + 3-D tensor, shape (N, C*kH*kW, L) + """ + + def _pair(x): + return (x, x) if isinstance(x, int) else x + + k = _pair(kernel_size) + d = _pair(dilation) + p = _pair(padding) + s = _pair(stride) + + N, C, H, W = input.shape + + # ---- padding ---- + if any(_ > 0 for _ in p): + input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1]))) + + patches = lax.conv_general_dilated_patches( + input, + filter_shape=k, + window_strides=s, + padding="VALID", # has padde + rhs_dilation=d, + dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW' + ) # shape: (N, C*kH*kW, oH, oW) + + # ---- reshape -> (N, C*kH*kW, L) ---- + _, CKK, oH, oW = patches.shape + return patches.reshape(N, CKK, oH * oW) diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py new file mode 100644 index 000000000000..e9def4b255c9 --- /dev/null +++ b/keras/src/backend/jax/numpy.py @@ -0,0 +1,1445 @@ +import builtins +import math + +import jax.experimental.sparse as jax_sparse +import jax.numpy as jnp +from jax import export as jax_export + +from keras.src.backend import config +from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import canonicalize_axis +from keras.src.backend.common.backend_utils import to_tuple_or_list +from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.jax import nn +from keras.src.backend.jax import sparse +from keras.src.backend.jax.core import cast +from keras.src.backend.jax.core import convert_to_tensor + + +def rot90(array, k=1, axes=(0, 1)): + """Rotate an array by 90 degrees in the specified plane.""" + if array.ndim < 2: + raise ValueError( + f"Input array must have at least 2 dimensions. " + f"Received: array.ndim={array.ndim}" + ) + if len(axes) != 2 or axes[0] == axes[1]: + raise ValueError( + f"Invalid axes: {axes}. Axes must be a tuple of " + "two different dimensions." + ) + return jnp.rot90(array, k=k, axes=axes) + + +@sparse.elementwise_binary_union(linear=True, use_sparsify=True) +def add(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.add(x1, x2) + + +def bartlett(x): + x = convert_to_tensor(x) + return cast(jnp.bartlett(x), config.floatx()) + + +def hamming(x): + x = convert_to_tensor(x) + return cast(jnp.hamming(x), config.floatx()) + + +def hanning(x): + x = convert_to_tensor(x) + return cast(jnp.hanning(x), config.floatx()) + + +def heaviside(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.heaviside(x1, x2) + + +def hypot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.hypot(x1, x2) + + +def kaiser(x, beta): + x = convert_to_tensor(x) + return cast(jnp.kaiser(x, beta), config.floatx()) + + +def bincount(x, weights=None, minlength=0, sparse=False): + # Note: bincount is never traceable / jittable because the output shape + # depends on the values in x. + if sparse or isinstance(x, jax_sparse.BCOO): + if isinstance(x, jax_sparse.BCOO): + if weights is not None: + if not isinstance(weights, jax_sparse.BCOO): + raise ValueError("`x` and `weights` must both be BCOOs") + if x.indices is not weights.indices: + # This test works in eager mode only + if not jnp.all(jnp.equal(x.indices, weights.indices)): + raise ValueError( + "`x` and `weights` BCOOs must have the same indices" + ) + weights = weights.data + x = x.data + reduction_axis = 1 if len(x.shape) > 1 else 0 + maxlength = jnp.maximum(jnp.max(x) + 1, minlength) + one_hot_encoding = nn.one_hot(x, maxlength, sparse=True) + if weights is not None: + expanded_weights = jnp.expand_dims(weights, reduction_axis + 1) + one_hot_encoding = one_hot_encoding * expanded_weights + + outputs = jax_sparse.bcoo_reduce_sum( + one_hot_encoding, + axes=(reduction_axis,), + ) + return outputs + if len(x.shape) == 2: + if weights is None: + + def bincount_fn(arr): + return jnp.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return jnp.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + + return jnp.stack(bincounts) + return jnp.bincount(x, weights=weights, minlength=minlength) + + +def einsum(subscripts, *operands, **kwargs): + operands = [convert_to_tensor(x) for x in operands] + # When all operands are of int8, specifying `preferred_element_type` as + # int32 to enable hardware-accelerated einsum + dtypes = list(set(standardize_dtype(x.dtype) for x in operands)) + if len(dtypes) == 1 and dtypes[0] == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + kwargs["preferred_element_type"] = preferred_element_type + return jnp.einsum(subscripts, *operands, **kwargs) + + +@sparse.elementwise_binary_union(linear=True, use_sparsify=True) +def subtract(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.subtract(x1, x2) + + +def matmul(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + # When both x1 and x2 are of int8, specifying `preferred_element_type` as + # int32 to enable hardware-accelerated matmul + x1_dtype = standardize_dtype(x1.dtype) + x2_dtype = standardize_dtype(x2.dtype) + if x1_dtype == "int8" and x2_dtype == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + if isinstance(x1, jax_sparse.JAXSparse) or isinstance( + x2, jax_sparse.JAXSparse + ): + if not hasattr(matmul, "sparse_matmul"): + matmul.sparse_matmul = jax_sparse.sparsify(jnp.matmul) + if isinstance(x1, jax_sparse.BCOO): + x1 = jax_sparse.bcoo_update_layout( + x1, n_batch=len(x1.shape) - 2, on_inefficient="warn" + ) + if isinstance(x2, jax_sparse.BCOO): + x2 = jax_sparse.bcoo_update_layout( + x2, n_batch=len(x2.shape) - 2, on_inefficient="warn" + ) + return matmul.sparse_matmul( + x1, x2, preferred_element_type=preferred_element_type + ) + + return jnp.matmul(x1, x2, preferred_element_type=preferred_element_type) + + +def multiply(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + if isinstance(x1, jax_sparse.BCOO): + if isinstance(x2, jax_sparse.BCOO): + # x1 is sparse, x2 is sparse. + if x1.indices is x2.indices: + # `bcoo_multiply_sparse` will not detect that the indices are + # the same, optimize this case here. + if not x1.unique_indices: + x1 = jax_sparse.bcoo_sum_duplicates(x1) + x2 = jax_sparse.bcoo_sum_duplicates(x2) + return jax_sparse.BCOO( + (jnp.multiply(x1.data, x2.data), x1.indices), + shape=x1.shape, + indices_sorted=True, + unique_indices=True, + ) + else: + return jax_sparse.bcoo_multiply_sparse(x1, x2) + else: + # x1 is sparse, x2 is dense. + out_data = jax_sparse.bcoo_multiply_dense(x1, x2) + return jax_sparse.BCOO( + (out_data, x1.indices), + shape=x1.shape, + indices_sorted=x1.indices_sorted, + unique_indices=x1.unique_indices, + ) + elif isinstance(x2, jax_sparse.BCOO): + # x1 is dense, x2 is sparse. + out_data = jax_sparse.bcoo_multiply_dense(x2, x1) + return jax_sparse.BCOO( + (out_data, x2.indices), + shape=x2.shape, + indices_sorted=x2.indices_sorted, + unique_indices=x2.unique_indices, + ) + return jnp.multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + # `jnp.mean` does not handle low precision (e.g., float16) overflow + # correctly, so we compute with float32 and cast back to the original type. + compute_dtype = dtypes.result_type(x.dtype, "float32") + if "int" in ori_dtype or ori_dtype == "bool": + result_dtype = compute_dtype + else: + result_dtype = ori_dtype + if isinstance(x, jax_sparse.BCOO): + if axis is None: + axis = tuple(range(len(x.shape))) + ( + canonical_axis, + keep_dims_shape, + broadcast_dimensions, + ) = sparse.axis_shape_dims_for_broadcast_in_dim( + axis, x.shape, insert_dims=False + ) + divisor = math.prod(x.shape[i] for i in canonical_axis) + output = jax_sparse.bcoo_reduce_sum(x, axes=canonical_axis) + output = jax_sparse.BCOO( + (output.data.astype(result_dtype) / divisor, output.indices), + shape=output.shape, + ) + if keepdims: + # `bcoo_reduce_sum` does not support keepdims, neither does + # sparsify(jnp.sum), so we recreate the empty dimensions. + output = jax_sparse.bcoo_broadcast_in_dim( + output, + shape=keep_dims_shape, + broadcast_dimensions=broadcast_dimensions, + ) + return output + else: + output = jnp.mean(x, axis=axis, keepdims=keepdims, dtype=compute_dtype) + return cast(output, result_dtype) + + +def max(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial) + + +def ones(shape, dtype=None): + dtype = dtype or config.floatx() + return jnp.ones(shape, dtype=dtype) + + +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() + return jnp.zeros(shape, dtype=dtype) + + +@sparse.elementwise_unary(linear=False) +def absolute(x): + x = convert_to_tensor(x) + return jnp.absolute(x) + + +def abs(x): + return absolute(x) + + +def all(x, axis=None, keepdims=False): + return jnp.all(x, axis=axis, keepdims=keepdims) + + +def angle(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.angle(x) + + +def any(x, axis=None, keepdims=False): + return jnp.any(x, axis=axis, keepdims=keepdims) + + +def amax(x, axis=None, keepdims=False): + return jnp.amax(x, axis=axis, keepdims=keepdims) + + +def amin(x, axis=None, keepdims=False): + return jnp.amin(x, axis=axis, keepdims=keepdims) + + +def append(x1, x2, axis=None): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.append(x1, x2, axis=axis) + + +def arange(start, stop=None, step=None, dtype=None): + def get_dtype(x): + if hasattr(x, "dtype"): + return x.dtype + if jax_export.is_symbolic_dim(x): + return int + return type(x) + + if dtype is None: + dtypes_to_resolve = [get_dtype(start)] + if stop is not None: + dtypes_to_resolve.append(get_dtype(stop)) + if step is not None: + dtypes_to_resolve.append(get_dtype(step)) + dtype = dtypes.result_type(*dtypes_to_resolve) + dtype = standardize_dtype(dtype) + return jnp.arange(start, stop, step=step, dtype=dtype) + + +@sparse.densifying_unary +def arccos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.arccos(x) + + +@sparse.densifying_unary +def arccosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.arccosh(x) + + +@sparse.elementwise_unary(linear=False) +def arcsin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.arcsin(x) + + +@sparse.elementwise_unary(linear=False) +def arcsinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.arcsinh(x) + + +@sparse.elementwise_unary(linear=False) +def arctan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.arctan(x) + + +def arctan2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + return jnp.arctan2(x1, x2) + + +@sparse.elementwise_unary(linear=False) +def arctanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.arctanh(x) + + +def argmax(x, axis=None, keepdims=False): + from keras.src.testing.test_case import uses_cpu + + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "float" not in dtype or not uses_cpu() or x.ndim == 0: + return jnp.argmax(x, axis=axis, keepdims=keepdims) + + # Fix the flush-to-zero (FTZ) issue based on this issue: + # https://github.com/jax-ml/jax/issues/24280 + dtype = dtypes.result_type(dtype, "float32") + x = cast(x, dtype) + is_negative_zero = (x == 0.0) & jnp.signbit(x) + x = jnp.where(is_negative_zero, -jnp.finfo(x.dtype).tiny, x) + return jnp.argmax(x, axis=axis, keepdims=keepdims) + + +def argmin(x, axis=None, keepdims=False): + from keras.src.testing.test_case import uses_cpu + + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "float" not in dtype or not uses_cpu() or x.ndim == 0: + return jnp.argmin(x, axis=axis, keepdims=keepdims) + + # Fix the flush-to-zero (FTZ) issue based on this issue: + # https://github.com/jax-ml/jax/issues/24280 + dtype = dtypes.result_type(dtype, "float32") + x = cast(x, dtype) + is_negative_zero = (x == 0.0) & jnp.signbit(x) + x = jnp.where(is_negative_zero, -jnp.finfo(x.dtype).tiny, x) + return jnp.argmin(x, axis=axis, keepdims=keepdims) + + +def argsort(x, axis=-1): + x = convert_to_tensor(x) + if x.ndim == 0: + return jnp.argsort(x, axis=None) + return jnp.argsort(x, axis=axis) + + +def array(x, dtype=None): + return jnp.array(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = cast(x, dtype) + if weights is not None: + weights = cast(weights, dtype) + return jnp.average(x, weights=weights, axis=axis) + + +def bitwise_and(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + return jnp.bitwise_and(x, y) + + +def bitwise_invert(x): + x = convert_to_tensor(x) + return jnp.invert(x) + + +def bitwise_not(x): + return bitwise_invert(x) + + +def bitwise_or(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + return jnp.bitwise_or(x, y) + + +def bitwise_xor(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + return jnp.bitwise_xor(x, y) + + +def bitwise_left_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + return jnp.left_shift(x, y) + + +def left_shift(x, y): + return bitwise_left_shift(x, y) + + +def bitwise_right_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + return jnp.right_shift(x, y) + + +def right_shift(x, y): + return bitwise_right_shift(x, y) + + +def blackman(x): + x = convert_to_tensor(x) + return cast(jnp.blackman(x), config.floatx()) + + +def broadcast_to(x, shape): + x = convert_to_tensor(x) + return jnp.broadcast_to(x, shape) + + +def cbrt(x): + x = convert_to_tensor(x) + return jnp.cbrt(x) + + +@sparse.elementwise_unary(linear=False) +def ceil(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.ceil(x) + + +def clip(x, x_min, x_max): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "int32") + return jnp.clip(x, x_min, x_max) + + +def concatenate(xs, axis=0): + bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs) + if bcoo_count == len(xs): + axis = canonicalize_axis(axis, len(xs[0].shape)) + return jax_sparse.bcoo_concatenate(xs, dimension=axis) + elif bcoo_count: + xs = [ + x.todense() + if isinstance(x, jax_sparse.JAXSparse) + else convert_to_tensor(x) + for x in xs + ] + else: + xs = [convert_to_tensor(x) for x in xs] + return jnp.concatenate(xs, axis=axis) + + +@sparse.elementwise_unary(linear=True) +def conjugate(x): + x = convert_to_tensor(x) + return jnp.conjugate(x) + + +@sparse.elementwise_unary(linear=True) +def conj(x): + x = convert_to_tensor(x) + return jnp.conjugate(x) + + +@sparse.elementwise_unary(linear=True) +def copy(x): + x = convert_to_tensor(x) + return jnp.copy(x) + + +@sparse.densifying_unary +def cos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.cos(x) + + +@sparse.densifying_unary +def cosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.cosh(x) + + +def count_nonzero(x, axis=None): + return cast(jnp.count_nonzero(x, axis=axis), "int32") + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.cross( + x1, + x2, + axisa=axisa, + axisb=axisb, + axisc=axisc, + axis=axis, + ) + + +def cumprod(x, axis=None, dtype=None): + x = convert_to_tensor(x) + return jnp.cumprod(x, axis=axis, dtype=dtype) + + +def cumsum(x, axis=None, dtype=None): + x = convert_to_tensor(x) + return jnp.cumsum(x, axis=axis, dtype=dtype) + + +def deg2rad(x): + x = convert_to_tensor(x) + return jnp.deg2rad(x) + + +def diag(x, k=0): + x = convert_to_tensor(x) + return jnp.diag(x, k=k) + + +def diagflat(x, k=0): + x = convert_to_tensor(x) + return jnp.diagflat(x, k=k) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + x = convert_to_tensor(x) + return jnp.diagonal( + x, + offset=offset, + axis1=axis1, + axis2=axis2, + ) + + +def diff(a, n=1, axis=-1): + a = convert_to_tensor(a) + return jnp.diff(a, n=n, axis=axis) + + +@sparse.elementwise_unary(linear=False) +def digitize(x, bins): + x = convert_to_tensor(x) + bins = convert_to_tensor(bins) + return jnp.digitize(x, bins) + + +def dot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.dot(x1, x2) + + +def empty(shape, dtype=None): + dtype = dtype or config.floatx() + return jnp.empty(shape, dtype=dtype) + + +def equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.equal(x1, x2) + + +@sparse.densifying_unary +def exp(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, config.floatx()) + return jnp.exp(x) + + +@sparse.densifying_unary +def exp2(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, config.floatx()) + return jnp.exp2(x) + + +def expand_dims(x, axis): + x = convert_to_tensor(x) + if isinstance(x, jax_sparse.BCOO): + ( + _, + result_shape, + broadcast_dimensions, + ) = sparse.axis_shape_dims_for_broadcast_in_dim( + axis, x.shape, insert_dims=True + ) + return jax_sparse.bcoo_broadcast_in_dim( + x, shape=result_shape, broadcast_dimensions=broadcast_dimensions + ) + return jnp.expand_dims(x, axis) + + +@sparse.elementwise_unary(linear=False) +def expm1(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, config.floatx()) + return jnp.expm1(x) + + +def flip(x, axis=None): + return jnp.flip(x, axis=axis) + + +@sparse.elementwise_unary(linear=False) +def floor(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.floor(x) + + +def full(shape, fill_value, dtype=None): + dtype = dtype or config.floatx() + return jnp.full(shape, fill_value, dtype=dtype) + + +def full_like(x, fill_value, dtype=None): + return jnp.full_like(x, fill_value, dtype=dtype) + + +def gcd(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.gcd(x1, x2) + + +def greater(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.greater(x1, x2) + + +def greater_equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.greater_equal(x1, x2) + + +def hstack(xs): + return jnp.hstack(xs) + + +def identity(n, dtype=None): + dtype = dtype or config.floatx() + return jnp.identity(n, dtype=dtype) + + +@sparse.elementwise_unary(linear=True) +def imag(x): + x = convert_to_tensor(x) + return jnp.imag(x) + + +def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.isclose(x1, x2, rtol, atol, equal_nan) + + +@sparse.densifying_unary +def isfinite(x): + x = convert_to_tensor(x) + return jnp.isfinite(x) + + +def isin(x1, x2, assume_unique=False, invert=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.isin(x1, x2, assume_unique=assume_unique, invert=invert) + + +@sparse.elementwise_unary(linear=False) +def isinf(x): + x = convert_to_tensor(x) + return jnp.isinf(x) + + +@sparse.elementwise_unary(linear=False) +def isnan(x): + x = convert_to_tensor(x) + return jnp.isnan(x) + + +def isneginf(x): + x = convert_to_tensor(x) + return jnp.isneginf(x) + + +def isposinf(x): + x = convert_to_tensor(x) + return jnp.isposinf(x) + + +def kron(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.kron(x1, x2) + + +def lcm(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.lcm(x1, x2) + + +def less(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.less(x1, x2) + + +def less_equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.less_equal(x1, x2) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + return jnp.linspace( + start, + stop, + num=num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + + +@sparse.densifying_unary +def log(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return jnp.log(x) + + +@sparse.densifying_unary +def log10(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return jnp.log10(x) + + +@sparse.elementwise_unary(linear=False) +def log1p(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return jnp.log1p(x) + + +@sparse.densifying_unary +def log2(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return jnp.log2(x) + + +def logaddexp(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + return jnp.logaddexp(x1, x2) + + +def logaddexp2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + return jnp.logaddexp2(x1, x2) + + +def logical_and(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.logical_and(x1, x2) + + +def logical_not(x): + x = convert_to_tensor(x) + return jnp.logical_not(x) + + +def logical_or(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.logical_or(x1, x2) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + return jnp.logspace( + start, + stop, + num=num, + endpoint=endpoint, + base=base, + dtype=dtype, + axis=axis, + ) + + +@sparse.elementwise_binary_union(linear=False, use_sparsify=False) +def maximum(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.maximum(x1, x2) + + +def median(x, axis=None, keepdims=False): + # axis of jnp.median must be hashable + if isinstance(axis, list): + axis = tuple(axis) + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + + result = jnp.median(x, axis=axis, keepdims=keepdims) + + # TODO: with jax < 0.4.26 jnp.median failed to keepdims when axis is None + if keepdims is True and axis is None: + while result.ndim < x.ndim: + result = jnp.expand_dims(result, axis=-1) + return result + + +def meshgrid(*x, indexing="xy"): + return jnp.meshgrid(*x, indexing=indexing) + + +def min(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial) + + +@sparse.elementwise_binary_union(linear=False, use_sparsify=False) +def minimum(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.minimum(x1, x2) + + +def mod(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.mod(x1, x2) + + +def moveaxis(x, source, destination): + return jnp.moveaxis(x, source=source, destination=destination) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + x = convert_to_tensor(x) + return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +def ndim(x): + return jnp.ndim(x) + + +def nonzero(x): + return jnp.nonzero(x) + + +def not_equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.not_equal(x1, x2) + + +def ones_like(x, dtype=None): + return jnp.ones_like(x, dtype=dtype) + + +def zeros_like(x, dtype=None): + return jnp.zeros_like(x, dtype=dtype) + + +def outer(x1, x2): + return jnp.outer(x1, x2) + + +def pad(x, pad_width, mode="constant", constant_values=None): + x = convert_to_tensor(x) + kwargs = {} + if constant_values is not None: + if mode != "constant": + raise ValueError( + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`. " + f"Received: mode={mode}" + ) + kwargs["constant_values"] = constant_values + return jnp.pad(x, pad_width, mode=mode, **kwargs) + + +def prod(x, axis=None, keepdims=False, dtype=None): + x = convert_to_tensor(x) + return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + x = convert_to_tensor(x) + q = convert_to_tensor(q) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + + result = jnp.quantile(x, q, axis=axis, method=method, keepdims=keepdims) + + # TODO: with jax < 0.4.26 jnp.quantile failed to keepdims when axis is None + if keepdims is True and axis is None: + result_ndim = x.ndim + (1 if len(q.shape) > 0 else 0) + while result.ndim < result_ndim: + result = jnp.expand_dims(result, axis=-1) + return result + + +def ravel(x): + x = convert_to_tensor(x) + return jnp.ravel(x) + + +def unravel_index(indices, shape): + indices = convert_to_tensor(indices) + return jnp.unravel_index(indices, shape) + + +@sparse.elementwise_unary(linear=True) +def real(x): + x = convert_to_tensor(x) + return jnp.real(x) + + +@sparse.densifying_unary +def reciprocal(x): + x = convert_to_tensor(x) + return jnp.reciprocal(x) + + +def repeat(x, repeats, axis=None): + x = convert_to_tensor(x) + return jnp.repeat(x, repeats, axis=axis) + + +def reshape(x, newshape): + if isinstance(x, jax_sparse.BCOO): + from keras.src.ops import operation_utils + + # Resolve the -1 in `new_shape` if applicable and possible + output_shape = operation_utils.compute_reshape_output_shape( + x.shape, newshape, "new_shape" + ) + if None not in output_shape: + newshape = output_shape + return jax_sparse.bcoo_reshape(x, new_sizes=newshape) + x = convert_to_tensor(x) + return jnp.reshape(x, newshape) + + +def roll(x, shift, axis=None): + return jnp.roll(x, shift, axis=axis) + + +def searchsorted(sorted_sequence, values, side="left"): + if ndim(sorted_sequence) != 1: + raise ValueError( + "`searchsorted` only supports 1-D sorted sequences. " + "You can use `keras.ops.vectorized_map` " + "to extend it to N-D sequences. Received: " + f"sorted_sequence.shape={sorted_sequence.shape}" + ) + return jnp.searchsorted(sorted_sequence, values, side=side) + + +@sparse.elementwise_unary(linear=False) +def sign(x): + x = convert_to_tensor(x) + return jnp.sign(x) + + +@sparse.elementwise_unary(linear=False) +def signbit(x): + x = convert_to_tensor(x) + return jnp.signbit(x) + + +@sparse.elementwise_unary(linear=False) +def sin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.sin(x) + + +@sparse.elementwise_unary(linear=False) +def sinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.sinh(x) + + +def size(x): + return jnp.size(x) + + +def sort(x, axis=-1): + x = convert_to_tensor(x) + return jnp.sort(x, axis=axis) + + +def split(x, indices_or_sections, axis=0): + x = convert_to_tensor(x) + return jnp.split(x, indices_or_sections, axis=axis) + + +def stack(x, axis=0): + x = [convert_to_tensor(t) for t in x] + return jnp.stack(x, axis=axis) + + +def std(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return jnp.std(x, axis=axis, keepdims=keepdims) + + +def swapaxes(x, axis1, axis2): + x = convert_to_tensor(x) + return jnp.swapaxes(x, axis1=axis1, axis2=axis2) + + +def take(x, indices, axis=None): + x = convert_to_tensor(x) + indices = convert_to_tensor(indices, sparse=False) + return jnp.take(x, indices, axis=axis) + + +def take_along_axis(x, indices, axis=None): + return jnp.take_along_axis(x, indices, axis=axis) + + +@sparse.elementwise_unary(linear=False) +def tan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.tan(x) + + +@sparse.elementwise_unary(linear=False) +def tanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return jnp.tanh(x) + + +def tensordot(x1, x2, axes=2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.tensordot(x1, x2, axes=axes) + + +@sparse.elementwise_unary(linear=False) +def round(x, decimals=0): + x = convert_to_tensor(x) + + # jnp.round doesn't support decimals < 0 for integers + x_dtype = standardize_dtype(x.dtype) + if "int" in x_dtype and decimals < 0: + factor = cast(math.pow(10, decimals), config.floatx()) + x = cast(x, config.floatx()) + x = jnp.multiply(x, factor) + x = jnp.round(x) + x = jnp.divide(x, factor) + return cast(x, x_dtype) + else: + return jnp.round(x, decimals=decimals) + + +def tile(x, repeats): + return jnp.tile(x, repeats) + + +def trace(x, offset=0, axis1=0, axis2=1): + x = convert_to_tensor(x) + dtype = None + # TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27 + # for both CPU & GPU environments. + # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32 + # otherwise. + if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"): + dtype = "int32" + return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + + +def tri(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + return jnp.tri(N, M=M, k=k, dtype=dtype) + + +def tril(x, k=0): + x = convert_to_tensor(x) + return jnp.tril(x, k=k) + + +def triu(x, k=0): + x = convert_to_tensor(x) + return jnp.triu(x, k=k) + + +def trunc(x): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "int" in dtype or "bool" == dtype: + return x + return jnp.trunc(x) + + +def vdot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.vdot(x1, x2) + + +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.inner(x1, x2) + + +def vstack(xs): + return jnp.vstack(xs) + + +def vectorize(pyfunc, *, excluded=None, signature=None): + if excluded is None: + excluded = set() + return jnp.vectorize(pyfunc, excluded=excluded, signature=signature) + + +def where(condition, x1=None, x2=None): + return jnp.where(condition, x1, x2) + + +@sparse.elementwise_division +def divide(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.divide(x1, x2) + + +def divide_no_nan(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + safe_x2 = jnp.where(x2 == 0, 1, x2) + return jnp.where(x2 == 0, 0, jnp.divide(x1, safe_x2)) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.power(x1, x2) + + +@sparse.elementwise_unary(linear=True) +def negative(x): + x = convert_to_tensor(x) + return jnp.negative(x) + + +@sparse.elementwise_unary(linear=False) +def square(x): + x = convert_to_tensor(x) + return jnp.square(x) + + +@sparse.elementwise_unary(linear=False) +def sqrt(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return jnp.sqrt(x) + + +def squeeze(x, axis=None): + if isinstance(x, jax_sparse.BCOO): + if axis is None: + axis = tuple(i for i, d in enumerate(x.shape) if d == 1) + axis = to_tuple_or_list(axis) + return jax_sparse.bcoo_squeeze(x, dimensions=axis) + x = convert_to_tensor(x) + return jnp.squeeze(x, axis=axis) + + +def transpose(x, axes=None): + x = convert_to_tensor(x) + if isinstance(x, jax_sparse.BCOO): + num_dims = len(x.shape) + if axes is None: + permutation = tuple(range(num_dims)[::-1]) + else: + permutation = [] + for a in axes: + a = canonicalize_axis(a, num_dims) + permutation.append(a) + return jax_sparse.bcoo_transpose(x, permutation=permutation) + return jnp.transpose(x, axes=axes) + + +def var(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + # `jnp.var` does not handle low precision (e.g., float16) overflow + # correctly, so we compute with float32 and cast back to the original type. + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + return cast( + jnp.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype), + result_dtype, + ) + + +def sum(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + if isinstance(x, jax_sparse.BCOO): + if axis is None: + axis = tuple(range(len(x.shape))) + ( + canonical_axis, + keep_dims_shape, + broadcast_dimensions, + ) = sparse.axis_shape_dims_for_broadcast_in_dim( + axis, x.shape, insert_dims=False + ) + output = jax_sparse.bcoo_reduce_sum(x, axes=canonical_axis) + if keepdims: + # `bcoo_reduce_sum` does not support keepdims, neither does + # sparsify(jnp.sum), so we recreate the empty dimensions. + output = jax_sparse.bcoo_broadcast_in_dim( + output, + shape=keep_dims_shape, + broadcast_dimensions=broadcast_dimensions, + ) + return output + return jnp.sum(x, axis=axis, keepdims=keepdims) + + +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + return jnp.eye(N, M=M, k=k, dtype=dtype) + + +def floor_divide(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.floor_divide(x1, x2) + + +def logical_xor(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.logical_xor(x1, x2) + + +def corrcoef(x): + x = convert_to_tensor(x) + return jnp.corrcoef(x) + + +def correlate(x1, x2, mode="valid"): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return jnp.correlate(x1, x2, mode) + + +def select(condlist, choicelist, default=0): + return jnp.select(condlist, choicelist, default=default) + + +def slogdet(x): + x = convert_to_tensor(x) + return tuple(jnp.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + return jnp.argpartition(x, kth, axis) + + +def histogram(x, bins=10, range=None): + return jnp.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/jax/optimizer.py b/keras/src/backend/jax/optimizer.py new file mode 100644 index 000000000000..5cd6a40f65fb --- /dev/null +++ b/keras/src/backend/jax/optimizer.py @@ -0,0 +1,111 @@ +"""A class for JAX specific optimizer logic. + +Its purpose is to route around statelessness +requirements in cond ops used for EMA handling +and gradient accumulation handling. We do this +by skipping conditionals entirely. +""" + +import jax +from jax import numpy as jnp + +from keras.src.optimizers import base_optimizer + + +class JaxOptimizer(base_optimizer.BaseOptimizer): + def _backend_apply_gradients(self, grads, trainable_variables): + if self.gradient_accumulation_steps: + is_update_step = ( + self._iterations + 1 + ) % self.gradient_accumulation_steps == 0 + steps = self.gradient_accumulation_steps + + current_trainable_vars_value = [ + v.value for v in trainable_variables + ] + current_optimizer_vars_value = [v.value for v in self.variables] + + # `trainable_variables` might have been filtered in previous + # processing steps, so we need to ensure the correct mapping between + # `self._accumulated_gradients` and `trainable_variables` + acc_grads = [ + self._accumulated_gradients[self._get_variable_index(v)] + for v in trainable_variables + ] + + new_g_accs = jax.lax.cond( + is_update_step, + lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads], + lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)], + ) + + grads = jax.lax.cond( + is_update_step, + lambda: [ + (g + acc_g.value) / steps + for g, acc_g in zip(grads, acc_grads) + ], + lambda: list(grads), + ) + + # Apply clipping and weight decay. + grads = self._clip_gradients(grads) + self._apply_weight_decay(trainable_variables) + + self._backend_update_step( + grads, trainable_variables, self.learning_rate + ) + new_trainable_vars = jax.lax.cond( + is_update_step, + lambda: [v.value for v in trainable_variables], + lambda: current_trainable_vars_value, + ) + new_opt_vars = jax.lax.cond( + is_update_step, + lambda: [v.value for v in self.variables], + lambda: current_optimizer_vars_value, + ) + + for value, v in zip(new_trainable_vars, trainable_variables): + v.assign(value) + + for value, v in zip(new_opt_vars, self.variables): + v.assign(value) + + for n_g_acc, g_acc in zip(new_g_accs, acc_grads): + g_acc.assign(n_g_acc) + + else: + # Apply clipping and weight decay. + grads = self._clip_gradients(grads) + self._apply_weight_decay(trainable_variables) + + self._backend_update_step( + grads, trainable_variables, self.learning_rate + ) + + if self.use_ema: + self._update_model_variables_moving_average( + self._trainable_variables + ) + if self.ema_overwrite_frequency is not None: + should_overwrite_model_vars = ( + self.iterations + 1 + ) % self.ema_overwrite_frequency == 0 + should_overwrite_model_vars_int = ( + should_overwrite_model_vars.astype("int32") + ) + should_not_overwrite_model_vars_int = jnp.logical_not( + should_overwrite_model_vars + ).astype("int32") + current_trainable_vars_value = [ + v.value for v in self._trainable_variables + ] + for var, average_var in zip( + self._trainable_variables, + self._model_variables_moving_average, + ): + var.assign( + average_var * should_overwrite_model_vars_int + + var.value * should_not_overwrite_model_vars_int + ) diff --git a/keras/src/backend/jax/random.py b/keras/src/backend/jax/random.py new file mode 100644 index 000000000000..79901696339f --- /dev/null +++ b/keras/src/backend/jax/random.py @@ -0,0 +1,116 @@ +import jax + +from keras.src.backend.config import floatx +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +def jax_draw_seed(seed): + if isinstance(seed, jax.Array): + return seed + else: + return draw_seed(seed) + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = jax_draw_seed(seed) + sample = jax.random.normal(seed, shape=shape, dtype=dtype) + return sample * stddev + mean + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = jax_draw_seed(seed) + return jax.random.uniform( + seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval + ) + + +def categorical(logits, num_samples, dtype="int32", seed=None): + seed = jax_draw_seed(seed) + output_shape = list(logits.shape) + output_shape[1] = num_samples + output_shape = tuple(output_shape) + output = jax.random.categorical( + seed, logits[..., None], shape=output_shape, axis=1 + ) + return output.astype(dtype) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + seed = jax_draw_seed(seed) + return jax.random.randint( + seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval + ) + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = jax_draw_seed(seed) + sample = jax.random.truncated_normal( + seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype + ) + return sample * stddev + mean + + +def _get_concrete_noise_shape(inputs, noise_shape): + if noise_shape is None: + return inputs.shape + + concrete_inputs_shape = inputs.shape + concrete_noise_shape = [] + for i, value in enumerate(noise_shape): + concrete_noise_shape.append( + concrete_inputs_shape[i] if value is None else value + ) + return concrete_noise_shape + + +def dropout(inputs, rate, noise_shape=None, seed=None): + seed = jax_draw_seed(seed) + keep_prob = 1.0 - rate + # The `noise_shape` may contain `None` so we need to convert it + # into a concrete shape before passing it on to jax. + noise_shape = _get_concrete_noise_shape(inputs, noise_shape) + mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape) + mask = jax.numpy.broadcast_to(mask, inputs.shape) + return jax.lax.select( + mask, inputs / keep_prob, jax.numpy.zeros_like(inputs) + ) + + +def shuffle(x, axis=0, seed=None): + seed = jax_draw_seed(seed) + return jax.random.permutation(seed, x, axis, independent=True) + + +def gamma(shape, alpha, dtype=None, seed=None): + seed = jax_draw_seed(seed) + dtype = dtype or floatx() + return jax.random.gamma(seed, alpha, shape=shape, dtype=dtype) + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + dtype = dtype or floatx() + seed = jax_draw_seed(seed) + # jax doesn't accept python lists as arguments + counts = jax.numpy.array(counts) + probabilities = jax.numpy.array(probabilities) + sample = jax.random.binomial( + key=seed, n=counts, p=probabilities, shape=shape, dtype=dtype + ) + return sample + + +def beta(shape, alpha, beta, dtype=None, seed=None): + dtype = dtype or floatx() + seed = jax_draw_seed(seed) + # jax doesn't accept python lists as arguments + alpha = jax.numpy.array(alpha) + beta = jax.numpy.array(beta) + sample = jax.random.beta( + key=seed, a=alpha, b=beta, shape=shape, dtype=dtype + ) + return sample diff --git a/keras/src/backend/jax/rnn.py b/keras/src/backend/jax/rnn.py new file mode 100644 index 000000000000..ec7e5146acf1 --- /dev/null +++ b/keras/src/backend/jax/rnn.py @@ -0,0 +1,230 @@ +import contextlib + +from jax import lax +from jax import numpy as jnp + +from keras.src import tree +from keras.src.backend.common import stateless_scope + + +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return jnp.transpose(input_t, axes) + + if not time_major: + inputs = tree.map_structure(swap_batch_timestep, inputs) + + flattened_inputs = tree.flatten(inputs) + time_steps = flattened_inputs[0].shape[0] + + if mask is not None: + if mask.dtype != "bool": + mask = mask.astype("bool") + if len(mask.shape) == 2: + mask = jnp.expand_dims(mask, axis=-1) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tree.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor, but got {mask_t}" + ) + if tree.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor, but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = jnp.expand_dims(mask_t, -1) + multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:]) + return jnp.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = unstack(input_t) # unstack for time_step dim + if go_backwards: + input_t.reverse() + return input_t + + if tree.is_nested(inputs): + processed_input = tree.map_structure( + _process_single_input_t, inputs + ) + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return tree.pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = unstack(mask) + if go_backwards: + mask_list.reverse() + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = jnp.zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = jnp.where(tiled_mask_t, output, prev_output) + + flat_states = tree.flatten(states) + flat_new_states = tree.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) + flat_final_states = tuple( + jnp.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) + ) + states = tree.pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = jnp.stack(successive_outputs) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = jnp.stack(successive_outputs) + + else: # Unroll == False + if mask is not None: + + def _step(states, current_input): + current_input, current_mask = current_input + is_masked = jnp.all( + jnp.logical_not(current_mask), axis=-1, keepdims=True + ) + + output_t, new_states = step_function(current_input, states) + + if zero_output_for_mask: + masked_outs = jnp.where( + is_masked, jnp.zeros_like(output_t), output_t + ) + else: + # Assume the first state is the previous output. + output_tm1 = states[0] + if tree.is_nested(output_tm1): + # Stacked RNN case: assume first state of last cell. + output_tm1 = states[-1][0] + masked_outs = jnp.where(is_masked, output_tm1, output_t) + + new_states = tree.map_structure( + lambda s, ns: jnp.where(is_masked, s, ns), + states, + new_states, + ) + return (new_states, masked_outs) + + scan_xs = (inputs, mask) + + else: + + def _step(states, current_input): + output_t, new_states = step_function(current_input, states) + return new_states, output_t + + scan_xs = inputs + + if stateless_scope.in_stateless_scope(): + # Reuse the existing parent stateless scope. + scope = contextlib.nullcontext() + else: + scope = stateless_scope.StatelessScope() + with scope: + # We must use a stateless scope because `scan` will involve + # JAX tracing -- any variable update at this stage would + # be a leak. + new_states, outputs = lax.scan( + f=_step, + init=initial_states, + xs=scan_xs, + reverse=go_backwards, + ) + if go_backwards: + outputs = jnp.flip(outputs, axis=0) + last_output = outputs[-1] + + if not time_major: + outputs = tree.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +def cudnn_ok(*args, **kwargs): + return False + + +def lstm(*args, **kwargs): + raise NotImplementedError + + +def gru(*args, **kwargs): + raise NotImplementedError + + +def unstack(x, axis=0): + return [ + lax.index_in_dim(x, i, axis, keepdims=False) + for i in range(x.shape[axis]) + ] diff --git a/keras/src/backend/jax/sparse.py b/keras/src/backend/jax/sparse.py new file mode 100644 index 000000000000..f2d7f19d7d16 --- /dev/null +++ b/keras/src/backend/jax/sparse.py @@ -0,0 +1,328 @@ +import functools + +import jax.experimental.sparse as jax_sparse +import jax.numpy as jnp + +from keras.src.utils import jax_utils + + +def axis_shape_dims_for_broadcast_in_dim(axis, input_shape, insert_dims): + """Turn the `axis` argument to the arguments needed by `broadcast_in_dim`. + + Args: + axis: single int or a tuple of ints for the axis argument. The list of + dimensions to reduce or insert. + input_shape: the shape of the input as a tuple ints. + insert_dims: `False` turns dimensions in `axis` to 1s (use case: + reduction along `axis` with `keep_dims=True`). `True`, inserts 1s + according to `axis` (use case: `expand_dims`). + Returns: + A tuple of three lists + - The canonical value for `axis`: always a list, negative values have + been resolved and values are sorted in ascending order. + - The output shape: `input_shape` with 1s at the indices in `axis`, for + use as the `shape` argument of `broadcast_in_dim`. + - The broadcast dimensions: list of dimensions not in `axis`, for use as + the `broadcast_dimensions` argument of `broadcast_in_dim`. + """ + if axis is None: + raise ValueError("Received `None` value for `axis`") + if isinstance(axis, int): + axis = (axis,) + # Check uniqueness. + if len(set(axis)) != len(axis): + raise ValueError(f"Repeated axis in `axis`: {axis}") + result_dims = len(input_shape) + if insert_dims: + result_dims += len(axis) + + # Resolve negative values. + canonical_axis = [] + for a in axis: + if not -result_dims <= a < result_dims: + raise ValueError( + f"In `axis`, axis {a} is out of bounds for array " + f"of dimension {result_dims}" + ) + if a < 0: + a = a + result_dims + canonical_axis.append(a) + + # Check uniqueness again after resolving negative values. + if len(set(canonical_axis)) != len(canonical_axis): + raise ValueError(f"Repeated axis in `axis`: {canonical_axis}") + canonical_axis = sorted(canonical_axis) + + # Compute output shape. + output_shape = list(input_shape) + for i in canonical_axis: + if insert_dims: + output_shape.insert(i, 1) + else: + output_shape[i] = 1 + broadcast_dims = [i for i in range(result_dims) if i not in canonical_axis] + return canonical_axis, output_shape, broadcast_dims + + +def bcoo_add_indices(x1, x2, sum_duplicates): + """Add the indices of `x2` to `x1` with zero values. + + Args: + x1: `BCOO` tensor to add indices to. + x2: `BCOO` tensor to take the indices to add to x1. + sum_duplicates: if `True` calls `bcoo_sum_duplicates` on the output. + Returns: + a `BCOO` tensor equal to `x1` but with extra zeros at indices in `x2` + that were missing in `x1`. + """ + x2_zeros = jnp.zeros(x2.data.shape, x1.data.dtype) + concat_axis = len(x1.indices.shape) - 2 + output_indices = jnp.concatenate([x1.indices, x2.indices], axis=concat_axis) + output_data = jnp.concatenate([x1.data, x2_zeros], axis=concat_axis) + output = jax_sparse.BCOO((output_data, output_indices), shape=x1.shape) + if sum_duplicates: + output = jax_sparse.bcoo_sum_duplicates(output) + return output + + +def densifying_unary(func): + """Decorator to add support for `JAXSparse` tensors (including `BCOO`) to a + non-zero-preserving element-wise unary operator. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise + - The operator must be unary (one input tensor and one output tensor) + - The operator must return a tensor of the same shape. + + Additional arguments to the function (besides the input tensor) are + supported. The returned result is a dense tensor. + + Args: + func: The unary operator to wrap. + Returns: + Wrapped function that supports `JAXSparse` tensors. + """ + + @functools.wraps(func) + def sparse_wrapper(x, *args, **kwargs): + if isinstance(x, jax_sparse.JAXSparse): + x = x.todense() + return func(x, *args, **kwargs) + + return sparse_wrapper + + +def elementwise_unary(linear): + """Decorator to add support for `BCOO` sparse tensors to a zero-preserving + element-wise unary operator. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise + - The operator must be unary (one input tensor and one output tensor) + - The operator must return a tensor of the same shape, and if it is a + `BCOO` tensor, the indices of the result must be the same. Therefore: + - Reduction operations are not supported (e.g. `mean`). + - Operations for which the result may be dense (e.g. `reciprocal`), or + the sparse indices depend on the inputs are not supported (e.g. + `clip`). This implies that `func(0)` must be 0. + + Additional arguments to the function (besides the input tensor) are + supported as long as they cannot change the indices of the result. For + instance,`round` is supported, but `clip` is not supported as + `clip(x, 1.0, 2.0)` would always return a dense tensor. + + Note that if an input sparse tensor contains zero values, the indices and + the zero values are preserved. + + Args: + linear: if `True`, means that the operation is such that + `op(a + b) == op(a) + op(b)`. + Returns: + Wrapped function that supports `BCOO` sparse tensors. + """ + + def wrap_elementwise_unary(func): + @functools.wraps(func) + def sparse_wrapper(x, *args, **kwargs): + if isinstance(x, jax_sparse.BCOO): + if not linear and not x.unique_indices: + x = jax_sparse.bcoo_sum_duplicates(x) + return jax_sparse.BCOO( + (func(x.data, *args, **kwargs), x.indices), shape=x.shape + ) + else: + return func(x, *args, **kwargs) + + return sparse_wrapper + + return wrap_elementwise_unary + + +def elementwise_binary_union(linear, use_sparsify): + """Decorator to add support for `JAXSparse` tensors (including `BCOO`) to an + element-wise binary operator such that the indices present in the result are + are the union of the indices in the two operand. + + The primary use case for this is the `add` and `subtract` operators. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise. + - The operator must be binary (two input tensors and one output tensor). + - Both inputs must be of the same shape or one input must be a scalar. + - The output must be of the same shape as the (non scalar) inputs. + - The indices of the output must be the union of the indices of the inputs. + This implies that func(0, 0) must be 0. As a result, if one operand is + dense or a scalar, then the result will be dense. + + Additional arguments to the function (besides the input tensors) are not + supported. + + Note that if the result of the operation is zero at some indices, including + because the operands were zero at these indices, the zeros and indices are + preserved. + + The `BCOO` format is the only supported one in all cases. Other formats are + not supported when `use_sparsify` is `False`. + + Args: + use_sparsify: indicates that the JAX `sparsify` transform supports this + operation. + linear: if `True`, mean that the operation is such that + `op(a + b, c) == op(a, c) + op(b, c)` and + `op(a, c + d) == op(a, c) + op(a, d)`. + Returns: + Wrapped function that supports `JAXSparse`. + """ + + def wrap_elementwise_binary_union(func): + sparse_func = jax_sparse.sparsify(func) if use_sparsify else None + + @functools.wraps(func) + def sparse_wrapper(x1, x2): + if isinstance(x1, jax_sparse.JAXSparse): + if isinstance(x2, jax_sparse.JAXSparse): + # x1 and x2 are sparse. + # The way we use `sparsify` it cannot know that the indices + # are the same, so we optimize this case here. + if ( + x1.indices is x2.indices + and isinstance(x1, jax_sparse.BCOO) + and isinstance(x2, jax_sparse.BCOO) + ): + if not linear and not x1.unique_indices: + x1 = jax_sparse.bcoo_sum_duplicates(x1) + x2 = jax_sparse.bcoo_sum_duplicates(x2) + return jax_sparse.BCOO( + (func(x1.data, x2.data), x1.indices), + shape=x1.shape, + indices_sorted=x1.indices_sorted, + unique_indices=x1.unique_indices, + ) + elif use_sparsify: + return sparse_func(x1, x2) + elif isinstance(x1, jax_sparse.BCOO) and isinstance( + x2, jax_sparse.BCOO + ): + x1 = bcoo_add_indices(x1, x2, sum_duplicates=not linear) + x2 = bcoo_add_indices(x2, x1, sum_duplicates=not linear) + return jax_sparse.BCOO( + (func(x1.data, x2.data), x1.indices), + shape=x1.shape, + indices_sorted=True, + unique_indices=True, + ) + else: + ValueError( + "Unsupported sparse format: " + f"{x1.__class__} and {x2.__class__}" + ) + else: + # x1 is sparse, x2 is dense, densify x2. + x1 = x1.todense() + elif isinstance(x2, jax_sparse.JAXSparse): + # x1 is dense, x2 is sparse, densify x2. + x2 = x2.todense() + return func(x1, x2) + + return sparse_wrapper + + return wrap_elementwise_binary_union + + +def elementwise_division(func): + """Decorator to add support for `BCOO` sparse tensors to element-wise binary + division and related operators. + + This decorator is designed for operations related to the division of two + two operands (e.g. `divide`). It accepts `BCOO` tensors for both the + dividend and the divisor, but handles them differently based on whether they + are the dividend or the divisor. + + - If the divisor is sparse, it is densified and the result is dense because + the result contains Inf or Nan outside of the indices of the dividend. + - If the dividend is sparse and the divisor is dense, it finds occurrences + of zeros and NaNs in the divisor. The result may therefore have more + indices than there were in the dividend to return correct values where the + divisor was zero or NaN. + - If the dividend is sparse and the divisor is a scalar, it does the + division element-wise. Note that the result is incorrectly sparse if the + scalar divisor is zero. + + Args: + func: The function to wrap. + Returns: + Wrapped function that supports `BCOO` sparse tensors. + """ + sparse_func = jax_sparse.sparsify(func) + + @functools.wraps(func) + def sparse_wrapper(x1, x2): + if isinstance(x1, jax_sparse.JAXSparse): + if isinstance(x2, jax_sparse.JAXSparse): + # x1 is sparse and x2 is sparse. + # Divisor is sparse, meaning we're doing divisions by zero + # outside of x2.indices, so the result is dense. Densify both. + x1 = x1.todense() + x2 = x2.todense() + elif isinstance(x1, jax_sparse.BCOO): + if not hasattr(x2, "shape") or len(x2.shape) == 0: + # x1 is sparse BCOO, x2 is scalar, apply func element-wise. + return jax_sparse.BCOO( + (func(x1.data, x2), x1.indices), + shape=x1.shape, + indices_sorted=x1.indices_sorted, + unique_indices=x1.unique_indices, + ) + else: + # x1 is sparse BCOO, x2 is dense. + if not jax_utils.is_in_jax_tracing_scope(x2): + # Find zeros and nans in x2 and add indices to x1. + # 1. Create a dense mask for zeros and nans. + x2_zeros_and_nans = jnp.equal(x2, 0) + if not jnp.issubdtype(x2.dtype, jnp.integer): + x2_zeros_and_nans = jnp.logical_or( + x2_zeros_and_nans, jnp.isnan(x2) + ) + # 2. Make it a BCOO of True values. + x2_zeros_and_nans = jax_sparse.bcoo_fromdense( + x2_zeros_and_nans, + n_batch=x1.n_batch, + n_dense=x1.n_dense, + index_dtype=x1.indices.dtype, + ) + # 3. Add the indices to x1. + x1 = bcoo_add_indices( + x1, x2_zeros_and_nans, sum_duplicates=True + ) + return sparse_func(x1, x2) + else: + raise ValueError(f"Unsupported sparse format: {x1.__class__}") + elif isinstance(x2, jax_sparse.JAXSparse): + # x1 is dense, x2 is sparse, densify x2 + x2 = x2.todense() + return func(x1, x2) + + return sparse_wrapper diff --git a/keras/src/backend/jax/tensorboard.py b/keras/src/backend/jax/tensorboard.py new file mode 100644 index 000000000000..d8f105b3a9f2 --- /dev/null +++ b/keras/src/backend/jax/tensorboard.py @@ -0,0 +1,23 @@ +from keras.src.utils.module_utils import jax + + +def start_trace(logdir): + if logdir: + jax.profiler.start_trace(logdir) + + +def stop_trace(save): + if save: + jax.profiler.stop_trace() + + +def start_batch_trace(batch): + batch_trace_context = jax.profiler.TraceAnnotation( + f"Profiled batch {batch}" + ) + batch_trace_context.__enter__() + return batch_trace_context + + +def stop_batch_trace(batch_trace_context): + batch_trace_context.__exit__(None, None, None) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py new file mode 100644 index 000000000000..5f01505c2d47 --- /dev/null +++ b/keras/src/backend/jax/trainer.py @@ -0,0 +1,1038 @@ +import collections +import itertools +import warnings +from functools import partial + +import jax +import numpy as np + +from keras.src import backend +from keras.src import callbacks as callbacks_module +from keras.src import optimizers as optimizers_module +from keras.src import tree +from keras.src.backend import config +from keras.src.backend import distribution_lib as jax_distribution_lib +from keras.src.backend.config import is_nnx_enabled +from keras.src.distribution import distribution_lib +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import array_slicing +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + +if is_nnx_enabled(): + from flax import nnx + + jit = nnx.jit +else: + jit = jax.jit + + +class JAXTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.train_function = None + self.test_function = None + self.predict_function = None + self._jax_state_synced = True + + def compute_loss_and_updates( + self, + trainable_variables, + non_trainable_variables, + metrics_variables, + x, + y, + sample_weight, + training=False, + optimizer_variables=None, + ): + """This method is stateless and is intended for use with jax.grad.""" + kwargs = {} + if self._call_has_training_arg: + kwargs["training"] = training + + # Run stateless forward pass + y_pred, non_trainable_variables, losses = self.stateless_call( + trainable_variables, + non_trainable_variables, + x, + return_losses=True, + **kwargs, + ) + if losses: + # Make forward pass losses available to compute_loss. + self._losses_override.clear() + self._losses_override = losses + + loss, variables = self.stateless_compute_loss( + trainable_variables, + non_trainable_variables, + metrics_variables, + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=training, + ) + if losses: + self._losses_override.clear() + (trainable_variables, non_trainable_variables, metrics_variables) = ( + variables + ) + + # Handle loss scaling + unscaled_loss = loss + if training and self.optimizer is not None: + # Scale loss with a StatelessScope, to use an update scale variable. + mapping = list(zip(self.optimizer.variables, optimizer_variables)) + with backend.StatelessScope(state_mapping=mapping): + loss = self.optimizer.scale_loss(loss) + return loss, ( + unscaled_loss, + y_pred, + non_trainable_variables, + metrics_variables, + ) + + def _update_metrics_variables( + self, metrics_variables, unscaled_loss, x, y, y_pred, sample_weight + ): + with backend.StatelessScope( + state_mapping=[ + (ref_v, v) + for ref_v, v in zip(self.metrics_variables, metrics_variables) + ] + ) as scope: + self._loss_tracker.update_state( + unscaled_loss, + sample_weight=next( + i for i in tree.flatten(x) if i is not None + ).shape[0], + ) + logs = self.compute_metrics(x, y, y_pred, sample_weight) + + new_metrics_variables = [] + for ref_v in self.metrics_variables: + new_v = scope.get_current_value(ref_v) + if new_v is None: + new_v = ref_v.value + new_metrics_variables.append(new_v) + return logs, new_metrics_variables + + def train_step(self, state, data): + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + grad_fn = jax.value_and_grad( + self.compute_loss_and_updates, has_aux=True + ) + (loss, aux), grads = grad_fn( + trainable_variables, + non_trainable_variables, + metrics_variables, + x, + y, + sample_weight, + training=True, + optimizer_variables=optimizer_variables, + ) + (unscaled_loss, y_pred, non_trainable_variables, metrics_variables) = ( + aux + ) + + ( + trainable_variables, + optimizer_variables, + ) = self.optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + + logs, metrics_variables = self._update_metrics_variables( + metrics_variables, unscaled_loss, x, y, y_pred, sample_weight + ) + + state = ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) + return logs, state + + def test_step(self, state, data): + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state + x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + loss, aux = self.compute_loss_and_updates( + trainable_variables, + non_trainable_variables, + metrics_variables, + x, + y, + sample_weight, + training=False, + ) + (unscaled_loss, y_pred, non_trainable_variables, metrics_variables) = ( + aux + ) + + logs, metrics_variables = self._update_metrics_variables( + metrics_variables, unscaled_loss, x, y, y_pred, sample_weight + ) + + state = ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) + return logs, state + + def predict_step(self, state, data): + trainable_variables, non_trainable_variables = state + kwargs = {} + if self._call_has_training_arg: + kwargs["training"] = False + + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + outputs, non_trainable_variables = self.stateless_call( + trainable_variables, non_trainable_variables, x, **kwargs + ) + return outputs, non_trainable_variables + + def _make_function(self, step_function, concatenate_outputs=False): + if self.steps_per_execution > 1: + if concatenate_outputs: + + def concatenate(outputs): + output = outputs[0] + for next_output in outputs[1:]: + output = tree.map_structure( + lambda t1, t2: jax.numpy.concatenate([t1, t2]), + output, + next_output, + ) + return output + + if not self.run_eagerly and self.jit_compile: + concatenate = jit(concatenate) + + def iterator_step(state, iterator): + data = next(iterator) + outputs, state = step_function(state, data) + outputs = [outputs] + try: + for _ in range(self.steps_per_execution - 1): + data = next(iterator) + _outputs, state = step_function(state, data) + outputs.append(_outputs) + except StopIteration: + pass + outputs = concatenate(outputs) + return outputs, state + + else: + + def iterator_step(state, iterator): + data = next(iterator) + outputs, state = step_function(state, data) + try: + for _ in range(self.steps_per_execution - 1): + data = next(iterator) + outputs, state = step_function(state, data) + except StopIteration: + pass + return outputs, state + + else: + + def iterator_step(state, iterator): + return step_function(state, next(iterator)) + + return iterator_step + + def make_train_function(self, force=False): + if self.train_function is not None and not force: + return + if not self.run_eagerly and self.jit_compile: + out_shardings = None + if distribution_lib.distribution() is not None: + state_shardings = self._get_state_sharding_spec() + out_shardings = (None, state_shardings) + train_step = jit( + self.train_step, + donate_argnums=0, + out_shardings=out_shardings, + ) + else: + train_step = self.train_step + + step_function = self._make_function(train_step) + + self.train_function = step_function + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return + if not self.run_eagerly and self.jit_compile: + out_shardings = None + if distribution_lib.distribution() is not None: + ( + trainable_shardings, + non_trainable_shardings, + _, # optimizer_shardings + metrics_shardings, + ) = self._get_state_sharding_spec() + state_shardings = ( + trainable_shardings, + non_trainable_shardings, + metrics_shardings, + ) + out_shardings = (None, state_shardings) + test_step = jit( + self.test_step, + donate_argnums=0, + out_shardings=out_shardings, + ) + else: + test_step = self.test_step + + step_function = self._make_function(test_step) + + self.test_function = step_function + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + def predict_step(state, data): + outputs, non_trainable_variables = self.predict_step(state, data) + return outputs, (state[0], non_trainable_variables) + + if not self.run_eagerly and self.jit_compile: + out_shardings = None + if distribution_lib.distribution() is not None: + ( + trainable_shardings, + non_trainable_shardings, + _, # optimizer_shardings + _, # metrics_shardings + ) = self._get_state_sharding_spec() + state_shardings = ( + trainable_shardings, + non_trainable_shardings, + ) + out_shardings = (None, state_shardings) + predict_step = jit( + predict_step, + donate_argnums=0, + out_shardings=out_shardings, + ) + + _step_function = self._make_function( + predict_step, concatenate_outputs=True + ) + + def step_function(state, iterator): + outputs, state = _step_function(state, iterator) + return outputs, state + + self.predict_function = step_function + + @traceback_utils.filter_traceback + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + self._assert_compile_called("fit") + # Possibly cap epochs for debugging runs. + max_epochs = config.max_epochs() + if max_epochs and max_epochs < epochs: + warnings.warn("Limiting epochs to %d" % max_epochs) + epochs = max_epochs + # TODO: respect compiled trainable state + self._eval_epoch_iterator = None + if validation_split and validation_data is None: + # Create the validation data using the training data. Only supported + # for TF/numpy/jax arrays. + ( + (x, y, sample_weight), + validation_data, + ) = array_slicing.train_validation_split( + (x, y, sample_weight), validation_split=validation_split + ) + + if validation_data is not None: + ( + val_x, + val_y, + val_sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(validation_data) + + # Create an iterator that yields batches for one epoch. + epoch_iterator = JAXEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + shuffle=shuffle, + class_weight=class_weight, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=epochs, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_train_function() + self.stop_training = False + training_logs = {} + training_finished = False + callbacks.on_train_begin() + initial_epoch = self._initial_epoch or initial_epoch + try: + for epoch in range(initial_epoch, epochs): + self.reset_metrics() + callbacks.on_epoch_begin(epoch) + + self._jax_state_synced = True + with epoch_iterator.catch_stop_iteration(): + for begin_step, end_step, iterator in epoch_iterator: + # Callbacks + callbacks.on_train_batch_begin(begin_step) + + # Train step + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + metrics_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + + logs, state = self.train_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + + # Setting _jax_state enables callbacks to force a state + # sync if they need to. + self._jax_state = { + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "optimizer_variables": optimizer_variables, + "metrics_variables": metrics_variables, + } + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_train_batch_end(end_step, logs) + + if self.stop_training: + # Stop training if a callback has set + # this flag in on_(train_)batch_end. + break + + # Reattach state to the model + # (if not already done by a callback). + # NOTE: doing this after each step would be a big performance + # bottleneck. + self.jax_state_sync() + + # Override with model metrics instead of last step logs if + # needed. + epoch_logs = dict(self._get_metrics_result_or_logs(logs)) + + # Run validation. + if validation_data is not None and self._should_eval( + epoch, validation_freq + ): + # Create JAXEpochIterator for evaluation and cache it. + if getattr(self, "_eval_epoch_iterator", None) is None: + self._eval_epoch_iterator = JAXEpochIterator( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps_per_execution=self.steps_per_execution, + steps_per_epoch=validation_steps, + shuffle=False, + ) + val_logs = self.evaluate( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps=validation_steps, + callbacks=callbacks, + return_dict=True, + _use_cached_eval_dataset=True, + ) + val_logs = { + f"val_{name}": val for name, val in val_logs.items() + } + epoch_logs.update(val_logs) + + callbacks.on_epoch_end(epoch, epoch_logs) + training_logs = epoch_logs + if self.stop_training: + break + training_finished = True + + finally: + self.jax_state_sync() + if ( + isinstance(self.optimizer, optimizers_module.Optimizer) + and epochs > 0 + ): + self.optimizer.finalize_variable_values(self.trainable_weights) + + # If _eval_epoch_iterator exists, delete it after all epochs + # are done. + if getattr(self, "_eval_epoch_iterator", None) is not None: + del self._eval_epoch_iterator + if training_finished: + callbacks.on_train_end(logs=training_logs) + self._jax_state = None + return self.history + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + self._assert_compile_called("evaluate") + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of + # input/target data. + epoch_iterator = JAXEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + self.reset_metrics() + + self._jax_state_synced = True + with epoch_iterator.catch_stop_iteration(): + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) + + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + metrics_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + + logs, state = self.test_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) = state + + # Setting _jax_state enables callbacks to force a state sync + # if they need to. + self._jax_state = { + # I wouldn't recommend modifying non-trainable model state + # during evaluate(), but it's allowed. + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "metrics_variables": metrics_variables, + } + + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_test_batch_end(end_step, logs) + + if self.stop_evaluating: + break + + # Reattach state back to model (if not already done by a callback). + self.jax_state_sync() + + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + self._jax_state = None + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = JAXEpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + if not all(layer.built for layer in self._flatten_layers()): + # Build the model on one batch of data. + for _, _, iterator in epoch_iterator: + # Build model + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight( + next(iterator) + ) + if is_nnx_enabled(): + self(x) + else: + with backend.StatelessScope(): + self(x) + break + epoch_iterator.reset() + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + self._jax_state_synced = True + outputs = None + non_trainable_variables = None + with epoch_iterator.catch_stop_iteration(): + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) + if self._jax_state_synced: + # The state may have been synced by a callback. + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + purge_model_variables=True, + ) + self._jax_state_synced = False + batch_outputs, state = self.predict_function(state, iterator) + ( + trainable_variables, + non_trainable_variables, + ) = state + self._jax_state = { + "trainable_variables": trainable_variables, + # I wouldn't recommend modifying non-trainable model state + # during predict(), but it's allowed. + "non_trainable_variables": non_trainable_variables, + } + outputs = append_to_outputs(batch_outputs, outputs) + + # Dispatch callbacks. This takes care of async dispatch. + callbacks.on_predict_batch_end( + end_step, {"outputs": batch_outputs} + ) + + if self.stop_predicting: + break + + self.jax_state_sync() + callbacks.on_predict_end() + self._jax_state = None + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + self._assert_compile_called("train_on_batch") + if class_weight is not None: + if sample_weight is not None: + raise ValueError( + "Arguments `sample_weight` and `class_weight` " + "cannot be specified at the same time. " + f"Received: sample_weight={sample_weight}, " + f"class_weight={class_weight}" + ) + sample_weight = data_adapter_utils.class_weight_to_sample_weights( + y, class_weight + ) + + def data(): + yield _distribute_data((x, y, sample_weight)) + + # Maybe build model + self._symbolic_build(data_batch=next(data())) + self.make_train_function() + + # Train step + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + optimizer_variables=True, + metrics_variables=True, + purge_model_variables=False, + ) + self._jax_state_synced = False + logs, state = self.train_function(state, data()) + + # State sync + ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + metrics_variables, + ) = state + self._jax_state = { + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "optimizer_variables": optimizer_variables, + "metrics_variables": metrics_variables, + } + self.jax_state_sync() + + # Format return values + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + self._assert_compile_called("test_on_batch") + + def data(): + yield _distribute_data((x, y, sample_weight)) + + # Maybe build model + self._symbolic_build(data_batch=next(data())) + self.make_test_function() + + # Test step + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + metrics_variables=True, + purge_model_variables=False, + ) + self._jax_state_synced = False + logs, state = self.test_function(state, data()) + + # State sync + trainable_variables, non_trainable_variables, metrics_variables = state + self._jax_state = { + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + "metrics_variables": metrics_variables, + } + self.jax_state_sync() + + # Format return values. + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def predict_on_batch(self, x): + if not all(layer.built for layer in self._flatten_layers()): + # Build model + with backend.StatelessScope(): + self(x) + self.make_predict_function() + + state = self._get_jax_state( + trainable_variables=True, + non_trainable_variables=True, + metrics_variables=False, + purge_model_variables=False, + ) + self._jax_state_synced = False + + def data(): + yield (x,) + + batch_outputs, state = self.predict_function(state, data()) + trainable_variables, non_trainable_variables = state + self._jax_state = { + "trainable_variables": trainable_variables, + "non_trainable_variables": non_trainable_variables, + } + self.jax_state_sync() + batch_outputs = tree.map_structure(lambda x: np.array(x), batch_outputs) + return batch_outputs + + def jax_state_sync(self): + if not getattr(self, "_jax_state", None) or self._jax_state_synced: + return + + trainable_variables = self._jax_state.get("trainable_variables", None) + non_trainable_variables = self._jax_state.get( + "non_trainable_variables", None + ) + optimizer_variables = self._jax_state.get("optimizer_variables", None) + metrics_variables = self._jax_state.get("metrics_variables", None) + if trainable_variables: + for ref_v, v in zip(self.trainable_variables, trainable_variables): + ref_v.assign(v) + if non_trainable_variables: + for ref_v, v in zip( + self.non_trainable_variables, non_trainable_variables + ): + ref_v.assign(v) + if optimizer_variables: + for ref_v, v in zip(self.optimizer.variables, optimizer_variables): + ref_v.assign(v) + if metrics_variables: + for ref_v, v in zip(self.metrics_variables, metrics_variables): + ref_v.assign(v) + self._jax_state_synced = True + + def _get_state_sharding_spec(self): + trainable_shardings = [ + v.value.sharding for v in self.trainable_variables + ] + non_trainable_shardings = [ + v.value.sharding for v in self.non_trainable_variables + ] + if hasattr(self, "optimizer") and self.optimizer is not None: + optimizer_shardings = [ + v.value.sharding for v in self.optimizer.variables + ] + else: + optimizer_shardings = [] + metrics_shardings = [v.value.sharding for v in self.metrics_variables] + return ( + trainable_shardings, + non_trainable_shardings, + optimizer_shardings, + metrics_shardings, + ) + + def _purge_model_variables( + self, + trainable_variables=False, + non_trainable_variables=False, + optimizer_variables=False, + metrics_variables=False, + ): + """Remove all the model variable for memory saving. + + During JAX training, since the training function is stateless, we have + to pass in and get the model weights over and over, during which the + copy of the weights that attached to the Variable are still and + occupying extra memory. We remove those variable to save memory (for + better memory utilization) at the beginning of the epoch, and reattach + the value back to variables at the end of the epoch, via + `jax_state_sync()`. + """ + if trainable_variables: + for v in self.trainable_variables: + v._value = None + if non_trainable_variables: + for v in self.non_trainable_variables: + v._value = None + if optimizer_variables: + for v in self.optimizer.variables: + v._value = None + if metrics_variables: + for v in self.metrics_variables: + v._value = None + + def _get_jax_state( + self, + trainable_variables=False, + non_trainable_variables=False, + optimizer_variables=False, + metrics_variables=False, + purge_model_variables=False, + ): + state = [] + if trainable_variables: + state.append([v.value for v in self.trainable_variables]) + if non_trainable_variables: + state.append([v.value for v in self.non_trainable_variables]) + if optimizer_variables: + state.append([v.value for v in self.optimizer.variables]) + if metrics_variables: + state.append([v.value for v in self.metrics_variables]) + if purge_model_variables: + self._purge_model_variables( + trainable_variables=trainable_variables, + non_trainable_variables=non_trainable_variables, + optimizer_variables=optimizer_variables, + metrics_variables=metrics_variables, + ) + return tuple(state) + + +def _distribute_data(data, layouts=None): + distribution = distribution_lib.distribution() + + if distribution is not None: + if layouts is None: + layouts = tree.map_structure( + lambda d: distribution.get_data_layout(d.shape), + data, + ) + jax_dist_data_input = partial( + jax_distribution_lib.distribute_data_input, + batch_dim_name=distribution.batch_dim_name, + ) + return tree.map_structure(jax_dist_data_input, data, layouts) + + return tree.map_structure(jax.device_put, data) + + +class JAXEpochIterator(EpochIterator): + def __next__(self): + return next(self._epoch_iterator) + + def _get_iterator(self): + distribution = distribution_lib.distribution() + if distribution is not None: + return self._get_distributed_iterator(distribution) + if self.data_adapter.builtin_prefetch: + return self.data_adapter.get_jax_iterator() + else: + return self._prefetch_numpy_iterator( + self.data_adapter.get_jax_iterator() + ) + + def _get_distributed_iterator(self, distribution): + """Lazily compute layouts to reduce host to device transfer latency.""" + layouts = None + for data in self.data_adapter.get_jax_iterator(): + if layouts is None: + layouts = tree.map_structure( + lambda d: distribution.get_data_layout( + d.shape + ).backend_layout, + data, + ) + yield _distribute_data(data, layouts) + + def _prefetch_numpy_iterator(self, numpy_iterator): + """Shard and prefetch batches on device. + + Most of the implementation has been borrowed from + `flax.jax_utils.prefetch_to_device` + + This utility takes an iterator and returns a new iterator which fills an + on device prefetch buffer. Eager prefetching can improve the performance + of training loops significantly by overlapping compute and data + transfer. + """ + queue = collections.deque() + + # If you're training on GPUs, 2 is generally the best choice because + # this guarantees that you can overlap a training step on GPU with a + # data prefetch step on CPU. + def enqueue(n=2): + for data in itertools.islice(numpy_iterator, n): + queue.append(_distribute_data(data)) + + enqueue(n=2) # TODO: should we make `n` configurable? + while queue: + yield queue.popleft() + enqueue(1) diff --git a/keras/src/backend/numpy/__init__.py b/keras/src/backend/numpy/__init__.py new file mode 100644 index 000000000000..1a9d8eeb7916 --- /dev/null +++ b/keras/src/backend/numpy/__init__.py @@ -0,0 +1,26 @@ +from keras.src.backend.common.name_scope import name_scope +from keras.src.backend.numpy import core +from keras.src.backend.numpy import image +from keras.src.backend.numpy import linalg +from keras.src.backend.numpy import math +from keras.src.backend.numpy import nn +from keras.src.backend.numpy import numpy +from keras.src.backend.numpy import random +from keras.src.backend.numpy.core import IS_THREAD_SAFE +from keras.src.backend.numpy.core import SUPPORTS_RAGGED_TENSORS +from keras.src.backend.numpy.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.numpy.core import Variable +from keras.src.backend.numpy.core import cast +from keras.src.backend.numpy.core import compute_output_spec +from keras.src.backend.numpy.core import cond +from keras.src.backend.numpy.core import convert_to_numpy +from keras.src.backend.numpy.core import convert_to_tensor +from keras.src.backend.numpy.core import device_scope +from keras.src.backend.numpy.core import is_tensor +from keras.src.backend.numpy.core import random_seed_dtype +from keras.src.backend.numpy.core import shape +from keras.src.backend.numpy.core import vectorized_map +from keras.src.backend.numpy.rnn import cudnn_ok +from keras.src.backend.numpy.rnn import gru +from keras.src.backend.numpy.rnn import lstm +from keras.src.backend.numpy.rnn import rnn diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py new file mode 100644 index 000000000000..16b2303e5e43 --- /dev/null +++ b/keras/src/backend/numpy/core.py @@ -0,0 +1,454 @@ +import builtins +import contextlib +import functools +import warnings + +import numpy as np + +from keras.src import tree +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.backend_utils import slice_along_axis +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope + +SUPPORTS_SPARSE_TENSORS = False +SUPPORTS_RAGGED_TENSORS = False +IS_THREAD_SAFE = True + + +class Variable(KerasVariable): + def _initialize(self, value): + self._value = value + + def _direct_assign(self, value): + self._value = np.array(value, dtype=self._dtype) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) + + # Overload native accessor. + def __array__(self): + return self.value + + +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + if sparse: + raise ValueError("`sparse=True` is not supported with numpy backend") + if ragged: + raise ValueError("`ragged=True` is not supported with numpy backend") + if dtype is not None: + dtype = standardize_dtype(dtype) + if isinstance(x, Variable): + if dtype and dtype != x.dtype: + return x.value.astype(dtype) + return x.value + if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16": + # Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset). + # Instead we convert "as is" (to stored dtype) and cast. + return np.asarray(x).astype(dtype) + if dtype is None: + dtype = result_type( + *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)] + ) + return np.array(x, dtype=dtype) + + +def convert_to_numpy(x): + return np.array(x) + + +def is_tensor(x): + if isinstance(x, (np.generic, np.ndarray)): + return True + return False + + +def shape(x): + return x.shape + + +def cast(x, dtype): + return convert_to_tensor(x, dtype=dtype) + + +def cond(pred, true_fn, false_fn): + if pred: + return true_fn() + return false_fn() + + +def vectorized_map(function, elements): + if not isinstance(elements, (list, tuple)): + return np.stack([function(x) for x in elements]) + else: + batch_size = elements[0].shape[0] + output_store = [] + for index in range(batch_size): + output_store.append(function([x[index] for x in elements])) + return np.stack(output_store) + + +# Shape / dtype inference util +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(), SymbolicScope(): + + def has_none_shape(x): + if isinstance(x, KerasTensor): + return None in x.shape + return False + + none_in_shape = any( + builtins.map(has_none_shape, tree.flatten((args, kwargs))) + ) + + def convert_keras_tensor_to_numpy(x, fill_value=None): + if isinstance(x, KerasTensor): + shape = list(x.shape) + if fill_value: + for i, e in enumerate(shape): + if e is None: + shape[i] = fill_value + return np.empty( + shape=shape, + dtype=x.dtype, + ) + return x + + args_1, kwargs_1 = tree.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=83), + (args, kwargs), + ) + outputs_1 = fn(*args_1, **kwargs_1) + + outputs = outputs_1 + + if none_in_shape: + args_2, kwargs_2 = tree.map_structure( + lambda x: convert_keras_tensor_to_numpy(x, fill_value=89), + (args, kwargs), + ) + outputs_2 = fn(*args_2, **kwargs_2) + + flat_out_1 = tree.flatten(outputs_1) + flat_out_2 = tree.flatten(outputs_2) + + flat_out = [] + for x1, x2 in zip(flat_out_1, flat_out_2): + shape = list(x1.shape) + for i, e in enumerate(x2.shape): + if e != shape[i]: + shape[i] = None + flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype))) + outputs = tree.pack_sequence_as(outputs_1, flat_out) + + def convert_numpy_to_keras_tensor(x): + if is_tensor(x): + return KerasTensor(x.shape, standardize_dtype(x.dtype)) + return x + + output_spec = tree.map_structure(convert_numpy_to_keras_tensor, outputs) + return output_spec + + +def map(f, xs): + def g(_, x): + return (), f(x) + + _, ys = scan(g, (), xs) + return ys + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # Ref: jax.lax.scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + xs_flat = tree.flatten(xs) + xs_flat = [convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else shape(xs_flat[0])[0] + + init_flat = tree.flatten(init) + init_flat = [convert_to_tensor(init) for init in init_flat] + init = pack_output(init_flat) + dummy_y = [np.zeros_like(init) for init in init_flat] + + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(n)): + xs_slice = [x[i] for x in xs_flat] + packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None + carry, y = f(carry, packed_xs) + ys.append(y if y is not None else dummy_y) + stacked_y = tree.map_structure( + lambda *ys: np.stack(ys), *maybe_reversed(ys) + ) + return carry, stacked_y + + +def associative_scan(f, elems, reverse=False, axis=0): + # Ref: jax.lax.associative_scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + elems_flat = tree.flatten(elems) + elems_flat = [convert_to_tensor(elem) for elem in elems_flat] + if reverse: + elems_flat = [np.flip(elem, (axis,)) for elem in elems_flat] + + def _combine(a_flat, b_flat): + a = tree.pack_sequence_as(elems, a_flat) + b = tree.pack_sequence_as(elems, b_flat) + c = f(a, b) + c_flat = tree.flatten(c) + return c_flat + + num_elems = int(elems_flat[0].shape[axis]) + if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): + raise ValueError( + "Array inputs to associative_scan must have the same " + "first dimension. (saw: {})".format( + [elem.shape for elem in elems_flat] + ) + ) + + def _interleave(a, b, axis): + """Given two Tensors of static shape, interleave them along axis.""" + assert ( + a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 + ) + + # we want to get a: [a1, a2], b: [b1, b2] + # to a: [a1, 0, a2, 0], b: [0, b1, 0, b2] + a_shape = list(a.shape) + a_shape[axis] = a.shape[axis] * 2 - 1 + + b_shape = list(b.shape) + b_shape[axis] = b.shape[axis] * 2 - 1 + + a_dil = np.zeros(a_shape) + np.copyto(slice_along_axis(a_dil, 0, None, 2, axis), a) + b_dil = np.zeros(b_shape) + np.copyto(slice_along_axis(b_dil, 0, None, 2, axis), b) + + a_pad = [[0, 0] for _ in range(a.ndim)] + a_pad[axis][-1] = 1 if a.shape[axis] == b.shape[axis] else 0 + + b_pad = [[0, 0] for _ in range(b.ndim)] + b_pad[axis] = [1, 0] if a.shape[axis] == b.shape[axis] else [1, 1] + + op = np.bitwise_or if a.dtype == np.bool_ else np.add + return op( + np.pad(a_dil, a_pad), + np.pad(b_dil, b_pad), + ) + + def _scan(elems): + num_elems = elems[0].shape[axis] + if num_elems < 2: + return elems + + reduced_elems = _combine( + [ + slice_along_axis(elem, 0, -1, step=2, axis=axis) + for elem in elems + ], + [ + slice_along_axis(elem, 1, None, step=2, axis=axis) + for elem in elems + ], + ) + + odd_elems = _scan(reduced_elems) + if num_elems % 2 == 0: + even_elems = _combine( + [slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems], + [ + slice_along_axis(e, 2, None, step=2, axis=axis) + for e in elems + ], + ) + else: + even_elems = _combine( + odd_elems, + [ + slice_along_axis(e, 2, None, step=2, axis=axis) + for e in elems + ], + ) + + even_elems = [ + np.concatenate( + [slice_along_axis(elem, 0, 1, axis=axis), result], + axis=axis, + ) + for (elem, result) in zip(elems, even_elems) + ] + return list( + builtins.map( + functools.partial(_interleave, axis=axis), even_elems, odd_elems + ) + ) + + scans = _scan(elems_flat) + if reverse: + scans = [np.flip(scanned, (axis,)) for scanned in scans] + + return tree.pack_sequence_as(elems, scans) + + +def scatter(indices, values, shape): + indices = convert_to_tensor(indices) + values = convert_to_tensor(values) + zeros = np.zeros(shape, dtype=values.dtype) + + index_length = indices.shape[-1] + value_shape = shape[index_length:] + indices = np.reshape(indices, [-1, index_length]) + values = np.reshape(values, [-1] + list(value_shape)) + + for i in range(indices.shape[0]): + index = indices[i] + zeros[tuple(index)] += values[i] + return zeros + + +def scatter_update(inputs, indices, updates): + indices = np.array(indices) + indices = np.transpose(indices) + inputs[tuple(indices)] = updates + return inputs + + +def slice(inputs, start_indices, shape): + # Validate inputs + assert len(start_indices) == len(shape) + + # Generate list of indices arrays for each dimension + indices = [ + np.arange(start, start + length) + for start, length in zip(start_indices, shape) + ] + + # Use np.ix_ to create a multidimensional index array + mesh = np.ix_(*indices) + + return inputs[mesh] + + +def slice_update(inputs, start_indices, updates): + # Generate list of indices arrays for each dimension + indices = [ + np.arange(start, start + length) + for start, length in zip(start_indices, updates.shape) + ] + + # Use np.ix_ to create a multidimensional index array + mesh = np.ix_(*indices) + inputs[mesh] = updates + return inputs + + +def switch(index, branches, *operands): + index = convert_to_tensor(index, "int32") + index = np.clip(index, 0, len(branches) - 1) + return branches[index](*operands) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + current_iter = 0 + iteration_check = ( + lambda iter: maximum_iterations is None or iter < maximum_iterations + ) + is_tuple = isinstance(loop_vars, (tuple, list)) + loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,) + loop_vars = tree.map_structure(convert_to_tensor, loop_vars) + while cond(*loop_vars) and iteration_check(current_iter): + loop_vars = body(*loop_vars) + if not isinstance(loop_vars, (list, tuple)): + loop_vars = (loop_vars,) + loop_vars = tuple(loop_vars) + current_iter += 1 + return loop_vars if is_tuple else loop_vars[0] + + +def fori_loop(lower, upper, body_fun, init_val): + val = init_val + for i in range(lower, upper): + val = body_fun(i, val) + return val + + +def stop_gradient(variable): + return variable + + +def unstack(x, num=None, axis=0): + x = np.moveaxis(x, axis, 0) + return [x[i] for i in range(x.shape[0])] + + +def random_seed_dtype(): + return "uint32" + + +class custom_gradient: + """Decorator for custom gradients. + + Args: + fun: Forward pass function. + """ + + def __init__(self, fun): + warnings.warn( + "`custom_gradient` for the numpy backend acts as a pass-through to " + "support the forward pass. No gradient computation or modification " + "takes place." + ) + self.fun = fun + + def __call__(self, *args, **kwargs): + outputs, _ = self.fun(*args, **kwargs) + return outputs + + +@contextlib.contextmanager +def device_scope(device_name): + yield + + +def remat(f): + warnings.warn( + "Rematerialization memory optimization is not supported by the " + "Numpy backend. Please switch to JAX, TensorFlow, or PyTorch to " + "utilize this feature." + ) + return f diff --git a/keras/src/backend/numpy/export.py b/keras/src/backend/numpy/export.py new file mode 100644 index 000000000000..f754c5bc6333 --- /dev/null +++ b/keras/src/backend/numpy/export.py @@ -0,0 +1,10 @@ +class NumpyExportArchive: + def track(self, resource): + raise NotImplementedError( + "`track` is not implemented in the numpy backend." + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + raise NotImplementedError( + "`add_endpoint` is not implemented in the numpy backend." + ) diff --git a/keras/src/backend/numpy/image.py b/keras/src/backend/numpy/image.py new file mode 100644 index 000000000000..30ce1c9bba4c --- /dev/null +++ b/keras/src/backend/numpy/image.py @@ -0,0 +1,1202 @@ +import ml_dtypes +import numpy as np + +from keras.src import backend +from keras.src.backend.numpy.core import convert_to_tensor +from keras.src.random.seed_generator import draw_seed +from keras.src.utils.module_utils import scipy + +RESIZE_INTERPOLATIONS = ( + "bilinear", + "nearest", + "lanczos3", + "lanczos5", + "bicubic", +) +AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order + "nearest": 0, + "bilinear": 1, +} +AFFINE_TRANSFORM_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} +MAP_COORDINATES_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} +SCALE_AND_TRANSLATE_METHODS = { + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", +} + + +def rgb_to_grayscale(images, data_format=None): + images = convert_to_tensor(images) + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + # Convert to floats + original_dtype = images.dtype + compute_dtype = backend.result_type(images.dtype, float) + images = images.astype(compute_dtype) + + # Ref: tf.image.rgb_to_grayscale + rgb_weights = np.array([0.2989, 0.5870, 0.1140], dtype=images.dtype) + grayscales = np.tensordot(images, rgb_weights, axes=(channels_axis, -1)) + grayscales = np.expand_dims(grayscales, axis=channels_axis) + return grayscales.astype(original_dtype) + + +def rgb_to_hsv(images, data_format=None): + # Ref: dm_pix + images = convert_to_tensor(images) + dtype = backend.standardize_dtype(images.dtype) + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={dtype}" + ) + eps = ml_dtypes.finfo(dtype).eps + images = np.where(np.abs(images) < eps, 0.0, images) + red, green, blue = np.split(images, 3, channels_axis) + red = np.squeeze(red, channels_axis) + green = np.squeeze(green, channels_axis) + blue = np.squeeze(blue, channels_axis) + + def rgb_planes_to_hsv_planes(r, g, b): + value = np.maximum(np.maximum(r, g), b) + minimum = np.minimum(np.minimum(r, g), b) + range_ = value - minimum + + safe_value = np.where(value > 0, value, 1.0) + safe_range = np.where(range_ > 0, range_, 1.0) + + saturation = np.where(value > 0, range_ / safe_value, 0.0) + norm = 1.0 / (6.0 * safe_range) + + hue = np.where( + value == g, + norm * (b - r) + 2.0 / 6.0, + norm * (r - g) + 4.0 / 6.0, + ) + hue = np.where(value == r, norm * (g - b), hue) + hue = np.where(range_ > 0, hue, 0.0) + (hue < 0.0).astype(hue.dtype) + return hue, saturation, value + + images = np.stack( + rgb_planes_to_hsv_planes(red, green, blue), axis=channels_axis + ) + return images.astype(dtype) + + +def hsv_to_rgb(images, data_format=None): + # Ref: dm_pix + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + hue, saturation, value = np.split(images, 3, channels_axis) + hue = np.squeeze(hue, channels_axis) + saturation = np.squeeze(saturation, channels_axis) + value = np.squeeze(value, channels_axis) + + def hsv_planes_to_rgb_planes(hue, saturation, value): + dh = np.mod(hue, 1.0) * 6.0 + dr = np.clip(np.abs(dh - 3.0) - 1.0, 0.0, 1.0) + dg = np.clip(2.0 - np.abs(dh - 2.0), 0.0, 1.0) + db = np.clip(2.0 - np.abs(dh - 4.0), 0.0, 1.0) + one_minus_s = 1.0 - saturation + + red = value * (one_minus_s + saturation * dr) + green = value * (one_minus_s + saturation * dg) + blue = value * (one_minus_s + saturation * db) + return red, green, blue + + images = np.stack( + hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis + ) + return images.astype(dtype) + + +def resize( + images, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in RESIZE_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}" + ) + if fill_mode != "constant": + raise ValueError( + "Invalid value for argument `fill_mode`. Only `'constant'` " + f"is supported. Received: fill_mode={fill_mode}" + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) + if not len(size) == 2: + raise ValueError( + "Argument `size` must be a tuple of two elements " + f"(height, width). Received: size={size}" + ) + size = tuple(size) + target_height, target_width = size + if len(images.shape) == 4: + if data_format == "channels_last": + size = (images.shape[0],) + size + (images.shape[-1],) + else: + size = (images.shape[0], images.shape[1]) + size + elif len(images.shape) == 3: + if data_format == "channels_last": + size = size + (images.shape[-1],) + else: + size = (images.shape[0],) + size + else: + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + if crop_to_aspect_ratio: + shape = images.shape + if data_format == "channels_last": + height, width = shape[-3], shape[-2] + else: + height, width = shape[-2], shape[-1] + crop_height = int(float(width * target_height) / target_width) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / target_height) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + if data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + elif pad_to_aspect_ratio: + shape = images.shape + batch_size = images.shape[0] + if data_format == "channels_last": + height, width, channels = shape[-3], shape[-2], shape[-1] + else: + channels, height, width = shape[-3], shape[-2], shape[-1] + pad_height = int(float(width * target_height) / target_width) + pad_height = max(height, pad_height) + pad_width = int(float(height * target_width) / target_height) + pad_width = max(width, pad_width) + img_box_hstart = int(float(pad_height - height) / 2) + img_box_wstart = int(float(pad_width - width) / 2) + + if data_format == "channels_last": + if img_box_hstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, + ) + else: + padded_img = np.concatenate( + [ + np.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=0, + ) + elif img_box_wstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=2, + ) + else: + padded_img = np.concatenate( + [ + np.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + images, + np.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, + ) + else: + padded_img = images + else: + if img_box_hstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, channels, img_box_hstart, width) + ) + * fill_value, + images, + np.ones( + (batch_size, channels, img_box_hstart, width) + ) + * fill_value, + ], + axis=2, + ) + else: + padded_img = np.concatenate( + [ + np.ones((channels, img_box_hstart, width)) + * fill_value, + images, + np.ones((channels, img_box_hstart, width)) + * fill_value, + ], + axis=1, + ) + elif img_box_wstart > 0: + if len(images.shape) == 4: + padded_img = np.concatenate( + [ + np.ones( + (batch_size, channels, height, img_box_wstart) + ) + * fill_value, + images, + np.ones( + (batch_size, channels, height, img_box_wstart) + ) + * fill_value, + ], + axis=3, + ) + else: + padded_img = np.concatenate( + [ + np.ones((channels, height, img_box_wstart)) + * fill_value, + images, + np.ones((channels, height, img_box_wstart)) + * fill_value, + ], + axis=2, + ) + else: + padded_img = images + images = padded_img + + return _resize(images, size, method=interpolation, antialias=antialias) + + +def _compute_weight_mat( + input_size, output_size, scale, translation, kernel, antialias +): + dtype = np.result_type(scale, translation) + inv_scale = 1.0 / scale + kernel_scale = np.maximum(inv_scale, 1.0) if antialias else 1.0 + + sample_f = ( + (np.arange(output_size, dtype=dtype) + 0.5) * inv_scale + - translation * inv_scale + - 0.5 + ) + + x = ( + np.abs( + sample_f[np.newaxis, :] + - np.arange(input_size, dtype=dtype)[:, np.newaxis] + ) + / kernel_scale + ) + + weights = kernel(x) + + total_weight_sum = np.sum(weights, axis=0, keepdims=True) + weights = np.where( + np.abs(total_weight_sum) > 1000.0 * np.finfo(np.float32).eps, + np.divide( + weights, np.where(total_weight_sum != 0, total_weight_sum, 1) + ), + 0, + ) + + input_size_minus_0_5 = input_size - 0.5 + return np.where( + np.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ + np.newaxis, : + ], + weights, + 0, + ) + + +def _resize(image, shape, method, antialias): + if method == "nearest": + return _resize_nearest(image, shape) + else: + kernel = _kernels.get(method, None) + if kernel is None: + raise ValueError("Unknown resize method") + + spatial_dims = tuple( + i for i in range(len(shape)) if image.shape[i] != shape[i] + ) + scale = [ + shape[d] / image.shape[d] if image.shape[d] != 0 else 1.0 + for d in spatial_dims + ] + + return _scale_and_translate( + image, + shape, + spatial_dims, + scale, + [0.0] * len(spatial_dims), + kernel, + antialias, + ) + + +def _resize_nearest(x, output_shape): + input_shape = x.shape + spatial_dims = tuple( + i for i in range(len(input_shape)) if input_shape[i] != output_shape[i] + ) + + for d in spatial_dims: + m, n = input_shape[d], output_shape[d] + offsets = (np.arange(n, dtype=np.float32) + 0.5) * m / n + offsets = np.floor(offsets).astype(np.int32) + indices = [slice(None)] * len(input_shape) + indices[d] = offsets + x = x[tuple(indices)] + return x + + +def _fill_triangle_kernel(x): + return np.maximum(0, 1 - np.abs(x)) + + +def _fill_keys_cubic_kernel(x): + out = ((1.5 * x - 2.5) * x) * x + 1.0 + out = np.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out) + return np.where(x >= 2.0, 0.0, out) + + +def _fill_lanczos_kernel(radius, x): + y = radius * np.sin(np.pi * x) * np.sin(np.pi * x / radius) + out = np.where( + x > 1e-3, np.divide(y, np.where(x != 0, np.pi**2 * x**2, 1)), 1 + ) + return np.where(x > radius, 0.0, out) + + +_kernels = { + "linear": _fill_triangle_kernel, + "bilinear": _fill_triangle_kernel, # For `resize`. + "cubic": _fill_keys_cubic_kernel, + "bicubic": _fill_keys_cubic_kernel, # For `resize`. + "lanczos3": lambda x: _fill_lanczos_kernel(3.0, x), + "lanczos5": lambda x: _fill_lanczos_kernel(5.0, x), +} + + +def _scale_and_translate( + x, output_shape, spatial_dims, scale, translation, kernel, antialias +): + input_shape = x.shape + + if len(spatial_dims) == 0: + return x + + if np.issubdtype(x.dtype, np.integer): + output = x.astype(np.float32) + use_rounding = True + else: + output = x.copy() + use_rounding = False + + for i, d in enumerate(spatial_dims): + d = d % x.ndim + m, n = input_shape[d], output_shape[d] + + w = _compute_weight_mat( + m, n, scale[i], translation[i], kernel, antialias + ).astype(output.dtype) + output = np.tensordot(output, w, axes=(d, 0)) + output = np.moveaxis(output, -1, d) + + if use_rounding: + output = np.clip(np.round(output), x.min(), x.max()) + output = output.astype(x.dtype) + return output + + +def affine_transform( + images, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + + images = convert_to_tensor(images) + transform = convert_to_tensor(transform) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if len(transform.shape) not in (1, 2): + raise ValueError( + "Invalid transform rank: expected rank 1 (single transform) " + "or rank 2 (batch of transforms). Received input with shape: " + f"transform.shape={transform.shape}" + ) + + # `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16. + input_dtype = backend.standardize_dtype(images.dtype) + compute_dtype = backend.result_type(input_dtype, "float32") + images = images.astype(compute_dtype) + transform = transform.astype(compute_dtype) + + # unbatched case + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + if len(transform.shape) == 1: + transform = np.expand_dims(transform, axis=0) + + if data_format == "channels_first": + images = np.transpose(images, (0, 2, 3, 1)) + + batch_size = images.shape[0] + + # get indices + meshgrid = np.meshgrid( + *[np.arange(size) for size in images.shape[1:]], indexing="ij" + ) + indices = np.concatenate( + [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 + ) + indices = np.tile(indices, (batch_size, 1, 1, 1, 1)) + + # swap the values + a0 = transform[:, 0].copy() + a2 = transform[:, 2].copy() + b1 = transform[:, 4].copy() + b2 = transform[:, 5].copy() + transform[:, 0] = b1 + transform[:, 2] = b2 + transform[:, 4] = a0 + transform[:, 5] = a2 + + # deal with transform + transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1) + transform = np.reshape(transform, (batch_size, 3, 3)) + offset = transform[:, 0:2, 2].copy() + offset = np.pad(offset, pad_width=[[0, 0], [0, 1]]) + transform[:, 0:2, 2] = 0 + + # transform the indices + coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = np.moveaxis(coordinates, source=-1, destination=1) + coordinates += np.reshape(offset, (*offset.shape, 1, 1, 1)) + + # apply affine transformation + affined = np.stack( + [ + map_coordinates( + images[i], + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for i in range(batch_size) + ], + axis=0, + ) + + if data_format == "channels_first": + affined = np.transpose(affined, (0, 3, 1, 2)) + if need_squeeze: + affined = np.squeeze(affined, axis=0) + return affined.astype(input_dtype) + + +def perspective_transform( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + start_points = convert_to_tensor(start_points) + end_points = convert_to_tensor(end_points) + + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: " + f"interpolation={interpolation}" + ) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + if start_points.ndim not in (2, 3) or start_points.shape[-2:] != (4, 2): + raise ValueError( + "Invalid start_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {start_points.shape}" + ) + if end_points.ndim not in (2, 3) or end_points.shape[-2:] != (4, 2): + raise ValueError( + "Invalid end_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {end_points.shape}" + ) + if start_points.shape != end_points.shape: + raise ValueError( + "start_points and end_points must have the same shape." + f" Received start_points.shape={start_points.shape}, " + f"end_points.shape={end_points.shape}" + ) + + input_dtype = images.dtype + if input_dtype == "float16": + images = images.astype("float32") + + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + + if len(start_points.shape) == 2: + start_points = np.expand_dims(start_points, axis=0) + if len(end_points.shape) == 2: + end_points = np.expand_dims(end_points, axis=0) + + if data_format == "channels_first": + images = np.transpose(images, (0, 2, 3, 1)) + + batch_size, height, width, channels = images.shape + + transforms = compute_homography_matrix(start_points, end_points) + + if len(transforms.shape) == 1: + transforms = np.expand_dims(transforms, axis=0) + if transforms.shape[0] == 1 and batch_size > 1: + transforms = np.tile(transforms, (batch_size, 1)) + + x, y = np.meshgrid( + np.arange(width, dtype=np.float32), + np.arange(height, dtype=np.float32), + indexing="xy", + ) + + output = np.empty((batch_size, height, width, channels)) + + for i in range(batch_size): + a0, a1, a2, a3, a4, a5, a6, a7 = transforms[i] + denom = a6 * x + a7 * y + 1.0 + x_in = (a0 * x + a1 * y + a2) / denom + y_in = (a3 * x + a4 * y + a5) / denom + + coords = np.stack([y_in.ravel(), x_in.ravel()], axis=0) + + mapped_channels = [] + for channel in range(channels): + channel_img = images[i, :, :, channel] + + mapped_channel = map_coordinates( + channel_img, + coords, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode="constant", + fill_value=fill_value, + ) + mapped_channels.append(mapped_channel.reshape(height, width)) + + output[i] = np.stack(mapped_channels, axis=-1) + + if data_format == "channels_first": + output = np.transpose(output, (0, 3, 1, 2)) + if need_squeeze: + output = np.squeeze(output, axis=0) + output = output.astype(input_dtype) + + return output + + +def compute_homography_matrix(start_points, end_points): + start_points = convert_to_tensor(start_points) + end_points = convert_to_tensor(end_points) + dtype = backend.result_type(start_points.dtype, end_points.dtype, float) + # `np.linalg.solve` lacks support for float16 and bfloat16. + compute_dtype = backend.result_type(dtype, "float32") + start_points = start_points.astype(dtype) + end_points = end_points.astype(dtype) + + start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1] + start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1] + start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1] + start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1] + + end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1] + end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1] + end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1] + end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1] + + coefficient_matrix = np.stack( + [ + np.stack( + [ + end_x1, + end_y1, + np.ones_like(end_x1), + np.zeros_like(end_x1), + np.zeros_like(end_x1), + np.zeros_like(end_x1), + -start_x1 * end_x1, + -start_x1 * end_y1, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x1), + np.zeros_like(end_x1), + np.zeros_like(end_x1), + end_x1, + end_y1, + np.ones_like(end_x1), + -start_y1 * end_x1, + -start_y1 * end_y1, + ], + axis=-1, + ), + np.stack( + [ + end_x2, + end_y2, + np.ones_like(end_x2), + np.zeros_like(end_x2), + np.zeros_like(end_x2), + np.zeros_like(end_x2), + -start_x2 * end_x2, + -start_x2 * end_y2, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x2), + np.zeros_like(end_x2), + np.zeros_like(end_x2), + end_x2, + end_y2, + np.ones_like(end_x2), + -start_y2 * end_x2, + -start_y2 * end_y2, + ], + axis=-1, + ), + np.stack( + [ + end_x3, + end_y3, + np.ones_like(end_x3), + np.zeros_like(end_x3), + np.zeros_like(end_x3), + np.zeros_like(end_x3), + -start_x3 * end_x3, + -start_x3 * end_y3, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x3), + np.zeros_like(end_x3), + np.zeros_like(end_x3), + end_x3, + end_y3, + np.ones_like(end_x3), + -start_y3 * end_x3, + -start_y3 * end_y3, + ], + axis=-1, + ), + np.stack( + [ + end_x4, + end_y4, + np.ones_like(end_x4), + np.zeros_like(end_x4), + np.zeros_like(end_x4), + np.zeros_like(end_x4), + -start_x4 * end_x4, + -start_x4 * end_y4, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x4), + np.zeros_like(end_x4), + np.zeros_like(end_x4), + end_x4, + end_y4, + np.ones_like(end_x4), + -start_y4 * end_x4, + -start_y4 * end_y4, + ], + axis=-1, + ), + ], + axis=1, + ) + + target_vector = np.stack( + [ + start_x1, + start_y1, + start_x2, + start_y2, + start_x3, + start_y3, + start_x4, + start_y4, + ], + axis=-1, + ) + target_vector = np.expand_dims(target_vector, axis=-1) + coefficient_matrix = coefficient_matrix.astype(compute_dtype) + target_vector = target_vector.astype(compute_dtype) + homography_matrix = np.linalg.solve(coefficient_matrix, target_vector) + homography_matrix = np.reshape(homography_matrix, [-1, 8]) + return homography_matrix.astype(dtype) + + +def map_coordinates( + inputs, coordinates, order, fill_mode="constant", fill_value=0.0 +): + inputs = convert_to_tensor(inputs) + coordinates = convert_to_tensor(coordinates) + if coordinates.shape[0] != len(inputs.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`inputs`. " + f"Received inputs with shape: {inputs.shape} and coordinate " + f"leading dim of {coordinates.shape[0]}" + ) + if len(coordinates.shape) < 2: + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {coordinates.shape}" + ) + if fill_mode not in MAP_COORDINATES_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: " + f"fill_mode={fill_mode}" + ) + if order not in range(2): + raise ValueError( + "Invalid value for argument `order`. Expected one of " + f"{[0, 1]}. Received: order={order}" + ) + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, inputs.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + inputs, padding, mode=pad_mode, constant_values=fill_value + ) + else: + padded = np.pad(inputs, padding, mode=pad_mode) + + # `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16. + if backend.is_float_dtype(padded.dtype): + padded = padded.astype("float32") + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + ) + return result.astype(inputs.dtype) + + +def gaussian_blur( + images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None +): + def _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype): + def _get_gaussian_kernel1d(size, sigma): + x = np.arange(size, dtype=dtype) - (size - 1) / 2 + kernel1d = np.exp(-0.5 * (x / sigma) ** 2) + return kernel1d / np.sum(kernel1d) + + def _get_gaussian_kernel2d(size, sigma): + size = np.asarray(size, dtype) + kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0]) + kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1]) + return np.outer(kernel1d_y, kernel1d_x) + + kernel = _get_gaussian_kernel2d(kernel_size, sigma) + kernel = kernel[:, :, np.newaxis] + kernel = np.tile(kernel, (1, 1, num_channels)) + return kernel.astype(dtype) + + images = convert_to_tensor(images) + kernel_size = convert_to_tensor(kernel_size) + sigma = convert_to_tensor(sigma) + input_dtype = backend.standardize_dtype(images.dtype) + # `scipy.signal.convolve2d` lacks support for float16 and bfloat16. + compute_dtype = backend.result_type(input_dtype, "float32") + images = images.astype(compute_dtype) + sigma = sigma.astype(compute_dtype) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_first": + images = np.transpose(images, (0, 2, 3, 1)) + + batch_size, height, width, num_channels = images.shape + + kernel = _create_gaussian_kernel( + kernel_size, sigma, num_channels, input_dtype + ) + + pad_h = kernel_size[0] // 2 + pad_w = kernel_size[1] // 2 + + blurred_images = np.empty_like(images) + + for b in range(batch_size): + for ch in range(num_channels): + padded = np.pad( + images[b, :, :, ch], + ((pad_h, pad_h), (pad_w, pad_w)), + mode="constant", + ) + blurred_images[b, :, :, ch] = scipy.signal.convolve2d( + padded, kernel[:, :, ch], mode="valid" + ) + + if data_format == "channels_first": + blurred_images = np.transpose(blurred_images, (0, 3, 1, 2)) + if need_squeeze: + blurred_images = np.squeeze(blurred_images, axis=0) + return blurred_images.astype(input_dtype) + + +def elastic_transform( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + images = convert_to_tensor(images) + input_dtype = images.dtype + + alpha = convert_to_tensor(alpha, dtype=input_dtype) + sigma = convert_to_tensor(sigma, dtype=input_dtype) + + kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1) + + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_last": + batch_size, height, width, channels = images.shape + channel_axis = -1 + else: + batch_size, channels, height, width = images.shape + channel_axis = 1 + + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + dx = ( + rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype( + input_dtype + ) + * sigma + ) + dy = ( + rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype( + input_dtype + ) + * sigma + ) + + dx = gaussian_blur( + np.expand_dims(dx, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + dy = gaussian_blur( + np.expand_dims(dy, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + + dx = np.squeeze(dx) + dy = np.squeeze(dy) + + x, y = np.meshgrid(np.arange(width), np.arange(height)) + x, y = x[None, :, :], y[None, :, :] + + distorted_x = x + alpha * dx + distorted_y = y + alpha * dy + + transformed_images = np.zeros_like(images) + + if data_format == "channels_last": + for i in range(channels): + transformed_images[..., i] = np.stack( + [ + map_coordinates( + images[b, ..., i], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + else: + for i in range(channels): + transformed_images[:, i, :, :] = np.stack( + [ + map_coordinates( + images[b, i, ...], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + + if need_squeeze: + transformed_images = np.squeeze(transformed_images, axis=0) + transformed_images = transformed_images.astype(input_dtype) + + return transformed_images + + +def scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias=True, +): + if method not in SCALE_AND_TRANSLATE_METHODS: + raise ValueError( + "Invalid value for argument `method`. Expected of one " + f"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}" + ) + if method in ("linear", "bilinear", "trilinear", "triangle"): + method = "linear" + elif method in ("cubic", "bicubic", "tricubic"): + method = "cubic" + + images = convert_to_tensor(images) + scale = convert_to_tensor(scale) + translation = convert_to_tensor(translation) + kernel = _kernels[method] + dtype = backend.result_type(scale.dtype, translation.dtype) + scale = scale.astype(dtype) + translation = translation.astype(dtype) + return _scale_and_translate( + images, + output_shape, + spatial_dims, + scale, + translation, + kernel, + antialias, + ) diff --git a/keras/src/backend/numpy/layer.py b/keras/src/backend/numpy/layer.py new file mode 100644 index 000000000000..08b761f972e8 --- /dev/null +++ b/keras/src/backend/numpy/layer.py @@ -0,0 +1,2 @@ +class NumpyLayer: + pass diff --git a/keras/src/backend/numpy/linalg.py b/keras/src/backend/numpy/linalg.py new file mode 100644 index 000000000000..881911d7240a --- /dev/null +++ b/keras/src/backend/numpy/linalg.py @@ -0,0 +1,102 @@ +import numpy as np +import scipy.linalg as sl + +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.numpy.core import convert_to_tensor + + +def cholesky(a, upper=False): + return np.linalg.cholesky(a, upper=upper) + + +def cholesky_inverse(a, upper=False): + identity = np.eye(a.shape[-1], dtype=a.dtype) + inv_chol = solve_triangular(a, identity, lower=not upper) + if upper: + a_inv = np.matmul(inv_chol, inv_chol.T) + else: + a_inv = np.matmul(inv_chol.T, inv_chol) + return a_inv + + +def det(a): + return np.linalg.det(a) + + +def eig(a): + return np.linalg.eig(a) + + +def eigh(a): + return np.linalg.eigh(a) + + +def inv(a): + return np.linalg.inv(a) + + +def lu_factor(a): + if a.ndim == 2: + return sl.lu_factor(a) + + m, n = a.shape[-2:] + signature = "(m,n) -> (m,n), " + signature += "(m)" if m <= n else "(n)" + _lu_factor_gufunc = np.vectorize( + sl.lu_factor, + signature=signature, + ) + return _lu_factor_gufunc(a) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = dtypes.result_type(x.dtype, "float32") + return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype( + dtype + ) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return np.linalg.qr(x, mode=mode) + + +def solve(a, b): + return np.linalg.solve(a, b) + + +def solve_triangular(a, b, lower=False): + if a.ndim == 2: + return sl.solve_triangular(a, b, lower=lower) + + _vectorized_solve_triangular = np.vectorize( + lambda a, b: sl.solve_triangular(a, b, lower=lower), + signature="(n,n),(n,m)->(n,m)", + ) + if b.ndim == a.ndim - 1: + b = np.expand_dims(b, axis=-1) + return _vectorized_solve_triangular(a, b).squeeze(axis=-1) + return _vectorized_solve_triangular(a, b) + + +def svd(x, full_matrices=True, compute_uv=True): + return np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + + +def lstsq(a, b, rcond=None): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return np.linalg.lstsq(a, b, rcond=rcond)[0] + + +def jvp(fun, primals, tangents, has_aux=False): + raise NotImplementedError("JVP is not supported by the Numpy backend.") diff --git a/keras/src/backend/numpy/math.py b/keras/src/backend/numpy/math.py new file mode 100644 index 000000000000..db2cdbfc68ea --- /dev/null +++ b/keras/src/backend/numpy/math.py @@ -0,0 +1,322 @@ +import numpy as np + +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.jax.math import fft as jax_fft +from keras.src.backend.jax.math import fft2 as jax_fft2 +from keras.src.backend.numpy.core import convert_to_tensor +from keras.src.utils.module_utils import scipy + + +def _segment_reduction_fn( + data, segment_ids, reduction_method, num_segments, sorted +): + if num_segments is None: + num_segments = np.amax(segment_ids) + 1 + + valid_indices = segment_ids >= 0 # Ignore segment_ids that are -1 + valid_data = data[valid_indices] + valid_segment_ids = segment_ids[valid_indices] + + data_shape = list(valid_data.shape) + data_shape[0] = ( + num_segments # Replace first dimension (which corresponds to segments) + ) + + if reduction_method == np.maximum: + result = np.ones(data_shape, dtype=valid_data.dtype) * -np.inf + else: + result = np.zeros(data_shape, dtype=valid_data.dtype) + + if sorted: + reduction_method.at(result, valid_segment_ids, valid_data) + else: + sort_indices = np.argsort(valid_segment_ids) + sorted_segment_ids = valid_segment_ids[sort_indices] + sorted_data = valid_data[sort_indices] + + reduction_method.at(result, sorted_segment_ids, sorted_data) + + return result + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + return _segment_reduction_fn( + data, segment_ids, np.add, num_segments, sorted + ) + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + return _segment_reduction_fn( + data, segment_ids, np.maximum, num_segments, sorted + ) + + +def top_k(x, k, sorted=True): + if sorted: + # Take the k largest values. + sorted_indices = np.argsort(x, axis=-1)[..., ::-1] + sorted_values = np.take_along_axis(x, sorted_indices, axis=-1) + top_k_values = sorted_values[..., :k] + top_k_indices = sorted_indices[..., :k] + else: + # Partition the array such that all values larger than the k-th + # largest value are to the right of it. + top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:] + top_k_values = np.take_along_axis(x, top_k_indices, axis=-1) + return top_k_values, top_k_indices + + +def in_top_k(targets, predictions, k): + targets = targets[:, None] + topk_values = top_k(predictions, k)[0] + targets_values = np.take_along_axis(predictions, targets, axis=-1) + mask = targets_values >= topk_values + return np.any(mask, axis=-1) + + +def logsumexp(x, axis=None, keepdims=False): + return scipy.special.logsumexp(x, axis=axis, keepdims=keepdims) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return np.linalg.qr(x, mode=mode) + + +def extract_sequences(x, sequence_length, sequence_stride): + *batch_shape, _ = x.shape + batch_shape = list(batch_shape) + shape = x.shape[:-1] + ( + (x.shape[-1] - (sequence_length - sequence_stride)) // sequence_stride, + sequence_length, + ) + strides = x.strides[:-1] + ( + sequence_stride * x.strides[-1], + x.strides[-1], + ) + x = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return np.reshape(x, (*batch_shape, *x.shape[-2:])) + + +def _get_complex_tensor_from_tuple(x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + f"Received: x={x}" + ) + # `convert_to_tensor` does not support passing complex tensors. We separate + # the input out into real and imaginary and convert them separately. + real, imag = x + # Check shapes. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + "Both the real and imaginary parts should have the same shape. " + f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" + ) + # Ensure dtype is float. + if not np.issubdtype(real.dtype, np.floating) or not np.issubdtype( + imag.dtype, np.floating + ): + raise ValueError( + "At least one tensor in input `x` is not of type float." + f"Received: x={x}." + ) + complex_input = real + 1j * imag + return complex_input + + +def fft(x): + real, imag = jax_fft(x) + return np.array(real), np.array(imag) + + +def fft2(x): + real, imag = jax_fft2(x) + return np.array(real), np.array(imag) + + +def ifft2(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = np.fft.ifft2(complex_input) + return np.real(complex_output), np.imag(complex_output) + + +def rfft(x, fft_length=None): + complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward") + # numpy always outputs complex128, so we need to recast the dtype + return ( + np.real(complex_output).astype(x.dtype), + np.imag(complex_output).astype(x.dtype), + ) + + +def irfft(x, fft_length=None): + complex_input = _get_complex_tensor_from_tuple(x) + # numpy always outputs float64, so we need to recast the dtype + return np.fft.irfft( + complex_input, n=fft_length, axis=-1, norm="backward" + ).astype(x[0].dtype) + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + if standardize_dtype(x.dtype) not in {"float32", "float64"}: + raise TypeError( + "Invalid input type. Expected `float32` or `float64`. " + f"Received: input type={x.dtype}" + ) + if fft_length < sequence_length: + raise ValueError( + "`fft_length` must equal or larger than `sequence_length`. " + f"Received: sequence_length={sequence_length}, " + f"fft_length={fft_length}" + ) + if isinstance(window, str): + if window not in {"hann", "hamming"}: + raise ValueError( + "If a string is passed to `window`, it must be one of " + f'`"hann"`, `"hamming"`. Received: window={window}' + ) + x = convert_to_tensor(x) + ori_dtype = x.dtype + + if center: + pad_width = [(0, 0) for _ in range(len(x.shape))] + pad_width[-1] = (fft_length // 2, fft_length // 2) + x = np.pad(x, pad_width, mode="reflect") + + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + win = convert_to_tensor( + scipy.signal.get_window(window, sequence_length), dtype=x.dtype + ) + else: + win = convert_to_tensor(window, dtype=x.dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + win = np.pad(win, [[l_pad, r_pad]]) + else: + win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype) + + x = scipy.signal.stft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=None, + padded=False, + )[-1] + + # scale and swap to (..., num_sequences, fft_bins) + x = x / np.sqrt(1.0 / win.sum() ** 2) + x = np.swapaxes(x, -2, -1) + return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype) + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + x = _get_complex_tensor_from_tuple(x) + dtype = np.real(x).dtype + + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + win = convert_to_tensor( + scipy.signal.get_window(window, sequence_length), dtype=dtype + ) + else: + win = convert_to_tensor(window, dtype=dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + win = np.pad(win, [[l_pad, r_pad]]) + else: + win = np.ones((sequence_length + l_pad + r_pad), dtype=dtype) + + x = scipy.signal.istft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=False, + time_axis=-2, + freq_axis=-1, + )[-1] + + # scale + x = x / win.sum() if window is not None else x / sequence_stride + + start = 0 if center is False else fft_length // 2 + if length is not None: + end = start + length + elif center is True: + end = -(fft_length // 2) + else: + end = expected_output_len + return x[..., start:end] + + +def rsqrt(x): + return 1.0 / np.sqrt(x) + + +def erf(x): + return np.array(scipy.special.erf(x)) + + +def erfinv(x): + return np.array(scipy.special.erfinv(x)) + + +def solve(a, b): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return np.linalg.solve(a, b) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = dtypes.result_type(x.dtype, "float32") + return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype( + dtype + ) + + +def logdet(x): + from keras.src.backend.numpy.numpy import slogdet + + # In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See + # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html + return slogdet(x)[1] diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py new file mode 100644 index 000000000000..93e0f57831a4 --- /dev/null +++ b/keras/src/backend/numpy/nn.py @@ -0,0 +1,1231 @@ +import jax +import numpy as np +from jax import lax + +from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_jax, +) +from keras.src.backend.numpy.core import cast +from keras.src.backend.numpy.core import convert_to_tensor +from keras.src.backend.numpy.core import is_tensor +from keras.src.utils.module_utils import scipy + + +def relu(x): + x = convert_to_tensor(x) + return np.maximum(x, np.array(0.0, x.dtype)) + + +def relu6(x): + x = convert_to_tensor(x) + # np.clip incorrectly promote bfloat16 to float32, so we replace it with + # np.minimum and np.maximum here + return np.minimum( + np.maximum(x, np.array(0.0, x.dtype)), np.array(6.0, x.dtype) + ) + + +def sigmoid(x): + x = convert_to_tensor(x) + return np.array(1.0, x.dtype) / (np.array(1.0, x.dtype) + np.exp(-x)) + + +def sparse_sigmoid(x): + x = convert_to_tensor(x) + return np.where( + x <= -1, + np.array(0.0, x.dtype), + np.where( + x >= 1, np.array(1.0, x.dtype), np.array(0.5 * (x + 1), x.dtype) + ), + ) + + +def tanh(x): + return np.tanh(x) + + +def tanh_shrink(x): + x = convert_to_tensor(x) + return x - np.tanh(x) + + +def softplus(x): + x = convert_to_tensor(x) + return np.logaddexp(x, np.array(0.0, x.dtype)) + + +def softsign(x): + x = convert_to_tensor(x) + return x / (np.array(1.0, x.dtype) + np.abs(x)) + + +def soft_shrink(x, threshold=0.5): + return np.where( + x > threshold, + np.array(x - threshold, dtype=x.dtype), + np.where( + x < -threshold, + np.array(x + threshold, dtype=x.dtype), + np.array(0.0, dtype=x.dtype), + ), + ) + + +def sparse_plus(x): + return np.where( + x <= -1, + np.zeros_like(x, dtype=x.dtype), + np.where(x < 1, np.array((1 / 4) * (x + 1) ** 2, dtype=x.dtype), x), + ) + + +def silu(x): + x = convert_to_tensor(x) + return x * sigmoid(x) + + +def squareplus(x, b=4): + x = convert_to_tensor(x) + b = convert_to_tensor(b, dtype=x.dtype) + y = x + np.sqrt(x**2 + b) + return y / 2 + + +def log_sigmoid(x): + x = convert_to_tensor(x) + return -softplus(-x) + + +def leaky_relu(x, negative_slope=0.2): + x = convert_to_tensor(x) + return np.maximum(x, np.array(negative_slope, x.dtype) * x) + + +def hard_sigmoid(x): + # python numbers will be promoted to float64 by np, so it's necessary to + # first convert the python numbers to np scalars + x = x / np.array(6.0, x.dtype) + np.array(0.5, x.dtype) + return np.where( + x <= 0.0, + np.array(0.0, x.dtype), + np.where(x >= 1.0, np.array(1.0, x.dtype), x), + ) + + +def hard_silu(x): + return x * hard_sigmoid(x) + + +def elu(x, alpha=1.0): + x = convert_to_tensor(x) + return np.where( + x >= np.array(0.0, x.dtype), x, np.array(alpha, x.dtype) * np.expm1(x) + ) + + +def selu(x): + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + x = convert_to_tensor(x) + return np.array(scale, x.dtype) * elu(x, alpha) + + +def gelu(x, approximate=True): + x = convert_to_tensor(x) + # followed by JAX's implementation + if approximate: + sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) + cdf = np.array(0.5, x.dtype) * ( + np.array(1.0, x.dtype) + + np.tanh( + sqrt_2_over_pi + * (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype)) + ) + ) + return x * cdf + else: + sqrt_2 = np.sqrt(2).astype(x.dtype) + return ( + x + * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype) + / np.array(2, x.dtype) + ) + + +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + alpha = np.array(alpha, x.dtype) + return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1( + np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha + ) + + +def glu(x, axis=-1): + x = convert_to_tensor(x) + dtype = x.dtype + if x.shape[axis] % 2 != 0: + raise ValueError( + "axis size must be divisible by 2. " + f"Received: x.shape={x.shape} with axis={axis}" + ) + x1, x2 = np.split(x, 2, axis) + return (x1 * sigmoid(x2)).astype(dtype) + + +def hard_tanh(x): + x = convert_to_tensor(x) + min_val = np.asarray(-1.0, x.dtype) + max_val = np.asarray(1.0, x.dtype) + return np.array(np.clip(x, min_val, max_val), dtype=x.dtype) + + +def hard_shrink(x, threshold=0.5): + x = convert_to_tensor(x) + threshold = np.asarray(threshold, x.dtype) + return np.array( + np.where(np.abs(x) > threshold, x, np.array(0.0, dtype=x.dtype)), + dtype=x.dtype, + ) + + +def threshold(x, threshold, default_value): + x = convert_to_tensor(x) + return np.where(x > threshold, x, np.array(default_value, dtype=x.dtype)) + + +def softmax(x, axis=-1): + exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) + + +def log_softmax(x, axis=-1): + max_x = np.max(x, axis=axis, keepdims=True) + logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True)) + return x - max_x - logsumexp + + +def sparsemax(x, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(x) + logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis) + logits_cumsum = np.cumsum(logits_sorted, axis=axis) + r = np.arange(1, logits.shape[axis] + 1) + r_shape = [1] * logits.ndim + r_shape[axis] = -1 # Broadcast to match the target axis + r = r.reshape(r_shape) + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + k = np.sum(support, axis=axis, keepdims=True) + logits_cumsum_safe = np.where(support, logits_cumsum, 0.0) + tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k + output = np.maximum(logits - tau, 0.0) + return output + + +def _convert_to_spatial_operand( + x, + num_spatial_dims, + data_format="channels_last", + include_batch_and_channels=True, +): + # Helper function that converts an operand to a spatial operand. + x = (x,) * num_spatial_dims if isinstance(x, int) else x + if not include_batch_and_channels: + return x + if data_format == "channels_last": + x = (1,) + x + (1,) + else: + x = (1,) + (1,) + x + return x + + +def _pool( + inputs, + initial_value, + reduce_fn, + pool_size, + strides=None, + padding="valid", +): + """Helper function to define pooling functions. + + Args: + inputs: input data of shape `N+2`. + initial_value: the initial value for the reduction. + reduce_fn: a reduce function of the form `(T, T) -> T`. + pool_size: a sequence of `N` integers, representing the window size to + reduce over. + strides: a sequence of `N` integers, representing the inter-window + strides (default: `(1, ..., 1)`). + padding: either the string `same` or `valid`. + + Returns: + The output of the reduction for each window slice. + """ + if padding not in ("same", "valid"): + raise ValueError( + f"Invalid padding '{padding}', must be 'same' or 'valid'." + ) + padding = padding.upper() + return np.array( + lax.reduce_window( + inputs, + initial_value, + reduce_fn, + pool_size, + strides, + padding, + ) + ) + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding) + + +def average_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = _convert_to_spatial_operand( + pool_size, num_spatial_dims, data_format + ) + strides = pool_size if strides is None else strides + strides = _convert_to_spatial_operand( + strides, num_spatial_dims, data_format + ) + + pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding) + if padding == "valid": + # Avoid the extra reduce_window. + return pooled / np.prod(pool_size) + else: + # Count the number of valid entries at each input point, then use that + # for computing average. Assumes that any two arrays of same shape will + # be padded the same. Avoid broadcasting on axis where pooling is + # skipped. + shape = [ + (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size) + ] + window_counts = _pool( + np.ones(shape, inputs.dtype), + 0.0, + lax.add, + pool_size, + strides, + padding, + ) + return pooled / window_counts + + +def _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format="channels_last", + transpose=False, +): + """Create a `lax.ConvDimensionNumbers` for the given inputs.""" + num_dims = num_spatial_dims + 2 + + if data_format == "channels_last": + spatial_dims = tuple(range(1, num_dims - 1)) + inputs_dn = (0, num_dims - 1) + spatial_dims + else: + spatial_dims = tuple(range(2, num_dims)) + inputs_dn = (0, 1) + spatial_dims + + if transpose: + kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) + else: + kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) + + return lax.ConvDimensionNumbers( + lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn + ) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + if data_format == "channels_last": + channels = inputs.shape[-1] + else: + channels = inputs.shape[1] + kernel_in_channels = kernel.shape[-2] + if channels % kernel_in_channels > 0: + raise ValueError( + "The number of input channels must be evenly divisible by " + f"kernel's in_channels. Received input channels {channels} and " + f"kernel in_channels {kernel_in_channels}. " + ) + feature_group_count = channels // kernel_in_channels + return np.array( + jax.lax.conv_general_dilated( + inputs, + kernel if is_tensor(kernel) else kernel.numpy(), + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + ) + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + feature_group_count = ( + inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1] + ) + kernel = np.reshape( + kernel if is_tensor(kernel) else kernel.numpy(), + kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]), + ) + return np.array( + jax.lax.conv_general_dilated( + inputs, + kernel, + strides, + padding, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + ) + ) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + depthwise_conv_output = depthwise_conv( + inputs, + depthwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + return conv( + depthwise_conv_output, + pointwise_kernel, + strides=1, + padding="valid", + data_format=data_format, + dilation_rate=dilation_rate, + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + padding_values = compute_conv_transpose_padding_args_for_jax( + input_shape=inputs.shape, + kernel_shape=kernel.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + dimension_numbers = _convert_to_lax_conv_dimension_numbers( + num_spatial_dims, + data_format, + transpose=False, + ) + strides = _convert_to_spatial_operand( + strides, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + dilation_rate = _convert_to_spatial_operand( + dilation_rate, + num_spatial_dims, + data_format, + include_batch_and_channels=False, + ) + + return np.array( + jax.lax.conv_transpose( + inputs, + kernel if is_tensor(kernel) else kernel.numpy(), + strides, + padding=padding_values, + rhs_dilation=dilation_rate, + dimension_numbers=dimension_numbers, + transpose_kernel=True, + ) + ) + + +def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with numpy backend") + if dtype is None: + dtype = "float32" + x = convert_to_tensor(x) + input_shape = x.shape + + x = x.reshape(-1) + if not num_classes: + num_classes = np.max(x) + 1 + + batch_size = x.shape[0] + categorical = np.zeros((batch_size, num_classes), dtype=dtype) + valid_indices = x >= 0 + categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1 + + # First, reshape the array with the extra dimension at the end + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + + # Then, move this new dimension to the right place (according to axis) + if axis != -1: + categorical = np.moveaxis(categorical, -1, axis) + + return categorical + + +def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with numpy backend") + x = convert_to_tensor(x) + reduction_axis = 1 if len(x.shape) > 1 else 0 + outputs = np.max( + one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), + axis=reduction_axis, + ) + return outputs + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = np.array(target) + output = np.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + log_prob = log_softmax(output, axis=axis) + else: + output = output / np.sum(output, axis, keepdims=True) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = np.log(output) + return -np.sum(target * log_prob, axis=axis) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = np.array(target, dtype="int32") + output = np.array(output) + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: + target = np.squeeze(target, axis=-1) + + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + if target.shape != output.shape[:-1]: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if from_logits: + log_prob = log_softmax(output, axis=axis) + else: + output = output / np.sum(output, axis, keepdims=True) + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = np.log(output) + target = one_hot(target, output.shape[axis], axis=axis) + return -np.sum(target * log_prob, axis=axis) + + +def binary_crossentropy(target, output, from_logits=False): + target = np.array(target) + output = np.array(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + output = sigmoid(output) + + output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + bce = target * np.log(output) + bce += (1.0 - target) * np.log(1.0 - output) + return -bce + + +def moments(x, axes, keepdims=False, synchronized=False): + if synchronized: + raise NotImplementedError( + "Argument synchronized=True is not supported with NumPy." + ) + axes = tuple(axes) if isinstance(axes, list) else axes + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = backend.standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = np.mean(x, axes, keepdims=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean) + + if not keepdims: + mean = np.squeeze(mean, axes) + variance = np.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max) + variance = np.clip( + variance, np.finfo(np.float16).min, np.finfo(np.float16).max + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + shape = [1] * len(x.shape) + shape[axis] = mean.shape[0] + mean = np.reshape(mean, shape) + variance = np.reshape(variance, shape) + + inv = 1.0 / np.sqrt(variance + epsilon) + if scale is not None: + scale = np.reshape(scale, shape) + inv = inv * scale + + res = -mean * inv + if offset is not None: + offset = np.reshape(offset, shape) + res = res + offset + + return x * inv + res + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + # Ref: https://github.com/google-deepmind/optax + # optax.ctc_loss_with_forward_probs + target = convert_to_tensor(target, dtype="int32") + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length, "int32") + output_length = convert_to_tensor(output_length, "int32") + batch_size, max_input_length, num_classes = output.shape + batch_size, max_label_length = target.shape + log_epsilon = -1e5 + + # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = output.astype(dtype) + + def _lengths_to_paddings(lengths, max_length): + indices = np.arange(max_length).reshape( + (1,) * lengths.ndim + (max_length,) + ) + lengths = np.expand_dims(lengths, axis=-1) + elem_valid = indices < lengths + return np.logical_not(elem_valid) + + target_paddings = _lengths_to_paddings(target_length, max_label_length) + output_paddings = _lengths_to_paddings(output_length, max_input_length) + target_paddings = target_paddings.astype(output.dtype) + output_paddings = output_paddings.astype(output.dtype) + + logprobs = log_softmax(output, axis=-1) + label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype( + np.int32 + ) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32) + repeat = np.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, mask_index : mask_index + 1] # [B, T, 1] + logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + _one_hot = one_hot(target, num_classes=num_classes) # [B, N, K] + logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot) + logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + # [B, N] + logalpha_phi_init = ( + np.ones((batch_size, max_label_length + 1), dtype=output.dtype) + * log_epsilon + ) + logalpha_phi_init[:, 0] = 0.0 + logalpha_emit_init = ( + np.ones((batch_size, max_label_length), dtype=output.dtype) + * log_epsilon + ) + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return np.concatenate( + [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1 + ) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = np.logaddexp( + prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit + ) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat) + ) + + pad = pad.reshape((batch_size, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + def np_scan(f, init, xs): + carry = init + ys = [] + for x in zip(*xs): + carry, y = f(carry, x) + ys.append(y) + result = [] + for i in range(len(ys[0])): + result.append(np.stack([y[i] for y in ys])) + return carry, result + + xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = np_scan( + loop_body, (logalpha_phi_init, logalpha_emit_init), xs + ) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi[-1] = logalpha_phi_last + + # extract per_seq_loss + # [B, N+1] + _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1) + per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot) + return per_seq_loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = np.argmax(inputs, axis=-1).astype("int32") + scores = np.max(inputs, axis=-1) + + seqlen_mask = np.arange(max_length)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = np.where(seqlen_mask, mask_index, indices) + scores = np.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat_mask = indices[:, 1:] == indices[:, :-1] + repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0))) + indices = np.where(repeat_mask, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = np.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = np.expand_dims(np.arange(max_length), axis=0) # [1, N] + order = np.tile(order, (batch_size, 1)) # [B, N] + order = np.where(invalid_mask, max_length, order) + order = np.argsort(order, axis=-1) + indices = np.take_along_axis(indices, order, axis=-1) + + scores = -np.sum(scores, axis=1)[:, None] + indices = np.expand_dims(indices, axis=0) + return indices, scores + + +def _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=100, + top_paths=1, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths) + + batch_size, max_seq_len, num_classes = inputs.shape + inputs = log_softmax(inputs, axis=-1) + seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None] + + if mask_index is None: + mask_index = num_classes - 1 + + # This is a workaround for the fact that np.argsort does not support + # the order parameter which is used to break ties when scores are equal. + # For compatibility with the tensorflow implementation, we flip the inputs + # and the mask_index, and then flip the classes back to the correct indices + inputs = np.flip(inputs, axis=2) + mask_index = num_classes - mask_index - 1 + + _pad = -1 + + init_paths = np.full( + (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32 + ) + + num_init_paths = np.min(np.array([num_classes, beam_width])) + max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:] + init_classes = np.where(max_classes == mask_index, _pad, max_classes) + init_paths[:, :num_init_paths, 0] = init_classes + + init_scores = np.full( + (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype + ) + init_scores[:, :num_init_paths] = np.take_along_axis( + inputs[:, 0], max_classes, axis=1 + ) + init_masked = init_paths[:, :, 0] == _pad + + def _extend_paths(paths, scores, masked, x): + paths = np.repeat(paths, num_classes, axis=0) + scores = np.repeat(scores, num_classes) + masked = np.repeat(masked, num_classes) + + path_tail_index = np.argmax(paths == _pad, axis=1) + paths_arange = np.arange(2 * beam_width * num_classes) + path_tails = paths[paths_arange, path_tail_index - 1] + path_tails = np.where(path_tail_index == 0, _pad, path_tails) + + classes = np.arange(num_classes) + classes[mask_index] = _pad + classes = np.tile(classes, 2 * beam_width) + + prev_masked = masked + masked = classes == _pad + + masked_repeat = ~prev_masked & (path_tails == classes) + classes = np.where(masked_repeat, _pad, classes) + paths[paths_arange, path_tail_index] = classes + + x = np.tile(x, 2 * beam_width) + scores = scores + x + + return paths, scores, masked + + def _merge_scores(unique_inverse, scores): + scores_max = np.max(scores) + scores_exp = np.exp(scores - scores_max) + scores = np.zeros_like(scores) + for i, u in enumerate(unique_inverse): + scores[u] += scores_exp[i] + scores = np.log(scores) + scores_max + return scores + + def _prune_paths(paths, scores, masked): + paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) + pad_size = (2 * num_classes * beam_width) - len(paths) + if pad_size > 0: + paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) + paths = paths[: 2 * num_classes * beam_width] + if len(unique_inverse.shape) >= 2: + unique_inverse = np.squeeze(unique_inverse, axis=1) + + emit_scores = np.where(masked, -np.inf, scores) + mask_scores = np.where(masked, scores, -np.inf) + + emit_scores = _merge_scores(unique_inverse, emit_scores) + mask_scores = _merge_scores(unique_inverse, mask_scores) + + total_scores = np.logaddexp(emit_scores, mask_scores) + top_indices = np.argsort(total_scores, kind="stable")[-beam_width:] + + paths = paths[top_indices] + emit_scores = emit_scores[top_indices] + mask_scores = mask_scores[top_indices] + + paths = np.tile(paths, (2, 1)) + scores = np.concatenate([emit_scores, mask_scores]) + masked = np.concatenate( + [np.zeros(beam_width, bool), np.ones(beam_width, bool)] + ) + + return paths, scores, masked + + def _decode_step(paths, scores, masked, x): + paths, scores, masked = _extend_paths(paths, scores, masked, x) + paths, scores, masked = _prune_paths(paths, scores, masked) + return paths, scores, masked + + def _step(prev, x): + paths, scores, masked = prev + x, seqlen_mask = x + if not seqlen_mask: + paths, scores, masked = _decode_step(paths, scores, masked, x) + return (paths, scores, masked), None + + def _decode_batch( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ): + def np_scan_only_carry(f, init, xs): + carry = init + for x in zip(*xs): + carry, y = f(carry, x) + return carry, None + + (paths, scores, masked), _ = np_scan_only_carry( + _step, + (init_paths, init_scores, init_masked), + (inputs[1:], seqlen_mask[1:]), + ) + + paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0) + pad_size = (2 * num_classes * beam_width) - len(paths) + if pad_size > 0: + paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad) + paths = paths[: 2 * num_classes * beam_width] + if len(unique_inverse.shape) >= 2: + unique_inverse = np.squeeze(unique_inverse, axis=1) + scores = _merge_scores(unique_inverse, scores) + + top_indices = np.argsort(scores)[-top_paths:][::-1] + paths = paths[top_indices] + scores = scores[top_indices] + + return paths, scores + + results = [ + _decode_batch(p, s, m, i, sm) + for p, s, m, i, sm in zip( + init_paths, init_scores, init_masked, inputs, seqlen_mask + ) + ] + paths = np.stack([r[0] for r in results]) + scores = np.stack([r[1] for r in results]) + + # convert classes back to the correct indices + paths = np.where(paths == _pad, _pad, num_classes - paths - 1) + paths = np.transpose(paths, [1, 0, 2]) + return paths, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + return _ctc_beam_search_decode( + inputs, + sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + mask_index=mask_index, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = np.mean(np.square(x1 - x2)) + psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) + return psnr + + +def _get_large_negative(dtype): + dtype = backend.standardize_dtype(dtype) + val = 65500.0 if dtype == "float16" else 3.38953e38 + return np.asarray(val * -0.7, dtype=dtype) + + +def _apply_masks(logits, mask, is_causal): + if mask is None and not is_causal: + return logits + + combined_mask = np.ones_like(logits, dtype=np.bool_) + if mask is not None: + combined_mask = np.logical_and(combined_mask, mask) + + if is_causal: + T, S = logits.shape[2], logits.shape[3] + mask = np.tril(np.ones((T, S), dtype=np.bool_)) + mask = mask[None, None, :, :] + combined_mask = np.logical_and(combined_mask, mask) + + padded_logits = np.where( + combined_mask, logits, _get_large_negative(logits.dtype) + ) + return padded_logits + + +def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): + original_dtype = key.dtype + logits_dtype = np.promote_types(query.dtype, np.float32) + if backend.standardize_dtype(key.dtype) == "bfloat16": + # `np.einsum` doesn't support bfloat16 + key = key.astype("float32") + value = value.astype("float32") + logits = np.einsum("BTNH,BSNH->BNTS", query, key) + logits = logits.astype(logits_dtype) + logits *= np.array(scale, dtype=logits.dtype) + + if bias is not None: + logits = (logits + bias).astype(logits.dtype) + + padded_logits = _apply_masks(logits, mask, is_causal) + + # Softmax and it is always carried out in fp32. + padded_logits = padded_logits.astype(np.float32) + probs = softmax(padded_logits, axis=-1).astype(original_dtype) + encoded_dtype = probs.dtype + if backend.standardize_dtype(probs.dtype) == "bfloat16": + # `np.einsum` doesn't support bfloat16 + probs = probs.astype("float32") + value = value.astype("float32") + encoded = np.einsum("BNTS,BSNH->BTNH", probs, value) + encoded = encoded.astype(encoded_dtype) + return encoded + + +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, +): + if flash_attention is None: + flash_attention = False + if flash_attention: + raise ValueError("Flash attention is not supported in numpy backend.") + + # Ref: jax.nn.dot_product_attention + # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 + # Not support `query_seq_lengths` and `key_value_seq_lengths` args + query = convert_to_tensor(query) + key = convert_to_tensor(key) + value = convert_to_tensor(value) + if len(query.shape) != 4: + raise ValueError( + "`dot_product_attention` only supports 4D inputs. " + f"Received: query.shape={query.shape}, key.shape={key.shape}, " + f"value.shape={value.shape}." + ) + compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype) + query = cast(query, compute_dtype) + key = cast(key, compute_dtype) + value = cast(value, compute_dtype) + if bias is not None: + bias = convert_to_tensor(bias, dtype=compute_dtype) + + _, _, _, H = key.shape + scale = (1.0 / np.sqrt(H)) if scale is None else scale + return _dot_product_attention_xla( + query, key, value, bias, mask, is_causal, scale + ) + + +def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + """NumPy implementation of Unfold. + Extract sliding local blocks from a **NCHW** batched image tensor. + + Args: + input: 4-D tensor, shape (N, C, H, W) **required**. + kernel_size: int or (kH, kW) + dilation: int or (dH, dW), default 1 + padding: int or (pH, pW), default 0 + stride: int or (sH, sW), default 1 + + Returns: + 3-D tensor, shape (N, C*kH*kW, L) + """ + + def _pair(x): + return (x, x) if isinstance(x, int) else x + + k = _pair(kernel_size) + d = _pair(dilation) + p = _pair(padding) + s = _pair(stride) + + N, C, H, W = input.shape + + # ---- padding ---- + if any(_ > 0 for _ in p): + input = np.pad( + input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant" + ) + + # ---- spatial size ---- + oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1 + oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1 + + i0 = np.arange(0, oH) * s[0] + j0 = np.arange(0, oW) * s[1] + i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW) + i = i.reshape(-1) + j = j.reshape(-1) + + # ---- flatten patches ---- + patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype) + for idx in range(k[0]): + for jdx in range(k[1]): + patches[:, :, idx, jdx, :] = input[ + :, :, i + idx * d[0], j + jdx * d[1] + ] + + # ---- reshape -> (N, C*kH*kW, L) ---- + return patches.reshape(N, C * k[0] * k[1], -1) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py new file mode 100644 index 000000000000..d8d4b8930341 --- /dev/null +++ b/keras/src/backend/numpy/numpy.py @@ -0,0 +1,1416 @@ +import numpy as np + +from keras.src import tree +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import standardize_axis_for_numpy +from keras.src.backend.numpy.core import convert_to_tensor + + +def rot90(array, k=1, axes=(0, 1)): + """Rotate an array by 90 degrees in the specified plane.""" + if array.ndim < 2: + raise ValueError( + "Input array must have at least 2 dimensions. " + f"Received: array.ndim={array.ndim}" + ) + if len(axes) != 2 or axes[0] == axes[1]: + raise ValueError( + f"Invalid axes: {axes}. Axes must be a tuple " + "of two different dimensions." + ) + return np.rot90(array, k=k, axes=axes) + + +def add(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.add(x1, x2) + + +def einsum(subscripts, *operands, **kwargs): + operands = tree.map_structure(convert_to_tensor, operands) + dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) + # When operands are of int8, we cast the result to int32 to align with + # the behavior of jax. + if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + compute_dtype = "int32" # prevent overflow + result_dtype = "int32" + else: + result_dtype = dtypes.result_type(*dtypes_to_resolve) + compute_dtype = result_dtype + # TODO: np.einsum doesn't support bfloat16 + if compute_dtype == "bfloat16": + compute_dtype = "float32" + operands = tree.map_structure(lambda x: x.astype(compute_dtype), operands) + return np.einsum(subscripts, *operands, **kwargs).astype(result_dtype) + + +def subtract(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.subtract(x1, x2) + + +def matmul(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + # When both x1 and x2 are of int8, we cast the outputs to int32 to align + # with jax + x1_dtype = standardize_dtype(x1.dtype) + x2_dtype = standardize_dtype(x2.dtype) + if x1_dtype == "int8" and x2_dtype == "int8": + dtype = "int32" + else: + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.matmul(x1, x2).astype(dtype) + + +def multiply(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + result_dtype = dtypes.result_type(x.dtype, "float32") + else: + result_dtype = ori_dtype + return np.mean(x, axis=axis, keepdims=keepdims).astype(result_dtype) + + +def max(x, axis=None, keepdims=False, initial=None): + axis = standardize_axis_for_numpy(axis) + return np.max(x, axis=axis, keepdims=keepdims, initial=initial) + + +def ones(shape, dtype=None): + dtype = dtype or config.floatx() + return np.ones(shape, dtype=dtype) + + +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() + return np.zeros(shape, dtype=dtype) + + +def absolute(x): + return np.absolute(x) + + +def abs(x): + return absolute(x) + + +def all(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.all(x, axis=axis, keepdims=keepdims) + + +def angle(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.angle(x) + + +def any(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.any(x, axis=axis, keepdims=keepdims) + + +def amax(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.amax(x, axis=axis, keepdims=keepdims) + + +def amin(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + return np.amin(x, axis=axis, keepdims=keepdims) + + +def append(x1, x2, axis=None): + axis = standardize_axis_for_numpy(axis) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.append(x1, x2, axis=axis) + + +def arange(start, stop=None, step=None, dtype=None): + if dtype is None: + dtypes_to_resolve = [getattr(start, "dtype", type(start))] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + if step is not None: + dtypes_to_resolve.append(getattr(step, "dtype", type(step))) + dtype = dtypes.result_type(*dtypes_to_resolve) + if stop is None: + start, stop = 0, start + if step is None: + step = 1 + return np.arange(start, stop, step=step, dtype=dtype) + + +def arccos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arccos(x) + + +def arccosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arccosh(x) + + +def arcsin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arcsin(x) + + +def arcsinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arcsinh(x) + + +def arctan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arctan(x) + + +def arctan2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.arctan2(x1, x2) + + +def arctanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.arctanh(x) + + +def argmax(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + axis = standardize_axis_for_numpy(axis) + dtype = standardize_dtype(x.dtype) + if "float" not in dtype or x.ndim == 0: + return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32") + + dtype = dtypes.result_type(dtype, "float32") + x = x.astype(dtype) + is_negative_zero = (x == 0.0) & np.signbit(x) + x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x) + return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32") + + +def argmin(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + axis = standardize_axis_for_numpy(axis) + dtype = standardize_dtype(x.dtype) + if "float" not in dtype or x.ndim == 0: + return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32") + + dtype = dtypes.result_type(dtype, "float32") + x = x.astype(dtype) + is_negative_zero = (x == 0.0) & np.signbit(x) + x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x) + return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32") + + +def argsort(x, axis=-1): + axis = standardize_axis_for_numpy(axis) + return np.argsort(x, axis=axis).astype("int32") + + +def array(x, dtype=None): + return convert_to_tensor(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = x.astype(dtype) + if weights is not None: + weights = weights.astype(dtype) + return np.average(x, weights=weights, axis=axis) + + +def bartlett(x): + x = convert_to_tensor(x) + return np.bartlett(x).astype(config.floatx()) + + +def hamming(x): + x = convert_to_tensor(x) + return np.hamming(x).astype(config.floatx()) + + +def hanning(x): + x = convert_to_tensor(x) + return np.hanning(x).astype(config.floatx()) + + +def heaviside(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype in ["int64"]: + dtype = "float64" + + return np.heaviside(x1, x2).astype(dtype) + + +def kaiser(x, beta): + x = convert_to_tensor(x) + return np.kaiser(x, beta).astype(config.floatx()) + + +def bincount(x, weights=None, minlength=0, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with numpy backend") + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + else: + dtype = "int32" + if len(x.shape) == 2: + if weights is None: + + def bincount_fn(arr): + return np.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return np.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + + return np.stack(bincounts).astype(dtype) + return np.bincount(x, weights, minlength).astype(dtype) + + +def bitwise_and(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = x.astype(dtype) + y = y.astype(dtype) + return np.bitwise_and(x, y) + + +def bitwise_invert(x): + x = convert_to_tensor(x) + return np.bitwise_not(x) + + +def bitwise_not(x): + return bitwise_invert(x) + + +def bitwise_or(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = x.astype(dtype) + y = y.astype(dtype) + return np.bitwise_or(x, y) + + +def bitwise_xor(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = x.astype(dtype) + y = y.astype(dtype) + return np.bitwise_xor(x, y) + + +def bitwise_left_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = x.astype(dtype) + y = y.astype(dtype) + return np.left_shift(x, y) + + +def left_shift(x, y): + return bitwise_left_shift(x, y) + + +def bitwise_right_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = x.astype(dtype) + y = y.astype(dtype) + return np.right_shift(x, y) + + +def right_shift(x, y): + return bitwise_right_shift(x, y) + + +def blackman(x): + x = convert_to_tensor(x) + return np.blackman(x).astype(config.floatx()) + + +def broadcast_to(x, shape): + return np.broadcast_to(x, shape) + + +def cbrt(x): + x = convert_to_tensor(x) + + dtype = standardize_dtype(x.dtype) + if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype == "int64": + dtype = "float64" + + return np.cbrt(x).astype(dtype) + + +def ceil(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.ceil(x) + + +def clip(x, x_min, x_max): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype == "bool": + dtype = "int32" + return np.clip(x, x_min, x_max).astype(dtype) + + +def concatenate(xs, axis=0): + axis = standardize_axis_for_numpy(axis) + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.concatenate(xs, axis=axis) + + +def conjugate(x): + return np.conjugate(x) + + +def conj(x): + return conjugate(x) + + +def copy(x): + return np.copy(x) + + +def cos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.cos(x) + + +def cosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.cosh(x) + + +def count_nonzero(x, axis=None): + axis = standardize_axis_for_numpy(axis) + # np.count_nonzero will return python int when axis=None, so we need + # to convert_to_tensor + return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32") + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + axis = standardize_axis_for_numpy(axis) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.cross( + x1, + x2, + axisa=axisa, + axisb=axisb, + axisc=axisc, + axis=axis, + ) + + +def cumprod(x, axis=None, dtype=None): + axis = standardize_axis_for_numpy(axis) + dtype = dtypes.result_type(dtype or x.dtype) + if dtype == "bool": + dtype = "int32" + return np.cumprod(x, axis=axis, dtype=dtype) + + +def cumsum(x, axis=None, dtype=None): + axis = standardize_axis_for_numpy(axis) + dtype = dtypes.result_type(dtype or x.dtype) + if dtype == "bool": + dtype = "int32" + return np.cumsum(x, axis=axis, dtype=dtype) + + +def deg2rad(x): + x = convert_to_tensor(x) + + if x.dtype in ["int64", "float64"]: + dtype = "float64" + elif x.dtype in ["bfloat16", "float16"]: + dtype = x.dtype + else: + dtype = config.floatx() + + return np.deg2rad(x).astype(dtype) + + +def diag(x, k=0): + return np.diag(x, k=k) + + +def diagflat(x, k=0): + return np.diagflat(x, k=k) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + axis1 = standardize_axis_for_numpy(axis1) + axis2 = standardize_axis_for_numpy(axis2) + return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + + +def diff(a, n=1, axis=-1): + return np.diff(a, n=n, axis=axis) + + +def digitize(x, bins): + return np.digitize(x, bins).astype(np.int32) + + +def dot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.dot(x1, x2) + + +def empty(shape, dtype=None): + dtype = dtype or config.floatx() + return np.empty(shape, dtype=dtype) + + +def equal(x1, x2): + return np.equal(x1, x2) + + +def exp(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.exp(x) + + +def exp2(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.exp2(x) + + +def expand_dims(x, axis): + axis = standardize_axis_for_numpy(axis) + return np.expand_dims(x, axis) + + +def expm1(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.expm1(x) + + +def flip(x, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.flip(x, axis=axis) + + +def floor(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = x.astype(dtype) + return np.floor(x) + + +def full(shape, fill_value, dtype=None): + dtype = dtype or config.floatx() + return np.full(shape, fill_value, dtype=dtype) + + +def full_like(x, fill_value, dtype=None): + return np.full_like(x, fill_value, dtype=dtype) + + +def gcd(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + return np.gcd(x1, x2).astype(dtype) + + +def greater(x1, x2): + return np.greater(x1, x2) + + +def greater_equal(x1, x2): + return np.greater_equal(x1, x2) + + +def hstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.hstack(xs) + + +def hypot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype in ["int64"]: + dtype = "float64" + + return np.hypot(x1, x2).astype(dtype) + + +def identity(n, dtype=None): + dtype = dtype or config.floatx() + return np.identity(n, dtype=dtype) + + +def imag(x): + return np.imag(x) + + +def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): + return np.isclose(x1, x2, rtol, atol, equal_nan) + + +def isfinite(x): + return np.isfinite(x) + + +def isin(x1, x2, assume_unique=False, invert=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return np.isin(x1, x2, assume_unique=assume_unique, invert=invert) + + +def isinf(x): + return np.isinf(x) + + +def isnan(x): + return np.isnan(x) + + +def isneginf(x): + x = convert_to_tensor(x) + return np.isneginf(x) + + +def isposinf(x): + x = convert_to_tensor(x) + return np.isposinf(x) + + +def kron(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + return np.kron(x1, x2).astype(dtype) + + +def lcm(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + return np.lcm(x1, x2).astype(dtype) + + +def less(x1, x2): + return np.less(x1, x2) + + +def less_equal(x1, x2): + return np.less_equal(x1, x2) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + axis = standardize_axis_for_numpy(axis) + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + return np.linspace( + start, + stop, + num=num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + + +def log(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log(x, dtype=dtype) + + +def log10(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log10(x, dtype=dtype) + + +def log1p(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log1p(x, dtype=dtype) + + +def log2(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.log2(x, dtype=dtype) + + +def logaddexp(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.logaddexp(x1, x2) + + +def logaddexp2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + return np.logaddexp2(x1, x2).astype(dtype) + + +def logical_and(x1, x2): + return np.logical_and(x1, x2) + + +def logical_not(x): + return np.logical_not(x) + + +def logical_or(x1, x2): + return np.logical_or(x1, x2) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + return np.logspace( + start, + stop, + num=num, + endpoint=endpoint, + base=base, + dtype=dtype, + axis=axis, + ) + + +def maximum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.maximum(x1, x2) + + +def median(x, axis=None, keepdims=False): + dtype = dtypes.result_type(x.dtype, float) + return np.median(x, axis=axis, keepdims=keepdims).astype(dtype) + + +def meshgrid(*x, indexing="xy"): + return np.meshgrid(*x, indexing=indexing) + + +def min(x, axis=None, keepdims=False, initial=None): + axis = standardize_axis_for_numpy(axis) + return np.min(x, axis=axis, keepdims=keepdims, initial=initial) + + +def minimum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.minimum(x1, x2) + + +def mod(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype == "bool": + dtype = "int32" + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.mod(x1, x2) + + +def moveaxis(x, source, destination): + return np.moveaxis(x, source=source, destination=destination) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +def ndim(x): + return np.ndim(x) + + +def nonzero(x): + return tuple(indices.astype("int32") for indices in np.nonzero(x)) + + +def not_equal(x1, x2): + return np.not_equal(x1, x2) + + +def zeros_like(x, dtype=None): + return np.zeros_like(x, dtype=dtype) + + +def ones_like(x, dtype=None): + return np.ones_like(x, dtype=dtype) + + +def outer(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.outer(x1, x2) + + +def pad(x, pad_width, mode="constant", constant_values=None): + kwargs = {} + if constant_values is not None: + if mode != "constant": + raise ValueError( + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`. " + f"Received: mode={mode}" + ) + kwargs["constant_values"] = constant_values + return np.pad(x, pad_width, mode=mode, **kwargs) + + +def prod(x, axis=None, keepdims=False, dtype=None): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + if dtype is None: + dtype = dtypes.result_type(x.dtype) + if dtype in ("bool", "int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + + ori_dtype = standardize_dtype(x.dtype) + # np.quantile doesn't support bool + if ori_dtype == "bool": + x = x.astype(config.floatx()) + if ori_dtype == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + return np.quantile( + x, q, axis=axis, method=method, keepdims=keepdims + ).astype(dtype) + + +def ravel(x): + return np.ravel(x) + + +def unravel_index(indices, shape): + dtype = dtypes.result_type(indices.dtype) + return tuple( + indices.astype(dtype) for indices in np.unravel_index(indices, shape) + ) + + +def real(x): + return np.real(x) + + +def reciprocal(x): + return np.reciprocal(x) + + +def repeat(x, repeats, axis=None): + return np.repeat(x, repeats, axis=axis) + + +def reshape(x, newshape): + return np.reshape(x, newshape) + + +def roll(x, shift, axis=None): + return np.roll(x, shift, axis=axis) + + +def searchsorted(sorted_sequence, values, side="left"): + if ndim(sorted_sequence) != 1: + raise ValueError( + "`searchsorted` only supports 1-D sorted sequences. " + "You can use `keras.ops.vectorized_map` " + "to extend it to N-D sequences. Received: " + f"sorted_sequence.shape={sorted_sequence.shape}" + ) + out_type = ( + "int32" + if sorted_sequence.shape[0] <= np.iinfo(np.int32).max + else "int64" + ) + return np.searchsorted(sorted_sequence, values, side=side).astype(out_type) + + +def sign(x): + return np.sign(x) + + +def signbit(x): + return np.signbit(x) + + +def sin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.sin(x) + + +def sinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.sinh(x) + + +def size(x): + return np.size(x) + + +def sort(x, axis=-1): + axis = standardize_axis_for_numpy(axis) + return np.sort(x, axis=axis) + + +def split(x, indices_or_sections, axis=0): + axis = standardize_axis_for_numpy(axis) + return np.split(x, indices_or_sections, axis=axis) + + +def stack(x, axis=0): + axis = standardize_axis_for_numpy(axis) + dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + x = tree.map_structure(lambda a: convert_to_tensor(a).astype(dtype), x) + return np.stack(x, axis=axis) + + +def std(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = x.astype(config.floatx()) + return np.std(x, axis=axis, keepdims=keepdims) + + +def swapaxes(x, axis1, axis2): + return np.swapaxes(x, axis1=axis1, axis2=axis2) + + +def take(x, indices, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.take(x, indices, axis=axis) + + +def take_along_axis(x, indices, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.take_along_axis(x, indices, axis=axis) + + +def tan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.tan(x) + + +def tanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = x.astype(dtype) + return np.tanh(x) + + +def tensordot(x1, x2, axes=2): + axes = tuple(axes) if isinstance(axes, list) else axes + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.tensordot(x1, x2, axes=axes) + + +def round(x, decimals=0): + return np.round(x, decimals=decimals) + + +def tile(x, repeats): + return np.tile(x, repeats) + + +def trace(x, offset=0, axis1=0, axis2=1): + axis1 = standardize_axis_for_numpy(axis1) + axis2 = standardize_axis_for_numpy(axis2) + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype not in ("int64", "uint32", "uint64"): + dtype = dtypes.result_type(dtype, "int32") + return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + + +def tri(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + return np.tri(N, M=M, k=k, dtype=dtype) + + +def tril(x, k=0): + return np.tril(x, k=k) + + +def triu(x, k=0): + return np.triu(x, k=k) + + +def trunc(x): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "int" in dtype or "bool" == dtype: + return x + return np.trunc(x) + + +def vdot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.vdot(x1, x2) + + +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) + return np.inner(x1, x2) + + +def vstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) + return np.vstack(xs) + + +def vectorize(pyfunc, *, excluded=None, signature=None): + return np.vectorize(pyfunc, excluded=excluded, signature=signature) + + +def where(condition, x1=None, x2=None): + if x1 is not None and x2 is not None: + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.where(condition, x1, x2) + else: + return np.where(condition) + + +def divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.divide(x1, x2) + + +def divide_no_nan(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + # No need for the double-where trick since we don't calculate gradients in + # numpy backend. + return np.where(x2 == 0, np.array(0, dtype=dtype), np.divide(x1, x2)) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.power(x1, x2) + + +def negative(x): + return np.negative(x) + + +def square(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + x = x.astype("int32") + return np.square(x) + + +def sqrt(x): + x = convert_to_tensor(x) + # upcast to float64 for int64 which matches JAX's behavior + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return np.sqrt(x, dtype=dtype) + + +def squeeze(x, axis=None): + axis = standardize_axis_for_numpy(axis) + return np.squeeze(x, axis=axis) + + +def transpose(x, axes=None): + axes = tuple(axes) if isinstance(axes, list) else axes + return np.transpose(x, axes=axes) + + +def var(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + x = convert_to_tensor(x) + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + return np.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype).astype( + result_dtype + ) + + +def sum(x, axis=None, keepdims=False): + axis = standardize_axis_for_numpy(axis) + dtype = standardize_dtype(x.dtype) + # follow jax's rule + if dtype in ("bool", "int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + return np.sum(x, axis=axis, keepdims=keepdims).astype(dtype) + + +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + return np.eye(N, M=M, k=k, dtype=dtype) + + +def floor_divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)) + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.floor_divide(x1, x2) + + +def logical_xor(x1, x2): + return np.logical_xor(x1, x2) + + +def corrcoef(x): + if x.dtype in ["int64", "float64"]: + dtype = "float64" + elif x.dtype in ["bfloat16", "float16"]: + dtype = x.dtype + else: + dtype = config.floatx() + + x = convert_to_tensor(x) + + return np.corrcoef(x).astype(dtype) + + +def correlate(x1, x2, mode="valid"): + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if dtype == "int64": + dtype = "float64" + elif dtype not in ["bfloat16", "float16", "float64"]: + dtype = "float32" + + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return np.correlate(x1, x2, mode) + + +def select(condlist, choicelist, default=0): + return np.select(condlist, choicelist, default=default) + + +def slogdet(x): + return tuple(np.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + return np.argpartition(x, kth, axis).astype("int32") + + +def histogram(x, bins=10, range=None): + return np.histogram(x, bins=bins, range=range) diff --git a/keras/src/backend/numpy/random.py b/keras/src/backend/numpy/random.py new file mode 100644 index 000000000000..f8fd65aa38ba --- /dev/null +++ b/keras/src/backend/numpy/random.py @@ -0,0 +1,120 @@ +import numpy as np + +from keras.src.backend.config import floatx +from keras.src.backend.numpy.nn import softmax +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.uniform(size=shape, low=minval, high=maxval).astype(dtype) + + +def categorical(logits, num_samples, dtype="int64", seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + output = [] + for logits_instance in logits: + probabilities = softmax(logits_instance) + classes = np.arange(logits_instance.shape[-1]) + samples = rng.choice(classes, size=num_samples, p=probabilities) + output.append(samples) + return np.array(output).astype(dtype) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + output = rng.integers(low=minval, high=maxval, size=shape, dtype=dtype) + return output + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + + lower_bound = mean - 2 * stddev + upper_bound = mean + 2 * stddev + + flat_shape = np.prod(shape) + random_numbers = np.empty(0) + + # loop until we have enough valid numbers to fill our desired shape + while random_numbers.shape[0] < flat_shape: + # Generate a batch of random numbers from a normal distribution + batch = rng.normal(loc=mean, scale=stddev, size=flat_shape) + + # Filter the numbers to keep only those within the specified bounds + valid = batch[(batch >= lower_bound) & (batch <= upper_bound)] + + # Append the valid numbers to the result array + random_numbers = np.append(random_numbers, valid) + + # Truncate the result array to the desired size and reshape it + return random_numbers[:flat_shape].astype(dtype).reshape(shape) + + +def dropout(inputs, rate, noise_shape=None, seed=None): + dtype = inputs.dtype + seed = draw_seed(seed) + + keep_prob = 1.0 - rate + + # If noise_shape is not provided, use the shape of inputs + if noise_shape is None: + noise_shape = inputs.shape + else: + # If noise_shape is provided, replace None with corresponding + # input shape + noise_shape = [ + n if n is not None else inputs.shape[i] + for i, n in enumerate(noise_shape) + ] + + rng = np.random.default_rng(seed) + mask = rng.uniform(size=noise_shape) < keep_prob + mask = np.broadcast_to(mask, inputs.shape) + return np.where( + mask, (inputs / keep_prob).astype(dtype), np.zeros_like(inputs) + ) + + +def shuffle(x, axis=0, seed=None): + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.permuted(x, axis=axis) + + +def gamma(shape, alpha, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + return rng.gamma(alpha, scale=1.0, size=shape).astype(dtype) + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + sample = rng.binomial(n=counts, p=probabilities, size=shape).astype(dtype) + return sample + + +def beta(shape, alpha, beta, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed) + sample = rng.beta(a=alpha, b=beta, size=shape).astype(dtype) + return sample diff --git a/keras/src/backend/numpy/rnn.py b/keras/src/backend/numpy/rnn.py new file mode 100644 index 000000000000..7a3f990112dc --- /dev/null +++ b/keras/src/backend/numpy/rnn.py @@ -0,0 +1,243 @@ +import numpy as np + +from keras.src import tree + + +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return np.transpose(input_t, axes) + + if not time_major: + inputs = tree.map_structure(swap_batch_timestep, inputs) + + flattened_inputs = tree.flatten(inputs) + time_steps = flattened_inputs[0].shape[0] + + if mask is not None: + if mask.dtype != "bool": + mask = mask.astype("bool") + if len(mask.shape) == 2: + mask = np.expand_dims(mask, axis=-1) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tree.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor, but got {mask_t}" + ) + if tree.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor, but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = np.expand_dims(mask_t, -1) + multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:]) + return np.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = unstack(input_t) # unstack for time_step dim + if go_backwards: + input_t.reverse() + return input_t + + if tree.is_nested(inputs): + processed_input = tree.map_structure( + _process_single_input_t, inputs + ) + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return tree.pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = unstack(mask) + if go_backwards: + mask_list.reverse() + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = np.zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = np.where(tiled_mask_t, output, prev_output) + + flat_states = tree.flatten(states) + flat_new_states = tree.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) + flat_final_states = tuple( + np.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) + ) + states = tree.pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = np.stack(successive_outputs) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = np.stack(successive_outputs) + + else: # Unroll == False + if mask is not None: + + def _step(states, current_input): + current_input, current_mask = current_input + is_masked = np.all( + np.logical_not(current_mask), axis=-1, keepdims=True + ) + + output_t, new_states = step_function(current_input, states) + + if zero_output_for_mask: + masked_outs = np.where( + is_masked, np.zeros_like(output_t), output_t + ) + else: + # Assume the first state is the previous output. + output_tm1 = states[0] + if tree.is_nested(output_tm1): + # Stacked RNN case: assume first state of last cell. + output_tm1 = states[-1][0] + masked_outs = np.where(is_masked, output_tm1, output_t) + + new_states = tree.map_structure( + lambda s, ns: np.where(is_masked, s, ns), + states, + new_states, + ) + return (new_states, masked_outs) + + scan_xs = (inputs, mask) + + else: + + def _step(states, current_input): + output_t, new_states = step_function(current_input, states) + return new_states, output_t + + scan_xs = inputs + + new_states, outputs = numpy_scan( + f=_step, + init=initial_states, + xs=scan_xs, + reverse=go_backwards, + mask=mask, + ) + + if go_backwards: + outputs = np.flip(outputs, axis=0) + last_output = outputs[-1] + + if not time_major: + outputs = tree.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +def lstm(*args, **kwargs): + raise NotImplementedError + + +def gru(*args, **kwargs): + raise NotImplementedError + + +def unstack(x, axis=0): + return [x.take(i, axis) for i in range(x.shape[axis])] + + +def numpy_scan(f, init, xs, reverse=False, mask=None): + states = init + outputs = [] + + if mask is not None: + x, mask = xs + x = np.flip(x, axis=0) if reverse else x + mask = np.flip(mask, axis=0) if reverse else mask + + for each_x, each_mask in zip(x, mask): + states, output = f(states, (each_x, each_mask)) + outputs.append(output) + else: + xs = np.flip(xs, axis=0) if reverse else xs + + for x in xs: + states, output = f(states, x) + outputs.append(output) + + outputs = np.array(outputs) + + if reverse: + outputs = np.flip(outputs, axis=0) + + return states, outputs + + +def cudnn_ok(*args, **kwargs): + return False diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py new file mode 100644 index 000000000000..fd8c276a86d2 --- /dev/null +++ b/keras/src/backend/numpy/trainer.py @@ -0,0 +1,331 @@ +import numpy as np + +from keras.src import backend +from keras.src import callbacks as callbacks_module +from keras.src import tree +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.numpy.core import is_tensor +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + + +class NumpyTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.test_function = None + self.predict_function = None + + def test_step(self, data): + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + loss = self._compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False + ) + self._loss_tracker.update_state( + loss, sample_weight=tree.flatten(x)[0].shape[0] + ) + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + return y_pred + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + + def one_test_step(data): + data = data[0] + return self.test_step(data) + + def multi_test_steps(data): + for single_step_data in data: + logs = one_test_step([single_step_data]) + return logs + + if self.steps_per_execution > 1: + test_step = multi_test_steps + else: + test_step = one_test_step + + self.test_function = test_step + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + def one_predict_step(data): + data = data[0] + return self.predict_step(data) + + def multi_predict_steps(data): + outputs = one_predict_step(data[:1]) + + for single_step_data in data[1:]: + step_outputs = one_predict_step([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: np.concatenate([t1, t2]), + outputs, + step_outputs, + ) + return outputs + + if self.steps_per_execution > 1: + predict_step = multi_predict_steps + else: + predict_step = one_predict_step + + self.predict_function = predict_step + + def _symbolic_build(self, data_batch): + model_unbuilt = not all(layer.built for layer in self._flatten_layers()) + compile_metrics_unbuilt = ( + self._compile_metrics is not None + and not self._compile_metrics.built + ) + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: + # Create symbolic tensors matching an input batch. + + def to_symbolic_input(v): + if is_tensor(v): + return KerasTensor(v.shape, standardize_dtype(v.dtype)) + return v + + data_batch = tree.map_structure(to_symbolic_input, data_batch) + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) + # Build all model state with `backend.compute_output_spec`. + try: + y_pred = backend.compute_output_spec(self, x) + except: + raise RuntimeError( + "Unable to automatically build the model. " + "Please build it yourself before calling " + "fit/evaluate/predict. " + "A model is 'built' when its variables have " + "been created and its `self.built` attribute " + "is True. Usually, calling the model on a batch " + "of data is the right way to build it." + ) + if compile_metrics_unbuilt: + # Build all metric state with `backend.compute_output_spec`. + backend.compute_output_spec( + self.compute_metrics, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + self._post_build() + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + raise NotImplementedError("fit not implemented for NumPy backend.") + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = EpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + outputs = None + for begin_step, end_step, data in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) + if self.stop_predicting: + break + callbacks.on_predict_end() + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = EpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + if not all(layer.built for layer in self._flatten_layers()): + # Build the model on one batch of data. + for _, _, data in epoch_iterator: + data_batch = data[0] + self._symbolic_build(data_batch) + break + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + self.reset_metrics() + for begin_step, end_step, data in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) + logs = self.test_function(data) + callbacks.on_test_batch_end(end_step, logs) + if self.stop_evaluating: + break + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "train_on_batch not implemented for NumPy backend." + ) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + self._assert_compile_called("test_on_batch") + + data = (x, y, sample_weight) + + # Maybe build model + self._symbolic_build(data) + self.make_test_function() + + logs = self.test_function([data]) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function([(x,)]) + batch_outputs = tree.map_structure( + backend.convert_to_numpy, batch_outputs + ) + return batch_outputs diff --git a/keras/src/backend/openvino/__init__.py b/keras/src/backend/openvino/__init__.py new file mode 100644 index 000000000000..0612260452ea --- /dev/null +++ b/keras/src/backend/openvino/__init__.py @@ -0,0 +1,25 @@ +from keras.src.backend.common.name_scope import name_scope +from keras.src.backend.openvino import core +from keras.src.backend.openvino import image +from keras.src.backend.openvino import linalg +from keras.src.backend.openvino import math +from keras.src.backend.openvino import nn +from keras.src.backend.openvino import numpy +from keras.src.backend.openvino import random +from keras.src.backend.openvino.core import IS_THREAD_SAFE +from keras.src.backend.openvino.core import SUPPORTS_RAGGED_TENSORS +from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.openvino.core import Variable +from keras.src.backend.openvino.core import cast +from keras.src.backend.openvino.core import compute_output_spec +from keras.src.backend.openvino.core import cond +from keras.src.backend.openvino.core import convert_to_numpy +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import is_tensor +from keras.src.backend.openvino.core import random_seed_dtype +from keras.src.backend.openvino.core import shape +from keras.src.backend.openvino.core import vectorized_map +from keras.src.backend.openvino.rnn import cudnn_ok +from keras.src.backend.openvino.rnn import gru +from keras.src.backend.openvino.rnn import lstm +from keras.src.backend.openvino.rnn import rnn diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py new file mode 100644 index 000000000000..93f9f5819c8b --- /dev/null +++ b/keras/src/backend/openvino/core.py @@ -0,0 +1,1187 @@ +import builtins +import contextlib +import warnings + +import numpy as np +import openvino as ov +import openvino.opset14 as ov_opset +from openvino import Model +from openvino import Tensor +from openvino import Type +from openvino import compile_model + +from keras.src import tree +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import dtypes +from keras.src.backend.common import global_state +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.stateless_scope import StatelessScope + +SUPPORTS_SPARSE_TENSORS = False +SUPPORTS_RAGGED_TENSORS = False +IS_THREAD_SAFE = True + +OPENVINO_DTYPES = { + "float16": ov.Type.f16, + "float32": ov.Type.f32, + "float64": ov.Type.f64, + "uint8": ov.Type.u8, + "uint16": ov.Type.u16, + "uint32": ov.Type.u32, + "uint64": ov.Type.u64, + "int8": ov.Type.i8, + "int16": ov.Type.i16, + "int32": ov.Type.i32, + "int64": ov.Type.i64, + "bfloat16": ov.Type.bf16, + "bool": ov.Type.boolean, + "float8_e4m3fn": ov.Type.f8e4m3, + "float8_e5m2": ov.Type.f8e5m2, + "string": ov.Type.string, +} + +DTYPES_MAX = { + ov.Type.bf16: 3.38953139e38, + ov.Type.f16: np.finfo(np.float16).max, + ov.Type.f32: np.finfo(np.float32).max, + ov.Type.f64: np.finfo(np.float64).max, + ov.Type.u8: np.iinfo(np.uint8).max, + ov.Type.u16: np.iinfo(np.uint16).max, + ov.Type.u32: np.iinfo(np.uint32).max, + ov.Type.u64: np.iinfo(np.uint64).max, + ov.Type.i8: np.iinfo(np.int8).max, + ov.Type.i16: np.iinfo(np.int16).max, + ov.Type.i32: np.iinfo(np.int32).max, + ov.Type.i64: np.iinfo(np.int64).max, + ov.Type.boolean: 1, +} + +DTYPES_MIN = { + ov.Type.bf16: -3.38953139e38, + ov.Type.f16: np.finfo(np.float16).min, + ov.Type.f32: np.finfo(np.float32).min, + ov.Type.f64: np.finfo(np.float64).min, + ov.Type.u8: np.iinfo(np.uint8).min, + ov.Type.u16: np.iinfo(np.uint16).min, + ov.Type.u32: np.iinfo(np.uint32).min, + ov.Type.u64: np.iinfo(np.uint64).min, + ov.Type.i8: np.iinfo(np.int8).min, + ov.Type.i16: np.iinfo(np.int16).min, + ov.Type.i32: np.iinfo(np.int32).min, + ov.Type.i64: np.iinfo(np.int64).min, + ov.Type.boolean: 0, +} + + +def align_operand_types(x1, x2, op_name): + x1_type = x1.element_type + x2_type = x2.element_type + if x1_type.is_dynamic() or x2_type.is_dynamic(): + raise ValueError( + f"'{op_name}' operation is not supported for dynamic operand type " + "with openvino backend" + ) + x1_type = ov_to_keras_type(x1_type) + x2_type = ov_to_keras_type(x2_type) + result_type = dtypes.result_type(x1_type, x2_type) + result_type = OPENVINO_DTYPES[result_type] + if x1_type != result_type: + x1 = ov_opset.convert(x1, result_type).output(0) + if x2_type != result_type: + x2 = ov_opset.convert(x2, result_type).output(0) + return x1, x2 + + +# create ov.Output (symbolic OpenVINO tensor) +# for different input `x` +def get_ov_output(x, ov_type=None): + if isinstance(x, float): + if ov_type is None: + ov_type = Type.f32 + x = ov_opset.constant(x, ov_type).output(0) + elif isinstance(x, int): + if ov_type is None: + ov_type = Type.i32 + x = ov_opset.constant(x, ov_type).output(0) + elif isinstance(x, np.ndarray): + if x.dtype == np.dtype("bfloat16"): + x = ov_opset.constant(x, OPENVINO_DTYPES["bfloat16"]).output(0) + else: + x = ov_opset.constant(x).output(0) + elif isinstance(x, (list, tuple)): + if isinstance(x, tuple): + x = list(x) + if ov_type is None: + x = ov_opset.constant(x).output(0) + else: + x = ov_opset.constant(x, ov_type).output(0) + elif np.isscalar(x): + x = ov_opset.constant(x).output(0) + elif isinstance(x, KerasVariable): + if isinstance(x.value, OpenVINOKerasTensor): + return x.value.output + x = ov_opset.constant(x.value.data).output(0) + elif isinstance(x, OpenVINOKerasTensor): + x = x.output + elif isinstance(x, Tensor): + x = ov_opset.constant(x.data).output(0) + else: + raise ValueError( + "unsupported type of `x` to create ov.Output: {}".format(type(x)) + ) + return x + + +# wrapper for OpenVINO symbolic tensor ov.Output +# that provides interface similar to KerasTensor +# with dtype and shape members +class OpenVINOKerasTensor: + def __init__(self, x, data=None): + x_shape = x.get_partial_shape() + if x_shape.rank.is_dynamic: + x_keras_shape = None + else: + x_keras_shape = [ + None if dim.is_dynamic else dim.get_length() + for dim in list(x_shape) + ] + x_type = x.get_element_type() + x_keras_type = ov_to_keras_type(x_type) + self.output = x + self.shape = tuple(x_keras_shape) + self.dtype = x_keras_type + self.ndim = None + self.data = data + if x.get_partial_shape().rank.is_static: + self.ndim = x.get_partial_shape().rank.get_length() + + def __add__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__add__" + ) + return OpenVINOKerasTensor(ov_opset.add(first, other).output(0)) + + def __radd__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__radd__" + ) + return OpenVINOKerasTensor(ov_opset.add(first, other).output(0)) + + def __sub__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__sub__" + ) + if first.get_element_type() == Type.boolean: + return OpenVINOKerasTensor( + ov_opset.logical_xor(first, other).output(0) + ) + return OpenVINOKerasTensor(ov_opset.subtract(first, other).output(0)) + + def __rsub__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__rsub__" + ) + return OpenVINOKerasTensor(ov_opset.subtract(other, first).output(0)) + + def __mul__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__mul__" + ) + if first.get_element_type() == Type.boolean: + return OpenVINOKerasTensor( + ov_opset.logical_and(first, other).output(0) + ) + return OpenVINOKerasTensor(ov_opset.multiply(first, other).output(0)) + + def __rmul__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__rmul__" + ) + if first.get_element_type() == Type.boolean: + return OpenVINOKerasTensor( + ov_opset.logical_and(first, other).output(0) + ) + return OpenVINOKerasTensor(ov_opset.multiply(first, other).output(0)) + + def __truediv__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__truediv__" + ) + return OpenVINOKerasTensor(ov_opset.divide(first, other).output(0)) + + def __rtruediv__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__rtruediv__" + ) + return OpenVINOKerasTensor(ov_opset.divide(other, first).output(0)) + + def __floordiv__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__floordiv__" + ) + return OpenVINOKerasTensor(ov_opset.divide(first, other).output(0)) + + def __rfloordiv__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__rfloordiv__" + ) + return OpenVINOKerasTensor(ov_opset.divide(other, first).output(0)) + + def __neg__(self): + first = self.output + return OpenVINOKerasTensor(ov_opset.negative(first).output(0)) + + def __abs__(self): + first = self.output + return OpenVINOKerasTensor(ov_opset.absolute(first).output(0)) + + def __invert__(self): + first = self.output + return OpenVINOKerasTensor(ov_opset.logical_not(first).output(0)) + + def __pow__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__pow__" + ) + return OpenVINOKerasTensor(ov_opset.power(first, other).output(0)) + + def __rpow__(self, other): + other = get_ov_output(other) + first = self.output + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__rpow__" + ) + return OpenVINOKerasTensor(ov_opset.power(other, first).output(0)) + + def __lt__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__lt__" + ) + return OpenVINOKerasTensor(ov_opset.less(first, other).output(0)) + + def __gt__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__gt__" + ) + return OpenVINOKerasTensor(ov_opset.greater(first, other).output(0)) + + def __le__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__le__" + ) + return OpenVINOKerasTensor(ov_opset.less_equal(first, other).output(0)) + + def __ge__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__ge__" + ) + return OpenVINOKerasTensor( + ov_opset.greater_equal(first, other).output(0) + ) + + def __eq__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__eq__" + ) + return OpenVINOKerasTensor(ov_opset.equal(first, other).output(0)) + + def __ne__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__ne__" + ) + return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0)) + + def __getitem__(self, indices): + data = self.output + rank = len(data.get_partial_shape()) + axes, gather_indices_nodes = [], [] + slice_axes, slice_starts, slice_ends, slice_steps = [], [], [], [] + unsqueeze_axes = [] + + if not isinstance(indices, tuple): + indices = (indices,) + + if any(i is Ellipsis for i in indices): + ellipsis_pos = indices.index(Ellipsis) + num_specified = sum( + i is not Ellipsis and i is not None for i in indices + ) + num_missing = rank - num_specified + indices = ( + indices[:ellipsis_pos] + + (builtins.slice(None),) * num_missing + + indices[ellipsis_pos + 1 :] + ) + + def count_unsqueeze_before(dim): + return sum(1 for i in range(dim) if indices[i] is None) + + partial_shape = ov_opset.shape_of(data, Type.i32) + zero_const = ov_opset.constant(0, Type.i32) + + for dim, index in enumerate(indices): + if isinstance(index, bool): + raise ValueError( + "OpenVINO backend does not support boolean indexing" + ) + elif isinstance(index, (int, np.integer, np.ndarray)): + if isinstance(index, (np.ndarray, np.integer)): + if isinstance(index, np.ndarray) and len(index.shape) != 0: + raise ValueError( + "OpenVINO backend does not support" + "multi-dimensional indexing" + ) + index = int(index) + actual_dim = dim - count_unsqueeze_before(dim) + if not (0 <= actual_dim < rank): + raise IndexError( + f"Index {index} is out of bounds for " + f"axis {dim} with rank {rank}" + ) + length = ov_opset.gather( + partial_shape, + ov_opset.constant([actual_dim], Type.i32), + zero_const, + ) + if index >= 0: + idx_value = ov_opset.constant([index], Type.i32) + else: + idx_value = ov_opset.add( + ov_opset.constant([index], Type.i32), length + ) + axes.append(dim) + gather_indices_nodes.append(idx_value.output(0)) + elif isinstance(index, builtins.slice): + if index == builtins.slice(None): + continue + if index.step is not None and index.step < 0: + raise ValueError("OpenVINO doesn't support negative steps") + slice_axes.append(dim) + slice_starts.append(0 if index.start is None else index.start) + slice_ends.append( + 2**31 - 1 if index.stop is None else index.stop + ) + slice_steps.append(1 if index.step is None else index.step) + elif index is None: + unsqueeze_axes.append(dim) + elif isinstance(index, OpenVINOKerasTensor): + index = get_ov_output(index) + index_type = index.get_element_type() + index_shape = index.get_partial_shape() + if index_type == Type.boolean or not index_type.is_integral(): + raise ValueError( + "OpenVINO backend does not " + f"support {index_type} indexing" + ) + axes.append(dim) + if len(index_shape) > 1: + raise ValueError( + "OpenVINO backend does not " + "support multi-dimensional indexing" + ) + if len(index_shape) == 0: + index = ov_opset.unsqueeze(index, zero_const).output(0) + if index_type != Type.i32: + index = ov_opset.convert(index, Type.i32).output(0) + shape_tensor = ov_opset.shape_of(data, Type.i32) + axis_i32 = ov_opset.constant([dim], dtype=Type.i32) + dim_size = ov_opset.gather(shape_tensor, axis_i32, zero_const) + is_negative = ov_opset.less(index, zero_const) + adjusted_index = ov_opset.add(index, dim_size) + index = ov_opset.select( + is_negative, adjusted_index, index + ).output(0) + gather_indices_nodes.append(index) + else: + raise ValueError( + f"Unsupported index type {type(index)} " + "in OpenVINOKerasTensor.__getitem__" + ) + + if slice_axes: + step = ov_opset.constant(slice_steps, Type.i32).output(0) + start = ov_opset.constant(slice_starts, Type.i32).output(0) + stop = ov_opset.constant(slice_ends, Type.i32).output(0) + adjusted_slice_axes = [ + ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax) + for ax in slice_axes + ] + axes_const = ov_opset.constant( + adjusted_slice_axes, Type.i32 + ).output(0) + data = ov_opset.slice(data, start, stop, step, axes_const).output(0) + + if axes: + gather_indices_const = ( + gather_indices_nodes[0] + if len(gather_indices_nodes) == 1 + else ov_opset.concat(gather_indices_nodes, axis=0).output(0) + ) + adjusted_axes = [ + ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax) + for ax in axes + ] + if len(axes) == 1: + data = ov_opset.gather( + data, gather_indices_const, adjusted_axes[0] + ).output(0) + data = ov_opset.squeeze(data, adjusted_axes[0]).output(0) + else: + rank = len(data.get_partial_shape()) + remaining_axes = [ + i for i in range(rank) if i not in adjusted_axes + ] + perm = ov_opset.constant( + adjusted_axes + remaining_axes, Type.i32 + ) + data = ov_opset.transpose(data, perm).output(0) + data = ov_opset.gather_nd(data, gather_indices_const).output(0) + + if unsqueeze_axes: + adjusted_unsqueeze = [] + for ax in unsqueeze_axes: + ax -= sum(1 for s in axes if s < ax) + ax -= sum(1 for s in slice_axes if s < ax) + adjusted_unsqueeze.append(ax) + unsqueeze_const = ov_opset.constant( + adjusted_unsqueeze, Type.i32 + ).output(0) + data = ov_opset.unsqueeze(data, unsqueeze_const).output(0) + + return OpenVINOKerasTensor(data) + + def __len__(self): + ov_output = self.output + ov_shape = ov_output.get_partial_shape() + assert ov_shape.rank.is_static and ov_shape.rank.get_length() > 0, ( + "rank must be static and greater than zero" + ) + assert ov_shape[0].is_static, "the first dimension must be static" + return ov_shape[0].get_length() + + def __mod__(self, other): + first = self.output + other = get_ov_output(other) + first, other = align_operand_types( + first, other, "OpenVINOKerasTensor::__mod__" + ) + return OpenVINOKerasTensor(ov_opset.mod(first, other).output(0)) + + def __array__(self, dtype=None): + try: + tensor = cast(self, dtype=dtype) if dtype is not None else self + return convert_to_numpy(tensor) + except Exception as e: + raise RuntimeError( + "An OpenVINOKerasTensor is symbolic: it's a placeholder " + "for a shape and a dtype.\n" + "It doesn't have any actual numerical value.\n" + "You cannot convert it to a NumPy array." + ) from e + + def numpy(self): + return self.__array__() + + +def ov_to_keras_type(ov_type): + for _keras_type, _ov_type in OPENVINO_DTYPES.items(): + if ov_type == _ov_type: + return _keras_type + raise ValueError( + f"Requested OpenVINO type has no keras analogue '{ov_type.to_string()}'" + ) + + +@contextlib.contextmanager +def device_scope(device_name): + current_device = _parse_device_input(device_name) + global_state.set_global_attribute("openvino_device", current_device) + + +def get_device(): + device = global_state.get_global_attribute("openvino_device", None) + if device is None: + return "CPU" + return device + + +def _parse_device_input(device_name): + if isinstance(device_name, str): + # We support string value like "cpu:0", "gpu:1", and need to convert + # "gpu" to "cuda" + device_name = device_name.upper() + device_type, _ = device_name.split(":") + return device_type + else: + raise ValueError( + "Invalid value for argument `device_name`. " + "Expected a string like 'gpu:0' or 'cpu'. " + f"Received: device_name='{device_name}'" + ) + return device_name + + +class Variable(KerasVariable): + def _initialize(self, value): + if isinstance(value, OpenVINOKerasTensor): + self._value = value + elif isinstance(value, Tensor): + value_const = ov_opset.constant( + value.data, dtype=OPENVINO_DTYPES[self._dtype] + ) + self._value = OpenVINOKerasTensor(value_const.output(0)) + else: + value_const = ov_opset.constant( + value, dtype=OPENVINO_DTYPES[self._dtype] + ) + self._value = OpenVINOKerasTensor(value_const.output(0)) + + def _direct_assign(self, value): + self._value = value + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) + + def __array__(self): + if isinstance(self.value, OpenVINOKerasTensor): + return self.value.output.get_node().data + return self.value.data + + def __getitem__(self, idx): + if isinstance(self.value, OpenVINOKerasTensor): + arr = self.value.output.get_node().data + return arr.__getitem__(idx) + return self.value.__getitem__(idx) + + def __int__(self): + if isinstance(self.value, OpenVINOKerasTensor): + arr = self.value.output.get_node().data + else: + arr = self.value.data + if arr.ndim > 0: + raise TypeError( + "Only scalar arrays can be converted to Python scalars. " + f"Got: shape={arr.shape}" + ) + return int(arr) + + def __float__(self): + if isinstance(self.value, OpenVINOKerasTensor): + arr = self.value.output.get_node().data + else: + arr = self.value.data + if arr.ndim > 0: + raise TypeError( + "Only scalar arrays can be converted to Python scalars. " + f"Got: shape={arr.shape}" + ) + return float(arr) + + +def _is_scalar(elem): + return not isinstance(elem, (list, tuple, set, dict)) + + +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + if sparse: + raise ValueError("`sparse=True` is not supported with openvino backend") + if ragged: + raise ValueError("`ragged=True` is not supported with openvino backend") + if dtype is not None: + dtype = standardize_dtype(dtype) + if isinstance(x, OpenVINOKerasTensor): + if dtype and dtype != standardize_dtype(x.dtype): + x = cast(x, dtype) + return x + elif isinstance(x, np.ndarray): + if dtype is not None: + ov_type = OPENVINO_DTYPES[dtype] + else: + ov_type = OPENVINO_DTYPES[standardize_dtype(x.dtype)] + return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0)) + elif isinstance(x, (list, tuple)): + if dtype is None: + dtype = result_type( + *[ + getattr(item, "dtype", type(item)) + for item in tree.flatten(x) + ] + ) + x = np.array(x, dtype=dtype) + ov_type = OPENVINO_DTYPES[dtype] + return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x) + elif isinstance(x, (float, int, bool)): + if dtype is None: + dtype = standardize_dtype(type(x)) + ov_type = OPENVINO_DTYPES[dtype] + return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x) + elif isinstance(x, ov.Output): + return OpenVINOKerasTensor(x) + if isinstance(x, Variable): + x = x.value + if dtype and dtype != x.dtype: + x = cast(x, dtype) + return x + original_type = type(x) + try: + if dtype is None: + dtype = getattr(x, "dtype", original_type) + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + else: + ov_type = OPENVINO_DTYPES[dtype] + x = np.array(x) + return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0)) + except Exception as e: + raise TypeError( + f"Cannot convert object of type {original_type} " + f"to OpenVINOKerasTensor: {e}" + ) + + +def convert_to_numpy(x): + if isinstance(x, np.ndarray): + return x + elif isinstance(x, (int, float)): + return np.array(x) + elif isinstance(x, (list, tuple)): + x_new = [] + for elem in x: + x_new.append(convert_to_numpy(elem)) + return np.array(x_new) + elif np.isscalar(x): + return x + elif isinstance(x, ov.Tensor): + return x.data + elif x is None: + return x + elif isinstance(x, KerasVariable): + if isinstance(x.value, OpenVINOKerasTensor): + x = x.value + else: + return x.value.data + assert isinstance(x, OpenVINOKerasTensor), ( + "unsupported type {} for `convert_to_numpy` in openvino backend".format( + type(x) + ) + ) + try: + ov_result = x.output + ov_model = Model(results=[ov_result], parameters=[]) + ov_compiled_model = compile_model(ov_model, get_device()) + result = ov_compiled_model({})[0] + except Exception as inner_exception: + raise RuntimeError( + "`convert_to_numpy` failed to convert the tensor." + ) from inner_exception + return result + + +def is_tensor(x): + if isinstance(x, OpenVINOKerasTensor): + return True + if isinstance(x, ov.Tensor): + return True + return False + + +def shape(x): + return tuple(x.shape) + + +def cast(x, dtype): + dtype = standardize_dtype(dtype) + ov_type = OPENVINO_DTYPES[dtype] + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.convert(x, ov_type).output(0)) + + +def cond(pred, true_fn, false_fn): + raise NotImplementedError("`cond` is not supported with openvino backend") + + +def vectorized_map(function, elements): + raise NotImplementedError( + "`vectorized_map` is not supported with openvino backend" + ) + + +# Shape / dtype inference util +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(): + + def convert_keras_tensor_to_openvino(x): + if isinstance(x, KerasTensor): + x_shape = list(x.shape) + x_shape = [-1 if dim is None else dim for dim in x_shape] + x_type = OPENVINO_DTYPES[x.dtype] + param = ov_opset.parameter(shape=x_shape, dtype=x_type) + return OpenVINOKerasTensor(param.output(0)) + return x + + args_1, kwargs_1 = tree.map_structure( + lambda x: convert_keras_tensor_to_openvino(x), + (args, kwargs), + ) + outputs_1 = fn(*args_1, **kwargs_1) + + outputs = outputs_1 + + def convert_openvino_to_keras_tensor(x): + if is_tensor(x): + x_type = x.dtype + x_shape = x.shape + return KerasTensor(x_shape, x_type) + elif isinstance(x, OpenVINOKerasTensor): + x_type = x.dtype + x_shape = x.shape + return KerasTensor(x_shape, x_type) + return x + + output_spec = tree.map_structure( + convert_openvino_to_keras_tensor, outputs + ) + return output_spec + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + raise NotImplementedError("`scan` is not supported with openvino backend") + + +def scatter(indices, values, shape): + raise NotImplementedError( + "`scatter` is not supported with openvino backend" + ) + + +def scatter_update(inputs, indices, updates): + raise NotImplementedError( + "`scatter_update` is not supported with openvino backend" + ) + + +def slice(inputs, start_indices, shape): + inputs = get_ov_output(inputs) + if isinstance(start_indices, (list, np.ndarray)): + start_indices = tuple(start_indices) + if isinstance(shape, (list, np.ndarray)): + shape = tuple(shape) + assert isinstance(start_indices, tuple), ( + "`slice` is not supported by openvino backend" + " for `start_indices` of type {}".format(type(start_indices)) + ) + assert isinstance(shape, tuple), ( + "`slice` is not supported by openvino backend" + " for `shape` of type {}".format(type(shape)) + ) + + axes = [] + start = [] + stop = [] + + def prepare_slice_index(val): + val_type = val.get_element_type() + if not val_type.is_integral(): + raise ValueError( + "`slice` is not supported by OpenVINO backend " + "for `start_indices` or `shape` with non-integer types" + ) + if val_type != Type.i32: + val = ov_opset.convert(val, Type.i32).output(0) + if len(val.get_partial_shape()) == 0: + val = ov_opset.unsqueeze( + val, ov_opset.constant(0, Type.i32) + ).output(0) + return val + + for idx, length in enumerate(shape): + if length is not None and length >= 0: + axes.append(idx) + start_val = prepare_slice_index(get_ov_output(start_indices[idx])) + stop_val = prepare_slice_index( + get_ov_output(start_indices[idx] + length) + ) + start.append(start_val) + stop.append(stop_val) + + if len(axes) == 0: + return inputs + + step = [1] * len(start) + step = ov_opset.constant(step, Type.i32).output(0) + start = ov_opset.concat(start, axis=0).output(0) + stop = ov_opset.concat(stop, axis=0).output(0) + axes = ov_opset.constant(axes, Type.i32).output(0) + result = ov_opset.slice(inputs, start, stop, step, axes).output(0) + + # Apply reshape to ensure output matches expected shape + # Convert None (dynamic) dimensions to -1 for OpenVINO compatibility + if all(dim is None or (isinstance(dim, int) and dim >= 0) for dim in shape): + reshape_pattern = [(-1 if dim is None else dim) for dim in shape] + target_shape = ov_opset.constant(reshape_pattern, Type.i32).output(0) + result = ov_opset.reshape(result, target_shape, False).output(0) + + return OpenVINOKerasTensor(result) + + +def slice_update(inputs, start_indices, updates): + inputs = get_ov_output(inputs) + updates_tensor = get_ov_output(updates) + + if isinstance(start_indices, (list, np.ndarray)): + start_indices = tuple(start_indices) + if not isinstance(start_indices, tuple): + raise ValueError( + "`slice_update` is not supported by openvino backend" + " for `start_indices` of type {}".format(type(start_indices)) + ) + + zero_scalar = ov_opset.constant(0, Type.i32) + one_scalar = ov_opset.constant(1, Type.i32) + zero_tensor = ov_opset.constant([0], Type.i32) + one_tensor = ov_opset.constant([1], Type.i32) + + processed_start_indices = [] + for idx in start_indices: + val = get_ov_output(idx) + if not val.get_element_type().is_integral(): + raise ValueError("`slice_update` requires integral start_indices") + if val.get_element_type() != Type.i32: + val = ov_opset.convert(val, Type.i32).output(0) + if val.get_partial_shape().rank.get_length() == 0: + val = ov_opset.unsqueeze(val, zero_scalar).output(0) + processed_start_indices.append(val) + + updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0) + rank = updates_tensor.get_partial_shape().rank.get_length() + if rank == 0: + # Handle scalar update + start_tensor = ov_opset.concat(processed_start_indices, axis=0).output( + 0 + ) + # For scatter_nd_update, + # indices should be of shape [num_updates, rank_of_inputs] + # and updates should be of shape [num_updates]. Here num_updates is 1. + absolute_indices = ov_opset.unsqueeze(start_tensor, zero_scalar).output( + 0 + ) + updates_flat = ov_opset.unsqueeze(updates_tensor, zero_scalar).output(0) + result = ov_opset.scatter_nd_update( + inputs, absolute_indices, updates_flat + ).output(0) + return OpenVINOKerasTensor(result) + + # Compute the total number of elements in the updates tensor. + # Example: + # if updates.shape = [2, 3], total_elements = 6. + total_elements = ov_opset.reduce_prod( + updates_shape, zero_tensor, keep_dims=False + ).output(0) + + # Generate a flat range [0, 1, ..., total_elements-1]. + # This will be used to enumerate all positions in the updates tensor. + flat_indices = ov_opset.range( + zero_scalar, total_elements, one_scalar, output_type=Type.i32 + ).output(0) + + dim_sizes = [] + strides = [] + + # For each dimension, compute its size and the stride. + # (number of elements to skip to move to the next index in this dimension). + # Example: + # for shape [2, 3], strides = [3, 1]. + for dim in range(rank): + dim_size = ov_opset.gather( + updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar + ).output(0) + dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0) + dim_sizes.append(dim_size_scalar) + + # Strides to convert a flat index into a multi-dimensional index. + # This allows us to map each element in the flattened updates tensor + # to its correct N-dimensional position, so we can compute the absolute + # index in the input tensor for the scatter update. + # Stride for a dimension is the product of all dimensions after it. + # For the last dimension, stride is 1. + # Example: + # For a 3D tensor with shape [2, 3, 4]: + # - stride for dim=0 (first axis) is 3*4=12 + # (to move to the next "block" along axis 0) + # - stride for dim=1 is 4 (to move to the next row along axis 1) + # - stride for dim=2 is 1 (to move to the next element along axis 2) + # This is equivalent to how numpy flattens multi-dimensional arrays. + if dim < rank - 1: + remaining_dims = ov_opset.slice( + updates_shape, + ov_opset.constant([dim + 1], Type.i32), + ov_opset.constant([rank], Type.i32), + one_tensor, + zero_tensor, + ).output(0) + stride = ov_opset.reduce_prod( + remaining_dims, zero_tensor, keep_dims=False + ).output(0) + else: + stride = one_scalar + strides.append(stride) + + coord_tensors = [] + # For each dimension, compute the coordinate for every flat index. + # Example: + # for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1). + for dim in range(rank): + coords = ov_opset.mod( + ov_opset.divide(flat_indices, strides[dim]).output(0), + dim_sizes[dim], + ).output(0) + coord_tensors.append(coords) + + coord_tensors_unsqueezed = [] + for coord in coord_tensors: + # Unsqueeze to make each coordinate a column vector for concatenation. + coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0) + coord_tensors_unsqueezed.append(coord_unsqueezed) + + # Concatenate all coordinate columns to form [total_elements, rank] matrix. + # Each row is a multi-dimensional index into the updates tensor. + # Example: + # for shape [2, 3], row 4 = [1, 1]. + indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0) + + # Broadcast start indices to match the number of updates. + # Example: + # start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...], + # start_broadcast = [[2,3],[2,3],...] + start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0) + start_reshaped = ov_opset.reshape( + start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False + ).output(0) + + broadcast_shape = ov_opset.concat( + [ + ov_opset.unsqueeze(total_elements, zero_tensor).output(0), + one_tensor, + ], + axis=0, + ).output(0) + + start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0) + + # Add the broadcasted start indices to the relative indices + # to get absolute indices in the input tensor. + # Example: + # if start=(2,3), update index [1,1] -> absolute index [3,4]. + absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0) + + # Flatten the updates tensor to match the flat indices. + updates_flat = ov_opset.reshape( + updates_tensor, + ov_opset.unsqueeze(total_elements, zero_tensor).output(0), + special_zero=False, + ).output(0) + + # Perform the scatter update: for each absolute index, + # set the corresponding value from updates_flat. + result = ov_opset.scatter_nd_update( + inputs, absolute_indices, updates_flat + ).output(0) + return OpenVINOKerasTensor(result) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + def flatten_structure(data): + if isinstance(data, dict): + return [v for k in sorted(data) for v in flatten_structure(data[k])] + elif isinstance(data, (tuple, list)): + return [v for item in data for v in flatten_structure(item)] + else: + return [data] + + def pack_structure(template, flat): + if isinstance(template, dict): + keys = sorted(template) + packed = {} + for k in keys: + value, flat = pack_structure(template[k], flat) + packed[k] = value + return packed, flat + elif isinstance(template, (tuple, list)): + packed = [] + for item in template: + value, flat = pack_structure(item, flat) + packed.append(value) + return ( + tuple(packed) if isinstance(template, tuple) else packed + ), flat + else: + return flat[0], flat[1:] + + is_scalar_input = _is_scalar(loop_vars) + + if is_scalar_input: + loop_vars = (loop_vars,) + elif isinstance(loop_vars, (list, np.ndarray)): + loop_vars = tuple(loop_vars) + else: + assert isinstance(loop_vars, (tuple, dict)), ( + f"Unsupported type {type(loop_vars)} for loop_vars" + ) + + flat_loop_vars = flatten_structure(loop_vars) + loop_vars_ov = [get_ov_output(var) for var in flat_loop_vars] + + maximum_iterations = ( + ov_opset.constant(-1, Type.i32).output(0) + if maximum_iterations is None + else get_ov_output(maximum_iterations) + ) + + trip_count = maximum_iterations + execution_condition = ov_opset.constant(True, Type.boolean).output(0) + loop = ov_opset.loop(trip_count, execution_condition) + + shapes = [var.get_partial_shape() for var in loop_vars_ov] + types = [var.get_element_type() for var in loop_vars_ov] + params = [ + ov_opset.parameter(shape, dtype) for shape, dtype in zip(shapes, types) + ] + param_tensors = [OpenVINOKerasTensor(p.output(0)) for p in params] + + packed_args, _ = pack_structure(loop_vars, param_tensors) + if isinstance(packed_args, dict): + body_out = body(packed_args) + else: + body_out = body(*packed_args) + + if not isinstance(body_out, (list, tuple, dict)): + body_out = (body_out,) + + flat_body_out = flatten_structure(body_out) + if isinstance(packed_args, dict): + cond_output = get_ov_output(cond(body_out)) + else: + cond_output = get_ov_output(cond(*body_out)) + + if len(cond_output.get_partial_shape()) != 0: + raise ValueError( + "`cond` function must return a scalar boolean value, " + "but got shape {}".format(cond_output.get_partial_shape()) + ) + + for p, out in zip(params, flat_body_out): + out_shape = get_ov_output(out).get_partial_shape() + p.set_partial_shape(out_shape) + + results = [cond_output] + [get_ov_output(x) for x in flat_body_out] + body_func = Model(results=results, parameters=params) + loop.set_function(body_func) + loop.set_special_body_ports([-1, 0]) + + for param, init_val, next_val in zip(params, loop_vars_ov, flat_body_out): + loop.set_merged_input(param, init_val, get_ov_output(next_val)) + + outputs_flat = [ + OpenVINOKerasTensor(loop.get_iter_value(get_ov_output(val))) + for val in flat_body_out + ] + final_output, _ = pack_structure(loop_vars, outputs_flat) + + if is_scalar_input: + if isinstance(final_output, tuple): + return final_output[0] + else: + return final_output + else: + return final_output + + +def fori_loop(lower, upper, body_fun, init_val): + raise NotImplementedError( + "`fori_loop` is not supported with openvino backend" + ) + + +def stop_gradient(variable): + return variable + + +def unstack(x, num=None, axis=0): + raise NotImplementedError( + "`unstack` is not supported with openvino backend" + ) + + +def random_seed_dtype(): + return "uint32" + + +def custom_gradient(fun): + """Decorator for custom gradients. + + Args: + fun: Forward pass function. + """ + + def __init__(self, fun): + warnings.warn( + "`custom_gradient` for the openvino backend" + " acts as a pass-through to " + "support the forward pass." + " No gradient computation or modification " + "takes place." + ) + self.fun = fun + + def __call__(self, *args, **kwargs): + outputs, _ = self.fun(*args, **kwargs) + return outputs + + +def remat(f): + warnings.warn( + "Rematerialization memory optimization is not supported by the " + "OpenVino backend. Please switch to JAX, TensorFlow, or PyTorch to " + "utilize this feature." + ) + return f diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt new file mode 100644 index 000000000000..13bae27343d5 --- /dev/null +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -0,0 +1,294 @@ +NumPyTestRot90 +NumpyArrayCreateOpsCorrectnessTest::test_eye +NumpyDtypeTest::test_absolute_bool +NumpyDtypeTest::test_add_ +NumpyDtypeTest::test_all +NumpyDtypeTest::test_angle +NumpyDtypeTest::test_any +NumpyDtypeTest::test_argpartition +NumpyDtypeTest::test_array +NumpyDtypeTest::test_bartlett +NumpyDtypeTest::test_blackman +NumpyDtypeTest::test_gcd +NumpyDtypeTest::test_hamming +NumpyDtypeTest::test_hanning +NumpyDtypeTest::test_heaviside +NumpyDtypeTest::test_hypot +NumpyDtypeTest::test_kaiser +NumpyDtypeTest::test_bitwise +NumpyDtypeTest::test_cbrt +NumpyDtypeTest::test_ceil +NumpyDtypeTest::test_concatenate +NumpyDtypeTest::test_corrcoef +NumpyDtypeTest::test_correlate +NumpyDtypeTest::test_cross +NumpyDtypeTest::test_cumprod +NumpyDtypeTest::test_cumsum_bool +NumpyDtypeTest::test_diag +NumpyDtypeTest::test_digitize +NumpyDtypeTest::test_einsum +NumpyDtypeTest::test_exp2 +NumpyDtypeTest::test_eye +NumpyDtypeTest::test_flip +NumpyDtypeTest::test_floor +NumpyDtypeTest::test_inner +NumpyDtypeTest::test_isfinite +NumpyDtypeTest::test_isin +NumpyDtypeTest::test_isinf +NumpyDtypeTest::test_isnan +NumpyDtypeTest::test_isposinf +NumpyDtypeTest::test_kron +NumpyDtypeTest::test_lcm +NumpyDtypeTest::test_logaddexp2 +NumpyDtypeTest::test_matmul_ +NumpyDtypeTest::test_max +NumpyDtypeTest::test_mean +NumpyDtypeTest::test_minimum_python_types +NumpyDtypeTest::test_multiply +NumpyDtypeTest::test_power +NumpyDtypeTest::test_quantile +NumpyDtypeTest::test_roll +NumpyDtypeTest::test_round +NumpyDtypeTest::test_searchsorted +NumpyDtypeTest::test_signbit +NumpyDtypeTest::test_sqrt +NumpyDtypeTest::test_std +NumpyDtypeTest::test_subtract +NumpyDtypeTest::test_sum +NumpyDtypeTest::test_swapaxes +NumpyDtypeTest::test_tensordot_ +NumpyDtypeTest::test_tile +NumpyDtypeTest::test_trace +NumpyDtypeTest::test_trunc +NumpyDtypeTest::test_unravel +NumpyDtypeTest::test_var +NumpyDtypeTest::test_vdot +NumpyDtypeTest::test_vstack +NumpyDtypeTest::test_clip_bool +NumpyDtypeTest::test_square_bool +HistogramTest +NumpyOneInputOpsCorrectnessTest::test_all +NumpyOneInputOpsCorrectnessTest::test_angle +NumpyOneInputOpsCorrectnessTest::test_any +NumpyOneInputOpsCorrectnessTest::test_argpartition +NumpyOneInputOpsCorrectnessTest::test_array +NumpyOneInputOpsCorrectnessTest::test_bartlett +NumpyOneInputOpsCorrectnessTest::test_blackman +NumpyOneInputOpsCorrectnessTest::test_hamming +NumpyOneInputOpsCorrectnessTest::test_hanning +NumpyOneInputOpsCorrectnessTest::test_kaiser +NumpyOneInputOpsCorrectnessTest::test_bitwise_invert +NumpyOneInputOpsCorrectnessTest::test_cbrt +NumpyOneInputOpsCorrectnessTest::test_conj +NumpyOneInputOpsCorrectnessTest::test_corrcoef +NumpyOneInputOpsCorrectnessTest::test_correlate +NumpyOneInputOpsCorrectnessTest::test_cumprod +NumpyOneInputOpsCorrectnessTest::test_diag +NumpyOneInputOpsCorrectnessTest::test_diagonal +NumpyOneInputOpsCorrectnessTest::test_exp2 +NumpyOneInputOpsCorrectnessTest::test_flip +NumpyOneInputOpsCorrectnessTest::test_floor_divide +NumpyOneInputOpsCorrectnessTest::test_imag +NumpyOneInputOpsCorrectnessTest::test_isfinite +NumpyOneInputOpsCorrectnessTest::test_isinf +NumpyOneInputOpsCorrectnessTest::test_isposinf +NumpyOneInputOpsCorrectnessTest::test_logaddexp2 +NumpyOneInputOpsCorrectnessTest::test_max +NumpyOneInputOpsCorrectnessTest::test_mean +NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2 +NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2 +NumpyOneInputOpsCorrectnessTest::test_pad_float64_constant_2 +NumpyOneInputOpsCorrectnessTest::test_pad_int16_constant_2 +NumpyOneInputOpsCorrectnessTest::test_pad_int8_constant_2 +NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2 +NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2 +NumpyOneInputOpsCorrectnessTest::test_real +NumpyOneInputOpsCorrectnessTest::test_reshape +NumpyOneInputOpsCorrectnessTest::test_roll +NumpyOneInputOpsCorrectnessTest::test_round +NumpyOneInputOpsCorrectnessTest::test_searchsorted +NumpyOneInputOpsCorrectnessTest::test_select +NumpyOneInputOpsCorrectnessTest::test_signbit +NumpyOneInputOpsCorrectnessTest::test_size +NumpyOneInputOpsCorrectnessTest::test_slogdet +NumpyOneInputOpsCorrectnessTest::test_sqrt_int32 +NumpyOneInputOpsCorrectnessTest::test_squeeze +NumpyOneInputOpsCorrectnessTest::test_std +NumpyOneInputOpsCorrectnessTest::test_sum +NumpyOneInputOpsCorrectnessTest::test_swapaxes +NumpyOneInputOpsCorrectnessTest::test_tile +NumpyOneInputOpsCorrectnessTest::test_trace +NumpyOneInputOpsCorrectnessTest::test_transpose +NumpyOneInputOpsCorrectnessTest::test_trunc +NumpyOneInputOpsCorrectnessTest::test_unravel_index +NumpyOneInputOpsCorrectnessTest::test_var +NumpyOneInputOpsCorrectnessTest::test_vectorize +NumpyOneInputOpsCorrectnessTest::test_vstack +NumpyTwoInputOpsCorrectnessTest::test_bitwise_and +NumpyTwoInputOpsCorrectnessTest::test_bitwise_left_shift +NumpyTwoInputOpsCorrectnessTest::test_bitwise_or +NumpyTwoInputOpsCorrectnessTest::test_bitwise_right_shift +NumpyTwoInputOpsCorrectnessTest::test_bitwise_xor +NumpyTwoInputOpsCorrectnessTest::test_cross +NumpyTwoInputOpsCorrectnessTest::test_digitize +NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan +NumpyTwoInputOpsCorrectnessTest::test_einsum +NumpyTwoInputOpsCorrectnessTest::test_gcd +NumpyTwoInputOpsCorrectnessTest::test_heaviside +NumpyTwoInputOpsCorrectnessTest::test_hypot +NumpyTwoInputOpsCorrectnessTest::test_inner +NumpyTwoInputOpsCorrectnessTest::test_isin +NumpyTwoInputOpsCorrectnessTest::test_kron +NumpyTwoInputOpsCorrectnessTest::test_lcm +NumpyTwoInputOpsCorrectnessTest::test_quantile +NumpyTwoInputOpsCorrectnessTest::test_tensordot +NumpyTwoInputOpsCorrectnessTest::test_vdot +NumpyOneInputOpsDynamicShapeTest::test_angle +NumpyOneInputOpsDynamicShapeTest::test_bartlett +NumpyOneInputOpsDynamicShapeTest::test_blackman +NumpyOneInputOpsDynamicShapeTest::test_cbrt +NumpyOneInputOpsDynamicShapeTest::test_corrcoef +NumpyOneInputOpsDynamicShapeTest::test_hamming +NumpyOneInputOpsDynamicShapeTest::test_hanning +NumpyOneInputOpsDynamicShapeTest::test_isposinf +NumpyOneInputOpsDynamicShapeTest::test_kaiser +NumpyOneInputOpsStaticShapeTest::test_angle +NumpyOneInputOpsStaticShapeTest::test_cbrt +NumpyOneInputOpsStaticShapeTest::test_isposinf +NumpyTwoInputOpsDynamicShapeTest::test_gcd +NumpyTwoInputOpsDynamicShapeTest::test_heaviside +NumpyTwoInputOpsDynamicShapeTest::test_hypot +NumpyTwoInputOpsDynamicShapeTest::test_isin +NumpyTwoInputOpsDynamicShapeTest::test_kron +NumpyTwoInputOpsDynamicShapeTest::test_lcm +NumpyTwoInputOpsStaticShapeTest::test_gcd +NumpyTwoInputOpsStaticShapeTest::test_heaviside +NumpyTwoInputOpsStaticShapeTest::test_hypot +NumpyTwoInputOpsStaticShapeTest::test_isin +NumpyTwoInputOpsStaticShapeTest::test_kron +NumpyTwoInputOpsStaticShapeTest::test_lcm +CoreOpsBehaviorTests::test_associative_scan_invalid_arguments +CoreOpsBehaviorTests::test_scan_invalid_arguments +CoreOpsCallsTests::test_associative_scan_basic_call +CoreOpsCallsTests::test_fori_loop_basic_functionality +CoreOpsCallsTests::test_map_basic_call +CoreOpsCallsTests::test_scan_basic_call +CoreOpsCallsTests::test_scatter_basic_call +CoreOpsCallsTests::test_scatter_update_basic_call +CoreOpsCallsTests::test_switch_basic_call +CoreOpsCallsTests::test_unstack_basic_functionality +CoreOpsCorrectnessTest::test_associative_scan +CoreOpsCorrectnessTest::test_cond +CoreOpsCorrectnessTest::test_fori_loop +CoreOpsCorrectnessTest::test_map +CoreOpsCorrectnessTest::test_scan +CoreOpsCorrectnessTest::test_scatter +CoreOpsCorrectnessTest::test_switch +CoreOpsCorrectnessTest::test_unstack +CoreOpsCorrectnessTest::test_vectorized_map +CoreOpsBehaviorTests::test_vectorized_map_serialization +ExtractSequencesOpTest::test_extract_sequences_call +InTopKTest::test_in_top_k_call +MathOpsCorrectnessTest::test_erfinv_operation_basic +MathOpsCorrectnessTest::test_erfinv_operation_dtype +MathOpsCorrectnessTest::test_erfinv_operation_edge_cases +MathOpsCorrectnessTest::test_extract_sequences +MathOpsCorrectnessTest::test_fft +MathOpsCorrectnessTest::test_fft2 +MathOpsCorrectnessTest::test_ifft2 +MathOpsCorrectnessTest::test_in_top_k +MathOpsCorrectnessTest::test_irfft0 +MathOpsCorrectnessTest::test_irfft1 +MathOpsCorrectnessTest::test_irfft2 +MathOpsCorrectnessTest::test_istft0 +MathOpsCorrectnessTest::test_istft1 +MathOpsCorrectnessTest::test_istft2 +MathOpsCorrectnessTest::test_istft3 +MathOpsCorrectnessTest::test_istft4 +MathOpsCorrectnessTest::test_istft5 +MathOpsCorrectnessTest::test_istft6 +MathOpsCorrectnessTest::test_logdet +MathOpsCorrectnessTest::test_rfft0 +MathOpsCorrectnessTest::test_rfft1 +MathOpsCorrectnessTest::test_rfft2 +MathOpsCorrectnessTest::test_segment_reduce0 +MathOpsCorrectnessTest::test_segment_reduce1 +MathOpsCorrectnessTest::test_segment_reduce2 +MathOpsCorrectnessTest::test_segment_reduce3 +MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments0 +MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments1 +MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments2 +MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments3 +MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments4 +MathOpsCorrectnessTest::test_segment_reduce_explicit_num_segments5 +MathOpsCorrectnessTest::test_stft0 +MathOpsCorrectnessTest::test_stft1 +MathOpsCorrectnessTest::test_stft2 +MathOpsCorrectnessTest::test_stft3 +MathOpsCorrectnessTest::test_stft4 +MathOpsCorrectnessTest::test_stft5 +MathOpsCorrectnessTest::test_stft6 +RandomCorrectnessTest::test_beta0 +RandomCorrectnessTest::test_beta1 +RandomCorrectnessTest::test_beta2 +RandomCorrectnessTest::test_binomial0 +RandomCorrectnessTest::test_binomial1 +RandomCorrectnessTest::test_binomial2 +RandomCorrectnessTest::test_dropout +RandomCorrectnessTest::test_dropout_noise_shape +RandomCorrectnessTest::test_gamma0 +RandomCorrectnessTest::test_gamma1 +RandomCorrectnessTest::test_gamma2 +RandomCorrectnessTest::test_randint0 +RandomCorrectnessTest::test_randint1 +RandomCorrectnessTest::test_randint2 +RandomCorrectnessTest::test_randint3 +RandomCorrectnessTest::test_randint4 +RandomCorrectnessTest::test_shuffle +RandomCorrectnessTest::test_truncated_normal0 +RandomCorrectnessTest::test_truncated_normal1 +RandomCorrectnessTest::test_truncated_normal2 +RandomCorrectnessTest::test_truncated_normal3 +RandomCorrectnessTest::test_truncated_normal4 +RandomCorrectnessTest::test_truncated_normal5 +RandomCorrectnessTest::test_uniform0 +RandomCorrectnessTest::test_uniform1 +RandomCorrectnessTest::test_uniform2 +RandomCorrectnessTest::test_uniform3 +RandomCorrectnessTest::test_uniform4 +RandomBehaviorTest::test_beta_tf_data_compatibility +RandomDTypeTest::test_beta_bfloat16 +RandomDTypeTest::test_beta_float16 +RandomDTypeTest::test_beta_float32 +RandomDTypeTest::test_beta_float64 +RandomDTypeTest::test_binomial_bfloat16 +RandomDTypeTest::test_binomial_float16 +RandomDTypeTest::test_binomial_float32 +RandomDTypeTest::test_binomial_float64 +RandomDTypeTest::test_dropout_bfloat16 +RandomDTypeTest::test_dropout_float16 +RandomDTypeTest::test_dropout_float32 +RandomDTypeTest::test_dropout_float64 +RandomDTypeTest::test_gamma_bfloat16 +RandomDTypeTest::test_gamma_float16 +RandomDTypeTest::test_gamma_float32 +RandomDTypeTest::test_gamma_float64 +RandomDTypeTest::test_normal_bfloat16 +RandomDTypeTest::test_randint_int16 +RandomDTypeTest::test_randint_int32 +RandomDTypeTest::test_randint_int64 +RandomDTypeTest::test_randint_int8 +RandomDTypeTest::test_randint_uint16 +RandomDTypeTest::test_randint_uint32 +RandomDTypeTest::test_randint_uint8 +RandomDTypeTest::test_truncated_normal_bfloat16 +RandomDTypeTest::test_uniform_bfloat16 +SegmentSumTest::test_segment_sum_call +SegmentMaxTest::test_segment_max_call +TestMathErrors::test_invalid_fft_length +TestMathErrors::test_istft_invalid_window_shape_2D_inputs +TestMathErrors::test_stft_invalid_input_type +TestMathErrors::test_stft_invalid_window +TestMathErrors::test_stft_invalid_window_shape +LinalgOpsCorrectnessTest::test_cholesky +LinalgOpsCorrectnessTest::test_cholesky_inverse diff --git a/keras/src/backend/openvino/excluded_tests.txt b/keras/src/backend/openvino/excluded_tests.txt new file mode 100644 index 000000000000..93821712ef20 --- /dev/null +++ b/keras/src/backend/openvino/excluded_tests.txt @@ -0,0 +1,39 @@ +keras/src/activations +keras/src/backend/common/dtypes_test.py +keras/src/callbacks/early_stopping_test.py +keras/src/dtype_policies/dtype_policy_map_test.py +keras/src/layers/attention +keras/src/layers/convolutional/conv_transpose_test.py +keras/src/layers/convolutional/separable_conv_test.py +keras/src/layers/core/dense_test.py +keras/src/layers/core/einsum_dense_test.py +keras/src/layers/core/embedding_test.py +keras/src/layers/normalization/spectral_normalization_test.py +keras/src/layers/normalization/unit_normalization_test.py +keras/src/layers/pooling/average_pooling_test.py +keras/src/layers/pooling/max_pooling_test.py +keras/src/layers/preprocessing +keras/src/layers/regularization +keras/src/layers/reshaping/reshape_test.py +keras/src/layers/reshaping/up_sampling1d_test.py +keras/src/layers/reshaping/up_sampling2d_test.py +keras/src/layers/reshaping/up_sampling3d_test.py +keras/src/layers/reshaping/zero_padding1d_test.py +keras/src/layers/reshaping/zero_padding2d_test.py +keras/src/layers/reshaping/zero_padding3d_test.py +keras/src/layers/layer_test.py +keras/src/layers/rnn +keras/src/legacy +keras/src/losses +keras/src/metrics +keras/src/models +keras/src/ops/image_test.py +keras/src/ops/linalg_test.py +keras/src/ops/nn_test.py +keras/src/optimizers +keras/src/quantizers +keras/src/random/seed_generator_test.py +keras/src/regularizers +keras/src/saving +keras/src/trainers +keras/src/utils \ No newline at end of file diff --git a/keras/src/backend/openvino/export.py b/keras/src/backend/openvino/export.py new file mode 100644 index 000000000000..977ce42607b8 --- /dev/null +++ b/keras/src/backend/openvino/export.py @@ -0,0 +1,10 @@ +class OpenvinoExportArchive: + def track(self, resource): + raise NotImplementedError( + "`track` is not implemented in the openvino backend." + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + raise NotImplementedError( + "`add_endpoint` is not implemented in the openvino backend." + ) diff --git a/keras/src/backend/openvino/image.py b/keras/src/backend/openvino/image.py new file mode 100644 index 000000000000..1788495fac4e --- /dev/null +++ b/keras/src/backend/openvino/image.py @@ -0,0 +1,89 @@ +def rgb_to_grayscale(images, data_format=None): + raise NotImplementedError( + "`rgb_to_grayscale` is not supported with openvino backend" + ) + + +def resize( + image, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format="channels_last", +): + raise NotImplementedError("`resize` is not supported with openvino backend") + + +def affine_transform( + images, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, +): + raise NotImplementedError( + "`affine_transform` is not supported with openvino backend" + ) + + +def perspective_transform( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + raise NotImplementedError( + "`perspective_transform` is not supported with openvino backend" + ) + + +def map_coordinates( + inputs, coordinates, order, fill_mode="constant", fill_value=0 +): + raise NotImplementedError( + "`map_coordinates` is not supported with openvino backend" + ) + + +def gaussian_blur( + images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None +): + raise NotImplementedError( + "`gaussian_blur` is not supported with openvino backend" + ) + + +def elastic_transform( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + raise NotImplementedError( + "`elastic_transform` is not supported with openvino backend" + ) + + +def scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias=True, +): + raise NotImplementedError( + "`scale_and_translate` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/layer.py b/keras/src/backend/openvino/layer.py new file mode 100644 index 000000000000..334c32958a7b --- /dev/null +++ b/keras/src/backend/openvino/layer.py @@ -0,0 +1,2 @@ +class OpenvinoLayer: + pass diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py new file mode 100644 index 000000000000..e5e495fa1ac7 --- /dev/null +++ b/keras/src/backend/openvino/linalg.py @@ -0,0 +1,62 @@ +def cholesky(a, upper=False): + raise NotImplementedError( + "`cholesky` is not supported with openvino backend." + ) + + +def cholesky_inverse(a, upper=False): + raise NotImplementedError( + "`cholesky_inverse` is not supported with openvino backend." + ) + + +def det(a): + raise NotImplementedError("`det` is not supported with openvino backend") + + +def eig(a): + raise NotImplementedError("`eig` is not supported with openvino backend") + + +def eigh(a): + raise NotImplementedError("`eigh` is not supported with openvino backend") + + +def inv(a): + raise NotImplementedError("`inv` is not supported with openvino backend") + + +def lu_factor(a): + raise NotImplementedError( + "`lu_factor` is not supported with openvino backend" + ) + + +def norm(x, ord=None, axis=None, keepdims=False): + raise NotImplementedError("`norm` is not supported with openvino backend") + + +def qr(x, mode="reduced"): + raise NotImplementedError("`qr` is not supported with openvino backend") + + +def solve(a, b): + raise NotImplementedError("`solve` is not supported with openvino backend") + + +def solve_triangular(a, b, lower=False): + raise NotImplementedError( + "`solve_triangular` is not supported with openvino backend" + ) + + +def svd(x, full_matrices=True, compute_uv=True): + raise NotImplementedError("`svd` is not supported with openvino backend") + + +def lstsq(a, b, rcond=None): + raise NotImplementedError("`lstsq` is not supported with openvino backend") + + +def jvp(fun, primals, tangents, has_aux=False): + raise NotImplementedError("`jvp` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/math.py b/keras/src/backend/openvino/math.py new file mode 100644 index 000000000000..33fa47e13ad5 --- /dev/null +++ b/keras/src/backend/openvino/math.py @@ -0,0 +1,128 @@ +import openvino.opset14 as ov_opset +from openvino import Type + +from keras.src.backend.openvino.core import OpenVINOKerasTensor +from keras.src.backend.openvino.core import get_ov_output + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + raise NotImplementedError( + "`segment_sum` is not supported with openvino backend" + ) + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + raise NotImplementedError( + "`segment_max` is not supported with openvino backend" + ) + + +def top_k(x, k, sorted=True): + x = get_ov_output(x) + k_tensor = ov_opset.constant(k, dtype=Type.i32) + axis = -1 + sort_type = "value" if sorted else "none" + topk_node = ov_opset.topk(x, k_tensor, axis, "max", sort_type) + values = topk_node.output(0) + indices = topk_node.output(1) + return OpenVINOKerasTensor(values), OpenVINOKerasTensor(indices) + + +def in_top_k(targets, predictions, k): + raise NotImplementedError( + "`in_top_k` is not supported with openvino backend" + ) + + +def logsumexp(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + if isinstance(axis, tuple): + axis = list(axis) + axis = ov_opset.constant(axis, Type.i32).output(0) + const_zero = ov_opset.constant(0, x.get_element_type()).output(0) + # Use keepdims=True for reduce_max to ensure proper broadcasting + reduce_max = ov_opset.reduce_max(x, axis, True).output(0) + is_finite = ov_opset.is_finite(reduce_max).output(0) + norm_max = ov_opset.select(is_finite, reduce_max, const_zero).output(0) + norm_max_sub = ov_opset.subtract(x, norm_max).output(0) + exp_norm_max = ov_opset.exp(norm_max_sub).output(0) + sum_exp = ov_opset.reduce_sum(exp_norm_max, axis, keepdims).output(0) + log_sum_exp = ov_opset.log(sum_exp).output(0) + # Squeeze norm_max if needed to match dimensions + if not keepdims: + norm_max = ov_opset.squeeze(norm_max, axis).output(0) + log_sum_exp = ov_opset.add(norm_max, log_sum_exp).output(0) + return OpenVINOKerasTensor(log_sum_exp) + + +def qr(x, mode="reduced"): + raise NotImplementedError("`qr` is not supported with openvino backend") + + +def extract_sequences(x, sequence_length, sequence_stride): + raise NotImplementedError( + "`extract_sequences` is not supported with openvino backend" + ) + + +def fft(x): + raise NotImplementedError("`fft` is not supported with openvino backend") + + +def fft2(x): + raise NotImplementedError("`fft2` is not supported with openvino backend") + + +def rfft(x, fft_length=None): + raise NotImplementedError("`rfft` is not supported with openvino backend") + + +def irfft(x, fft_length=None): + raise NotImplementedError("`irfft` is not supported with openvino backend") + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + raise NotImplementedError("`stft` is not supported with openvino backend") + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + raise NotImplementedError("`istft` is not supported with openvino backend") + + +def rsqrt(x): + x = get_ov_output(x) + const_one = ov_opset.constant(1, x.get_element_type()).output(0) + sqrt = ov_opset.sqrt(x).output(0) + return OpenVINOKerasTensor(ov_opset.divide(const_one, sqrt).output(0)) + + +def erf(x): + x = get_ov_output(x) + erf = ov_opset.erf(x).output(0) + return OpenVINOKerasTensor(erf) + + +def erfinv(x): + raise NotImplementedError("`erfinv` is not supported with openvino backend") + + +def solve(a, b): + raise NotImplementedError("`solve` is not supported with openvino backend") + + +def norm(x, ord=None, axis=None, keepdims=False): + raise NotImplementedError("`norm` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py new file mode 100644 index 000000000000..2c025825ed82 --- /dev/null +++ b/keras/src/backend/openvino/nn.py @@ -0,0 +1,508 @@ +import openvino.opset14 as ov_opset +from openvino import Type + +from keras.src import backend +from keras.src.backend.openvino.core import OpenVINOKerasTensor +from keras.src.backend.openvino.core import get_ov_output + + +def relu(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.relu(x).output(0)) + + +def relu6(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.clamp(x, 0.0, 6.0).output(0)) + + +def sigmoid(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.sigmoid(x).output(0)) + + +def tanh(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.tanh(x).output(0)) + + +def softplus(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.softplus(x).output(0)) + + +def softsign(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.softsign(x).output(0)) + + +def silu(x): + x = get_ov_output(x) + return OpenVINOKerasTensor( + ov_opset.multiply(x, ov_opset.sigmoid(x)).output(0) + ) + + +def log_sigmoid(x): + raise NotImplementedError( + "`log_sigmoid` is not supported with openvino backend" + ) + + +def leaky_relu(x, negative_slope=0.2): + x = get_ov_output(x) + slope_const = ov_opset.constant( + negative_slope, x.get_element_type() + ).output(0) + leaky_relu = ov_opset.prelu(x, slope_const).output(0) + return OpenVINOKerasTensor(leaky_relu) + + +def hard_sigmoid(x): + x = get_ov_output(x) + alpha = get_ov_output(1.0 / 6.0, x.get_element_type()) + beta = get_ov_output(0.5, x.get_element_type()) + return OpenVINOKerasTensor(ov_opset.hard_sigmoid(x, alpha, beta).output(0)) + + +def hard_silu(x): + hard_sigmoid_output = get_ov_output(hard_sigmoid(x)) + x = get_ov_output(x) + return OpenVINOKerasTensor( + ov_opset.multiply(x, hard_sigmoid_output).output(0) + ) + + +def elu(x, alpha=1.0): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.elu(x, alpha).output(0)) + + +def selu(x): + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + x = get_ov_output(x) + alpha = get_ov_output(alpha, x.get_element_type()) + scale = get_ov_output(scale, x.get_element_type()) + return OpenVINOKerasTensor(ov_opset.selu(x, alpha, scale).output(0)) + + +def gelu(x, approximate=True): + x = get_ov_output(x) + approximate_mode = "erf" + if approximate: + approximate_mode = "tanh" + return OpenVINOKerasTensor(ov_opset.gelu(x, approximate_mode).output(0)) + + +def softmax(x, axis=-1): + x = get_ov_output(x) + if axis is None: + x_shape = ov_opset.shape_of(x) + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + flatten_x = ov_opset.reshape(x, flatten_shape, False).output(0) + softmax_x = ov_opset.softmax(flatten_x, 0).output(0) + return OpenVINOKerasTensor( + ov_opset.reshape(softmax_x, x_shape, False).output(0) + ) + return OpenVINOKerasTensor(ov_opset.softmax(x, axis).output(0)) + + +def log_softmax(x, axis=-1): + x = get_ov_output(x) + if axis is None: + x_shape = ov_opset.shape_of(x) + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + flatten_x = ov_opset.reshape(x, flatten_shape, False).output(0) + log_softmax_x = ov_opset.log_softmax(flatten_x, 0).output(0) + return OpenVINOKerasTensor( + ov_opset.reshape(log_softmax_x, x_shape, False).output(0) + ) + return OpenVINOKerasTensor(ov_opset.log_softmax(x, axis).output(0)) + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + raise NotImplementedError( + "`max_pool` is not supported with openvino backend" + ) + + +def average_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + raise NotImplementedError( + "`average_pool` is not supported with openvino backend" + ) + + +def _adjust_strides_dilation( + x, + num_spatial_dims, +): + # Helper function that converts an operand to a spatial operand. + x = (x,) * num_spatial_dims if isinstance(x, int) else x + # OpenVINO expects input in NCHW layout + # x = [1, 1] + list(x) + x = list(x) + return x + + +def _adjust_padding( + padding, +): + padding = padding.lower() if isinstance(padding, str) else padding + if padding == "same": + return "SAME_UPPER", [], [] + elif padding == "same_lower": + return "SAME_LOWER", [], [] + elif padding == "valid": + return "VALID", [], [] + pads_begin = [] + pads_end = [] + for padding_pair in padding: + pads_begin.append(padding_pair[0]) + pads_end.append(padding_pair[1]) + return "EXPLICIT", pads_begin, pads_end + + +def _adjust_input(inputs, num_spatial_dims, data_format): + if data_format == "channels_first": + return inputs + if num_spatial_dims == 1: + permutation = [0, 2, 1] + elif num_spatial_dims == 2: + permutation = [0, 3, 1, 2] + else: + permutation = [0, 4, 1, 2, 3] + permutation = ov_opset.constant(permutation, Type.i32) + return ov_opset.transpose(inputs, permutation).output(0) + + +def _adjust_kernel(kernel, num_spatial_dims): + if num_spatial_dims == 1: + permutation = [2, 1, 0] + elif num_spatial_dims == 2: + permutation = [3, 2, 0, 1] + else: + permutation = [4, 3, 0, 1, 2] + permutation = ov_opset.constant(permutation, Type.i32) + return ov_opset.transpose(kernel, permutation).output(0) + + +def _adjust_depthwise_kernel(kernel, num_spatial_dims): + # kernel layout: filter_H, filter_W, C_IN, Ch_mul + if num_spatial_dims == 1: + # kernel layout: filter_H, C_IN, Ch_mul + permutation = [1, 2, 0] + elif num_spatial_dims == 2: + # kernel layout: filter_H, filter_W, C_IN, Ch_mul + permutation = [2, 3, 0, 1] + else: + # kernel layout: filter_H, filter_W, filter_Z, C_IN, Ch_mul + permutation = [3, 4, 0, 1, 2] + permutation = ov_opset.constant(permutation, Type.i32) + return ov_opset.transpose(kernel, permutation).output(0) + + +def _adjust_outputs(outputs, num_spatial_dims, data_format): + if data_format == "channels_first": + return outputs + # convert a tensor from NCHW to NHWC layout + if num_spatial_dims == 1: + permutation = [0, 2, 1] + elif num_spatial_dims == 2: + permutation = [0, 2, 3, 1] + else: + permutation = [0, 2, 3, 4, 1] + permutation = ov_opset.constant(permutation, Type.i32) + return ov_opset.transpose(outputs, permutation).output(0) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + inputs = get_ov_output(inputs) + kernel = get_ov_output(kernel) + + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2 + + if data_format == "channels_last": + inputs_in_channels = inputs.get_partial_shape()[ + 2 + num_spatial_dims - 1 + ] + else: + inputs_in_channels = inputs.get_partial_shape()[1] + kernel_in_channels = kernel.get_partial_shape()[-2] + + strides = _adjust_strides_dilation(strides, num_spatial_dims) + dilation_rate = _adjust_strides_dilation(dilation_rate, num_spatial_dims) + pad_mode, pads_begin, pads_end = _adjust_padding(padding) + inputs = _adjust_input(inputs, num_spatial_dims, data_format) + kernel = _adjust_kernel(kernel, num_spatial_dims) + + num_groups = ( + inputs_in_channels.get_length() // kernel_in_channels.get_length() + ) + if num_groups == 1: + conv = ov_opset.convolution( + inputs, + kernel, + strides, + pads_begin, + pads_end, + dilation_rate, + pad_mode, + ) + else: + input_shape = ov_opset.shape_of(inputs).output(0) + filter_shape = ov_opset.shape_of(kernel).output(0) + zero_const = ov_opset.constant([0], Type.i32).output(0) + one_const = ov_opset.constant([1], Type.i32).output(0) + two_const = ov_opset.constant([2], Type.i32).output(0) + input_cin = ov_opset.slice( + input_shape, one_const, two_const, one_const + ).output(0) + filter_cin = ov_opset.slice( + filter_shape, one_const, two_const, one_const + ).output(0) + num_groups = ov_opset.divide(input_cin, filter_cin).output(0) + + # reshape the filter based on the number of groups information + int_max_const = ov_opset.constant([2**31 - 1], Type.i32).output(0) + filter_cout = ov_opset.slice( + filter_shape, zero_const, one_const, one_const + ).output(0) + filter_new_cout = ov_opset.divide(filter_cout, num_groups).output(0) + shape_cin_xy = ov_opset.slice( + filter_shape, one_const, int_max_const, one_const + ).output(0) + filter_new_shape = ov_opset.concat( + [num_groups, filter_new_cout, shape_cin_xy], 0 + ).output(0) + new_filter = ov_opset.reshape(kernel, filter_new_shape, False).output(0) + conv = ov_opset.group_convolution( + inputs, + new_filter, + strides, + pads_begin, + pads_end, + dilation_rate, + pad_mode, + ) + conv = _adjust_outputs(conv.output(0), num_spatial_dims, data_format) + return OpenVINOKerasTensor(conv) + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + inputs = get_ov_output(inputs) + kernel = get_ov_output(kernel) + + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2 + + assert data_format == "channels_last", ( + "`depthwise_conv` is supported only for channels_last data_format" + ) + + strides = _adjust_strides_dilation(strides, num_spatial_dims) + dilation_rate = _adjust_strides_dilation(dilation_rate, num_spatial_dims) + pad_mode, pads_begin, pads_end = _adjust_padding(padding) + + inputs = _adjust_input(inputs, num_spatial_dims, data_format) + kernel = _adjust_depthwise_kernel(kernel, num_spatial_dims) + unsqueeze_dim = ov_opset.constant([2], Type.i32) + kernel = ov_opset.unsqueeze(kernel, unsqueeze_dim) + + group_conv = ov_opset.group_convolution( + inputs, kernel, strides, pads_begin, pads_end, dilation_rate, pad_mode + ) + group_conv = _adjust_outputs( + group_conv.output(0), num_spatial_dims, data_format + ) + return OpenVINOKerasTensor(group_conv) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + raise NotImplementedError( + "`separable_conv` is not supported with openvino backend" + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + raise NotImplementedError( + "`conv_transpose` is not supported with openvino backend" + ) + + +def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + raise NotImplementedError( + "`one_hot` is not supported with openvino backend" + ) + + +def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + raise NotImplementedError( + "`multi_hot` is not supported with openvino backend" + ) + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + raise NotImplementedError( + "`categorical_crossentropy` is not supported with openvino backend" + ) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + raise NotImplementedError( + "`sparse_categorical_crossentropy` is not supported " + "with openvino backend" + ) + + +def binary_crossentropy(target, output, from_logits=False): + raise NotImplementedError( + "`binary_crossentropy` is not supported with openvino backend" + ) + + +def moments(x, axes, keepdims=False, synchronized=False): + x = get_ov_output(x) + axes = ov_opset.constant(axes, Type.i32).output(0) + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + mean = ov_opset.reduce_mean(x, axes, keepdims).output(0) + const_two = ov_opset.constant(2, x.get_element_type()).output(0) + squared_x = ov_opset.power(x, const_two).output(0) + squared_mean = ov_opset.power(mean, const_two).output(0) + squared_x_mean = ov_opset.reduce_mean(squared_x, axes, keepdims) + mean = OpenVINOKerasTensor(mean) + variance = OpenVINOKerasTensor( + ov_opset.subtract(squared_x_mean, squared_mean).output(0) + ) + return mean, variance + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + x = get_ov_output(x) + mean = get_ov_output(mean) + variance = get_ov_output(variance) + if offset is not None: + offset = get_ov_output(offset) + else: + mean_shape = ov_opset.shape_of(mean) + mean_type = mean.get_element_type() + zero_const = ov_opset.constant([0], mean_type) + offset = ov_opset.broadcast(zero_const, mean_shape) + if scale is not None: + scale = get_ov_output(scale) + else: + mean_shape = ov_opset.shape_of(mean) + mean_type = mean.get_element_type() + one_const = ov_opset.constant([1], mean_type) + scale = ov_opset.broadcast(one_const, mean_shape) + + # adjust x input to have the second dimension representing the channel axis + x_rank = x.get_partial_shape().rank.get_length() + if axis < 0: + axis += x_rank + if axis != 1: + perm_vector = list(range(0, x_rank)) + perm_vector[1] = axis + perm_vector[axis] = 1 + perm_vector = ov_opset.constant(perm_vector, Type.i32).output(0) + x = ov_opset.transpose(x, perm_vector).output(0) + batch_norm = ov_opset.batch_norm_inference( + x, scale, offset, mean, variance, epsilon + ).output(0) + if axis != 1: + perm_vector = list(range(0, x_rank)) + perm_vector[1] = axis + perm_vector[axis] = 1 + perm_vector = ov_opset.constant(perm_vector, Type.i32).output(0) + batch_norm = ov_opset.transpose(batch_norm, perm_vector).output(0) + return OpenVINOKerasTensor(batch_norm) + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + raise NotImplementedError( + "`ctc_loss` is not supported with openvino backend" + ) + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + raise NotImplementedError( + "`ctc_decode` is not supported with openvino backend" + ) + + +def psnr(x1, x2, max_val): + raise NotImplementedError("`psnr` is not supported with openvino backend") + + +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, +): + raise NotImplementedError( + "`dot_product_attention` is not supported with openvino backend" + ) + + +def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + raise NotImplementedError("`unfold` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py new file mode 100644 index 000000000000..c750d409d4a0 --- /dev/null +++ b/keras/src/backend/openvino/numpy.py @@ -0,0 +1,2412 @@ +import numpy as np +import openvino.opset14 as ov_opset +from openvino import Type + +from keras.src.backend import config +from keras.src.backend.common import dtypes +from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.openvino.core import DTYPES_MAX +from keras.src.backend.openvino.core import DTYPES_MIN +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import OpenVINOKerasTensor +from keras.src.backend.openvino.core import ( + align_operand_types as _align_operand_types, +) +from keras.src.backend.openvino.core import convert_to_tensor +from keras.src.backend.openvino.core import get_ov_output +from keras.src.backend.openvino.core import ov_to_keras_type + + +def add(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "add()") + return OpenVINOKerasTensor(ov_opset.add(x1, x2).output(0)) + + +def einsum(subscripts, *operands, **kwargs): + inputs = [] + for operand in operands: + operand = get_ov_output(operand) + inputs.append(operand) + return OpenVINOKerasTensor(ov_opset.einsum(inputs, subscripts).output(0)) + + +def subtract(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "subtract()") + if x1.get_element_type() == Type.boolean: + return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0)) + return OpenVINOKerasTensor(ov_opset.subtract(x1, x2).output(0)) + + +def matmul(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "matmul()") + return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0)) + + +def multiply(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "multiply()") + return OpenVINOKerasTensor(ov_opset.multiply(x1, x2).output(0)) + + +def mean(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis_const = ov_opset.constant(axis, dtype=Type.i32).output(0) + mean_ops = ov_opset.reduce_mean(x, axis_const, keepdims) + return OpenVINOKerasTensor(mean_ops.output(0)) + + +def max(x, axis=None, keepdims=False, initial=None): + assert initial is None, ( + "`max` with not None initial is not supported by openvino backend" + ) + x = get_ov_output(x) + reduce_axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor( + ov_opset.reduce_max(x, reduce_axis, keepdims).output(0) + ) + + +def ones(shape, dtype=None): + dtype = standardize_dtype(dtype) or config.floatx() + ov_type = OPENVINO_DTYPES[dtype] + const_one = ov_opset.constant(1, ov_type).output(0) + if isinstance(shape, tuple): + shape = list(shape) + elif isinstance(shape, int): + shape = [shape] + output_shape = ov_opset.constant(shape, dtype=Type.i32).output(0) + ones = ov_opset.broadcast(const_one, output_shape) + return OpenVINOKerasTensor(ones.output(0)) + + +def zeros(shape, dtype=None): + dtype = standardize_dtype(dtype) or config.floatx() + ov_type = OPENVINO_DTYPES[dtype] + const_zero = ov_opset.constant(0, dtype=ov_type).output(0) + if isinstance(shape, tuple): + shape = list(shape) + elif isinstance(shape, int): + shape = [shape] + output_shape = ov_opset.constant(shape, dtype=Type.i32).output(0) + zeros = ov_opset.broadcast(const_zero, output_shape) + return OpenVINOKerasTensor(zeros.output(0)) + + +def absolute(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.absolute(x).output(0)) + + +def abs(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.absolute(x).output(0)) + + +def all(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor( + ov_opset.reduce_logical_and(x, axis, keepdims).output(0) + ) + + +def angle(x): + raise NotImplementedError("`angle` is not supported with openvino backend") + + +def any(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor( + ov_opset.reduce_logical_or(x, axis, keepdims).output(0) + ) + + +def amax(x, axis=None, keepdims=False): + if axis == () or axis == []: + return x + x = get_ov_output(x) + x_type = x.get_element_type() + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + if isinstance(axis, tuple): + axis = list(axis) + axis = ov_opset.constant(axis, Type.i32).output(0) + if x_type == Type.boolean: + return OpenVINOKerasTensor( + ov_opset.reduce_logical_or(x, axis, keepdims).output(0) + ) + return OpenVINOKerasTensor(ov_opset.reduce_max(x, axis, keepdims).output(0)) + + +def amin(x, axis=None, keepdims=False): + if axis == () or axis == []: + return x + x = get_ov_output(x) + x_type = x.get_element_type() + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + if isinstance(axis, tuple): + axis = list(axis) + axis = ov_opset.constant(axis, Type.i32).output(0) + if x_type == Type.boolean: + return OpenVINOKerasTensor( + ov_opset.reduce_logical_and(x, axis, keepdims).output(0) + ) + return OpenVINOKerasTensor(ov_opset.reduce_min(x, axis, keepdims).output(0)) + + +def append(x1, x2, axis=None): + x1, x2 = get_ov_output(x1), get_ov_output(x2) + x1, x2 = _align_operand_types(x1, x2, "append()") + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x1 = ov_opset.reshape(x1, flatten_shape, False).output(0) + x2 = ov_opset.reshape(x2, flatten_shape, False).output(0) + axis = 0 + return OpenVINOKerasTensor(ov_opset.concat([x1, x2], axis).output(0)) + + +def arange(start, stop=None, step=None, dtype=None): + if stop is None: + start, stop = get_ov_output(0), get_ov_output(start) + else: + start, stop = get_ov_output(start), get_ov_output(stop) + + step = get_ov_output(1) if step is None else get_ov_output(step) + + ov_type = None + if dtype is not None: + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + else: + ov_type = OPENVINO_DTYPES[ + dtypes.result_type( + ov_to_keras_type(start.get_element_type()), + ov_to_keras_type(stop.get_element_type()), + ov_to_keras_type(step.get_element_type()), + "int32", + ) + ] + + start_node = ov_opset.convert(start, ov_type) + stop_node = ov_opset.convert(stop, ov_type) + step_node = ov_opset.convert(step, ov_type) + + return OpenVINOKerasTensor( + ov_opset.range(start_node, stop_node, step_node, ov_type).output(0) + ) + + +def arccos(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.acos(x).output(0)) + + +def arccosh(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.acosh(x).output(0)) + + +def arcsin(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.asin(x).output(0)) + + +def arcsinh(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.asinh(x).output(0)) + + +def arctan(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.atan(x).output(0)) + + +def arctan2(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + + x1_type = ov_to_keras_type(x1.get_element_type()) + x2_type = ov_to_keras_type(x2.get_element_type()) + result_type = dtypes.result_type(x1_type, x2_type, float) + result_type = OPENVINO_DTYPES[result_type] + x1 = ov_opset.convert(x1, result_type) + x2 = ov_opset.convert(x2, result_type) + + x = ov_opset.divide(x1, x2) + y = ov_opset.atan(x) + + ov_type = x1.get_element_type() + pi = ov_opset.constant(float(np.pi), ov_type) + half_pi = ov_opset.constant(float(np.pi / 2), ov_type) + neg_half_pi = ov_opset.constant(-float(np.pi / 2), ov_type) + zero_const = ov_opset.constant(0.0, ov_type) + + cond_x2_gt0 = ov_opset.greater(x2, zero_const).output(0) + cond_x2_lt0 = ov_opset.less(x2, zero_const).output(0) + + cond_x1_ge0 = ov_opset.greater_equal(x1, zero_const).output(0) + cond_x1_gt0 = ov_opset.greater(x1, zero_const).output(0) + cond_x1_eq0 = ov_opset.equal(x1, zero_const).output(0) + + out_x2_lt0 = ov_opset.select( + cond_x1_ge0, + ov_opset.add(y, pi), + ov_opset.subtract(y, pi), + ) + + out_x1_zero = ov_opset.select(cond_x1_eq0, zero_const, neg_half_pi) + out_x2_zero = ov_opset.select(cond_x1_gt0, half_pi, out_x1_zero) + + out_not_pos = ov_opset.select(cond_x2_lt0, out_x2_lt0, out_x2_zero) + + final_out = ov_opset.select(cond_x2_gt0, y, out_not_pos) + return OpenVINOKerasTensor(final_out.output(0)) + + +def arctanh(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.atanh(x).output(0)) + + +def argmax(x, axis=None, keepdims=False): + x = get_ov_output(x) + x_shape = x.get_partial_shape() + rank = x_shape.rank.get_length() + if rank == 0: + return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0)) + if axis is None: + flatten_shape = ov_opset.constant( + [-1] + [1] * (rank - 1), Type.i32 + ).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + k = ov_opset.constant(1, Type.i32).output(0) + else: + if axis < 0: + axis = rank + axis + k = ov_opset.constant(1, Type.i32).output(0) + topk_outputs = ov_opset.topk( + x, + k=k, + axis=axis, + mode="max", + sort="value", + stable=True, + index_element_type=Type.i32, + ) + topk_indices = topk_outputs.output(1) + if not keepdims: + topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0) + return OpenVINOKerasTensor(topk_indices) + + +def argmin(x, axis=None, keepdims=False): + x = get_ov_output(x) + x_shape = x.get_partial_shape() + rank = x_shape.rank.get_length() + if rank == 0: + return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0)) + if axis is None: + flatten_shape = ov_opset.constant( + [-1] + [1] * (rank - 1), Type.i32 + ).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + k = ov_opset.constant(1, Type.i32).output(0) + else: + if axis < 0: + axis = rank + axis + k = ov_opset.constant(1, Type.i32).output(0) + topk_outputs = ov_opset.topk( + x, + k=k, + axis=axis, + mode="min", + sort="value", + stable=True, + index_element_type=Type.i32, + ) + topk_indices = topk_outputs.output(1) + if not keepdims: + topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0) + return OpenVINOKerasTensor(topk_indices) + + +def argsort(x, axis=-1): + x = get_ov_output(x) + x_shape = x.get_partial_shape() + rank = x_shape.rank.get_length() + if rank == 0: + return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0)) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0) + k = ov_opset.reduce_prod( + x_shape_tensor, ov_opset.constant([0], Type.i32), keep_dims=False + ) + axis = 0 + else: + if axis < 0: + axis = rank + axis + x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0) + k = ov_opset.gather( + x_shape_tensor, + ov_opset.constant(axis, Type.i32).output(0), + ov_opset.constant(0, Type.i32).output(0), + ).output(0) + sorted_indices = ov_opset.topk( + x, + k=k, + axis=axis, + mode="min", + sort="value", + ).output(1) + return OpenVINOKerasTensor(sorted_indices) + + +def array(x, dtype=None): + if dtype is not None: + return np.array(x, dtype=dtype) + return np.array(x) + + +def average(x, axis=None, weights=None): + x = get_ov_output(x) + if weights is not None: + weights = get_ov_output(weights) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + if weights is not None: + weights = ov_opset.reshape(weights, flatten_shape, False).output(0) + axis = 0 + + if weights is not None: + x_type = x.get_element_type() + weights_type = weights.get_element_type() + if (weights_type.is_integral() or weights_type == Type.boolean) and ( + x_type.is_integral() or x_type == Type.boolean + ): + x = ov_opset.convert(x, Type.f32).output(0) + weights = ov_opset.convert(weights, Type.f32).output(0) + x, weights = _align_operand_types(x, weights, "multiply()") + x = ov_opset.multiply(x, weights) + + if isinstance(axis, tuple): + axis = list(axis) + if axis == []: + return OpenVINOKerasTensor(x) + + axis_const = ov_opset.constant(axis, dtype=Type.i32).output(0) + mean_ops = ov_opset.reduce_mean(x, axis_const, False) + return OpenVINOKerasTensor(mean_ops.output(0)) + + +def bartlett(x): + raise NotImplementedError( + "`bartlett` is not supported with openvino backend" + ) + + +def hamming(x): + raise NotImplementedError( + "`hamming` is not supported with openvino backend" + ) + + +def heaviside(x1, x2): + raise NotImplementedError( + "`heaviside` is not supported with openvino backend" + ) + + +def kaiser(x, beta): + raise NotImplementedError("`kaiser` is not supported with openvino backend") + + +def bincount(x, weights=None, minlength=0, sparse=False): + if x is None: + raise ValueError("input x is None") + if sparse: + raise ValueError("Unsupported value `sparse=True`") + x = get_ov_output(x) + x_type = x.get_element_type() + shape_x = ov_opset.shape_of(x, "i64").output(0) + rank_x = ov_opset.shape_of(shape_x, "i64").output(0) + rank_x = ov_opset.convert(rank_x, x_type).output(0) + scalar_shape = ov_opset.constant([], x_type).output(0) + rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0) + const_minus_one = ov_opset.constant(-1, x_type).output(0) + rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0) + minlength = get_ov_output(minlength) + minlength = ov_opset.convert(minlength, x_type).output(0) + const_one = ov_opset.constant(1, x_type).output(0) + const_zero = ov_opset.constant(0, x_type).output(0) + max_element = ov_opset.reduce_max(x, const_zero, keep_dims=False).output(0) + depth = ov_opset.add(max_element, const_one).output(0) + depth = ov_opset.maximum(depth, minlength).output(0) + depth_scalar = ov_opset.reduce_max( + depth, const_zero, keep_dims=False + ).output(0) + one_hot = ov_opset.one_hot( + x, depth_scalar, const_one, const_zero, axis=-1 + ).output(0) + if weights is not None: + weights = get_ov_output(weights) + weights_type = weights.get_element_type() + weights_new = ov_opset.reshape(weights, [-1, 1], False).output(0) + one_hot = ov_opset.convert(one_hot, weights_type).output(0) + final_one_hot = ov_opset.multiply(one_hot, weights_new).output(0) + final_output = ov_opset.reduce_sum( + final_one_hot, rank_minus_one, keep_dims=False + ).output(0) + return OpenVINOKerasTensor(final_output) + else: + final_output = ov_opset.reduce_sum( + one_hot, rank_minus_one, keep_dims=False + ).output(0) + final_output = ov_opset.convert(final_output, Type.i32).output(0) + return OpenVINOKerasTensor(final_output) + + +def blackman(x): + raise NotImplementedError( + "`blackman` is not supported with openvino backend" + ) + + +def broadcast_to(x, shape): + assert isinstance(shape, (tuple, list)), ( + "`broadcast_to` is supported only for tuple and list `shape`" + ) + target_shape = ov_opset.constant(list(shape), Type.i32).output(0) + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.broadcast(x, target_shape).output(0)) + + +def cbrt(x): + raise NotImplementedError("`cbrt` is not supported with openvino backend") + + +def ceil(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.ceil(x).output(0)) + + +def clip(x, x_min, x_max): + x = get_ov_output(x) + x_min = get_ov_output(x_min, x.get_element_type()) + x_max = get_ov_output(x_max, x.get_element_type()) + clip_by_min = ov_opset.maximum(x, x_min).output(0) + clip_by_max = ov_opset.minimum(clip_by_min, x_max).output(0) + return OpenVINOKerasTensor(clip_by_max) + + +def concatenate(xs, axis=0): + assert isinstance(xs, list), "`concatenate` is supported only for `x` list" + elems = [] + for elem in xs: + elem = get_ov_output(elem) + elems.append(elem) + res = ov_opset.concat(elems, axis).output(0) + return OpenVINOKerasTensor(res) + + +def conjugate(x): + raise NotImplementedError( + "`conjugate` is not supported with openvino backend" + ) + + +def conj(x): + raise NotImplementedError("`conj` is not supported with openvino backend") + + +def copy(x): + return x + + +def cos(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.cos(x).output(0)) + + +def cosh(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.cosh(x).output(0)) + + +def count_nonzero(x, axis=None): + x = get_ov_output(x) + zero_constant = ov_opset.constant(0, dtype=Type.i32).output(0) + zero_constant = ov_opset.convert_like(zero_constant, x) + x = ov_opset.not_equal(x, zero_constant).output(0) + x = ov_opset.convert(x, Type.i32).output(0) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + if isinstance(axis, tuple): + axis = list(axis) + if axis == []: + return OpenVINOKerasTensor(x) + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.reduce_sum(x, axis, False).output(0)) + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + raise NotImplementedError("`cross` is not supported with openvino backend") + + +def cumprod(x, axis=None, dtype=None): + raise NotImplementedError( + "`cumprod` is not supported with openvino backend" + ) + + +def cumsum(x, axis=None, dtype=None): + x = get_ov_output(x) + if dtype is not None: + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + x = ov_opset.convert(x, ov_type).output(0) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.cumsum(x, axis).output(0)) + + +def deg2rad(x): + x = get_ov_output(x) + x_type = x.get_element_type() + pi_over_180 = np.pi / 180.0 + + if x_type == Type.i64: + output_type = Type.f64 + elif x_type.is_integral(): + output_type = OPENVINO_DTYPES[config.floatx()] + else: + output_type = x_type + + if x_type != output_type: + x = ov_opset.convert(x, output_type) + + const_pi_over_180 = ov_opset.constant(pi_over_180, output_type).output(0) + result = ov_opset.multiply(x, const_pi_over_180).output(0) + + return OpenVINOKerasTensor(result) + + +def diag(x, k=0): + raise NotImplementedError("`diag` is not supported with openvino backend") + + +def diagonal(x, offset=0, axis1=0, axis2=1): + raise NotImplementedError( + "`diagonal` is not supported with openvino backend" + ) + + +def diff(a, n=1, axis=-1): + if n == 0: + return OpenVINOKerasTensor(get_ov_output(a)) + if n < 0: + raise ValueError(f"order must be non-negative but got {repr(n)}") + a = get_ov_output(a) + a_type = a.get_element_type() + if isinstance(a, np.ndarray): + rank = a.ndim + else: + rank = a.get_partial_shape().rank.get_length() + if axis < 0: + axis = axis + rank + result = a + for _ in range(n): + rank = result.get_partial_shape().rank.get_length() + strides = ov_opset.constant( + np.array([1] * rank, dtype=np.int64), Type.i64 + ).output(0) + + begin_upper_list = [0] * rank + begin_upper_list[axis] = 1 + begin_upper = ov_opset.constant( + np.array(begin_upper_list, dtype=np.int64), Type.i64 + ).output(0) + end_upper = ov_opset.constant( + np.array([0] * rank, dtype=np.int64), Type.i64 + ).output(0) + begin_mask_upper = [1] * rank + begin_mask_upper[axis] = 0 + end_mask_upper = [1] * rank + upper = ov_opset.strided_slice( + data=result, + begin=begin_upper, + end=end_upper, + strides=strides, + begin_mask=begin_mask_upper, + end_mask=end_mask_upper, + new_axis_mask=[], + shrink_axis_mask=[], + ellipsis_mask=[], + ).output(0) + + begin_lower = ov_opset.constant( + np.array([0] * rank, dtype=np.int64), Type.i64 + ).output(0) + end_lower_list = [0] * rank + end_lower_list[axis] = -1 + end_lower = ov_opset.constant( + np.array(end_lower_list, dtype=np.int64), Type.i64 + ).output(0) + begin_mask_lower = [1] * rank + end_mask_lower = [1] * rank + end_mask_lower[axis] = 0 + lower = ov_opset.strided_slice( + data=result, + begin=begin_lower, + end=end_lower, + strides=strides, + begin_mask=begin_mask_lower, + end_mask=end_mask_lower, + new_axis_mask=[], + shrink_axis_mask=[], + ellipsis_mask=[], + ).output(0) + + if a_type == Type.boolean: + result = ov_opset.not_equal(upper, lower).output(0) + else: + result = ov_opset.subtract(upper, lower).output(0) + return OpenVINOKerasTensor(result) + + +def digitize(x, bins): + raise NotImplementedError( + "`digitize` is not supported with openvino backend" + ) + + +def dot(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "dot()") + if x1.get_partial_shape().rank == 0 or x2.get_partial_shape().rank == 0: + return OpenVINOKerasTensor(ov_opset.multiply(x1, x2).output(0)) + return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0)) + + +def empty(shape, dtype=None): + dtype = standardize_dtype(dtype) or config.floatx() + ov_type = OPENVINO_DTYPES[dtype] + if isinstance(shape, tuple): + shape = list(shape) + elif isinstance(shape, int): + shape = [shape] + shape_node = ov_opset.constant(shape, Type.i32).output(0) + const_zero = ov_opset.constant(0, dtype=ov_type).output(0) + empty_tensor = ov_opset.broadcast(const_zero, shape_node).output(0) + return OpenVINOKerasTensor(empty_tensor) + + +def equal(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "equal()") + return OpenVINOKerasTensor(ov_opset.equal(x1, x2).output(0)) + + +def exp(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.exp(x).output(0)) + + +def expand_dims(x, axis): + x = get_ov_output(x) + if isinstance(axis, tuple): + axis = list(axis) + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.unsqueeze(x, axis).output(0)) + + +def expm1(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + exp_x = ov_opset.exp(x).output(0) + const_one = ov_opset.constant(1, exp_x.get_element_type()) + result = ov_opset.subtract(exp_x, const_one).output(0) + return OpenVINOKerasTensor(result) + + +def flip(x, axis=None): + raise NotImplementedError("`flip` is not supported with openvino backend") + + +def floor(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.floor(x).output(0)) + + +def full(shape, fill_value, dtype=None): + dtype = standardize_dtype(dtype) or config.floatx() + ov_type = OPENVINO_DTYPES[dtype] + fill_value = get_ov_output(fill_value, ov_type) + if isinstance(shape, tuple): + shape = list(shape) + target_shape = ov_opset.constant(shape, Type.i32) + return OpenVINOKerasTensor( + ov_opset.broadcast(fill_value, target_shape).output(0) + ) + + +def full_like(x, fill_value, dtype=None): + x = get_ov_output(x) + shape_x = ov_opset.shape_of(x) + if dtype is not None: + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + else: + ov_type = x.get_element_type() + const_value = ov_opset.constant(fill_value, ov_type).output(0) + res = ov_opset.broadcast(const_value, shape_x).output(0) + return OpenVINOKerasTensor(res) + + +def gcd(x1, x2): + raise NotImplementedError("`gcd` is not supported with openvino backend") + + +def greater(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "greater()") + return OpenVINOKerasTensor(ov_opset.greater(x1, x2).output(0)) + + +def greater_equal(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "greater_equal()") + return OpenVINOKerasTensor(ov_opset.greater_equal(x1, x2).output(0)) + + +def hstack(xs): + if not isinstance(xs, (list, tuple)): + xs = (xs,) + elems = [convert_to_tensor(elem) for elem in xs] + element_type = elems[0].output.get_element_type() + elems = [get_ov_output(elem, element_type) for elem in elems] + is_1d = elems and len(elems[0].get_partial_shape().to_shape()) == 1 + axis = 0 if is_1d else 1 + for i in range(1, len(elems)): + elems[0], elems[i] = _align_operand_types( + elems[0], elems[i], "hstack()" + ) + return OpenVINOKerasTensor(ov_opset.concat(elems, axis).output(0)) + + +def hypot(x1, x2): + raise NotImplementedError("`hypot` is not supported with openvino backend") + + +def identity(n, dtype=None): + n = get_ov_output(n) + dtype = Type.f32 if dtype is None else dtype + if isinstance(dtype, str): + ov_dtype = OPENVINO_DTYPES[dtype] + else: + ov_dtype = dtype + n32 = ov_opset.convert(n, Type.i32).output(0) + identity_matrix = ov_opset.eye( + num_rows=n32, num_columns=n32, diagonal_index=0, output_type=ov_dtype + ) + return OpenVINOKerasTensor(identity_matrix.output(0)) + + +def imag(x): + raise NotImplementedError("`imag` is not supported with openvino backend") + + +def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): + dtype = OPENVINO_DTYPES[config.floatx()] + + x1 = ov_opset.convert(get_ov_output(x1), dtype) + x2 = ov_opset.convert(get_ov_output(x2), dtype) + rtol = ov_opset.convert(get_ov_output(rtol), dtype) + atol = ov_opset.convert(get_ov_output(atol), dtype) + + abs_diff = ov_opset.abs(x1 - x2) + abs_x2 = ov_opset.abs(x2) + total_tolerance = atol + rtol * abs_x2 + is_close = ov_opset.less_equal(abs_diff, total_tolerance) + if equal_nan: + both_nan = ov_opset.logical_and(ov_opset.isnan(x1), ov_opset.isnan(x2)) + is_close = ov_opset.logical_or(is_close, both_nan) + + return OpenVINOKerasTensor(is_close.output(0)) + + +def isfinite(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.is_finite(x).output(0)) + + +def isin(x1, x2, assume_unique=False, invert=False): + raise NotImplementedError("`isin` is not supported with openvino backend") + + +def isinf(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.is_inf(x).output(0)) + + +def isnan(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.is_nan(x).output(0)) + + +def isneginf(x): + x = get_ov_output(x) + x_type = x.get_element_type() + + if x_type.is_integral() or x_type == Type.boolean: + shape = ov_opset.shape_of(x, "i32").output(0) + false_const = ov_opset.constant(False, Type.boolean).output(0) + return OpenVINOKerasTensor( + ov_opset.broadcast(false_const, shape).output(0) + ) + + if x_type == Type.bf16: + x_f32 = ov_opset.convert(x, Type.f32).output(0) + neg_inf = ov_opset.constant(-np.inf, Type.f32).output(0) + is_neg_inf = ov_opset.equal(x_f32, neg_inf).output(0) + else: + if x_type == Type.f16: + neg_inf = ov_opset.constant(-np.inf, Type.f16).output(0) + elif x_type == Type.f32: + neg_inf = ov_opset.constant(-np.inf, Type.f32).output(0) + elif x_type == Type.f64: + neg_inf = ov_opset.constant(-np.inf, Type.f64).output(0) + else: + neg_inf = ov_opset.constant(-np.inf, Type.f32).output(0) + is_neg_inf = ov_opset.equal(x, neg_inf).output(0) + + return OpenVINOKerasTensor(is_neg_inf) + + +def isposinf(x): + raise NotImplementedError( + "`isposinf` is not supported with openvino backend" + ) + + +def kron(x1, x2): + raise NotImplementedError("`kron` is not supported with openvino backend") + + +def lcm(x1, x2): + raise NotImplementedError("`lcm` is not supported with openvino backend") + + +def less(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "less()") + return OpenVINOKerasTensor(ov_opset.less(x1, x2).output(0)) + + +def less_equal(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "less_equal()") + return OpenVINOKerasTensor(ov_opset.less_equal(x1, x2).output(0)) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + """Return evenly spaced numbers over a specified interval. + + Supports axis=0 (prepend) and axis=-1 (append). Intermediate axis values are + treated as axis=-1. + + If `retstep` is True, also returns the step size between values. + + """ + + start = get_ov_output(start) + stop = get_ov_output(stop) + + if hasattr(num, "output") or isinstance(num, OpenVINOKerasTensor): + num_tensor = get_ov_output(num) + try: + if num_tensor.get_node().get_type_name() == "Constant": + num_value = num_tensor.get_node().get_vector()[0] + num = int(num_value) + else: + raise NotImplementedError( + "Dynamic num values not fully supported" + ) + except Exception as e: + raise NotImplementedError( + "Could not extract num value from tensor" + ) from e + else: + num = int(num) + + if dtype is None: + output_type = OPENVINO_DTYPES[config.floatx()] + else: + output_type = OPENVINO_DTYPES[dtype] + + start = ov_opset.convert(start, output_type).output(0) + stop = ov_opset.convert(stop, output_type).output(0) + + if num < 0: + raise ValueError("Number of samples, `num`, must be non-negative.") + + if num == 0: + empty_shape = ov_opset.constant([0], Type.i32).output(0) + result = ov_opset.broadcast( + ov_opset.constant(0.0, output_type).output(0), empty_shape + ).output(0) + if retstep: + nan_step = ov_opset.constant(np.nan, output_type).output(0) + return OpenVINOKerasTensor(result), OpenVINOKerasTensor(nan_step) + return OpenVINOKerasTensor(result) + + if num == 1: + result_val = start + axis_const = ov_opset.constant([axis], Type.i32).output(0) + result = ov_opset.unsqueeze(result_val, axis_const).output(0) + if retstep: + if endpoint: + step = ov_opset.constant(np.nan, output_type).output(0) + else: + step = ov_opset.subtract(stop, start).output(0) + return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step) + zero_i32 = ov_opset.constant(0, Type.i32).output(0) + one_i32 = ov_opset.constant(1, Type.i32).output(0) + one_i32_array = ov_opset.constant([1], Type.i32).output(0) + + num_const = ov_opset.constant(num, output_type).output(0) + + if endpoint: + divisor = ov_opset.subtract( + num_const, ov_opset.constant(1, output_type).output(0) + ).output(0) + else: + divisor = num_const + + step = ov_opset.divide( + ov_opset.subtract(stop, start).output(0), divisor + ).output(0) + + indices = ov_opset.range( + zero_i32, + ov_opset.constant(num, Type.i32).output(0), + one_i32, + output_type, + ).output(0) + + start_shape = ov_opset.convert( + ov_opset.shape_of(start).output(0), Type.i32 + ).output(0) + indices_shape = ov_opset.convert( + ov_opset.shape_of(indices).output(0), Type.i32 + ).output(0) + + start_rank = ov_opset.shape_of(start_shape).output(0) + ones_for_start = ov_opset.broadcast(one_i32, start_rank).output(0) + + if axis == 0: + indices_target_shape = ov_opset.concat( + [indices_shape, ones_for_start], 0 + ).output(0) + start_target_shape = ov_opset.concat( + [one_i32_array, start_shape], 0 + ).output(0) + else: + indices_target_shape = ov_opset.concat( + [ones_for_start, indices_shape], 0 + ).output(0) + start_target_shape = ov_opset.concat( + [start_shape, one_i32_array], 0 + ).output(0) + + indices_reshaped = ov_opset.reshape( + indices, indices_target_shape, False + ).output(0) + start_reshaped = ov_opset.reshape(start, start_target_shape, False).output( + 0 + ) + step_reshaped = ov_opset.reshape(step, start_target_shape, False).output(0) + + scaled_indices = ov_opset.multiply(indices_reshaped, step_reshaped).output( + 0 + ) + result = ov_opset.add(start_reshaped, scaled_indices).output(0) + + if retstep: + return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step) + return OpenVINOKerasTensor(result) + + +def log(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + x_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, x_type) + return OpenVINOKerasTensor(ov_opset.log(x).output(0)) + + +def log10(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + x_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, x_type) + log_x = ov_opset.log(x).output(0) + const_10 = ov_opset.constant(10, x_type).output(0) + log_10 = ov_opset.log(const_10).output(0) + result = ov_opset.divide(log_x, log_10).output(0) + return OpenVINOKerasTensor(result) + + +def log1p(x): + x = get_ov_output(x) + x_type = x.get_element_type() + + if x_type.is_integral(): + x_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, x_type) + + one_const = ov_opset.constant(1, x_type).output(0) + added = ov_opset.add(x, one_const).output(0) + result = ov_opset.log(added).output(0) + return OpenVINOKerasTensor(result) + + +def log2(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + x_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, x_type) + log_x = ov_opset.log(x).output(0) + const_2 = ov_opset.constant(2, x_type).output(0) + log_2 = ov_opset.log(const_2).output(0) + result = ov_opset.divide(log_x, log_2).output(0) + return OpenVINOKerasTensor(result) + + +def logaddexp(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "logaddexp()") + + if x1.element_type.is_integral() or x2.element_type.is_integral(): + float_dtype = OPENVINO_DTYPES[config.floatx()] + if x1.element_type.is_integral(): + x1 = ov_opset.convert(x1, float_dtype) + if x2.element_type.is_integral(): + x2 = ov_opset.convert(x2, float_dtype) + + # Get the output nodes properly + max_val_node = ov_opset.maximum(x1, x2) + max_val = max_val_node.output(0) + + # Compute absolute difference + sub_node = ov_opset.subtract(x1, x2) + abs_diff_node = ov_opset.abs(sub_node.output(0)) + abs_diff = abs_diff_node.output(0) + + # Compute negative absolute difference and its exponential + neg_abs_diff_node = ov_opset.negative(abs_diff) + neg_abs_diff = neg_abs_diff_node.output(0) + exp_neg_abs_node = ov_opset.exp(neg_abs_diff) + exp_neg_abs = exp_neg_abs_node.output(0) + + # Get the element type from the node, not the output + element_type = exp_neg_abs_node.get_element_type() + one_node = ov_opset.constant(1, element_type) + one = one_node.output(0) + + # Compute log term + one_plus_exp_node = ov_opset.add(one, exp_neg_abs) + one_plus_exp = one_plus_exp_node.output(0) + log_term_node = ov_opset.log(one_plus_exp) + log_term = log_term_node.output(0) + + # Final result + result_node = ov_opset.add(max_val, log_term) + result = result_node.output(0) + + return OpenVINOKerasTensor(result) + + +def logaddexp2(x1, x2): + raise NotImplementedError( + "`logaddexp2` is not supported with openvino backend" + ) + + +def logical_and(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1 = ov_opset.convert(x1, Type.boolean).output(0) + x2 = ov_opset.convert(x2, Type.boolean).output(0) + return OpenVINOKerasTensor(ov_opset.logical_and(x1, x2).output(0)) + + +def logical_not(x): + x = get_ov_output(x) + x = ov_opset.convert(x, Type.boolean).output(0) + return OpenVINOKerasTensor(ov_opset.logical_not(x).output(0)) + + +def logical_or(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1 = ov_opset.convert(x1, Type.boolean).output(0) + x2 = ov_opset.convert(x2, Type.boolean).output(0) + return OpenVINOKerasTensor(ov_opset.logical_or(x1, x2).output(0)) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + linear_samples = linspace( + start=start, + stop=stop, + num=num, + endpoint=endpoint, + retstep=False, + dtype=dtype, + axis=axis, + ) + + if dtype is None: + output_type = OPENVINO_DTYPES[config.floatx()] + else: + output_type = OPENVINO_DTYPES[dtype] + + linear_output = get_ov_output(linear_samples) + base_tensor = get_ov_output(base) + + base_tensor = ov_opset.convert(base_tensor, output_type).output(0) + + result = ov_opset.power(base_tensor, linear_output).output(0) + + return OpenVINOKerasTensor(result) + + +def maximum(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1, x2 = _align_operand_types(x1, x2, "maximum()") + return OpenVINOKerasTensor(ov_opset.maximum(x1, x2).output(0)) + + +def median(x, axis=None, keepdims=False): + x = get_ov_output(x) + x_shape = x.get_partial_shape() + rank = x_shape.rank.get_length() + + if rank == 0: + return OpenVINOKerasTensor(x) + + # Handle axis=None by flattening the input + flattened_all = False + if axis is None: + x = ov_opset.reshape(x, [-1], False).output(0) + axis = 0 + original_rank = rank + rank = 1 + flattened_all = True + else: + # Handle tuple axis - for median, we only support single axis + if isinstance(axis, (tuple, list)): + if len(axis) != 1: + raise ValueError("median only supports single axis reduction") + axis = axis[0] + + # Handle negative axis + if axis < 0: + axis = rank + axis + original_rank = rank + + # Get the size of the dimension to sort + shape_tensor = ov_opset.shape_of(x, output_type=Type.i32).output(0) + k = ov_opset.gather( + shape_tensor, + ov_opset.constant([axis], Type.i32).output(0), + ov_opset.constant(0, Type.i32).output(0), + ).output(0) + + # Convert k to a scalar value + k_scalar = ov_opset.squeeze(k, [0]).output(0) + + # Use topk with k=size_of_axis to get all elements sorted + topk_outputs = ov_opset.topk( + x, k=k_scalar, axis=axis, mode="min", sort="value", stable=True + ) + + # Get the sorted values + sorted_values = topk_outputs.output(0) + + # Convert to float for median calculation + x1_type = ov_to_keras_type(sorted_values.get_element_type()) + result_type = dtypes.result_type(x1_type, float) + result_type = OPENVINO_DTYPES[result_type] + sorted_values = ov_opset.convert(sorted_values, result_type).output(0) + + # Calculate median indices + # For odd length: median_idx = (k-1) // 2 + # For even length: we need indices (k//2 - 1) and k//2, then average + + k_minus_1 = ov_opset.subtract( + k_scalar, ov_opset.constant(1, Type.i32).output(0) + ).output(0) + k_div_2 = ov_opset.divide( + k_scalar, ov_opset.constant(2, Type.i32).output(0) + ).output(0) + k_minus_1_div_2 = ov_opset.divide( + k_minus_1, ov_opset.constant(2, Type.i32).output(0) + ).output(0) + + # Check if k is odd + k_mod_2 = ov_opset.mod( + k_scalar, ov_opset.constant(2, Type.i32).output(0) + ).output(0) + is_odd = ov_opset.equal( + k_mod_2, ov_opset.constant(1, Type.i32).output(0) + ).output(0) + + # For odd case: take the middle element + odd_idx = k_minus_1_div_2 + + # For even case: take average of two middle elements + even_idx1 = ov_opset.subtract( + k_div_2, ov_opset.constant(1, Type.i32).output(0) + ).output(0) + even_idx2 = k_div_2 + + # Gather elements for both cases + # Create gather indices tensor for the axis + gather_indices_odd = ov_opset.unsqueeze(odd_idx, [0]).output(0) + gather_indices_even1 = ov_opset.unsqueeze(even_idx1, [0]).output(0) + gather_indices_even2 = ov_opset.unsqueeze(even_idx2, [0]).output(0) + + # Gather the median elements + odd_result = ov_opset.gather( + sorted_values, + gather_indices_odd, + ov_opset.constant(axis, Type.i32).output(0), + ).output(0) + even_result1 = ov_opset.gather( + sorted_values, + gather_indices_even1, + ov_opset.constant(axis, Type.i32).output(0), + ).output(0) + even_result2 = ov_opset.gather( + sorted_values, + gather_indices_even2, + ov_opset.constant(axis, Type.i32).output(0), + ).output(0) + + # Average the two middle elements for even case + even_sum = ov_opset.add(even_result1, even_result2).output(0) + even_result = ov_opset.divide( + even_sum, ov_opset.constant(2.0, result_type).output(0) + ).output(0) + + # Select between odd and even results + median_result = ov_opset.select(is_odd, odd_result, even_result).output(0) + + # Remove the gathered dimension (squeeze) + median_result = ov_opset.squeeze(median_result, [axis]).output(0) + + # Handle keepdims + if keepdims: + if flattened_all: + # When axis=None, keepdims should restore all dimensions as 1 + ones_shape = ov_opset.constant( + [1] * original_rank, Type.i32 + ).output(0) + median_result = ov_opset.reshape( + median_result, ones_shape, False + ).output(0) + else: + median_result = ov_opset.unsqueeze(median_result, [axis]).output(0) + + return OpenVINOKerasTensor(median_result) + + +def meshgrid(*x, indexing="xy"): + if len(x) < 2: + raise ValueError( + "meshgrid requires at least 2 input arrays. " + f"Received: {len(x)} input array(s)." + ) + if indexing not in ("xy", "ij"): + raise ValueError("indexing must be either 'xy' or 'ij'") + + tensors = [get_ov_output(xi) for xi in x] + n = len(tensors) + + shapes = [ + ov_opset.shape_of(t, Type.i64).output(0) for t in tensors + ] # each is [Ni] + one = ov_opset.constant([1], Type.i64).output(0) + + if indexing == "xy": + shape_list = [shapes[1], shapes[0]] + shapes[2:] + out_shape = ov_opset.concat(shape_list, axis=0).output(0) + else: + out_shape = ov_opset.concat(shapes, axis=0).output(0) + + outputs = [] + for i, t in enumerate(tensors): + reshape_parts = [one] * n + if indexing == "xy": + if i == 0: + reshape_parts[1] = shapes[0] + elif i == 1: + reshape_parts[0] = shapes[1] + else: + reshape_parts[i] = shapes[i] + else: + reshape_parts[i] = shapes[i] + + reshape_shape = ov_opset.concat(reshape_parts, axis=0).output(0) + reshaped = ov_opset.reshape(t, reshape_shape, False).output(0) + broadcasted = ov_opset.broadcast(reshaped, out_shape).output(0) + outputs.append(OpenVINOKerasTensor(broadcasted)) + + return outputs + + +def min(x, axis=None, keepdims=False, initial=None): + x = get_ov_output(x) + original_type = x.get_element_type() + x_type = original_type + x_shape = x.get_partial_shape().to_shape() + + is_bool = x_type == Type.boolean + if is_bool: + x = ov_opset.convert(x, Type.i32).output(0) + x_type = Type.i32 + + if isinstance(axis, tuple) and len(axis) == 0: + return OpenVINOKerasTensor(x) + + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + + if isinstance(axis, tuple): + axis = list(axis) + + axis_const = ov_opset.constant(axis, Type.i32).output(0) + min_result = ov_opset.reduce_min(x, axis_const, keepdims).output(0) + + if initial is not None: + initial_tensor = ov_opset.constant(initial, x_type).output(0) + min_result = ov_opset.minimum(min_result, initial_tensor).output(0) + + if keepdims: + result_shape = [1] * len(x_shape) + min_result = ov_opset.reshape( + min_result, + ov_opset.constant(result_shape, Type.i32).output(0), + False, + ).output(0) + + if is_bool: + min_result = ov_opset.convert(min_result, Type.boolean).output(0) + + return OpenVINOKerasTensor(min_result) + + +def minimum(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1, x2 = _align_operand_types(x1, x2, "minimum()") + return OpenVINOKerasTensor(ov_opset.minimum(x1, x2).output(0)) + + +def mod(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1, x2 = _align_operand_types(x1, x2, "mod()") + return OpenVINOKerasTensor(ov_opset.floor_mod(x1, x2).output(0)) + + +def moveaxis(x, source, destination): + x = get_ov_output(x) + if isinstance(source, int): + source = [source] + if isinstance(destination, int): + destination = [destination] + + ndim = x.get_partial_shape().rank.get_length() + source = [axis if axis >= 0 else axis + ndim for axis in source] + destination = [axis if axis >= 0 else axis + ndim for axis in destination] + + axes = list(range(ndim)) + for src, dst in zip(source, destination): + axes.remove(src) + axes.insert(dst, src) + + axes_const = ov_opset.constant(axes, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.transpose(x, axes_const).output(0)) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + x = get_ov_output(x) + dtype = x.get_element_type() + if dtype.is_integral(): + return OpenVINOKerasTensor(x) + isfloat64 = True if dtype == Type.f64 else False + if isfloat64: # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/30264 + x = ov_opset.convert(x, Type.f32).output(0) + dtype = Type.f32 + nan_val = ov_opset.constant(nan, dtype).output(0) + posinf_val = ov_opset.constant( + posinf if posinf is not None else DTYPES_MAX[dtype], dtype + ).output(0) + neginf_val = ov_opset.constant( + neginf if neginf is not None else DTYPES_MIN[dtype], dtype + ).output(0) + posinf_mask = ov_opset.is_inf( + x, + {"detect_positive": True, "detect_negative": False}, + ).output(0) + neginf_mask = ov_opset.is_inf( + x, + {"detect_positive": False, "detect_negative": True}, + ).output(0) + nan_mask = ov_opset.is_nan(x).output(0) + x = ov_opset.select(nan_mask, nan_val, x).output(0) + x = ov_opset.select(posinf_mask, posinf_val, x).output(0) + x = ov_opset.select(neginf_mask, neginf_val, x).output(0) + if isfloat64: + x = ov_opset.convert(x, Type.f64).output(0) + return OpenVINOKerasTensor(x) + + +def ndim(x): + x = get_ov_output(x) + shape_tensor = ov_opset.shape_of(x, Type.i64).output(0) + rank_tensor = ov_opset.shape_of(shape_tensor, Type.i64).output(0) + return OpenVINOKerasTensor(rank_tensor) + + +def nonzero(x): + x = get_ov_output(x) + res = ov_opset.non_zero(data=x, output_type="i32").output(0) + return OpenVINOKerasTensor(res) + + +def not_equal(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "not_equal()") + return OpenVINOKerasTensor(ov_opset.not_equal(x1, x2).output(0)) + + +def zeros_like(x, dtype=None): + x = get_ov_output(x) + shape_x = ov_opset.shape_of(x) + if dtype is not None: + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + const_zero = ov_opset.constant(0, ov_type).output(0) + else: + const_zero = ov_opset.constant(0, x.get_element_type()).output(0) + res = ov_opset.broadcast(const_zero, shape_x).output(0) + return OpenVINOKerasTensor(res) + + +def ones_like(x, dtype=None): + x = get_ov_output(x) + shape_x = ov_opset.shape_of(x) + if dtype is not None: + ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)] + const_one = ov_opset.constant(1, ov_type).output(0) + else: + const_one = ov_opset.constant(1, x.get_element_type()).output(0) + res = ov_opset.broadcast(const_one, shape_x).output(0) + return OpenVINOKerasTensor(res) + + +def outer(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + + x1, x2 = _align_operand_types(x1, x2, "outer()") + + new_shape_x1 = ov_opset.constant([-1, 1], Type.i32).output(0) + new_shape_x2 = ov_opset.constant([1, -1], Type.i32).output(0) + + # Reshape directly from original tensors + x1_reshaped = ov_opset.reshape(x1, new_shape_x1, False).output(0) + x2_reshaped = ov_opset.reshape(x2, new_shape_x2, False).output(0) + + result = ov_opset.multiply(x1_reshaped, x2_reshaped).output(0) + + return OpenVINOKerasTensor(result) + + +def pad(x, pad_width, mode="constant", constant_values=None): + x = get_ov_output(x) + pad_value = None + if constant_values is not None: + if mode != "constant": + raise ValueError( + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`. " + f"Received: mode={mode}" + ) + assert isinstance(constant_values, int), ( + "`pad` operation supports only scalar pad value " + "in constant mode by openvino backend" + ) + pad_value = constant_values + + # split pad_width into two tensors pads_begin and pads_end + pads_begin = [] + pads_end = [] + for pads_pair in pad_width: + pads_begin.append(pads_pair[0]) + pads_end.append(pads_pair[1]) + pads_begin = ov_opset.constant(pads_begin, Type.i32).output(0) + pads_end = ov_opset.constant(pads_end, Type.i32).output(0) + return OpenVINOKerasTensor( + ov_opset.pad(x, pads_begin, pads_end, mode, pad_value).output(0) + ) + + +def prod(x, axis=None, keepdims=False, dtype=None): + x = get_ov_output(x) + + # If a specific dtype is requested, cast the input to that dtype. + if dtype is not None: + ov_dtype = OPENVINO_DTYPES[standardize_dtype(dtype)] + x = ov_opset.convert(x, ov_dtype).output(0) + # Otherwise, apply dtype promotion rules before reduction. + else: + x_type = x.get_element_type() + if x_type == Type.boolean: + x = ov_opset.convert(x, Type.i32).output(0) + elif x_type in (Type.i8, Type.i16): + x = ov_opset.convert(x, Type.i32).output(0) + elif x_type in (Type.u8, Type.u16): + x = ov_opset.convert(x, Type.u32).output(0) + + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + + if isinstance(axis, tuple): + axis = list(axis) + axis = ov_opset.constant(axis, Type.i32).output(0) + + # Compute the product + result = ov_opset.reduce_prod(x, axis, keepdims).output(0) + + return OpenVINOKerasTensor(result) + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + raise NotImplementedError( + "`quantile` is not supported with openvino backend" + ) + + +def ravel(x): + x = get_ov_output(x) + target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0) + return OpenVINOKerasTensor( + ov_opset.reshape(x, target_shape, special_zero=False).output(0) + ) + + +def real(x): + raise NotImplementedError("`real` is not supported with openvino backend") + + +def reciprocal(x): + x = get_ov_output(x) + one_constant = ov_opset.constant(1, dtype=x.get_element_type()).output(0) + x = ov_opset.divide(one_constant, x).output(0) + return OpenVINOKerasTensor(x) + + +def repeat(x, repeats, axis=None): + x = get_ov_output(x) + const_0 = ov_opset.constant(0, Type.i32) + const_1 = ov_opset.constant(1, Type.i32) + const_neg_1 = ov_opset.constant([-1], Type.i32) + + if axis is not None and axis < 0: + axis += len(x.get_partial_shape()) + + if axis is None: + x = ov_opset.reshape(x, const_neg_1, special_zero=False) + axis = 0 + + if isinstance(repeats, (int, np.integer)) or ( + isinstance(repeats, np.ndarray) + and repeats.ndim == 1 + and repeats.size == 1 + ): + repeats_val = ( + int(repeats) + if isinstance(repeats, (np.integer, np.ndarray)) + else repeats + ) + dim_len = ov_opset.gather( + ov_opset.shape_of(x, Type.i32), + ov_opset.constant([axis], Type.i32), + const_0, + ) + dim_len = ov_opset.squeeze(dim_len, ov_opset.constant([0], Type.i32)) + idx_range = ov_opset.range( + const_0, dim_len, const_1, output_type=Type.i32 + ) + idx_range = ov_opset.unsqueeze(idx_range, const_1) + tiled = ov_opset.tile( + idx_range, ov_opset.constant([1, repeats_val], Type.i32) + ) + idx = ov_opset.reshape(tiled, const_neg_1, special_zero=False) + result = ov_opset.gather(x, idx, ov_opset.constant(axis, Type.i32)) + return OpenVINOKerasTensor(result.output(0)) + repeats_tensor = get_ov_output(repeats) + cumsum = ov_opset.cumsum(repeats_tensor, const_0) + total = ov_opset.reduce_sum( + repeats_tensor, ov_opset.constant([0], Type.i32), keep_dims=False + ) + total = ov_opset.convert(total, Type.i32) + out_indices = ov_opset.range(const_0, total, const_1, output_type=Type.i32) + cumsum_unsq = ov_opset.unsqueeze(cumsum, const_0) + out_indices_unsq = ov_opset.unsqueeze(out_indices, const_1) + cumsum_unsq = ov_opset.convert(cumsum_unsq, Type.i32) + mask = ov_opset.greater_equal(out_indices_unsq, cumsum_unsq) + gather_indices = ov_opset.reduce_sum( + ov_opset.convert(mask, Type.i32), ov_opset.constant([1], Type.i32) + ) + result = ov_opset.gather( + x, gather_indices, ov_opset.constant(axis, Type.i32) + ) + return OpenVINOKerasTensor(result.output(0)) + + +def reshape(x, newshape): + x = get_ov_output(x) + if isinstance(newshape, tuple): + newshape = list(newshape) + newshape = ov_opset.constant(newshape, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.reshape(x, newshape, False).output(0)) + + +def roll(x, shift, axis=None): + raise NotImplementedError("`roll` is not supported with openvino backend") + + +def sign(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.sign(x).output(0)) + + +def signbit(x): + raise NotImplementedError( + "`signbit` is not supported with openvino backend" + ) + + +def sin(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.sin(x).output(0)) + + +def sinh(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.sinh(x).output(0)) + + +def size(x): + raise NotImplementedError("`size` is not supported with openvino backend") + + +def sort(x, axis=-1): + x = get_ov_output(x) + x_shape = x.get_partial_shape() + rank = x_shape.rank.get_length() + + if rank == 0: + return OpenVINOKerasTensor(x) + + # Handle axis=None by flattening the input + if axis is None: + x = ov_opset.reshape( + x, ov_opset.constant([-1], Type.i32), False + ).output(0) + axis = 0 + # Handle negative axis + elif axis < 0: + axis = rank + axis + + # Get the size of the dimension to sort + shape_tensor = ov_opset.shape_of(x, output_type=Type.i32).output(0) + k = ov_opset.gather( + shape_tensor, + ov_opset.constant([axis], Type.i32).output(0), + ov_opset.constant(0, Type.i32).output(0), + ).output(0) + + # Convert k to a scalar value + k_scalar = ov_opset.squeeze(k, ov_opset.constant([0], Type.i32)).output(0) + + # Use topk with k=size_of_axis to get all elements sorted + topk_outputs = ov_opset.topk( + x, k=k_scalar, axis=axis, mode="min", sort="value", stable=True + ) + + # Get the sorted values + sorted_values = topk_outputs.output(0) + + return OpenVINOKerasTensor(sorted_values) + + +def split(x, indices_or_sections, axis=0): + x = get_ov_output(x) + axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0) + + shape_tensor = ov_opset.shape_of(x) + axis_i32 = ov_opset.constant([axis], dtype=Type.i32) + dim_at_axis_tensor = ov_opset.gather( + shape_tensor, axis_i32, ov_opset.constant(0, dtype=Type.i32) + ) + + if isinstance(indices_or_sections, int): + num_splits = indices_or_sections + splits = ov_opset.split(x, axis_tensor, num_splits=num_splits) + result = [] + for i in range(num_splits): + result.append(OpenVINOKerasTensor(splits.output(i))) + return result + + if isinstance(indices_or_sections, (list, tuple, np.ndarray)): + indices = list(indices_or_sections) + split_lengths = [] + split_lengths.append(indices[0]) + for i in range(1, len(indices)): + split_lengths.append(indices[i] - indices[i - 1]) + + last_index_tensor = ov_opset.constant(indices[-1], dtype=Type.i64) + remaining_length_tensor = ov_opset.subtract( + dim_at_axis_tensor, last_index_tensor + ) + + length_parts = [] + length_parts.append(ov_opset.constant(split_lengths, dtype=Type.i64)) + length_parts.append(remaining_length_tensor) + length_tensor = ov_opset.concat(length_parts, axis=0) + + splits = ov_opset.variadic_split(x, axis_tensor, length_tensor) + result = [] + for i in range(len(split_lengths) + 1): + result.append(OpenVINOKerasTensor(splits.output(i))) + return result + + raise TypeError( + f"unsupported type of indices_or_sections: {type(indices_or_sections)}" + ) + + +def stack(x, axis=0): + if isinstance(x, tuple): + x = list(x) + assert isinstance(x, list), "`stack` supports only `x` as list or tuple" + elems = [get_ov_output(e) for e in x] + ref = elems[0] + for i in range(1, len(elems)): + ref, elems[i] = _align_operand_types(ref, elems[i], "stack()") + elems[0] = ref + const_axis = ov_opset.constant(axis, Type.i32).output(0) + elems = [ov_opset.unsqueeze(e, const_axis).output(0) for e in elems] + res = ov_opset.concat(elems, axis).output(0) + return OpenVINOKerasTensor(res) + + +def std(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis = ov_opset.constant(axis, Type.i32).output(0) + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + mean = ov_opset.reduce_mean(x, axis, keepdims).output(0) + const_two = ov_opset.constant(2, x.get_element_type()).output(0) + squared_x = ov_opset.power(x, const_two).output(0) + squared_mean = ov_opset.power(mean, const_two).output(0) + squared_x_mean = ov_opset.reduce_mean(squared_x, axis, keepdims) + variance = ov_opset.subtract(squared_x_mean, squared_mean).output(0) + std_var = OpenVINOKerasTensor(ov_opset.sqrt(variance).output(0)) + return std_var + + +def swapaxes(x, axis1, axis2): + raise NotImplementedError( + "`swapaxes` is not supported with openvino backend" + ) + + +def take(x, indices, axis=None): + x = get_ov_output(x) + indices = get_ov_output(indices) + if axis is None: + target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0) + x = ov_opset.reshape(x, target_shape, False).output(0) + axis = ov_opset.constant(0, dtype=Type.i32).output(0) + else: + axis = ov_opset.constant(axis, dtype=Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.gather(x, indices, axis).output(0)) + + +def take_along_axis(x, indices, axis=None): + x = get_ov_output(x) + indices = get_ov_output(indices) + + if axis is None: + target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0) + x_flat = ov_opset.reshape(x, target_shape, False).output(0) + indices_flat = ov_opset.reshape(indices, target_shape, False).output(0) + result = ov_opset.gather_elements(x_flat, indices_flat, 0).output(0) + return OpenVINOKerasTensor(result) + + x_rank = len(x.get_partial_shape()) + if axis < 0: + axis += x_rank + + x_shape = ov_opset.shape_of(x, Type.i32).output(0) + indices_shape = ov_opset.shape_of(indices, Type.i32).output(0) + + zero_const = ov_opset.constant(0, dtype=Type.i32).output(0) + axis_index = ov_opset.constant([axis], dtype=Type.i32).output(0) + + # Fix negative indices + dim_size = ov_opset.squeeze( + ov_opset.gather(x_shape, axis_index, zero_const).output(0), zero_const + ).output(0) + zero_scalar = ov_opset.constant(0, indices.get_element_type()).output(0) + is_neg = ov_opset.less(indices, zero_scalar).output(0) + dim_size_cast = ov_opset.convert( + dim_size, indices.get_element_type() + ).output(0) + indices = ov_opset.select( + is_neg, ov_opset.add(indices, dim_size_cast).output(0), indices + ).output(0) + indices = ov_opset.convert(indices, Type.i32).output(0) + + x_target_parts, indices_target_parts = [], [] + + for i in range(x_rank): + dim_idx = ov_opset.constant([i], dtype=Type.i32).output(0) + x_dim = ov_opset.gather(x_shape, dim_idx, zero_const).output(0) + indices_dim = ov_opset.gather( + indices_shape, dim_idx, zero_const + ).output(0) + + if i == axis: + # For axis dimension: keep original dimensions + x_target_parts.append(x_dim) + indices_target_parts.append(indices_dim) + else: + # For other dimensions: use maximum for broadcasting + max_dim = ov_opset.maximum(x_dim, indices_dim).output(0) + x_target_parts.append(max_dim) + indices_target_parts.append(max_dim) + + x_target_shape = ov_opset.concat(x_target_parts, axis=0).output(0) + indices_target_shape = ov_opset.concat(indices_target_parts, axis=0).output( + 0 + ) + + # Broadcast to target shapes and gather elements + x_broadcasted = ov_opset.broadcast(x, x_target_shape).output(0) + indices_broadcasted = ov_opset.broadcast( + indices, indices_target_shape + ).output(0) + result = ov_opset.gather_elements( + x_broadcasted, indices_broadcasted, axis + ).output(0) + + return OpenVINOKerasTensor(result) + + +def tan(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.tan(x).output(0)) + + +def tanh(x): + x = get_ov_output(x) + x_type = x.get_element_type() + if x_type.is_integral(): + ov_type = OPENVINO_DTYPES[config.floatx()] + x = ov_opset.convert(x, ov_type) + return OpenVINOKerasTensor(ov_opset.tanh(x).output(0)) + + +def tensordot(x1, x2, axes=2): + raise NotImplementedError( + "`tensordot` is not supported with openvino backend" + ) + + +def round(x, decimals=0): + raise NotImplementedError("`round` is not supported with openvino backend") + + +def tile(x, repeats): + raise NotImplementedError("`tile` is not supported with openvino backend") + + +def trace(x, offset=0, axis1=0, axis2=1): + raise NotImplementedError("`trace` is not supported with openvino backend") + + +def tri(N, M=None, k=0, dtype=None): + if M is None: + M = N + if dtype is None: + dtype = "float32" + + ov_dtype = OPENVINO_DTYPES[dtype] + + def ensure_constant(value, default_type=Type.i32): + if isinstance(value, (int, float)): + return ov_opset.constant(value, default_type) + elif hasattr(value, "get_element_type"): + if value.get_element_type() != Type.i32: + value = ov_opset.convert(value, Type.i32) + return ov_opset.squeeze(value, ov_opset.constant([0], Type.i32)) + else: + return ov_opset.constant(value, default_type) + + N_const = ensure_constant(N) + M_const = ensure_constant(M) + k_const = ensure_constant(k) + + # Create row and column indices + row_range = ov_opset.range( + ov_opset.constant(0, Type.i32), + N_const, + ov_opset.constant(1, Type.i32), + output_type=Type.i32, + ) + col_range = ov_opset.range( + ov_opset.constant(0, Type.i32), + M_const, + ov_opset.constant(1, Type.i32), + output_type=Type.i32, + ) + + # Reshape indices for broadcasting + row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32)) + col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32)) + + mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const)) + + if ov_dtype == Type.boolean: + result = mask + else: + result = ov_opset.convert(mask, ov_dtype) + + return OpenVINOKerasTensor(result.output(0)) + + +def tril(x, k=0): + x = get_ov_output(x) + ov_type = x.get_element_type() + shape = ov_opset.shape_of(x, Type.i32) + zero_const = ov_opset.constant(0, Type.i32) + minus2 = ov_opset.constant([-2], Type.i32) + minus1 = ov_opset.constant([-1], Type.i32) + M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const) + N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const) + tri_mask = tri(M, N, k=k, dtype="bool").output + mask = ov_opset.convert(tri_mask, ov_type) + if ov_type == Type.boolean: + out = ov_opset.logical_and(x, mask) + else: + out = ov_opset.multiply(x, mask) + return OpenVINOKerasTensor(out.output(0)) + + +def triu(x, k=0): + x = get_ov_output(x) + ov_type = x.get_element_type() + shape = ov_opset.shape_of(x, Type.i32) + zero_const = ov_opset.constant(0, Type.i32) + minus2 = ov_opset.constant([-2], Type.i32) + minus1 = ov_opset.constant([-1], Type.i32) + M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const) + N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const) + tri_mask = tri(M, N, k=k - 1, dtype="bool").output + if ov_type == Type.boolean: + mask = ov_opset.logical_not(tri_mask) + else: + const_one = ov_opset.constant(1, ov_type) + converted_mask = ov_opset.convert(tri_mask, ov_type) + mask = ov_opset.subtract(const_one, converted_mask) + if ov_type == Type.boolean: + out = ov_opset.logical_and(x, mask) + else: + out = ov_opset.multiply(x, mask) + return OpenVINOKerasTensor(out.output(0)) + + +def vdot(x1, x2): + raise NotImplementedError("`vdot` is not supported with openvino backend") + + +def vstack(xs): + raise NotImplementedError("`vstack` is not supported with openvino backend") + + +def vectorize(pyfunc, *, excluded=None, signature=None): + raise NotImplementedError( + "`vectorize` is not supported with openvino backend" + ) + + +def where(condition, x1=None, x2=None): + condition = get_ov_output(condition) + if x1 is None and x2 is None: + nonzero_indices = ov_opset.non_zero(condition) + return OpenVINOKerasTensor(nonzero_indices.output(0)) + if x1 is None: + return OpenVINOKerasTensor(condition) + if x2 is None: + raise ValueError("x2 must be provided if x1 is specified.") + + def cast_literal_like_tensor(literal, x): + ov_type = get_ov_output(x).get_element_type() + is_bool = ov_type == Type.boolean + is_float_to_int = isinstance(literal, float) and ov_type.is_integral() + if is_bool or is_float_to_int: + return get_ov_output(literal), get_ov_output(x) + return get_ov_output(literal, ov_type), get_ov_output(x) + + if isinstance(x1, (int, float)): + x1, x2 = cast_literal_like_tensor(x1, x2) + elif isinstance(x2, (int, float)): + x2, x1 = cast_literal_like_tensor(x2, x1) + else: + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1, x2 = _align_operand_types(x1, x2, "select()") + return OpenVINOKerasTensor(ov_opset.select(condition, x1, x2).output(0)) + + +def divide(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1_type = ov_to_keras_type(x1.get_element_type()) + x2_type = ov_to_keras_type(x2.get_element_type()) + result_type = dtypes.result_type(x1_type, x2_type, float) + result_type = OPENVINO_DTYPES[result_type] + x1 = ov_opset.convert(x1, result_type).output(0) + x2 = ov_opset.convert(x2, result_type).output(0) + return OpenVINOKerasTensor(ov_opset.divide(x1, x2).output(0)) + + +def divide_no_nan(x1, x2): + raise NotImplementedError( + "`divide_no_nan` is not supported with openvino backend" + ) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + element_type = None + if isinstance(x1, OpenVINOKerasTensor): + element_type = x1.output.get_element_type() + if isinstance(x2, OpenVINOKerasTensor): + element_type = x2.output.get_element_type() + x1 = get_ov_output(x1, element_type) + x2 = get_ov_output(x2, element_type) + x1, x2 = _align_operand_types(x1, x2, "power()") + return OpenVINOKerasTensor(ov_opset.power(x1, x2).output(0)) + + +def negative(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.negative(x).output(0)) + + +def square(x): + x = get_ov_output(x) + const_two = ov_opset.constant(2, x.get_element_type()).output(0) + return OpenVINOKerasTensor(ov_opset.power(x, const_two).output(0)) + + +def sqrt(x): + x = get_ov_output(x) + return OpenVINOKerasTensor(ov_opset.sqrt(x).output(0)) + + +def squeeze(x, axis=None): + x = get_ov_output(x) + if axis is None: + axis = [] + for idx, dim in enumerate(x.get_partial_shape()): + if dim == 1: + axis.append(idx) + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.squeeze(x, axis).output(0)) + + +def transpose(x, axes=None): + x = get_ov_output(x) + if axes is None: + # generate reverse permutation vector + shape_x = ov_opset.shape_of(x, "i64").output(0) + rank_x = ov_opset.shape_of(shape_x, "i64").output(0) + scalar_shape = ov_opset.constant([], Type.i32).output(0) + rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0) + const_minus_one = ov_opset.constant(-1, Type.i64).output(0) + rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0) + axes = ov_opset.range( + rank_minus_one, const_minus_one, const_minus_one, "i64" + ).output(0) + else: + if isinstance(axes, tuple): + axes = list(axes) + axes = ov_opset.constant(axes, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.transpose(x, axes).output(0)) + + +def var(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis = ov_opset.constant(axis, Type.i32).output(0) + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + mean = ov_opset.reduce_mean(x, axis, keepdims).output(0) + const_two = ov_opset.constant(2, x.get_element_type()).output(0) + squared_x = ov_opset.power(x, const_two).output(0) + squared_mean = ov_opset.power(mean, const_two).output(0) + squared_x_mean = ov_opset.reduce_mean(squared_x, axis, keepdims) + variance = OpenVINOKerasTensor( + ov_opset.subtract(squared_x_mean, squared_mean).output(0) + ) + return variance + + +def sum(x, axis=None, keepdims=False): + x = get_ov_output(x) + if axis is None: + flatten_shape = ov_opset.constant([-1], Type.i32).output(0) + x = ov_opset.reshape(x, flatten_shape, False).output(0) + axis = 0 + axis = ov_opset.constant(axis, Type.i32).output(0) + return OpenVINOKerasTensor(ov_opset.reduce_sum(x, axis, keepdims).output(0)) + + +def eye(N, M=None, k=0, dtype=None): + raise NotImplementedError("`eye` is not supported with openvino backend") + + +def floor_divide(x1, x2): + raise NotImplementedError( + "`floor_divide` is not supported with openvino backend" + ) + + +def logical_xor(x1, x2): + x1 = get_ov_output(x1) + x2 = get_ov_output(x2) + x1 = ov_opset.convert(x1, Type.boolean).output(0) + x2 = ov_opset.convert(x2, Type.boolean).output(0) + return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0)) + + +def corrcoef(x): + raise NotImplementedError( + "`corrcoef` is not supported with openvino backend" + ) + + +def correlate(x1, x2, mode="valid"): + raise NotImplementedError( + "`correlate` is not supported with openvino backend" + ) + + +def select(condlist, choicelist, default=0): + raise NotImplementedError("`select` is not supported with openvino backend") + + +def slogdet(x): + raise NotImplementedError( + "`slogdet` is not supported with openvino backend" + ) + + +def argpartition(x, kth, axis=-1): + raise NotImplementedError( + "`argpartition` is not supported with openvino backend" + ) diff --git a/keras/src/backend/openvino/random.py b/keras/src/backend/openvino/random.py new file mode 100644 index 000000000000..38de21294677 --- /dev/null +++ b/keras/src/backend/openvino/random.py @@ -0,0 +1,149 @@ +import numpy as np +import openvino.opset14 as ov_opset +from openvino import Type + +from keras.src.backend.config import floatx +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import OpenVINOKerasTensor +from keras.src.backend.openvino.core import convert_to_numpy +from keras.src.backend.openvino.core import get_ov_output +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed.data) + normal_const = rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype) + return OpenVINOKerasTensor(ov_opset.constant(normal_const).output(0)) + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed_val = draw_seed(seed) + if isinstance(seed_val, OpenVINOKerasTensor): + seed_data = convert_to_numpy(seed_val) + else: + seed_data = seed_val.data + rng = np.random.default_rng(seed_data) + random_values = rng.uniform(minval, maxval, size=shape).astype(dtype) + return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0)) + + +def categorical(logits, num_samples, dtype="int64", seed=None): + dtype = dtype or "int64" + ov_dtype = OPENVINO_DTYPES[dtype] + logits = get_ov_output(logits) + + zero_const = ov_opset.constant(0, Type.i32).output(0) + one_const = ov_opset.constant(1, Type.i32).output(0) + neg_one_const = ov_opset.constant(-1, Type.i32).output(0) + + # Compute probabilities and cumulative sum + probs = ov_opset.softmax(logits, axis=-1).output(0) + cumsum_probs = ov_opset.cumsum(probs, neg_one_const).output(0) + + # Get shape and compute batch dimensions + logits_shape = ov_opset.shape_of(logits, Type.i32).output(0) + rank = ov_opset.shape_of(logits_shape, Type.i32).output(0) + rank_scalar = ov_opset.squeeze(rank, zero_const).output(0) + rank_minus_1 = ov_opset.subtract(rank_scalar, one_const).output(0) + + # Extract batch shape (all dimensions except last) + batch_indices = ov_opset.range( + zero_const, rank_minus_1, one_const, output_type=Type.i32 + ).output(0) + batch_shape = ov_opset.gather(logits_shape, batch_indices, axis=0).output(0) + + # Create final shape [batch_dims..., num_samples] + num_samples_const = ov_opset.constant([num_samples], Type.i32).output(0) + final_shape = ov_opset.concat( + [batch_shape, num_samples_const], axis=0 + ).output(0) + + seed_tensor = draw_seed(seed) + if isinstance(seed_tensor, OpenVINOKerasTensor): + seed1, seed2 = convert_to_numpy(seed_tensor) + else: + seed1, seed2 = seed_tensor.data + + probs_dtype = probs.get_element_type() + zero_float = ov_opset.constant(0.0, probs_dtype).output(0) + one_float = ov_opset.constant(1.0, probs_dtype).output(0) + + rand = ov_opset.random_uniform( + final_shape, zero_float, one_float, probs_dtype, seed1, seed2 + ).output(0) + + rand_unsqueezed = ov_opset.unsqueeze(rand, neg_one_const).output(0) + cumsum_unsqueezed = ov_opset.unsqueeze(cumsum_probs, one_const).output(0) + + # Count how many cumulative probabilities each random number exceeds + greater = ov_opset.greater(rand_unsqueezed, cumsum_unsqueezed).output(0) + samples = ov_opset.reduce_sum( + ov_opset.convert(greater, Type.i32).output(0), neg_one_const + ).output(0) + + result = ov_opset.convert(samples, ov_dtype).output(0) + return OpenVINOKerasTensor(result) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + raise NotImplementedError( + "`randint` is not supported with openvino backend" + ) + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = draw_seed(seed) + rng = np.random.default_rng(seed.data) + + lower_bound = mean - 2 * stddev + upper_bound = mean + 2 * stddev + + flat_shape = np.prod(shape) + random_numbers = np.empty(0) + + # loop until we have enough valid numbers to fill our desired shape + while random_numbers.shape[0] < flat_shape: + # Generate a batch of random numbers from a normal distribution + batch = rng.normal(loc=mean, scale=stddev, size=flat_shape) + + # Filter the numbers to keep only those within the specified bounds + valid = batch[(batch >= lower_bound) & (batch <= upper_bound)] + + # Append the valid numbers to the result array + random_numbers = np.append(random_numbers, valid) + + # Truncate the result array to the desired size and reshape it + np_array_res = random_numbers[:flat_shape].astype(dtype).reshape(shape) + return OpenVINOKerasTensor(ov_opset.constant(np_array_res).output(0)) + + +def dropout(inputs, rate, noise_shape=None, seed=None): + raise NotImplementedError( + "`dropout` is not supported with openvino backend" + ) + + +def shuffle(x, axis=0, seed=None): + raise NotImplementedError( + "`shuffle` is not supported with openvino backend" + ) + + +def gamma(shape, alpha, dtype=None, seed=None): + raise NotImplementedError("`gamma` is not supported with openvino backend") + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + raise NotImplementedError( + "`binomial` is not supported with openvino backend" + ) + + +def beta(shape, alpha, beta, dtype=None, seed=None): + raise NotImplementedError("`beta` is not supported with openvino backend") diff --git a/keras/src/backend/openvino/rnn.py b/keras/src/backend/openvino/rnn.py new file mode 100644 index 000000000000..70190fc47c8b --- /dev/null +++ b/keras/src/backend/openvino/rnn.py @@ -0,0 +1,38 @@ +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + raise NotImplementedError("`rnn` is not supported with openvino backend") + + +def lstm(*args, **kwargs): + raise NotImplementedError("`lstm` is not supported with openvino backend") + + +def gru(*args, **kwargs): + raise NotImplementedError("`gru` is not supported with openvino backend") + + +def unstack(x, axis=0): + raise NotImplementedError( + "`unstack` is not supported with openvino backend" + ) + + +def numpy_scan(f, init, xs, reverse=False, mask=None): + raise NotImplementedError( + "`numpy_scan` is not supported with openvino backend" + ) + + +def cudnn_ok(*args, **kwargs): + return False diff --git a/keras/src/backend/openvino/trainer.py b/keras/src/backend/openvino/trainer.py new file mode 100644 index 000000000000..ac2e64a8060c --- /dev/null +++ b/keras/src/backend/openvino/trainer.py @@ -0,0 +1,272 @@ +import numpy as np +import openvino as ov +import openvino.opset14 as ov_opset + +from keras.src import backend +from keras.src import callbacks as callbacks_module +from keras.src import tree +from keras.src.backend.openvino.core import OPENVINO_DTYPES +from keras.src.backend.openvino.core import OpenVINOKerasTensor +from keras.src.backend.openvino.core import get_device +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + + +class OpenVINOTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.test_function = None + self.predict_function = None + self.ov_compiled_model = None + self.ov_device = None + self.struct_params = None + self.struct_outputs = None + + def _unpack_singleton(self, x): + if isinstance(x, (list, tuple)) and len(x) == 1: + return x[0] + return x + + def test_step(self, data): + raise NotImplementedError( + "`test_step` is not supported with openvino backend" + ) + + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + ov_compiled_model = self._get_compiled_model(x) + flatten_x = tree.flatten(x) + y_pred = ov_compiled_model(flatten_x) + # recover structure of the model output + y_pred = self._unpack_singleton( + tree.pack_sequence_as(self.struct_outputs, y_pred.to_tuple()) + ) + return y_pred + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + + def one_test_step(data): + data = data[0] + return self.test_step(data) + + def multi_test_steps(data): + for single_step_data in data: + logs = one_test_step([single_step_data]) + return logs + + if self.steps_per_execution > 1: + test_step = multi_test_steps + else: + test_step = one_test_step + + self.test_function = test_step + + def _parameterize_data(self, data): + if isinstance(data, (list, tuple)): + parametrize_data = [] + for elem in data: + param_elem = self._parameterize_data(elem) + parametrize_data.append(param_elem) + elif isinstance(data, dict): + parametrize_data = dict() + for elem_name, elem in data.items(): + param_elem = self._parameterize_data(elem) + parametrize_data[elem_name] = param_elem + elif isinstance(data, np.ndarray) or np.isscalar(data): + ov_type = OPENVINO_DTYPES[str(data.dtype)] + ov_shape = list(data.shape) + param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) + parametrize_data = OpenVINOKerasTensor(param.output(0)) + elif isinstance(data, int): + param = ov_opset.parameter(shape=[], dtype=ov.Type.i32) + parametrize_data = OpenVINOKerasTensor(param.output(0)) + elif isinstance(data, float): + param = ov_opset.parameter(shape=[], dtype=ov.Type.f32) + parametrize_data = OpenVINOKerasTensor(param.output(0)) + else: + raise "Unknown type of input data {}".format(type(data)) + return parametrize_data + + def _get_compiled_model(self, data): + if ( + self.ov_compiled_model is not None + and get_device() == self.ov_device + ): + return self.ov_compiled_model + + # remove the previous cached compiled model if exists + del self.ov_compiled_model + + # prepare parameterized input + self.struct_params = self._parameterize_data(data) + # construct OpenVINO graph during calling Keras Model + self.struct_outputs = self(self.struct_params) + + parameters = [] + for p in tree.flatten(self.struct_params): + parameters.append(p.output.get_node()) + results = [] + for r in tree.flatten(self.struct_outputs): + results.append(ov_opset.result(r.output)) + + # prepare compiled model from scratch + ov_model = ov.Model(results=results, parameters=parameters) + self.ov_compiled_model = ov.compile_model(ov_model, get_device()) + self.ov_device = get_device() + return self.ov_compiled_model + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + def one_predict_step(data): + data = data[0] + return self.predict_step(data) + + def multi_predict_steps(data): + outputs = one_predict_step(data[:1]) + + for single_step_data in data[1:]: + step_outputs = one_predict_step([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: np.concatenate([t1, t2]), + outputs, + step_outputs, + ) + return outputs + + if self.steps_per_execution > 1: + predict_step = multi_predict_steps + else: + predict_step = one_predict_step + + self.predict_function = predict_step + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + raise NotImplementedError( + "`fit` is not supported with openvino backend" + ) + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = EpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + outputs = None + for begin_step, end_step, data in epoch_iterator.enumerate_epoch(): + callbacks.on_predict_batch_begin(begin_step) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) + if self.stop_predicting: + break + callbacks.on_predict_end() + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + raise NotImplementedError( + "`evaluate` is not supported with openvino backend" + ) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "`train_on_batch` is not supported with openvino backend" + ) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + raise NotImplementedError( + "`test_on_batch` is not supported with openvino backend" + ) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function([(x,)]) + batch_outputs = tree.map_structure( + backend.convert_to_numpy, batch_outputs + ) + return batch_outputs diff --git a/keras/src/backend/tensorflow/__init__.py b/keras/src/backend/tensorflow/__init__.py new file mode 100644 index 000000000000..ea4eed39b8da --- /dev/null +++ b/keras/src/backend/tensorflow/__init__.py @@ -0,0 +1,30 @@ +from keras.src.backend.tensorflow import core +from keras.src.backend.tensorflow import distribution_lib +from keras.src.backend.tensorflow import image +from keras.src.backend.tensorflow import linalg +from keras.src.backend.tensorflow import math +from keras.src.backend.tensorflow import nn +from keras.src.backend.tensorflow import numpy +from keras.src.backend.tensorflow import random +from keras.src.backend.tensorflow import tensorboard +from keras.src.backend.tensorflow.core import IS_THREAD_SAFE +from keras.src.backend.tensorflow.core import SUPPORTS_RAGGED_TENSORS +from keras.src.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.tensorflow.core import Variable +from keras.src.backend.tensorflow.core import cast +from keras.src.backend.tensorflow.core import compute_output_spec +from keras.src.backend.tensorflow.core import cond +from keras.src.backend.tensorflow.core import convert_to_numpy +from keras.src.backend.tensorflow.core import convert_to_tensor +from keras.src.backend.tensorflow.core import device_scope +from keras.src.backend.tensorflow.core import is_tensor +from keras.src.backend.tensorflow.core import name_scope +from keras.src.backend.tensorflow.core import random_seed_dtype +from keras.src.backend.tensorflow.core import scatter +from keras.src.backend.tensorflow.core import shape +from keras.src.backend.tensorflow.core import stop_gradient +from keras.src.backend.tensorflow.core import vectorized_map +from keras.src.backend.tensorflow.rnn import cudnn_ok +from keras.src.backend.tensorflow.rnn import gru +from keras.src.backend.tensorflow.rnn import lstm +from keras.src.backend.tensorflow.rnn import rnn diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py new file mode 100644 index 000000000000..6896b74c519c --- /dev/null +++ b/keras/src/backend/tensorflow/core.py @@ -0,0 +1,698 @@ +import builtins + +import numpy as np +import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice + +from keras.src import tree +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import global_state +from keras.src.backend.common import is_int_dtype +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.backend_utils import slice_along_axis +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.name_scope import name_scope as base_name_scope +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.tensorflow.sparse import sparse_to_dense +from keras.src.utils.naming import auto_name + +SUPPORTS_SPARSE_TENSORS = True +SUPPORTS_RAGGED_TENSORS = True +# https://github.com/tensorflow/tensorflow/issues/78338 +IS_THREAD_SAFE = False + + +class Variable( + KerasVariable, + tf.__internal__.types.Tensor, + tf.__internal__.tracking.Trackable, +): + _should_act_as_resource_variable = True + + @property + def handle(self): + return self.value.handle + + def _initialize(self, value): + if isinstance(value, tf.Variable): + self._value = value + else: + self._value = tf.Variable( + value, + dtype=self._dtype, + trainable=self.trainable, + name=self.name, + aggregation=self._map_aggregation(self.aggregation), + synchronization=self._map_synchronization(self.synchronization), + ) + + def _initialize_with_initializer(self, initializer): + self._initialize(lambda: initializer(self._shape, dtype=self._dtype)) + + def _deferred_initialize(self): + if self._value is not None: + raise ValueError(f"Variable {self.path} is already initialized.") + + if in_stateless_scope(): + raise ValueError( + "You are attempting to initialize a variable " + "while in a stateless scope. This is disallowed. " + "Make sure that all variables are initialized " + "before you start using your layer/model objects." + ) + with tf.init_scope(): + self._initialize_with_initializer(self._initializer) + self._initializer = None + + def _direct_assign(self, value): + self._value.assign(tf.cast(value, self._value.dtype)) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) + + def numpy(self): # noqa: F811 + return self.value.numpy() + + @property + def shape(self): + return tf.TensorShape(super().shape) + + # Overload native accessor. + def __tf_tensor__(self, dtype=None, name=None): + return tf.convert_to_tensor(self.value, dtype=dtype, name=name) + + # Methods below are for SavedModel support + @property + def _shared_name(self): + return self.value._shared_name + + def _serialize_to_tensors(self): + try: + return self.value._serialize_to_tensors() + except NotImplementedError: + return {"VARIABLE_VALUE": self.value} + + def _restore_from_tensors(self, restored_tensors): + try: + return self.value._restore_from_tensors(restored_tensors) + except NotImplementedError: + self.assign(restored_tensors["VARIABLE_VALUE"]) + return self.value + + def _copy_trackable_to_cpu(self, object_map): + self.value._copy_trackable_to_cpu(object_map) + object_map[self] = tf.Variable(object_map[self.value]) + + def _export_to_saved_model_graph( + self, object_map, tensor_map, options, **kwargs + ): + resource_list = self.value._export_to_saved_model_graph( + object_map, tensor_map, options, **kwargs + ) + object_map[self] = tf.Variable(object_map[self.value]) + return resource_list + + def _write_object_proto(self, proto, options): + return self.value._write_object_proto(proto, options) + + def _map_aggregation(self, aggregation): + mapping = { + "none": tf.VariableAggregation.NONE, + "sum": tf.VariableAggregation.SUM, + "mean": tf.VariableAggregation.MEAN, + "only_first_replica": tf.VariableAggregation.ONLY_FIRST_REPLICA, + } + return mapping[aggregation] + + def _map_synchronization(self, synchronization): + mapping = { + "none": tf.VariableSynchronization.NONE, + "on_read": tf.VariableSynchronization.ON_READ, + "on_write": tf.VariableSynchronization.ON_WRITE, + "auto": tf.VariableSynchronization.AUTO, + } + return mapping[synchronization] + + +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + if isinstance(x, tf.SparseTensor) and sparse is not None and not sparse: + x = sparse_to_dense(x) + if isinstance(x, tf.RaggedTensor) and ragged is not None and not ragged: + x = x.to_tensor() + if dtype is not None: + dtype = standardize_dtype(dtype) + if not tf.is_tensor(x): + if dtype == "bool" or is_int_dtype(dtype): + # TensorFlow conversion is stricter than other backends, it does not + # allow ints for bools or floats for ints. We convert without dtype + # and cast instead. + x = tf.convert_to_tensor(x) + return tf.cast(x, dtype) + return tf.convert_to_tensor(x, dtype=dtype) + elif dtype is not None and not standardize_dtype(x.dtype) == dtype: + if isinstance(x, tf.SparseTensor): + x_shape = x.shape + x = tf.cast(x, dtype) + x.set_shape(x_shape) + return x + return tf.cast(x, dtype=dtype) + return x + + +def convert_to_numpy(x): + if isinstance(x, tf.SparseTensor): + x = sparse_to_dense(x) + elif isinstance(x, tf.IndexedSlices): + x = tf.convert_to_tensor(x) + elif isinstance(x, tf.RaggedTensor): + x = x.to_tensor() + return np.array(x) + + +def is_tensor(x): + return tf.is_tensor(x) + + +def shape(x): + """Always return a tuple shape. + + `tf.shape` will return a `tf.Tensor`, which differs from the tuple return + type on the torch and jax backends. We write our own method instead which + always returns a tuple, with integer values when the shape is known, and + tensor values when the shape is unknown (this is tf specific, as dynamic + shapes do not apply in other backends). + """ + if isinstance(x, KerasTensor): + return x.shape + if not tf.is_tensor(x): + x = tf.convert_to_tensor(x) + if x.shape == tf.TensorShape(None): + raise ValueError( + "All tensors passed to `ops.shape` must have a statically known " + f"rank. Received: x={x} with unknown rank." + ) + shape = x.shape.as_list() + dynamic = tf.shape(x) + for i in range(len(shape)): + if shape[i] is None: + try: + shape[i] = dynamic[i] + except: + # With RaggedTensors, accessing a ragged dimension will fail, + # we leave it as None. + pass + return tuple(shape) + + +def cast(x, dtype): + dtype = standardize_dtype(dtype) + if isinstance(x, tf.SparseTensor): + x_shape = x.shape + x = tf.cast(x, dtype) + x.set_shape(x_shape) + return x + else: + return tf.cast(x, dtype=dtype) + + +def compute_output_spec(fn, *args, **kwargs): + with StatelessScope(), SymbolicScope(): + graph_name = auto_name("scratch_graph") + with tf.__internal__.FuncGraph(graph_name).as_default(): + + def convert_keras_tensor_to_tf(x): + if isinstance(x, KerasTensor): + if x.sparse: + return tf.compat.v1.sparse_placeholder( + shape=x.shape, dtype=x.dtype + ) + else: + return tf.compat.v1.placeholder( + shape=x.shape, dtype=x.dtype + ) + return x + + args, kwargs = tree.map_structure( + convert_keras_tensor_to_tf, (args, kwargs) + ) + tf_out = fn(*args, **kwargs) + + def convert_tf_to_keras_tensor(x): + if tf.is_tensor(x): + return KerasTensor( + x.shape, x.dtype, sparse=isinstance(x, tf.SparseTensor) + ) + return x + + output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out) + return output_spec + + +def cond(pred, true_fn, false_fn): + if isinstance(pred, tf.Variable): + return tf.cond(pred, true_fn=true_fn, false_fn=false_fn) + return tf.__internal__.smart_cond.smart_cond( + pred, true_fn=true_fn, false_fn=false_fn + ) + + +def vectorized_map(function, elements): + return tf.vectorized_map(function, elements) + + +def map(f, xs): + xs = tree.map_structure(convert_to_tensor, xs) + + def get_fn_output_signature(x): + out = f(x) + return tree.map_structure(tf.TensorSpec.from_tensor, out) + + if tree.is_nested(xs): + input = tree.pack_sequence_as(xs, [x[0] for x in tree.flatten(xs)]) + fn_output_signature = get_fn_output_signature(input) + return tf.map_fn(f, xs, fn_output_signature=fn_output_signature) + else: + fn_output_signature = get_fn_output_signature(xs[0]) + return tf.map_fn(f, xs, fn_output_signature=fn_output_signature) + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # We have reimplemented `scan` to match the behavior of `jax.lax.scan` + # Ref: tf.scan, jax.lax.scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + # xs_flat = flatten_input(xs) + xs_flat = tree.flatten(xs) + xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else tf.shape(xs_flat[0])[0] + + # TensorArrays are always flat + xs_array = [ + tf.TensorArray( + dtype=x.dtype, + size=n, + dynamic_size=False, + element_shape=x.shape[1:], + infer_shape=True, + ) + for x in xs_flat + ] + xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)] + + init_flat = tree.flatten(init) + carry_flat = [tf.convert_to_tensor(init) for init in init_flat] + + # Store the intermediate values + # Note: there is a constraint that the output of `f` must have the same + # shape and dtype as carry (`init`). + ys_array = [ + tf.TensorArray( + dtype=carry.dtype, + size=n, + dynamic_size=False, + element_shape=carry.shape, + infer_shape=True, + ) + for carry in carry_flat + ] + carry_array = [ + tf.TensorArray( + dtype=carry.dtype, + size=1, + dynamic_size=False, + clear_after_read=False, + element_shape=carry.shape, + infer_shape=True, + ) + for carry in carry_flat + ] + carry_array = [ + carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat) + ] + + def loop_body(i, carry_array, ys_array): + packed_xs = ( + pack_input([xs.read(i) for xs in xs_array]) + if len(xs_array) > 0 + else None + ) + packed_carry = pack_output([carry.read(0) for carry in carry_array]) + + carry, ys = f(packed_carry, packed_xs) + + if ys is not None: + flat_ys = tree.flatten(ys) + ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)] + if carry is not None: + flat_carry = tree.flatten(carry) + carry_array = [ + carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry) + ] + next_i = i + 1 if not reverse else i - 1 + return (next_i, carry_array, ys_array) + + if isinstance(unroll, bool): + unroll = max(n, 1) if unroll else 1 + + _, carry_array, ys_array = tf.while_loop( + lambda i, _1, _2: i >= 0 if reverse else i < n, + loop_body, + (n - 1 if reverse else 0, carry_array, ys_array), + parallel_iterations=unroll, + ) + + ys_flat = [ys.stack() for ys in ys_array] + carry_flat = [carry.read(0) for carry in carry_array] + if xs is not None: + n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0] + if not isinstance(n_static, int): + for x in xs_flat[1:]: + n_static.assert_is_compatible_with( + x.get_shape().with_rank_at_least(1)[0] + ) + for r in ys_flat: + r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:])) + return pack_output(carry_flat), pack_output(ys_flat) + + +def associative_scan(f, elems, reverse=False, axis=0): + # Implementation is the same as tfp.math.scan_associative + # with additional checks to ensure similar behavior with jax + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + elems_flat = tree.flatten(elems) + elems_flat = [tf.convert_to_tensor(elem) for elem in elems_flat] + if reverse: + elems_flat = [tf.reverse(elem, [axis]) for elem in elems_flat] + + def _combine(a_flat, b_flat): + a = tree.pack_sequence_as(elems, a_flat) + b = tree.pack_sequence_as(elems, b_flat) + c = f(a, b) + c_flat = tree.flatten(c) + return c_flat + + def _get_dim(x): + return shape(x)[axis] + + # TODO add constant dim check + num_elems = _get_dim(elems_flat[0]) + if not all(_get_dim(elem) == num_elems for elem in elems_flat[1:]): + raise ValueError( + "Array inputs to associative_scan must have the same " + "first dimension. (saw: {})".format( + [tf.shape(elem) for elem in elems_flat] + ) + ) + + def _interleave(a, b, axis): + # [a b c ...] [d e f ...] -> [a d b e c f ...] + num_elems_a = _get_dim(a) + num_elems_b = _get_dim(b) + + # Note that interleaving implies rank(a)==rank(b). + axis = tf.where(axis >= 0, axis, tf.rank(a) + axis) + axis = ( + int(axis) # Avoid ndarray values. + if tf.get_static_value(axis) is not None + else axis + ) + + def _interleave_with_b(a): + return tf.reshape( + # Work around lack of support for Tensor axes in + # `tf.stack` by using `concat` and `expand_dims` instead. + tf.concat( + [ + tf.expand_dims(a, axis=axis + 1), + tf.expand_dims(b, axis=axis + 1), + ], + axis=axis + 1, + ), + tf.concat( + [ + a.get_shape()[:axis], + [2 * num_elems_b], + a.get_shape()[axis + 1 :], + ], + axis=0, + ), + ) + + return tf.cond( + tf.equal(num_elems_a, num_elems_b + 1), + lambda: tf.concat( + [ + _interleave_with_b( + slice_along_axis(a, None, -1, axis=axis) + ), + slice_along_axis(a, -1, None, axis=axis), + ], + axis=axis, + ), + lambda: _interleave_with_b(a), + ) + + def _scan(elems): + elem_length = _get_dim(elems[0]) + a = [slice_along_axis(elem, 0, -1, step=2, axis=axis) for elem in elems] + b = [ + slice_along_axis(elem, 1, None, step=2, axis=axis) for elem in elems + ] + reduced_elems = _combine(a, b) + + def _handle_base_case_elem_length_two(): + return [ + tf.concat( + [slice_along_axis(elem, 0, 1, axis=axis), reduced_elem], + axis=axis, + ) + for (reduced_elem, elem) in zip(reduced_elems, elems) + ] + + def _handle_base_case_elem_length_three(): + reduced_reduced_elems = _combine( + reduced_elems, + [slice_along_axis(elem, 2, 3, axis=axis) for elem in elems], + ) + return [ + tf.concat( + [ + slice_along_axis(elem, 0, 1, axis=axis), + reduced_elem, + reduced_reduced_elem, + ], + axis=axis, + ) + for (reduced_reduced_elem, reduced_elem, elem) in zip( + reduced_reduced_elems, reduced_elems, elems + ) + ] + + at_base_case = tf.logical_or( + tf.equal(elem_length, 2), tf.equal(elem_length, 3) + ) + + def _base_case(): + return tf.cond( + tf.equal(elem_length, 2), + _handle_base_case_elem_length_two, + _handle_base_case_elem_length_three, + ) + + def _recursive_case(): + odd_elems = _scan(reduced_elems) + + def _even_length_case(): + return _combine( + [ + slice_along_axis(odd_elem, 0, -1, axis=axis) + for odd_elem in odd_elems + ], + [ + slice_along_axis(elem, 2, None, 2, axis=axis) + for elem in elems + ], + ) + + def _odd_length_case(): + return _combine( + [odd_elem for odd_elem in odd_elems], + [ + slice_along_axis(elem, 2, None, 2, axis=axis) + for elem in elems + ], + ) + + results = tf.cond( + tf.equal(elem_length % 2, 0), + _even_length_case, + _odd_length_case, + ) + + even_elems = [ + tf.concat( + [slice_along_axis(elem, 0, 1, axis=axis), result], axis=axis + ) + for (elem, result) in zip(elems, results) + ] + return list( + builtins.map( + lambda a, b: _interleave(a, b, axis=axis), + even_elems, + odd_elems, + ) + ) + + return tf.cond(at_base_case, _base_case, _recursive_case) + + scans = _scan(elems_flat) + if reverse: + scans = [tf.reverse(scanned, [axis]) for scanned in scans] + + return tree.pack_sequence_as(elems, scans) + + +def scatter(indices, values, shape): + return tf.scatter_nd(indices, values, shape) + + +def scatter_update(inputs, indices, updates): + return tf.tensor_scatter_nd_update(inputs, indices, updates) + + +def slice(inputs, start_indices, shape): + return tf.slice(inputs, start_indices, shape) + + +def slice_update(inputs, start_indices, updates): + return dynamic_update_slice(inputs, updates, start_indices) + + +def switch(index, branches, *operands): + index = convert_to_tensor(index, "int32") + index = tf.clip_by_value(index, 0, len(branches) - 1) + + # Workaround to deal with python closures. More details: + # https://github.com/tensorflow/tensorflow/issues/8776#issuecomment-311383887 + def gen_fn(i): + return lambda: branches[i](*operands) + + branch_fns = [gen_fn(i) for i in range(len(branches))] + return tf.switch_case(index, branch_fns) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + is_tuple = isinstance(loop_vars, (tuple, list)) + loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,) + + def _body(*args): + outputs = body(*args) + return tuple(outputs) if is_tuple else (outputs,) + + outputs = tf.while_loop( + cond, + _body, + loop_vars, + maximum_iterations=maximum_iterations, + ) + return outputs if is_tuple else outputs[0] + + +def fori_loop(lower, upper, body_fun, init_val): + return tf.while_loop( + lambda i, val: i < upper, + lambda i, val: (i + 1, body_fun(i, val)), + (lower, init_val), + )[1] + + +def stop_gradient(variable): + return tf.stop_gradient(variable) + + +def unstack(x, num=None, axis=0): + return tf.unstack(x, num=num, axis=axis) + + +def random_seed_dtype(): + # tensorflow random operation only works on int32/int64, not uint32. + return "int64" + + +def custom_gradient(fun): + return tf.custom_gradient(f=fun) + + +def remat(f): + """Implementation of rematerialization. + + Args: + f: The function or operation to rematerialize. + Returns: + A function wrapping f that defines a custom gradient, which + recomputes f on the backwards pass of a gradient call. + """ + return tf.recompute_grad(f) + + +class name_scope(base_name_scope): + def __init__(self, name, **kwargs): + super().__init__(name, **kwargs) + self._tf_name_scope = tf.name_scope(name) + + def __enter__(self): + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack", default=[], set_to_default=True + ) + if self.deduplicate and name_scope_stack: + parent_caller = name_scope_stack[-1].caller + parent_name = name_scope_stack[-1].name + if ( + self.caller is not None + and self.caller is parent_caller + and self.name == parent_name + ): + return self + name_scope_stack.append(self) + self._pop_on_exit = True + self._tf_name_scope.__enter__() + return self + + def __exit__(self, *args, **kwargs): + super().__exit__(*args, **kwargs) + if self._pop_on_exit: + self._tf_name_scope.__exit__(*args, **kwargs) + + +def device_scope(device_name): + return tf.device(device_name) diff --git a/keras/src/backend/tensorflow/distribute_test.py b/keras/src/backend/tensorflow/distribute_test.py new file mode 100644 index 000000000000..d2381bf64c14 --- /dev/null +++ b/keras/src/backend/tensorflow/distribute_test.py @@ -0,0 +1,205 @@ +"""Tests for tf.distribute related functionality under tf implementation.""" + +import numpy as np +import pytest +import tensorflow as tf +from tensorflow.python.eager import context + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.backend.tensorflow import trainer as tf_trainer + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The distribute test can only run with TF backend.", +) +class DistributeTest(testing.TestCase): + def setUp(self): + super().setUp() + # Need at least 2 devices for distribution related tests. + cpus = tf.config.list_physical_devices("CPU") + context._reset_context() + tf.config.set_logical_device_configuration( + cpus[0], + [ + tf.config.LogicalDeviceConfiguration(), + tf.config.LogicalDeviceConfiguration(), + ], + ) + + def test_variable_creation(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + with strategy.scope(): + dense = layers.Dense(2) + dense.build([4, 2]) + + self.assertIsInstance(dense.kernel, backend.Variable) + self.assertIsInstance( + dense.kernel.value, tf.distribute.DistributedValues + ) + self.assertIn("MirroredVariable", dense.kernel.value.__class__.__name__) + + self.assertIsInstance(dense.kernel, backend.Variable) + self.assertIsInstance(dense.bias.value, tf.distribute.DistributedValues) + self.assertIn("MirroredVariable", dense.bias.value.__class__.__name__) + + def test_strategy_run(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + + with strategy.scope(): + inputs = layers.Input(shape=[4]) + dense = layers.Dense(2) + output = dense(inputs) + model = models.Functional(inputs, output) + + self.assertIsInstance(dense.kernel, backend.Variable) + self.assertIsInstance( + dense.kernel.value, tf.distribute.DistributedValues + ) + + def input_fn(ctx): + if ctx.replica_id_in_sync_group == 1: + return tf.ones([8, 4]) + else: + return tf.zeros([8, 4]) + + distributed_inputs = ( + strategy.experimental_distribute_values_from_function(input_fn) + ) + + @tf.function + def run_fn(data): + return model(data) + + result = strategy.run(run_fn, args=(distributed_inputs,)) + + self.assertIsInstance( + result, tf.types.experimental.distributed.PerReplica + ) + self.assertLen(result.values, 2) + self.assertEqual(result.values[0].shape, [8, 2]) + self.assertEqual(result.values[1].shape, [8, 2]) + self.assertNotAllClose(result.values[0], result.values[1]) + self.assertAllClose(result.values[0], tf.zeros([8, 2])) + + def test_epoch_iterator(self): + x = np.random.random((100, 16)) + y = np.random.random((100, 4)) + sample_weight = np.random.random((100,)) + batch_size = 16 + shuffle = True + + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + + epoch_iterator = tf_trainer.TFEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + shuffle=shuffle, + distribute_strategy=strategy, + ) + steps_seen = [] + for step, _, data_iterator in epoch_iterator: + steps_seen.append(step) + batch = next(data_iterator) + self.assertEqual(len(batch), 3) + x, y, sample_weight = batch + self.assertTrue( + isinstance(x, tf.types.experimental.distributed.PerReplica) + ) + # Make sure the local batch size is 8 + if step < 6: + self.assertEqual(x.values[0].shape, [8, 16]) + self.assertEqual(y.values[0].shape, [8, 4]) + self.assertEqual(sample_weight.values[0].shape, [8]) + else: + # Last partial batch + self.assertEqual(x.values[0].shape, [2, 16]) + self.assertEqual(y.values[0].shape, [2, 4]) + self.assertEqual(sample_weight.values[0].shape, [2]) + self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6]) + + def test_variable_aggregation(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + + with strategy.scope(): + x = np.random.random((4, 4)) + v1 = backend.Variable(x, dtype="float32") + self.assertEqual(v1.aggregation, "none") + self.assertEqual(v1.value.aggregation, tf.VariableAggregation.NONE) + + v2 = backend.Variable(x, dtype="float32", aggregation="sum") + self.assertEqual(v2.aggregation, "sum") + self.assertEqual(v2.value.aggregation, tf.VariableAggregation.SUM) + + def test_variable_synchronization(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + + with strategy.scope(): + x = np.random.random((4, 4)) + v1 = backend.Variable(x, dtype="float32") + self.assertEqual(v1.synchronization, "auto") + # AUTO with MirroredStrategy defaults to ON_WRITE + self.assertEqual( + v1.value.synchronization, tf.VariableSynchronization.ON_WRITE + ) + + v2 = backend.Variable(x, dtype="float32", synchronization="on_read") + self.assertEqual(v2.synchronization, "on_read") + self.assertEqual( + v2.value.synchronization, tf.VariableSynchronization.ON_READ + ) + + def test_seed_generator(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + with strategy.scope(): + seed_generator = keras.random.SeedGenerator(42) + states = strategy.run(lambda: seed_generator.state.value).values + for s in states: + self.assertAllClose(keras.ops.convert_to_numpy(s), (42, 0)) + + def test_correctness_with_fit_and_regularizer(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + + batch_size = 12 + x = keras.ops.ones((batch_size, 1)) + y = keras.ops.zeros((batch_size, 1)) + + # Runs without a strategy to get expected weights. + inputs = layers.Input(shape=(1,)) + layer = layers.Dense( + 1, + use_bias=False, + kernel_initializer=keras.initializers.Constant(1), + kernel_regularizer=keras.regularizers.L1L2(l1=0.01, l2=0.01), + ) + model = models.Model(inputs, layer(inputs)) + model.compile(loss="mse", optimizer="sgd") + history = model.fit(x, y, batch_size=batch_size, epochs=1) + expected_loss = history.history["loss"] + expected_weights = keras.ops.convert_to_numpy(layer.kernel) + + # Runs with a mirrored strategy. + with strategy.scope(): + inputs = layers.Input(shape=(1,)) + layer = layers.Dense( + 1, + use_bias=False, + kernel_initializer=keras.initializers.Constant(1), + kernel_regularizer=keras.regularizers.L1L2(l1=0.01, l2=0.01), + ) + model = models.Model(inputs, layer(inputs)) + model.compile(loss="mse", optimizer="sgd") + history = model.fit(x, y, batch_size=batch_size, epochs=1) + weights = strategy.run(lambda: layer.kernel.value).values + + self.assertAllClose(history.history["loss"], expected_loss) + for w in weights: + self.assertAllClose( + keras.ops.convert_to_numpy(w), expected_weights + ) diff --git a/keras/src/backend/tensorflow/distribution_lib.py b/keras/src/backend/tensorflow/distribution_lib.py new file mode 100644 index 000000000000..b306fd07dd0e --- /dev/null +++ b/keras/src/backend/tensorflow/distribution_lib.py @@ -0,0 +1,87 @@ +"""!!!DO NOT USE!!! + +Distribution related class for Tensorflow backend. + +This is just a prototype and we might want to unify it +with other backends in the future. +""" + +import tensorflow as tf +from tensorflow.experimental import dtensor + + +def list_devices(device_type=None): + """Return all the available devices based on the device type. + + Note that this should return the global devices in a distributed setting. + + Args: + device_type: string of `"cpu"`, `"gpu"` or `"tpu"`. Default to `gpu` or + `tpu` if available when device_type is not provided. Otherwise will + return the `cpu` devices. + + Return: + List of devices that are available for distribute computation. + """ + device_type = device_type.upper() if device_type else None + + # DTensor doesn't support getting global devices, even when knowing the + # Mesh. Use TF API instead to get global devices. Coordinator service is + # enabled by default with DTensor, so that list_logical_devices() returns + # a list of global devices. More context can be found in b/254911601. + tf_devices = tf.config.list_logical_devices(device_type=device_type) + cpu_devices = [] + other_devices = [] + for device in tf_devices: + if device.device_type.lower() == "cpu": + cpu_devices.append(device) + else: + other_devices.append(device) + if device_type is None: + tf_devices = other_devices if len(other_devices) > 0 else cpu_devices + return [ + f"{device.device_type.lower()}:{device.name.split(':')[-1]}" + for device in tf_devices + ] + + +def distribute_value(value, tensor_layout): + # TODO + pass + + +def _to_backend_mesh(device_mesh): + """Convert the DeviceMesh to Tensorflow backend specific Mesh. + + Args: + device_mesh: DeviceMesh instance to convert. + + Returns: + A `tf.dtensor.Mesh` instance. + """ + mesh_dims = list(zip(device_mesh.axis_names, device_mesh.shape)) + return dtensor.create_distributed_mesh( + mesh_dims=mesh_dims, local_devices=device_mesh.devices.flatten() + ) + + +def _to_backend_layout(tensor_layout): + """Convert the TensorLayout to Tensorflow backend specific Sharding. + + Args: + tensor_layout: TensorLayout instance to convert. + + Returns: + A `tf.dtensor.Layout` instance. + """ + if tensor_layout.device_mesh is None: + raise ValueError( + "Cannot create sharding when device mesh is not set for " + "TensorLayout." + ) + + sharding_specs = [ + axis if axis else dtensor.UNSHARDED for axis in tensor_layout.axes + ] + dtensor_mesh = tensor_layout.device_mesh.backend_mesh + return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh) diff --git a/keras/src/backend/tensorflow/export.py b/keras/src/backend/tensorflow/export.py new file mode 100644 index 000000000000..e57f74cc8bde --- /dev/null +++ b/keras/src/backend/tensorflow/export.py @@ -0,0 +1,19 @@ +import tensorflow as tf + + +class TFExportArchive: + def _track_layer(self, layer): + # Variables in the lists below are actually part of the trackables + # that get saved, because the lists are created in __init__. + variables = layer.variables + trainable_variables = layer.trainable_variables + non_trainable_variables = layer.non_trainable_variables + self._tf_trackable.variables += variables + self._tf_trackable.trainable_variables += trainable_variables + self._tf_trackable.non_trainable_variables += non_trainable_variables + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + decorated_fn = tf.function( + fn, input_signature=input_signature, autograph=False + ) + return decorated_fn diff --git a/keras/src/backend/tensorflow/image.py b/keras/src/backend/tensorflow/image.py new file mode 100644 index 000000000000..0c693f4ff243 --- /dev/null +++ b/keras/src/backend/tensorflow/image.py @@ -0,0 +1,1076 @@ +import functools +import itertools +import operator + +import numpy as np +import tensorflow as tf + +from keras.src import backend +from keras.src.backend.tensorflow.core import convert_to_tensor +from keras.src.backend.tensorflow.numpy import moveaxis +from keras.src.random.seed_generator import draw_seed + +RESIZE_INTERPOLATIONS = ( + "bilinear", + "nearest", + "lanczos3", + "lanczos5", + "bicubic", + "area", +) +AFFINE_TRANSFORM_INTERPOLATIONS = ( + "nearest", + "bilinear", +) +AFFINE_TRANSFORM_FILL_MODES = ( + "constant", + "nearest", + "wrap", + # "mirror", not supported by TF + "reflect", +) +MAP_COORDINATES_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} +SCALE_AND_TRANSLATE_METHODS = { + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", +} + + +def rgb_to_grayscale(images, data_format=None): + images = convert_to_tensor(images) + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + # Convert to floats + original_dtype = images.dtype + compute_dtype = backend.result_type(images.dtype, float) + images = tf.cast(images, compute_dtype) + + # Ref: tf.image.rgb_to_grayscale + rgb_weights = convert_to_tensor( + [0.2989, 0.5870, 0.1140], dtype=images.dtype + ) + images = tf.tensordot(images, rgb_weights, axes=(channels_axis, -1)) + images = tf.expand_dims(images, axis=channels_axis) + return tf.cast(images, original_dtype) + + +def rgb_to_hsv(images, data_format=None): + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + if data_format == "channels_first": + if len(images.shape) == 4: + images = tf.transpose(images, (0, 2, 3, 1)) + else: + images = tf.transpose(images, (1, 2, 0)) + images = tf.image.rgb_to_hsv(images) + if data_format == "channels_first": + if len(images.shape) == 4: + images = tf.transpose(images, (0, 3, 1, 2)) + elif len(images.shape) == 3: + images = tf.transpose(images, (2, 0, 1)) + return images + + +def hsv_to_rgb(images, data_format=None): + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + if data_format == "channels_first": + if len(images.shape) == 4: + images = tf.transpose(images, (0, 2, 3, 1)) + else: + images = tf.transpose(images, (1, 2, 0)) + images = tf.image.hsv_to_rgb(images) + if data_format == "channels_first": + if len(images.shape) == 4: + images = tf.transpose(images, (0, 3, 1, 2)) + elif len(images.shape) == 3: + images = tf.transpose(images, (2, 0, 1)) + return images + + +def resize( + images, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in RESIZE_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}" + ) + if fill_mode != "constant": + raise ValueError( + "Invalid value for argument `fill_mode`. Only `'constant'` " + f"is supported. Received: fill_mode={fill_mode}" + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) + if not len(size) == 2: + raise ValueError( + "Argument `size` must be a tuple of two elements " + f"(height, width). Received: size={size}" + ) + size = tuple(size) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if data_format == "channels_first": + if len(images.shape) == 4: + images = tf.transpose(images, (0, 2, 3, 1)) + else: + images = tf.transpose(images, (1, 2, 0)) + + if crop_to_aspect_ratio: + shape = tf.shape(images) + height, width = shape[-3], shape[-2] + target_height, target_width = size + crop_height = tf.cast( + tf.cast(width * target_height, "float32") / target_width, + "int32", + ) + crop_height = tf.maximum(tf.minimum(height, crop_height), 1) + crop_height = tf.cast(crop_height, "int32") + crop_width = tf.cast( + tf.cast(height * target_width, "float32") / target_height, + "int32", + ) + crop_width = tf.maximum(tf.minimum(width, crop_width), 1) + crop_width = tf.cast(crop_width, "int32") + + crop_box_hstart = tf.cast( + tf.cast(height - crop_height, "float32") / 2, "int32" + ) + crop_box_wstart = tf.cast( + tf.cast(width - crop_width, "float32") / 2, "int32" + ) + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + elif pad_to_aspect_ratio: + shape = tf.shape(images) + height, width = shape[-3], shape[-2] + target_height, target_width = size + pad_height = tf.cast( + tf.cast(width * target_height, "float32") / target_width, + "int32", + ) + pad_height = tf.maximum(height, pad_height) + pad_height = tf.cast(pad_height, "int32") + pad_width = tf.cast( + tf.cast(height * target_width, "float32") / target_height, + "int32", + ) + pad_width = tf.maximum(width, pad_width) + pad_width = tf.cast(pad_width, "int32") + + img_box_hstart = tf.cast( + tf.cast(pad_height - height, "float32") / 2, "int32" + ) + img_box_wstart = tf.cast( + tf.cast(pad_width - width, "float32") / 2, "int32" + ) + if len(images.shape) == 4: + batch_size = tf.shape(images)[0] + channels = tf.shape(images)[3] + padded_img = tf.cond( + img_box_hstart > 0, + lambda: tf.concat( + [ + tf.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + tf.ones( + (batch_size, img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, + ), + lambda: images, + ) + padded_img = tf.cond( + img_box_wstart > 0, + lambda: tf.concat( + [ + tf.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + padded_img, + tf.ones( + (batch_size, height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=2, + ), + lambda: padded_img, + ) + else: + channels = tf.shape(images)[2] + padded_img = tf.cond( + img_box_hstart > 0, + lambda: tf.concat( + [ + tf.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + images, + tf.ones( + (img_box_hstart, width, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=0, + ), + lambda: images, + ) + padded_img = tf.cond( + img_box_wstart > 0, + lambda: tf.concat( + [ + tf.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + padded_img, + tf.ones( + (height, img_box_wstart, channels), + dtype=images.dtype, + ) + * fill_value, + ], + axis=1, + ), + lambda: padded_img, + ) + images = padded_img + + resized = tf.image.resize( + images, size, method=interpolation, antialias=antialias + ) + if data_format == "channels_first": + if len(images.shape) == 4: + resized = tf.transpose(resized, (0, 3, 1, 2)) + elif len(images.shape) == 3: + resized = tf.transpose(resized, (2, 0, 1)) + return resized + + +def affine_transform( + images, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if len(transform.shape) not in (1, 2): + raise ValueError( + "Invalid transform rank: expected rank 1 (single transform) " + "or rank 2 (batch of transforms). Received input with shape: " + f"transform.shape={transform.shape}" + ) + # unbatched case + need_squeeze = False + if len(images.shape) == 3: + images = tf.expand_dims(images, axis=0) + need_squeeze = True + if len(transform.shape) == 1: + transform = tf.expand_dims(transform, axis=0) + + if data_format == "channels_first": + images = tf.transpose(images, (0, 2, 3, 1)) + + affined = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=tf.cast(transform, dtype=tf.float32), + output_shape=tf.shape(images)[1:-1], + fill_value=fill_value, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + ) + affined = tf.ensure_shape(affined, images.shape) + + if data_format == "channels_first": + affined = tf.transpose(affined, (0, 3, 1, 2)) + if need_squeeze: + affined = tf.squeeze(affined, axis=0) + return affined + + +def perspective_transform( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + start_points = convert_to_tensor(start_points, dtype=tf.float32) + end_points = convert_to_tensor(end_points, dtype=tf.float32) + + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: " + f"interpolation={interpolation}" + ) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + if start_points.shape.rank not in (2, 3) or start_points.shape[-2:] != ( + 4, + 2, + ): + raise ValueError( + "Invalid start_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {start_points.shape}" + ) + if end_points.shape.rank not in (2, 3) or end_points.shape[-2:] != (4, 2): + raise ValueError( + "Invalid end_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {end_points.shape}" + ) + if start_points.shape != end_points.shape: + raise ValueError( + "start_points and end_points must have the same shape." + f" Received start_points.shape={start_points.shape}, " + f"end_points.shape={end_points.shape}" + ) + + need_squeeze = False + if len(images.shape) == 3: + images = tf.expand_dims(images, axis=0) + need_squeeze = True + + if len(start_points.shape) == 2: + start_points = tf.expand_dims(start_points, axis=0) + if len(end_points.shape) == 2: + end_points = tf.expand_dims(end_points, axis=0) + + if data_format == "channels_first": + images = tf.transpose(images, (0, 2, 3, 1)) + + transform = compute_homography_matrix(start_points, end_points) + if len(transform.shape) == 1: + transform = tf.expand_dims(transform, axis=0) + + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=tf.cast(transform, dtype=tf.float32), + output_shape=tf.shape(images)[1:-1], + fill_value=fill_value, + interpolation=interpolation.upper(), + ) + output = tf.ensure_shape(output, images.shape) + + if data_format == "channels_first": + output = tf.transpose(output, (0, 3, 1, 2)) + if need_squeeze: + output = tf.squeeze(output, axis=0) + return output + + +def compute_homography_matrix(start_points, end_points): + start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1] + start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1] + start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1] + start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1] + + end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1] + end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1] + end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1] + end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1] + + coefficient_matrix = tf.stack( + [ + tf.stack( + [ + end_x1, + end_y1, + tf.ones_like(end_x1), + tf.zeros_like(end_x1), + tf.zeros_like(end_x1), + tf.zeros_like(end_x1), + -start_x1 * end_x1, + -start_x1 * end_y1, + ], + axis=-1, + ), + tf.stack( + [ + tf.zeros_like(end_x1), + tf.zeros_like(end_x1), + tf.zeros_like(end_x1), + end_x1, + end_y1, + tf.ones_like(end_x1), + -start_y1 * end_x1, + -start_y1 * end_y1, + ], + axis=-1, + ), + tf.stack( + [ + end_x2, + end_y2, + tf.ones_like(end_x2), + tf.zeros_like(end_x2), + tf.zeros_like(end_x2), + tf.zeros_like(end_x2), + -start_x2 * end_x2, + -start_x2 * end_y2, + ], + axis=-1, + ), + tf.stack( + [ + tf.zeros_like(end_x2), + tf.zeros_like(end_x2), + tf.zeros_like(end_x2), + end_x2, + end_y2, + tf.ones_like(end_x2), + -start_y2 * end_x2, + -start_y2 * end_y2, + ], + axis=-1, + ), + tf.stack( + [ + end_x3, + end_y3, + tf.ones_like(end_x3), + tf.zeros_like(end_x3), + tf.zeros_like(end_x3), + tf.zeros_like(end_x3), + -start_x3 * end_x3, + -start_x3 * end_y3, + ], + axis=-1, + ), + tf.stack( + [ + tf.zeros_like(end_x3), + tf.zeros_like(end_x3), + tf.zeros_like(end_x3), + end_x3, + end_y3, + tf.ones_like(end_x3), + -start_y3 * end_x3, + -start_y3 * end_y3, + ], + axis=-1, + ), + tf.stack( + [ + end_x4, + end_y4, + tf.ones_like(end_x4), + tf.zeros_like(end_x4), + tf.zeros_like(end_x4), + tf.zeros_like(end_x4), + -start_x4 * end_x4, + -start_x4 * end_y4, + ], + axis=-1, + ), + tf.stack( + [ + tf.zeros_like(end_x4), + tf.zeros_like(end_x4), + tf.zeros_like(end_x4), + end_x4, + end_y4, + tf.ones_like(end_x4), + -start_y4 * end_x4, + -start_y4 * end_y4, + ], + axis=-1, + ), + ], + axis=1, + ) + + target_vector = tf.stack( + [ + start_x1, + start_y1, + start_x2, + start_y2, + start_x3, + start_y3, + start_x4, + start_y4, + ], + axis=-1, + ) + target_vector = tf.expand_dims(target_vector, axis=-1) + + homography_matrix = tf.linalg.solve(coefficient_matrix, target_vector) + homography_matrix = tf.reshape(homography_matrix, [-1, 8]) + + return homography_matrix + + +def _mirror_index_fixer(index, size): + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return tf.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index, size): + return tf.math.floordiv( + _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2 + ) + + +def _nearest_indices_and_weights(coordinate): + coordinate = ( + coordinate if coordinate.dtype.is_integer else tf.round(coordinate) + ) + index = tf.cast(coordinate, tf.int32) + weight = tf.constant(1, coordinate.dtype) + return [(index, weight)] + + +def _linear_indices_and_weights(coordinate): + lower = tf.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = tf.cast(lower, tf.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +def map_coordinates( + inputs, coordinates, order, fill_mode="constant", fill_value=0.0 +): + input_arr = convert_to_tensor(inputs) + coordinate_arrs = convert_to_tensor(coordinates) + + if coordinate_arrs.shape[0] != len(input_arr.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`inputs`. " + f"Received inputs with shape: {input_arr.shape} and coordinate " + f"leading dim of {coordinate_arrs.shape[0]}" + ) + if len(coordinate_arrs.shape) < 2: + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {coordinate_arrs.shape}" + ) + if fill_mode not in MAP_COORDINATES_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: " + f"fill_mode={fill_mode}" + ) + + fill_value = convert_to_tensor(fill_value, dtype=input_arr.dtype) + + coordinate_arrs = tf.unstack(coordinate_arrs, axis=0) + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + raise NotImplementedError("map_coordinates currently requires order<=1") + + def process_coordinates(coords, size): + if fill_mode == "constant": + valid = (coords >= 0) & (coords < size) + safe_coords = tf.clip_by_value(coords, 0, size - 1) + return safe_coords, valid + elif fill_mode == "nearest": + return tf.clip_by_value(coords, 0, size - 1), tf.ones_like( + coords, dtype=tf.bool + ) + elif fill_mode in ["mirror", "reflect"]: + coords = tf.abs(coords) + size_2 = size * 2 + mod = tf.math.mod(coords, size_2) + under = mod < size + over = ~under + # reflect mode is same as mirror for under + coords = tf.where(under, mod, size_2 - mod) + # for reflect mode, adjust the over case + if fill_mode == "reflect": + coords = tf.where(over, coords - 1, coords) + return coords, tf.ones_like(coords, dtype=tf.bool) + elif fill_mode == "wrap": + coords = tf.math.mod(coords, size) + return coords, tf.ones_like(coords, dtype=tf.bool) + else: + raise ValueError(f"Unknown fill_mode: {fill_mode}") + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + safe_index, valid = process_coordinates(index, size) + valid_interp.append((safe_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = zip(*items) + indices = tf.transpose(tf.stack(indices)) + + gathered = tf.transpose(tf.gather_nd(input_arr, indices)) + + # Cast to computation dtype early to avoid type issues + dtype = weights[0].dtype + gathered = tf.cast(gathered, dtype) + gathered = tf.cast(gathered, weights[0].dtype) + + if fill_mode == "constant": + all_valid = tf.reduce_all(validities, axis=0) + fill_value_typed = tf.cast(fill_value, dtype) + gathered = tf.where(all_valid, gathered, fill_value_typed) + + outputs.append(functools.reduce(operator.mul, weights) * gathered) + + result = functools.reduce(operator.add, outputs) + + if input_arr.dtype.is_integer: + result = tf.round(result) + return tf.cast(result, input_arr.dtype) + + +def gaussian_blur( + images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None +): + def _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype): + def _get_gaussian_kernel1d(size, sigma): + x = tf.range(size, dtype=dtype) - (size - 1) / 2 + kernel1d = tf.exp(-0.5 * (x / sigma) ** 2) + return kernel1d / tf.reduce_sum(kernel1d) + + def _get_gaussian_kernel2d(size, sigma): + size = tf.cast(size, dtype) + kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0]) + kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1]) + return tf.tensordot(kernel1d_y, kernel1d_x, axes=0) + + kernel = _get_gaussian_kernel2d(kernel_size, sigma) + kernel = tf.reshape(kernel, (kernel_size[0], kernel_size[1], 1, 1)) + kernel = tf.tile(kernel, [1, 1, num_channels, 1]) + kernel = tf.cast(kernel, dtype) + return kernel + + images = convert_to_tensor(images) + dtype = backend.standardize_dtype(images.dtype) + kernel_size = convert_to_tensor(kernel_size, dtype=dtype) + sigma = convert_to_tensor(sigma, dtype=dtype) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + need_squeeze = False + if len(images.shape) == 3: + images = tf.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_first": + images = tf.transpose(images, (0, 2, 3, 1)) + + num_channels = tf.shape(images)[-1] + kernel = _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype) + + blurred_images = tf.nn.depthwise_conv2d( + images, kernel, strides=[1, 1, 1, 1], padding="SAME" + ) + + if data_format == "channels_first": + blurred_images = tf.transpose(blurred_images, (0, 3, 1, 2)) + if need_squeeze: + blurred_images = tf.squeeze(blurred_images, axis=0) + + return blurred_images + + +def elastic_transform( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + images = convert_to_tensor(images) + input_dtype = images.dtype + + alpha = convert_to_tensor(alpha, dtype=input_dtype) + sigma = convert_to_tensor(sigma, dtype=input_dtype) + kernel_factor = convert_to_tensor(sigma, dtype="int32") + kernel_size = (6 * kernel_factor | 1, 6 * kernel_factor | 1) + + need_squeeze = False + if len(images.shape) == 3: + images = tf.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_last": + batch_size, height, width, channels = images.shape + channel_axis = -1 + else: + batch_size, channels, height, width = images.shape + channel_axis = 1 + + seed = draw_seed(seed) + + if batch_size is None: + batch_size = 1 + + dx = ( + tf.random.stateless_normal( + shape=(batch_size, height, width), + mean=0.0, + stddev=1.0, + dtype=input_dtype, + seed=seed, + ) + * sigma + ) + dy = ( + tf.random.stateless_normal( + shape=(batch_size, height, width), + mean=0.0, + stddev=1.0, + dtype=input_dtype, + seed=seed, + ) + * sigma + ) + + dx = gaussian_blur( + tf.expand_dims(dx, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + dy = gaussian_blur( + tf.expand_dims(dy, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + + dx = tf.squeeze(dx, axis=channel_axis) + dy = tf.squeeze(dy, axis=channel_axis) + + x, y = tf.meshgrid( + tf.range(width, dtype=input_dtype), + tf.range(height, dtype=input_dtype), + indexing="xy", + ) + x = tf.expand_dims(x, axis=0) + y = tf.expand_dims(y, axis=0) + + distorted_x = x + alpha * dx + distorted_y = y + alpha * dy + + channel_outputs = [] + if data_format == "channels_last": + for i in range(channels): + channel_transformed = tf.stack( + [ + map_coordinates( + images[b, ..., i], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS.index( + interpolation + ), + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ], + axis=0, + ) + channel_outputs.append(channel_transformed) + transformed_images = tf.stack(channel_outputs, axis=-1) + else: + for i in range(channels): + channel_transformed = tf.stack( + [ + map_coordinates( + images[b, i, ...], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS.index( + interpolation + ), + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ], + axis=0, + ) + channel_outputs.append(channel_transformed) + transformed_images = tf.stack(channel_outputs, axis=1) + + if need_squeeze: + transformed_images = tf.squeeze(transformed_images, axis=0) + transformed_images = tf.cast(transformed_images, input_dtype) + + return transformed_images + + +def _fill_triangle_kernel(x): + return tf.maximum(tf.constant(0, dtype=x.dtype), 1 - tf.abs(x)) + + +def _fill_keys_cubic_kernel(x): + out = ((1.5 * x - 2.5) * x) * x + 1.0 + out = tf.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out) + return tf.where(x >= 2.0, 0.0, out) + + +def _fill_lanczos_kernel(radius, x): + y = radius * tf.sin(np.pi * x) * tf.sin(np.pi * x / radius) + out = tf.where( + x > 1e-3, tf.divide(y, tf.where(x != 0, np.pi**2 * x**2, 1)), 1 + ) + return tf.where(x > radius, 0.0, out) + + +_kernels = { + "linear": _fill_triangle_kernel, + "cubic": _fill_keys_cubic_kernel, + "lanczos3": lambda x: _fill_lanczos_kernel(3.0, x), + "lanczos5": lambda x: _fill_lanczos_kernel(5.0, x), +} + + +def _compute_weight_mat( + input_size, output_size, scale, translation, kernel, antialias +): + dtype = backend.result_type(scale.dtype, translation.dtype) + inv_scale = 1.0 / scale + kernel_scale = tf.maximum(inv_scale, 1.0) if antialias else 1.0 + sample_f = ( + (tf.range(output_size, dtype=dtype) + 0.5) * inv_scale + - translation * inv_scale + - 0.5 + ) + x = ( + tf.abs( + sample_f[tf.newaxis, :] + - tf.range(input_size, dtype=dtype)[:, tf.newaxis] + ) + / kernel_scale + ) + weights = kernel(x) + total_weight_sum = tf.reduce_sum(weights, axis=0, keepdims=True) + weights = tf.where( + tf.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps), + tf.divide( + weights, tf.where(total_weight_sum != 0, total_weight_sum, 1) + ), + 0, + ) + input_size_minus_0_5 = tf.cast(input_size, dtype=dtype) - 0.5 + return tf.where( + tf.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ + tf.newaxis, : + ], + weights, + 0, + ) + + +def _scale_and_translate( + x, output_shape, spatial_dims, scale, translation, kernel, antialias +): + x = convert_to_tensor(x) + input_shape = tf.shape(x) + if len(spatial_dims) == 0: + return x + if backend.is_int_dtype(x.dtype): + output = tf.cast(x, tf.float32) + use_rounding = True + else: + output = tf.identity(x) + use_rounding = False + for i, d in enumerate(spatial_dims): + d = d % x.ndim + m, n = input_shape[d], output_shape[d] + w = tf.cast( + _compute_weight_mat( + m, n, scale[i], translation[i], kernel, antialias + ), + output.dtype, + ) + output = tf.tensordot(output, w, axes=(d, 0)) + output = moveaxis(output, -1, d) + if use_rounding: + output = tf.clip_by_value( + tf.round(output), tf.reduce_min(x), tf.reduce_max(x) + ) + output = tf.cast(output, x.dtype) + return output + + +def scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias=True, +): + if method not in SCALE_AND_TRANSLATE_METHODS: + raise ValueError( + "Invalid value for argument `method`. Expected of one " + f"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}" + ) + if method in ("linear", "bilinear", "trilinear", "triangle"): + method = "linear" + elif method in ("cubic", "bicubic", "tricubic"): + method = "cubic" + + images = convert_to_tensor(images) + scale = convert_to_tensor(scale) + translation = convert_to_tensor(translation) + kernel = _kernels[method] + dtype = backend.result_type(scale.dtype, translation.dtype) + scale = tf.cast(scale, dtype) + translation = tf.cast(translation, dtype) + return _scale_and_translate( + images, + output_shape, + spatial_dims, + scale, + translation, + kernel, + antialias, + ) diff --git a/keras/src/backend/tensorflow/layer.py b/keras/src/backend/tensorflow/layer.py new file mode 100644 index 000000000000..2e0c4cd2c144 --- /dev/null +++ b/keras/src/backend/tensorflow/layer.py @@ -0,0 +1,114 @@ +import tensorflow as tf + +from keras.src import tree +from keras.src.backend.tensorflow.trackable import KerasAutoTrackable +from keras.src.utils import tf_utils +from keras.src.utils import tracking + + +class TFLayer(KerasAutoTrackable): + def __init__(self, *args, **kwargs): + # Export-related attributes + self._saved_model_inputs_spec = None + self._saved_model_arg_spec = None + self._tracked = [] + + @tf.__internal__.tracking.no_automatic_dependency_tracking + def _set_save_spec(self, inputs, args=None, kwargs=None): + """Defines the save spec so that serialization can trace layer calls. + + The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are + saved into a tuple of `([inputs] + args, kwargs)`. + + Args: + inputs: possibly nested inputs passed into the call function. + args: a list of positional arguments passed into call. + kwargs: a dictionary of keyword arguments passed into call. + """ + if self._saved_model_inputs_spec is not None: + return # Already set. + + inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs) + args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or []) + kwargs_spec = {} + # Filter out non-tensor arguments from kwargs. + for key, kwarg in kwargs.items(): + flat_kwarg = tree.flatten(kwarg) + flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg] + if any(s is None for s in flat_specs): + continue + kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs) + + self._saved_model_inputs_spec = inputs_spec + self._saved_model_arg_spec = ( + [inputs_spec] + list(args_spec), + kwargs_spec, + ) + + def _trackable_children(self, save_type="checkpoint", **kwargs): + if save_type == "savedmodel": + # SavedModel needs to ignore the execution functions. + train_function = getattr(self, "train_function", None) + test_function = getattr(self, "test_function", None) + predict_function = getattr(self, "predict_function", None) + self.train_function = None + self.test_function = None + self.predict_function = None + + children = super()._trackable_children(save_type, **kwargs) + + if save_type == "savedmodel": + self.train_function = train_function + self.test_function = test_function + self.predict_function = predict_function + + for tracked_attr in self._tracked: + tracked_item = getattr(self, tracked_attr) + if isinstance(tracked_item, tracking.TrackedList): + children[tracked_attr] = list(tracked_item) + if isinstance(tracked_item, tracking.TrackedDict): + children[tracked_attr] = dict(tracked_item) + if isinstance(tracked_item, tracking.TrackedSet): + children[tracked_attr] = list(tracked_item) + + return children + + @property + def _default_save_signature(self): + """For SavedModel support: returns the default serving signature.""" + + from keras.src.models.functional import Functional + from keras.src.models.model import Model + from keras.src.models.sequential import Sequential + + if not isinstance(self, Model): + return None + + inputs = None + if ( + isinstance(self, Sequential) + and getattr(self, "_functional", None) is not None + ): + inputs = self._functional.input + elif isinstance(self, Functional): + inputs = self.input + + if inputs is not None: + input_signature = ( + tree.map_structure( + lambda x: tf.TensorSpec(x.shape, x.dtype), inputs + ), + ) + else: + input_signature = tuple( + tree.map_shape_structure( + lambda s: tf.TensorSpec(s, self.input_dtype), value + ) + for value in self._build_shapes_dict.values() + ) + + @tf.function(input_signature=input_signature) + def serving_default(inputs): + return self(inputs) + + return serving_default diff --git a/keras/src/backend/tensorflow/linalg.py b/keras/src/backend/tensorflow/linalg.py new file mode 100644 index 000000000000..16053ad5c812 --- /dev/null +++ b/keras/src/backend/tensorflow/linalg.py @@ -0,0 +1,270 @@ +import tensorflow as tf + +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.tensorflow.core import cast +from keras.src.backend.tensorflow.core import convert_to_tensor + + +def cholesky(a, upper=False): + out = tf.linalg.cholesky(a) + # tf.linalg.cholesky simply returns NaNs for non-positive definite matrices + out = tf.debugging.check_numerics(out, "Cholesky") + if upper: + return tf.linalg.adjoint(out) + return out + + +def cholesky_inverse(a, upper=False): + identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype) + inv_chol = tf.linalg.triangular_solve(a, identity, lower=not upper) + if upper: + a_inv = tf.matmul(inv_chol, inv_chol, transpose_b=True) + else: + a_inv = tf.matmul(inv_chol, inv_chol, transpose_a=True) + return a_inv + + +def det(a): + return tf.linalg.det(a) + + +def eig(a): + return tf.linalg.eig(a) + + +def eigh(a): + return tf.linalg.eigh(a) + + +def inv(a): + return tf.linalg.inv(a) + + +def lu_factor(a): + lu, p = tf.linalg.lu(a) + return lu, tf.math.invert_permutation(p) + + +def norm(x, ord=None, axis=None, keepdims=False): + from keras.src.backend.tensorflow.numpy import moveaxis + + x = convert_to_tensor(x) + x_shape = x.shape + ndim = x_shape.rank + + if axis is None: + axis = tuple(range(ndim)) + elif isinstance(axis, int): + axis = (axis,) + if any(a < -ndim or a >= ndim for a in axis): + raise ValueError( + "All `axis` values must be in the range [-ndim, ndim). " + f"Received inputs with ndim={ndim}, while axis={axis}" + ) + axis = axis[0] if len(axis) == 1 else axis + num_axes = 1 if isinstance(axis, int) else len(axis) + + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + + # Ref: jax.numpy.linalg.norm + if num_axes == 1: + if ord is None or ord == 2: + return tf.sqrt( + tf.reduce_sum(x * tf.math.conj(x), axis=axis, keepdims=keepdims) + ) + elif ord == float("inf"): + return tf.math.reduce_max( + tf.math.abs(x), axis=axis, keepdims=keepdims + ) + elif ord == float("-inf"): + return tf.math.reduce_min( + tf.math.abs(x), axis=axis, keepdims=keepdims + ) + elif ord == 0: + return tf.math.reduce_sum( + tf.cast(tf.not_equal(x, 0), dtype=x.dtype), + axis=axis, + keepdims=keepdims, + ) + elif isinstance(ord, str): + raise ValueError( + f"Invalid `ord` argument for vector norm. Received: ord={ord}" + ) + else: + ord = convert_to_tensor(ord, dtype=x.dtype) + out = tf.math.reduce_sum( + tf.pow(tf.math.abs(x), ord), axis=axis, keepdims=keepdims + ) + return tf.pow(out, 1.0 / ord) + elif num_axes == 2: + row_axis, col_axis = axis[0], axis[1] + row_axis = row_axis + ndim if row_axis < 0 else row_axis + col_axis = col_axis + ndim if col_axis < 0 else col_axis + if ord is None or ord == "fro": + return tf.sqrt( + tf.reduce_sum(x * tf.math.conj(x), axis=axis, keepdims=keepdims) + ) + elif ord == 1: + if not keepdims and col_axis > row_axis: + col_axis -= 1 + x = tf.math.reduce_max( + tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, + keepdims=keepdims, + ) + elif ord == -1: + if not keepdims and col_axis > row_axis: + col_axis -= 1 + x = tf.math.reduce_min( + tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, + keepdims=keepdims, + ) + elif ord == float("inf"): + if not keepdims and row_axis > col_axis: + row_axis -= 1 + x = tf.math.reduce_max( + tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims), + axis=row_axis, + keepdims=keepdims, + ) + elif ord == float("-inf"): + if not keepdims and row_axis > col_axis: + row_axis -= 1 + x = tf.math.reduce_min( + tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims), + axis=row_axis, + keepdims=keepdims, + ) + elif ord in ("nuc", 2, -2): + x = moveaxis(x, axis, (-2, -1)) + if ord == -2: + x = tf.math.reduce_min( + tf.linalg.svd(x, compute_uv=False), axis=-1 + ) + elif ord == 2: + x = tf.math.reduce_max( + tf.linalg.svd(x, compute_uv=False), axis=-1 + ) + else: + x = tf.math.reduce_sum( + tf.linalg.svd(x, compute_uv=False), axis=-1 + ) + if keepdims: + x = tf.expand_dims(x, axis[0]) + x = tf.expand_dims(x, axis[1]) + else: + raise ValueError( + f"Invalid `ord` argument for matrix norm. Received: ord={ord}" + ) + return x + else: + raise ValueError(f"Invalid axis values. Received: axis={axis}") + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + if mode == "reduced": + return tf.linalg.qr(x) + return tf.linalg.qr(x, full_matrices=True) + + +def solve(a, b): + # tensorflow.linalg.solve only supports same rank inputs + if b.shape.ndims == a.shape.ndims - 1: + b = tf.expand_dims(b, axis=-1) + return tf.squeeze(tf.linalg.solve(a, b), axis=-1) + return tf.linalg.solve(a, b) + + +def solve_triangular(a, b, lower=False): + if b.shape.ndims == a.shape.ndims - 1: + b = tf.expand_dims(b, axis=-1) + return tf.squeeze( + tf.linalg.triangular_solve(a, b, lower=lower), axis=-1 + ) + return tf.linalg.triangular_solve(a, b, lower=lower) + + +def svd(x, full_matrices=True, compute_uv=True): + if compute_uv is False: + return tf.linalg.svd(x, full_matrices=full_matrices, compute_uv=False) + s, u, v = tf.linalg.svd( + x, full_matrices=full_matrices, compute_uv=compute_uv + ) + return u, s, tf.linalg.adjoint(v) + + +def lstsq(a, b, rcond=None): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + if a.shape[0] != b.shape[0]: + raise ValueError("Leading dimensions of input arrays must match") + b_orig_ndim = b.ndim + if b_orig_ndim == 1: + b = b[:, None] + if a.ndim != 2: + raise TypeError( + f"{a.ndim}-dimensional array given. Array must be two-dimensional" + ) + if b.ndim != 2: + raise TypeError( + f"{b.ndim}-dimensional array given. " + "Array must be one or two-dimensional" + ) + m, n = a.shape + dtype = a.dtype + eps = tf.experimental.numpy.finfo(dtype).eps + if a.shape == (): + s = tf.zeros(0, dtype=a.dtype) + x = tf.zeros((n, *b.shape[1:]), dtype=a.dtype) + else: + if rcond is None: + rcond = eps * max(n, m) + else: + rcond = tf.where(rcond < 0, eps, rcond) + u, s, vt = svd(a, full_matrices=False) + mask = s >= tf.convert_to_tensor(rcond, dtype=s.dtype) * s[0] + safe_s = tf.cast(tf.where(mask, s, 1), dtype=a.dtype) + s_inv = tf.where(mask, 1 / safe_s, 0)[:, tf.newaxis] + u_t_b = tf.matmul(tf.transpose(tf.math.conj(u)), b) + x = tf.matmul(tf.transpose(tf.math.conj(vt)), s_inv * u_t_b) + + if b_orig_ndim == 1: + x = tf.reshape(x, [-1]) + return x + + +def jvp(fun, primals, tangents, has_aux=False): + primal_flat = tf.nest.flatten(primals) + tangent_flat = tf.nest.flatten(tangents) + + tangent_flat = [ + tf.cast(t, p.dtype) for t, p in zip(tangent_flat, primal_flat) + ] + + with tf.autodiff.ForwardAccumulator(primal_flat, tangent_flat) as acc: + if has_aux: + primals_out, aux = fun(*primals) + else: + primals_out = fun(*primals) + + primals_out_flat = tf.nest.flatten(primals_out) + tangents_out_flat = [acc.jvp(po) for po in primals_out_flat] + + tangents_out = tf.nest.pack_sequence_as(primals_out, tangents_out_flat) + + if has_aux: + return primals_out, tangents_out, aux + return primals_out, tangents_out diff --git a/keras/src/backend/tensorflow/math.py b/keras/src/backend/tensorflow/math.py new file mode 100644 index 000000000000..e01e40e682db --- /dev/null +++ b/keras/src/backend/tensorflow/math.py @@ -0,0 +1,381 @@ +import tensorflow as tf + +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.tensorflow.core import cast +from keras.src.backend.tensorflow.core import convert_to_tensor + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + if sorted: + if num_segments is not None: + raise ValueError( + "Argument `num_segments` cannot be set when sorted is True " + "when using the tensorflow backend." + f"Received: num_segments={num_segments}, sorted={sorted}." + ) + return tf.math.segment_sum(data, segment_ids) + else: + if num_segments is None: + unique_segment_ids, _ = tf.unique(segment_ids) + num_segments = tf.shape(unique_segment_ids)[0] + return tf.math.unsorted_segment_sum(data, segment_ids, num_segments) + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + if sorted: + if num_segments is not None: + raise ValueError( + "Argument `num_segments` cannot be set when sorted is True " + "when using the tensorflow backend." + f"Received: num_segments={num_segments}, sorted={sorted}." + ) + return tf.math.segment_max(data, segment_ids) + else: + if num_segments is None: + unique_segment_ids, _ = tf.unique(segment_ids) + num_segments = tf.shape(unique_segment_ids)[0] + return tf.math.unsorted_segment_max(data, segment_ids, num_segments) + + +def top_k(x, k, sorted=True): + return tf.math.top_k(x, k, sorted=sorted) + + +def in_top_k(targets, predictions, k): + return tf.math.in_top_k(targets, predictions, k) + + +def logsumexp(x, axis=None, keepdims=False): + return tf.math.reduce_logsumexp(x, axis=axis, keepdims=keepdims) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + if mode == "reduced": + return tf.linalg.qr(x) + return tf.linalg.qr(x, full_matrices=True) + + +def extract_sequences(x, sequence_length, sequence_stride): + return tf.signal.frame( + x, + frame_length=sequence_length, + frame_step=sequence_stride, + axis=-1, + pad_end=False, + ) + + +def _get_complex_tensor_from_tuple(x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + f"Received: x={x}" + ) + # `convert_to_tensor` does not support passing complex tensors. We separate + # the input out into real and imaginary and convert them separately. + real, imag = x + real = convert_to_tensor(real) + imag = convert_to_tensor(imag) + # Check shapes. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + "Both the real and imaginary parts should have the same shape. " + f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" + ) + # Ensure dtype is float. + if not real.dtype.is_floating or not imag.dtype.is_floating: + raise ValueError( + "At least one tensor in input `x` is not of type float." + f"Received: x={x}." + ) + complex_input = tf.dtypes.complex(real, imag) + return complex_input + + +def fft(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = tf.signal.fft(complex_input) + return tf.math.real(complex_output), tf.math.imag(complex_output) + + +def fft2(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = tf.signal.fft2d(complex_input) + return tf.math.real(complex_output), tf.math.imag(complex_output) + + +def ifft2(x): + real, imag = x + h = cast(tf.shape(real)[-2], real.dtype) + w = cast(tf.shape(real)[-1], real.dtype) + real_conj, imag_conj = real, -imag + fft_real, fft_imag = fft2((real_conj, imag_conj)) + return fft_real / (h * w), -fft_imag / (h * w) + + +def rfft(x, fft_length=None): + if fft_length is not None: + fft_length = [fft_length] + complex_output = tf.signal.rfft(x, fft_length=fft_length) + return tf.math.real(complex_output), tf.math.imag(complex_output) + + +def irfft(x, fft_length=None): + complex_input = _get_complex_tensor_from_tuple(x) + if fft_length is not None: + fft_length = [fft_length] + return tf.signal.irfft(complex_input, fft_length) + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + if standardize_dtype(x.dtype) not in {"float32", "float64"}: + raise TypeError( + "Invalid input type. Expected `float32` or `float64`. " + f"Received: input type={x.dtype}" + ) + if fft_length < sequence_length: + raise ValueError( + "`fft_length` must equal or larger than `sequence_length`. " + f"Received: sequence_length={sequence_length}, " + f"fft_length={fft_length}" + ) + if isinstance(window, str): + if window not in {"hann", "hamming"}: + raise ValueError( + "If a string is passed to `window`, it must be one of " + f'`"hann"`, `"hamming"`. Received: window={window}' + ) + x = convert_to_tensor(x) + + if center: + pad_width = [(0, 0) for _ in range(len(x.shape))] + pad_width[-1] = (fft_length // 2, fft_length // 2) + x = tf.pad(x, pad_width, mode="reflect") + + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + if window == "hann": + win_array = tf.signal.hann_window( + sequence_length, periodic=True, dtype=x.dtype + ) + else: + win_array = tf.signal.hamming_window( + sequence_length, periodic=True, dtype=x.dtype + ) + else: + win_array = convert_to_tensor(window, dtype=x.dtype) + if len(win_array.shape) != 1 or win_array.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win_array.shape}" + ) + win_array = tf.pad(win_array, [[l_pad, r_pad]]) + + def win(frame_step, dtype): + return win_array + + else: + win = None + + result = tf.signal.stft( + x, + frame_length=(sequence_length + l_pad + r_pad), + frame_step=sequence_stride, + fft_length=fft_length, + window_fn=win, + ) + return tf.math.real(result), tf.math.imag(result) + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + complex_input = _get_complex_tensor_from_tuple(x) + dtype = tf.math.real(complex_input).dtype + + expected_output_len = fft_length + sequence_stride * ( + tf.shape(complex_input)[-2] - 1 + ) + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + if window == "hann": + win_array = tf.signal.hann_window( + sequence_length, periodic=True, dtype=dtype + ) + else: + win_array = tf.signal.hamming_window( + sequence_length, periodic=True, dtype=dtype + ) + else: + win_array = convert_to_tensor(window, dtype=dtype) + if len(win_array.shape) != 1 or win_array.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win_array.shape}" + ) + win_array = tf.pad(win_array, [[l_pad, r_pad]]) + win = tf.signal.inverse_stft_window_fn( + sequence_stride, lambda frame_step, dtype: win_array + ) + else: + win = None + + x = tf.signal.inverse_stft( + complex_input, + frame_length=(sequence_length + l_pad + r_pad), + frame_step=sequence_stride, + fft_length=fft_length, + window_fn=win, + ) + + start = 0 if center is False else fft_length // 2 + if length is not None: + end = start + length + elif center is True: + end = -(fft_length // 2) + else: + end = expected_output_len + return x[..., start:end] + + +def rsqrt(x): + return tf.math.rsqrt(x) + + +def erf(x): + return tf.math.erf(x) + + +def erfinv(x): + return tf.math.erfinv(x) + + +def solve(a, b): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return tf.linalg.solve(a, b) + + +def norm(x, ord=None, axis=None, keepdims=False): + from keras.src.backend.tensorflow.numpy import moveaxis + + x = convert_to_tensor(x) + x_shape = x.shape + ndim = x_shape.rank + + if axis is None: + axis = tuple(range(ndim)) + elif isinstance(axis, int): + axis = (axis,) + + axis = axis[0] if len(axis) == 1 else axis + num_axes = 1 if isinstance(axis, int) else len(axis) + + if num_axes == 1 and ord is None: + ord = "euclidean" + elif num_axes == 2 and ord is None: + ord = "fro" + + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + + # Fast path to utilize `tf.linalg.norm` + if (num_axes == 1 and ord in ("euclidean", 1, 2, float("inf"))) or ( + num_axes == 2 and ord in ("euclidean", "fro", 1, 2, float("inf")) + ): + return tf.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + + # Ref: jax.numpy.linalg.norm + if num_axes == 1 and ord not in ("fro", "nuc"): + if ord == float("-inf"): + return tf.math.reduce_min( + tf.math.abs(x), axis=axis, keepdims=keepdims + ) + elif ord == 0: + return tf.math.reduce_sum( + tf.cast(tf.not_equal(x, 0), dtype=x.dtype), + axis=axis, + keepdims=keepdims, + ) + else: + ord = convert_to_tensor(ord, dtype=x.dtype) + out = tf.math.reduce_sum( + tf.pow(tf.math.abs(x), ord), axis=axis, keepdims=keepdims + ) + return tf.pow(out, 1.0 / ord) + elif num_axes == 2 and ord in ("nuc", float("-inf"), -2, -1): + row_axis, col_axis = axis[0], axis[1] + row_axis = row_axis + ndim if row_axis < 0 else row_axis + col_axis = col_axis + ndim if col_axis < 0 else col_axis + if ord == float("-inf"): + if not keepdims and row_axis > col_axis: + row_axis -= 1 + x = tf.math.reduce_min( + tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims), + axis=row_axis, + keepdims=keepdims, + ) + elif ord == -1: + if not keepdims and col_axis > row_axis: + col_axis -= 1 + x = tf.math.reduce_min( + tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims), + axis=col_axis, + keepdims=keepdims, + ) + else: + x = moveaxis(x, axis, (-2, -1)) + if ord == -2: + x = tf.math.reduce_min( + tf.linalg.svd(x, compute_uv=False), axis=-1 + ) + else: + x = tf.math.reduce_sum( + tf.linalg.svd(x, compute_uv=False), axis=-1 + ) + if keepdims: + x = tf.expand_dims(x, axis[0]) + x = tf.expand_dims(x, axis[1]) + return x + + if num_axes == 1: + raise ValueError( + f"Invalid `ord` argument for vector norm. Received: ord={ord}" + ) + elif num_axes == 2: + raise ValueError( + f"Invalid `ord` argument for matrix norm. Received: ord={ord}" + ) + else: + raise ValueError(f"Invalid axis values. Received: axis={axis}") + + +def logdet(x): + x = convert_to_tensor(x) + return tf.linalg.logdet(x) diff --git a/keras/src/backend/tensorflow/name_scope_test.py b/keras/src/backend/tensorflow/name_scope_test.py new file mode 100644 index 000000000000..f9d8eb7b8499 --- /dev/null +++ b/keras/src/backend/tensorflow/name_scope_test.py @@ -0,0 +1,37 @@ +import tensorflow as tf + +from keras.src.backend.tensorflow.core import name_scope +from keras.src.testing import TestCase + + +class TFNameScopeTest(TestCase): + def test_stacking(self): + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + with name_scope("outer") as outer: + self.assertEqual(outer.name, "outer") + self.assertEqual(tf.Variable(0, name="x").name, "outer/x:0") + with name_scope("middle") as middle: + self.assertEqual(middle.name, "middle") + self.assertEqual( + tf.Variable(0, name="x").name, "outer/middle/x:0" + ) + with name_scope("inner") as inner: + self.assertEqual(inner.name, "inner") + self.assertEqual( + tf.Variable(0, name="x").name, "outer/middle/inner/x:0" + ) + self.assertEqual( + tf.Variable(0, name="x").name, "outer/middle/x:0" + ) + self.assertEqual(tf.Variable(0, name="x").name, "outer/x:0") + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + + def test_deduplicate(self): + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + with name_scope("name", caller=1): + with name_scope("name", caller=1): + self.assertEqual(tf.Variable(0, name="x").name, "name/x:0") + self.assertEqual(tf.Variable(0, name="x").name, "x:0") + with name_scope("name"): + with name_scope("name"): + self.assertEqual(tf.Variable(0, name="x").name, "name/name/x:0") diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py new file mode 100644 index 000000000000..8ba64b10b78f --- /dev/null +++ b/keras/src/backend/tensorflow/nn.py @@ -0,0 +1,1126 @@ +import math +import warnings + +import tensorflow as tf + +from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_output_shape, +) +from keras.src.backend.tensorflow.core import cast +from keras.src.backend.tensorflow.core import convert_to_tensor + + +def relu(x): + return tf.nn.relu(x) + + +def relu6(x): + return tf.nn.relu6(x) + + +def sigmoid(x): + logits = x + output = tf.nn.sigmoid(x) + output._keras_logits = logits + return output + + +def sparse_sigmoid(x): + x = convert_to_tensor(x) + return tf.where( + x <= -1, + tf.constant(0.0, dtype=x.dtype), + tf.where(x >= 1, tf.constant(1.0, dtype=x.dtype), 0.5 * (x + 1)), + ) + + +def tanh(x): + return tf.nn.tanh(x) + + +def tanh_shrink(x): + return x - tf.math.tanh(x) + + +def softplus(x): + return tf.math.softplus(x) + + +def softsign(x): + return tf.nn.softsign(x) + + +def soft_shrink(x, threshold=0.5): + return tf.where( + x > threshold, + x - threshold, + tf.where(x < -threshold, x + threshold, tf.zeros_like(x)), + ) + + +def sparse_plus(x): + return tf.where( + x <= -1, + tf.zeros_like(x), + tf.where(x < 1, (1 / 4) * tf.pow(x + 1, 2), x), + ) + + +def silu(x): + return tf.nn.silu(x) + + +def squareplus(x, b=4): + x = convert_to_tensor(x) + b = convert_to_tensor(b, dtype=x.dtype) + y = x + tf.sqrt(tf.square(x) + b) + return y / 2 + + +def log_sigmoid(x): + return tf.math.log_sigmoid(x) + + +def leaky_relu(x, negative_slope=0.2): + return tf.nn.leaky_relu(x, alpha=negative_slope) + + +def hard_sigmoid(x): + x = convert_to_tensor(x) + return relu6(x + tf.constant(3.0, x.dtype)) / tf.constant(6.0, x.dtype) + + +def hard_silu(x): + return x * hard_sigmoid(x) + + +def elu(x, alpha=1.0): + res = tf.nn.elu(x) + if alpha == 1: + return res + else: + return tf.where(x > 0, res, alpha * res) + + +def selu(x): + return tf.nn.selu(x) + + +def gelu(x, approximate=True): + x = convert_to_tensor(x) + return tf.nn.gelu(x, approximate=approximate) + + +def celu(x, alpha=1.0): + return tf.maximum(x, 0.0) + alpha * tf.math.expm1( + tf.minimum(x, 0.0) / alpha + ) + + +def glu(x, axis=-1): + if x.shape[axis] % 2 != 0: + raise ValueError( + "axis size must be divisible by 2. " + f"Received: x.shape={x.shape} with axis={axis}" + ) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=axis) + return x1 * tf.sigmoid(x2) + + +def hard_tanh(x): + return tf.clip_by_value(x, clip_value_min=-1.0, clip_value_max=1.0) + + +def hard_shrink(x, threshold=0.5): + return tf.where(tf.abs(x) > threshold, x, tf.zeros_like(x)) + + +def threshold(x, threshold, default_value): + return tf.where(x > threshold, x, default_value) + + +def softmax(x, axis=-1): + logits = x + if axis is None: + # Unlike numpy, tf will handle axis=None as axis=-1. + # We need this workaround for the reduction on every dim. + output = tf.reshape(x, [-1]) + output = tf.nn.softmax(output, axis=-1) + output = tf.reshape(output, tf.shape(x)) + else: + output = tf.nn.softmax(x, axis=axis) + output._keras_logits = logits + return output + + +def log_softmax(x, axis=-1): + if axis is None: + # Unlike numpy, tf will handle axis=None as axis=-1. + # We need this workaround for the reduction on every dim. + output = tf.reshape(x, [-1]) + output = tf.nn.log_softmax(output, axis=-1) + return tf.reshape(output, tf.shape(x)) + return tf.nn.log_softmax(x, axis=axis) + + +def sparsemax(x, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(x) + logits_sorted = tf.sort(logits, direction="DESCENDING", axis=axis) + logits_cumsum = tf.cumsum(logits_sorted, axis=axis) + r = tf.range(1, tf.shape(logits)[axis] + 1, dtype=logits.dtype) + r_shape = [1] * len(logits.shape) + r_shape[axis] = -1 # Broadcast to match the target axis + r = tf.reshape(r, r_shape) # Reshape for broadcasting + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + logits_cumsum_safe = tf.where(support, logits_cumsum, 0.0) + k = tf.reduce_sum(tf.cast(support, logits.dtype), axis=axis, keepdims=True) + tau = (tf.reduce_sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k + output = tf.maximum(logits - tau, 0.0) + return output + + +def _transpose_spatial_inputs(inputs): + num_spatial_dims = len(inputs.shape) - 2 + # Tensorflow pooling does not support `channels_first` format, so + # we need to transpose to `channels_last` format. + if num_spatial_dims == 1: + inputs = tf.transpose(inputs, (0, 2, 1)) + elif num_spatial_dims == 2: + inputs = tf.transpose(inputs, (0, 2, 3, 1)) + elif num_spatial_dims == 3: + inputs = tf.transpose(inputs, (0, 2, 3, 4, 1)) + else: + raise ValueError( + "Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D " + f"and 3D inputs. But received shape: {inputs.shape}." + ) + return inputs + + +def _transpose_spatial_outputs(outputs): + # Undo the transpose in `_transpose_spatial_inputs`. + num_spatial_dims = len(outputs.shape) - 2 + if num_spatial_dims == 1: + outputs = tf.transpose(outputs, (0, 2, 1)) + elif num_spatial_dims == 2: + outputs = tf.transpose(outputs, (0, 3, 1, 2)) + elif num_spatial_dims == 3: + outputs = tf.transpose(outputs, (0, 4, 1, 2, 3)) + return outputs + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + strides = pool_size if strides is None else strides + padding = padding.upper() + tf_data_format = _convert_data_format("channels_last", len(inputs.shape)) + if data_format == "channels_first": + # Tensorflow pooling does not support `channels_first` format, so + # we need to transpose to `channels_last` format. + inputs = _transpose_spatial_inputs(inputs) + + outputs = tf.nn.max_pool( + inputs, + pool_size, + strides, + padding, + tf_data_format, + ) + if data_format == "channels_first": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + +def average_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + strides = pool_size if strides is None else strides + padding = padding.upper() + tf_data_format = _convert_data_format("channels_last", len(inputs.shape)) + if data_format == "channels_first": + # Tensorflow pooling does not support `channels_first` format, so + # we need to transpose to `channels_last` format. + inputs = _transpose_spatial_inputs(inputs) + + outputs = tf.nn.avg_pool( + inputs, + pool_size, + strides, + padding, + tf_data_format, + ) + if data_format == "channels_first": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + +def _convert_data_format(data_format, ndim): + if data_format == "channels_last": + if ndim == 3: + return "NWC" + elif ndim == 4: + return "NHWC" + elif ndim == 5: + return "NDHWC" + else: + raise ValueError( + f"Input rank not supported: {ndim}. " + "Expected values are [3, 4, 5]" + ) + elif data_format == "channels_first": + if ndim == 3: + return "NCW" + elif ndim == 4: + return "NCHW" + elif ndim == 5: + return "NCDHW" + else: + raise ValueError( + f"Input rank not supported: {ndim}. " + "Expected values are [3, 4, 5]" + ) + else: + raise ValueError( + f"Invalid data_format: {data_format}. " + 'Expected values are ["channels_first", "channels_last"]' + ) + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + def _conv(): + tf_data_format = _convert_data_format(data_format, len(inputs.shape)) + return tf.nn.convolution( + inputs, + kernel, + strides, + padding.upper(), + data_format=tf_data_format, + dilations=dilation_rate, + ) + + # Certain ops are are broken in Tensorflow on CPU only. + # We can work around by compiling the op with XLA. + @tf.function(jit_compile=True) + def _conv_xla(): + return _conv() + + # Channels first "NCDHW" (3d convolutions) are broken on CPU without XLA. + needs_xla = data_format == "channels_first" and len(inputs.shape) == 5 + # grouped convolutions are broken on CPU without XLA. + data_format = backend.standardize_data_format(data_format) + if data_format == "channels_last": + channels = inputs.shape[-1] + else: + channels = inputs.shape[1] + needs_xla = needs_xla or channels != kernel.shape[-2] + if needs_xla: + return _conv_xla() + else: + return _conv() + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = len(inputs.shape) - 2 + if num_spatial_dims > 2: + raise ValueError( + "`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: " + f"{inputs.ndim}." + ) + # Because we use `tf.nn.depthwise_conv2d` for both 1D and 2D convs, we set + # `tf_data_format` using 2D conv format. + tf_data_format = _convert_data_format(data_format, 4) + padding = padding.upper() + if isinstance(strides, int): + strides = (strides,) * num_spatial_dims + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) * num_spatial_dims + if num_spatial_dims == 1: + # 1D depthwise conv. + if data_format == "channels_last": + strides = (1,) + strides * 2 + (1,) + spatial_start_dim = 1 + else: + strides = (1, 1) + strides * 2 + spatial_start_dim = 2 + inputs = tf.expand_dims(inputs, spatial_start_dim) + kernel = tf.expand_dims(kernel, axis=0) + + dilation_rate = None if dilation_rate is None else (1,) + dilation_rate + + outputs = tf.nn.depthwise_conv2d( + inputs, + kernel, + strides, + padding, + data_format=tf_data_format, + dilations=dilation_rate, + ) + return tf.squeeze(outputs, [spatial_start_dim]) + + if data_format == "channels_last": + strides = (1,) + strides + (1,) + spatial_start_dim = 1 + else: + strides = (1, 1) + strides + spatial_start_dim = 2 + return tf.nn.depthwise_conv2d( + inputs, + kernel, + strides, + padding, + data_format=tf_data_format, + dilations=dilation_rate, + ) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + num_spatial_dims = len(inputs.shape) - 2 + if num_spatial_dims > 2: + raise ValueError( + "`num_spatial_dims` must be 1 or 2. Received: " + f"num_spatial_dims={num_spatial_dims}." + ) + # Because we use `tf.nn.separable_conv2d` for both 1D and 2D convs, we set + # `tf_data_format` using 2D conv format. + tf_data_format = _convert_data_format(data_format, 4) + padding = padding.upper() + if isinstance(strides, int): + strides = (strides,) * num_spatial_dims + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) * num_spatial_dims + if num_spatial_dims == 1: + # 1D depthwise conv. + if data_format == "channels_last": + strides = (1,) + strides * 2 + (1,) + spatial_start_dim = 1 + else: + strides = (1, 1) + strides * 2 + spatial_start_dim = 2 + inputs = tf.expand_dims(inputs, spatial_start_dim) + depthwise_kernel = tf.expand_dims(depthwise_kernel, axis=0) + pointwise_kernel = tf.expand_dims(pointwise_kernel, axis=0) + dilation_rate = None if dilation_rate is None else (1,) + dilation_rate + + outputs = tf.nn.separable_conv2d( + inputs, + depthwise_kernel, + pointwise_kernel, + strides, + padding, + data_format=tf_data_format, + dilations=dilation_rate, + ) + return tf.squeeze(outputs, [spatial_start_dim]) + + if data_format == "channels_last": + strides = (1,) + strides + (1,) + else: + strides = (1, 1) + strides + return tf.nn.separable_conv2d( + inputs, + depthwise_kernel, + pointwise_kernel, + strides, + padding, + data_format=tf_data_format, + dilations=dilation_rate, + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + data_format = backend.standardize_data_format(data_format) + tf_data_format = _convert_data_format(data_format, len(inputs.shape)) + kernel_size = kernel.shape[:-2] + filters = kernel.shape[-2] + input_shape = list(inputs.shape) + symbolic_shape = tf.shape(inputs) + for i, e in enumerate(input_shape): + if e is None: + input_shape[i] = symbolic_shape[i] + output_shape = compute_conv_transpose_output_shape( + input_shape, + kernel_size, + filters, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + + return tf.nn.conv_transpose( + inputs, + kernel, + output_shape, + strides, + padding=padding.upper(), + data_format=tf_data_format, + dilations=dilation_rate, + ) + + +def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + x = convert_to_tensor(x, dtype="int64") + if dtype is None: + dtype = "float32" + else: + dtype = backend.standardize_dtype(dtype) + if sparse: + # We don't use `tf.sparse.bincount`, it doesn't handle negative indices + # and only support rank 1 and 2 tensors (`one_hot` adds a dimension). + if axis < 0: + axis = axis + len(x.shape) + 1 + values_count = math.prod(x.shape) + values = tf.reshape(x, (values_count,)) + # We deal with negative inputs by having zeros in the output although + # it's useless. It makes shapes static. + values = tf.cast(tf.greater_equal(values, 0), dtype=dtype) + indices = [tf.range(dim) for dim in x.shape] + indices = tf.meshgrid(*indices, indexing="ij") + indices.insert(axis, tf.maximum(x, 0)) # Deal with negative indices + indices = [tf.reshape(a, (values_count, 1)) for a in indices] + indices = [tf.cast(a, tf.int64) for a in indices] + indices = tf.concat(indices, axis=1) + shape = list(x.shape) + shape.insert(axis, num_classes) + return tf.SparseTensor(indices, values, shape) + on_value, off_value = (True, False) if dtype == "bool" else (None, None) + return tf.one_hot( + x, + num_classes, + on_value=on_value, + off_value=off_value, + axis=axis, + dtype=dtype, + ) + + +def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + reduction_axis = 1 if len(x.shape) > 1 else 0 + if backend.standardize_dtype(dtype) == "bool": + if sparse: + # `tf.sparse.reduce_max` doesn't work on bool and there is no + # `tf.sparse.reduce_any`. + outputs = one_hot( + x, num_classes, axis=axis, dtype="int8", sparse=True + ) + outputs = tf.sparse.reduce_max( + outputs, axis=reduction_axis, output_is_sparse=True + ) + outputs_shape = outputs.shape + outputs = tf.cast(outputs, dtype) + outputs.set_shape(outputs_shape) + return outputs + else: + outputs = one_hot(x, num_classes, axis=axis, dtype=dtype) + return tf.reduce_any(outputs, axis=reduction_axis) + else: + if sparse: + # We don't use `tf.sparse.bincount`, it doesn't handle negative + # indices and has a rank limitation. + outputs = one_hot( + x, num_classes, axis=axis, dtype=dtype, sparse=True + ) + return tf.sparse.reduce_max( + outputs, axis=reduction_axis, output_is_sparse=True + ) + else: + outputs = one_hot(x, num_classes, axis=axis, dtype=dtype) + return tf.reduce_max(outputs, axis=reduction_axis) + + +def _get_logits(output, from_logits, op_type, fn_name): + """Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor.""" + output_ = output + from_logits_ = from_logits + + has_keras_logits = hasattr(output, "_keras_logits") + if has_keras_logits: + output_ = output._keras_logits + from_logits_ = True + + from_expected_op_type = ( + hasattr(output, "op") + and not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable)) + and output.op.type == op_type + ) and not has_keras_logits + + if from_expected_op_type: + # When softmax activation function is used for output operation, we + # use logits from the softmax function directly to compute loss in order + # to prevent collapsing zero when training. + assert len(output.op.inputs) == 1 + output_ = output.op.inputs[0] + from_logits_ = True + + if from_logits and (has_keras_logits or from_expected_op_type): + warnings.warn( + f'"`{fn_name}` received `from_logits=True`, but ' + f"the `output` argument was produced by a {op_type} " + "activation and thus does not represent logits. " + "Was this intended?", + stacklevel=2, + ) + return output_, from_logits_ + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + """Categorical crossentropy between an output tensor and a target tensor. + + Args: + target: A tensor of the same shape as `output`. + output: A tensor resulting from a softmax + (unless `from_logits` is `True`, in which + case `output` is expected to be the logits). + from_logits: Boolean, whether `output` is the + result of a softmax, or is a tensor of logits. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last`, and `axis=1` corresponds to data format + `channels_first`. + + Returns: + Output tensor. + + Example: + + >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3]) + >>> print(a) + tf.Tensor( + [[1. 0. 0.] + [0. 1. 0.] + [0. 0. 1.]], shape=(3, 3), dtype=float32) + >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94], + ... shape=[3, 3]) + >>> print(b) + tf.Tensor( + [[0.9 0.05 0.05] + [0.05 0.89 0.06] + [0.05 0.01 0.94]], shape=(3, 3), dtype=float32) + >>> loss = categorical_crossentropy(a, b) + >>> print(np.around(loss, 5)) + [0.10536 0.11653 0.06188] + >>> loss = categorical_crossentropy(a, a) + >>> print(np.around(loss, 5)) + [0. 0. 0.] + """ + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) != len(output.shape): + raise ValueError( + "Arguments `target` and `output` must have the same rank " + "(ndim). Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + for e1, e2 in zip(target.shape, output.shape): + if e1 is not None and e2 is not None and e1 != e2: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + output, from_logits = _get_logits( + output, from_logits, "Softmax", "categorical_crossentropy" + ) + if from_logits: + return tf.nn.softmax_cross_entropy_with_logits( + labels=target, logits=output, axis=axis + ) + + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. + output = output / tf.reduce_sum(output, axis, keepdims=True) + + # Compute cross entropy from probabilities. + output = tf.clip_by_value( + output, backend.epsilon(), 1.0 - backend.epsilon() + ) + return -tf.reduce_sum(target * tf.math.log(output), axis) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + """Categorical crossentropy with integer targets. + + Args: + target: An integer tensor. + output: A tensor resulting from a softmax + (unless `from_logits` is True, in which + case `output` is expected to be the logits). + from_logits: Boolean, whether `output` is the + result of a softmax, or is a tensor of logits. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last`, and `axis=1` corresponds to data format + `channels_first`. + + Returns: + Output tensor. + """ + if axis != -1 and axis != len(output.shape) - 1: + raise ValueError( + f"Only axis=-1 is currently supported. Received: axis={axis}" + ) + output, from_logits = _get_logits( + output, from_logits, "Softmax", "sparse_categorical_crossentropy" + ) + + target = tf.convert_to_tensor(target) + target = tf.cast(target, dtype="int64") + output = tf.convert_to_tensor(output) + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: + target = tf.squeeze(target, axis=-1) + + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + if len(target.shape) != len(output.shape[:-1]): + raise ValueError( + "Argument `output` must have rank (ndim) `target.ndim - 1`. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + for e1, e2 in zip(target.shape, output.shape[:-1]): + if e1 is not None and e2 is not None and e1 != e2: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if not from_logits: + output = tf.clip_by_value( + output, backend.epsilon(), 1 - backend.epsilon() + ) + output = tf.math.log(output) + + result = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=target, logits=output + ) + return result + + +def binary_crossentropy(target, output, from_logits=False): + """Binary crossentropy between an output tensor and a target tensor. + + Args: + target: A tensor with the same shape as `output`. + output: A tensor. + from_logits: Whether `output` is expected to be a logits tensor. + By default, we consider that `output` + encodes a probability distribution. + + Returns: + A tensor. + """ + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + + if len(target.shape) != len(output.shape): + raise ValueError( + "Arguments `target` and `output` must have the same rank " + "(ndim). Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + for e1, e2 in zip(target.shape, output.shape): + if e1 is not None and e2 is not None and e1 != e2: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + output, from_logits = _get_logits( + output, from_logits, "Sigmoid", "binary_crossentropy" + ) + + if from_logits: + return tf.nn.sigmoid_cross_entropy_with_logits( + labels=target, logits=output + ) + + # Compute cross entropy from probabilities. + output = tf.clip_by_value( + output, backend.epsilon(), 1.0 - backend.epsilon() + ) + bce = target * tf.math.log(output) + bce += (1 - target) * tf.math.log(1 - output) + return -bce + + +def moments(x, axes, keepdims=False, synchronized=False): + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = backend.standardize_dtype(x.dtype) + if ori_dtype in ("float16", "bfloat16"): + need_cast = True + x = cast(x, "float32") + + if synchronized: + mean, variance = _compute_moments_sync(x, axes, keepdims) + else: + mean, variance = _compute_moments(x, axes, keepdims) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max) + variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance + + +def _compute_moments_sync(x, axes, keepdims): + replica_ctx = tf.distribute.get_replica_context() + if not replica_ctx: + return _compute_moments(x, axes, keepdims) + + local_count = tf.ones_like(x, name="count") + + local_sum = tf.reduce_sum(x, axis=axes, keepdims=True) + local_squared_sum = tf.reduce_sum(tf.square(x), axis=axes, keepdims=True) + local_count = tf.reduce_sum(local_count, axis=axes, keepdims=True) + + # TODO(b/163099951): batch the all-reduces once we sort out the + # ordering issue for NCCL. We don't have a mechanism to launch + # NCCL in the same order in each replica nowadays, so we limit + # NCCL to batch all-reduces. + y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum) + y_squared_sum = replica_ctx.all_reduce( + tf.distribute.ReduceOp.SUM, local_squared_sum + ) + count_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_count) + + mean = tf.math.divide_no_nan(y_sum, count_sum) + y_squared_mean = tf.math.divide_no_nan(y_squared_sum, count_sum) + # var = E(x^2) - E(x)^2 + variance = tf.maximum(y_squared_mean - tf.square(mean), 0.0) + if not keepdims: + mean = tf.squeeze(mean, axes) + variance = tf.squeeze(variance, axes) + + return mean, variance + + +def _compute_moments(x, axes, keepdims): + return tf.nn.moments(x, axes, keepdims=keepdims) + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + if axis != -1: + shape = [1] * len(x.shape) + shape[axis] = mean.shape[0] + mean = tf.reshape(mean, shape) + variance = tf.reshape(variance, shape) + if offset is not None: + offset = tf.reshape(offset, shape) + if scale is not None: + scale = tf.reshape(scale, shape) + + return tf.nn.batch_normalization( + x=x, + mean=mean, + variance=variance, + offset=offset, + scale=scale, + variance_epsilon=epsilon, + ) + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + target = convert_to_tensor(target) + output = convert_to_tensor(output) + target = tf.cast(target, dtype="int32") + + # `tf.nn.ctc_loss` will internally cast to float32 when the input is float16 + # or bfloat16. Additionally, it will raise an error when the input is + # float64. As a result, we perform the casting externally and add support + # for float64. + result_dtype = backend.result_type(output.dtype, "float32") + compute_dtype = "float32" if result_dtype == "float64" else result_dtype + output = tf.cast(output, compute_dtype) + loss = tf.nn.ctc_loss( + labels=target, + logits=output, + label_length=target_length, + logit_length=output_length, + blank_index=mask_index, + logits_time_major=False, + ) + return tf.cast(loss, result_dtype) + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + input_shape = tf.shape(inputs) + num_samples, num_steps = input_shape[0], input_shape[1] + inputs = tf.transpose(inputs, (1, 0, 2)) + + dtype = backend.result_type(inputs.dtype, "float32") + inputs = tf.cast(inputs, dtype) + + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + if strategy == "greedy": + (decoded, scores) = tf.nn.ctc_greedy_decoder( + inputs=inputs, + sequence_length=sequence_lengths, + merge_repeated=merge_repeated, + blank_index=mask_index, + ) + elif strategy == "beam_search": + # Move `mask_index` column to the last position since this is the + # default for `tf.nn.ctc_beam_search_decoder` + if mask_index is not None: + inputs_before = inputs[..., :mask_index] + inputs_mask = inputs[..., mask_index : mask_index + 1] + inputs_after = inputs[..., mask_index + 1 :] + inputs = tf.concat( + [inputs_before, inputs_after, inputs_mask], axis=-1 + ) + (decoded, scores) = tf.nn.ctc_beam_search_decoder( + inputs=inputs, + sequence_length=sequence_lengths, + beam_width=beam_width, + top_paths=top_paths, + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + # Postprocess sparse tensor + decoded_dense = [] + for st in decoded: + st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) + decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) + decoded_dense = tf.stack(decoded_dense, axis=0) + decoded_dense = tf.cast(decoded_dense, "int32") + + # We need to recover the labels because we swapped the indices earlier + if strategy == "beam_search" and mask_index is not None: + if mask_index < 0: + mask_index = mask_index + input_shape[-1] + decoded_dense = tf.where( + decoded_dense >= mask_index, decoded_dense + 1, decoded_dense + ) + return decoded_dense, scores + + +def psnr(x1, x2, max_val): + from keras.src.backend.tensorflow.numpy import log10 + + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = tf.reduce_mean(tf.square(x1 - x2)) + psnr = 20 * log10(max_val) - 10 * log10(mse) + return psnr + + +def _get_large_negative(dtype): + dtype = backend.standardize_dtype(dtype) + val = 65500.0 if dtype == "float16" else 3.38953e38 + return tf.constant(val * -0.7, dtype=dtype) + + +def _apply_masks(logits, mask, is_causal): + if mask is None and not is_causal: + return logits + + combined_mask = tf.ones_like(logits, dtype="bool") + if mask is not None: + combined_mask = tf.logical_and(combined_mask, mask) + + if is_causal: + logits_shape = tf.shape(logits) + T, S = logits_shape[2], logits_shape[3] + mask = tf.linalg.band_part(tf.ones((T, S), "bool"), -1, 0) + mask = mask[None, None, :, :] + combined_mask = tf.logical_and(combined_mask, mask) + + padded_logits = tf.where( + combined_mask, logits, _get_large_negative(logits.dtype) + ) + return padded_logits + + +def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): + logits_dtype = backend.result_type(query.dtype, "float32") + logits = tf.einsum("BTNH,BSNH->BNTS", query, key, optimize="optimal") + logits = tf.cast(logits, logits_dtype) + logits = tf.multiply(logits, tf.cast(scale, logits.dtype)) + + if bias is not None: + logits = tf.add(logits, tf.cast(bias, logits.dtype)) + + padded_logits = _apply_masks(logits, mask, is_causal) + + # Softmax is always carried out in high precision. + probs_dtype = backend.result_type(padded_logits.dtype, "float32") + probs = tf.cast( + tf.nn.softmax(tf.cast(padded_logits, probs_dtype), axis=-1), key.dtype + ) + return tf.einsum("BNTS,BSNH->BTNH", probs, value, optimize="optimal") + + +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, +): + if flash_attention is None: + flash_attention = False + if flash_attention: + raise ValueError( + "Flash attention is not supported in tensorflow backend." + ) + + # Ref: jax.nn.dot_product_attention + # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 + # Not support `query_seq_lengths` and `key_value_seq_lengths` args + query = convert_to_tensor(query) + key = convert_to_tensor(key) + value = convert_to_tensor(value) + if len(query.shape) != 4: + raise ValueError( + "`dot_product_attention` only supports 4D inputs. " + f"Received: query.shape={query.shape}, key.shape={key.shape}, " + f"value.shape={value.shape}." + ) + compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype) + query = cast(query, compute_dtype) + key = cast(key, compute_dtype) + value = cast(value, compute_dtype) + if bias is not None: + bias = convert_to_tensor(bias, dtype=compute_dtype) + + H = tf.shape(key)[-1] + scale = (1.0 / tf.sqrt(tf.cast(H, "float32"))) if scale is None else scale + return _dot_product_attention_xla( + query, key, value, bias, mask, is_causal, scale + ) + + +def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + """Tensorflow implementation of Unfold. + Extract sliding local blocks from a **NCHW** batched image tensor. + + Args: + input: 4-D tensor, shape (N, C, H, W) **required**. + kernel_size: int or (kH, kW) + dilation: int or (dH, dW), default 1 + padding: int or (pH, pW), default 0 + stride: int or (sH, sW), default 1 + + Returns: + 3-D tensor, shape (N, C*kH*kW, L) + """ + k = ( + (kernel_size, kernel_size) + if isinstance(kernel_size, int) + else kernel_size + ) + d = (dilation, dilation) if isinstance(dilation, int) else dilation + p = (padding, padding) if isinstance(padding, int) else padding + s = (stride, stride) if isinstance(stride, int) else stride + N, C, H, W = input.shape + + # ---- padding ---- + if any(_ > 0 for _ in p): + input = tf.pad(input, [[0, 0], [0, 0], [p[0], p[0]], [p[1], p[1]]]) + x = tf.transpose(input, [0, 2, 3, 1]) # (N, H, W, C) + patches = tf.image.extract_patches( + images=x, + sizes=[1, k[0], k[1], 1], + strides=[1, s[0], s[1], 1], + rates=[1, d[0], d[1], 1], + padding="VALID", + ) # (N, nH, nW, kH*kW*C) + + N, nH, nW, D = patches.shape + patches = tf.reshape( + patches, [N, nH, nW, k[0], k[1], C] + ) # (N, nH, nW, kH, kW, C) + patches = tf.transpose( + patches, [0, 5, 3, 4, 1, 2] + ) # (N, C, kH, kW, nH, nW) + patches = tf.reshape(patches, [N, C * k[0] * k[1], nH * nW]) + return patches diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py new file mode 100644 index 000000000000..ff146a41253b --- /dev/null +++ b/keras/src/backend/tensorflow/numpy.py @@ -0,0 +1,3174 @@ +import builtins +import collections +import functools +import math +import string +import warnings + +import numpy as np +import tensorflow as tf +from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops +from tensorflow.python.ops.math_ops import is_nan + +from keras.src import tree +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import canonicalize_axis +from keras.src.backend.common.backend_utils import to_tuple_or_list +from keras.src.backend.common.backend_utils import vectorize_impl +from keras.src.backend.tensorflow import sparse +from keras.src.backend.tensorflow.core import cast +from keras.src.backend.tensorflow.core import convert_to_tensor +from keras.src.backend.tensorflow.core import shape as shape_op + + +def rot90(array, k=1, axes=(0, 1)): + """Rotate an array by 90 degrees in the specified plane. + + Args: + array: Input tensor + k: Number of 90-degree rotations (default=1) + axes: Tuple of two axes that define the plane of rotation. + Defaults to (0, 1). + + Returns: + Rotated tensor with correct shape transformation + """ + array = convert_to_tensor(array) + + if array.shape.rank < 2: + raise ValueError( + f"Input array must have at least 2 dimensions. " + f"Received: array.ndim={array.shape.rank}" + ) + + if len(axes) != 2 or axes[0] == axes[1]: + raise ValueError( + f"Invalid axes: {axes}. Axes must be a tuple of " + "two different dimensions." + ) + + k = k % 4 + if k == 0: + return array + + axes = tuple( + axis if axis >= 0 else array.shape.rank + axis for axis in axes + ) + + perm = [i for i in range(array.shape.rank) if i not in axes] + perm.extend(axes) + array = tf.transpose(array, perm) + + shape = tf.shape(array) + non_rot_shape = shape[:-2] + h, w = shape[-2], shape[-1] + + array = tf.reshape(array, tf.concat([[-1], [h, w]], axis=0)) + + array = tf.reverse(array, axis=[2]) + array = tf.transpose(array, [0, 2, 1]) + + if k % 2 == 1: + final_h, final_w = w, h + else: + final_h, final_w = h, w + + if k > 1: + array = tf.reshape(array, tf.concat([[-1], [final_h, final_w]], axis=0)) + for _ in range(k - 1): + array = tf.reverse(array, axis=[2]) + array = tf.transpose(array, [0, 2, 1]) + + final_shape = tf.concat([non_rot_shape, [final_h, final_w]], axis=0) + array = tf.reshape(array, final_shape) + + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + array = tf.transpose(array, inv_perm) + + return array + + +@sparse.elementwise_binary_union(tf.sparse.add) +def add(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + + # Special case of `tf.add`: `tf.nn.bias_add` + # `BiasAdd` can be fused with `MatMul` and `Conv*` kernels + # Expecting `x1` to be `inputs` and `x2` to be `bias` (no swapping) + x2_squeeze_shape = [d for d in x2.shape.as_list() if d is None or d > 1] + if ( + # `x2` looks like bias (can be squeezed to vector) + 1 == len(x2_squeeze_shape) + # `x1` looks like input tensor (rank >= 2) + and len(x1.shape) > 1 + # `x2` non-squeezable dimension defined + and x2_squeeze_shape[0] is not None + # `x2` non-squeezable dimension match `x1` channel dimension + and x2_squeeze_shape[0] + in {x1.shape.as_list()[1], x1.shape.as_list()[-1]} + ): + if x1.shape[-1] == x2_squeeze_shape[0]: + data_format = "NHWC" + else: + data_format = "NCHW" + if len(x2.shape) > 1: + x2 = tf.squeeze(x2) + return tf.nn.bias_add(x1, x2, data_format=data_format) + + return tf.add(x1, x2) + + +def bartlett(x): + x = convert_to_tensor(x, dtype=config.floatx()) + if x == 0: + return tf.constant([]) + if x == 1: + return tf.ones([1]) + + n = tf.range(x) + half = (x - 1) / 2 + + window = tf.where(n <= half, 2.0 * n / (x - 1), 2.0 - 2.0 * n / (x - 1)) + + return window + + +def hamming(x): + x = convert_to_tensor(x, dtype=tf.int32) + return tf.signal.hamming_window(x, periodic=False) + + +def hanning(x): + x = convert_to_tensor(x, dtype=tf.int32) + return tf.signal.hann_window(x, periodic=False) + + +def heaviside(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype in ["int64"]: + dtype = "float64" + + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + return tf.where( + x1 < 0, + tf.zeros_like(x1), + tf.where(x1 > 0, tf.ones_like(x1), x2), + ) + + +def kaiser(x, beta): + x = convert_to_tensor(x, dtype=tf.int32) + return tf.signal.kaiser_window(x, beta=beta) + + +def bincount(x, weights=None, minlength=0, sparse=False): + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype] + if standardize_dtype(x.dtype) not in ["int32", "int64"]: + x = tf.cast(x, tf.int32) + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + if standardize_dtype(weights.dtype) not in [ + "int32", + "int64", + "float32", + "float64", + ]: + if "int" in standardize_dtype(weights.dtype): + weights = tf.cast(weights, tf.int32) + else: + weights = tf.cast(weights, tf.float32) + else: + dtype = "int32" + if sparse or isinstance(x, tf.SparseTensor): + output = tf.sparse.bincount( + x, + weights=weights, + minlength=minlength, + axis=-1, + ) + actual_length = output.shape[-1] + if actual_length is None: + actual_length = tf.shape(output)[-1] + output = cast(output, dtype) + if x.shape.rank == 1: + output_shape = (actual_length,) + else: + batch_size = output.shape[0] + if batch_size is None: + batch_size = tf.shape(output)[0] + output_shape = (batch_size, actual_length) + return tf.SparseTensor( + indices=output.indices, + values=output.values, + dense_shape=output_shape, + ) + return tf.cast( + tf.math.bincount(x, weights=weights, minlength=minlength, axis=-1), + dtype, + ) + + +@functools.lru_cache(512) +def _normalize_einsum_subscripts(subscripts): + # string.ascii_letters + mapping = {} + normalized_subscripts = "" + for c in subscripts: + if c in string.ascii_letters: + if c not in mapping: + mapping[c] = string.ascii_letters[len(mapping)] + normalized_subscripts += mapping[c] + else: + normalized_subscripts += c + return normalized_subscripts + + +def einsum(subscripts, *operands, **kwargs): + operands = tree.map_structure(convert_to_tensor, operands) + subscripts = _normalize_einsum_subscripts(subscripts) + + def is_valid_for_custom_ops(subscripts, *operands): + # Check that `subscripts` is supported and the shape of operands is not + # `None`. + if subscripts in [ + "a,b->ab", + "ab,b->a", + "ab,bc->ac", + "ab,cb->ac", + "abc,cd->abd", + "abc,dc->abd", + "abcd,abde->abce", + "abcd,abed->abce", + "abcd,acbe->adbe", + "abcd,adbe->acbe", + "abcd,aecd->acbe", + "abcd,aecd->aceb", + ]: + # These subscripts don't require the shape information + return True + elif subscripts == "abc,cde->abde": + _, b1, c1 = operands[0].shape + c2, d2, e2 = operands[1].shape + b, c, d, e = b1, c1 or c2, d2, e2 + if None in (b, c, d, e): + return False + return True + elif subscripts == "abc,dce->abde": + _, b1, c1 = operands[0].shape + d2, c2, e2 = operands[1].shape + b, c, d, e = b1, c1 or c2, d2, e2 + if None in (b, c, d, e): + return False + return True + elif subscripts == "abc,dec->abde": + _, b1, c1 = operands[0].shape + d2, e2, c2 = operands[1].shape + b, c, d, e = b1, c1 or c2, d2, e2 + if None in (b, c, d, e): + return False + return True + elif subscripts == "abcd,cde->abe": + _, b1, c1, d1 = operands[0].shape + c2, d2, e2 = operands[1].shape + b, c, d, e = b1, c1 or c2, d1 or d2, e2 + if None in (b, c, d, e): + return False + return True + elif subscripts == "abcd,ced->abe": + _, b1, c1, d1 = operands[0].shape + c2, e2, d2 = operands[1].shape + b, c, d, e = b1, c1 or c2, d1 or d2, e2 + if None in (b, c, d, e): + return False + return True + elif subscripts == "abcd,ecd->abe": + _, b1, c1, d1 = operands[0].shape + e2, c2, d2 = operands[1].shape + b, c, d, e = b1, c1 or c2, d1 or d2, e2 + if None in (b, c, d, e): + return False + return True + elif subscripts == "abcde,aebf->adbcf": + _, b1, c1, d1, e1 = operands[0].shape + _, e2, b2, f2 = operands[1].shape + b, c, d, e, f = b1 or b2, c1, d1, e1 or e2, f2 + if None in (b, c, d, e, f): + return False + return True + elif subscripts == "abcde,afce->acdbf": + _, b1, c1, d1, e1 = operands[0].shape + _, f2, c2, e2 = operands[1].shape + b, c, d, e, f = b1, c1 or c2, d1, e1 or e2, f2 + if None in (b, c, d, e, f): + return False + return True + else: + # No match in subscripts + return False + + def use_custom_ops(subscripts, *operands, output_type): + # Replace tf.einsum with custom ops to utilize hardware-accelerated + # matmul + x, y = operands[0], operands[1] + if subscripts == "a,b->ab": + x = tf.expand_dims(x, axis=-1) + y = tf.expand_dims(y, axis=0) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "ab,b->a": + y = tf.expand_dims(y, axis=-1) + result = tf.matmul(x, y, output_type=output_type) + return tf.squeeze(result, axis=-1) + elif subscripts == "ab,bc->ac": + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "ab,cb->ac": + y = tf.transpose(y, [1, 0]) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abc,cd->abd": + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abc,cde->abde": + _, b1, c1 = x.shape + c2, d2, e2 = y.shape + b, c, d, e = b1, c1 or c2, d2, e2 + y = tf.reshape(y, [c, -1]) + result = tf.matmul(x, y, output_type=output_type) + return tf.reshape(result, [-1, b, d, e]) + elif subscripts == "abc,dc->abd": + y = tf.transpose(y, [1, 0]) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abc,dce->abde": + _, b1, c1 = x.shape + d2, c2, e2 = y.shape + b, c, d, e = b1, c1 or c2, d2, e2 + y = tf.transpose(y, [1, 0, 2]) # cde + y = tf.reshape(y, [c, -1]) + result = tf.matmul(x, y, output_type=output_type) + return tf.reshape(result, [-1, b, d, e]) + elif subscripts == "abc,dec->abde": + _, b1, c1 = x.shape + d2, e2, c2 = y.shape + b, c, d, e = b1, c1 or c2, d2, e2 + y = tf.transpose(y, [2, 0, 1]) # cde + y = tf.reshape(y, [c, -1]) + result = tf.matmul(x, y, output_type=output_type) + return tf.reshape(result, [-1, b, d, e]) + elif subscripts == "abcd,abde->abce": + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abcd,abed->abce": + y = tf.transpose(y, [0, 1, 3, 2]) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abcd,acbe->adbe": + x = tf.transpose(x, [0, 1, 3, 2]) + y = tf.transpose(y, [0, 2, 1, 3]) + result = tf.matmul(x, y, output_type=output_type) + return tf.transpose(result, [0, 2, 1, 3]) + elif subscripts == "abcd,adbe->acbe": + y = tf.transpose(y, [0, 2, 1, 3]) # abde + result = tf.matmul(x, y, output_type=output_type) # abce + return tf.transpose(result, [0, 2, 1, 3]) + elif subscripts == "abcd,aecd->acbe": + x = tf.transpose(x, [0, 2, 1, 3]) # acbd + y = tf.transpose(y, [0, 2, 3, 1]) # acde + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abcd,aecd->aceb": + x = tf.transpose(x, [0, 2, 1, 3]) + y = tf.transpose(y, [0, 2, 3, 1]) + result = tf.matmul(x, y, output_type=output_type) # acbe + return tf.transpose(result, [0, 1, 3, 2]) + elif subscripts == "abcd,cde->abe": + _, b1, c1, d1 = x.shape + c2, d2, e2 = y.shape + b, c, d, e = b1, c1 or c2, d1 or d2, e2 + x = tf.reshape(x, [-1, b, c * d]) + y = tf.reshape(y, [-1, e]) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abcd,ced->abe": + _, b1, c1, d1 = x.shape + c2, e2, d2 = y.shape + b, c, d, e = b1, c1 or c2, d1 or d2, e2 + x = tf.reshape(x, [-1, b, c * d]) + y = tf.transpose(y, [0, 2, 1]) + y = tf.reshape(y, [-1, e]) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abcd,ecd->abe": + _, b1, c1, d1 = x.shape + e2, c2, d2 = y.shape + b, c, d, e = b1, c1 or c2, d1 or d2, e2 + x = tf.reshape(x, [-1, b, c * d]) + y = tf.transpose(y, [1, 2, 0]) + y = tf.reshape(y, [-1, e]) + return tf.matmul(x, y, output_type=output_type) + elif subscripts == "abcde,aebf->adbcf": + _, b1, c1, d1, e1 = x.shape + _, e2, b2, f2 = y.shape + b, c, d, e, f = b1 or b2, c1, d1, e1 or e2, f2 + x = tf.reshape(x, [-1, b, c * d, e]) # ab(cd)e + y = tf.transpose(y, [0, 2, 1, 3]) # abef + result = tf.matmul(x, y, output_type=output_type) # ab(cd)f + result = tf.reshape(result, [-1, b, c, d, f]) # abcdf + return tf.transpose(result, [0, 3, 1, 2, 4]) + elif subscripts == "abcde,afce->acdbf": + _, b1, c1, d1, e1 = x.shape + _, f2, c2, e2 = y.shape + b, c, d, e, f = b1, c1 or c2, d1, e1 or e2, f2 + x = tf.transpose(x, [0, 2, 3, 1, 4]) # acdbe + x = tf.reshape(x, [-1, c, d * b, e]) # ac(db)e + y = tf.transpose(y, [0, 2, 3, 1]) # acef + result = tf.matmul(x, y, output_type=output_type) # ac(db)f + return tf.reshape(result, [-1, c, d, b, f]) + else: + raise NotImplementedError + + dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) + # When operands are of int8, we cast the result to int32 to align with + # the behavior of jax. + if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + compute_dtype = "int8" + result_dtype = "int32" + output_type = "int32" + else: + result_dtype = dtypes.result_type(*dtypes_to_resolve) + compute_dtype = result_dtype + output_type = None + + # TODO: Remove the condition once `tf.einsum` supports int8xint8->int32 + if is_valid_for_custom_ops(subscripts, *operands) and not kwargs: + # TODO: tf.matmul doesn't support integer dtype if not specifying + # output_type="int32" + if "int" in compute_dtype and output_type is None: + compute_dtype = config.floatx() + operands = tree.map_structure( + lambda x: tf.cast(x, compute_dtype), operands + ) + result = use_custom_ops(subscripts, *operands, output_type=output_type) + else: + # TODO: tf.einsum doesn't support integer dtype with gpu + if "int" in compute_dtype: + compute_dtype = config.floatx() + operands = tree.map_structure( + lambda x: tf.cast(x, compute_dtype), operands + ) + result = tf.einsum(subscripts, *operands, **kwargs) + return tf.cast(result, result_dtype) + + +@sparse.elementwise_binary_union(sparse.sparse_subtract) +def subtract(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.subtract(x1, x2) + + +def matmul(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + x1_shape = x1.shape + x2_shape = x2.shape + x1_sparse = isinstance(x1, tf.SparseTensor) + x2_sparse = isinstance(x2, tf.SparseTensor) + # When both x1 and x2 are of int8 and dense tensor, specifying `output_type` + # as int32 to enable hardware-accelerated matmul + x1_dtype = standardize_dtype(x1.dtype) + x2_dtype = standardize_dtype(x2.dtype) + if ( + x1_dtype == "int8" + and x2_dtype == "int8" + and not x1_sparse + and not x2_sparse + and x1_shape.rank != 1 # TODO: support tf.tensordot + and x2_shape.rank != 1 # TODO: support tf.tensordot + ): + compute_dtype = "int8" + result_dtype = "int32" + output_type = result_dtype + else: + # TODO: Typically, GPU and XLA only support float types + compute_dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + output_type = None + x1 = tf.cast(x1, compute_dtype) + x2 = tf.cast(x2, compute_dtype) + + def with_combined_batch_dimensions(a, b, output_shape, fn_3d): + a_sparse = isinstance(a, tf.SparseTensor) + b_sparse = isinstance(b, tf.SparseTensor) + batch_shape = b.shape[:-2] if b_sparse else a.shape[:-2] + batch_size = math.prod(batch_shape) + a3d_shape = [batch_size] + a.shape[-2:] + a_3d = ( + tf.sparse.reshape(a, a3d_shape) + if a_sparse + else tf.reshape(a, a3d_shape) + ) + b3d_shape = [batch_size] + b.shape[-2:] + b_3d = ( + tf.sparse.reshape(b, b3d_shape) + if b_sparse + else tf.reshape(b, b3d_shape) + ) + result_3d = fn_3d(a_3d, b_3d) + return ( + tf.sparse.reshape(result_3d, output_shape) + if isinstance(result_3d, tf.SparseTensor) + else tf.reshape(result_3d, output_shape) + ) + + def sparse_sparse_matmul(a, b): + dtype = a.values.dtype + # Convert SparseTensors to CSR SparseMatrix. + a_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( + a.indices, a.values, a.dense_shape + ) + b_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix( + b.indices, b.values, b.dense_shape + ) + # Compute the CSR SparseMatrix matrix multiplication. + result_csr = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul( + a_csr, b_csr, dtype + ) + # Convert the CSR SparseMatrix to a SparseTensor. + res = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor( + result_csr, dtype + ) + return tf.SparseTensor(res.indices, res.values, res.dense_shape) + + def embedding_lookup_sparse_dense_matmul(a, b): + # We need at least one id per rows for embedding_lookup_sparse, + # otherwise there will be missing rows in the output. + a, _ = tf.sparse.fill_empty_rows(a, 0) + # We need to split x1 into separate ids and weights tensors. The ids + # should be the column indices of x1 and the values of the weights + # can continue to be the actual x1. The column arrangement of ids + # and weights does not matter as we sum over columns. See details in + # the documentation for sparse_ops.sparse_tensor_dense_matmul. + ids = tf.SparseTensor( + indices=a.indices, + values=a.indices[:, 1], + dense_shape=a.dense_shape, + ) + return tf.nn.embedding_lookup_sparse(b, ids, a, combiner="sum") + + # Either a or b is sparse + def sparse_dense_matmul_3d(a, b): + return tf.map_fn( + lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]), + elems=(a, b), + fn_output_signature=a.dtype, + ) + + if x1_sparse or x2_sparse: + from keras.src.ops.operation_utils import compute_matmul_output_shape + + output_shape = compute_matmul_output_shape(x1_shape, x2_shape) + if x1_sparse and x2_sparse: + if x1_shape.rank <= 3: + output = sparse_sparse_matmul(x1, x2) + else: + output = with_combined_batch_dimensions( + x1, x2, output_shape, sparse_sparse_matmul + ) + else: + # Sparse * dense or dense * sparse + sparse_rank = x1_shape.rank if x1_sparse else x2_shape.rank + + # Special case: embedding_lookup_sparse for sparse * dense, rank 2 + if x1_sparse and sparse_rank == 2: + output = embedding_lookup_sparse_dense_matmul(x1, x2) + elif sparse_rank == 2: + output = tf.sparse.sparse_dense_matmul(x1, x2) + elif sparse_rank == 3: + output = sparse_dense_matmul_3d(x1, x2) + else: + output = with_combined_batch_dimensions( + x1, x2, output_shape, sparse_dense_matmul_3d + ) + output = tf.cast(output, result_dtype) + output.set_shape(output_shape) + return output + else: + if x1_shape.rank == 2 and x2_shape.rank == 2: + output = tf.matmul(x1, x2, output_type=output_type) + elif x2_shape.rank == 1: + output = tf.tensordot(x1, x2, axes=1) + elif x1_shape.rank == 1: + output = tf.tensordot(x1, x2, axes=[[0], [-2]]) + else: + output = tf.matmul(x1, x2, output_type=output_type) + return tf.cast(output, result_dtype) + + +@sparse.elementwise_binary_intersection +def multiply(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + if isinstance(x, tf.IndexedSlices): + if axis is None: + # Reduce against all axes, result is a single value and dense. + # The denominator has to account for `dense_shape`. + sum = tf.reduce_sum(x.values, keepdims=keepdims) + return sum / tf.cast(tf.reduce_prod(x.dense_shape), dtype=sum.dtype) + + axis = to_tuple_or_list(axis) + if not axis: + # Empty axis tuple, this is a no-op + return x + + dense_shape = tf.convert_to_tensor(x.dense_shape) + rank = tf.shape(dense_shape)[0] + # Normalize axis: convert negative values and sort + axis = [canonicalize_axis(a, rank) for a in axis] + axis.sort() + + if axis == [0]: + # Reduce against `axis=0` only, result is dense. + # The denominator has to account for `dense_shape[0]`. + sum = tf.reduce_sum(x.values, axis=0, keepdims=keepdims) + return sum / tf.cast(dense_shape[0], dtype=sum.dtype) + elif axis[0] == 0: + # Reduce against axis 0 and other axes, result is dense. + # We do `axis=0` separately first. The denominator has to account + # for `dense_shape[0]`. + # We use `keepdims=True` in `reduce_sum`` so that we can leave the + # 0 in axis and do `reduce_mean` with `keepdims` to apply it for all + # axes. + sum = tf.reduce_sum(x.values, axis=0, keepdims=True) + axis_0_mean = sum / tf.cast(dense_shape[0], dtype=sum.dtype) + return tf.reduce_mean(axis_0_mean, axis=axis, keepdims=keepdims) + elif keepdims: + # With `keepdims=True`, result is an `IndexedSlices` with the same + # indices since axis 0 is not touched. The only thing to do is to + # correct `dense_shape` to account for dimensions that became 1. + new_values = tf.reduce_mean(x.values, axis=axis, keepdims=True) + new_dense_shape = tf.concat( + [dense_shape[0:1], new_values.shape[1:]], axis=0 + ) + return tf.IndexedSlices(new_values, x.indices, new_dense_shape) + elif rank == len(axis) + 1: + # `keepdims=False` and reducing against all axes except 0, result is + # a 1D tensor, which cannot be `IndexedSlices`. We have to scatter + # the computed means to construct the correct dense tensor. + return tf.scatter_nd( + tf.expand_dims(x.indices, axis=1), + tf.reduce_mean(x.values, axis=axis), + [dense_shape[0]], + ) + else: + # `keepdims=False`, not reducing against axis 0 and there is at + # least one other axis we are not reducing against. We simply need + # to fix `dense_shape` to remove dimensions that were reduced. + gather_indices = [i for i in range(rank) if i not in axis] + return tf.IndexedSlices( + tf.reduce_mean(x.values, axis=axis), + x.indices, + tf.gather(x.dense_shape, gather_indices, axis=0), + ) + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + compute_dtype = dtypes.result_type(x.dtype, "float32") + # `tf.reduce_mean` does not handle low precision (e.g., float16) overflow + # correctly, so we compute with float32 and cast back to the original type. + if "int" in ori_dtype or ori_dtype == "bool": + result_dtype = compute_dtype + else: + result_dtype = ori_dtype + output = tf.reduce_mean( + tf.cast(x, compute_dtype), axis=axis, keepdims=keepdims + ) + return tf.cast(output, result_dtype) + + +def max(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + + # The TensorFlow numpy API implementation doesn't support `initial` so we + # handle it manually here. + if initial is not None: + if standardize_dtype(x.dtype) == "bool": + x = tf.reduce_any(x, axis=axis, keepdims=keepdims) + x = tf.math.maximum(tf.cast(x, "int32"), tf.cast(initial, "int32")) + return tf.cast(x, "bool") + else: + x = tf.reduce_max(x, axis=axis, keepdims=keepdims) + return tf.math.maximum(x, initial) + + # TensorFlow returns -inf by default for an empty list, but for consistency + # with other backends and the numpy API we want to throw in this case. + if tf.executing_eagerly(): + size_x = size(x) + tf.assert_greater( + size_x, + tf.constant(0, dtype=size_x.dtype), + message="Cannot compute the max of an empty tensor.", + ) + + if standardize_dtype(x.dtype) == "bool": + return tf.reduce_any(x, axis=axis, keepdims=keepdims) + else: + return tf.reduce_max(x, axis=axis, keepdims=keepdims) + + +def ones(shape, dtype=None): + dtype = dtype or config.floatx() + return tf.ones(shape, dtype=dtype) + + +def zeros(shape, dtype=None): + dtype = dtype or config.floatx() + return tf.zeros(shape, dtype=dtype) + + +@sparse.elementwise_unary +def absolute(x): + x = convert_to_tensor(x) + # uintx and bool are always non-negative + dtype = standardize_dtype(x.dtype) + if "uint" in dtype or dtype == "bool": + return x + return tf.abs(x) + + +def abs(x): + return absolute(x) + + +def all(x, axis=None, keepdims=False): + x = tf.cast(x, "bool") + return tf.reduce_all(x, axis=axis, keepdims=keepdims) + + +def angle(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.angle(x) + + +def any(x, axis=None, keepdims=False): + x = tf.cast(x, "bool") + return tf.reduce_any(x, axis=axis, keepdims=keepdims) + + +def amax(x, axis=None, keepdims=False): + return max(x, axis=axis, keepdims=keepdims) + + +def amin(x, axis=None, keepdims=False): + return min(x, axis=axis, keepdims=keepdims) + + +def append(x1, x2, axis=None): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + if axis is None: + return tf.concat([tf.reshape(x1, [-1]), tf.reshape(x2, [-1])], axis=0) + else: + return tf.concat([x1, x2], axis=axis) + + +def arange(start, stop=None, step=None, dtype=None): + if dtype is None: + dtypes_to_resolve = [getattr(start, "dtype", type(start))] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + if step is not None: + dtypes_to_resolve.append(getattr(step, "dtype", type(step))) + dtype = dtypes.result_type(*dtypes_to_resolve) + dtype = standardize_dtype(dtype) + if step is None: + step = 1 + try: + out = tf.range(start, stop, delta=step, dtype=dtype) + except tf.errors.NotFoundError: + # Some dtypes may not work in eager mode on CPU or GPU. + out = tf.range(start, stop, delta=step, dtype="float32") + out = tf.cast(out, dtype) + return out + + +@sparse.densifying_unary(0.5 * np.pi) +def arccos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.acos(x) + + +@sparse.densifying_unary(np.nan) +def arccosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.acosh(x) + + +@sparse.elementwise_unary +def arcsin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.asin(x) + + +@sparse.elementwise_unary +def arcsinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.asinh(x) + + +@sparse.elementwise_unary +def arctan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.atan(x) + + +def arctan2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.math.atan2(x1, x2) + + +@sparse.elementwise_unary +def arctanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.atanh(x) + + +def _keepdims(x, y, axis): + if axis is None: + shape = [1 for _ in range(len(x.shape))] + else: + shape = list(shape_op(x)) + for axis in tree.flatten(axis): + shape[axis] = 1 + y = tf.reshape(y, shape) + return y + + +def argmax(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "float" not in dtype or x.ndim == 0: + _x = x + if axis is None: + x = tf.reshape(x, [-1]) + y = tf.argmax(x, axis=axis, output_type="int32") + if keepdims: + y = _keepdims(_x, y, axis) + return y + + # Fix the flush-to-zero (FTZ) issue based on this issue: + # https://github.com/jax-ml/jax/issues/24280 + dtype = dtypes.result_type(dtype, "float32") + x = cast(x, dtype) + is_negative_zero = tf.logical_and(tf.equal(x, 0.0), signbit(x)) + x = tf.where( + is_negative_zero, -np.finfo(standardize_dtype(x.dtype)).tiny, x + ) + _x = x + if axis is None: + x = tf.reshape(x, [-1]) + y = tf.argmax(x, axis=axis, output_type="int32") + if keepdims: + y = _keepdims(_x, y, axis) + return y + + +def argmin(x, axis=None, keepdims=False): + from keras.src.testing.test_case import uses_cpu + + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if "float" not in dtype or not uses_cpu() or x.ndim == 0: + _x = x + if axis is None: + x = tf.reshape(x, [-1]) + y = tf.argmin(x, axis=axis, output_type="int32") + if keepdims: + y = _keepdims(_x, y, axis) + return y + + # Fix the flush-to-zero (FTZ) issue based on this issue: + # https://github.com/jax-ml/jax/issues/24280 + dtype = dtypes.result_type(dtype, "float32") + x = cast(x, dtype) + is_negative_zero = tf.logical_and(tf.equal(x, 0.0), signbit(x)) + x = tf.where( + is_negative_zero, -np.finfo(standardize_dtype(x.dtype)).tiny, x + ) + _x = x + if axis is None: + x = tf.reshape(x, [-1]) + y = tf.argmin(x, axis=axis, output_type="int32") + if keepdims: + y = _keepdims(_x, y, axis) + return y + + +def argsort(x, axis=-1): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "uint8") + + x_shape = x.shape + if x_shape.rank == 0: + return tf.cast([0], "int32") + + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + return tf.argsort(x, axis=axis) + + +def array(x, dtype=None): + return convert_to_tensor(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + x = convert_to_tensor(x) + + if weights is None: # Treat all weights as 1 + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + avg = tf.reduce_mean(x, axis=axis) + else: + weights = convert_to_tensor(weights) + dtype = dtypes.result_type(x.dtype, weights.dtype, float) + x = tf.cast(x, dtype) + weights = tf.cast(weights, dtype) + + def _rank_equal_case(): + weights_sum = tf.reduce_sum(weights, axis=axis) + return tf.reduce_sum(x * weights, axis=axis) / weights_sum + + def _rank_not_equal_case(): + weights_sum = tf.reduce_sum(weights) + axes = tf.convert_to_tensor([[axis], [0]]) + return tf.tensordot(x, weights, axes) / weights_sum + + if axis is None: + avg = _rank_equal_case() + else: + if len(x.shape) == len(weights.shape): + avg = _rank_equal_case() + else: + avg = _rank_not_equal_case() + return avg + + +def bitwise_and(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = tf.cast(x, dtype) + y = tf.cast(y, dtype) + return tf.bitwise.bitwise_and(x, y) + + +def bitwise_invert(x): + x = convert_to_tensor(x) + return tf.bitwise.invert(x) + + +def bitwise_not(x): + return bitwise_invert(x) + + +def bitwise_or(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = tf.cast(x, dtype) + y = tf.cast(y, dtype) + return tf.bitwise.bitwise_or(x, y) + + +def bitwise_xor(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = tf.cast(x, dtype) + y = tf.cast(y, dtype) + return tf.bitwise.bitwise_xor(x, y) + + +def bitwise_left_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = tf.cast(x, dtype) + y = tf.cast(y, dtype) + return tf.bitwise.left_shift(x, y) + + +def left_shift(x, y): + return bitwise_left_shift(x, y) + + +def bitwise_right_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + dtype = dtypes.result_type(x.dtype, y.dtype) + x = tf.cast(x, dtype) + y = tf.cast(y, dtype) + return tf.bitwise.right_shift(x, y) + + +def right_shift(x, y): + return bitwise_right_shift(x, y) + + +def blackman(x): + dtype = config.floatx() + x = tf.cast(x, dtype) + n = tf.range(x, dtype=dtype) + n_minus_1 = tf.cast(x - 1, dtype) + term1 = 0.42 + term2 = -0.5 * tf.cos(2 * np.pi * n / n_minus_1) + term3 = 0.08 * tf.cos(4 * np.pi * n / n_minus_1) + window = term1 + term2 + term3 + return window + + +def broadcast_to(x, shape): + return tf.broadcast_to(x, shape) + + +def cbrt(x): + x = convert_to_tensor(x) + + dtype = standardize_dtype(x.dtype) + if dtype == "int64": + x = tf.cast(x, "float64") + elif dtype not in ["bfloat16", "float16", "float64"]: + x = tf.cast(x, config.floatx()) + + return tf.sign(x) * tf.pow(tf.abs(x), 1.0 / 3.0) + + +@sparse.elementwise_unary +def ceil(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.ceil(x) + + +def clip(x, x_min, x_max): + dtype = standardize_dtype(x.dtype) + if dtype == "bool": + x = tf.cast(x, "int32") + return tf.clip_by_value(x, x_min, x_max) + + +def concatenate(xs, axis=0): + sparse_count = builtins.sum(isinstance(x, tf.SparseTensor) for x in xs) + if sparse_count: + if sparse_count == len(xs): + return tf.sparse.concat(axis=axis, sp_inputs=xs) + else: + xs = [ + ( + convert_to_tensor(x, sparse=False) + if isinstance(x, tf.SparseTensor) + else x + ) + for x in xs + ] + xs = tree.map_structure(convert_to_tensor, xs) + dtype_set = set([x.dtype for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs) + return tf.concat(xs, axis=axis) + + +@sparse.elementwise_unary +def conjugate(x): + return tf.math.conj(x) + + +@sparse.elementwise_unary +def conj(x): + return tf.math.conj(x) + + +@sparse.elementwise_unary +def copy(x): + x = convert_to_tensor(x) + return tf.identity(x) + + +@sparse.densifying_unary(1) +def cos(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.cos(x) + + +@sparse.densifying_unary(1) +def cosh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.cosh(x) + + +def count_nonzero(x, axis=None): + return tf.math.count_nonzero(x, axis=axis, dtype="int32") + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + if axis is not None: + axisa = axis + axisb = axis + axisc = axis + x1 = moveaxis(x1, axisa, -1) + x2 = moveaxis(x2, axisb, -1) + + def maybe_pad_zeros(x, size_of_last_dim): + def pad_zeros(x): + return tf.pad( + x, + tf.concat( + [ + tf.zeros([tf.rank(x) - 1, 2], "int32"), + tf.constant([[0, 1]], "int32"), + ], + axis=0, + ), + ) + + if isinstance(size_of_last_dim, int): + if size_of_last_dim == 2: + return pad_zeros(x) + return x + + return tf.cond( + tf.equal(size_of_last_dim, 2), lambda: pad_zeros(x), lambda: x + ) + + x1_dim = shape_op(x1)[-1] + x2_dim = shape_op(x2)[-1] + + x1 = maybe_pad_zeros(x1, x1_dim) + x2 = maybe_pad_zeros(x2, x2_dim) + + # Broadcast each other + shape = shape_op(x1) + + shape = tf.broadcast_dynamic_shape(shape, shape_op(x2)) + x1 = tf.broadcast_to(x1, shape) + x2 = tf.broadcast_to(x2, shape) + + c = tf.linalg.cross(x1, x2) + + if isinstance(x1_dim, int) and isinstance(x2_dim, int): + if (x1_dim == 2) & (x2_dim == 2): + return c[..., 2] + return moveaxis(c, -1, axisc) + + return tf.cond( + (x1_dim == 2) & (x2_dim == 2), + lambda: c[..., 2], + lambda: moveaxis(c, -1, axisc), + ) + + +def cumprod(x, axis=None, dtype=None): + x = convert_to_tensor(x, dtype=dtype) + # tf.math.cumprod doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + return tf.math.cumprod(x, axis=axis) + + +def cumsum(x, axis=None, dtype=None): + x = convert_to_tensor(x, dtype=dtype) + # tf.math.cumprod doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + if axis is None: + x = tf.reshape(x, [-1]) + axis = 0 + return tf.math.cumsum(x, axis=axis) + + +def deg2rad(x): + x = convert_to_tensor(x) + + dtype = x.dtype + if standardize_dtype(dtype) in [ + "bool", + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + ]: + dtype = config.floatx() + elif standardize_dtype(dtype) in ["int64"]: + dtype = "float64" + x = tf.cast(x, dtype) + + pi = tf.constant(math.pi, dtype=dtype) + return x * (pi / tf.constant(180.0, dtype=dtype)) + + +def diag(x, k=0): + x = convert_to_tensor(x) + if len(x.shape) == 1: + return tf.cond( + tf.equal(tf.size(x), 0), + lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype), + lambda: tf.linalg.diag(x, k=k), + ) + elif len(x.shape) == 2: + return diagonal(x, offset=k) + else: + raise ValueError(f"`x` must be 1d or 2d. Received: x.shape={x.shape}") + + +def diagflat(x, k=0): + x = convert_to_tensor(x) + return diag(tf.reshape(x, [-1]), k) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + x = convert_to_tensor(x) + x_rank = x.ndim + if ( + offset == 0 + and (axis1 == x_rank - 2 or axis1 == -2) + and (axis2 == x_rank - 1 or axis2 == -1) + ): + return tf.linalg.diag_part(x) + + x = moveaxis(x, (axis1, axis2), (-2, -1)) + x_shape = shape_op(x) + + def _zeros(): + return tf.zeros(tf.concat([x_shape[:-1], [0]], 0), dtype=x.dtype) + + if isinstance(x_shape[-1], int) and isinstance(x_shape[-2], int): + if offset <= -1 * x_shape[-2] or offset >= x_shape[-1]: + x = _zeros() + else: + x = tf.cond( + tf.logical_or( + tf.less_equal(offset, -1 * x_shape[-2]), + tf.greater_equal(offset, x_shape[-1]), + ), + lambda: _zeros(), + lambda: x, + ) + return tf.linalg.diag_part(x, k=offset) + + +def diff(a, n=1, axis=-1): + a = convert_to_tensor(a) + if n == 0: + return a + elif n < 0: + raise ValueError(f"Order `n` must be non-negative. Received n={n}") + elif a.ndim == 0: + raise ValueError( + "`diff` requires input that is at least one dimensional. " + f"Received: a={a}" + ) + axis = canonicalize_axis(axis, a.ndim) + slice1 = [slice(None)] * a.ndim + slice2 = [slice(None)] * a.ndim + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1_tuple = tuple(slice1) + slice2_tuple = tuple(slice2) + for _ in range(n): + if standardize_dtype(a.dtype) == "bool": + a = tf.not_equal(a[slice1_tuple], a[slice2_tuple]) + else: + a = tf.subtract(a[slice1_tuple], a[slice2_tuple]) + return a + + +def digitize(x, bins): + x = convert_to_tensor(x) + bins = list(bins) + + # bins must be float type + bins = tree.map_structure(lambda x: float(x), bins) + + # TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8 + # int16, uint8, uint16, uint32 + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype in ("bool", "int8", "int16", "uint8", "uint16"): + x = cast(x, "int32") + elif ori_dtype == "uint32": + x = cast(x, "int64") + elif ori_dtype in ("bfloat16", "float16"): + x = cast(x, "float32") + + if isinstance(x, tf.RaggedTensor): + return tf.ragged.map_flat_values( + lambda y: tf.raw_ops.Bucketize(input=y, boundaries=bins), x + ) + elif isinstance(x, tf.SparseTensor): + output = tf.SparseTensor( + indices=tf.identity(x.indices), + values=tf.raw_ops.Bucketize(input=x.values, boundaries=bins), + dense_shape=tf.identity(x.dense_shape), + ) + output.set_shape(x.shape) + return output + return tf.raw_ops.Bucketize(input=x, boundaries=bins) + + +def dot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + # GPU only supports float types + compute_dtype = dtypes.result_type(result_dtype, float) + x1 = tf.cast(x1, compute_dtype) + x2 = tf.cast(x2, compute_dtype) + + x_shape = x1.shape + y_shape = x2.shape + if x_shape.rank == 0 or y_shape.rank == 0: + output = x1 * x2 + elif y_shape.rank == 1: + output = tf.tensordot(x1, x2, axes=[[-1], [-1]]) + else: + output = tf.tensordot(x1, x2, axes=[[-1], [-2]]) + return tf.cast(output, result_dtype) + + +def empty(shape, dtype=None): + dtype = dtype or config.floatx() + return tf.zeros(shape, dtype=dtype) + + +def equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.equal(x1, x2) + + +@sparse.densifying_unary(1) +def exp(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = tf.cast(x, config.floatx()) + return tf.exp(x) + + +@sparse.densifying_unary(1) +def exp2(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = tf.cast(x, config.floatx()) + return tf.math.pow(2.0, x) + + +def expand_dims(x, axis): + x = convert_to_tensor(x) + axis = to_tuple_or_list(axis) + out_ndim = len(x.shape) + len(axis) + axis = sorted([canonicalize_axis(a, out_ndim) for a in axis]) + if isinstance(x, tf.SparseTensor): + from keras.src.ops.operation_utils import ( + compute_expand_dims_output_shape, + ) + + output_shape = compute_expand_dims_output_shape(x.shape, axis) + for a in axis: + x = tf.sparse.expand_dims(x, a) + x.set_shape(output_shape) + return x + for a in axis: + x = tf.expand_dims(x, a) + return x + + +@sparse.elementwise_unary +def expm1(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = tf.cast(x, config.floatx()) + return tf.math.expm1(x) + + +def flip(x, axis=None): + x = convert_to_tensor(x) + if axis is None: + return tf.reverse(x, tf.range(tf.rank(x))) + return tf.reverse(x, [axis]) + + +@sparse.elementwise_unary +def floor(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + return tf.floor(x) + + +def full(shape, fill_value, dtype=None): + dtype = dtype or config.floatx() + fill_value = convert_to_tensor(fill_value, dtype) + return tf.broadcast_to(fill_value, shape) + + +def full_like(x, fill_value, dtype=None): + x = convert_to_tensor(x) + dtype = dtypes.result_type(dtype or x.dtype) + fill_value = convert_to_tensor(fill_value, dtype) + return tf.broadcast_to(fill_value, tf.shape(x)) + + +def gcd(x1, x2): + x1 = tf.convert_to_tensor(x1) + x2 = tf.convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + if not x1.dtype.is_integer: + raise TypeError("Arguments to gcd must be integers.") + + target_shape = tf.broadcast_static_shape(x1.shape, x2.shape) + x1 = tf.broadcast_to(x1, target_shape) + x2 = tf.broadcast_to(x2, target_shape) + + def cond(a, b): + return tf.reduce_any(b != 0) + + def body(a, b): + b_safe = tf.where(tf.equal(b, 0), tf.ones_like(b), b) + return ( + tf.where(tf.not_equal(b, 0), b, a), + tf.where( + tf.not_equal(b, 0), + tf.math.floormod(a, b_safe), + tf.zeros_like(b), + ), + ) + + if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]: + x1 = tf.abs(x1) + x2 = tf.abs(x2) + + gcd_val, _ = tf.while_loop(cond, body, [x1, x2]) + return gcd_val + + +def greater(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.greater(x1, x2) + + +def greater_equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.greater_equal(x1, x2) + + +def hstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) + if len(xs[0].shape) == 1: + return tf.concat(xs, axis=0) + return tf.concat(xs, axis=1) + + +def hypot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype in ["int64"]: + dtype = "float64" + + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + x1_abs = tf.abs(x1) + x2_abs = tf.abs(x2) + max_val = tf.maximum(x1_abs, x2_abs) + min_val = tf.minimum(x1_abs, x2_abs) + + ratio = tf.math.divide_no_nan(min_val, max_val) + return max_val * tf.sqrt(1.0 + tf.square(ratio)) + + +def identity(n, dtype=None): + return eye(N=n, M=n, dtype=dtype) + + +@sparse.elementwise_unary +def imag(x): + return tf.math.imag(x) + + +def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + if "float" in dtype: + result = tf.abs(x1 - x2) <= (atol + rtol * tf.abs(x2)) + if equal_nan: + result = result | (is_nan(x1) & is_nan(x2)) + return result + else: + return tf.equal(x1, x2) + + +@sparse.densifying_unary(True) +def isfinite(x): + x = convert_to_tensor(x) + dtype_as_dtype = tf.as_dtype(x.dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return tf.ones(x.shape, tf.bool) + return tf.math.is_finite(x) + + +def isin(x1, x2, assume_unique=False, invert=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + output_shape = tf.shape(x1) + + x1 = tf.reshape(x1, [-1]) + x2 = tf.reshape(x2, [-1]) + + if not assume_unique: + x2 = tf.unique(x2)[0] + + if tf.size(x1) == 0 or tf.size(x2) == 0: + return tf.zeros(output_shape, dtype=tf.bool) + + cmp = tf.equal(tf.expand_dims(x1, 1), tf.expand_dims(x2, 0)) + result_flat = tf.reduce_any(cmp, axis=1) + + if invert: + result_flat = tf.logical_not(result_flat) + + return tf.reshape(result_flat, output_shape) + + +def isinf(x): + x = convert_to_tensor(x) + dtype_as_dtype = tf.as_dtype(x.dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return tf.zeros(x.shape, tf.bool) + return tf.math.is_inf(x) + + +def isnan(x): + x = convert_to_tensor(x) + dtype_as_dtype = tf.as_dtype(x.dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return tf.zeros(x.shape, tf.bool) + return tf.math.is_nan(x) + + +def isneginf(x): + x = convert_to_tensor(x) + dtype_as_dtype = tf.as_dtype(x.dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return tf.zeros_like(x, dtype=tf.bool) + return tf.math.equal(x, -tf.constant(float("inf"), dtype=x.dtype)) + + +def isposinf(x): + x = convert_to_tensor(x) + dtype_as_dtype = tf.as_dtype(x.dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return tf.zeros_like(x, dtype=tf.bool) + return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype)) + + +def kron(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + ndim_x1 = tf.rank(x1) + ndim_x2 = tf.rank(x2) + + def expand_front(x, num): + for _ in range(num): + x = tf.expand_dims(x, axis=0) + return x + + x1 = tf.cond( + ndim_x1 < ndim_x2, + lambda: expand_front(x1, ndim_x2 - ndim_x1), + lambda: x1, + ) + x2 = tf.cond( + ndim_x2 < ndim_x1, + lambda: expand_front(x2, ndim_x1 - ndim_x2), + lambda: x2, + ) + + x1_reshaped = tf.reshape( + x1, + tf.reshape( + tf.stack([tf.shape(x1), tf.ones_like(tf.shape(x1))], axis=1), [-1] + ), + ) + x2_reshaped = tf.reshape( + x2, + tf.reshape( + tf.stack([tf.ones_like(tf.shape(x2)), tf.shape(x2)], axis=1), [-1] + ), + ) + + out = tf.multiply(x1_reshaped, x2_reshaped) + out_shape = tf.multiply(tf.shape(x1), tf.shape(x2)) + out = tf.reshape(out, out_shape) + return out + + +def lcm(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + if not (x1.dtype.is_integer and x2.dtype.is_integer): + raise TypeError( + f"Arguments to lcm must be integers. " + f"Received: x1.dtype={x1.dtype.name}, x2.dtype={x2.dtype.name}" + ) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]: + x1 = tf.math.abs(x1) + x2 = tf.math.abs(x2) + + divisor = gcd(x1, x2) + divisor_safe = tf.where( + divisor == 0, tf.constant(1, dtype=divisor.dtype), divisor + ) + + result = x1 * (x2 // divisor_safe) + result = tf.where(divisor == 0, tf.zeros_like(result), result) + + return result + + +def less(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.less(x1, x2) + + +def less_equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.less_equal(x1, x2) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + if num < 0: + raise ValueError( + f"`num` must be a non-negative integer. Received: num={num}" + ) + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + else: + dtype = standardize_dtype(dtype) + start = convert_to_tensor(start, dtype=dtype) + stop = convert_to_tensor(stop, dtype=dtype) + step = convert_to_tensor(np.nan) + if endpoint: + result = tf.linspace(start, stop, num, axis=axis) + if num > 1: + step = (stop - start) / (tf.cast(num, dtype) - 1) + else: + # tf.linspace doesn't support endpoint=False, so we manually handle it + if num > 0: + step = (stop - start) / tf.cast(num, dtype) + if num > 1: + new_stop = tf.cast(stop, step.dtype) - step + start = tf.cast(start, new_stop.dtype) + result = tf.linspace(start, new_stop, num, axis=axis) + else: + result = tf.linspace(start, stop, num, axis=axis) + if dtype is not None: + if "int" in dtype: + result = tf.floor(result) + result = tf.cast(result, dtype) + if retstep: + return (result, step) + else: + return result + + +@sparse.densifying_unary(-np.inf) +def log(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + return tf.math.log(x) + + +@sparse.densifying_unary(-np.inf) +def log10(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + return tf.math.log(x) / tf.math.log(tf.constant(10, x.dtype)) + + +@sparse.elementwise_unary +def log1p(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + return tf.math.log1p(x) + + +@sparse.densifying_unary(-np.inf) +def log2(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + return tf.math.log(x) / tf.math.log(tf.constant(2, x.dtype)) + + +def logaddexp(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + delta = x1 - x2 + return tf.where( + tf.math.is_nan(delta), + x1 + x2, + tf.maximum(x1, x2) + tf.math.log1p(tf.math.exp(-tf.abs(delta))), + ) + + +def logaddexp2(x1, x2): + x1 = tf.convert_to_tensor(x1) + x2 = tf.convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + delta = x1 - x2 + log2 = tf.cast(tf.math.log(2.0), dtype) + return tf.where( + tf.math.is_nan(delta), + x1 + x2, + tf.maximum(x1, x2) + + tf.math.log1p(tf.math.exp(-tf.abs(delta) * log2)) / log2, + ) + + +def logical_and(x1, x2): + x1 = tf.cast(x1, "bool") + x2 = tf.cast(x2, "bool") + return tf.logical_and(x1, x2) + + +def logical_not(x): + x = tf.cast(x, "bool") + return tf.logical_not(x) + + +def logical_or(x1, x2): + x1 = tf.cast(x1, "bool") + x2 = tf.cast(x2, "bool") + return tf.logical_or(x1, x2) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + result = linspace( + start=start, + stop=stop, + num=num, + endpoint=endpoint, + dtype=dtype, + axis=axis, + ) + return tf.pow(tf.cast(base, result.dtype), result) + + +@sparse.elementwise_binary_union(tf.sparse.maximum, densify_mixed=True) +def maximum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.maximum(x1, x2) + + +def median(x, axis=None, keepdims=False): + return quantile(x, 0.5, axis=axis, keepdims=keepdims) + + +def meshgrid(*x, indexing="xy"): + return tf.meshgrid(*x, indexing=indexing) + + +def min(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + + # The TensorFlow numpy API implementation doesn't support `initial` so we + # handle it manually here. + if initial is not None: + if standardize_dtype(x.dtype) == "bool": + x = tf.reduce_all(x, axis=axis, keepdims=keepdims) + x = tf.math.minimum(tf.cast(x, "int32"), tf.cast(initial, "int32")) + return tf.cast(x, "bool") + else: + x = tf.reduce_min(x, axis=axis, keepdims=keepdims) + return tf.math.minimum(x, initial) + + # TensorFlow returns inf by default for an empty list, but for consistency + # with other backends and the numpy API we want to throw in this case. + if tf.executing_eagerly(): + size_x = size(x) + tf.assert_greater( + size_x, + tf.constant(0, dtype=size_x.dtype), + message="Cannot compute the min of an empty tensor.", + ) + + if standardize_dtype(x.dtype) == "bool": + return tf.reduce_all(x, axis=axis, keepdims=keepdims) + else: + return tf.reduce_min(x, axis=axis, keepdims=keepdims) + + +@sparse.elementwise_binary_union(tf.sparse.minimum, densify_mixed=True) +def minimum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.minimum(x1, x2) + + +def mod(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype == "bool": + dtype = "int32" + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.math.mod(x1, x2) + + +def moveaxis(x, source, destination): + x = convert_to_tensor(x) + + _source = to_tuple_or_list(source) + _destination = to_tuple_or_list(destination) + _source = tuple(canonicalize_axis(i, x.ndim) for i in _source) + _destination = tuple(canonicalize_axis(i, x.ndim) for i in _destination) + if len(_source) != len(_destination): + raise ValueError( + "Inconsistent number of `source` and `destination`. " + f"Received: source={source}, destination={destination}" + ) + # Directly return x if no movement is required + if _source == _destination: + return x + perm = [i for i in range(x.ndim) if i not in _source] + for dest, src in sorted(zip(_destination, _source)): + perm.insert(dest, src) + return tf.transpose(x, perm) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + x = convert_to_tensor(x) + + dtype = x.dtype + dtype_as_dtype = tf.as_dtype(dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return x + + # Replace NaN with `nan` + x = tf.where(tf.math.is_nan(x), tf.constant(nan, dtype), x) + + # Replace positive infinity with `posinf` or `dtype.max` + if posinf is None: + posinf = dtype.max + x = tf.where(tf.math.is_inf(x) & (x > 0), tf.constant(posinf, dtype), x) + + # Replace negative infinity with `neginf` or `dtype.min` + if neginf is None: + neginf = dtype.min + x = tf.where(tf.math.is_inf(x) & (x < 0), tf.constant(neginf, dtype), x) + + return x + + +def ndim(x): + x = convert_to_tensor(x) + return x.ndim + + +def nonzero(x): + x = convert_to_tensor(x) + result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1) + return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result) + + +def not_equal(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.not_equal(x1, x2) + + +def ones_like(x, dtype=None): + return tf.ones_like(x, dtype=dtype) + + +def zeros_like(x, dtype=None): + return tf.zeros_like(x, dtype=dtype) + + +def outer(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + return tf.reshape(x1, [-1, 1]) * tf.reshape(x2, [-1]) + + +def pad(x, pad_width, mode="constant", constant_values=None): + x = convert_to_tensor(x) + kwargs = {} + if constant_values is not None: + if mode != "constant": + raise ValueError( + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`. " + f"Received: mode={mode}" + ) + kwargs["constant_values"] = constant_values + pad_width = convert_to_tensor(pad_width, "int32") + return tf.pad(x, pad_width, mode.upper(), **kwargs) + + +def prod(x, axis=None, keepdims=False, dtype=None): + x = convert_to_tensor(x) + if dtype is None: + dtype = dtypes.result_type(x.dtype) + if dtype == "bool": + dtype = "int32" + elif dtype in ("int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + x = tf.cast(x, dtype) + return tf.reduce_prod(x, axis=axis, keepdims=keepdims) + + +def _quantile(x, q, axis=None, method="linear", keepdims=False): + # ref: tfp.stats.percentile + # float64 is needed here and below, else we get the wrong index if the array + # is huge along axis. + q = tf.cast(q, "float64") + + # Move `axis` dims of `x` to the rightmost, call it `y`. + if axis is None: + y = tf.reshape(x, [-1]) + else: + x_ndims = len(x.shape) + # _make_static_axis_non_negative_list + axis = [canonicalize_axis(a, x_ndims) for a in axis] + + # _move_dims_to_flat_end + other_dims = sorted(set(range(x_ndims)).difference(axis)) + perm = other_dims + list(axis) + x_permed = tf.transpose(a=x, perm=perm) + if None not in x.shape: + x_shape = list(x.shape) + other_shape = [x_shape[i] for i in other_dims] + end_shape = [math.prod([x_shape[i] for i in axis])] + full_shape = other_shape + end_shape + else: + other_shape = tf.gather(tf.shape(x), tf.cast(other_dims, tf.int64)) + full_shape = tf.concat([other_shape, [-1]], axis=0) + y = tf.reshape(x_permed, shape=full_shape) + + # Sort (in ascending order) everything which allows multiple calls to sort + # only once (under the hood) and use CSE. + sorted_y = tf.sort(y, axis=-1, direction="ASCENDING") + + d = tf.cast(tf.shape(y)[-1], "float64") + + def _get_indices(method): + """Get values of y at the indices implied by method.""" + if method == "lower": + indices = tf.math.floor((d - 1) * q) + elif method == "higher": + indices = tf.math.ceil((d - 1) * q) + elif method == "nearest": + indices = tf.round((d - 1) * q) + # d - 1 will be distinct from d in int32, but not necessarily double. + # So clip to avoid out of bounds errors. + return tf.clip_by_value( + tf.cast(indices, "int32"), 0, tf.shape(y)[-1] - 1 + ) + + if method in ["nearest", "lower", "higher"]: + gathered_y = tf.gather(sorted_y, _get_indices(method), axis=-1) + elif method == "midpoint": + gathered_y = 0.5 * ( + tf.gather(sorted_y, _get_indices("lower"), axis=-1) + + tf.gather(sorted_y, _get_indices("higher"), axis=-1) + ) + elif method == "linear": + larger_y_idx = _get_indices("higher") + exact_idx = (d - 1) * q + # preserve_gradients + smaller_y_idx = tf.maximum(larger_y_idx - 1, 0) + larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(y)[-1] - 1) + fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx + fraction = tf.cast(fraction, y.dtype) + gathered_y = ( + tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) + + tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction + ) + + # Propagate NaNs + if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64): + # Apparently tf.is_nan doesn't like other dtypes + nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis) + right_rank_matched_shape = tf.pad( + tf.shape(nan_batch_members), + paddings=[[0, tf.rank(q)]], + constant_values=1, + ) + nan_batch_members = tf.reshape( + nan_batch_members, shape=right_rank_matched_shape + ) + nan_value = tf.constant(float("NaN"), dtype=x.dtype) + gathered_y = tf.where(nan_batch_members, nan_value, gathered_y) + + # Expand dimensions if requested + if keepdims: + if axis is None: + ones_vec = tf.ones(shape=[tf.rank(x) + tf.rank(q)], dtype="int32") + gathered_y *= tf.ones(ones_vec, dtype=gathered_y.dtype) + else: + for i in sorted(axis): + gathered_y = tf.expand_dims(gathered_y, axis=i) + + # rotate_transpose + shift_value_static = tf.get_static_value(tf.rank(q)) + ndims = tf.TensorShape(gathered_y.shape).rank + if ndims < 2: + return gathered_y + shift_value_static = int( + math.copysign(1, shift_value_static) + * (builtins.abs(shift_value_static) % ndims) + ) + if shift_value_static == 0: + return gathered_y + perm = collections.deque(range(ndims)) + perm.rotate(shift_value_static) + return tf.transpose(a=gathered_y, perm=perm) + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + x = convert_to_tensor(x) + q = convert_to_tensor(q) + axis = to_tuple_or_list(axis) + compute_dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, compute_dtype) + return _quantile(x, q, axis=axis, method=method, keepdims=keepdims) + + +def ravel(x): + x = convert_to_tensor(x) + return tf.reshape(x, [-1]) + + +def unravel_index(indices, shape): + indices = tf.convert_to_tensor(indices) + input_dtype = indices.dtype + + if None in shape: + raise ValueError( + f"`shape` argument cannot contain `None`. Received: shape={shape}" + ) + + if indices.ndim == 1: + coords = [] + for dim in reversed(shape): + coords.append(tf.cast(indices % dim, input_dtype)) + indices = indices // dim + return tuple(reversed(coords)) + + indices_shape = indices.shape + coords = [] + for dim in shape: + coords.append( + tf.reshape(tf.cast(indices % dim, input_dtype), indices_shape) + ) + indices = indices // dim + + return tuple(reversed(coords)) + + +@sparse.elementwise_unary +def real(x): + x = convert_to_tensor(x) + return tf.math.real(x) + + +@sparse.densifying_unary(np.inf) +def reciprocal(x): + x = convert_to_tensor(x) + return tf.math.reciprocal(x) + + +def repeat(x, repeats, axis=None): + x = convert_to_tensor(x) + # TODO: tf.repeat doesn't support uint16 + if standardize_dtype(x.dtype) == "uint16": + x = tf.cast(x, "uint32") + return tf.cast(tf.repeat(x, repeats, axis=axis), "uint16") + return tf.repeat(x, repeats, axis=axis) + + +def reshape(x, newshape): + x = convert_to_tensor(x) + if isinstance(x, tf.SparseTensor): + from keras.src.ops.operation_utils import compute_reshape_output_shape + + output_shape = compute_reshape_output_shape( + x.shape, newshape, "newshape" + ) + output = tf.sparse.reshape(x, newshape) + output.set_shape(output_shape) + return output + return tf.reshape(x, newshape) + + +def roll(x, shift, axis=None): + x = convert_to_tensor(x) + if axis is not None: + return tf.roll(x, shift=shift, axis=axis) + + # If axis is None, the roll happens as a 1-d tensor. + original_shape = tf.shape(x) + x = tf.roll(tf.reshape(x, [-1]), shift, 0) + return tf.reshape(x, original_shape) + + +def searchsorted(sorted_sequence, values, side="left"): + if ndim(sorted_sequence) != 1: + raise ValueError( + "`searchsorted` only supports 1-D sorted sequences. " + "You can use `keras.ops.vectorized_map` " + "to extend it to N-D sequences. Received: " + f"sorted_sequence.shape={sorted_sequence.shape}" + ) + sequence_len = sorted_sequence.shape[0] + out_type = ( + "int32" + if sequence_len is not None and sequence_len <= np.iinfo(np.int32).max + else "int64" + ) + return tf.searchsorted( + sorted_sequence, values, side=side, out_type=out_type + ) + + +@sparse.elementwise_unary +def sign(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + # TODO: tf.sign doesn't support uint8, uint16, uint32 + if ori_dtype in ("uint8", "uint16", "uint32"): + x = tf.cast(x, "int32") + return tf.cast(tf.sign(x), ori_dtype) + return tf.sign(x) + + +@sparse.elementwise_unary +def signbit(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype == "bool": + return tf.fill(tf.shape(x), False) + elif "int" in ori_dtype: + return x < 0 + else: + x = cast(x, "float32") + return tf.less( + tf.bitwise.bitwise_and( + tf.bitcast(x, tf.int32), + # tf.float32 sign bit + tf.constant(tf.int32.min, dtype=tf.int32), + ), + 0, + ) + + +@sparse.elementwise_unary +def sin(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.sin(x) + + +@sparse.elementwise_unary +def sinh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.sinh(x) + + +def size(x): + x = convert_to_tensor(x) + return tf.size(x) + + +def sort(x, axis=-1): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + # TODO: tf.sort doesn't support bool + if ori_dtype == "bool": + x = tf.cast(x, "int8") + return tf.cast(tf.sort(x, axis=axis), ori_dtype) + return tf.sort(x, axis=axis) + + +def split(x, indices_or_sections, axis=0): + if not isinstance(indices_or_sections, int): + # `tf.split` requires `num_or_size_splits`, so we need to convert + # `indices_or_sections` to the appropriate format. + total_size = x.shape[axis] + indices_or_sections = convert_to_tensor(indices_or_sections) + start_size = indices_or_sections[0:1] + end_size = total_size - indices_or_sections[-1:] + num_or_size_splits = tf.concat( + [start_size, diff(indices_or_sections), end_size], axis=0 + ) + else: + num_or_size_splits = indices_or_sections + return tf.split(x, num_or_size_splits, axis=axis) + + +def stack(x, axis=0): + dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x) + return tf.stack(x, axis=axis) + + +def std(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = tf.cast(x, config.floatx()) + return tf.math.reduce_std(x, axis=axis, keepdims=keepdims) + + +def swapaxes(x, axis1, axis2): + x = convert_to_tensor(x) + + if ( + x.shape.rank is not None + and isinstance(axis1, int) + and isinstance(axis2, int) + ): + # This branch makes sure `perm` is statically known, to avoid a + # not-compile-time-constant XLA error. + axis1 = canonicalize_axis(axis1, x.ndim) + axis2 = canonicalize_axis(axis2, x.ndim) + + # Directly return x if no movement is required + if axis1 == axis2: + return x + + perm = list(range(x.ndim)) + perm[axis1] = axis2 + perm[axis2] = axis1 + else: + x_rank = tf.rank(x) + axis1 = tf.where(axis1 < 0, tf.add(axis1, x_rank), axis1) + axis2 = tf.where(axis2 < 0, tf.add(axis2, x_rank), axis2) + perm = tf.range(x_rank) + perm = tf.tensor_scatter_nd_update( + perm, [[axis1], [axis2]], [axis2, axis1] + ) + return tf.transpose(x, perm) + + +def take(x, indices, axis=None): + x = convert_to_tensor(x) + if axis is None: + x = tf.reshape(x, (-1,)) + axis = 0 + + def fix_negative_indices(i): + # Correct the indices using "fill" mode which is the same as in jax + return tf.where(i < 0, i + tf.cast(tf.shape(x)[axis], i.dtype), i) + + if isinstance(indices, tf.SparseTensor): + if x.dtype not in (tf.float16, tf.float32, tf.float64, tf.bfloat16): + warnings.warn( + "`take` with the TensorFlow backend does not support " + f"`x.dtype={x.dtype}` when `indices` is a sparse tensor; " + "densifying `indices`." + ) + indices = convert_to_tensor(indices, sparse=False) + elif axis != 0: + warnings.warn( + "`take` with the TensorFlow backend does not support " + f"`axis={axis}` when `indices` is a sparse tensor; " + "densifying `indices`." + ) + indices = convert_to_tensor(indices, sparse=False) + else: + indices = sparse.sparse_with_values( + indices, fix_negative_indices(indices.values) + ) + # `expand_dims` on `indices` prevents combiner from being applied. + output = tf.nn.safe_embedding_lookup_sparse( + embedding_weights=tf.convert_to_tensor(x), + sparse_ids=tf.sparse.expand_dims(indices, axis=-1), + default_id=0, + ) + output.set_shape(indices.shape + output.shape[len(indices.shape) :]) + return output + elif isinstance(indices, tf.RaggedTensor): + indices = indices.with_values(fix_negative_indices(indices.values)) + if axis == 0: + return tf.nn.embedding_lookup(x, indices) + else: + return tf.gather(x, indices, axis=axis) + + indices = fix_negative_indices(convert_to_tensor(indices)) + return tf.gather(x, indices, axis=axis) + + +def take_along_axis(x, indices, axis=None): + from keras.src.ops import operation_utils + + x = convert_to_tensor(x) + indices = convert_to_tensor(indices, "int64") + if axis is None: + if indices.ndim != 1: + raise ValueError( + "`indices` must be 1D if axis=None. " + f"Received: indices.shape={indices.shape}" + ) + return take_along_axis(tf.reshape(x, [-1]), indices, 0) + + # Compute the static output shape as later on, all shapes manipulations + # use dynamic shapes. + static_output_shape = operation_utils.compute_take_along_axis_output_shape( + x.shape, indices.shape, axis + ) + rank = x.ndim + static_axis = axis + axis = axis + rank if axis < 0 else axis + + if axis >= rank: + raise ValueError(f"Invalid axis: {static_axis} for input rank: {rank}") + + x_original_shape = shape_op(x) + indices_original_shape = shape_op(indices) + + # Broadcast the static shapes first, but not for the `axis` dimension. + x_static_shape = list(x.shape) + indices_static_shape = list(indices.shape) + x_static_shape[axis] = 1 + indices_static_shape[axis] = 1 + broadcast_shape = operation_utils.broadcast_shapes( + x_static_shape, indices_static_shape + ) + + if None in broadcast_shape: + # Dynamic broadcast case. Note that `tf.broadcast_dynamic_shape` is + # not always XLA compilable with dynamic dimensions. + # We replace `None`s with the dynamic dimensions. + # `maximum` is the correct formula only when shapes are broadcastable, + # we rely on the broacast itself to fail in the incorrect case rather + # than make some expensive dynamic checks here. + broadcast_shape = [ + tf.maximum(x_original_shape[i], indices_original_shape[i]) + if dim is None + else dim + for i, dim in enumerate(broadcast_shape) + ] + + x_shape = list(broadcast_shape) + x_shape[axis] = x_original_shape[axis] + indices_shape = list(broadcast_shape) + indices_shape[axis] = indices_original_shape[axis] + x = tf.broadcast_to(x, x_shape) + indices = tf.broadcast_to(indices, indices_shape) + + # Correct the indices using "fill" mode which is the same as in jax + indices = tf.where( + indices < 0, + indices + tf.cast(x_shape[static_axis], dtype=indices.dtype), + indices, + ) + + x = swapaxes(x, static_axis, -1) + indices = swapaxes(indices, static_axis, -1) + + x_shape = tf.shape(x) + x = tf.reshape(x, [-1, x_shape[-1]]) + indices_shape = tf.shape(indices) + indices = tf.reshape(indices, [-1, indices_shape[-1]]) + + result = tf.gather(x, indices, batch_dims=1) + result = tf.reshape(result, indices_shape) + result = swapaxes(result, static_axis, -1) + result.set_shape(static_output_shape) + return result + + +@sparse.elementwise_unary +def tan(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.tan(x) + + +@sparse.elementwise_unary +def tanh(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, dtype) + return tf.math.tanh(x) + + +def tensordot(x1, x2, axes=2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + # TODO: tf.tensordot only supports float types + compute_dtype = dtypes.result_type(result_dtype, float) + x1 = tf.cast(x1, compute_dtype) + x2 = tf.cast(x2, compute_dtype) + return tf.cast(tf.tensordot(x1, x2, axes=axes), dtype=result_dtype) + + +@sparse.elementwise_unary +def round(x, decimals=0): + if decimals == 0: + return tf.round(x) + x_dtype = x.dtype + if tf.as_dtype(x_dtype).is_integer: + # int + if decimals > 0: + return x + # temporarily convert to floats + factor = tf.cast(math.pow(10, decimals), config.floatx()) + x = tf.cast(x, config.floatx()) + else: + # float + factor = tf.cast(math.pow(10, decimals), x.dtype) + x = tf.multiply(x, factor) + x = tf.round(x) + x = tf.divide(x, factor) + return tf.cast(x, x_dtype) + + +def tile(x, repeats): + x = convert_to_tensor(x) + repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1]) + repeats_size = tf.size(repeats) + repeats = tf.pad( + repeats, + [[tf.maximum(x.shape.rank - repeats_size, 0), 0]], + constant_values=1, + ) + x_shape = tf.pad( + tf.shape(x), + [[tf.maximum(repeats_size - x.shape.rank, 0), 0]], + constant_values=1, + ) + x = tf.reshape(x, x_shape) + return tf.tile(x, repeats) + + +def trace(x, offset=0, axis1=0, axis2=1): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype not in ("int64", "uint32", "uint64"): + dtype = dtypes.result_type(dtype, "int32") + x_shape = tf.shape(x) + x = moveaxis(x, (axis1, axis2), (-2, -1)) + # Mask out the diagonal and reduce. + x = tf.where( + eye(x_shape[axis1], x_shape[axis2], k=offset, dtype="bool"), + x, + tf.zeros_like(x), + ) + # The output dtype is set to "int32" if the input dtype is "bool" + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype) + + +def tri(N, M=None, k=0, dtype=None): + M = M if M is not None else N + dtype = standardize_dtype(dtype or config.floatx()) + if k < 0: + lower = -k - 1 + if lower > N: + r = tf.zeros([N, M], dtype=dtype) + else: + o = tf.ones([N, M], dtype="bool") + r = tf.cast( + tf.logical_not(tf.linalg.band_part(o, lower, -1)), dtype=dtype + ) + else: + o = tf.ones([N, M], dtype=dtype) + if k > M: + r = o + else: + r = tf.linalg.band_part(o, -1, k) + return r + + +def tril(x, k=0): + x = convert_to_tensor(x) + + def _negative_k_branch(): + shape = tf.shape(x) + rows, cols = shape[-2], shape[-1] + i, j = tf.meshgrid(tf.range(rows), tf.range(cols), indexing="ij") + mask = i >= j - k + return tf.where(tf.broadcast_to(mask, shape), x, tf.zeros_like(x)) + + if isinstance(k, int): + if k >= 0: + return tf.linalg.band_part(x, -1, k) + return _negative_k_branch() + + # when `k` is a tensor + return tf.cond( + tf.greater_equal(k, 0), + lambda: tf.linalg.band_part(x, -1, k), + _negative_k_branch, + ) + + +def triu(x, k=0): + x = convert_to_tensor(x) + + def _positive_k_branch(): + shape = tf.shape(x) + rows, cols = shape[-2], shape[-1] + i, j = tf.meshgrid(tf.range(rows), tf.range(cols), indexing="ij") + mask = i <= j - k + return tf.where(tf.broadcast_to(mask, shape), x, tf.zeros_like(x)) + + if isinstance(k, int): + if k <= 0: + return tf.linalg.band_part(x, -k, -1) + return _positive_k_branch() + + # when `k` is a tensor + return tf.cond( + tf.less_equal(k, 0), + lambda: tf.linalg.band_part(x, -k, -1), + _positive_k_branch, + ) + + +def trunc(x): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype == "bool" or "int" in dtype: + return x + return tf.where(x < 0, tf.math.ceil(x), tf.math.floor(x)) + + +def vdot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + compute_dtype = dtypes.result_type(result_dtype, float) + x1 = tf.cast(x1, compute_dtype) + x2 = tf.cast(x2, compute_dtype) + x1 = tf.reshape(x1, [-1]) + x2 = tf.reshape(x2, [-1]) + return tf.cast(dot(x1, x2), result_dtype) + + +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + compute_dtype = dtypes.result_type(result_dtype, float) + x1 = tf.cast(x1, compute_dtype) + x2 = tf.cast(x2, compute_dtype) + x = tf.cond( + tf.math.logical_or( + tf.math.equal(tf.rank(x1), 0), + tf.math.equal(tf.rank(x2), 0), + ), + lambda: x1 * x2, + lambda: tf.tensordot(x1, x2, axes=[[-1], [-1]]), + ) + return tf.cast(x, result_dtype) + + +def vstack(xs): + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) + return tf.concat(xs, axis=0) + + +def _vmap_fn(fn, in_axes=0): + if in_axes != 0: + raise ValueError( + "Not supported with `vectorize()` with the TensorFlow backend." + ) + + @functools.wraps(fn) + def wrapped(x): + return tf.vectorized_map(fn, x) + + return wrapped + + +def vectorize(pyfunc, *, excluded=None, signature=None): + return vectorize_impl( + pyfunc, _vmap_fn, excluded=excluded, signature=signature + ) + + +def where(condition, x1=None, x2=None): + condition = tf.cast(condition, "bool") + if x1 is not None and x2 is not None: + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.where(condition, x1, x2) + if x1 is None and x2 is None: + return nonzero(condition) + raise ValueError( + "`x1` and `x2` either both should be `None`" + " or both should have non-None value." + ) + + +@sparse.elementwise_division +def divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.divide(x1, x2) + + +def divide_no_nan(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.math.divide_no_nan(x1, x2) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + # TODO: tf.pow doesn't support uint* types + if "uint" in dtype: + x1 = convert_to_tensor(x1, "int32") + x2 = convert_to_tensor(x2, "int32") + return tf.cast(tf.pow(x1, x2), dtype) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.pow(x1, x2) + + +@sparse.elementwise_unary +def negative(x): + return tf.negative(x) + + +@sparse.elementwise_unary +def square(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + x = tf.cast(x, "int32") + return tf.square(x) + + +@sparse.elementwise_unary +def sqrt(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = tf.cast(x, dtype) + return tf.math.sqrt(x) + + +def squeeze(x, axis=None): + x = convert_to_tensor(x) + axis = to_tuple_or_list(axis) + static_shape = x.shape.as_list() + if axis is not None: + for a in axis: + if static_shape[a] != 1: + raise ValueError( + f"Cannot squeeze axis={a}, because the dimension is not 1." + ) + axis = sorted([canonicalize_axis(a, len(static_shape)) for a in axis]) + if isinstance(x, tf.SparseTensor): + dynamic_shape = tf.shape(x) + new_shape = [] + gather_indices = [] + for i, dim in enumerate(static_shape): + if not (dim == 1 if axis is None else i in axis): + new_shape.append(dim if dim is not None else dynamic_shape[i]) + gather_indices.append(i) + new_indices = tf.gather(x.indices, gather_indices, axis=1) + return tf.SparseTensor(new_indices, x.values, tuple(new_shape)) + return tf.squeeze(x, axis=axis) + + +def transpose(x, axes=None): + if isinstance(x, tf.SparseTensor): + from keras.src.ops.operation_utils import compute_transpose_output_shape + + output = tf.sparse.transpose(x, perm=axes) + output.set_shape(compute_transpose_output_shape(x.shape, axes)) + return output + return tf.transpose(x, perm=axes) + + +def var(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + x = tf.cast(x, compute_dtype) + return tf.cast( + tf.math.reduce_variance(x, axis=axis, keepdims=keepdims), + result_dtype, + ) + + +def sum(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + # follow jax's rule + if dtype in ("bool", "int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + x = cast(x, dtype) + if isinstance(x, tf.SparseTensor): + return tf.sparse.reduce_sum( + x, axis=axis, keepdims=keepdims, output_is_sparse=True + ) + return tf.reduce_sum(x, axis=axis, keepdims=keepdims) + + +def eye(N, M=None, k=0, dtype=None): + dtype = dtype or config.floatx() + M = N if M is None else M + if isinstance(k, int) and k == 0: + return tf.eye(N, M, dtype=dtype) + # Create a smaller square eye and pad appropriately. + return tf.pad( + tf.eye(tf.minimum(M - k, N + k), dtype=dtype), + paddings=( + (tf.maximum(-k, 0), tf.maximum(N - M + k, 0)), + (tf.maximum(k, 0), tf.maximum(M - N - k, 0)), + ), + ) + + +def floor_divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return tf.math.floordiv(x1, x2) + + +def logical_xor(x1, x2): + x1 = tf.cast(x1, "bool") + x2 = tf.cast(x2, "bool") + return tf.math.logical_xor(x1, x2) + + +def corrcoef(x): + dtype = x.dtype + if dtype in ["bool", "int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + x = convert_to_tensor(x, dtype) + + if tf.rank(x) == 0: + return tf.constant(float("nan"), dtype=config.floatx()) + + mean = tf.reduce_mean(x, axis=-1, keepdims=True) + x_centered = x - mean + + num_samples = tf.cast(tf.shape(x)[-1], x.dtype) + cov_matrix = tf.matmul(x_centered, x_centered, adjoint_b=True) / ( + num_samples - 1 + ) + + diag = tf.linalg.diag_part(cov_matrix) + stddev = tf.sqrt(tf.math.real(diag)) + + outer_std = tf.tensordot(stddev, stddev, axes=0) + outer_std = tf.cast(outer_std, cov_matrix.dtype) + correlation = cov_matrix / outer_std + + correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0) + if correlation.dtype.is_complex: + imag_clipped = tf.clip_by_value(tf.math.imag(correlation), -1.0, 1.0) + return tf.complex(correlation_clipped, imag_clipped) + else: + return correlation_clipped + + +def correlate(x1, x2, mode="valid"): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if dtype == tf.int64: + dtype = tf.float64 + elif dtype not in [tf.bfloat16, tf.float16, tf.float64]: + dtype = tf.float32 + + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) + + x1_len, x2_len = int(x1.shape[0]), int(x2.shape[0]) + + if mode == "full": + full_len = x1_len + x2_len - 1 + + x1_pad = (full_len - x1_len) / 2 + x2_pad = (full_len - x2_len) / 2 + + x1 = tf.pad( + x1, paddings=[[tf.math.floor(x1_pad), tf.math.ceil(x1_pad)]] + ) + x2 = tf.pad( + x2, paddings=[[tf.math.floor(x2_pad), tf.math.ceil(x2_pad)]] + ) + + x1 = tf.reshape(x1, (1, full_len, 1)) + x2 = tf.reshape(x2, (full_len, 1, 1)) + + return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME")) + + x1 = tf.reshape(x1, (1, x1_len, 1)) + x2 = tf.reshape(x2, (x2_len, 1, 1)) + + return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper())) + + +def select(condlist, choicelist, default=0): + return tf.experimental.numpy.select(condlist, choicelist, default=default) + + +def slogdet(x): + x = convert_to_tensor(x) + return tuple(tf.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + x = convert_to_tensor(x, tf.int32) + + x = swapaxes(x, axis, -1) + bottom_ind = tf.math.top_k(-x, kth + 1).indices + + n = tf.shape(x)[-1] + + mask = tf.reduce_sum(tf.one_hot(bottom_ind, n, dtype=tf.int32), axis=0) + + indices = tf.where(mask) + updates = tf.squeeze(tf.zeros(tf.shape(indices)[0], dtype=tf.int32)) + + final_mask = tf.tensor_scatter_nd_update(x, indices, updates) + + top_ind = tf.math.top_k(final_mask, tf.shape(x)[-1] - kth - 1).indices + + out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1) + return swapaxes(out, -1, axis) + + +def histogram(x, bins=10, range=None): + """Computes a histogram of the data tensor `x`. + + Note: the `tf.histogram_fixed_width()` and + `tf.histogram_fixed_width_bins()` functions + yield slight numerical differences for some edge cases. + """ + + x = tf.convert_to_tensor(x, dtype=x.dtype) + + # Handle the range argument + if range is None: + min_val = tf.reduce_min(x) + max_val = tf.reduce_max(x) + else: + min_val, max_val = range + + x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val)) + bin_edges = tf.linspace(min_val, max_val, bins + 1) + bin_edges = tf.cast(bin_edges, x.dtype) + bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right") + + # tf.math.bincount does not work with XLA in this case. So, we use + # `scatter_nd`. + bin_counts = tf.scatter_nd( + indices=tf.expand_dims(bin_indices, axis=-1), + updates=tf.ones_like(bin_indices, dtype=x.dtype), + shape=(bins,), + ) + return bin_counts, bin_edges diff --git a/keras/src/backend/tensorflow/optimizer.py b/keras/src/backend/tensorflow/optimizer.py new file mode 100644 index 000000000000..f4497543d6ab --- /dev/null +++ b/keras/src/backend/tensorflow/optimizer.py @@ -0,0 +1,253 @@ +"""A class for Tensorflow specific optimizer logic. + +The major behavior change for this class is for tf.distribute. + +It will override methods from base Keras core Optimizer, +which provide distribute specific functionality, e.g. variable +creation, loss reduction, etc. +""" + +import warnings + +import tensorflow as tf + +from keras.src import backend +from keras.src.backend.tensorflow.trackable import KerasAutoTrackable +from keras.src.optimizers import base_optimizer + + +class TFOptimizer(KerasAutoTrackable, base_optimizer.BaseOptimizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._distribution_strategy = tf.distribute.get_strategy() + + def add_variable_from_reference( + self, reference_variable, name=None, initializer="zeros" + ): + if isinstance(reference_variable, backend.Variable): + colocate_var = reference_variable.value + else: + colocate_var = reference_variable + + with self._distribution_strategy.extended.colocate_vars_with( + colocate_var + ): + return super().add_variable_from_reference( + reference_variable, name=name, initializer=initializer + ) + + def stateless_apply(self, optimizer_variables, grads, trainable_variables): + # This is mainly due to the interaction with tf.distribute.Strategy, + # which requires tf.Variable as the inputs for most of its APIs. + raise ValueError( + "stateless_apply is not supported with the TensorFlow backend " + "(as it is incompatible with tf.distribute)." + ) + + def assign(self, variable, value): + if isinstance(variable, backend.Variable): + variable = variable.value + value = tf.cast(value, variable.dtype) + if isinstance(value, tf.IndexedSlices): + variable.scatter_update(value) + else: + variable.assign(value) + + def assign_add(self, variable, value): + if isinstance(variable, backend.Variable): + variable = variable.value + value = tf.cast(value, variable.dtype) + if isinstance(value, tf.IndexedSlices): + variable.scatter_add(value) + else: + variable.assign_add(value) + + def assign_sub(self, variable, value): + if isinstance(variable, backend.Variable): + variable = variable.value + value = tf.cast(value, variable.dtype) + if isinstance(value, tf.IndexedSlices): + variable.scatter_sub(value) + else: + variable.assign_sub(value) + + def _var_key(self, variable): + if isinstance(variable, backend.Variable): + variable = variable.value # Convert to tf.Variable + if hasattr(variable, "_distributed_container"): + variable = variable._distributed_container() + elif ( + isinstance(variable, tf.__internal__.CompositeTensor) + and hasattr(variable, "handle") + and hasattr(variable.handle, "_distributed_container") + ): + # For ResourceVariables, the _distributed_container attribute + # is added to their handle tensors. + variable = variable.handle._distributed_container() + return variable._unique_id + + def _apply_weight_decay(self, variables): + if self.weight_decay is None: + return + + def distributed_apply_weight_decay(distribution, variables, **kwargs): + def weight_decay_fn(variable): + if self._use_weight_decay(variable): + lr = tf.cast(self.learning_rate, variable.dtype) + wd = tf.cast(self.weight_decay, variable.dtype) + variable.assign_sub(variable * wd * lr) + + for variable in variables: + if isinstance(variable, backend.Variable): + variable = variable.value # Convert to tf.Variable + distribution.extended.update( + variable, weight_decay_fn, group=False + ) + + tf.__internal__.distribute.interim.maybe_merge_call( + distributed_apply_weight_decay, + self._distribution_strategy, + variables, + ) + + def _backend_update_step(self, grads, trainable_variables, learning_rate): + trainable_variables = [ + v.value if isinstance(v, backend.Variable) else v + for v in trainable_variables + ] + grads_and_vars = list(zip(grads, trainable_variables)) + grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars) + tf.__internal__.distribute.interim.maybe_merge_call( + self._distributed_tf_update_step, + self._distribution_strategy, + grads_and_vars, + learning_rate, + ) + + def _distributed_tf_update_step( + self, distribution, grads_and_vars, learning_rate + ): + def apply_grad_to_update_var(var, grad, learning_rate): + return self.update_step(grad, var, learning_rate) + + for grad, var in grads_and_vars: + distribution.extended.update( + var, + apply_grad_to_update_var, + args=(grad, learning_rate), + group=False, + ) + + def _all_reduce_sum_gradients(self, grads_and_vars): + """Returns all-reduced gradients aggregated via summation. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + + Returns: + List of (gradient, variable) pairs + where gradients have been all-reduced. + """ + replica_context = tf.distribute.get_replica_context() + if not replica_context: + return grads_and_vars + + grads_and_vars = list(grads_and_vars) + filtered_grads_and_vars = filter_empty_gradients(grads_and_vars) + if filtered_grads_and_vars: + grads = [pair[0] for pair in filtered_grads_and_vars] + reduced = tf.distribute.get_replica_context().all_reduce( + tf.distribute.ReduceOp.SUM, grads + ) + else: + reduced = [] + # Copy 'reduced' but add None gradients back in + reduced_with_nones = [] + reduced_pos = 0 + for g, v in grads_and_vars: + if g is None: + reduced_with_nones.append((None, v)) + else: + reduced_with_nones.append((reduced[reduced_pos], v)) + reduced_pos += 1 + assert reduced_pos == len(reduced), "Failed to add all gradients" + return reduced_with_nones + + def _overwrite_model_variables_with_average_value( + self, trainable_variables + ): + """Overwrite model variables with their moving average values. + + This function overwrites variables on each device. + + Args: + var_list: list of model variables. + """ + trainable_variables = [ + v.value if isinstance(v, backend.Variable) else v + for v in trainable_variables + ] + # Override model variable by the stored average value on all devices. + for var, average_var in zip( + trainable_variables, self._model_variables_moving_average + ): + self._distribution_strategy.extended.update( + var, lambda a, b: a.assign(b), args=(average_var,) + ) + + def _backend_increment_gradient_accumulators(self, grads, acc_grads): + def update_accumulator(var, grad): + var.assign(var + grad) + + accumulators = [v.value for v in acc_grads] + + def _distributed_tf_increment_grad_acc( + distribution, grads, accumulators + ): + for grad, var in zip(grads, accumulators): + distribution.extended.update( + var, update_accumulator, args=(grad,), group=False + ) + + tf.__internal__.distribute.interim.maybe_merge_call( + _distributed_tf_increment_grad_acc, + self._distribution_strategy, + grads, + accumulators, + ) + + def _clip_by_norm(self, values, axes=None): + # We need to use TF-specific OP to support the case, + # when `values` are `tf.IndexedSlices`. + return tf.clip_by_norm(values, self.clipnorm, axes) + + +def filter_empty_gradients(grads_and_vars): + """Filter out `(grad, var)` pairs that have a gradient equal to `None`.""" + grads_and_vars = tuple(grads_and_vars) + if not grads_and_vars: + return grads_and_vars + + filtered = [] + vars_with_empty_grads = [] + for grad, var in grads_and_vars: + if grad is None: + vars_with_empty_grads.append(var) + else: + filtered.append((grad, var)) + filtered = tuple(filtered) + + if not filtered: + variable = ([v.name for _, v in grads_and_vars],) + raise ValueError( + f"No gradients provided for any variable: {variable}. " + f"Provided `grads_and_vars` is {grads_and_vars}." + ) + if vars_with_empty_grads: + warnings.warn( + "Gradients do not exist for variables %s when minimizing the " + "loss. If you're using `model.compile()`, did you forget to " + "provide a `loss` argument?", + ([v.name for v in vars_with_empty_grads]), + ) + return filtered diff --git a/keras/src/backend/tensorflow/optimizer_distribute_test.py b/keras/src/backend/tensorflow/optimizer_distribute_test.py new file mode 100644 index 000000000000..fe31af4366a5 --- /dev/null +++ b/keras/src/backend/tensorflow/optimizer_distribute_test.py @@ -0,0 +1,212 @@ +# flake8: noqa + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized +from tensorflow.python.eager import context + +from keras.src import backend +from keras.src import testing +from keras.src.optimizers.sgd import SGD + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The distribute test can only run with TF backend.", +) +class OptimizerDistributeTest(testing.TestCase): + def setUp(self): + super().setUp() + # Need at least 2 devices for distribution related tests. + cpus = tf.config.list_physical_devices("CPU") + context._reset_context() + tf.config.set_logical_device_configuration( + cpus[0], + [ + tf.config.LogicalDeviceConfiguration(), + tf.config.LogicalDeviceConfiguration(), + ], + ) + self.strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + + def test_config(self): + with self.strategy.scope(): + optimizer = SGD( + learning_rate=0.5, + momentum=0.06, + nesterov=True, + weight_decay=0.004, + ) + self.run_class_serialization_test(optimizer) + + @parameterized.parameters([("keras_sgd",), ("tf_keras_sgd",)]) + def test_single_step(self, optimizer_type): + if optimizer_type == "tf_keras_sgd": + try: + import tf_keras + + optimizer_fn = tf_keras.optimizers.SGD + except (ImportError, AttributeError): + self.skipTest("tf_keras not installed") + else: + optimizer_fn = SGD + with self.strategy.scope(): + optimizer = optimizer_fn( + learning_rate=0.5, + momentum=0.06, + ) + # use tf variable to work both in k2 & k3. + vars = tf.Variable([1.0, 2.0, 3.0, 4.0]) + + def update(): + grads = tf.constant([1.0, 6.0, 7.0, 2.0]) + optimizer.apply_gradients(zip([grads], [vars])) + + self.strategy.run(update) + self.assertAllClose( + vars, [0.0, -4.0, -4.0, 2.0], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + with self.strategy.scope(): + grads, var1, var2, var3 = ( + tf.zeros(()), + backend.Variable(2.0), + backend.Variable(3.0, name="exclude"), + backend.Variable(4.0), + ) + optimizer_1 = SGD(learning_rate=1.0, weight_decay=0.004) + self.strategy.run( + lambda: optimizer_1.apply_gradients(zip([grads], [var1])) + ) + + optimizer_2 = SGD(learning_rate=1.0, weight_decay=0.004) + + def opt2_run(): + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + self.strategy.run(opt2_run) + + optimizer_3 = SGD(learning_rate=1.0, weight_decay=0.004) + + def opt3_run(): + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.strategy.run(opt3_run) + + self.assertAlmostEqual(var1.numpy(), 1.9760959) + self.assertAlmostEqual(var2.numpy(), 3.0) + self.assertAlmostEqual(var3.numpy(), 4.0) + + def test_correctness_with_golden(self): + with self.strategy.scope(): + optimizer = SGD(nesterov=True) + x = backend.Variable(np.ones([10])) + + def update_grads(): + grads = backend.convert_to_tensor(np.arange(0.1, 1.1, 0.1)) + optimizer.apply_gradients(zip([grads], [x])) + + def update_first_grads(): + first_grads = backend.convert_to_tensor(np.full((10,), 0.01)) + optimizer.apply_gradients(zip([first_grads], [x])) + + # fmt: off + golden = np.array( + [ + [0.9980, 0.9960, 0.9940, 0.9920, 0.9900, 0.9880, 0.9860, 0.9840, 0.9820, 0.9800], + [0.9978, 0.9958, 0.9938, 0.9918, 0.9898, 0.9878, 0.9858, 0.9838, 0.9818, 0.9798], + [0.9976, 0.9956, 0.9936, 0.9916, 0.9896, 0.9876, 0.9856, 0.9836, 0.9816, 0.9796], + [0.9974, 0.9954, 0.9934, 0.9914, 0.9894, 0.9874, 0.9854, 0.9834, 0.9814, 0.9794], + [0.9972, 0.9952, 0.9932, 0.9912, 0.9892, 0.9872, 0.9852, 0.9832, 0.9812, 0.9792], + ] + ) + # fmt: on + + self.strategy.run(update_grads) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + self.strategy.run(update_first_grads) + + def test_clip_norm(self): + with self.strategy.scope(): + optimizer = SGD(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + with self.strategy.scope(): + optimizer = SGD(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + def test_stateless_not_supported(self): + optimizer = SGD(learning_rate=0.5) + grads = [np.array([1.0, 6.0, 7.0, 2.0])] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + with self.assertRaisesRegex(ValueError, "not supported"): + optimizer.stateless_apply(optimizer.variables, grads, vars) + + def test_ema(self): + with self.strategy.scope(): + v = backend.Variable([[3.0, 4.0], [5.0, 6.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = SGD( + learning_rate=1.0, + use_ema=True, + ema_momentum=0.9, + ema_overwrite_frequency=3, + ) + self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) + self.assertAllClose(v, [[2.0, 3.0], [4.0, 5.0]]) + self.assertAllClose( + optimizer._model_variables_moving_average[0], + [[2.0, 3.0], [4.0, 5.0]], # initialized after first step + ) + self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._model_variables_moving_average[0], + [[1.9, 2.9], [3.9, 4.9]], + ) + self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) + # Variables were overwritten with EMA + self.assertAllClose(v, [[1.71, 2.71], [3.71, 4.71]]) + self.assertAllClose( + optimizer._model_variables_moving_average[0], + [[1.71, 2.71], [3.71, 4.71]], + ) + + def test_gradient_accumulation(self): + with self.strategy.scope(): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [2.0, 2.0]]) + optimizer = SGD(learning_rate=1.0, gradient_accumulation_steps=3) + self.assertEqual(optimizer.gradient_accumulation_steps, 3) + self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[1.0, 1.0], [2.0, 2.0]] + ) + self.assertAllClose(optimizer._iterations, 1) + self.assertAllClose(optimizer.iterations, 0) + self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[2.0, 2.0], [4.0, 4.0]] + ) + self.assertAllClose(optimizer._iterations, 2) + self.assertAllClose(optimizer.iterations, 0) + self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)])) + self.assertAllClose(v, [[-1.0, 0.0], [-1.0, 0.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]] + ) + self.assertAllClose(optimizer._iterations, 3) + self.assertAllClose(optimizer.iterations, 1) diff --git a/keras/src/backend/tensorflow/random.py b/keras/src/backend/tensorflow/random.py new file mode 100644 index 000000000000..e807b0de9aab --- /dev/null +++ b/keras/src/backend/tensorflow/random.py @@ -0,0 +1,190 @@ +import tensorflow as tf + +from keras.src.backend.common import standardize_dtype +from keras.src.backend.config import floatx +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +def _cast_seed(seed): + # TensorFlow has a device placement issue that `Variable` must be int64 + # in `SeedGenerator`. However, all `tf.random.stateless_*` expect the seed + # to be int32 to run with XLA. + # This function addresses the inconsistency using `floormod`. + # Ref: https://www.tensorflow.org/api_docs/python/tf/random + if standardize_dtype(seed.dtype) == "int32": + return seed + else: + seed = tf.cast(tf.math.floormod(seed, tf.int32.max - 1), dtype="int32") + return seed + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = _cast_seed(draw_seed(seed)) + return tf.random.stateless_normal( + shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = _cast_seed(draw_seed(seed)) + return tf.random.stateless_uniform( + shape=shape, + minval=tf.cast(minval, dtype), + maxval=tf.cast(maxval, dtype), + dtype=dtype, + seed=seed, + ) + + +def categorical(logits, num_samples, dtype="int64", seed=None): + seed = _cast_seed(draw_seed(seed)) + output = tf.random.stateless_categorical(logits, num_samples, seed=seed) + return tf.cast(output, dtype) + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + intermediate_dtype = dtype + if standardize_dtype(dtype) not in ["int32", "int64"]: + intermediate_dtype = "int64" + seed = _cast_seed(draw_seed(seed)) + output = tf.random.stateless_uniform( + shape=shape, + minval=minval, + maxval=maxval, + dtype=intermediate_dtype, + seed=seed, + ) + return tf.cast(output, dtype) + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + seed = _cast_seed(draw_seed(seed)) + return tf.random.stateless_truncated_normal( + shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) + + +def _get_concrete_noise_shape(inputs, noise_shape): + if noise_shape is None: + return tf.shape(inputs) + + concrete_inputs_shape = tf.shape(inputs) + concrete_noise_shape = [] + for i, value in enumerate(noise_shape): + concrete_noise_shape.append( + concrete_inputs_shape[i] if value is None else value + ) + return concrete_noise_shape + + +def dropout(inputs, rate, noise_shape=None, seed=None): + seed = _cast_seed(draw_seed(seed)) + noise_shape = _get_concrete_noise_shape(inputs, noise_shape) + return tf.nn.experimental.stateless_dropout( + inputs, + rate=rate, + noise_shape=noise_shape, + seed=seed, + ) + + +def shuffle(x, axis=0, seed=None): + seed = _cast_seed(draw_seed(seed)) + indices = tf.argsort( + tf.random.stateless_uniform(shape=[tf.shape(x)[axis]], seed=seed) + ) + return tf.gather(x, indices, axis=axis) + + +def gamma(shape, alpha, dtype=None, seed=None): + dtype = dtype or floatx() + seed = _cast_seed(draw_seed(seed)) + # TODO: `tf.random.stateless_gamma` doesn't support bfloat16 + intermediate_dtype = dtype + if standardize_dtype(dtype) == "bfloat16": + intermediate_dtype = "float32" + return tf.cast( + tf.random.stateless_gamma( + shape, + alpha=alpha, + dtype=intermediate_dtype, + seed=seed, + ), + dtype, + ) + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + dtype = dtype or floatx() + seed = _cast_seed(draw_seed(seed)) + # TODO: `tf.random.stateless_binomial` doesn't support bfloat16 + intermediate_dtype = dtype + if standardize_dtype(dtype) == "bfloat16": + intermediate_dtype = "float32" + return tf.cast( + tf.random.stateless_binomial( + shape=shape, + seed=seed, + counts=counts, + probs=probabilities, + output_dtype=intermediate_dtype, + ), + dtype, + ) + + +def beta(shape, alpha, beta, dtype=None, seed=None): + dtype = dtype or floatx() + # since tensorflow doesn't offer a beta distribution function + # so we'll use the formula U(a,b) = (X(a) / (X(a) + Y(b)), + # where U(a,b) is a beta-distributed random variable with + # parameters a and b, and X(a) and Y(b) are gamma-distributed + # random variables with parameters a and b respectively. + + # Additionally, we'll use two different seeds for our two + # gamma random variables to prevent any unintended + # dependencies and correlations between the generated values + # due to the usage of same seed. + seed_1 = _cast_seed(draw_seed(seed)) + # The choice of 12 is totally arbitrary, as we're + # incrementing the first drawn seed by a CONSTANT to + # ensure deterministic results. + seed_2 = seed_1 + 12 + + # TODO: `tf.random.stateless_gamma` doesn't support bfloat16 + intermediate_dtype = dtype + if standardize_dtype(dtype) == "bfloat16": + intermediate_dtype = "float32" + alpha = tf.convert_to_tensor(alpha, dtype=intermediate_dtype) + beta = tf.convert_to_tensor(beta, dtype=intermediate_dtype) + + # tensorflow's tf.random.stateless_gamma has a bit of unconventional + # implementation of the stateless_gamma function where it checks the + # broadcastability of alpha's shape with ONLY the RIGHTMOST dimension of + # the specified output shape instead of considering the whole. + # Consequently, it then results in errors for perfectly broadcastable shapes + # such as for output shape of (2, 3) and alpha shape of (1, 3) + # So to resolve this, we explicitly broadcast alpha and beta to shape before + # passing them to the stateless_gamma function. + alpha = tf.broadcast_to(alpha, shape) + beta = tf.broadcast_to(beta, shape) + + gamma_a = tf.cast( + tf.random.stateless_gamma( + shape=shape, seed=seed_1, alpha=alpha, dtype=intermediate_dtype + ), + dtype, + ) + gamma_b = tf.cast( + tf.random.stateless_gamma( + shape=shape, seed=seed_2, alpha=beta, dtype=intermediate_dtype + ), + dtype, + ) + sample = gamma_a / (gamma_a + gamma_b) + return sample diff --git a/keras/src/backend/tensorflow/rnn.py b/keras/src/backend/tensorflow/rnn.py new file mode 100644 index 000000000000..06d450a18838 --- /dev/null +++ b/keras/src/backend/tensorflow/rnn.py @@ -0,0 +1,971 @@ +import tensorflow as tf + +from keras.src import tree + + +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + """Iterates over the time dimension of a tensor. + + Args: + step_function: RNN step function. + Args; + `input`; Tensor with shape `(samples, ...)` (no time dimension), + representing input for the batch of samples at a certain + time step. + `states`; List of tensors. + Returns; + `output`; Tensor with shape `(samples, output_dim)` + (no time dimension). + `new_states`; List of tensors, same length and shapes + as 'states'. The first state in the list must be the + output tensor at the previous timestep. + inputs: Tensor of temporal data of shape `(samples, time, ...)` + (at least 3D), or nested tensors, and each of which has shape + `(samples, time, ...)`. + initial_states: Tensor with shape `(samples, state_size)` + (no time dimension), containing the initial values for the states + used in the step function. In the case that state_size is in a + nested shape, the shape of initial_states will also follow the + nested structure. + go_backwards: Boolean. If `True`, do the iteration over the time + dimension in reverse order and return the reversed sequence. + mask: Binary tensor with shape `(samples, time, 1)`, + with a zero for every element that is masked. + constants: List of constant values passed at each step. + unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. + input_length: An integer or a 1-D Tensor, depending on whether + the time dimension is fixed-length or not. In case of variable + length input, it is used for masking in case there's no mask + specified. + time_major: Boolean. If `True`, the inputs and outputs will be in shape + `(timesteps, batch, ...)`, whereas in the False case, it will be + `(batch, timesteps, ...)`. Using `time_major = True` is a bit more + efficient because it avoids transposes at the beginning and end of + the RNN calculation. However, most TensorFlow data is batch-major, + so by default this function accepts input and emits output in + batch-major form. + zero_output_for_mask: Boolean. If `True`, the output for masked timestep + will be zeros, whereas in the `False` case, output from previous + timestep is returned. + return_all_outputs: Boolean. If `True`, return the recurrent outputs for + all timesteps in the sequence. If `False`, only return the output + for the last timestep (which consumes less memory). + + Returns: + A tuple, `(last_output, outputs, new_states)`. + - `last_output`: the latest output of the rnn, + with shape `(samples, ...)`. + - `outputs`: + - If `return_all_outputs=True`: a tensor with shape + `(samples, time, ...)` where each entry `outputs[s, t]` is the + output of the step function at time `t` for sample `s` + - Else, a tensor equal to `last_output` with shape + `(samples, 1, ...)` + - `new_states`: list of tensors, latest states returned by + the step function, of shape `(samples, ...)`. + """ + input_length = input_length or inputs.shape[1] + + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return tf.transpose(input_t, axes) + + if not time_major: + inputs = tree.map_structure(swap_batch_timestep, inputs) + + flattened_inputs = tree.flatten(inputs) + time_steps = flattened_inputs[0].shape[0] + time_steps_t = ( + tf.shape(flattened_inputs[0])[0] if time_steps is None else time_steps + ) + + for input_ in flattened_inputs: + input_.shape.with_rank_at_least(3) + + if mask is not None: + if mask.dtype != tf.bool: + mask = tf.cast(mask, tf.bool) + if len(mask.shape) == 2: + mask = tf.expand_dims(mask, axis=-1) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + # tf.where needs its condition tensor to be the same shape as its two + # result tensors, but in our case the condition (mask) tensor is + # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. + # So we need to broadcast the mask to match the shape of inputs. + # That's what the tile call does, it just repeats the mask along its + # second dimension n times. + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tree.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor, but got {mask_t}" + ) + if tree.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor, but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = tf.expand_dims(mask_t, -1) + multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:] + return tf.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = tf.unstack(input_t) # unstack for time_step dim + if go_backwards: + input_t.reverse() + return input_t + + if tree.is_nested(inputs): + processed_input = tree.map_structure( + _process_single_input_t, inputs + ) + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return tree.pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = tf.unstack(mask) + if go_backwards: + mask_list.reverse() + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = tf.zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = tf.where(tiled_mask_t, output, prev_output) + + flat_states = tree.flatten(states) + flat_new_states = tree.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) + flat_final_states = tuple( + tf.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) + ) + states = tree.pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = tf.stack(successive_outputs) + + if zero_output_for_mask: + last_output = tf.where( + _expand_mask(mask_list[-1], last_output), + last_output, + tf.zeros_like(last_output), + ) + outputs = tf.where( + _expand_mask(mask, outputs, fixed_dim=2), + outputs, + tf.zeros_like(outputs), + ) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = tf.stack(successive_outputs) + + else: # Unroll == False + states = tuple(initial_states) + + # Create input tensor array, if the inputs is nested tensors, then it + # will be flattened first, and tensor array will be created one per + # flattened tensor. + input_ta = tuple( + tf.TensorArray( + dtype=inp.dtype, + size=time_steps_t, + tensor_array_name=f"input_ta_{i}", + ) + for i, inp in enumerate(flattened_inputs) + ) + input_ta = tuple( + ( + ta.unstack(input_) + if not go_backwards + else ta.unstack(tf.reverse(input_, [0])) + ) + for ta, input_ in zip(input_ta, flattened_inputs) + ) + + # Get the time(0) input and compute the output for that, the output will + # be used to determine the dtype of output tensor array. Don't read from + # input_ta due to TensorArray clear_after_read default to True. + input_time_zero = tree.pack_sequence_as( + inputs, [inp[0] for inp in flattened_inputs] + ) + # output_time_zero is used to determine the cell output shape and its + # dtype. the value is discarded. + output_time_zero, _ = step_function( + input_time_zero, tuple(initial_states) + tuple(constants) + ) + + output_ta_size = time_steps_t if return_all_outputs else 1 + output_ta = tuple( + tf.TensorArray( + dtype=out.dtype, + size=output_ta_size, + element_shape=out.shape, + tensor_array_name=f"output_ta_{i}", + ) + for i, out in enumerate(tree.flatten(output_time_zero)) + ) + + time = tf.constant(0, dtype="int32", name="time") + + if input_length is None: + max_iterations = time_steps_t + else: + max_iterations = tf.reduce_max(input_length) + + while_loop_kwargs = { + "cond": lambda time, *_: time < time_steps_t, + "maximum_iterations": max_iterations, + "parallel_iterations": 32, + "swap_memory": True, + } + if mask is not None: + if go_backwards: + mask = tf.reverse(mask, [0]) + + mask_ta = tf.TensorArray( + dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta" + ) + mask_ta = mask_ta.unstack(mask) + + def masking_fn(time): + return mask_ta.read(time) + + def compute_masked_output(mask_t, flat_out, flat_mask): + tiled_mask_t = tuple( + _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape)) + for o in flat_out + ) + return tuple( + tf.where(m, o, fm) + for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask) + ) + + elif isinstance(input_length, tf.Tensor): + if go_backwards: + max_len = tf.reduce_max(input_length, axis=0) + rev_input_length = tf.subtract(max_len - 1, input_length) + + def masking_fn(time): + return tf.less(rev_input_length, time) + + else: + + def masking_fn(time): + return tf.greater(input_length, time) + + def compute_masked_output(mask_t, flat_out, flat_mask): + return tuple( + tf.where(mask_t, o, zo) + for (o, zo) in zip(flat_out, flat_mask) + ) + + else: + masking_fn = None + + if masking_fn is not None: + # Mask for the T output will be base on the output of T - 1. In the + # case T = 0, a zero filled tensor will be used. + flat_zero_output = tuple( + tf.zeros_like(o) for o in tree.flatten(output_time_zero) + ) + + def _step(time, output_ta_t, prev_output, *states): + """RNN step function. + + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + + Returns: + Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)` + """ + current_input = tuple(ta.read(time) for ta in input_ta) + # maybe set shape. + current_input = tree.pack_sequence_as(inputs, current_input) + mask_t = masking_fn(time) + output, new_states = step_function( + current_input, tuple(states) + tuple(constants) + ) + # mask output + flat_output = tree.flatten(output) + flat_mask_output = ( + flat_zero_output + if zero_output_for_mask + else tree.flatten(prev_output) + ) + flat_new_output = compute_masked_output( + mask_t, flat_output, flat_mask_output + ) + + # mask states + flat_state = tree.flatten(states) + flat_new_state = tree.flatten(new_states) + flat_final_state = compute_masked_output( + mask_t, flat_new_state, flat_state + ) + new_states = tree.pack_sequence_as(new_states, flat_final_state) + + ta_index_to_write = time if return_all_outputs else 0 + output_ta_t = tuple( + ta.write(ta_index_to_write, out) + for ta, out in zip(output_ta_t, flat_new_output) + ) + + return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple( + new_states + ) + + final_outputs = tf.while_loop( + body=_step, + loop_vars=(time, output_ta, flat_zero_output) + states, + **while_loop_kwargs, + ) + # Skip final_outputs[2] which is the output for final timestep. + new_states = final_outputs[3:] + else: + + def _step(time, output_ta_t, *states): + """RNN step function. + + Args: + time: Current timestep value. + output_ta_t: TensorArray. + *states: List of states. + + Returns: + Tuple: `(time + 1,output_ta_t) + tuple(new_states)` + """ + current_input = tuple(ta.read(time) for ta in input_ta) + current_input = tree.pack_sequence_as(inputs, current_input) + output, new_states = step_function( + current_input, tuple(states) + tuple(constants) + ) + flat_new_state = tree.flatten(new_states) + + flat_output = tree.flatten(output) + ta_index_to_write = time if return_all_outputs else 0 + output_ta_t = tuple( + ta.write(ta_index_to_write, out) + for ta, out in zip(output_ta_t, flat_output) + ) + + new_states = tree.pack_sequence_as( + initial_states, flat_new_state + ) + return (time + 1, output_ta_t) + tuple(new_states) + + final_outputs = tf.while_loop( + body=_step, + loop_vars=(time, output_ta) + states, + **while_loop_kwargs, + ) + new_states = final_outputs[2:] + + output_ta = final_outputs[1] + + outputs = tuple(o.stack() for o in output_ta) + last_output = tuple(o[-1] for o in outputs) + + outputs = tree.pack_sequence_as(output_time_zero, outputs) + last_output = tree.pack_sequence_as(output_time_zero, last_output) + + if not time_major: + outputs = tree.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +def gru( + inputs, + initial_state, + mask, + kernel, + recurrent_kernel, + bias, + activation, + recurrent_activation, + return_sequences=False, + go_backwards=False, + unroll=False, + time_major=False, + reset_after=True, +): + cudnn_supported = cudnn_ok( + activation, + recurrent_activation, + unroll, + use_bias=bias is not None, + reset_after=reset_after, + ) + if not cudnn_supported: + raise NotImplementedError + + from keras.src.backend.tensorflow import Variable + + if isinstance(kernel, Variable): + kernel = kernel.value + if isinstance(recurrent_kernel, Variable): + recurrent_kernel = recurrent_kernel.value + if isinstance(bias, Variable): + bias = bias.value + + try: + return _cudnn_gru( + inputs, + initial_state, + kernel, + recurrent_kernel, + bias, + mask, + time_major, + go_backwards, + return_sequences, + ) + except tf.errors.InvalidArgumentError: + # cuDNN op not found. + raise NotImplementedError + except tf.errors.NotFoundError: + # alternative error: device not found for op + raise NotImplementedError + + +def _do_gru_arguments_support_cudnn( + activation, + recurrent_activation, + unroll, + use_bias, + reset_after, +): + from keras.src import activations + from keras.src import ops + + return ( + activation in (activations.tanh, tf.tanh, ops.tanh) + and recurrent_activation + in (activations.sigmoid, tf.sigmoid, ops.sigmoid) + and not unroll + and use_bias + and reset_after + ) + + +def _do_lstm_arguments_support_cudnn( + activation, + recurrent_activation, + unroll, + use_bias, +): + from keras.src import activations + from keras.src import ops + + return ( + activation in (activations.tanh, tf.tanh, ops.tanh) + and recurrent_activation + in (activations.sigmoid, tf.sigmoid, ops.sigmoid) + and not unroll + and use_bias + ) + + +def _has_fully_masked_sequence(mask): + # Cudnn kernel will error out if the input sequence contains any + # fully masked data. We walk around this issue by rerouting the computation + # to standard kernel, until the issue on cudnn side has been fixed. For a + # fully masked sequence, it will contain all Falses. To make it easy to + # check, we inverse the boolean, check if any of the sequence has all True. + return tf.reduce_any( + tf.reduce_all(tf.logical_not(tf.cast(mask, dtype="bool")), axis=1) + ) + + +def _assert_valid_mask(mask): + valid = tf.logical_and( + tf.logical_not(_has_fully_masked_sequence(mask)), + _is_sequence_right_padded(mask), + ) + tf.Assert( + valid, + [ + ( + "You are passing a RNN mask that does not correspond to " + "right-padded sequences, while using cuDNN, which is not " + "supported. With cuDNN, RNN masks can only be used for " + "right-padding, e.g. `[[True, True, False, False]]` would " + "be a valid mask, but any mask that isn't just contiguous " + "`True`'s on the left and contiguous `False`'s on the right " + "would be invalid. You can pass `use_cudnn=False` to your " + "RNN layer to stop using cuDNN (this may be slower)." + ) + ], + ) + + +def _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False): + """Utility function convert variable to cuDNN compatible parameter. + + Note that Keras weights for kernels are different from the cuDNN format. + Eg.: + + ``` + Keras cuDNN + [[0, 1, 2], <---> [[0, 2, 4], + [3, 4, 5]] [1, 3, 5]] + ``` + + If the input weights need to be in a unified format, then set + `transpose_weights=True` to convert the weights. + + Args: + weights: list of weights for the kernels and recurrent kernels. + biases: list of biases for individual gate. + shape: the shape for the converted variables that will be feed to cuDNN. + transpose_weights: boolean, whether to transpose the weights. + + Returns: + The converted weights that can be feed to cuDNN ops as param. + """ + + def convert(w): + return tf.transpose(w) if transpose_weights else w + + weights = [tf.reshape(convert(x), shape) for x in weights] + biases = [tf.reshape(x, shape) for x in biases] + return tf.concat(weights + biases, axis=0) + + +def _is_sequence_right_padded(mask): + """Check the mask tensor and see if it right padded. + + cuDNN uses the sequence length param to skip the tailing + timestep. If the data is left padded, or not a strict right padding (has + masked value in the middle of the sequence), then cuDNN won't work + properly in those cases. + + Left padded data: [[False, False, True, True, True]]. + Right padded data: [[True, True, True, False, False]]. + Mixture of mask/unmasked data: [[True, False, True, False, False]]. + + Note that for the mixed data example above, the actually data RNN should see + are those 2 Trues (index 0 and 2), the index 1 False should be ignored and + not pollute the internal states. + + Args: + mask: the Boolean tensor with shape [batch, timestep] + + Returns: + boolean scalar tensor, whether the mask is strictly right padded. + """ + max_seq_length = tf.shape(mask)[1] + count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1) + right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length) + return tf.reduce_all( + tf.equal( + tf.cast(mask, dtype="bool"), + tf.cast(right_padded_mask, dtype="bool"), + ) + ) + + +def _compute_sequence_length_from_mask(mask, time_major): + """Calculate the sequence length tensor (1-D) based on the masking tensor. + + The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For + any timestep that should be masked, the corresponding field will be False. + Consider the following example: + a = [[True, True, False, False], + [True, True, True, False]] + It is a (2, 4) tensor, and the corresponding sequence length result should + be 1D tensor with value [2, 3]. Note that the masking tensor must be right + padded that could be checked by, e.g., `is_sequence_right_padded()`. + + Args: + mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] + if time_major=True. + time_major: Boolean, which indicates whether the mask is time major or + batch major. + + Returns: + sequence_length: 1D int32 tensor. + """ + timestep_index = 0 if time_major else 1 + return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index) + + +def _is_gpu_available(): + return bool(tf.config.list_logical_devices("GPU")) + + +def _cudnn_gru( + inputs, + initial_state, + kernel, + recurrent_kernel, + bias, + mask, + time_major, + go_backwards, + return_sequences, +): + """GRU with cuDNN implementation which is only available for GPU.""" + if mask is not None: + _assert_valid_mask(mask) + sequence_lengths = _compute_sequence_length_from_mask(mask, time_major) + else: + if time_major: + batch_dim = tf.shape(inputs)[1] + max_sequence_length = tf.shape(inputs)[0] + else: + batch_dim = tf.shape(inputs)[0] + max_sequence_length = tf.shape(inputs)[1] + sequence_lengths = tf.fill([batch_dim], max_sequence_length) + + if not time_major and sequence_lengths is None: + inputs = tf.transpose(inputs, perm=(1, 0, 2)) + seq_axis, batch_axis = (0, 1) + else: + seq_axis, batch_axis = (0, 1) if time_major else (1, 0) + + # For init_h, cuDNN expects one more dim of num_layers before or after batch + # dim for time major or batch major inputs respectively + init_h = tf.expand_dims(initial_state, axis=seq_axis) + + weights = tf.split(kernel, 3, axis=1) + weights += tf.split(recurrent_kernel, 3, axis=1) + # Note that the bias was initialized as shape (2, 3 * units), flatten it to + # (6 * units) + bias = tf.split(tf.reshape(bias, [-1]), 6) + + if tf.sysconfig.get_build_info()["is_cuda_build"]: + # Note that the gate order for cuDNN is different from the canonical + # format. canonical format is [z, r, h], whereas cuDNN is [r, z, h]. + # The swap need to be done for kernel, recurrent_kernel, input_bias, + # recurrent_bias. + # z is update gate weights. + # r is reset gate weights. + # h is output gate weights. + weights[0], weights[1] = weights[1], weights[0] + weights[3], weights[4] = weights[4], weights[3] + bias[0], bias[1] = bias[1], bias[0] + bias[3], bias[4] = bias[4], bias[3] + + params = _standardize_cudnn_weights( + weights=weights, + biases=bias, + shape=tf.constant([-1]), + transpose_weights=True, + ) + + if go_backwards: + # Three reversals are required. E.g., + # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked + # reversed_input_to_cudnn = [3, 2, 1, 0, 0] + # output_from_cudnn = [6, 5, 4, 0, 0] + # expected_output = [0, 0, 6, 5 ,4] + inputs = tf.reverse_sequence( + inputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, + ) + outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3( + input=inputs, + input_h=init_h, + input_c=0, + params=params, + is_training=True, + rnn_mode="gru", + sequence_lengths=sequence_lengths, + time_major=time_major, + ) + if go_backwards: + outputs = tf.reverse_sequence( + outputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, + ) + outputs = tf.reverse(outputs, axis=[seq_axis]) + + last_output = outputs[-1] + if not time_major and sequence_lengths is None and return_sequences: + outputs = tf.transpose(outputs, perm=[1, 0, 2]) + state = tf.squeeze(h, axis=seq_axis) + + # In the case of variable length input, the cudnn kernel will fill zeros for + # the output, whereas the default keras behavior is to bring over the + # previous output for t-1, so that in the return_sequence=False case, user + # can quickly get the final effect output instead just 0s at the last + # timestep. In order to mimic the default keras behavior, we copy the final + # h state as the last_output, since it is numerically same as the output. + if sequence_lengths is not None: + last_output = state + + # Match CPU return format + if not return_sequences: + outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) + + return ( + last_output, + outputs, + [state], + ) + + +def cudnn_ok( + activation, + recurrent_activation, + unroll, + use_bias, + reset_after=None, +): + if reset_after is None: + args_supported = _do_lstm_arguments_support_cudnn( + activation=activation, + recurrent_activation=recurrent_activation, + unroll=unroll, + use_bias=use_bias, + ) + else: + args_supported = _do_gru_arguments_support_cudnn( + activation=activation, + recurrent_activation=recurrent_activation, + unroll=unroll, + use_bias=use_bias, + reset_after=reset_after, + ) + return args_supported and _is_gpu_available() + + +def lstm( + inputs, + initial_state_h, + initial_state_c, + mask, + kernel, + recurrent_kernel, + bias, + activation, + recurrent_activation, + return_sequences=False, + go_backwards=False, + unroll=False, + time_major=False, +): + cudnn_supported = cudnn_ok( + activation, recurrent_activation, unroll, use_bias=bias is not None + ) + if not cudnn_supported: + raise NotImplementedError + + from keras.src.backend.tensorflow import Variable + + if isinstance(kernel, Variable): + kernel = kernel.value + if isinstance(recurrent_kernel, Variable): + recurrent_kernel = recurrent_kernel.value + if isinstance(bias, Variable): + bias = bias.value + + try: + return _cudnn_lstm( + inputs, + initial_state_h, + initial_state_c, + kernel, + recurrent_kernel, + bias, + mask, + time_major, + go_backwards, + return_sequences, + ) + except tf.errors.InvalidArgumentError: + # cuDNN op not found. + raise NotImplementedError + except tf.errors.NotFoundError: + # alternative error: device not found for op + raise NotImplementedError + + +def _cudnn_lstm( + inputs, + initial_state_h, + initial_state_c, + kernel, + recurrent_kernel, + bias, + mask, + time_major, + go_backwards, + return_sequences, +): + if mask is not None: + _assert_valid_mask(mask) + sequence_lengths = _compute_sequence_length_from_mask(mask, time_major) + else: + if time_major: + batch_dim = tf.shape(inputs)[1] + max_sequence_length = tf.shape(inputs)[0] + else: + batch_dim = tf.shape(inputs)[0] + max_sequence_length = tf.shape(inputs)[1] + sequence_lengths = tf.fill([batch_dim], max_sequence_length) + + if not time_major and sequence_lengths is None: + inputs = tf.transpose(inputs, perm=(1, 0, 2)) + seq_axis, batch_axis = (0, 1) + else: + seq_axis, batch_axis = (0, 1) if time_major else (1, 0) + # For init_h and init_c, cuDNN expects one more dim of num_layers before or + # after batch dim for time major or batch major inputs respectively + init_h = tf.expand_dims(initial_state_h, axis=seq_axis) + init_c = tf.expand_dims(initial_state_c, axis=seq_axis) + + weights = tf.split(kernel, 4, axis=1) + weights += tf.split(recurrent_kernel, 4, axis=1) + # cuDNN has an extra set of bias for inputs, we disable them (setting to 0), + # so that mathematically it is same as the canonical LSTM implementation. + full_bias = tf.concat((tf.zeros_like(bias), bias), 0) + + if tf.sysconfig.get_build_info()["is_rocm_build"]: + # ROCm MIOpen's weight sequence for LSTM is different from both + # canonical and Cudnn format + # MIOpen: [i, f, o, c] Cudnn/Canonical: [i, f, c, o] + # i is input gate weights. + # f is forget gate weights. + # o is output gate weights. + # c is cell gate weights. + weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)] + # full_bias is a tensor of shape (8*n,) + full_bias = tf.split(full_bias, 8, axis=0) + full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)] + + params = _standardize_cudnn_weights( + weights=weights, + biases=tf.split(full_bias, 8), + shape=tf.constant([-1]), + transpose_weights=True, + ) + + if go_backwards: + # Three reversals are required. E.g., + # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked + # reversed_input_to_cudnn = [3, 2, 1, 0, 0] + # output_from_cudnn = [6, 5, 4, 0, 0] + # expected_output = [0, 0, 6, 5 ,4] + inputs = tf.reverse_sequence( + inputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, + ) + outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3( + input=inputs, + input_h=init_h, + input_c=init_c, + params=params, + is_training=True, + rnn_mode="lstm", + sequence_lengths=sequence_lengths, + time_major=time_major, + ) + if go_backwards: + outputs = tf.reverse_sequence( + outputs, + sequence_lengths, + seq_axis=seq_axis, + batch_axis=batch_axis, + ) + outputs = tf.reverse(outputs, axis=[seq_axis]) + + last_output = outputs[-1] + if not time_major and sequence_lengths is None and return_sequences: + outputs = tf.transpose(outputs, perm=[1, 0, 2]) + h = tf.squeeze(h, axis=seq_axis) + c = tf.squeeze(c, axis=seq_axis) + + # In the case of variable length input, the cudnn kernel will fill zeros for + # the output, whereas the default keras behavior is to bring over the + # previous output for t-1, so that in the return_sequence=False case, user + # can quickly get the final effect output instead just 0s at the last + # timestep. In order to mimic the default keras behavior, we copy the final + # h state as the last_output, since it is numerically same as the output. + if sequence_lengths is not None: + last_output = h + + # Match CPU return format + if not return_sequences: + outputs = tf.expand_dims(last_output, axis=0 if time_major else 1) + + return (last_output, outputs, [h, c]) diff --git a/keras/src/backend/tensorflow/saved_model_test.py b/keras/src/backend/tensorflow/saved_model_test.py new file mode 100644 index 000000000000..4a7a4643f095 --- /dev/null +++ b/keras/src/backend/tensorflow/saved_model_test.py @@ -0,0 +1,436 @@ +"""Tests for SavedModel functionality under tf implementation.""" + +import os + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import testing +from keras.src.saving import object_registration +from keras.src.testing.test_utils import named_product + + +@object_registration.register_keras_serializable(package="my_package") +class CustomModelX(models.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dense1 = layers.Dense(1) + self.dense2 = layers.Dense(1) + + def call(self, inputs): + out = self.dense1(inputs) + return self.dense2(out) + + def one(self): + return 1 + + +@object_registration.register_keras_serializable(package="my_package") +class CustomSignatureModel(models.Model): + def __init__(self): + super(CustomSignatureModel, self).__init__() + self.v = tf.Variable(1.0) + + @tf.function + def __call__(self, x): + return x * self.v + + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def mutate(self, new_v): + self.v.assign(new_v) + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The SavedModel test can only run with TF backend.", +) +class SavedModelTest(testing.TestCase): + def test_sequential(self): + model = models.Sequential([layers.Dense(1)]) + model.compile(loss="mse", optimizer="adam") + X_train = np.random.rand(100, 3) + y_train = np.random.rand(100, 1) + model.fit(X_train, y_train) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertAllClose( + model(X_train), + restored_model.signatures["serving_default"]( + tf.convert_to_tensor(X_train, dtype=tf.float32) + )["output_0"], + rtol=1e-4, + atol=1e-4, + ) + + def test_functional(self): + inputs = layers.Input(shape=(3,)) + x = layers.Dense(1, name="first_dense")(inputs) + outputs = layers.Dense(1, name="second_dense")(x) + model = models.Model(inputs, outputs) + model.compile( + optimizer="adam", + loss="mse", + ) + X_train = np.random.rand(100, 3) + y_train = np.random.rand(100, 1) + model.fit(X_train, y_train) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertAllClose( + model(X_train), + restored_model.signatures["serving_default"]( + tf.convert_to_tensor(X_train, dtype=tf.float32) + )["output_0"], + rtol=1e-4, + atol=1e-4, + ) + + def test_subclassed(self): + model = CustomModelX() + model.compile( + optimizer="adam", + loss="mse", + metrics=[metrics.Hinge(), "mse"], + ) + X_train = np.random.rand(100, 3) + y_train = np.random.rand(100, 1) + model.fit(X_train, y_train) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertAllClose( + model(X_train), + restored_model.signatures["serving_default"]( + tf.convert_to_tensor(X_train, dtype=tf.float32) + )["output_0"], + rtol=1e-4, + atol=1e-4, + ) + + def test_custom_model_and_layer(self): + @object_registration.register_keras_serializable(package="my_package") + class CustomLayer(layers.Layer): + def __call__(self, inputs): + return inputs + + @object_registration.register_keras_serializable(package="my_package") + class Model(models.Model): + def __init__(self): + super().__init__() + self.layer = CustomLayer() + + @tf.function(input_signature=[tf.TensorSpec([None, 1])]) + def call(self, inputs): + return self.layer(inputs) + + model = Model() + inp = np.array([[1.0]]) + result = model(inp) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertAllClose( + result, + restored_model.call(inp), + rtol=1e-4, + atol=1e-4, + ) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + class TupleModel(models.Model): + def call(self, inputs): + x, y = inputs + return x + ops.mean(y, axis=1) + + class ArrayModel(models.Model): + def call(self, inputs): + x = inputs[0] + y = inputs[1] + return x + ops.mean(y, axis=1) + + class DictModel(models.Model): + def call(self, inputs): + x = inputs["x"] + y = inputs["y"] + return x + ops.mean(y, axis=1) + + input_x = tf.constant([1.0]) + input_y = tf.constant([[1.0, 0.0, 2.0]]) + if struct_type == "tuple": + model = TupleModel() + inputs = (input_x, input_y) + elif struct_type == "array": + model = ArrayModel() + inputs = [input_x, input_y] + elif struct_type == "dict": + model = DictModel() + inputs = {"x": input_x, "y": input_y} + + result = model(inputs) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + outputs = restored_model.signatures["serving_default"]( + inputs=input_x, inputs_1=input_y + ) + self.assertAllClose(result, outputs["output_0"], rtol=1e-4, atol=1e-4) + + def test_multi_input_model(self): + input_1 = layers.Input(shape=(3,)) + input_2 = layers.Input(shape=(5,)) + + y1 = layers.Dense(1)(input_1) + y2 = layers.Dense(1)(input_2) + layer_2 = layers.Dense(1, activation="relu") + output_1 = layer_2(y1) + output_2 = layer_2(y2) + model = models.Model([input_1, input_2], [output_1, output_2]) + + input_arr_1 = np.random.random((1, 3)).astype("float32") + input_arr_2 = np.random.random((1, 5)).astype("float32") + + model = models.Model([input_1, input_2], [output_1, output_2]) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + outputs_1 = model( + inputs=[ + tf.convert_to_tensor(input_arr_1, dtype=tf.float32), + tf.convert_to_tensor(input_arr_2, dtype=tf.float32), + ], + ) + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + + outputs_2 = restored_model.signatures["serving_default"]( + inputs=tf.convert_to_tensor(input_arr_1, dtype=tf.float32), + inputs_1=tf.convert_to_tensor(input_arr_2, dtype=tf.float32), + ) + self.assertAllClose( + outputs_1[0], outputs_2["output_0"], rtol=1e-4, atol=1e-4 + ) + self.assertAllClose( + outputs_1[1], outputs_2["output_1"], rtol=1e-4, atol=1e-4 + ) + + def test_multi_input_custom_model_and_layer(self): + @object_registration.register_keras_serializable(package="my_package") + class CustomLayer(layers.Layer): + def build(self, *input_shape): + pass + + def call(self, *input_list): + self.add_loss(input_list[-2] * 2) + return sum(input_list) + + @object_registration.register_keras_serializable(package="my_package") + class CustomModel(models.Model): + def build(self, *input_shape): + self.layer = CustomLayer() + self.layer.build(*input_shape) + + @tf.function + def call(self, *inputs): + inputs = list(inputs) + return self.layer(*inputs) + + model = CustomModel() + inp = [ + tf.constant(i, shape=[1, 1], dtype=tf.float32) for i in range(1, 4) + ] + expected = model(*inp) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + output = restored_model.call(*inp) + self.assertAllClose(expected, output, rtol=1e-4, atol=1e-4) + + def test_list_trackable_children_tracking(self): + @object_registration.register_keras_serializable(package="my_package") + class CustomLayerList(layers.Layer): + def __init__(self): + super().__init__() + self.sublayers = [ + layers.Dense(2), + layers.Dense(2), + ] + + def call(self, inputs): + x = inputs + for sublayer in self.sublayers: + x = sublayer(x) + return x + + inputs = layers.Input(shape=(1,)) + outputs = CustomLayerList()(inputs) + model = models.Model(inputs, outputs) + + inp = np.array([[1.0]]) + expected = model(inp) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertAllClose( + expected, + restored_model.signatures["serving_default"]( + tf.convert_to_tensor(inp, dtype=tf.float32) + )["output_0"], + rtol=1e-4, + atol=1e-4, + ) + + def test_dict_trackable_children_tracking(self): + @object_registration.register_keras_serializable(package="my_package") + class CustomLayerDict(layers.Layer): + def __init__(self): + super().__init__() + self.sublayers = { + "first_layer": layers.Dense(2), + "second_layer": layers.Dense(2), + } + + def call(self, inputs): + x = inputs + for key, sublayer in self.sublayers.items(): + x = sublayer(x) + return x + + inputs = layers.Input(shape=(1,)) + outputs = CustomLayerDict()(inputs) + model = models.Model(inputs, outputs) + + inp = np.array([[1.0]]) + expected = model(inp) + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertAllClose( + expected, + restored_model.signatures["serving_default"]( + tf.convert_to_tensor(inp, dtype=tf.float32) + )["output_0"], + rtol=1e-4, + atol=1e-4, + ) + + def test_fixed_signature_string_dtype(self): + @object_registration.register_keras_serializable(package="my_package") + class Adder(models.Model): + @tf.function( + input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)] + ) + def concat(self, x): + return x + x + + model = Adder() + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save(model, path) + restored_model = tf.saved_model.load(path) + self.assertEqual(model.concat("hello"), restored_model.concat("hello")) + + def test_non_fixed_signature_string_dtype(self): + @object_registration.register_keras_serializable(package="my_package") + class Adder(models.Model): + @tf.function + def concat(self, x): + return x + x + + model = Adder() + + no_fn_path = os.path.join(self.get_temp_dir(), "my_keras_model_no_fn") + tf.saved_model.save(model, no_fn_path) + restored_model = tf.saved_model.load(no_fn_path) + with self.assertRaisesRegex(ValueError, "zero restored functions"): + _ = restored_model.concat("hello") + + path = os.path.join(self.get_temp_dir(), "my_keras_model") + tf.saved_model.save( + model, + path, + signatures=model.concat.get_concrete_function( + tf.TensorSpec(shape=[], dtype=tf.string, name="string_input") + ), + ) + restored_model = tf.saved_model.load(path) + self.assertEqual(model.concat("hello"), restored_model.concat("hello")) + + def test_fine_tuning(self): + model = CustomSignatureModel() + model_no_signatures_path = os.path.join( + self.get_temp_dir(), "model_no_signatures" + ) + _ = model(tf.constant(0.0)) + + tf.saved_model.save(model, model_no_signatures_path) + restored_model = tf.saved_model.load(model_no_signatures_path) + + self.assertLen(list(restored_model.signatures.keys()), 0) + self.assertEqual(restored_model(tf.constant(3.0)).numpy(), 3) + restored_model.mutate(tf.constant(2.0)) + self.assertEqual(restored_model(tf.constant(3.0)).numpy(), 6) + optimizer = optimizers.SGD(0.05) + + def train_step(): + with tf.GradientTape() as tape: + loss = (10.0 - restored_model(tf.constant(2.0))) ** 2 + variables = tape.watched_variables() + grads = tape.gradient(loss, variables) + optimizer.apply_gradients(zip(grads, variables)) + return loss + + for _ in range(10): + # "v" approaches 5, "loss" approaches 0 + loss = train_step() + + self.assertAllClose(loss, 0.0, rtol=1e-2, atol=1e-2) + self.assertAllClose(restored_model.v.numpy(), 5.0, rtol=1e-2, atol=1e-2) + + def test_signatures_path(self): + model = CustomSignatureModel() + model_with_signature_path = os.path.join( + self.get_temp_dir(), "model_with_signature" + ) + call = model.__call__.get_concrete_function( + tf.TensorSpec(None, tf.float32) + ) + + tf.saved_model.save(model, model_with_signature_path, signatures=call) + restored_model = tf.saved_model.load(model_with_signature_path) + self.assertEqual( + list(restored_model.signatures.keys()), ["serving_default"] + ) + + def test_multiple_signatures_dict_path(self): + model = CustomSignatureModel() + model_multiple_signatures_path = os.path.join( + self.get_temp_dir(), "model_with_multiple_signatures" + ) + call = model.__call__.get_concrete_function( + tf.TensorSpec(None, tf.float32) + ) + signatures = { + "serving_default": call, + "array_input": model.__call__.get_concrete_function( + tf.TensorSpec([None], tf.float32) + ), + } + + tf.saved_model.save( + model, model_multiple_signatures_path, signatures=signatures + ) + restored_model = tf.saved_model.load(model_multiple_signatures_path) + self.assertEqual( + list(restored_model.signatures.keys()), + ["serving_default", "array_input"], + ) diff --git a/keras/src/backend/tensorflow/sparse.py b/keras/src/backend/tensorflow/sparse.py new file mode 100644 index 000000000000..f6a1da210d29 --- /dev/null +++ b/keras/src/backend/tensorflow/sparse.py @@ -0,0 +1,782 @@ +import functools + +import tensorflow as tf + +ones_bool = functools.partial(tf.ones, dtype=tf.bool) +ones_int8 = functools.partial(tf.ones, dtype=tf.int8) +zeros_int8 = functools.partial(tf.zeros, dtype=tf.int8) +ones_like_int8 = functools.partial(tf.ones_like, dtype=tf.int8) +zeros_like_int8 = functools.partial(tf.zeros_like, dtype=tf.int8) + + +def sparse_to_dense(x, default_value=None): + x_shape = x.shape + if x_shape.rank == 0: + # Workaround for bug on GPU when sparse tensor represents a scalar. + if x.values.shape[0] == 0: + return tf.constant(default_value, dtype=x.dtype) + else: + return tf.reshape(x.values, ()) + x = tf.sparse.to_dense(x, default_value=default_value) + x.set_shape(x_shape) + return x + + +def sparse_with_values(x, values): + x_shape = x.shape + x = tf.SparseTensor(x.indices, values, x.dense_shape) + x.set_shape(x_shape) + return x + + +def broadcast_scalar_to_sparse_shape(scalar, sparse): + output = tf.broadcast_to(scalar, sparse.dense_shape) + output.set_shape(sparse.shape) + return output + + +def sparse_subtract(x1, x2): + """Subtraction for `tf.SparseTensor`s. + + Either `x1` or `x2` or both can be `tf.SparseTensor`s. + + Args: + x1: fist tensor to add. + x2: second tensor to add. + Returns: + The sum of `x1` and `x2`, which is a `tf.SparseTensor` if and only if + both `x1` or `x2` are `tf.SparseTensor`s. + """ + if isinstance(x2, tf.SparseTensor): + return tf.sparse.add(x1, tf.sparse.map_values(tf.negative, x2)) + else: + return tf.sparse.add(x1, tf.negative(x2)) + + +def sparse_union_indices_and_values(x1, x2_indices, x2_values=None): + """Compute the indices for the union of the indices of the provided + `tf.SparseTensor`s and another set of indices and return the modified values + for these indices. + + Args: + x: a `tf.SparseTensor`. + indices: another set of indices in the `tf.SparseTensor` format. + Returns: A tuple containing: + - the indices for the union + - `x1` values for the union indices (some zeros were added) + - `x2` values for the union indices (some zeros were added) or `None` if + `x2_values` was `None`. + """ + # Add zeros at the x2 indices to x1 to create the union. + zeros2 = tf.SparseTensor( + x2_indices, + tf.zeros((tf.shape(x2_indices)[0],), x1.values.dtype), + x1.dense_shape, + ) + x1_for_union = tf.sparse.add(x1, zeros2) + if x2_values is not None: + # Add zeros at the x1 indices to x2 to create the union. + x2 = tf.SparseTensor(x2_indices, x2_values, x1.dense_shape) + zeros1 = tf.sparse.map_values(tf.zeros_like, x1) + x2_for_union = tf.sparse.add(x2, zeros1) + return x1_for_union.indices, x1_for_union.values, x2_for_union.values + else: + return x1_for_union.indices, x1_for_union.values, None + + +def indexed_slices_union_indices_and_values(x1, x2_indices, x2_values=None): + """Compute the indices for the union of two `tf.IndexedSlices` and modify + the values for these indices. + + Args: + x1: the first `tf.IndexedSlices`. + x2_indices: the indices for the second `tf.IndexedSlices`. + x2_value: (optional) the values for the second `tf.IndexedSlices`. + Returns: A tuple containing: + - the indices for the union + - `x1` values for the union indices (some zeros were added) + - `x2` values for the union indices (some zeros were added) or `None` if + `x2_values` was `None`. + """ + # Compute the union of the indices by doing a logical or between the one-hot + # encoded indices for x1 and x2. + dim_0 = x1.dense_shape[0] + x1_indices_expanded = tf.expand_dims(x1.indices, axis=1) + x2_indices_expanded = tf.expand_dims(x2_indices, axis=1) + x1_indices_count = tf.shape(x1_indices_expanded)[0] + x2_indices_count = tf.shape(x2_indices_expanded)[0] + x1_indices_one_hot = tf.scatter_nd( + x1_indices_expanded, + ones_bool((x1_indices_count,)), + (dim_0,), + ) + x2_indices_one_hot = tf.scatter_nd( + x2_indices_expanded, + ones_bool((x2_indices_count,)), + (dim_0,), + ) + union_indices = tf.squeeze( + tf.where(tf.math.logical_or(x1_indices_one_hot, x2_indices_one_hot)), + axis=-1, + ) + union_indices_count = tf.shape(union_indices)[0] + + # Re-gather the values with extra zeros added at indices that are part of + # the union but were not in x1 or x2. + def values_for_union(indices_expanded, indices_count, values): + indices_indices = tf.scatter_nd( + indices_expanded, + tf.range(1, indices_count + 1), + (dim_0,), + ) + to_union_indices = tf.gather(indices_indices, union_indices) + values_with_leading_zeros = tf.concat( + [tf.zeros_like(values[0:1]), values], axis=0 + ) + return tf.gather(values_with_leading_zeros, to_union_indices) + + # Only recompute values if some indices were added. + x1_values_for_union_indices = tf.cond( + tf.equal(x1_indices_count, union_indices_count), + lambda: x1.values, + lambda: values_for_union( + x1_indices_expanded, x1_indices_count, x1.values + ), + ) + if x2_values is not None: + x2_values_for_union_indices = tf.cond( + tf.equal(x2_indices_count, union_indices_count), + lambda: x2_values, + lambda: values_for_union( + x2_indices_expanded, x2_indices_count, x2_values + ), + ) + else: + x2_values_for_union_indices = None + + return ( + union_indices, + x1_values_for_union_indices, + x2_values_for_union_indices, + ) + + +def sparse_intersection_indices_and_values(x1, x2): + """Compute the indices for the intersection of two `tf.SparseTensor`s and + modify the values for these indices. + + Args: + x1: the first `tf.SparseTensor`. + x2: the second `tf.SparseTensor`. + Returns: A tuple containing: + - the indices for the intersection + - `x1` values for the intersection indices (some values were removed) + - `x2` values for the intersection indices (some values were removed) + """ + # Compute the intersection of indices in the form of a sparse + # tensor containing ones as values. + ones1 = tf.sparse.map_values(ones_like_int8, x1) + ones2 = tf.sparse.map_values(ones_like_int8, x2) + # tf.sets.intersection ignores the last dimension when, so we + # need to add a dummy extra dimension and then remove it. + intersection_extra_dim = tf.sets.intersection( + tf.sparse.expand_dims(ones1, axis=-1), + tf.sparse.expand_dims(ones2, axis=-1), + ) + + def empty_intersection(): + return ( + tf.zeros((0, x1.shape.rank), dtype=tf.int64), + tf.zeros((0,), dtype=x1.values.dtype), + tf.zeros((0,), dtype=x2.values.dtype), + ) + + def non_empty_intersection(): + intersection = tf.sparse.reshape(intersection_extra_dim, x1.dense_shape) + + # Compute the masks to remove indices in x1 and x2 that are not + # in the intersection, then trim x1 and x2. + zeros1 = tf.sparse.map_values(zeros_like_int8, x1) + zeros2 = tf.sparse.map_values(zeros_like_int8, x2) + mask1 = tf.sparse.add(zeros1, intersection) + mask2 = tf.sparse.add(zeros2, intersection) + return ( + intersection.indices, + tf.sparse.retain(x1, tf.cast(mask1.values, tf.bool)).values, + tf.sparse.retain(x2, tf.cast(mask2.values, tf.bool)).values, + ) + + return tf.cond( + tf.equal(tf.size(intersection_extra_dim), 0), + empty_intersection, + non_empty_intersection, + ) + + +def indexed_slices_intersection_indices_and_values(x1, x2): + """Compute the indices for the intersection of two `tf.IndexedSlices` and + modify the values for these indices. + + Args: + x1: the first `tf.IndexedSlices`. + x2: the second `tf.IndexedSlices`. + Returns: A tuple containing: + - the indices for the intersection + - `x1` values for the intersection indices (some values were removed) + - `x2` values for the intersection indices (some values were removed) + """ + # Compute the intersection of the indices by doing a logical + # and between the one hot encoded indices for x1 and x2. + dim_0 = x1.dense_shape[0] + x1_indices_expanded = tf.expand_dims(x1.indices, axis=1) + x2_indices_expanded = tf.expand_dims(x2.indices, axis=1) + x1_indices_count = x1_indices_expanded.shape[0] + x2_indices_count = x2_indices_expanded.shape[0] + x1_indices_one_hot = tf.scatter_nd( + x1_indices_expanded, + ones_bool((x1_indices_count,)), + (dim_0,), + ) + x2_indices_one_hot = tf.scatter_nd( + x2_indices_expanded, + ones_bool((x2_indices_count,)), + (dim_0,), + ) + intersection_indices = tf.squeeze( + tf.where(tf.math.logical_and(x1_indices_one_hot, x2_indices_one_hot)), + axis=-1, + ) + intersection_indices_count = tf.shape(intersection_indices)[0] + + def empty_intersection(): + return ( + intersection_indices, + tf.zeros((0,) + x1.values.shape[1:], x1.dtype), + tf.zeros((0,) + x2.values.shape[1:], x2.dtype), + ) + + def non_empty_intersection(): + # Re-gather sub parts of the values that are part of the intersection. + def values_for_intersection(indices_expanded, indices_count, values): + indices_indices = tf.scatter_nd( + indices_expanded, + tf.range(indices_count), + (dim_0,), + ) + to_intersection_indices = tf.gather( + indices_indices, intersection_indices + ) + return tf.gather(values, to_intersection_indices) + + # Only recompute values if some indices were removed. + x1_values_for_intersection = tf.cond( + tf.equal(x1_indices_count, intersection_indices_count), + lambda: x1.values, + lambda: values_for_intersection( + x1_indices_expanded, x1_indices_count, x1.values + ), + ) + x2_values_for_intersection = tf.cond( + tf.equal(x2_indices_count, intersection_indices_count), + lambda: x2.values, + lambda: values_for_intersection( + x2_indices_expanded, x2_indices_count, x2.values + ), + ) + + return ( + intersection_indices, + x1_values_for_intersection, + x2_values_for_intersection, + ) + + return tf.cond( + tf.equal(intersection_indices_count, 0), + empty_intersection, + non_empty_intersection, + ) + + +def densifying_unary(default_value): + """Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to + a non-zero-preserving element-wise unary operator. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise + - The operator must be unary (one input tensor and one output tensor) + - The operator must return a tensor of the same shape. + + Additional arguments to the function (besides the input tensor) are + supported. The returned result is a dense tensor and contains + `default_value` outside of the indices of the input tensor. + + Args: + default_value: The value to use outside of indices. It must be the value + that the operator returns for zero values. + Returns: + Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`. + """ + + def wrap_densifying_unary(func): + @functools.wraps(func) + def sparse_wrapper(x, *args, **kwargs): + if isinstance(x, tf.SparseTensor): + sparse_output = sparse_with_values( + x, func(x.values, *args, **kwargs) + ) + return sparse_to_dense( + sparse_output, + tf.cast(default_value, sparse_output.values.dtype), + ) + elif isinstance(x, tf.IndexedSlices): + sparse_output_values = func(x.values, *args, **kwargs) + output = tf.fill( + x.dense_shape, + tf.cast(default_value, sparse_output_values.dtype), + ) + return tf.tensor_scatter_nd_update( + output, tf.expand_dims(x.indices, 1), sparse_output_values + ) + return func(x, *args, **kwargs) + + return sparse_wrapper + + return wrap_densifying_unary + + +def elementwise_unary(func): + """Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to + a zero-preserving element-wise unary operator. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise + - The operator must be unary (one input tensor and one output tensor) + - The operator must return a tensor of the same shape, and if it is a + `tf.SparseTensor` or `tf.IndexedSlices`, the indices of the result must be + the same. Therefore: + - Reduction operations are not supported (e.g. `mean`). + - Operations for which the result may be dense (e.g. `reciprocal`), or + the sparse indices depend on the inputs are not supported (e.g. + `clip`). This implies that `func(0)` must be 0. + + Additional arguments to the function (besides the input tensor) are + supported as long as they cannot change the indices of the result. For + instance,`round` is supported, but `clip` is not supported as + `clip(x, 1.0, 2.0)` would always return a dense tensor. + + Note that if an input sparse tensor contains zero values, the indices and + the zero values are preserved. + + Args: + func: The function to wrap. + Returns: + Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`. + """ + + @functools.wraps(func) + def sparse_wrapper(x, *args, **kwargs): + if isinstance(x, tf.SparseTensor): + return sparse_with_values(x, func(x.values, *args, **kwargs)) + elif isinstance(x, tf.IndexedSlices): + return tf.IndexedSlices( + func(x.values, *args, **kwargs), x.indices, x.dense_shape + ) + else: + return func(x, *args, **kwargs) + + return sparse_wrapper + + +def elementwise_binary_union(sparse_op, densify_mixed=False): + """Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to + an element-wise binary operator such that the indices present in the result + are the union of the indices in the two operand. + + The primary use case for this is the `add` and `subtract` operators. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise. + - The operator must be binary (two input tensors and one output tensor). + - Both inputs must be of the same shape or one input must be a scalar. + - The output must be of the same shape as the (non scalar) inputs. + - The indices of the output must be the union of the indices of the inputs. + This implies that func(0, 0) must be 0. As a result, if one operand is + dense or a scalar, then the result will be dense. + + Additional arguments to the function (besides the input tensors) are not + supported. + + Note that if the result of the operation is zero at some indices, including + because the operands were zero at these indices, the zeros and indices are + preserved. + + Args: + sparse_op: implementation of the operation for `tf.SparseTensor`. Must + work if both of the operands are `tf.SparseTensor`s and can + optionally work if one of the operand is a `tf.SparseTensor` and + the other one is dense tensor, see `densify_mixed`. + densify_mixed: if `True`, `sparse_op` does not support a mix of + `tf.SparseTensor` and dense tensor or dense tensor with + `tf.SparseTensor` and the `tf.SparseTensor` tensor is densified. + Returns: + Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`. + """ + + def wrap_elementwise_binary_union(func): + @functools.wraps(func) + def sparse_wrapper(x1, x2): + if isinstance(x1, tf.SparseTensor): + if isinstance(x2, tf.SparseTensor): + # x1 is a SparseTensor and x2 is a SparseTensor. + if x1.indices is x2.indices: + return sparse_with_values( + x1, func(x1.values, x2.values) + ) + else: + output = sparse_op(x1, x2) + output.set_shape(x1.shape) + return output + else: + # x1 is a SparseTensor. + if densify_mixed: + x1 = sparse_to_dense(x1) + else: + if not hasattr(x2, "shape") or len(x2.shape) == 0: + # x2 is a scalar, broadcast. + x2 = broadcast_scalar_to_sparse_shape(x2, x1) + return sparse_op(x1, x2) + elif isinstance(x2, tf.SparseTensor): + # x2 is a SparseTensor. + if densify_mixed: + x2 = sparse_to_dense(x2) + else: + if not hasattr(x1, "shape") or len(x1.shape) == 0: + # x1 is a scalar, broadcast. + x1 = broadcast_scalar_to_sparse_shape(x1, x2) + return sparse_op(x1, x2) + elif isinstance(x1, tf.IndexedSlices): + if isinstance(x2, tf.IndexedSlices): + # x1 is an IndexedSlices and x2 is an IndexedSlices. + if x1.indices is x2.indices: + return tf.IndexedSlices( + func(x1.values, x2.values), + x1.indices, + x1.dense_shape, + ) + else: + # Compute the union of indices. + ( + union_indices, + x1_values_for_union, + x2_values_for_union, + ) = indexed_slices_union_indices_and_values( + x1, x2.indices, x2.values + ) + # Now, it is an element-wise operation on the union. + return tf.IndexedSlices( + func( + x1_values_for_union, + x2_values_for_union, + ), + union_indices, + x1.dense_shape, + ) + else: + # x1 is an IndexedSlices, densify. + x1 = tf.convert_to_tensor(x1) + elif isinstance(x2, tf.IndexedSlices): + # x2 is an IndexedSlices, densify. + x2 = tf.convert_to_tensor(x2) + return func(x1, x2) + + return sparse_wrapper + + return wrap_elementwise_binary_union + + +def elementwise_binary_intersection(func): + """Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to + an element-wise binary operator such that the indices present in the result + are the intersection of the indices in the two operand. + + The primary use case for this is the `multiply` operator. + + There are requirements on the operator for this decorator to work correctly: + + - The operator must be element-wise. + - The operator must be binary (two input tensors and one output tensor). + - Both inputs must be of the same shape or one input must be a scalar. + - The output must be of the same shape as the (non scalar) inputs. + - The indices of the output must be the intersection of the indices of the + inputs. This implies that func(0, x) and func(x, 0) must be 0 for any x. + As a result, if one operand is dense or a scalar, then the indices are the + ones from the other operand. + + Additional arguments to the function (besides the input tensors) are not + supported. + + Note that if the operands contains zero values at some common indices, the + indices and the zero values are preserved. + + Args: + func: The function to wrap. + Returns: + Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`. + """ + + @functools.wraps(func) + def sparse_wrapper(x1, x2): + if isinstance(x1, tf.SparseTensor): + if isinstance(x2, tf.SparseTensor): + # x1 is a SparseTensor and x2 is a SparseTensor. + if x1.indices is x2.indices: + return sparse_with_values(x1, func(x1.values, x2.values)) + else: + # Compute the intersection of indices. + ( + intersection_indices, + x1_values_for_intersection, + x2_values_for_intersection, + ) = sparse_intersection_indices_and_values(x1, x2) + # Now, it is an element-wise operation on the intersection. + output = tf.SparseTensor( + intersection_indices, + func( + x1_values_for_intersection, + x2_values_for_intersection, + ), + x1.dense_shape, + ) + output.set_shape(x1.shape) + return output + else: + # x1 is a SparseTensor. + if not hasattr(x2, "shape") or len(x2.shape) == 0: + # x2 is a scalar, apply func element-wise. + return sparse_with_values(x1, func(x1.values, x2)) + else: + # x2 is dense, gather values from x1 indices. + return sparse_with_values( + x1, func(x1.values, tf.gather_nd(x2, x1.indices)) + ) + elif isinstance(x2, tf.SparseTensor): + # x2 is a SparseTensor. + if not hasattr(x1, "shape") or len(x1.shape) == 0: + # x1 is a scalar, apply func element-wise. + return sparse_with_values(x2, func(x1, x2.values)) + else: + # x1 is dense, gather values from x2 indices. + return sparse_with_values( + x2, func(tf.gather_nd(x1, x2.indices), x2.values) + ) + elif isinstance(x1, tf.IndexedSlices): + if isinstance(x2, tf.IndexedSlices): + # x1 is an IndexedSlices and x2 is an IndexedSlices. + if x1.indices is x2.indices: + return tf.IndexedSlices( + func(x1.values, x2.values), x1.indices, x1.dense_shape + ) + else: + # Compute the intersection of indices. + ( + intersection_indices, + x1_values_for_intersection, + x2_values_for_intersection, + ) = indexed_slices_intersection_indices_and_values(x1, x2) + # Now, it is an element-wise operation on the intersection. + return tf.IndexedSlices( + func( + x1_values_for_intersection, + x2_values_for_intersection, + ), + intersection_indices, + x1.dense_shape, + ) + else: + # x1 is an IndexedSlices. + if not hasattr(x2, "shape") or len(x2.shape) == 0: + # x2 is a scalar, apply func element-wise. + return tf.IndexedSlices( + func(x1.values, x2), x1.indices, x1.dense_shape + ) + else: + # x2 is dense, gather values from x1 indices. + return tf.IndexedSlices( + func(x1.values, tf.gather(x2, x1.indices)), + x1.indices, + x1.dense_shape, + ) + elif isinstance(x2, tf.IndexedSlices): + # x2 is an IndexedSlices. + if not hasattr(x1, "shape") or len(x1.shape) == 0: + # x1 is a scalar, apply func element-wise. + return tf.IndexedSlices( + func(x1, x2.values), x2.indices, x2.dense_shape + ) + else: + # x1 is dense, gather values from x2 indices. + return tf.IndexedSlices( + func(tf.gather(x1, x2.indices), x2.values), + x2.indices, + x2.dense_shape, + ) + # Default case, no SparseTensor and no IndexedSlices. + return func(x1, x2) + + return sparse_wrapper + + +def elementwise_division(func): + """Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to + element-wise binary division and related operators. + + This decorator is designed for operations related to the division of two + operands (e.g. `divide`). It accepts `tf.SparseTensor` and + `tf.IndexedSlices` for both the dividend and the divisor, but handles them + differently based on whether they are the dividend or the divisor. + + - If the divisor is a `tf.SparseTensor` or `tf.IndexedSlices`, it is + densified and the result is dense because the result contains Inf or Nan + outside of the indices of the dividend. + - If the dividend is a `tf.SparseTensor` or `tf.IndexedSlices` and the + divisor is dense, it finds occurrences of zeros and NaNs in the divisor. + The result may therefore have more indices than there were in the dividend + to return correct values where the divisor was zero or NaN. + - If the dividend is a `tf.SparseTensor` or `tf.IndexedSlices` and the + divisor is a scalar, it does the division element-wise. Note that the + result is incorrectly sparse if the scalar divisor is zero. + + Args: + func: The function to wrap. + Returns: + Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`. + """ + + @functools.wraps(func) + def sparse_wrapper(x1, x2): + if isinstance(x1, tf.SparseTensor): + if isinstance(x2, tf.SparseTensor): + # x1 is a SparseTensor and x2 is a SparseTensor. + # Divisor is sparse, meaning we're doing divisions by zero + # outside of x2.indices, so the result is dense. Densify both. + x1 = sparse_to_dense(x1) + x2 = sparse_to_dense(x2) + else: + # x1 is a SparseTensor. + if not hasattr(x2, "shape") or len(x2.shape) == 0: + # x2 is a scalar, apply func element-wise. + return sparse_with_values(x1, func(x1.values, x2)) + else: + # x2 is dense. + x2_zeros_and_nans = tf.equal(x2, 0) + if not tf.as_dtype(x2.dtype).is_integer: + x2_zeros_and_nans = tf.math.logical_or( + x2_zeros_and_nans, tf.math.is_nan(x2) + ) + + def func_for_x1_indices(): + # Gather values from x1 indices. + return sparse_with_values( + x1, func(x1.values, tf.gather_nd(x2, x1.indices)) + ) + + def func_for_union_indices(): + # Compute the union of indices to keep zeros and NaNs. + x2_zeros_and_nan_indices = tf.where(x2_zeros_and_nans) + ( + union_indices, + x1_values_for_union, + _, + ) = sparse_union_indices_and_values( + x1, x2_zeros_and_nan_indices + ) + output = tf.SparseTensor( + union_indices, + func( + x1_values_for_union, + tf.gather_nd(x2, union_indices), + ), + x1.dense_shape, + ) + output.set_shape(x1.shape) + return output + + return tf.cond( + tf.reduce_any(x2_zeros_and_nans), + func_for_union_indices, + func_for_x1_indices, + ) + elif isinstance(x2, tf.SparseTensor): + # x2 is a SparseTensor. + # Divisor is sparse, densify to do the divisions by zero correctly. + x2 = sparse_to_dense(x2) + elif isinstance(x1, tf.IndexedSlices): + if isinstance(x2, tf.IndexedSlices): + # x1 is an IndexedSlices and x2 is an IndexedSlices. + # Divisor is slices, meaning we're doing divisions by zero + # outside of x2.indices, so the result is dense. Densify both. + x1 = tf.convert_to_tensor(x1) + x2 = tf.convert_to_tensor(x2) + else: + # x1 is a IndexedSlices. + if not hasattr(x2, "shape") or len(x2.shape) == 0: + # x2 is a scalar, apply func element-wise. + return tf.IndexedSlices( + func(x1.values, x2), x1.indices, x1.dense_shape + ) + else: + # x2 is dense. + x2_zeros_and_nans = tf.equal(x2, 0) + if not tf.as_dtype(x2.dtype).is_integer: + x2_zeros_and_nans = tf.math.logical_or( + x2_zeros_and_nans, tf.math.is_nan(x2) + ) + x2_zeros_and_nans = tf.reduce_any( + x2_zeros_and_nans, axis=tuple(range(1, x2.shape.rank)) + ) + + def func_for_x1_indices(): + # Gather values from x1 indices. + return tf.IndexedSlices( + func(x1.values, tf.gather(x2, x1.indices)), + x1.indices, + x1.dense_shape, + ) + + def func_for_union_indices(): + x2_zeros_and_nan_indices = tf.squeeze( + tf.where(x2_zeros_and_nans), axis=-1 + ) + # Compute the union of indices to keep zeros and NaNs. + ( + union_indices, + x1_values_for_union, + _, + ) = indexed_slices_union_indices_and_values( + x1, x2_zeros_and_nan_indices + ) + return tf.IndexedSlices( + func( + x1_values_for_union, + tf.gather(x2, union_indices), + ), + union_indices, + x1.dense_shape, + ) + + return tf.cond( + tf.reduce_any(x2_zeros_and_nans), + func_for_union_indices, + func_for_x1_indices, + ) + elif isinstance(x2, tf.IndexedSlices): + # x2 is a IndexedSlices. + # Divisor is slices, densify to do the divisions by zero correctly. + x2 = tf.convert_to_tensor(x2) + # Default case, no SparseTensor and no IndexedSlices. + return func(x1, x2) + + return sparse_wrapper diff --git a/keras/src/backend/tensorflow/tensorboard.py b/keras/src/backend/tensorflow/tensorboard.py new file mode 100644 index 000000000000..cf1c4c5102d8 --- /dev/null +++ b/keras/src/backend/tensorflow/tensorboard.py @@ -0,0 +1,21 @@ +from keras.src.utils.module_utils import tensorflow as tf + + +def start_trace(logdir): + tf.profiler.experimental.start(logdir=logdir) + + +def stop_trace(save): + tf.profiler.experimental.stop(save=save) + + +def start_batch_trace(batch): + batch_trace_context = tf.profiler.experimental.Trace( + "Profiled batch", step_num=batch + ) + batch_trace_context.__enter__() + return batch_trace_context + + +def stop_batch_trace(batch_trace_context): + batch_trace_context.__exit__(None, None, None) diff --git a/keras/src/backend/tensorflow/trackable.py b/keras/src/backend/tensorflow/trackable.py new file mode 100644 index 000000000000..e14b2996af34 --- /dev/null +++ b/keras/src/backend/tensorflow/trackable.py @@ -0,0 +1,58 @@ +import tensorflow as tf + +from keras.src.utils import tracking + + +class KerasAutoTrackable(tf.__internal__.tracking.AutoTrackable): + """Manages dependencies on other objects with Keras tracking. + + Similar to TF AutoTrackable, but disabling tracking is based + on tracking within Keras. + + This serves as an interface between Keras tracking and TF tracking. + """ + + def __setattr__(self, name, value): + """Support self.foo = trackable syntax.""" + try: + if getattr(self, name) is value: + # Short circuit for `self.$x = self.$x`. + return + except AttributeError: + pass + + if getattr(self, "_self_setattr_tracking", True): + value = sticky_attribute_assignment( + trackable=self, value=value, name=name + ) + super().__setattr__(name, value) + + +def sticky_attribute_assignment(trackable, name, value): + """Adds dependencies, called from __setattr__. + + Args: + trackable: The object to add dependencies to (generally the one having + an attribute assigned). + name: The attribute name being assigned. + value: The value being assigned. Not necessarily a trackable object. + + Returns: + The value which should be stored in the attribute. + """ + if isinstance( + value, (tracking.TrackedList, tracking.TrackedDict, tracking.TrackedSet) + ) and hasattr(trackable, "_tracked"): + trackable._tracked.append(name) + if not tracking.is_tracking_enabled(): + return value + if isinstance(value, tf.__internal__.tracking.Trackable): + trackable._track_trackable( # pylint: disable=protected-access + value, + name=name, + # Allow the user to switch the Trackable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + overwrite=True, + ) + return value diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py new file mode 100644 index 000000000000..cd6410999dd2 --- /dev/null +++ b/keras/src/backend/tensorflow/trainer.py @@ -0,0 +1,1004 @@ +import contextlib +import functools +import warnings + +import numpy as np +import tensorflow as tf +from tensorflow.python.eager import context as tf_context + +from keras.src import callbacks as callbacks_module +from keras.src import metrics as metrics_module +from keras.src import optimizers as optimizers_module +from keras.src import tree +from keras.src.backend import config +from keras.src.losses import loss as loss_module +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import array_slicing +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + + +class TensorFlowTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.train_function = None + self.test_function = None + self.predict_function = None + + # Specifies how many steps of the step_per_execution loop to unroll. + # Increasing this value can reduce kernel launch overhead, + # but will increase memory usage and compilation time. + self.unrolled_steps_per_execution = 1 + + # Model must be created under scope of DistStrat it will be trained + # with. + if tf.distribute.has_strategy(): + self._distribute_strategy = tf.distribute.get_strategy() + else: + self._distribute_strategy = None + + @property + def distribute_strategy(self): + return self._distribute_strategy or tf.distribute.get_strategy() + + @property + def distribute_reduction_method(self): + return self._distribute_reduction_method or "auto" + + @distribute_reduction_method.setter + def distribute_reduction_method(self, value): + self._distribute_reduction_method = value + + def train_step(self, data): + x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + + # Forward pass + with tf.GradientTape() as tape: + if self._call_has_training_arg: + y_pred = self(x, training=True) + else: + y_pred = self(x) + loss = self._compute_loss( + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=True, + ) + self._loss_tracker.update_state( + loss_module.unscale_loss_for_distribution(loss), + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], + ) + if self.optimizer is not None: + loss = self.optimizer.scale_loss(loss) + + # Compute gradients + if self.trainable_weights: + trainable_weights = self.trainable_weights + gradients = tape.gradient(loss, trainable_weights) + + # Update weights + self.optimizer.apply_gradients(zip(gradients, trainable_weights)) + else: + warnings.warn("The model does not have any trainable weights.") + + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + + def test_step(self, data): + x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + loss = self._compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False + ) + self._loss_tracker.update_state( + loss_module.unscale_loss_for_distribution(loss), + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], + ) + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + return y_pred + + def _autoconvert_optionals(self, step_func): + # Wrapper converting (nested) TF Optional in input data to None + @functools.wraps(step_func) + def wrapper(data): + converted_data = tree.map_structure( + lambda i: ( + None if isinstance(i, tf.experimental.Optional) else i + ), + data, + ) + result = step_func(converted_data) + return result + + return wrapper + + def _make_function(self, step_function): + @tf.autograph.experimental.do_not_convert + def one_step_on_data(data): + """Runs a single training step on a batch of data.""" + outputs = self.distribute_strategy.run(step_function, args=(data,)) + outputs = reduce_per_replica( + outputs, + self.distribute_strategy, + reduction="auto", + ) + return outputs + + if not self.run_eagerly: + one_step_on_data = tf.function( + one_step_on_data, + reduce_retracing=True, + jit_compile=self.jit_compile, + ) + one_step_on_data = self._autoconvert_optionals(one_step_on_data) + + @tf.autograph.experimental.do_not_convert + def multi_step_on_iterator(iterator): + if self.steps_per_execution == 1: + return tf.experimental.Optional.from_value( + one_step_on_data(iterator.get_next()) + ) + + # the spec is set lazily during the tracing of `tf.while_loop` + empty_outputs = tf.experimental.Optional.empty(None) + + def cond(execution_step, optional_outputs, next_optional_inputs): + return tf.logical_and( + tf.less(execution_step, self.steps_per_execution), + next_optional_inputs.has_value(), + ) + + def inner_body( + execution_step, optional_outputs, next_optional_inputs + ): + def has_next(): + next_optional_outputs = tf.experimental.Optional.from_value( + one_step_on_data(next_optional_inputs.get_value()) + ) + empty_outputs._element_spec = ( + next_optional_outputs.element_spec + ) + return next_optional_outputs + + def no_has_next(): + optional_outputs._element_spec = empty_outputs._element_spec + return optional_outputs + + next_optional_outputs = tf.cond( + tf.logical_and( + tf.less(execution_step, self.steps_per_execution), + next_optional_inputs.has_value(), + ), + has_next, + no_has_next, + ) + + return ( + execution_step + 1, + next_optional_outputs, + # We don't want to iterate if we have reached + # `steps_per_execution` steps + tf.cond( + tf.less(execution_step + 1, self.steps_per_execution), + lambda: iterator.get_next_as_optional(), + lambda: next_optional_inputs, + ), + ) + + def body(execution_step, optional_outputs, next_optional_inputs): + for _ in range( + min( + self.unrolled_steps_per_execution, + self.steps_per_execution, + ) + ): + execution_step, optional_outputs, next_optional_inputs = ( + inner_body( + execution_step, + optional_outputs, + next_optional_inputs, + ) + ) + + return (execution_step, optional_outputs, next_optional_inputs) + + execution_step = tf.constant(0) + next_optional_inputs = iterator.get_next_as_optional() + + # Run the while loop + _, final_optional_outputs, _ = tf.while_loop( + cond, + body, + loop_vars=[execution_step, empty_outputs, next_optional_inputs], + ) + final_optional_outputs._element_spec = empty_outputs.element_spec + return final_optional_outputs + + if not self.run_eagerly: + multi_step_on_iterator = tf.function( + multi_step_on_iterator, reduce_retracing=True + ) + + def function(iterator): + if isinstance( + iterator, (tf.data.Iterator, tf.distribute.DistributedIterator) + ): + opt_outputs = multi_step_on_iterator(iterator) + if not opt_outputs.has_value(): + raise StopIteration + return opt_outputs.get_value() + else: + for step, data in zip( + range(self.steps_per_execution), iterator + ): + outputs = one_step_on_data(data) + return outputs + + return function + + def make_train_function(self, force=False): + if self.train_function is not None and not force: + return self.train_function + self.train_function = self._make_function(self.train_step) + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + self.test_function = self._make_function(self.test_step) + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + @tf.autograph.experimental.do_not_convert + def one_step_on_data(data): + """Runs a predict test step on a batch of data.""" + return self.predict_step(data) + + if not self.run_eagerly and self.jit_compile: + one_step_on_data = tf.function( + one_step_on_data, reduce_retracing=True, jit_compile=True + ) + one_step_on_data = self._autoconvert_optionals(one_step_on_data) + + @tf.autograph.experimental.do_not_convert + def one_step_on_data_distributed(data): + data = data[0] + outputs = self.distribute_strategy.run( + one_step_on_data, args=(data,) + ) + outputs = reduce_per_replica( + outputs, + self.distribute_strategy, + reduction="concat", + ) + return outputs + + @tf.autograph.experimental.do_not_convert + def multi_step_on_data(data): + outputs = one_step_on_data_distributed(data[:1]) + for single_step_data in data[1:]: + step_outputs = one_step_on_data_distributed([single_step_data]) + outputs = tree.map_structure( + lambda t1, t2: concat([t1, t2]), outputs, step_outputs + ) + return outputs + + if self.steps_per_execution > 1: + predict_function = multi_step_on_data + else: + predict_function = one_step_on_data_distributed + + if not self.run_eagerly: + predict_function = tf.function( + predict_function, reduce_retracing=True + ) + + self.predict_function = predict_function + + @traceback_utils.filter_traceback + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + self._assert_compile_called("fit") + # Possibly cap epochs for debugging runs. + max_epochs = config.max_epochs() + if max_epochs and max_epochs < epochs: + warnings.warn("Limiting epochs to %d" % max_epochs) + epochs = max_epochs + # TODO: respect compiled trainable state + self._eval_epoch_iterator = None + if validation_split and validation_data is None: + # Create the validation data using the training data. Only supported + # for TF/numpy/jax arrays. + ( + (x, y, sample_weight), + validation_data, + ) = array_slicing.train_validation_split( + (x, y, sample_weight), validation_split=validation_split + ) + + if validation_data is not None: + ( + val_x, + val_y, + val_sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(validation_data) + + # Create an iterator that yields batches for one epoch. + epoch_iterator = TFEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + shuffle=shuffle, + class_weight=class_weight, + distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, + ) + + self._maybe_symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=epochs, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.stop_training = False + self.make_train_function() + callbacks.on_train_begin() + training_logs = None + logs = {} + initial_epoch = self._initial_epoch or initial_epoch + for epoch in range(initial_epoch, epochs): + self.reset_metrics() + callbacks.on_epoch_begin(epoch) + with epoch_iterator.catch_stop_iteration(): + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_train_batch_begin(begin_step) + logs = self.train_function(iterator) + callbacks.on_train_batch_end(end_step, logs) + if self.stop_training: + break + + # Override with model metrics instead of last step logs if needed. + epoch_logs = dict(self._get_metrics_result_or_logs(logs)) + + # Run validation. + if validation_data is not None and self._should_eval( + epoch, validation_freq + ): + # Create EpochIterator for evaluation and cache it. + if getattr(self, "_eval_epoch_iterator", None) is None: + self._eval_epoch_iterator = TFEpochIterator( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, + steps_per_epoch=validation_steps, + shuffle=False, + ) + val_logs = self.evaluate( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps=validation_steps, + callbacks=callbacks, + return_dict=True, + _use_cached_eval_dataset=True, + ) + val_logs = { + f"val_{name}": val for name, val in val_logs.items() + } + epoch_logs.update(val_logs) + + callbacks.on_epoch_end(epoch, epoch_logs) + training_logs = epoch_logs + if self.stop_training: + break + + if ( + isinstance(self.optimizer, optimizers_module.Optimizer) + and epochs > 0 + ): + self.optimizer.finalize_variable_values(self.trainable_weights) + + # If _eval_epoch_iterator exists, delete it after all epochs are done. + if getattr(self, "_eval_epoch_iterator", None) is not None: + del self._eval_epoch_iterator + callbacks.on_train_end(logs=training_logs) + return self.history + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + self._assert_compile_called("evaluate") + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = TFEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, + ) + + self._maybe_symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + self.reset_metrics() + with epoch_iterator.catch_stop_iteration(): + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) + logs = self.test_function(iterator) + callbacks.on_test_batch_end(end_step, logs) + if self.stop_evaluating: + break + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = TFEpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + distribute_strategy=self.distribute_strategy, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + def get_data(iterator): + """Returns data for the next execution.""" + data = [] + for _ in range(self.steps_per_execution): + try: + single_step_data = next(iterator) + except (StopIteration, tf.errors.OutOfRangeError) as e: + if hasattr(data, "__len__") and len(data) > 0: + # Suppress the error when still have remaining data. + return data + else: + # Re-raise the error for + # EpochIterator.catch_stop_iteration() to catch when + # no data left. + raise e + data.append(single_step_data) + return data + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + outputs = None + with epoch_iterator.catch_stop_iteration(): + for begin_step, end_step, iterator in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) + data = get_data(iterator) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end( + end_step, {"outputs": batch_outputs} + ) + if self.stop_predicting: + break + callbacks.on_predict_end() + outputs = tree.map_structure_up_to( + batch_outputs, potentially_ragged_concat, outputs + ) + return tree.map_structure(convert_to_np_if_not_ragged, outputs) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + self._assert_compile_called("train_on_batch") + if class_weight is not None: + if sample_weight is not None: + raise ValueError( + "Arguments `sample_weight` and `class_weight` " + "cannot be specified at the same time. " + f"Received: sample_weight={sample_weight}, " + f"class_weight={class_weight}" + ) + sample_weight = data_adapter_utils.class_weight_to_sample_weights( + y, class_weight + ) + + # Maybe build model + self._maybe_symbolic_build(data_batch=(x, y, sample_weight)) + self.make_train_function() + + def data(): + yield (x, y, sample_weight) + + logs = self.train_function(data()) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + self._assert_compile_called("test_on_batch") + + def data(): + yield (x, y, sample_weight) + + # Maybe build model + self._maybe_symbolic_build(data_batch=(x, y, sample_weight)) + self.make_test_function() + + logs = self.test_function(data()) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function([(x,)]) + batch_outputs = tree.map_structure( + convert_to_np_if_not_ragged, batch_outputs + ) + return batch_outputs + + # Backwards compatibility shims. + @property + def compiled_metrics(self): + class DeprecatedCompiledMetric: + def update_state(_, y, y_pred, sample_weight=None): + return self._compiled_metrics_update_state( + y, y_pred, sample_weight=sample_weight + ) + + return DeprecatedCompiledMetric() + + def _compiled_metrics_update_state(self, y, y_pred, sample_weight=None): + warnings.warn( + "`model.compiled_metrics()` is deprecated. " + "Instead, use e.g.:\n" + "```\n" + "for metric in self.metrics:\n" + " metric.update_state(y, y_pred)\n" + "```\n", + stacklevel=2, + ) + for metric in self.metrics: + if isinstance(metric, metrics_module.Mean): + metric.update_state(y_pred, sample_weight=sample_weight) + else: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + def compiled_loss( + self, y, y_pred, sample_weight=None, regularization_losses=None + ): + warnings.warn( + "`model.compiled_loss()` is deprecated. Instead, use " + "`model.compute_loss(x, y, y_pred, sample_weight, training)`.", + ) + return self.compute_loss( + x=None, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + + def loss(self, y, y_pred, sample_weight=None): + warnings.warn( + "`model.loss()` is deprecated. Instead, use " + "`model.compute_loss(x, y, y_pred, sample_weight, training)`.", + ) + return self.compute_loss( + x=None, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + + def _maybe_symbolic_build(self, iterator=None, data_batch=None): + # Only symbolic build when distribute strategy is created in tf trainer + if self._distribute_strategy is None: + # When no distribution strategy is set, defer building + # to when the train/test/predict function gets traced. + # This maximizes backwards compatibility. + return + + # Unlike jax/torch iterator, tf iterator returns an iterator instead + # of data batch in `iterator`. + if iterator is not None: + for _, _, it in iterator: + maybe_distributed_data_batch = next(it) + has_distributed_values = tree.map_structure( + lambda x: isinstance(x, tf.distribute.DistributedValues), + maybe_distributed_data_batch, + ) + if all(tree.flatten(has_distributed_values)): + data_batch = self.distribute_strategy.reduce( + "MEAN", + maybe_distributed_data_batch, + axis=None, + ) + else: + data_batch = maybe_distributed_data_batch + break + with self.distribute_strategy.scope(): + self._symbolic_build(data_batch=data_batch) + + def _aggregate_additional_loss(self, loss): + loss = super()._aggregate_additional_loss(loss) + return loss_module.scale_loss_for_distribution(loss) + + +class TFEpochIterator(EpochIterator): + def __init__(self, distribute_strategy=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self._distribute_strategy = distribute_strategy + dataset = self.data_adapter.get_tf_dataset() + if not isinstance(dataset, tf.distribute.DistributedDataset): + dataset = self._distribute_strategy.experimental_distribute_dataset( + dataset + ) + self._distributed_dataset = dataset + + def _get_iterator(self): + return self._distributed_dataset + + def tf_sync(self): + tf_context.async_wait() + + def __next__(self): + return next(self._epoch_iterator) + + @contextlib.contextmanager + def catch_stop_iteration(self): + """Catches errors when an iterator runs out of data.""" + with super().catch_stop_iteration(): + try: + yield + self.tf_sync() + except tf.errors.OutOfRangeError: + raise StopIteration + + +def reduce_per_replica(values, strategy, reduction): + """Attempt to reduce the structure `values` to single values. + + Given `values` (a `tf.Tensor` or a `PerReplica` structure), + which represents the values across all the replicas, `reduce_per_replica` + attempts to "reduce" those values and returns the corresponding structure + that represents only single values. + + Currently, `reduce_per_replica` is only used for reducing the metric results + from `tf.distribute.Strategy.run()`. Depending on the underlying + `Strategy` implementation, `values` may be a `PerReplica` object, + which can be thought of as a collection of values across the replicas, + or a `tf.Tensor`, if the strategy has already conducted the reduction + for the downstream library. + + There are five possible outcomes of reduction: + + 1) if the `values` is a structure of simple `tf.Tensor`s, meaning that + reduction is not actually needed, `reduce_per_replica` returns the + structure as-is. + 2) else, if `reduction="auto"`, then the best reduction strategy is + chosen based on the current environment. This should only be used + for training cases (`fit()`). + 3) else, if `reduction="first"`, then `reduce_per_replica` + returns the values of the first replica. This is used in the case of + training and evaluation, where `values` is expected to hold the same + value across the replicas as a result of `Strategy`'s synchronization + across the replicas. + `reduce_per_replica` does not synchronize the values. + 4) else, if `reduction="sum"`, then `reduce_per_replica` returns the sum + of values for all replicas. This may be used in the custom training loop + case, where each replica contain different values which are not + synchronized. + 5) else, if `reduction="concat"`, then `reduce_per_replica` + returns the concatenation of the values across the replicas, along the + axis of dimension 0. This is used in the inference case (`predict()`). + + Args: + values: Structure of `PerReplica` objects or `tf.Tensor`s. + `tf.Tensor`s are returned as-is. + strategy: `tf.distribute.Strategy` object. + reduction: One of `"auto"`, `"first"`, `"concat"`, `"mean"`, or `"sum"`. + `"auto"` will select `"first"` when used under a TPUStrategy, or + `"mean"` otherwise. + + Returns: + Structure of `Tensor`s, representing the result of reduction. + """ + + if reduction == "auto": + if isinstance(strategy, tf.distribute.TPUStrategy): + reduction = "first" + else: + reduction = "mean" + + def _reduce(v): + """Reduce a single `PerReplica` object.""" + if _collective_all_reduce_multi_worker(strategy): + if reduction == "concat": + return _multi_worker_concat(v, strategy) + elif reduction == "sum": + return strategy.reduce("SUM", v) + elif reduction == "mean": + return strategy.reduce("MEAN", v, axis=0) + + if not _is_per_replica_instance(v): + return v + elif reduction == "first": + return strategy.experimental_local_results(v)[0] + elif reduction == "concat": + if _is_tpu_multi_host(strategy): + return _tpu_multi_host_concat(v, strategy) + else: + return concat(strategy.experimental_local_results(v)) + elif reduction == "sum": + return tf.reduce_sum(strategy.experimental_local_results(v)) + elif reduction == "mean": + return tf.reduce_mean( + strategy.experimental_local_results(v), axis=0 + ) + else: + raise ValueError( + "`reduction` must be one of " + '"first", "concat", "mean", "sum", or "auto". ' + f"Received: reduction={reduction}." + ) + + return tree.map_structure(_reduce, values) + + +def _multi_worker_concat(v, strategy): + """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" + replicas = strategy.gather(v, axis=0) + # v might not have the same shape on different replicas + if _is_per_replica_instance(v): + shapes = tf.concat( + [ + tf.expand_dims(tf.shape(single_value)[0], axis=0) + for single_value in v.values + ], + axis=0, + ) + all_shapes = strategy.gather(shapes, axis=0) + else: + # v is a tensor. This may happen when, say, we have 2x1 multi-worker. + all_shapes = strategy.gather( + tf.expand_dims(tf.shape(v)[0], axis=0), axis=0 + ) + + replicas = tf.split( + replicas, + num_or_size_splits=all_shapes, + num=strategy.num_replicas_in_sync, + ) + ordered_replicas = [] + num_replicas_per_worker = len(strategy.extended.worker_devices) + for replica_id in range(num_replicas_per_worker): + ordered_replicas += replicas[replica_id::num_replicas_per_worker] + return concat(ordered_replicas) + + +def concat(tensors, axis=0): + """Concats `tensor`s along `axis`.""" + if isinstance(tensors[0], tf.SparseTensor): + return tf.sparse.concat(axis=axis, sp_inputs=tensors) + elif _is_scalar(tensors[0]): + return tf.stack(tensors, axis=axis) + else: + return tf.concat(tensors, axis=axis) + + +def _tpu_multi_host_concat(v, strategy): + """Correctly order TPU PerReplica objects.""" + replicas = strategy.experimental_local_results(v) + # When distributed datasets are created from Tensors / NumPy, + # TPUStrategy.experimental_distribute_dataset shards data in + # (Replica, Host) order, and TPUStrategy.experimental_local_results returns + # it in (Host, Replica) order. + num_replicas_per_host = strategy.extended.num_replicas_per_host + ordered_replicas = [] + for replica_id in range(num_replicas_per_host): + ordered_replicas += replicas[replica_id::num_replicas_per_host] + return concat(ordered_replicas) + + +def _collective_all_reduce_multi_worker(strategy): + return ( + isinstance(strategy, tf.distribute.MultiWorkerMirroredStrategy) + ) and strategy.extended._in_multi_worker_mode() + + +def _is_per_replica_instance(obj): + return isinstance(obj, tf.distribute.DistributedValues) and isinstance( + obj, tf.__internal__.CompositeTensor + ) + + +def _is_scalar(x): + return isinstance(x, (tf.Tensor, tf.Variable)) and x.shape.rank == 0 + + +def _is_tpu_multi_host(strategy): + return _is_tpu_strategy(strategy) and strategy.extended.num_hosts > 1 + + +def _is_tpu_strategy(strategy): + return _is_tpu_strategy_class(strategy.__class__) + + +def _is_tpu_strategy_class(clz): + def is_tpu_strat(k): + return k.__name__.startswith("TPUStrategy") + + if is_tpu_strat(clz): + return True + return any(map(_is_tpu_strategy_class, clz.__bases__)) + + +def convert_to_np_if_not_ragged(x): + if isinstance(x, tf.RaggedTensor): + return x + elif isinstance(x, tf.SparseTensor): + return x + return x.numpy() + + +def potentially_ragged_concat(tensors): + """Concats `Tensor`s along their first dimension. + + Args: + tensors: List of `Tensor`s. + + Returns: + Concatenation of the inputs along the first dimension -- of type + `np.ndarray` if all input shapes are compatible, or `tf.RaggedTensor` + if not. + """ + if len(tensors) == 1: + return tensors[0] + elif isinstance(tensors[0], tf.SparseTensor): + return tf.sparse.concat(axis=0, sp_inputs=tensors) + elif isinstance(tensors[0], tf.RaggedTensor): + return tf.concat(tensors, axis=0) + + non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors]) + constant_dims = tf.math.reduce_all( + non_batch_shapes == non_batch_shapes[:1], axis=0 + ) + if tf.math.reduce_all(constant_dims).numpy().item(): + # All non-batch dims are constant + if _is_scalar(tensors[0]): + return tf.stack(tensors, axis=0) + else: + return tf.concat(tensors, axis=0) + + # First, identify constant inner dimensions by finding the + # rightmost dimension that is not constant + constant_inner_dimensions = ( + constant_dims.numpy().tolist()[::-1].index(False) + ) + # If there are constant inner dimensions, define a constant inner shape + if constant_inner_dimensions == 0: + constant_inner_shape = None + else: + constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:] + return tf.ragged.constant( + [tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape + ).merge_dims(0, 1) diff --git a/keras/src/backend/tests/compute_output_spec_test.py b/keras/src/backend/tests/compute_output_spec_test.py new file mode 100644 index 000000000000..4d6fa2795f81 --- /dev/null +++ b/keras/src/backend/tests/compute_output_spec_test.py @@ -0,0 +1,111 @@ +import unittest + +import pytest + +from keras.src import backend +from keras.src import ops +from keras.src.backend.common.keras_tensor import KerasTensor + + +def single_arg_test_fn(x): + return ops.concatenate([(x + 1) ** 2, x], axis=-1) + + +def three_args_2_kwarg_test_fn(x1, x2, x3=None): + x1 = ops.max(x1, axis=1) + x2 = ops.max(x2, axis=1) + if x3 is not None: + x1 += ops.max(x3, axis=1) + return x1 + x2 + + +class ComputeOutputSpecTest(unittest.TestCase): + def test_dynamic_batch_size(self): + x = KerasTensor(shape=(None, 3, 5)) + y = backend.compute_output_spec(single_arg_test_fn, x) + self.assertEqual(y.shape, (None, 3, 10)) + + x1 = KerasTensor(shape=(None, 3, 5)) + x2 = KerasTensor(shape=(None, 3, 5)) + x3 = KerasTensor(shape=(None, 3, 5)) + y = backend.compute_output_spec( + three_args_2_kwarg_test_fn, x1, x2, x3=x3 + ) + self.assertEqual(y.shape, (None, 5)) + + def test_dynamic_everything(self): + x = KerasTensor(shape=(2, None, 3)) + y = backend.compute_output_spec(single_arg_test_fn, x) + self.assertEqual(y.shape, (2, None, 6)) + + x1 = KerasTensor(shape=(None, None, 5)) + x2 = KerasTensor(shape=(None, None, 5)) + x3 = KerasTensor(shape=(None, None, 5)) + y = backend.compute_output_spec( + three_args_2_kwarg_test_fn, x1, x2, x3=x3 + ) + self.assertEqual(y.shape, (None, 5)) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_sparse_to_sparse(self): + def single_arg_sparse_fn(x): + y0 = ops.transpose(x, axes=(0, 2, 1)) + y1 = ops.squeeze(ops.expand_dims(x, axis=3), axis=3) + return (y0, y1) + + x = KerasTensor(shape=(None, 3, 3), sparse=True) + ys = backend.compute_output_spec(single_arg_sparse_fn, x) + for y in ys: + self.assertEqual(y.shape, (None, 3, 3)) + self.assertTrue(y.sparse) + + def three_args_sparse_fn(x1, x2, x3=None): + y0 = ops.add(x1, x2) # sparse, sparse + y1 = ops.divide(x1, x3) # sparse, dense + y2 = ops.matmul(x1, x2) # sparse, sparse + y3 = ops.multiply(x1, x3) # sparse, dense + return (y0, y1, y2, y3) + + x1 = KerasTensor(shape=(None, 3, 3), sparse=True) + x2 = KerasTensor(shape=(None, 3, 3), sparse=True) + x3 = KerasTensor(shape=(None, 3, 3), sparse=False) + ys = backend.compute_output_spec(three_args_sparse_fn, x1, x2, x3=x3) + for y in ys: + self.assertEqual(y.shape, (None, 3, 3)) + self.assertTrue(y.sparse) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_sparse_to_dense(self): + def single_arg_dense_fn(x): + y0 = ops.exp(x) + return (y0,) + + x = KerasTensor(shape=(None, 3, 3), sparse=True) + ys = backend.compute_output_spec(single_arg_dense_fn, x) + for y in ys: + self.assertEqual(y.shape, (None, 3, 3)) + self.assertFalse(y.sparse) + + def three_args_dense_fn(x1, x2, x3=None): + y0 = ops.add(x1, x2) # sparse, dense + y1 = ops.add(x2, x1) # dense, sparse + y2 = ops.concatenate([x1, x2], axis=0) # sparse, dense + y3 = ops.matmul(x1, x2) # sparse, dense + y4 = ops.matmul(x2, x1) # dense, sparse + y5 = ops.take(x2, indices=x3, axis=1) # dense, sparse + y6 = ops.divide(x1, x1) # sparse, sparse + return (y0, y1, y2, y3, y4, y5, y6) + + x1 = KerasTensor(shape=(None, 3, 3), sparse=True) + x2 = KerasTensor(shape=(None, 3, 3), sparse=False) + x3 = KerasTensor(shape=(3,), dtype="int64", sparse=True) + ys = backend.compute_output_spec(three_args_dense_fn, x1, x2, x3=x3) + for y in ys: + self.assertEqual(y.shape, (None, 3, 3)) + self.assertFalse(y.sparse) diff --git a/keras/src/backend/tests/device_scope_test.py b/keras/src/backend/tests/device_scope_test.py new file mode 100644 index 000000000000..0b0f2f91c4d6 --- /dev/null +++ b/keras/src/backend/tests/device_scope_test.py @@ -0,0 +1,106 @@ +import pytest + +from keras.src import backend +from keras.src import testing + + +class DeviceTest(testing.TestCase): + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_tf_device_scope(self): + import tensorflow as tf + + if not tf.config.list_physical_devices("GPU"): + self.skipTest("Need at least one GPU for testing") + + with backend.device("cpu:0"): + t = backend.numpy.ones((2, 1)) + self.assertIn("CPU:0", t.device) + with backend.device("CPU:0"): + t = backend.numpy.ones((2, 1)) + self.assertIn("CPU:0", t.device) + + # When leaving the scope, the device should be back with gpu:0 + t = backend.numpy.ones((2, 1)) + self.assertIn("GPU:0", t.device) + + # Also verify the explicit gpu device + with backend.device("gpu:0"): + t = backend.numpy.ones((2, 1)) + self.assertIn("GPU:0", t.device) + + @pytest.mark.skipif(backend.backend() != "jax", reason="jax only") + def test_jax_device_scope(self): + import jax + + def get_device(t): + # After updating to Jax 0.4.33, Directly access via t.device attr. + return list(t.devices())[0] + + platform = jax.default_backend() + + if platform != "gpu": + self.skipTest("Need at least one GPU for testing") + + with backend.device("cpu:0"): + t = backend.numpy.ones((2, 1)) + self.assertEqual(get_device(t), jax.devices("cpu")[0]) + with backend.device("CPU:0"): + t = backend.numpy.ones((2, 1)) + self.assertEqual(get_device(t), jax.devices("cpu")[0]) + + # When leaving the scope, the device should be back with gpu:0 + t = backend.numpy.ones((2, 1)) + self.assertEqual(get_device(t), jax.devices("gpu")[0]) + + # Also verify the explicit gpu device + with backend.device("gpu:0"): + t = backend.numpy.ones((2, 1)) + self.assertEqual(get_device(t), jax.devices("gpu")[0]) + + @pytest.mark.skipif(backend.backend() != "jax", reason="jax only") + def test_invalid_jax_device(self): + with self.assertRaisesRegex(ValueError, "Received: device_name='123'"): + backend.device(123).__enter__() + + @pytest.mark.skipif(backend.backend() != "torch", reason="torch only") + def test_torch_device_scope(self): + import torch + + with backend.device("cpu:0"): + t = backend.numpy.ones((2, 1)) + self.assertEqual(t.device, torch.device("cpu")) + with backend.device("CPU:0"): + t = backend.numpy.ones((2, 1)) + self.assertEqual(t.device, torch.device("cpu")) + + # Need at least one GPU for the following testing. + if not torch.cuda.is_available(): + return + + # When leaving the scope, the device should be back with gpu:0 + t = backend.numpy.ones((2, 1)) + self.assertEqual(t.device, torch.device("cuda", 0)) + + # Also verify the explicit gpu -> cuda conversion + with backend.device("gpu:0"): + t = backend.numpy.ones((2, 1)) + self.assertEqual(t.device, torch.device("cuda", 0)) + + @pytest.mark.skipif(backend.backend() != "torch", reason="torch only") + def test_invalid_torch_device(self): + with self.assertRaisesRegex(ValueError, "Received: device_name='123'"): + backend.device(123).__enter__() + + @pytest.mark.skipif(backend.backend() != "torch", reason="torch only") + def test_torch_meta_device(self): + import torch + + with torch.device("meta"): + x = torch.ones(5) + + t = backend.convert_to_tensor(x) + + if not torch.cuda.is_available(): + self.assertEqual(t.device, torch.device("cpu")) + else: + self.assertEqual(t.device, torch.device("cuda", 0)) diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py new file mode 100644 index 000000000000..371a62cd0f52 --- /dev/null +++ b/keras/src/backend/torch/__init__.py @@ -0,0 +1,45 @@ +"""Torch backend APIs. + +# Note on device placement + +Torch has a different device placement style compared to TF and JAX. +In short, variables/tensors are not created on GPU by default, +and the GPU cannot directly communicate with the CPU. +To bring Torch behavior in line with TF and JAX automated device placement, +we are doing the following to automate device placement if a GPU is available: + +- Variables are created on GPU. +- Input data will be placed on GPU at the first `keras.layers.Layer` call. +- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU. +- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy. +""" + +from keras.src.backend.common.name_scope import name_scope +from keras.src.backend.torch import core +from keras.src.backend.torch import image +from keras.src.backend.torch import linalg +from keras.src.backend.torch import math +from keras.src.backend.torch import nn +from keras.src.backend.torch import numpy +from keras.src.backend.torch import random +from keras.src.backend.torch.core import IS_THREAD_SAFE +from keras.src.backend.torch.core import SUPPORTS_RAGGED_TENSORS +from keras.src.backend.torch.core import SUPPORTS_SPARSE_TENSORS +from keras.src.backend.torch.core import Variable +from keras.src.backend.torch.core import cast +from keras.src.backend.torch.core import compute_output_spec +from keras.src.backend.torch.core import cond +from keras.src.backend.torch.core import convert_to_numpy +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import device_scope +from keras.src.backend.torch.core import is_tensor +from keras.src.backend.torch.core import random_seed_dtype +from keras.src.backend.torch.core import scatter +from keras.src.backend.torch.core import shape +from keras.src.backend.torch.core import stop_gradient +from keras.src.backend.torch.core import to_torch_dtype +from keras.src.backend.torch.core import vectorized_map +from keras.src.backend.torch.rnn import cudnn_ok +from keras.src.backend.torch.rnn import gru +from keras.src.backend.torch.rnn import lstm +from keras.src.backend.torch.rnn import rnn diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py new file mode 100644 index 000000000000..877dc6909ea1 --- /dev/null +++ b/keras/src/backend/torch/core.py @@ -0,0 +1,732 @@ +import builtins +import contextlib +import functools + +import ml_dtypes +import numpy as np +import torch + +from keras.src import tree +from keras.src.backend.common import KerasVariable +from keras.src.backend.common import global_state +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.backend_utils import slice_along_axis +from keras.src.backend.common.dtypes import result_type +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.stateless_scope import get_stateless_scope +from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.config import floatx + +SUPPORTS_SPARSE_TENSORS = False +SUPPORTS_RAGGED_TENSORS = False +IS_THREAD_SAFE = True + +# Some operators such as 'aten::_foreach_mul_.Scalar' +# are not currently implemented for the MPS device. +# check https://github.com/pytorch/pytorch/issues/77764. +if torch.backends.mps.is_available(): + DEFAULT_DEVICE = "mps" +elif torch.cuda.is_available(): + DEFAULT_DEVICE = "cuda" +elif hasattr(torch, "xpu") and torch.xpu.is_available(): + DEFAULT_DEVICE = "xpu" +else: + DEFAULT_DEVICE = "cpu" + +TORCH_DTYPES = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "uint8": torch.uint8, + "uint16": torch.int32, # TODO: Torch doesn't have `uint16` dtype. + "uint32": torch.int64, # TODO: Torch doesn't have `uint32` dtype. + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "bfloat16": torch.bfloat16, + "bool": torch.bool, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, + "complex32": torch.complex32, + "complex64": torch.complex64, + "complex128": torch.complex128, +} + + +@contextlib.contextmanager +def device_scope(device_name): + previous_device = global_state.get_global_attribute("torch_device", None) + current_device = _parse_device_input(device_name) + global_state.set_global_attribute("torch_device", current_device) + try: + yield torch.device(current_device) + finally: + global_state.set_global_attribute("torch_device", previous_device) + + +def get_device(): + device = global_state.get_global_attribute("torch_device", None) + if device is None: + return DEFAULT_DEVICE + return device + + +def _parse_device_input(device_name): + if isinstance(device_name, str): + # We support string value like "cpu:0", "gpu:1", and need to convert + # "gpu" to "cuda" + device_name = device_name.lower() + if "gpu" in device_name: + device_name = device_name.replace("gpu", "cuda") + else: + raise ValueError( + "Invalid value for argument `device_name`. " + "Expected a string like 'gpu:0' or 'cpu'. " + f"Received: device_name='{device_name}'" + ) + # The torch.Device instance can be used directly. + return device_name + + +def to_torch_dtype(dtype): + standardized_dtype = TORCH_DTYPES.get(standardize_dtype(dtype), None) + if standardized_dtype is None: + raise ValueError(f"Unsupported dtype for PyTorch: {dtype}") + return standardized_dtype + + +class Variable(KerasVariable): + def _initialize(self, value): + if isinstance(value, torch.nn.Parameter): + # Reuse same parameter + self._value = value + else: + self._value = torch.nn.Parameter( + convert_to_tensor(value, dtype=self._dtype), + requires_grad=self.trainable, + ).to(get_device()) + + def _direct_assign(self, value): + with torch.no_grad(): + self.value.copy_(value) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) + + # Overload native accessor. + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + args = [arg.value if isinstance(arg, Variable) else arg for arg in args] + if kwargs is None: + kwargs = {} + kwargs = { + key: value.value if isinstance(value, Variable) else value + for key, value in kwargs.items() + } + return func(*args, **kwargs) + + def __array__(self, dtype=None): + value = convert_to_numpy(self.value) + if dtype: + return value.astype(dtype) + return value + + @property + def value(self): + # We cannot chain super() here because it will fail TorchDynamo. The + # reason why is unclear. + def maybe_use_symbolic_tensor(value): + # Create and use a symbolic tensor stub in symbolic calls. + if str(get_device()) == "meta" and str(value.device) != "meta": + return torch.nn.Parameter( + torch.empty( + size=self._shape, + dtype=to_torch_dtype(self._dtype), + device="meta", + ), + requires_grad=self.trainable, + ) + return value + + if in_stateless_scope(): + scope = get_stateless_scope() + value = scope.get_current_value(self) + if value is not None: + value = self._maybe_autocast(value) + return maybe_use_symbolic_tensor(value) + if self._value is None: + # Uninitialized variable. Return a placeholder. + # This is fine because it's only ever used + # in during shape inference / graph tracing + # (anything else would be a bug, to be fixed.) + value = self._maybe_autocast( + self._initializer(self._shape, dtype=self._dtype) + ) + else: + value = self._maybe_autocast(self._value) + return maybe_use_symbolic_tensor(value) + + @property + def trainable(self): + return self._trainable + + @trainable.setter + def trainable(self, value): + self._trainable = value + if self._value is not None: + self._value.requires_grad = value + + def __eq__(self, other): + try: + return super().__eq__(other) + except Exception: + return False + + +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + if sparse: + raise ValueError("`sparse=True` is not supported with torch backend") + if ragged: + raise ValueError("`ragged=True` is not supported with torch backend") + if isinstance(x, Variable) or is_tensor(x): + if isinstance(x, Variable): + x = x.value + device = get_device() + if x.device != device: + if x.is_meta: + x = torch.empty_like(x, device=device) + else: + x = x.to(device) + if dtype is not None: + x = x.to(to_torch_dtype(dtype)) + return x + if dtype is None: + if isinstance(x, bool): + return torch.as_tensor(x, dtype=torch.bool, device=get_device()) + elif isinstance(x, int): + return torch.as_tensor(x, dtype=torch.int32, device=get_device()) + elif isinstance(x, float): + return torch.as_tensor( + x, dtype=to_torch_dtype(floatx()), device=get_device() + ) + + # Convert to np in case of any array-like that is not list or tuple. + if not isinstance(x, (list, tuple)): + x = np.array(x) + elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x): + # Handle list or tuple of torch tensors + return torch.stack([convert_to_tensor(x1) for x1 in x]) + if isinstance(x, np.ndarray): + if x.dtype == np.uint32: + # Torch backend does not support uint32. + x = x.astype(np.int64) + if standardize_dtype(x.dtype) == "bfloat16": + # Torch backend does not support converting bfloat16 ndarray. + x = x.astype(np.float32) + dtype = "bfloat16" + dtype = dtype or x.dtype + if dtype is None: + dtype = result_type( + *[getattr(item, "dtype", type(item)) for item in tree.flatten(x)] + ) + dtype = to_torch_dtype(dtype) + return torch.as_tensor(x, dtype=dtype, device=get_device()) + + +def convert_to_numpy(x): + def transform(x): + if is_tensor(x): + if x.requires_grad: + x = x.detach() + # Tensor has to be moved to CPU before converting to numpy. + if x.device != torch.device("cpu"): + x = x.cpu() + if x.dtype == torch.bfloat16: + # Attempting to call .numpy() on a bfloat16 torch tensor leads + # to an immediate error. Instead we upcast to float32 and then + # convert to the numpy friendly bfloat16 type. + # https://github.com/pytorch/pytorch/issues/90574 + return np.array(x.to(torch.float32)).astype(ml_dtypes.bfloat16) + return np.array(x) + + if isinstance(x, (list, tuple)): + return np.array([transform(e) for e in x]) + return transform(x) + + +def is_tensor(x): + # Using the built-in `isinstance` is recommended by pytorch + # over using torch.is_tensor + # see: https://pytorch.org/docs/stable/generated/torch.is_tensor.html + # + # Also, `torch.is_tensor()` causes issues with dynamo caching when + # a torch.Tensor and numpy.ndarray of the same size, shape, and dtype + # is passed, if called on a Tensor first the second call with ndarray + # will return `True` and vice-versa. + return isinstance(x, torch.Tensor) + + +def shape(x): + # Convert from `torch.Size` to plain tuple. + return tuple(x.shape) + + +def cast(x, dtype): + dtype = to_torch_dtype(dtype) + if isinstance(x, Variable): + x = x.value + if is_tensor(x): + if x.dtype == dtype: + return x + else: + return x.to(dtype) + return convert_to_tensor(x, dtype) + + +# Shape / dtype inference util +def compute_output_spec(fn, *args, **kwargs): + def has_none_shape(x): + """Check for if a `KerasTensor` has dynamic shape.""" + if isinstance(x, KerasTensor): + return None in x.shape + return False + + def convert_keras_tensor_to_torch(x, fill_value=None): + """Convert `KerasTensor`s to `torch.Tensor`s.""" + if isinstance(x, KerasTensor): + shape = list(x.shape) + if fill_value: + for i, e in enumerate(shape): + if e is None: + shape[i] = fill_value + return torch.ones( + size=shape, + dtype=TORCH_DTYPES[x.dtype], + device=get_device(), + ) + return x + + def convert_torch_to_keras_tensor(x): + """Convert `torch.Tensor`s to `KerasTensor`s.""" + if is_tensor(x): + return KerasTensor(x.shape, standardize_dtype(x.dtype)) + return x + + def symbolic_call(fn, args, kwargs, fill_value): + """Call `fn` to infer output shape and dtype.""" + try: + # First try instantiating all tensors on the `"meta"` device, + # which should give a "zero flop" way to trace shape, but does + # not have universal support with torch operations. + with device_scope("meta"): + meta_args, meta_kwargs = tree.map_structure( + lambda x: convert_keras_tensor_to_torch(x, fill_value), + (args, kwargs), + ) + return fn(*meta_args, **meta_kwargs) + except: + with device_scope(DEFAULT_DEVICE): + # If the `"meta"` device placement fails, fall back to tracing + # eagerly with tensors on the default device. This will be + # more robust, but more expensive. + eager_args, eager_kwargs = tree.map_structure( + lambda x: convert_keras_tensor_to_torch(x, fill_value), + (args, kwargs), + ) + return fn(*eager_args, **eager_kwargs) + + with StatelessScope(), SymbolicScope(), torch.no_grad(): + outputs = symbolic_call(fn, args, kwargs, fill_value=83) + + none_in_shape = any( + builtins.map(has_none_shape, tree.flatten((args, kwargs))) + ) + if none_in_shape: + outputs_1 = outputs + outputs_2 = symbolic_call(fn, args, kwargs, fill_value=89) + + flat_out_1 = tree.flatten(outputs_1) + flat_out_2 = tree.flatten(outputs_2) + + flat_out = [] + for x1, x2 in zip(flat_out_1, flat_out_2): + shape = list(x1.shape) + for i, e in enumerate(x2.shape): + if e != shape[i]: + shape[i] = None + flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype))) + outputs = tree.pack_sequence_as(outputs_1, flat_out) + + output_spec = tree.map_structure(convert_torch_to_keras_tensor, outputs) + return output_spec + + +def cond(pred, true_fn, false_fn): + # When symbolic execution, take pred as true. + if get_device() == "meta": + return true_fn() + + if pred: + return true_fn() + return false_fn() + + +def vectorized_map(function, elements): + return torch.vmap(function)(elements) + + +def map(f, xs): + def g(_, x): + return (), f(x) + + _, ys = scan(g, (), xs) + return ys + + +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + # Ref: jax.lax.scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + if not isinstance(unroll, bool): + if not isinstance(unroll, int) or unroll < 1: + raise ValueError( + "`unroll` must be an positive integer or boolean. " + f"Received: unroll={unroll}" + ) + if xs is None and length is None: + raise ValueError("Got no `xs` to scan over and `length` not provided.") + + input_is_sequence = tree.is_nested(xs) + output_is_sequence = tree.is_nested(init) + + def pack_input(x): + return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0] + + def pack_output(x): + return tree.pack_sequence_as(init, x) if output_is_sequence else x[0] + + if xs is None: + xs_flat = [] + n = int(length) + else: + xs_flat = tree.flatten(xs) + xs_flat = [convert_to_tensor(elem) for elem in xs_flat] + n = int(length) if length is not None else shape(xs_flat[0])[0] + + init_flat = tree.flatten(init) + init_flat = [convert_to_tensor(init) for init in init_flat] + init = pack_output(init_flat) + dummy_y = [torch.zeros_like(init) for init in init_flat] + + carry = init + ys = [] + maybe_reversed = reversed if reverse else lambda x: x + for i in maybe_reversed(range(n)): + xs_slice = [x[i] for x in xs_flat] + packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None + carry, y = f(carry, packed_xs) + ys.append(y if y is not None else dummy_y) + stacked_y = tree.map_structure( + lambda *ys: torch.stack(ys), *maybe_reversed(ys) + ) + return carry, stacked_y + + +def associative_scan(f, elems, reverse=False, axis=0): + # Ref: jax.lax.associative_scan + if not callable(f): + raise TypeError(f"`f` should be a callable. Received: f={f}") + elems_flat = tree.flatten(elems) + elems_flat = [convert_to_tensor(elem) for elem in elems_flat] + if reverse: + elems_flat = [torch.flip(elem, (axis,)) for elem in elems_flat] + + def _combine(a_flat, b_flat): + a_flat = [convert_to_tensor(a) for a in a_flat] + b_flat = [convert_to_tensor(b) for b in b_flat] + + a = tree.pack_sequence_as(elems, a_flat) + b = tree.pack_sequence_as(elems, b_flat) + c = f(a, b) + c_flat = tree.flatten(c) + return c_flat + + num_elems = int(elems_flat[0].shape[axis]) + if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): + raise ValueError( + "Array inputs to associative_scan must have the same " + "first dimension. (saw: {})".format( + [elem.shape for elem in elems_flat] + ) + ) + + def _interleave(a, b, axis): + """Given two Tensors of static shape, interleave them along axis.""" + assert ( + a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 + ) + + # we want to get a: [a1, a2], b: [b1, b2] + # to a: [a1, 0, a2, 0], b: [0, b1, 0, b2] + a_shape = list(a.shape) + a_shape[axis] = a.shape[axis] * 2 - 1 + + b_shape = list(b.shape) + b_shape[axis] = b.shape[axis] * 2 - 1 + + a_dil = torch.zeros(a_shape) + slice_along_axis(a_dil, 0, None, 2, axis).copy_(a) + + b_dil = torch.zeros(b_shape) + slice_along_axis(b_dil, 0, None, 2, axis).copy_(b) + + a_pad = [[0, 0] for _ in range(a.dim())] + a_pad[axis][-1] = 1 if a.shape[axis] == b.shape[axis] else 0 + a_pad = a_pad[::-1] + a_pad = tree.flatten(a_pad) + + b_pad = [[0, 0] for _ in range(b.dim())] + b_pad[axis] = [1, 0] if a.shape[axis] == b.shape[axis] else [1, 1] + b_pad = b_pad[::-1] + b_pad = tree.flatten(b_pad) + + op = torch.bitwise_or if a.dtype == torch.bool else torch.add + return op( + torch.nn.functional.pad(a_dil, a_pad), + torch.nn.functional.pad(b_dil, b_pad), + ) + + def _scan(elems): + num_elems = elems[0].shape[axis] + if num_elems < 2: + return elems + + reduced_elems = _combine( + [ + slice_along_axis(elem, 0, -1, step=2, axis=axis) + for elem in elems + ], + [ + slice_along_axis(elem, 1, None, step=2, axis=axis) + for elem in elems + ], + ) + + odd_elems = _scan(reduced_elems) + if num_elems % 2 == 0: + even_elems = _combine( + [slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems], + [ + slice_along_axis(e, 2, None, step=2, axis=axis) + for e in elems + ], + ) + else: + even_elems = _combine( + odd_elems, + [ + slice_along_axis(e, 2, None, step=2, axis=axis) + for e in elems + ], + ) + + even_elems = [ + torch.cat( + [slice_along_axis(elem, 0, 1, axis=axis), result], + dim=axis, + ) + for (elem, result) in zip(elems, even_elems) + ] + return list( + builtins.map( + functools.partial(_interleave, axis=axis), even_elems, odd_elems + ) + ) + + scans = _scan(elems_flat) + if reverse: + scans = [torch.flip(scanned, (axis,)) for scanned in scans] + + return tree.pack_sequence_as(elems, scans) + + +def scatter(indices, values, shape): + indices = convert_to_tensor(indices) + values = convert_to_tensor(values) + zeros = torch.zeros(shape, dtype=values.dtype, device=get_device()) + + index_length = indices.shape[-1] + value_shape = shape[index_length:] + indices = torch.reshape(indices, [-1, index_length]) + values = torch.reshape(values, [-1] + list(value_shape)) + + for i in range(indices.shape[0]): + index = indices[i] + zeros[tuple(index)] += values[i] + return zeros + + +def scatter_update(inputs, indices, updates): + inputs = convert_to_tensor(inputs) + indices = convert_to_tensor(indices, dtype="int64") + updates = convert_to_tensor(updates, dtype=inputs.dtype) + indices = torch.transpose(indices, 0, 1) + + outputs = torch.clone(inputs) + outputs[tuple(indices)] = updates + return outputs + + +def slice(inputs, start_indices, shape): + shape_dtype = to_torch_dtype("int64") + inputs = convert_to_tensor(inputs) + start_indices = convert_to_tensor(start_indices).to(shape_dtype) + shape = convert_to_tensor(shape).to(shape_dtype) + + python_slice = __builtins__["slice"] + slices = [ + python_slice(start_index, start_index + length) + for start_index, length in zip(start_indices, shape) + ] + return inputs[slices] + + +def slice_update(inputs, start_indices, updates): + shape_dtype = to_torch_dtype("int64") + inputs = convert_to_tensor(inputs) + start_indices = convert_to_tensor(start_indices).to(shape_dtype) + updates = convert_to_tensor(updates) + + python_slice = __builtins__["slice"] + slices = [ + python_slice(start_index, start_index + update_length) + for start_index, update_length in zip(start_indices, updates.shape) + ] + outputs = torch.clone(inputs) + outputs[slices] = updates + return outputs + + +def switch(index, branches, *operands): + index = convert_to_tensor(index, "int32") + index = torch.clamp(index, 0, len(branches) - 1) + return branches[index](*operands) + + +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + current_iter = 0 + iteration_check = ( + lambda iter: maximum_iterations is None or iter < maximum_iterations + ) + is_tuple = isinstance(loop_vars, (tuple, list)) + loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,) + loop_vars = tree.map_structure(convert_to_tensor, loop_vars) + while cond(*loop_vars) and iteration_check(current_iter): + loop_vars = body(*loop_vars) + if not isinstance(loop_vars, (list, tuple)): + loop_vars = (loop_vars,) + loop_vars = tuple(loop_vars) + current_iter += 1 + return loop_vars if is_tuple else loop_vars[0] + + +def fori_loop(lower, upper, body_fun, init_val): + val = init_val + for i in range(lower, upper): + val = body_fun(i, val) + return val + + +def stop_gradient(variable): + if isinstance(variable, Variable): + variable = variable.value + # We can't use `.requires_grad_(False)` here since it only + # works when the tensor is a leaf node in the graph. + return variable.detach() + + +def unstack(x, num=None, axis=0): + return x.unbind(axis) + + +def random_seed_dtype(): + # uint32 doesn't exist in torch, use int32 instead. + return "int32" + + +def remat(f): + """Implementation of rematerialization. + + Args: + f: The function or operation to rematerialize. + Returns: + A function wrapping f that defines a custom gradient, which + recomputes f on the backwards pass of a gradient call. + """ + + def wrapped(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False) + + return wrapped + + +class custom_gradient: + """Decorator for custom gradients. + + Args: + forward_fn: Forward pass function. + """ + + def __init__(self, forward_fn): + self.forward_fn = forward_fn + + def __call__(self, *args, **kwargs): + return CustomGradientFunction.apply(self.forward_fn, *args, **kwargs) + + +class CustomGradientFunction(torch.autograd.Function): + """Enables custom forward & backward passes for gradient computation.""" + + @staticmethod + def forward(ctx, forward_fn, *args, **kwargs): + """Forward pass computation specification. + + Args: + ctx: Context object. + forward_fn: Function to compute forward pass. + *args: Arguments for the forward pass. + **kwargs: Keyword arguments for the forward pass. + """ + ctx.forward_fn = forward_fn + ctx.save_for_backward(*args) + try: + output, ctx.grad_fn = forward_fn(*args, **kwargs) + except: + output = forward_fn(*args, **kwargs) + ctx.grad_fn = lambda *args, **kwargs: torch.full((), float("nan")) + return output + + @staticmethod + def backward(ctx, grad_output): + """Backward pass computation specification. + + Args: + ctx: Context object. + grad_output: Gradient with respect to the output. + """ + args = ctx.saved_tensors + grad_fn = ctx.grad_fn + if grad_fn is None: + raise ValueError("grad_fn must be provided for custom gradient") + grads = grad_fn(*args, upstream=grad_output) + if not isinstance(grads, tuple): + grads = (grads,) + return (None,) + grads diff --git a/keras/src/backend/torch/export.py b/keras/src/backend/torch/export.py new file mode 100644 index 000000000000..4ec77610a046 --- /dev/null +++ b/keras/src/backend/torch/export.py @@ -0,0 +1,128 @@ +import copy +import warnings + +import torch + +from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor +from keras.src.utils.module_utils import tensorflow as tf +from keras.src.utils.module_utils import torch_xla + + +class TorchExportArchive: + def _track_layer(self, layer): + raise NotImplementedError( + "`track` is not supported for `Layer`s and `Model`s in the torch " + "backend. Use `track_and_add_endpoint` instead." + ) + + def add_endpoint(self, name, fn, input_signature, **kwargs): + raise NotImplementedError( + "`add_endpoint` is not supported for `Layer`s and `Model`s in the " + "torch backend. Use `track_and_add_endpoint` instead." + ) + + def track_and_add_endpoint(self, name, resource, input_signature, **kwargs): + # Disable false alarms related to lifting parameters. + warnings.filterwarnings("ignore", message=".*created when tracing.*") + warnings.filterwarnings( + "ignore", message=".*Unable to find the path of the module.*" + ) + + if not isinstance(resource, torch.nn.Module): + raise TypeError( + "`resource` must be an instance of `torch.nn.Module`. " + f"Received: resource={resource} (of type {type(resource)})" + ) + + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + sample_inputs = tuple(sample_inputs) + + # Ref: torch_xla.tf_saved_model_integration + # TODO: Utilize `dynamic_shapes` + exported = torch.export.export( + resource, sample_inputs, dynamic_shapes=None, strict=False + ) + options = torch_xla.stablehlo.StableHLOExportOptions( + override_tracing_arguments=sample_inputs + ) + stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo( + exported, options + ) + state_dict_keys = list(stablehlo_model._bundle.state_dict.keys()) + + # Remove unused variables. + for k in state_dict_keys: + if "lifted" not in k: + stablehlo_model._bundle.state_dict.pop(k) + + bundle = copy.deepcopy(stablehlo_model._bundle) + bundle.state_dict = { + k: tf.Variable(v, trainable=False, name=k) + for k, v in bundle.state_dict.items() + } + bundle.additional_constants = [ + tf.Variable(v, trainable=False) for v in bundle.additional_constants + ] + + # Track variables in `bundle` for `write_out`. + self._tf_trackable.variables += ( + list(bundle.state_dict.values()) + bundle.additional_constants + ) + + # Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf + def make_tf_function(func, bundle): + from tensorflow.compiler.tf2xla.python import xla as tfxla + + def _get_shape_with_dynamic(signature): + shape = copy.copy(signature.shape) + for i in signature.dynamic_dims: + shape[i] = None + return shape + + def _extract_call_parameters(args, meta, bundle): + call_args = [] + if meta.input_pytree_spec is not None: + args = tree.flatten(args) + for loc in meta.input_locations: + if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER: + call_args.append(bundle.state_dict[loc.name]) + elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT: + call_args.append( + bundle.additional_constants[loc.position] + ) + else: + call_args.append(args[loc.position]) + return call_args + + def inner(*args): + Touts = [sig.dtype for sig in func.meta.output_signature] + Souts = [ + _get_shape_with_dynamic(sig) + for sig in func.meta.output_signature + ] + call_args = _extract_call_parameters(args, func.meta, bundle) + results = tfxla.call_module( + tuple(call_args), + version=5, + Tout=Touts, # dtype information + Sout=Souts, # Shape information + function_list=[], + module=func.bytecode, + ) + if len(Souts) == 1: + results = results[0] + return results + + return inner + + decorated_fn = tf.function( + make_tf_function( + stablehlo_model._bundle.stablehlo_funcs[0], bundle + ), + input_signature=input_signature, + ) + return decorated_fn diff --git a/keras/src/backend/torch/image.py b/keras/src/backend/torch/image.py new file mode 100644 index 000000000000..b6976dc8569a --- /dev/null +++ b/keras/src/backend/torch/image.py @@ -0,0 +1,1192 @@ +import functools +import itertools +import operator + +import numpy as np +import torch +import torch._dynamo as dynamo +import torch.nn.functional as F + +from keras.src import backend +from keras.src.backend.torch.core import cast +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import get_device +from keras.src.backend.torch.core import to_torch_dtype +from keras.src.random.seed_generator import draw_seed + +RESIZE_INTERPOLATIONS = { + "bilinear": "bilinear", + "nearest": "nearest-exact", + "bicubic": "bicubic", +} +UNSUPPORTED_INTERPOLATIONS = ( + "lanczos3", + "lanczos5", +) +AFFINE_TRANSFORM_INTERPOLATIONS = { + "nearest": 0, + "bilinear": 1, +} +AFFINE_TRANSFORM_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", +} +SCALE_AND_TRANSLATE_METHODS = { + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", +} + + +def rgb_to_grayscale(images, data_format=None): + images = convert_to_tensor(images) + data_format = backend.standardize_data_format(data_format) + if images.ndim not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + channel_axis = -3 if data_format == "channels_first" else -1 + if images.shape[channel_axis] not in (1, 3): + raise ValueError( + "Invalid channel size: expected 3 (RGB) or 1 (Grayscale). " + f"Received input with shape: images.shape={images.shape}" + ) + + # This implementation is based on + # https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py + if images.shape[channel_axis] == 3: + r, g, b = images.unbind(dim=channel_axis) + images = (0.2989 * r + 0.587 * g + 0.114 * b).to(images.dtype) + images = images.unsqueeze(dim=channel_axis) + else: + images = images.clone() + return images + + +def rgb_to_hsv(images, data_format=None): + # Ref: dm_pix + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + eps = torch.finfo(dtype).eps + images = torch.where(torch.abs(images) < eps, 0.0, images) + red, green, blue = torch.split(images, [1, 1, 1], channels_axis) + red = torch.squeeze(red, channels_axis) + green = torch.squeeze(green, channels_axis) + blue = torch.squeeze(blue, channels_axis) + + def rgb_planes_to_hsv_planes(r, g, b): + value = torch.maximum(torch.maximum(r, g), b) + minimum = torch.minimum(torch.minimum(r, g), b) + range_ = value - minimum + + safe_value = torch.where(value > 0, value, 1.0) + safe_range = torch.where(range_ > 0, range_, 1.0) + + saturation = torch.where(value > 0, range_ / safe_value, 0.0) + norm = 1.0 / (6.0 * safe_range) + + hue = torch.where( + value == g, + norm * (b - r) + 2.0 / 6.0, + norm * (r - g) + 4.0 / 6.0, + ) + hue = torch.where(value == r, norm * (g - b), hue) + hue = torch.where(range_ > 0, hue, 0.0) + (hue < 0.0).to(hue.dtype) + return hue, saturation, value + + images = torch.stack( + rgb_planes_to_hsv_planes(red, green, blue), axis=channels_axis + ) + return images + + +def hsv_to_rgb(images, data_format=None): + # Ref: dm_pix + images = convert_to_tensor(images) + dtype = images.dtype + data_format = backend.standardize_data_format(data_format) + channels_axis = -1 if data_format == "channels_last" else -3 + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={backend.standardize_dtype(dtype)}" + ) + hue, saturation, value = torch.split(images, [1, 1, 1], channels_axis) + hue = torch.squeeze(hue, channels_axis) + saturation = torch.squeeze(saturation, channels_axis) + value = torch.squeeze(value, channels_axis) + + def hsv_planes_to_rgb_planes(hue, saturation, value): + dh = torch.remainder(hue, 1.0) * 6.0 + dr = torch.clip(torch.abs(dh - 3.0) - 1.0, 0.0, 1.0) + dg = torch.clip(2.0 - torch.abs(dh - 2.0), 0.0, 1.0) + db = torch.clip(2.0 - torch.abs(dh - 4.0), 0.0, 1.0) + one_minus_s = 1.0 - saturation + + red = value * (one_minus_s + saturation * dr) + green = value * (one_minus_s + saturation * dg) + blue = value * (one_minus_s + saturation * db) + return red, green, blue + + images = torch.stack( + hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis + ) + return images + + +def _cast_squeeze_in(image, req_dtypes): + need_squeeze = False + # make image NCHW + if image.ndim < 4: + image = image.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = image.dtype + need_cast = False + if out_dtype not in req_dtypes: + need_cast = True + req_dtype = req_dtypes[0] + image = image.to(req_dtype) + return image, need_cast, need_squeeze, out_dtype + + +def _cast_squeeze_out(image, need_cast, need_squeeze, out_dtype): + if need_squeeze: + image = image.squeeze(dim=0) + + if need_cast: + if out_dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ): + # it is better to round before cast + image = torch.round(image) + image = image.to(out_dtype) + return image + + +def resize( + images, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation in UNSUPPORTED_INTERPOLATIONS: + raise ValueError( + "Resizing with Lanczos interpolation is " + "not supported by the PyTorch backend. " + f"Received: interpolation={interpolation}." + ) + if interpolation not in RESIZE_INTERPOLATIONS: + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}" + ) + if fill_mode != "constant": + raise ValueError( + "Invalid value for argument `fill_mode`. Only `'constant'` " + f"is supported. Received: fill_mode={fill_mode}" + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) + if not len(size) == 2: + raise ValueError( + "Argument `size` must be a tuple of two elements " + f"(height, width). Received: size={size}" + ) + size = tuple(size) + images = convert_to_tensor(images) + if images.ndim not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + images, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + images, [torch.float32, torch.float64] + ) + if data_format == "channels_last": + images = images.permute((0, 3, 1, 2)) + + if crop_to_aspect_ratio: + shape = images.shape + height, width = shape[-2], shape[-1] + target_height, target_width = size + crop_height = int(float(width * target_height) / target_width) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / target_height) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + elif pad_to_aspect_ratio: + shape = images.shape + height, width = shape[-2], shape[-1] + target_height, target_width = size + pad_height = int(float(width * target_height) / target_width) + pad_height = max(height, pad_height) + pad_width = int(float(height * target_width) / target_height) + pad_width = max(width, pad_width) + img_box_hstart = int(float(pad_height - height) / 2) + img_box_wstart = int(float(pad_width - width) / 2) + + batch_size = images.shape[0] + channels = images.shape[1] + if img_box_hstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (batch_size, channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + images, + torch.ones( + (batch_size, channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=2, + ) + else: + padded_img = images + if img_box_wstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (batch_size, channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ), + padded_img, + torch.ones( + (batch_size, channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=3, + ) + images = padded_img + + # This implementation is based on + # https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py + if antialias and interpolation not in ("bilinear", "bicubic"): + # We manually set it to False to avoid an error downstream in + # interpolate(). This behaviour is documented: the parameter is + # irrelevant for modes that are not bilinear or bicubic. We used to + # raise an error here, but now we don't use True as the default. + antialias = False + # Define align_corners to avoid warnings + align_corners = False if interpolation in ("bilinear", "bicubic") else None + resized = F.interpolate( + images, + size=size, + mode=RESIZE_INTERPOLATIONS[interpolation], + align_corners=align_corners, + antialias=antialias, + ) + if interpolation == "bicubic" and out_dtype == torch.uint8: + resized = resized.clamp(min=0, max=255) + if data_format == "channels_last": + resized = resized.permute((0, 2, 3, 1)) + resized = _cast_squeeze_out( + resized, + need_cast=need_cast, + need_squeeze=need_squeeze, + out_dtype=out_dtype, + ) + return resized + + +def affine_transform( + images, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + + images = convert_to_tensor(images) + transform = convert_to_tensor(transform) + + if images.ndim not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if transform.ndim not in (1, 2): + raise ValueError( + "Invalid transform rank: expected rank 1 (single transform) " + "or rank 2 (batch of transforms). Received input with shape: " + f"transform.shape={transform.shape}" + ) + + # unbatched case + need_squeeze = False + if images.ndim == 3: + images = images.unsqueeze(dim=0) + need_squeeze = True + if transform.ndim == 1: + transform = transform.unsqueeze(dim=0) + + if data_format == "channels_first": + images = images.permute((0, 2, 3, 1)) + + batch_size = images.shape[0] + + # get indices + meshgrid = torch.meshgrid( + *[ + torch.arange(size, dtype=transform.dtype, device=transform.device) + for size in images.shape[1:] + ], + indexing="ij", + ) + indices = torch.concatenate( + [torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1 + ) + indices = torch.tile(indices, (batch_size, 1, 1, 1, 1)) + + # swap the values + a0 = transform[:, 0].clone() + a2 = transform[:, 2].clone() + b1 = transform[:, 4].clone() + b2 = transform[:, 5].clone() + transform[:, 0] = b1 + transform[:, 2] = b2 + transform[:, 4] = a0 + transform[:, 5] = a2 + + # deal with transform + transform = torch.nn.functional.pad( + transform, pad=[0, 1, 0, 0], mode="constant", value=1 + ) + transform = torch.reshape(transform, (batch_size, 3, 3)) + offset = transform[:, 0:2, 2].clone() + offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0]) + transform[:, 0:2, 2] = 0 + + # transform the indices + coordinates = torch.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = torch.moveaxis(coordinates, source=-1, destination=1) + coordinates += torch.reshape(offset, shape=(*offset.shape, 1, 1, 1)) + + # Note: torch.stack is faster than torch.vmap when the batch size is small. + affined = torch.stack( + [ + map_coordinates( + images[i], + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for i in range(len(images)) + ], + ) + + if data_format == "channels_first": + affined = affined.permute((0, 3, 1, 2)) + if need_squeeze: + affined = affined.squeeze(dim=0) + return affined + + +def perspective_transform( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + + images = convert_to_tensor(images) + dtype = backend.standardize_dtype(images.dtype) + start_points = convert_to_tensor(start_points, dtype=dtype) + end_points = convert_to_tensor(end_points, dtype=dtype) + + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + + if images.ndim not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + if start_points.shape[-2:] != (4, 2) or start_points.dim() not in (2, 3): + raise ValueError( + "Invalid start_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {start_points.shape}" + ) + if end_points.shape[-2:] != (4, 2) or end_points.dim() not in (2, 3): + raise ValueError( + "Invalid end_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {end_points.shape}" + ) + if start_points.shape != end_points.shape: + raise ValueError( + "start_points and end_points must have the same shape." + f" Received start_points.shape={start_points.shape}, " + f"end_points.shape={end_points.shape}" + ) + + need_squeeze = False + if images.ndim == 3: + images = images.unsqueeze(dim=0) + need_squeeze = True + + if start_points.ndim == 2: + start_points = start_points.unsqueeze(dim=0) + if end_points.ndim == 2: + end_points = end_points.unsqueeze(dim=0) + + if data_format == "channels_first": + images = images.permute((0, 2, 3, 1)) + + batch_size, height, width, channels = images.shape + + transforms = compute_homography_matrix(start_points, end_points) + + if transforms.dim() == 1: + transforms = transforms.unsqueeze(0) + if transforms.shape[0] == 1 and batch_size > 1: + transforms = transforms.repeat(batch_size, 1) + + grid_x, grid_y = torch.meshgrid( + torch.arange(width, dtype=to_torch_dtype(dtype), device=images.device), + torch.arange(height, dtype=to_torch_dtype(dtype), device=images.device), + indexing="xy", + ) + + output = torch.empty( + [batch_size, height, width, channels], + dtype=to_torch_dtype(dtype), + device=images.device, + ) + + for i in range(batch_size): + a0, a1, a2, a3, a4, a5, a6, a7 = transforms[i] + denom = a6 * grid_x + a7 * grid_y + 1.0 + x_in = (a0 * grid_x + a1 * grid_y + a2) / denom + y_in = (a3 * grid_x + a4 * grid_y + a5) / denom + + coords = torch.stack([y_in.flatten(), x_in.flatten()], dim=0) + mapped_channels = [] + for channel in range(channels): + channel_img = images[i, :, :, channel] + mapped_channel = map_coordinates( + channel_img, + coords, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode="constant", + fill_value=fill_value, + ) + mapped_channels.append(mapped_channel.reshape(height, width)) + output[i] = torch.stack(mapped_channels, dim=-1) + + if data_format == "channels_first": + output = output.permute((0, 3, 1, 2)) + if need_squeeze: + output = output.squeeze(dim=0) + + return output + + +def compute_homography_matrix(start_points, end_points): + start_points = convert_to_tensor(start_points) + end_points = convert_to_tensor(end_points) + dtype = backend.result_type(start_points.dtype, end_points.dtype, float) + # `torch.linalg.solve` requires float32. + compute_dtype = backend.result_type(dtype, "float32") + start_points = cast(start_points, dtype) + end_points = cast(end_points, dtype) + + start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1] + start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1] + start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1] + start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1] + + end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1] + end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1] + end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1] + end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1] + + coefficient_matrix = torch.stack( + [ + torch.stack( + [ + end_x1, + end_y1, + torch.ones_like(end_x1), + torch.zeros_like(end_x1), + torch.zeros_like(end_x1), + torch.zeros_like(end_x1), + -start_x1 * end_x1, + -start_x1 * end_y1, + ], + dim=-1, + ), + torch.stack( + [ + torch.zeros_like(end_x1), + torch.zeros_like(end_x1), + torch.zeros_like(end_x1), + end_x1, + end_y1, + torch.ones_like(end_x1), + -start_y1 * end_x1, + -start_y1 * end_y1, + ], + dim=-1, + ), + torch.stack( + [ + end_x2, + end_y2, + torch.ones_like(end_x2), + torch.zeros_like(end_x2), + torch.zeros_like(end_x2), + torch.zeros_like(end_x2), + -start_x2 * end_x2, + -start_x2 * end_y2, + ], + dim=-1, + ), + torch.stack( + [ + torch.zeros_like(end_x2), + torch.zeros_like(end_x2), + torch.zeros_like(end_x2), + end_x2, + end_y2, + torch.ones_like(end_x2), + -start_y2 * end_x2, + -start_y2 * end_y2, + ], + dim=-1, + ), + torch.stack( + [ + end_x3, + end_y3, + torch.ones_like(end_x3), + torch.zeros_like(end_x3), + torch.zeros_like(end_x3), + torch.zeros_like(end_x3), + -start_x3 * end_x3, + -start_x3 * end_y3, + ], + dim=-1, + ), + torch.stack( + [ + torch.zeros_like(end_x3), + torch.zeros_like(end_x3), + torch.zeros_like(end_x3), + end_x3, + end_y3, + torch.ones_like(end_x3), + -start_y3 * end_x3, + -start_y3 * end_y3, + ], + dim=-1, + ), + torch.stack( + [ + end_x4, + end_y4, + torch.ones_like(end_x4), + torch.zeros_like(end_x4), + torch.zeros_like(end_x4), + torch.zeros_like(end_x4), + -start_x4 * end_x4, + -start_x4 * end_y4, + ], + dim=-1, + ), + torch.stack( + [ + torch.zeros_like(end_x4), + torch.zeros_like(end_x4), + torch.zeros_like(end_x4), + end_x4, + end_y4, + torch.ones_like(end_x4), + -start_y4 * end_x4, + -start_y4 * end_y4, + ], + dim=-1, + ), + ], + dim=1, + ) + + target_vector = torch.stack( + [ + start_x1, + start_y1, + start_x2, + start_y2, + start_x3, + start_y3, + start_x4, + start_y4, + ], + dim=-1, + ).unsqueeze(-1) + + coefficient_matrix = cast(coefficient_matrix, compute_dtype) + target_vector = cast(target_vector, compute_dtype) + homography_matrix = torch.linalg.solve(coefficient_matrix, target_vector) + homography_matrix = homography_matrix.reshape(-1, 8) + homography_matrix = cast(homography_matrix, dtype) + return homography_matrix + + +def _mirror_index_fixer(index, size): + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return torch.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index, size): + return torch.floor_divide( + _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2 + ) + + +_INDEX_FIXERS = { + # we need to take care of out-of-bound indices in torch + "constant": lambda index, size: torch.clip(index, 0, size - 1), + "nearest": lambda index, size: torch.clip(index, 0, size - 1), + "wrap": lambda index, size: index % size, + "mirror": _mirror_index_fixer, + "reflect": _reflect_index_fixer, +} + + +def _is_integer(a): + if not torch.is_floating_point(a) and not torch.is_complex(a): + return True + return False + + +def _nearest_indices_and_weights(coordinate): + coordinate = ( + coordinate if _is_integer(coordinate) else torch.round(coordinate) + ) + index = coordinate.to(torch.int32) + return [(index, 1)] + + +def _linear_indices_and_weights(coordinate): + lower = torch.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = lower.to(torch.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +def map_coordinates( + inputs, coordinates, order, fill_mode="constant", fill_value=0.0 +): + input_arr = convert_to_tensor(inputs) + coordinate_arrs = [convert_to_tensor(c) for c in coordinates] + + if len(coordinate_arrs) != len(input_arr.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`inputs`. " + f"Received inputs with shape: {input_arr.shape} and coordinate " + f"leading dim of {len(coordinate_arrs)}" + ) + if len(coordinate_arrs[0].shape) < 1: + dim = len(coordinate_arrs) + shape = (dim,) + coordinate_arrs[0].shape + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {shape}" + ) + + # skip tensor creation as possible + if isinstance(fill_value, (int, float)) and _is_integer(input_arr): + fill_value = int(fill_value) + + if len(coordinates) != len(input_arr.shape): + raise ValueError( + "coordinates must be a sequence of length inputs.shape, but " + f"{len(coordinates)} != {len(input_arr.shape)}" + ) + + index_fixer = _INDEX_FIXERS.get(fill_mode) + if index_fixer is None: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected one of " + f"{set(_INDEX_FIXERS.keys())}. Received: fill_mode={fill_mode}" + ) + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + raise NotImplementedError("map_coordinates currently requires order<=1") + + if fill_mode == "constant": + + def is_valid(index, size): + return (0 <= index) & (index < size) + + else: + + def is_valid(index, size): + return True + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = index_fixer(index, size) + valid = is_valid(index, size) + valid_interp.append((fixed_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = zip(*items) + if all(valid is True for valid in validities): + # fast path + contribution = input_arr[indices] + else: + all_valid = functools.reduce(operator.and_, validities) + contribution = torch.where( + all_valid, input_arr[indices], fill_value + ) + outputs.append(functools.reduce(operator.mul, weights) * contribution) + result = functools.reduce(operator.add, outputs) + if _is_integer(input_arr): + result = result if _is_integer(result) else torch.round(result) + return result.to(input_arr.dtype) + + +def gaussian_blur( + images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None +): + def _create_gaussian_kernel(kernel_size, sigma, dtype): + def _get_gaussian_kernel1d(size, sigma): + x = ( + torch.arange(size, dtype=dtype, device=sigma.device) + - (size - 1) / 2 + ) + kernel1d = torch.exp(-0.5 * (x / sigma) ** 2) + return kernel1d / torch.sum(kernel1d) + + def _get_gaussian_kernel2d(size, sigma): + kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0]) + kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1]) + return torch.outer(kernel1d_y, kernel1d_x) + + kernel = _get_gaussian_kernel2d(kernel_size, sigma) + + kernel = kernel.view(1, 1, kernel_size[0], kernel_size[1]) + return kernel + + images = convert_to_tensor(images) + kernel_size = convert_to_tensor(kernel_size) + sigma = convert_to_tensor(sigma) + dtype = images.dtype + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + need_squeeze = False + if images.ndim == 3: + images = images.unsqueeze(dim=0) + need_squeeze = True + + if data_format == "channels_last": + images = images.permute(0, 3, 1, 2) + + num_channels = images.shape[1] + kernel = _create_gaussian_kernel(kernel_size, sigma, dtype) + + kernel = kernel.expand(num_channels, 1, kernel_size[0], kernel_size[1]) + + blurred_images = torch.nn.functional.conv2d( + images, + kernel, + stride=1, + padding=int(kernel_size[0] // 2), + groups=num_channels, + ) + + if data_format == "channels_last": + blurred_images = blurred_images.permute(0, 2, 3, 1) + + if need_squeeze: + blurred_images = blurred_images.squeeze(dim=0) + + return blurred_images + + +@dynamo.disable() +def _torch_seed_generator(seed): + first_seed, second_seed = draw_seed(seed) + device = get_device() + if device == "meta": + return None + generator = torch.Generator(device=get_device()) + generator.manual_seed(int(first_seed + second_seed)) + return generator + + +def elastic_transform( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): + raise ValueError( + "Invalid value for argument `interpolation`. Expected of one " + f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " + f"interpolation={interpolation}" + ) + if fill_mode not in AFFINE_TRANSFORM_FILL_MODES: + raise ValueError( + "Invalid value for argument `fill_mode`. Expected of one " + f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}" + ) + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + images = convert_to_tensor(images) + alpha = convert_to_tensor(alpha) + sigma = convert_to_tensor(sigma) + input_dtype = images.dtype + kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1) + + need_squeeze = False + if images.ndim == 3: + images = images.unsqueeze(dim=0) + need_squeeze = True + + if data_format == "channels_last": + batch_size, height, width, channels = images.shape + channel_axis = -1 + else: + batch_size, channels, height, width = images.shape + channel_axis = 1 + + generator = _torch_seed_generator(seed) if get_device() == "meta" else None + dx = ( + torch.normal( + 0.0, + 1.0, + size=(batch_size, height, width), + generator=generator, + dtype=input_dtype, + device=images.device, + ) + * sigma + ) + + dy = ( + torch.normal( + 0.0, + 1.0, + size=(batch_size, height, width), + generator=generator, + dtype=input_dtype, + device=images.device, + ) + * sigma + ) + + dx = gaussian_blur( + dx.unsqueeze(dim=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + dy = gaussian_blur( + dy.unsqueeze(dim=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + + dx = dx.squeeze() + dy = dy.squeeze() + + x, y = torch.meshgrid( + torch.arange(width), torch.arange(height), indexing="xy" + ) + x, y = x.unsqueeze(0).to(images.device), y.unsqueeze(0).to(images.device) + + distorted_x = x + alpha * dx + distorted_y = y + alpha * dy + + transformed_images = torch.zeros_like(images) + + if data_format == "channels_last": + for i in range(channels): + transformed_images[..., i] = torch.stack( + [ + map_coordinates( + images[b, ..., i], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + else: + for i in range(channels): + transformed_images[:, i, :, :] = torch.stack( + [ + map_coordinates( + images[b, i, ...], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + + if need_squeeze: + transformed_images = transformed_images.squeeze(0) + transformed_images = transformed_images.to(input_dtype) + + return transformed_images + + +def _fill_triangle_kernel(x): + return torch.maximum(torch.tensor(0, dtype=x.dtype), 1 - torch.abs(x)) + + +def _fill_keys_cubic_kernel(x): + out = ((1.5 * x - 2.5) * x) * x + 1.0 + out = torch.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out) + return torch.where(x >= 2.0, 0.0, out) + + +def _fill_lanczos_kernel(radius, x): + y = radius * torch.sin(np.pi * x) * torch.sin(np.pi * x / radius) + out = torch.where( + x > 1e-3, torch.divide(y, torch.where(x != 0, np.pi**2 * x**2, 1)), 1 + ) + return torch.where(x > radius, 0.0, out) + + +_kernels = { + "linear": _fill_triangle_kernel, + "cubic": _fill_keys_cubic_kernel, + "lanczos3": lambda x: _fill_lanczos_kernel(3.0, x), + "lanczos5": lambda x: _fill_lanczos_kernel(5.0, x), +} + + +def _compute_weight_mat( + input_size, output_size, scale, translation, kernel, antialias +): + dtype = to_torch_dtype(backend.result_type(scale.dtype, translation.dtype)) + inv_scale = 1.0 / scale + kernel_scale = ( + torch.maximum( + inv_scale, + torch.tensor(1.0, dtype=inv_scale.dtype, device=inv_scale.device), + ) + if antialias + else 1.0 + ) + sample_f = ( + (torch.arange(output_size, dtype=dtype, device=inv_scale.device) + 0.5) + * inv_scale + - translation * inv_scale + - 0.5 + ) + x = ( + torch.abs( + sample_f[torch.newaxis, :] + - torch.arange(input_size, dtype=dtype, device=sample_f.device)[ + :, torch.newaxis + ] + ) + / kernel_scale + ) + weights = kernel(x) + total_weight_sum = torch.sum(weights, dim=0, keepdims=True) + weights = torch.where( + torch.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps), + torch.divide( + weights, torch.where(total_weight_sum != 0, total_weight_sum, 1) + ), + 0, + ) + input_size_minus_0_5 = input_size - 0.5 + return torch.where( + torch.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ + torch.newaxis, : + ], + weights, + 0, + ) + + +def _scale_and_translate( + x, output_shape, spatial_dims, scale, translation, kernel, antialias +): + x = convert_to_tensor(x) + input_shape = x.shape + if len(spatial_dims) == 0: + return x + if backend.is_int_dtype(x.dtype): + output = cast(x, "float32") + use_rounding = True + else: + output = torch.clone(x) + use_rounding = False + for i, d in enumerate(spatial_dims): + d = d % x.ndim + m, n = input_shape[d], output_shape[d] + w = cast( + _compute_weight_mat( + m, n, scale[i], translation[i], kernel, antialias + ), + output.dtype, + ) + output = torch.tensordot(output, w, dims=((d,), (0,))) + output = torch.moveaxis(output, -1, d) + if use_rounding: + output = torch.clip(torch.round(output), torch.min(x), torch.max(x)) + output = cast(output, x.dtype) + return output + + +def scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias=True, +): + if method not in SCALE_AND_TRANSLATE_METHODS: + raise ValueError( + "Invalid value for argument `method`. Expected of one " + f"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}" + ) + if method in ("linear", "bilinear", "trilinear", "triangle"): + method = "linear" + elif method in ("cubic", "bicubic", "tricubic"): + method = "cubic" + + images = convert_to_tensor(images) + scale = convert_to_tensor(scale) + translation = convert_to_tensor(translation) + kernel = _kernels[method] + dtype = backend.result_type(scale.dtype, translation.dtype) + scale = cast(scale, dtype) + translation = cast(translation, dtype) + return _scale_and_translate( + images, + output_shape, + spatial_dims, + scale, + translation, + kernel, + antialias, + ) diff --git a/keras/src/backend/torch/layer.py b/keras/src/backend/torch/layer.py new file mode 100644 index 000000000000..da05f32ddfb4 --- /dev/null +++ b/keras/src/backend/torch/layer.py @@ -0,0 +1,65 @@ +import torch + +from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.ops.operation import Operation + + +class TorchLayer(torch.nn.Module): + @property + def torch_params(self): + if not hasattr(self, "_torch_params"): + self._track_variables() + return self._torch_params + + def _post_build(self): + # Do not track variables when in a stateless scope. + # The variables are not initialized. + if in_stateless_scope(): + return + self._track_variables() + + def _track_variables(self): + # set torch_params attribute will have module automatically track + # parameters. + self._torch_params = torch.nn.ParameterDict( + {variable.path: variable.value for variable in self.variables} + ) + + def named_parameters( + self, + prefix="", + recurse=True, + remove_duplicate=True, + ): + if not hasattr(self, "_torch_params"): + self._track_variables() + return torch.nn.Module.named_parameters( + self, prefix, recurse, remove_duplicate + ) + + def forward(self, *args, **kwargs): + return Operation.__call__(self, *args, **kwargs) + + def _setattr_hook(self, name, value): + from keras.src.layers import Layer + + if ( + isinstance(value, torch.nn.Module) + and not isinstance(value, Layer) + and not name == "_torch_params" + ): + from keras.src.utils.torch_utils import TorchModuleWrapper + + if not isinstance(self, TorchModuleWrapper): + value = TorchModuleWrapper(value) + return name, value + + def _post_track_variable(self, variable): + if hasattr(self, "_torch_params"): + if variable.path not in self.torch_params: + self.torch_params[variable.path] = variable.value + + def _post_untrack_variable(self, variable): + if hasattr(self, "_torch_params"): + if variable.path in self.torch_params: + self.torch_params.pop(variable.path) diff --git a/keras/src/backend/torch/linalg.py b/keras/src/backend/torch/linalg.py new file mode 100644 index 000000000000..5ea66de90f09 --- /dev/null +++ b/keras/src/backend/torch/linalg.py @@ -0,0 +1,86 @@ +import torch + +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.torch.core import cast +from keras.src.backend.torch.core import convert_to_tensor + + +def cholesky(x, upper=False): + return torch.linalg.cholesky(x, upper=upper) + + +def cholesky_inverse(x, upper=False): + return torch.cholesky_inverse(x, upper=upper) + + +def det(x): + return torch.det(x) + + +def eig(x): + return torch.linalg.eig(x) + + +def eigh(x): + return torch.linalg.eigh(x) + + +def inv(x): + return torch.linalg.inv(x) + + +def lu_factor(x): + LU, pivots = torch.linalg.lu_factor(x) + # torch returns pivots with 1-based indexing + return LU, pivots - 1 + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return torch.linalg.norm(x, ord=ord, dim=axis, keepdim=keepdims) + + +def qr(x, mode="reduced"): + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + return torch.linalg.qr(x, mode=mode) + + +def solve(a, b): + return torch.linalg.solve(a, b) + + +def solve_triangular(a, b, lower=False): + if b.ndim == a.ndim - 1: + b = torch.unsqueeze(b, axis=-1) + return torch.linalg.solve_triangular(a, b, upper=not lower).squeeze( + axis=-1 + ) + return torch.linalg.solve_triangular(a, b, upper=not lower) + + +def svd(x, full_matrices=True, compute_uv=True): + if not compute_uv: + return torch.linalg.svdvals(x) + return torch.linalg.svd(x, full_matrices=full_matrices) + + +def lstsq(a, b, rcond=None): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return torch.linalg.lstsq(a, b, rcond=rcond)[0] + + +def jvp(fun, primals, tangents, has_aux=False): + return torch.func.jvp(fun, primals, tangents, has_aux=has_aux) diff --git a/keras/src/backend/torch/math.py b/keras/src/backend/torch/math.py new file mode 100644 index 000000000000..40e45e1d6981 --- /dev/null +++ b/keras/src/backend/torch/math.py @@ -0,0 +1,419 @@ +import math + +import torch + +from keras.src.backend import config +from keras.src.backend import standardize_dtype +from keras.src.backend.common import dtypes +from keras.src.backend.torch.core import cast +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import get_device +from keras.src.backend.torch.numpy import pad + + +def _segment_reduction_fn(data, segment_ids, reduction_method, num_segments): + num_repeats = torch.prod( + torch.tensor(data.shape[1:], device=get_device()) + ).long() + # To use `scatter_add` in torch, we need to replicate `segment_ids` into the + # shape of `data`. + segment_ids = ( + segment_ids.repeat_interleave(num_repeats) + .view(*data.shape) + .type(torch.int64) + ) + num_segments = num_segments or len(torch.unique(segment_ids)) + + # .scatter_add does not support -1 in the indices. + # Add all out-of-bound indices value to an extra dimension after + # num_segments, which is removed before returning the result. + + # Replacing the out-of-bound indices. + segment_ids = torch.where(segment_ids >= 0, segment_ids, num_segments) + segment_ids = torch.where( + segment_ids < num_segments, segment_ids, num_segments + ) + + # Add one more dimension to the result shape with the "+1". + shape = (num_segments + 1,) + tuple(data.shape[1:]) + + if reduction_method == "amax": + result = torch.ones(*shape, device=get_device()) * -float("Inf") + else: + result = torch.zeros(*shape, device=get_device()) + + result = result.scatter_reduce( + 0, segment_ids, data.float(), reduction_method + ) + + # Removing the extra dimension. + result = result[:-1, ...] + + return result.type(data.dtype) + + +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + data = convert_to_tensor(data) + segment_ids = convert_to_tensor(segment_ids) + return _segment_reduction_fn(data, segment_ids, "sum", num_segments) + + +def segment_max(data, segment_ids, num_segments=None, sorted=False): + data = convert_to_tensor(data) + segment_ids = convert_to_tensor(segment_ids) + return _segment_reduction_fn(data, segment_ids, "amax", num_segments) + + +def top_k(x, k, sorted=True): + x = convert_to_tensor(x) + return torch.topk(x, k, sorted=sorted) + + +def in_top_k(targets, predictions, k): + targets = convert_to_tensor(targets).type(torch.int64) + targets = targets[:, None] + predictions = convert_to_tensor(predictions) + topk_values = top_k(predictions, k).values + targets_values = torch.take_along_dim(predictions, targets, dim=-1) + mask = targets_values >= topk_values + return torch.any(mask, axis=-1) + + +def logsumexp(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + axis = tuple(range(x.dim())) if axis is None else axis + return torch.logsumexp(x, dim=axis, keepdim=keepdims) + + +def qr(x, mode="reduced"): + x = convert_to_tensor(x) + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + x = convert_to_tensor(x) + return torch.linalg.qr(x, mode=mode) + + +def extract_sequences(x, sequence_length, sequence_stride): + x = convert_to_tensor(x) + return torch.unfold_copy( + x, dimension=-1, size=sequence_length, step=sequence_stride + ) + + +def _overlap_sequences(x, sequence_stride): + # Ref: https://github.com/google/jax/blob/main/jax/_src/scipy/signal.py + x = convert_to_tensor(x) + *batch_shape, num_sequences, sequence_length = x.shape + if sequence_stride > sequence_length: + raise ValueError( + "`sequence_stride` must equal or less than x.shape[-1]. " + f"Received: sequence_stride={sequence_stride}, " + f"x.shape[-1]={sequence_length}" + ) + if sequence_stride < (sequence_length / num_sequences): + raise ValueError( + "`sequence_stride` must equal or greater than " + "x.shape[-1] / x.shape[-2]. " + f"Received: sequence_stride={sequence_stride}, " + f"x.shape[-1]={sequence_length}, x.shape[-2]={num_sequences}" + ) + flat_batchsize = math.prod(batch_shape) + x = torch.reshape(x, (flat_batchsize, num_sequences, sequence_length)) + output_size = sequence_stride * (num_sequences - 1) + sequence_length + nstep_per_segment = 1 + (sequence_length - 1) // sequence_stride + # Here, we use shorter notation for axes. + # B: batch_size, N: num_sequences, S: nstep_per_segment, + # T: sequence_length divided by S + padded_segment_len = nstep_per_segment * sequence_stride + x = torch.nn.functional.pad( + x, (0, padded_segment_len - sequence_length, 0, 0, 0, 0) + ) + x = torch.reshape( + x, (flat_batchsize, num_sequences, nstep_per_segment, sequence_stride) + ) + # For obtaining shifted signals, this routine reinterprets flattened array + # with a shrinked axis. With appropriate truncation/ padding, this + # operation pushes the last padded elements of the previous row to the head + # of the current row. + # See implementation of `overlap_and_add` in Tensorflow for details. + x = torch.permute(x, (0, 2, 1, 3)) # x: (B, S, N, T) + x = torch.nn.functional.pad(x, (0, 0, 0, num_sequences, 0, 0, 0, 0)) + # x: (B, S, N*2, T) + shrinked = x.shape[2] - 1 + x = torch.reshape(x, (flat_batchsize, -1)) + x = x[:, : (nstep_per_segment * shrinked * sequence_stride)] + x = torch.reshape( + x, (flat_batchsize, nstep_per_segment, shrinked * sequence_stride) + ) + # Finally, sum shifted segments, and truncate results to the output_size. + x = torch.sum(x, dim=1)[:, :output_size] + return torch.reshape(x, tuple(batch_shape) + (-1,)) + + +def _get_complex_tensor_from_tuple(x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + f"Received: x={x}" + ) + # `convert_to_tensor` does not support passing complex tensors. We separate + # the input out into real and imaginary and convert them separately. + real, imag = x + real = convert_to_tensor(real) + imag = convert_to_tensor(imag) + # Check shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + "Both the real and imaginary parts should have the same shape. " + f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" + ) + # Ensure dtype is float. + if not torch.is_floating_point(real) or not torch.is_floating_point(imag): + raise ValueError( + "At least one tensor in input `x` is not of type float." + f"Received: x={x}." + ) + + complex_input = torch.complex(real, imag) + return complex_input + + +def fft(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = torch.fft.fft(complex_input) + return complex_output.real, complex_output.imag + + +def fft2(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = torch.fft.fft2(complex_input) + return complex_output.real, complex_output.imag + + +def ifft2(x): + complex_input = _get_complex_tensor_from_tuple(x) + complex_output = torch.fft.ifft2(complex_input) + return complex_output.real, complex_output.imag + + +def rfft(x, fft_length=None): + x = convert_to_tensor(x) + complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm="backward") + return complex_output.real, complex_output.imag + + +def irfft(x, fft_length=None): + complex_input = _get_complex_tensor_from_tuple(x) + return torch.fft.irfft(complex_input, n=fft_length, dim=-1, norm="backward") + + +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + if standardize_dtype(x.dtype) not in {"float32", "float64"}: + raise TypeError( + "Invalid input type. Expected `float32` or `float64`. " + f"Received: input type={x.dtype}" + ) + if fft_length < sequence_length: + raise ValueError( + "`fft_length` must equal or larger than `sequence_length`. " + f"Received: sequence_length={sequence_length}, " + f"fft_length={fft_length}" + ) + if isinstance(window, str): + if window not in {"hann", "hamming"}: + raise ValueError( + "If a string is passed to `window`, it must be one of " + f'`"hann"`, `"hamming"`. Received: window={window}' + ) + x = convert_to_tensor(x) + + if window is not None: + if isinstance(window, str): + if window == "hann": + win = torch.hann_window( + sequence_length, + periodic=True, + dtype=x.dtype, + device=get_device(), + ) + else: + win = torch.hamming_window( + sequence_length, + periodic=True, + dtype=x.dtype, + device=get_device(), + ) + else: + win = convert_to_tensor(window, dtype=x.dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + else: + win = torch.ones((sequence_length,), dtype=x.dtype, device=get_device()) + + need_unpack = False + *batch_shape, samples = x.shape + if len(x.shape) > 2: + need_unpack = True + flat_batchsize = math.prod(batch_shape) + x = torch.reshape(x, (flat_batchsize, samples)) + + x = torch.stft( + x, + n_fft=fft_length, + hop_length=sequence_stride, + win_length=sequence_length, + window=win, + center=center, + return_complex=True, + ) + if need_unpack: + fft_unique_bins, num_sequences = x.shape[-2:] + x = torch.reshape(x, (*batch_shape, fft_unique_bins, num_sequences)) + + x = torch.swapaxes(x, -2, -1) + return x.real, x.imag + + +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + complex_input = _get_complex_tensor_from_tuple(x) + dtype = complex_input.real.dtype + win = None + if window is not None: + if isinstance(window, str): + if window == "hann": + win = torch.hann_window( + sequence_length, + periodic=True, + dtype=dtype, + device=get_device(), + ) + else: + win = torch.hamming_window( + sequence_length, + periodic=True, + dtype=dtype, + device=get_device(), + ) + else: + win = convert_to_tensor(window, dtype=dtype) + if len(win.shape) != 1 or win.shape[-1] != sequence_length: + raise ValueError( + "The shape of `window` must be equal to [sequence_length]." + f"Received: window shape={win.shape}" + ) + + if sequence_length == fft_length and center is True and win is not None: + # can be fallen back to torch.istft + need_unpack = False + *batch_shape, num_sequences, fft_unique_bins = complex_input.shape + if len(complex_input.shape) > 3: + need_unpack = True + flat_batchsize = math.prod(batch_shape) + complex_input = torch.reshape( + complex_input, (flat_batchsize, num_sequences, fft_unique_bins) + ) + complex_input = torch.swapaxes(complex_input, -2, -1) + x = torch.istft( + complex_input, + n_fft=fft_length, + hop_length=sequence_stride, + win_length=sequence_length, + window=win, + center=center, + length=length, + return_complex=False, + ) + if need_unpack: + samples = x.shape[-1] + x = torch.reshape(x, (*batch_shape, samples)) + return x + + # custom implementation with irfft and _overlap_sequences + # references: + # torch: aten/src/ATen/native/SpectralOps.cpp + # tf: tf.signal.inverse_stft_window_fn + x = irfft(x, fft_length) + + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) + + if win is not None: + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + win = pad(win, [[l_pad, r_pad]], "constant") + + # square and sum + _sequence_length = sequence_length + l_pad + r_pad + denom = torch.square(win) + overlaps = -(-_sequence_length // sequence_stride) + denom = pad(denom, [(0, overlaps * sequence_stride - _sequence_length)]) + denom = torch.reshape(denom, [overlaps, sequence_stride]) + denom = torch.sum(denom, 0, keepdims=True) + denom = torch.tile(denom, [overlaps, 1]) + denom = torch.reshape(denom, [overlaps * sequence_stride]) + win = torch.divide(win, denom[:_sequence_length]) + x = torch.multiply(x, win) + + x = _overlap_sequences(x, sequence_stride) + + start = 0 if center is False else fft_length // 2 + if length is not None: + end = start + length + elif center is True: + end = -(fft_length // 2) + else: + end = expected_output_len + return x[..., start:end] + + +def rsqrt(x): + x = convert_to_tensor(x) + return torch.rsqrt(x) + + +def erf(x): + x = convert_to_tensor(x) + return torch.erf(x) + + +def erfinv(x): + x = convert_to_tensor(x) + return torch.erfinv(x) + + +def solve(a, b): + a = convert_to_tensor(a) + b = convert_to_tensor(b) + return torch.linalg.solve(a, b) + + +def norm(x, ord=None, axis=None, keepdims=False): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + x = cast(x, dtype) + return torch.linalg.norm(x, ord=ord, dim=axis, keepdim=keepdims) + + +def logdet(x): + x = convert_to_tensor(x) + return torch.logdet(x) diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py new file mode 100644 index 000000000000..85b2a32d5560 --- /dev/null +++ b/keras/src/backend/torch/nn.py @@ -0,0 +1,1117 @@ +import torch +import torch.nn.functional as tnn + +from keras.src import backend +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_torch, +) +from keras.src.backend.torch.core import cast +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import get_device +from keras.src.backend.torch.numpy import expand_dims +from keras.src.backend.torch.numpy import where +from keras.src.utils.argument_validation import standardize_tuple + + +def relu(x): + x = convert_to_tensor(x) + return tnn.relu(x) + + +def relu6(x): + x = convert_to_tensor(x) + return tnn.relu6(x) + + +def sigmoid(x): + x = convert_to_tensor(x) + return tnn.sigmoid(x) + + +def sparse_sigmoid(x): + x = convert_to_tensor(x) + return torch.where( + x <= -1, + torch.tensor(0.0, device=x.device, dtype=x.dtype), + torch.where( + x >= 1, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + 0.5 * (x + 1), + ), + ) + + +def tanh(x): + x = convert_to_tensor(x) + return tnn.tanh(x) + + +def tanh_shrink(x): + x = convert_to_tensor(x) + return tnn.tanhshrink(x) + + +def softplus(x): + x = convert_to_tensor(x) + return tnn.softplus(x) + + +def softsign(x): + x = convert_to_tensor(x) + return tnn.softsign(x) + + +def soft_shrink(x, threshold=0.5): + x = convert_to_tensor(x) + return tnn.softshrink(x, lambd=threshold) + + +def sparse_plus(x): + x = convert_to_tensor(x) + return torch.where( + x <= -1, + torch.zeros_like(x), + torch.where(x < 1, (1 / 4) * (x + 1) ** 2, x), + ) + + +def silu(x): + x = convert_to_tensor(x) + return tnn.silu(x) + + +def squareplus(x, b=4): + x = convert_to_tensor(x) + b = convert_to_tensor(b) + y = x + torch.sqrt(x**2 + b) + return y / 2 + + +def log_sigmoid(x): + x = convert_to_tensor(x) + return tnn.logsigmoid(x) + + +def leaky_relu(x, negative_slope=0.2): + x = convert_to_tensor(x) + return tnn.leaky_relu(x, negative_slope=negative_slope) + + +def hard_sigmoid(x): + x = convert_to_tensor(x) + return tnn.hardsigmoid(x) + + +def hard_silu(x): + x = convert_to_tensor(x) + return tnn.hardswish(x) + + +def elu(x, alpha=1.0): + x = convert_to_tensor(x) + return tnn.elu(x, alpha) + + +def selu(x): + x = convert_to_tensor(x) + return tnn.selu(x) + + +def gelu(x, approximate=True): + # TODO: torch.nn.gelu expects string approximate of `"none"` or `"tanh"` + x = convert_to_tensor(x) + if approximate: + return tnn.gelu(x, approximate="tanh") + return tnn.gelu(x) + + +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + return tnn.celu(x, alpha=alpha) + + +def glu(x, axis=-1): + x = convert_to_tensor(x) + return tnn.glu(x, dim=axis) + + +def hard_tanh(x): + x = convert_to_tensor(x) + return tnn.hardtanh(x, min_val=-1.0, max_val=1.0) + + +def hard_shrink(x, threshold=0.5): + x = convert_to_tensor(x) + return tnn.hardshrink(x, lambd=threshold) + + +def threshold(x, threshold, default_value): + x = convert_to_tensor(x) + return tnn.threshold(x, threshold=threshold, value=default_value) + + +def softmax(x, axis=-1): + x = convert_to_tensor(x) + dtype = backend.standardize_dtype(x.dtype) + # TODO: tnn.softmax doesn't support float16 using cpu + if ( + get_device() == "cpu" + and backend.standardize_dtype(x.dtype) == "float16" + ): + x = cast(x, "float32") + if axis is None: + # Unlike numpy, PyTorch will handle axis=None as axis=-1. + # We need this workaround for the reduction on every dim. + output = torch.reshape(x, [-1]) + output = tnn.softmax(output, dim=-1) + output = torch.reshape(output, x.shape) + else: + output = tnn.softmax(x, dim=axis) + return cast(output, dtype) + + +def log_softmax(x, axis=-1): + x = convert_to_tensor(x) + dtype = backend.standardize_dtype(x.dtype) + # TODO: tnn.log_softmax doesn't support float16 using cpu + if ( + get_device() == "cpu" + and backend.standardize_dtype(x.dtype) == "float16" + ): + x = cast(x, "float32") + if axis is None: + # Unlike numpy, PyTorch will handle axis=None as axis=-1. + # We need this workaround for the reduction on every dim. + output = torch.reshape(x, [-1]) + output = tnn.log_softmax(output, dim=-1) + output = torch.reshape(output, x.shape) + else: + output = tnn.log_softmax(x, dim=axis) + return cast(output, dtype) + + +def sparsemax(x, axis=-1): + # Sort logits along the specified axis in descending order + logits = convert_to_tensor(x) + logits_sorted, _ = torch.sort(logits, dim=axis, descending=True) + logits_cumsum = torch.cumsum(logits_sorted, dim=axis) + r = torch.arange( + 1, logits.size(axis) + 1, device=logits.device, dtype=logits.dtype + ) + r_shape = [1] * logits.ndim + r_shape[axis] = -1 # Broadcast to match the target axis + r = r.view(r_shape) + support = logits_sorted - (logits_cumsum - 1) / r > 0 + # Find the threshold + k = torch.sum(support, dim=axis, keepdim=True) + logits_cumsum_safe = torch.where( + support, logits_cumsum, torch.tensor(0.0, device=logits.device) + ) + tau = (torch.sum(logits_cumsum_safe, dim=axis, keepdim=True) - 1) / k + output = torch.clamp(logits - tau, min=0.0) + return output + + +def _compute_padding_length( + input_length, kernel_length, stride, dilation_rate=1 +): + """Compute padding length along one dimension with support + for asymmetric padding.""" + effective_k_size = (kernel_length - 1) * dilation_rate + 1 + if stride == 1: + # total padding is kernel_size - 1 + total_padding = effective_k_size - 1 + else: + # calc. needed padding for case with stride involved + output_size = (input_length + stride - 1) // stride + total_padding = max( + 0, (output_size - 1) * stride + effective_k_size - input_length + ) + + # divide padding evenly, with extra pixel going at the end if needed + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + return (left_padding, right_padding) + + +def _apply_same_padding( + inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1 +): + """Apply same padding to the input tensor. + + This function will evaluate if the padding value is compatible with torch + functions. To avoid calling `pad()` as much as possible, which may cause + performance or memory issues, when compatible, it does not apply the padding + to the tensor, but returns the input tensor and the padding value to pass to + the torch functions. If not compatible, it returns the padded tensor and 0 + as the padding value. + + Returns: + tensor: A padded tensor or the inputs. + padding: The padding value, ready to pass to the torch functions. + """ + spatial_shape = inputs.shape[2:] + num_spatial_dims = len(spatial_shape) + padding = [] + + if operation_type != "pooling": + dilation_rate = standardize_tuple( + dilation_rate, num_spatial_dims, "dilation_rate" + ) + + for i in range(num_spatial_dims): + dil = 1 if operation_type == "pooling" else dilation_rate[i] + pad = _compute_padding_length( + spatial_shape[i], kernel_size[i], strides[i], dil + ) + padding.append(pad) + + # convert padding to torch format + if all(left == right for left, right in padding): + return inputs, [left for left, _ in padding] + + # else, need to pad manually + flattened_padding = [] + for pad in reversed(padding): + flattened_padding.extend(pad) + + mode = "replicate" if operation_type == "pooling" else "constant" + return tnn.pad(inputs, pad=tuple(flattened_padding), mode=mode), 0 + + +def _transpose_spatial_inputs(inputs): + """Transpose inputs from channels_last to channels_first format.""" + # Torch pooling does not support `channels_last` format, so + # we need to transpose to `channels_first` format. + ndim = inputs.ndim - 2 + if ndim == 1: # 1D case + return torch.permute(inputs, (0, 2, 1)) + elif ndim == 2: # 2D case + return torch.permute(inputs, (0, 3, 1, 2)) + elif ndim == 3: # 3D case + return torch.permute(inputs, (0, 4, 1, 2, 3)) + raise ValueError( + "Inputs must have ndim=3, 4 or 5, " + "corresponding to 1D, 2D and 3D inputs. " + f"Received input shape: {inputs.shape}." + ) + + +def _transpose_spatial_outputs(outputs): + # Undo the transpose in `_transpose_spatial_inputs`. + num_spatial_dims = len(outputs.shape) - 2 + if num_spatial_dims == 1: + outputs = torch.permute(outputs, (0, 2, 1)) + elif num_spatial_dims == 2: + outputs = torch.permute(outputs, (0, 2, 3, 1)) + elif num_spatial_dims == 3: + outputs = torch.permute(outputs, (0, 2, 3, 4, 1)) + return outputs + + +def _transpose_conv_kernel(kernel): + # Torch requires conv kernel of format + # `(out_channels, in_channels, spatial_dims)`, we need to transpose. + num_spatial_dims = len(kernel.shape) - 2 + if num_spatial_dims == 1: + kernel = torch.permute(kernel, (2, 1, 0)) + elif num_spatial_dims == 2: + kernel = torch.permute(kernel, (3, 2, 0, 1)) + elif num_spatial_dims == 3: + kernel = torch.permute(kernel, (4, 3, 0, 1, 2)) + return kernel + + +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + """Fixed max pooling implementation.""" + inputs = convert_to_tensor(inputs) + num_spatial_dims = inputs.ndim - 2 + pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size") + if strides is None: + strides = pool_size + else: + strides = standardize_tuple(strides, num_spatial_dims, "strides") + + data_format = backend.standardize_data_format(data_format) + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + if padding == "same": + # Torch does not natively support `"same"` padding, we need to manually + # apply the right amount of padding to `inputs`. + inputs, padding = _apply_same_padding( + inputs, pool_size, strides, data_format, "pooling" + ) + else: + padding = 0 + + device = get_device() + # Torch max pooling ops do not support symbolic tensors. + # Create a real tensor to execute the ops. + if device == "meta": + inputs = torch.empty( + size=inputs.shape, dtype=inputs.dtype, device="cpu" + ) + + if num_spatial_dims == 1: + outputs = tnn.max_pool1d( + inputs, kernel_size=pool_size, stride=strides, padding=padding + ) + elif num_spatial_dims == 2: + outputs = tnn.max_pool2d( + inputs, kernel_size=pool_size, stride=strides, padding=padding + ) + elif num_spatial_dims == 3: + outputs = tnn.max_pool3d( + inputs, kernel_size=pool_size, stride=strides, padding=padding + ) + else: + raise ValueError( + "Inputs to pooling op must have ndim=3, 4 or 5, " + "corresponding to 1D, 2D and 3D inputs. " + f"Received input shape: {inputs.shape}." + ) + + outputs = outputs.to(device) + if data_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + +def average_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + """Fixed average pooling with correct padding calculation.""" + inputs = convert_to_tensor(inputs) + num_spatial_dims = inputs.ndim - 2 + pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size") + strides = ( + pool_size + if strides is None + else standardize_tuple(strides, num_spatial_dims, "strides") + ) + + data_format = backend.standardize_data_format(data_format) + orig_format = data_format + + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + if padding == "same": + # Torch does not natively support `"same"` padding, we need to manually + # apply the right amount of padding to `inputs`. + inputs, padding = _apply_same_padding( + inputs, + pool_size, + strides, + "channels_first", # we're in channels_first here + "pooling", + ) + else: + padding = 0 + + # apply pooling + if num_spatial_dims == 1: + outputs = tnn.avg_pool1d( + inputs, + kernel_size=pool_size, + stride=strides, + padding=padding, + count_include_pad=False, + ) + elif num_spatial_dims == 2: + outputs = tnn.avg_pool2d( + inputs, + kernel_size=pool_size, + stride=strides, + padding=padding, + count_include_pad=False, + ) + elif num_spatial_dims == 3: + outputs = tnn.avg_pool3d( + inputs, + kernel_size=pool_size, + stride=strides, + padding=padding, + count_include_pad=False, + ) + else: + raise ValueError( + "Inputs to pooling op must have ndim=3, 4 or 5, " + "corresponding to 1D, 2D and 3D inputs. " + f"Received input shape: {inputs.shape}." + ) + + if orig_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + + return outputs + + +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + """Convolution with fixed group handling.""" + inputs = convert_to_tensor(inputs) + kernel = convert_to_tensor(kernel) + num_spatial_dims = inputs.ndim - 2 + strides = standardize_tuple(strides, num_spatial_dims, "strides") + + data_format = backend.standardize_data_format(data_format) + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + kernel = _transpose_conv_kernel(kernel) + + # calc. groups snippet + in_channels = inputs.shape[1] + kernel_in_channels = kernel.shape[1] + if in_channels % kernel_in_channels != 0: + raise ValueError( + f"Input channels ({in_channels}) must be divisible by " + f"kernel input channels ({kernel_in_channels})" + ) + groups = in_channels // kernel_in_channels + + # handle padding + if padding == "same": + inputs, padding = _apply_same_padding( + inputs, + kernel.shape[2:], + strides, + data_format, + "conv", + dilation_rate, + ) + else: + padding = 0 + + # apply convolution + if num_spatial_dims == 1: + outputs = tnn.conv1d( + inputs, + kernel, + stride=strides, + padding=padding, + dilation=dilation_rate, + groups=groups, + ) + elif num_spatial_dims == 2: + outputs = tnn.conv2d( + inputs, + kernel, + stride=strides, + padding=padding, + dilation=dilation_rate, + groups=groups, + ) + elif num_spatial_dims == 3: + outputs = tnn.conv3d( + inputs, + kernel, + stride=strides, + padding=padding, + dilation=dilation_rate, + groups=groups, + ) + else: + raise ValueError( + "Inputs to conv operation should have ndim=3, 4, or 5," + "corresponding to 1D, 2D and 3D inputs. Received input " + f"shape: {inputs.shape}." + ) + + if data_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + kernel = convert_to_tensor(kernel) + kernel = torch.reshape( + kernel, kernel.shape[:-2] + (1, kernel.shape[-2] * kernel.shape[-1]) + ) + return conv(inputs, kernel, strides, padding, data_format, dilation_rate) + + +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + depthwise_conv_output = depthwise_conv( + inputs, + depthwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + return conv( + depthwise_conv_output, + pointwise_kernel, + strides=1, + padding="valid", + data_format=data_format, + dilation_rate=dilation_rate, + ) + + +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + inputs = convert_to_tensor(inputs) + kernel = convert_to_tensor(kernel) + num_spatial_dims = inputs.ndim - 2 + strides = standardize_tuple(strides, num_spatial_dims, "strides") + + data_format = backend.standardize_data_format(data_format) + ( + torch_padding, + torch_output_padding, + ) = compute_conv_transpose_padding_args_for_torch( + input_shape=inputs.shape, + kernel_shape=kernel.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + # Transpose kernel from keras format to torch format. + kernel = _transpose_conv_kernel(kernel) + kernel_spatial_shape = kernel.shape[2:] + if isinstance(dilation_rate, int): + dilation_rate = [dilation_rate] * len(kernel_spatial_shape) + + if num_spatial_dims == 1: + outputs = tnn.conv_transpose1d( + inputs, + kernel, + stride=strides, + padding=torch_padding, + output_padding=torch_output_padding, + dilation=dilation_rate, + ) + elif num_spatial_dims == 2: + outputs = tnn.conv_transpose2d( + inputs, + kernel, + stride=strides, + padding=torch_padding, + output_padding=torch_output_padding, + dilation=dilation_rate, + ) + elif num_spatial_dims == 3: + outputs = tnn.conv_transpose3d( + inputs, + kernel, + stride=strides, + padding=torch_padding, + output_padding=torch_output_padding, + dilation=dilation_rate, + ) + else: + raise ValueError( + "Inputs to conv transpose operation should have ndim=3, 4, or 5," + "corresponding to 1D, 2D and 3D inputs. Received input " + f"shape: {inputs.shape}." + ) + if data_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + +def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with torch backend") + # Axis is the output axis. By default, PyTorch, outputs to last axis. + # If axis is not last, change output to axis and shift remaining elements. + x = convert_to_tensor(x, dtype=torch.long) + zero = convert_to_tensor(0, dtype=torch.long) + + # Torch one_hot does not natively handle negative values, so we add some + # manual handling for negatives in the input to one_hot by using max(x, 0). + # The output will have some invalid results, so we set them back to 0 using + # `where` afterwards. + output = tnn.one_hot(torch.clamp(x, min=0), num_classes) + output = where(expand_dims(x, axis=-1) >= 0, output, zero) + output = convert_to_tensor(output, dtype=dtype) + dims = output.dim() + if axis != -1 and axis != dims: + new_axes_order = list(range(dims)) + new_axes_order[axis] = -1 # Shifts output to axis position + # Shift remaining axes with offset by 1 since output moved to `axis`. + for ax in range(axis + 1, dims): + new_axes_order[ax] -= 1 + output = output.permute(new_axes_order) + return output + + +def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with torch backend") + x = convert_to_tensor(x) + reduction_axis = 1 if len(x.shape) > 1 else 0 + outputs = torch.amax( + one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype), + dim=reduction_axis, + ) + return outputs + + +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = convert_to_tensor(target) + output = convert_to_tensor(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + + if from_logits: + log_prob = tnn.log_softmax(output, dim=axis) + else: + output = output / torch.sum(output, dim=axis, keepdim=True) + output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = torch.log(output) + return -torch.sum(target * log_prob, dim=axis) + + +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + target = convert_to_tensor(target, dtype=torch.long) + output = convert_to_tensor(output) + + if len(target.shape) == len(output.shape) and target.shape[-1] == 1: + target = torch.squeeze(target, dim=-1) + + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + output_shape_without_class_dim = list(output.shape) + del output_shape_without_class_dim[axis] + + if list(target.shape) != output_shape_without_class_dim: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if from_logits: + log_prob = tnn.log_softmax(output, dim=axis) + else: + output = output / torch.sum(output, dim=axis, keepdim=True) + output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + log_prob = torch.log(output) + target = one_hot(target, output.shape[axis], axis=axis) + return -torch.sum(target * log_prob, dim=axis) + + +def binary_crossentropy(target, output, from_logits=False): + target = convert_to_tensor(target) + output = convert_to_tensor(output) + + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + # By default, PyTorch, does reduction of `sum` over all rows, + # change reduction to `none` to keep dim + if from_logits: + return tnn.binary_cross_entropy_with_logits( + output, target, reduction="none" + ) + else: + output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + return tnn.binary_cross_entropy(output, target, reduction="none") + + +def moments(x, axes, keepdims=False, synchronized=False): + if synchronized: + raise NotImplementedError( + "Argument synchronized=True is not supported with PyTorch." + ) + x = convert_to_tensor(x) + # The dynamic range of float16 is too limited for statistics. As a + # workaround, we simply perform the operations on float32 and convert back + # to float16 + need_cast = False + ori_dtype = backend.standardize_dtype(x.dtype) + if ori_dtype == "float16": + need_cast = True + x = cast(x, "float32") + + mean = torch.mean(x, dim=axes, keepdim=True) + + # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster + # but less numerically stable. + # Note: stop_gradient does not change the gradient to the mean, because that + # gradient is zero. + variance = torch.mean( + torch.square(x), dim=axes, keepdim=True + ) - torch.square(mean) + + if not keepdims: + mean = torch.squeeze(mean, axes) + variance = torch.squeeze(variance, axes) + if need_cast: + # avoid overflow and underflow when casting from float16 to float32 + mean = torch.clip( + mean, + torch.finfo(torch.float16).min, + torch.finfo(torch.float16).max, + ) + variance = torch.clip( + variance, + torch.finfo(torch.float16).min, + torch.finfo(torch.float16).max, + ) + mean = cast(mean, ori_dtype) + variance = cast(variance, ori_dtype) + return mean, variance + + +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + x = convert_to_tensor(x) + mean = convert_to_tensor(mean) + variance = convert_to_tensor(variance) + + shape = [1] * len(x.shape) + shape[axis] = mean.shape[0] + mean = torch.reshape(mean, shape) + variance = torch.reshape(variance, shape) + + if offset is not None: + offset = convert_to_tensor(offset) + offset = torch.reshape(offset, shape) + else: + offset = torch.zeros_like(mean) + if scale is not None: + scale = convert_to_tensor(scale) + scale = torch.reshape(scale, shape) + else: + scale = torch.ones_like(variance) + + return ( + x.subtract(mean) + .mul_(variance.add(epsilon).rsqrt_().mul(scale)) + .add_(offset) + ) + + +def ctc_loss(target, output, target_length, output_length, mask_index=0): + target = convert_to_tensor(target) + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length) + output_length = convert_to_tensor(output_length) + + # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss` + dtype = backend.result_type(output.dtype, "float32") + output = cast(output, dtype) + + output = torch.transpose(output, 1, 0) + logits = tnn.log_softmax(output, dim=-1) + loss = tnn.ctc_loss( + logits, + target, + output_length, + target_length, + blank=mask_index, + reduction="none", + ) + return loss + + +def _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=True, + mask_index=None, +): + inputs = convert_to_tensor(inputs) + sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32") + batch_size, max_length, num_classes = inputs.shape + + if mask_index is None: + mask_index = num_classes - 1 + + indices = torch.argmax(inputs, axis=-1) + indices = cast(indices, "int32") + scores = torch.max(inputs, axis=-1)[0] + + seqlen_mask = torch.arange(max_length, device=indices.device)[None, :] + seqlen_mask = seqlen_mask >= sequence_lengths[:, None] + + indices = torch.where(seqlen_mask, mask_index, indices) + scores = torch.where(seqlen_mask, 0.0, scores) + + if merge_repeated: + repeat = indices[:, 1:] == indices[:, :-1] + repeat = tnn.pad(repeat, (1, 0, 0, 0)) + indices = torch.where(repeat, mask_index, indices) + + # We set to -1 for blank labels + invalid_mask = indices == mask_index + indices = torch.where(invalid_mask, -1, indices) + + # We rearrange the indices by moving `mask_index` to the end of the array + order = torch.unsqueeze( + torch.arange(max_length, device=indices.device), dim=0 + ) # [1, N] + order = torch.tile(order, (batch_size, 1)) # [B, N] + order = torch.where(invalid_mask, max_length, order) + order = torch.argsort(order, dim=-1) + indices = torch.take_along_dim(indices, order, dim=-1) + + scores = -torch.sum(scores, axis=1)[:, None] + indices = torch.unsqueeze(indices, dim=0) + return indices, scores + + +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + inputs = convert_to_tensor(inputs) + dtype = backend.result_type(inputs.dtype, "float32") + inputs = cast(inputs, dtype) + + if strategy == "greedy": + return _ctc_greedy_decode( + inputs, + sequence_lengths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + elif strategy == "beam_search": + raise NotImplementedError( + "Torch backend doesn't yet support the beam search strategy for CTC" + "decoding." + ) + else: + raise ValueError( + f"Invalid strategy {strategy}. Supported values are " + "'greedy' and 'beam_search'." + ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + x1, x2 = ( + convert_to_tensor(x1), + convert_to_tensor(x2), + ) + max_val = convert_to_tensor(max_val, dtype=x1.dtype) + mse = torch.mean((x1 - x2) ** 2) + psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse) + return psnr + + +def _get_large_negative(dtype): + dtype = backend.standardize_dtype(dtype) + if dtype == "float16": + val = 65500.0 + else: + val = 3.38953e38 + return convert_to_tensor(val * -0.7, dtype=dtype) + + +def _can_use_flash_attention( + query, key, value, mask=None, is_causal=False, raise_error=False +): + """Verify the availability of flash attention.""" + try: + from torch.backends.cuda import SDPAParams + from torch.backends.cuda import can_use_flash_attention + except ImportError: + if raise_error: + raise ImportError( + "Flash attention is not supported in your current PyTorch " + "version. Please update it by following the official guide: " + "https://pytorch.org/get-started/locally/" + ) + return False + + try: + spda_params = SDPAParams( + query, + key, + value, + mask, + 0.0, # dropout_p + is_causal, + False, # enable_gqa + ) + except TypeError: + # The old function signature for the older version of PyTorch + spda_params = SDPAParams( + query, + key, + value, + mask, + 0.0, # dropout_p + is_causal, + ) + if raise_error and can_use_flash_attention(spda_params, True) is False: + raise RuntimeError( + "Flash attention is not supported with the provided inputs. " + "Please check the warnings for more details." + ) + return can_use_flash_attention(spda_params, False) + + +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, +): + if bias is not None: + raise ValueError( + "torch's `dot_product_attention` doesn't support `bias`." + ) + query = convert_to_tensor(query) + key = convert_to_tensor(key) + value = convert_to_tensor(value) + if len(query.shape) != 4 or len(key.shape) != 4 or len(value.shape) != 4: + raise ValueError( + "`dot_product_attention` only supports 4D inputs. " + f"Received: query.shape={query.shape}, key.shape={key.shape}, " + f"value.shape={value.shape}." + ) + compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype) + query = cast(query, compute_dtype) + key = cast(key, compute_dtype) + value = cast(value, compute_dtype) + + mask = mask if mask is None else convert_to_tensor(mask, dtype="bool") + if mask is not None: + # Explicit set `is_causal` to `False` when `mask` is not `None`. + is_causal = False + mask = torch.where(mask, 0.0, _get_large_negative(query.dtype)) + + axis0, axis1 = 1, 2 + query = torch.transpose(query, axis0, axis1) + key = torch.transpose(key, axis0, axis1) + value = torch.transpose(value, axis0, axis1) + + if flash_attention is None: + flash_attention = _can_use_flash_attention( + query, key, value, mask, is_causal + ) + elif flash_attention is True: + # Use `raise_error=True` to provide more details if the inputs failed to + # use flash attention + _can_use_flash_attention( + query, key, value, mask, is_causal, raise_error=True + ) + if flash_attention: + with torch.nn.attention.sdpa_kernel( + backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION], + ): + attention_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=mask, + is_causal=is_causal, + scale=scale, + ) + else: + if mask is not None: + mask = mask.contiguous() + attention_output = torch.nn.functional.scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=mask, + is_causal=is_causal, + scale=scale, + ) + return torch.transpose(attention_output, axis1, axis0) + + +def unfold(input, kernel_size, dilation=1, padding=0, stride=1): + """Native PyTorch implementation of Unfold. + Extract sliding local blocks from a **NCHW** batched image tensor. + + Args: + input: 4-D tensor, shape (N, C, H, W) **required**. + kernel_size: int or (kH, kW) + dilation: int or (dH, dW), default 1 + padding: int or (pH, pW), default 0 + stride: int or (sH, sW), default 1 + + Returns: + 3-D tensor, shape (N, C*kH*kW, L) + """ + return tnn.unfold( + input, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py new file mode 100644 index 000000000000..553faea4fd40 --- /dev/null +++ b/keras/src/backend/torch/numpy.py @@ -0,0 +1,1944 @@ +import builtins +import math + +import numpy as np +import torch + +from keras.src.backend import KerasTensor +from keras.src.backend import config +from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import canonicalize_axis +from keras.src.backend.common.backend_utils import to_tuple_or_list +from keras.src.backend.common.backend_utils import vectorize_impl +from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.torch.core import cast +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import get_device +from keras.src.backend.torch.core import is_tensor +from keras.src.backend.torch.core import to_torch_dtype + +TORCH_INT_TYPES = ( + torch.int8, + torch.int16, + torch.int32, + torch.int64, +) + + +def rot90(array, k=1, axes=(0, 1)): + """Rotate an array by 90 degrees in the specified plane using PyTorch. + + Args: + array: Input tensor + k: Number of 90-degree rotations (default=1) + axes: Tuple of two axes that define the + plane of rotation (defaults to `(0, 1)`). + + Returns: + Rotated tensor + """ + array = convert_to_tensor(array) + + if array.ndim < 2: + raise ValueError( + "Input array must have at least 2 dimensions. " + f"Received: array.ndim={array.ndim}" + ) + if len(axes) != 2 or axes[0] == axes[1]: + raise ValueError( + f"Invalid axes: {axes}. Axes must be a tuple " + "of two different dimensions." + ) + + axes = tuple(axis if axis >= 0 else array.ndim + axis for axis in axes) + + if not builtins.all(0 <= axis < array.ndim for axis in axes): + raise ValueError( + f"Invalid axes {axes} for tensor with {array.ndim} dimensions" + ) + + rotated = torch.rot90(array, k=k, dims=axes) + if isinstance(array, np.ndarray): + rotated = rotated.cpu().numpy() + + return rotated + + +def add(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return torch.add(x1, x2) + + +def einsum(subscripts, *operands, **kwargs): + operands = [convert_to_tensor(operand) for operand in operands] + # When all operands are of int8, we cast the result to int32 to align with + # the behavior of jax. + dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) + if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + compute_dtype = "int32" + if get_device() == "cuda": + # TODO: torch.einsum doesn't support int32 when using cuda + compute_dtype = config.floatx() + # prevent overflow + operands = [cast(operand, compute_dtype) for operand in operands] + return cast(torch.einsum(subscripts, *operands), "int32") + return torch.einsum(subscripts, *operands) + + +def subtract(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + # TODO: torch.subtract doesn't support bool + if standardize_dtype(x1.dtype) == "bool": + x1 = cast(x1, x2.dtype) + if standardize_dtype(x2.dtype) == "bool": + x2 = cast(x2, x1.dtype) + return torch.subtract(x1, x2) + + +def matmul(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + def can_use_int_matmul(x1, x2): + # torch._int_mm only accepts the following conditions: + # 1. cuda + # 2. both inputs must have int8 dtype + # 3. both inputs must be 2d + # 4. x1.shape must be [>16, >= 16 and a multiplier of 8] + # 5. x2.shape must be [>= 16 and a multiplier of 8, multiplier of 8] + if get_device() != "cuda": + return False + x1_dtype = standardize_dtype(x1.dtype) + x2_dtype = standardize_dtype(x2.dtype) + if x1_dtype != "int8" or x2_dtype != "int8": + return False + x1_shape = x1.shape + x2_shape = x2.shape + if x1.ndim != 2 or x2.ndim != 2: + return False + if x1_shape[0] <= 16 or x1_shape[1] < 16 or x1_shape[1] % 8 != 0: + return False + if x2_shape[0] < 16 or x2_shape[0] % 8 != 0 or x2_shape[1] % 8 != 0: + return False + return True + + # Shortcut for torch._int_mm + # TODO: Loosen the restriction of the usage of torch._int_mm + # TODO: We should replace torch._int_mm with the public api if possible + if can_use_int_matmul(x1, x2): + return torch._int_mm(x1, x2) + + x1_dtype = standardize_dtype(x1.dtype) + x2_dtype = standardize_dtype(x2.dtype) + if x1_dtype == "int8" and x2_dtype == "int8": + result_dtype = "int32" + else: + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + compute_dtype = result_dtype + + # TODO: torch.matmul doesn't support bool + if compute_dtype == "bool": + compute_dtype = config.floatx() + # TODO: torch.matmul doesn't support float16 with cpu + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + # TODO: torch.matmul doesn't support integer types with cuda + if get_device() == "cuda" and "int" in compute_dtype: + compute_dtype = config.floatx() + + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.matmul(x1, x2), result_dtype) + + +def multiply(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return torch.multiply(x1, x2) + + +def mean(x, axis=None, keepdims=False): + if isinstance(x, (list, tuple)): + x = stack(x) + x = convert_to_tensor(x) + if axis == () or axis == []: + # Torch handles the empty axis case differently from numpy. + return x + axis = to_tuple_or_list(axis) # see [NB] below + + ori_dtype = standardize_dtype(x.dtype) + # torch.mean only supports floating point inputs + compute_dtype = dtypes.result_type(x.dtype, "float32") + if "int" in ori_dtype or ori_dtype == "bool": + result_dtype = compute_dtype + else: + result_dtype = ori_dtype + + # [NB] the python torch op torch.mean() is generated into + # `torch._C._VariableFunctions.pyi`, and the method + # signature is overloaded. + # Dynamo won't actually find the correct signature of + # `torch.mean()` if arguments are passed via kwargs + # So we have to pass the arguments via positional args + # EXCEPT for those that are forced as kwargs via the `*` + # delimiter in the overloaded method signatures. + # Additionally, we have to create a singleton-tuple + # when `axis` is an int to match the existing fn signature + result = torch.mean( + x, + axis, + keepdims, + dtype=to_torch_dtype(compute_dtype), + ) + return cast(result, result_dtype) + + +def max(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + if 0 in x.shape: + if initial is None: + raise ValueError("Cannot compute the max of an empty tensor.") + elif keepdims: + return torch.full((1,) * len(x.shape), initial) + else: + return torch.tensor(initial) + + if axis is None: + result = torch.max(x) + else: + result = amax(x, axis=axis, keepdims=keepdims) + if isinstance(getattr(result, "values", None), torch.Tensor): + result = result.values + + if initial is not None: + dtype = to_torch_dtype(result.dtype) + initial = convert_to_tensor(initial, dtype=dtype) + return torch.maximum( + result, torch.full(result.shape, initial, dtype=dtype) + ) + return result + + +def ones(shape, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + if isinstance(shape, int): + shape = (shape,) + return torch.ones(size=shape, dtype=dtype, device=get_device()) + + +def zeros(shape, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + if isinstance(shape, int): + shape = (shape,) + return torch.zeros(size=shape, dtype=dtype, device=get_device()) + + +def zeros_like(x, dtype=None): + x = convert_to_tensor(x) + dtype = to_torch_dtype(dtype or x.dtype) + return torch.zeros_like(x, dtype=dtype) + + +def absolute(x): + x = convert_to_tensor(x) + # bool are always non-negative + if standardize_dtype(x.dtype) == "bool": + return x + return torch.abs(x) + + +def abs(x): + return absolute(x) + + +def all(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + if axis is None: + return cast(torch.all(x), "bool") + axis = to_tuple_or_list(axis) + for a in axis: + # `torch.all` does not handle multiple axes. + x = torch.all(x, dim=a, keepdim=keepdims) + return cast(x, "bool") + + +def angle(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + + # torch.angle doesn't support float16 with cuda + if get_device() != "cpu" and ori_dtype == "float16": + x = cast(x, "float32") + return cast(torch.angle(x), "float16") + return torch.angle(x) + + +def any(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + if axis is None: + return cast(torch.any(x), "bool") + axis = to_tuple_or_list(axis) + for a in axis: + # `torch.any` does not handle multiple axes. + x = torch.any(x, dim=a, keepdim=keepdims) + return cast(x, "bool") + + +def amax(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + if axis is None: + return torch.amax(x) + if axis == () or axis == []: + # Torch handles the empty axis case differently from numpy. + return x + return torch.amax(x, dim=axis, keepdim=keepdims) + + +def amin(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + if axis is None: + return torch.amin(x) + if axis == () or axis == []: + # Torch handles the empty axis case differently from numpy. + return x + return torch.amin(x, dim=axis, keepdim=keepdims) + + +def append(x1, x2, axis=None): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + if axis is None: + return torch.cat((x1.flatten(), x2.flatten())) + return torch.cat((x1, x2), dim=axis) + + +def arange(start, stop=None, step=None, dtype=None): + if dtype is None: + dtypes_to_resolve = [getattr(start, "dtype", type(start))] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + if step is not None: + dtypes_to_resolve.append(getattr(step, "dtype", type(step))) + dtype = dtypes.result_type(*dtypes_to_resolve) + dtype = to_torch_dtype(dtype) + if stop is None: + start, stop = 0, start + if step is None: + step = 1 + return torch.arange( + start, stop, step=step, dtype=dtype, device=get_device() + ) + + +def arccos(x): + x = convert_to_tensor(x) + return torch.arccos(x) + + +def arccosh(x): + x = convert_to_tensor(x) + return torch.arccosh(x) + + +def arcsin(x): + x = convert_to_tensor(x) + return torch.arcsin(x) + + +def arcsinh(x): + x = convert_to_tensor(x) + return torch.arcsinh(x) + + +def arctan(x): + x = convert_to_tensor(x) + return torch.arctan(x) + + +def arctan2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + compute_dtype = result_dtype + # TODO: torch.arctan2 doesn't support float16 with cpu + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.arctan2(x1, x2), result_dtype) + + +def arctanh(x): + x = convert_to_tensor(x) + return torch.arctanh(x) + + +def argmax(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + + # TODO: torch.argmax doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "uint8") + + return cast(torch.argmax(x, dim=axis, keepdim=keepdims), dtype="int32") + + +def argmin(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + + # TODO: torch.argmin doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "uint8") + + return cast(torch.argmin(x, dim=axis, keepdim=keepdims), dtype="int32") + + +def argsort(x, axis=-1): + x = convert_to_tensor(x) + + # TODO: torch.argsort doesn't support bool + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "uint8") + + if axis is None: + axis = -1 + x = x.reshape(-1) + return cast(torch.argsort(x, dim=axis, stable=True), dtype="int32") + + +def array(x, dtype=None): + return convert_to_tensor(x, dtype=dtype) + + +def average(x, axis=None, weights=None): + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = cast(x, dtype) + if weights is not None: + weights = cast(weights, dtype) + if axis == () or axis == []: + # Torch handles the empty axis case differently from numpy. + return x + if weights is not None: + return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum( + weights, dim=-1 + ) + return torch.mean(x, axis) + + +def bartlett(x): + x = convert_to_tensor(x) + return torch.signal.windows.bartlett(x) + + +def hamming(x): + x = convert_to_tensor(x) + return torch.signal.windows.hamming(x) + + +def hanning(x): + x = convert_to_tensor(x) + return torch.signal.windows.hann(x) + + +def heaviside(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype == "int64": + dtype = "float64" + + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + + return torch.heaviside(x1, x2) + + +def kaiser(x, beta): + x = convert_to_tensor(x) + return torch.signal.windows.kaiser(x, beta=beta) + + +def bincount(x, weights=None, minlength=0, sparse=False): + if sparse: + raise ValueError("Unsupported value `sparse=True` with torch backend") + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + else: + dtype = "int32" + if len(x.shape) == 2: + if weights is None: + + def bincount_fn(arr): + return torch.bincount(arr, minlength=minlength) + + bincounts = list(map(bincount_fn, x)) + else: + + def bincount_fn(arr_w): + return torch.bincount( + arr_w[0], weights=arr_w[1], minlength=minlength + ) + + bincounts = list(map(bincount_fn, zip(x, weights))) + + return cast(torch.stack(bincounts), dtype) + return cast(torch.bincount(x, weights, minlength), dtype) + + +def bitwise_and(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + return torch.bitwise_and(x, y) + + +def bitwise_invert(x): + x = convert_to_tensor(x) + return torch.bitwise_not(x) + + +def bitwise_not(x): + return bitwise_invert(x) + + +def bitwise_or(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + return torch.bitwise_or(x, y) + + +def bitwise_xor(x, y): + x = convert_to_tensor(x) + y = convert_to_tensor(y) + return torch.bitwise_xor(x, y) + + +def bitwise_left_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + return torch.bitwise_left_shift(x, y) + + +def left_shift(x, y): + return bitwise_left_shift(x, y) + + +def bitwise_right_shift(x, y): + x = convert_to_tensor(x) + if not isinstance(y, int): + y = convert_to_tensor(y) + return torch.bitwise_right_shift(x, y) + + +def right_shift(x, y): + return bitwise_right_shift(x, y) + + +def blackman(x): + x = convert_to_tensor(x) + return torch.signal.windows.blackman(x) + + +def broadcast_to(x, shape): + x = convert_to_tensor(x) + return torch.broadcast_to(x, shape) + + +def cbrt(x): + x = convert_to_tensor(x) + + dtype = standardize_dtype(x.dtype) + if dtype == "bool": + x = cast(x, "int32") + elif dtype == "int64": + x = cast(x, "float64") + + return torch.sign(x) * torch.abs(x) ** (1.0 / 3.0) + + +def ceil(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + + # TODO: torch.ceil doesn't support bool + if ori_dtype == "bool": + x = cast(x, "uint8") + # TODO: torch.ceil doesn't support float16 with cpu + elif get_device() == "cpu" and ori_dtype == "float16": + x = cast(x, config.floatx()) + + if ori_dtype == "int64": + dtype = config.floatx() + else: + dtype = dtypes.result_type(ori_dtype, float) + return cast(torch.ceil(x), dtype=dtype) + + +def clip(x, x_min, x_max): + x = convert_to_tensor(x) + x_min = convert_to_tensor(x_min) + x_max = convert_to_tensor(x_max) + ori_dtype = standardize_dtype(x.dtype) + + # TODO: torch.clip doesn't support float16 with cpu + if get_device() == "cpu" and ori_dtype == "float16": + x = cast(x, "float32") + return cast(torch.clip(x, min=x_min, max=x_max), "float16") + + if ori_dtype == "bool": + x = cast(x, "int32") + return torch.clip(x, min=x_min, max=x_max) + + +def concatenate(xs, axis=0): + xs = [convert_to_tensor(x) for x in xs] + return torch.cat(xs, dim=axis) + + +def conjugate(x): + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.conj(x).resolve_conj() + + +def conj(x): + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.conj(x).resolve_conj() + + +def copy(x): + x = convert_to_tensor(x) + return torch.clone(x) + + +def cos(x): + x = convert_to_tensor(x) + return torch.cos(x) + + +def cosh(x): + x = convert_to_tensor(x) + return torch.cosh(x) + + +def count_nonzero(x, axis=None): + x = convert_to_tensor(x) + if axis == () or axis == []: + # Torch handles the empty axis case differently from numpy. + return cast(torch.ne(x, 0), "int32") + return cast(torch.count_nonzero(x, dim=axis).T, "int32") + + +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + if axisa != -1 or axisb != -1 or axisc != -1: + raise ValueError( + "Torch backend does not support `axisa`, `axisb`, or `axisc`. " + f"Received: axisa={axisa}, axisb={axisb}, axisc={axisc}. Please " + "use `axis` arg in torch backend." + ) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + compute_dtype = dtypes.result_type(x1.dtype, x2.dtype) + result_dtype = compute_dtype + # TODO: torch.cross doesn't support bfloat16 with gpu + if get_device() == "cuda" and compute_dtype == "bfloat16": + compute_dtype = "float32" + # TODO: torch.cross doesn't support float16 with cpu + elif get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.cross(x1, x2, dim=axis), result_dtype) + + +def cumprod(x, axis=None, dtype=None): + x = convert_to_tensor(x) + if axis is None: + x = x.flatten() + axis = 0 + dtype = dtypes.result_type(dtype or x.dtype) + if dtype == "bool": + dtype = "int32" + # TODO: torch.cumprod doesn't support float16 with cpu + elif get_device() == "cpu" and dtype == "float16": + return cast( + torch.cumprod(x, dim=axis, dtype=to_torch_dtype("float32")), + "float16", + ) + return torch.cumprod(x, dim=axis, dtype=to_torch_dtype(dtype)) + + +def cumsum(x, axis=None, dtype=None): + x = convert_to_tensor(x) + if axis is None: + x = x.flatten() + axis = 0 + dtype = dtypes.result_type(dtype or x.dtype) + if dtype == "bool": + dtype = "int32" + # TODO: torch.cumsum doesn't support float16 with cpu + elif get_device() == "cpu" and dtype == "float16": + return cast( + torch.cumsum(x, dim=axis, dtype=to_torch_dtype("float32")), + "float16", + ) + return torch.cumsum(x, dim=axis, dtype=to_torch_dtype(dtype)) + + +def deg2rad(x): + x = convert_to_tensor(x) + + if standardize_dtype(x.dtype) == "int64": + return cast(torch.deg2rad(x), "float64") + + return torch.deg2rad(x) + + +def diag(x, k=0): + x = convert_to_tensor(x) + return torch.diag(x, diagonal=k) + + +def diagflat(x, k=0): + x = convert_to_tensor(x) + return torch.diagflat(x, offset=k) + + +def diagonal(x, offset=0, axis1=0, axis2=1): + x = convert_to_tensor(x) + return torch.diagonal( + x, + offset=offset, + dim1=axis1, + dim2=axis2, + ) + + +def diff(a, n=1, axis=-1): + a = convert_to_tensor(a) + return torch.diff(a, n=n, dim=axis) + + +def digitize(x, bins): + x = convert_to_tensor(x) + bins = convert_to_tensor(bins) + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "uint8") + return cast(torch.bucketize(x, bins, right=True), "int32") + + +def dot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + # GPU only supports float types + compute_dtype = dtypes.result_type(result_dtype, float) + + # TODO: torch.matmul doesn't support float16 with cpu + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + if x1.ndim == 0 or x2.ndim == 0: + return cast(torch.multiply(x1, x2), result_dtype) + return cast(torch.matmul(x1, x2), result_dtype) + + +def empty(shape, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + return torch.empty(size=shape, dtype=dtype, device=get_device()) + + +def equal(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.eq(x1, x2) + + +def exp(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, config.floatx()) + return torch.exp(x) + + +def exp2(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, config.floatx()) + return torch.exp2(x) + + +def expand_dims(x, axis): + x = convert_to_tensor(x) + axis = to_tuple_or_list(axis) + out_ndim = len(x.shape) + len(axis) + axis = sorted([canonicalize_axis(a, out_ndim) for a in axis]) + for a in axis: + x = torch.unsqueeze(x, dim=a) + return x + + +def expm1(x): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, config.floatx()) + return torch.expm1(x) + + +def flip(x, axis=None): + x = convert_to_tensor(x) + if axis is None: + axis = tuple(range(x.ndim)) + axis = to_tuple_or_list(axis) + return torch.flip(x, dims=axis) + + +def floor(x): + x = convert_to_tensor(x) + dtype = ( + config.floatx() + if standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + x = cast(x, dtype) + return torch.floor(x) + + +def full(shape, fill_value, dtype=None): + dtype = to_torch_dtype(dtype) + fill_value = convert_to_tensor(fill_value, dtype=dtype) + if len(fill_value.shape) > 0: + # `torch.full` only supports scala `fill_value`. + expand_size = len(shape) - len(fill_value.shape) + tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape) + return torch.tile(fill_value, tile_shape) + return torch.full( + size=shape, fill_value=fill_value, dtype=dtype, device=get_device() + ) + + +def full_like(x, fill_value, dtype=None): + dtype = dtype or x.dtype + return full(shape=x.shape, fill_value=fill_value, dtype=dtype) + + +def gcd(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return torch.gcd(x1, x2) + + +def greater(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.greater(x1, x2) + + +def greater_equal(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.greater_equal(x1, x2) + + +def hstack(xs): + xs = [convert_to_tensor(x) for x in xs] + return torch.hstack(xs) + + +def hypot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = config.floatx() + elif dtype == "int64": + dtype = "float64" + + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + + return torch.hypot(x1, x2) + + +def identity(n, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + + # TODO: torch.eye doesn't support bfloat16 with cpu + if get_device() == "cpu" and dtype == torch.bfloat16: + return cast( + torch.eye(n, dtype=to_torch_dtype("float32"), device=get_device()), + dtype, + ) + return torch.eye(n, dtype=dtype, device=get_device()) + + +def imag(x): + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.imag(x) + + +def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = cast(x1, result_dtype) + x2 = cast(x2, result_dtype) + return torch.isclose(x1, x2, rtol, atol, equal_nan) + + +def isfinite(x): + x = convert_to_tensor(x) + return torch.isfinite(x) + + +def isin(x1, x2, assume_unique=False, invert=False): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype == "bool": + x1 = cast(x1, "int32") + x2 = cast(x2, "int32") + + if standardize_dtype(x1.dtype) == "bool": + x1 = cast(x1, x2.dtype) + if standardize_dtype(x2.dtype) == "bool": + x2 = cast(x2, x1.dtype) + + return torch.isin(x1, x2, assume_unique=assume_unique, invert=invert) + + +def isinf(x): + x = convert_to_tensor(x) + return torch.isinf(x) + + +def isnan(x): + x = convert_to_tensor(x) + return torch.isnan(x) + + +def isneginf(x): + x = convert_to_tensor(x) + return torch.isneginf(x) + + +def isposinf(x): + x = convert_to_tensor(x) + return torch.isposinf(x) + + +def kron(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return torch.kron(x1, x2) + + +def lcm(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return torch.lcm(x1, x2) + + +def less(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.less(x1, x2) + + +def less_equal(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.less_equal(x1, x2) + + +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + if axis != 0: + raise ValueError( + "torch.linspace does not support an `axis` argument. " + f"Received axis={axis}" + ) + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + dtype = to_torch_dtype(dtype) + + step = convert_to_tensor(torch.nan) + if endpoint: + if num > 1: + step = (stop - start) / (num - 1) + else: + if num > 0: + step = (stop - start) / num + if num > 1: + stop = stop - ((stop - start) / num) + if hasattr(start, "__len__") and hasattr(stop, "__len__"): + start = convert_to_tensor(start, dtype=dtype) + stop = convert_to_tensor(stop, dtype=dtype) + steps = torch.arange(num, dtype=dtype, device=get_device()) / (num - 1) + + # reshape `steps` to allow for broadcasting + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # increments from `start` to `stop` in each dimension + linspace = start[None] + steps * (stop - start)[None] + else: + linspace = torch.linspace( + start=start, + end=stop, + steps=num, + dtype=dtype, + device=get_device(), + ) + if retstep is True: + return (linspace, step) + return linspace + + +def log(x): + x = convert_to_tensor(x) + return torch.log(x) + + +def log10(x): + x = convert_to_tensor(x) + return torch.log10(x) + + +def log1p(x): + x = convert_to_tensor(x) + return torch.log1p(x) + + +def log2(x): + x = convert_to_tensor(x) + return torch.log2(x) + + +def logaddexp(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + + # TODO: torch.logaddexp doesn't support float16 with cpu + if get_device() == "cpu" and dtype == "float16": + x1 = cast(x1, "float32") + x2 = cast(x2, "float32") + return cast(torch.logaddexp(x1, x2), dtype) + else: + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + return torch.logaddexp(x1, x2) + + +def logaddexp2(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype, float) + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + return torch.logaddexp2(x1, x2) + + +def logical_and(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.logical_and(x1, x2) + + +def logical_not(x): + x = convert_to_tensor(x) + return torch.logical_not(x) + + +def logical_or(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.logical_or(x1, x2) + + +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + if axis != 0: + raise ValueError( + "torch.logspace does not support an `axis` argument. " + f"Received axis={axis}" + ) + if dtype is None: + dtypes_to_resolve = [ + getattr(start, "dtype", type(start)), + getattr(stop, "dtype", type(stop)), + float, + ] + dtype = dtypes.result_type(*dtypes_to_resolve) + dtype = to_torch_dtype(dtype) + + if endpoint is False: + stop = stop - ((stop - start) / num) + if hasattr(start, "__len__") and hasattr(stop, "__len__"): + start = convert_to_tensor(start, dtype=dtype) + stop = convert_to_tensor(stop, dtype=dtype) + steps = torch.arange(num, dtype=dtype, device=get_device()) / (num - 1) + + # reshape `steps` to allow for broadcasting + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # increments from `start` to `stop` in each dimension + linspace = start[None] + steps * (stop - start)[None] + logspace = base**linspace + else: + compute_dtype = dtype + # TODO: torch.logspace doesn't support float16 with cpu + if get_device() == "cpu" and dtype == torch.float16: + compute_dtype = torch.float32 + logspace = cast( + torch.logspace( + start=start, + end=stop, + steps=num, + base=base, + dtype=compute_dtype, + device=get_device(), + ), + dtype, + ) + return logspace + + +def maximum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return torch.maximum(x1, x2) + + +def median(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + x = cast(x, compute_dtype) + + if axis is None and keepdims is False: + return cast(torch.median(x), result_dtype) + elif isinstance(axis, int): + return cast( + torch.median(x, dim=axis, keepdim=keepdims)[0], result_dtype + ) + + # support multiple axes + if axis is None: + y = reshape(x, [-1]) + else: + # transpose + axis = [canonicalize_axis(a, x.ndim) for a in axis] + other_dims = sorted(set(range(x.ndim)).difference(axis)) + perm = other_dims + list(axis) + x_permed = torch.permute(x, dims=perm) + # reshape + x_shape = list(x.shape) + other_shape = [x_shape[i] for i in other_dims] + end_shape = [math.prod([x_shape[i] for i in axis])] + full_shape = other_shape + end_shape + y = reshape(x_permed, full_shape) + + y = torch.median(y, dim=-1)[0] + + if keepdims: + if axis is None: + for _ in range(x.ndim): + y = expand_dims(y, axis=-1) + else: + for i in sorted(axis): + y = expand_dims(y, axis=i) + + return cast(y, result_dtype) + + +def meshgrid(*x, indexing="xy"): + x = [convert_to_tensor(sc_tensor) for sc_tensor in x] + return torch.meshgrid(x, indexing=indexing) + + +def min(x, axis=None, keepdims=False, initial=None): + x = convert_to_tensor(x) + if 0 in x.shape: + if initial is None: + raise ValueError("Cannot compute the min of an empty tensor.") + elif keepdims: + return torch.full((1,) * len(x.shape), initial) + else: + return torch.tensor(initial) + + if axis is None: + result = torch.min(x) + else: + result = amin(x, axis=axis, keepdims=keepdims) + + if isinstance(getattr(result, "values", None), torch.Tensor): + result = result.values + + if initial is not None: + dtype = to_torch_dtype(result.dtype) + initial = convert_to_tensor(initial, dtype=dtype) + return torch.minimum(result, initial) + return result + + +def minimum(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1 = convert_to_tensor(x1, dtype) + x2 = convert_to_tensor(x2, dtype) + return torch.minimum(x1, x2) + + +def mod(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype == "bool": + x1 = cast(x1, "int32") + x2 = cast(x2, "int32") + return torch.remainder(x1, x2) + + +def moveaxis(x, source, destination): + x = convert_to_tensor(x) + return torch.moveaxis(x, source=source, destination=destination) + + +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + x = convert_to_tensor(x) + return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +def ndim(x): + x = convert_to_tensor(x) + return x.ndim + + +def nonzero(x): + x = convert_to_tensor(x) + return cast(torch.nonzero(x).T, "int32") + + +def not_equal(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.not_equal(x1, x2) + + +def ones_like(x, dtype=None): + x = convert_to_tensor(x) + dtype = to_torch_dtype(dtype or x.dtype) + return torch.ones_like(x, dtype=dtype) + + +def outer(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.outer(x1.flatten(), x2.flatten()) + + +def pad(x, pad_width, mode="constant", constant_values=None): + kwargs = {} + if constant_values is not None: + if mode != "constant": + raise ValueError( + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`. " + f"Received: mode={mode}" + ) + kwargs["value"] = constant_values + x = convert_to_tensor(x) + pad_sum = [] + pad_width = list(pad_width)[::-1] # torch uses reverse order + pad_width_sum = 0 + for pad in pad_width: + pad_width_sum += pad[0] + pad[1] + for pad in pad_width: + pad_sum += pad + pad_width_sum -= pad[0] + pad[1] + if pad_width_sum == 0: # early break when no padding in higher order + break + if mode == "symmetric": + mode = "replicate" + if mode == "constant": + return torch.nn.functional.pad(x, pad=pad_sum, mode=mode, **kwargs) + # TODO: reflect and symmetric padding are implemented for padding the + # last 3 dimensions of a 4D or 5D input tensor, the last 2 dimensions of a + # 3D or 4D input tensor, or the last dimension of a 2D or 3D input tensor. + # https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + ori_dtype = x.dtype + ori_ndim = x.ndim + need_squeeze = False + if x.ndim < 3: + need_squeeze = True + new_dims = [1] * (3 - x.ndim) + x = x.view(*new_dims, *x.shape) + need_cast = False + if x.dtype not in (torch.float32, torch.float64): + # TODO: reflect and symmetric padding are only supported with float32/64 + # https://github.com/pytorch/pytorch/issues/40763 + need_cast = True + x = cast(x, torch.float32) + x = torch.nn.functional.pad(x, pad=pad_sum, mode=mode) + if need_cast: + x = cast(x, ori_dtype) + if need_squeeze: + x = torch.squeeze(x, dim=tuple(range(3 - ori_ndim))) + return x + + +def prod(x, axis=None, keepdims=False, dtype=None): + x = convert_to_tensor(x) + if dtype is None: + dtype = dtypes.result_type(x.dtype) + if dtype == "bool": + dtype = "int32" + elif dtype in ("int8", "int16"): + dtype = "int32" + # TODO: torch.prod doesn't support uint32 + elif dtype == "uint8": + dtype = "int32" + compute_dtype = dtype + # TODO: torch.prod doesn't support float16 with cpu + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + if axis is None: + return cast(torch.prod(x, dtype=to_torch_dtype(compute_dtype)), dtype) + axis = to_tuple_or_list(axis) + for a in axis: + # `torch.prod` does not handle multiple axes. + x = cast( + torch.prod( + x, dim=a, keepdim=keepdims, dtype=to_torch_dtype(compute_dtype) + ), + dtype, + ) + return x + + +def quantile(x, q, axis=None, method="linear", keepdims=False): + x = convert_to_tensor(x) + q = convert_to_tensor(q) + axis = to_tuple_or_list(axis) + + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + + x = cast(x, compute_dtype) + # q must be same dtype as x + if x.dtype != q.dtype: + q = cast(q, x.dtype) + + # support multiple axes + if axis is None: + y = reshape(x, [-1]) + else: + # transpose + axis = [canonicalize_axis(a, x.ndim) for a in axis] + other_dims = sorted(set(range(x.ndim)).difference(axis)) + perm = other_dims + list(axis) + x_permed = torch.permute(x, dims=perm) + # reshape + x_shape = list(x.shape) + other_shape = [x_shape[i] for i in other_dims] + end_shape = [math.prod([x_shape[i] for i in axis])] + full_shape = other_shape + end_shape + y = reshape(x_permed, full_shape) + + y = torch.quantile(y, q, dim=-1, interpolation=method) + + if keepdims: + if axis is None: + for _ in range(x.ndim): + y = expand_dims(y, axis=-1) + else: + for i in sorted(axis): + i = i + 1 if q.ndim > 0 else i + y = expand_dims(y, axis=i) + + return cast(y, result_dtype) + + +def ravel(x): + x = convert_to_tensor(x) + return torch.ravel(x) + + +def unravel_index(indices, shape): + indices = convert_to_tensor(indices) + dtype = dtypes.result_type(indices.dtype) + return tuple( + cast(idx, dtype) for idx in torch.unravel_index(indices, shape) + ) + + +def real(x): + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.real(x) + + +def reciprocal(x): + x = convert_to_tensor(x) + return torch.reciprocal(x) + + +def repeat(x, repeats, axis=None): + x = convert_to_tensor(x) + + if get_device() == "meta": + x = KerasTensor(x.shape, standardize_dtype(x.dtype)) + outputs = repeat(x, repeats, axis=axis) + + return torch.empty( + size=outputs.shape, + dtype=to_torch_dtype(outputs.dtype), + device=get_device(), + ) + + repeats = convert_to_tensor(repeats, dtype=int) + + return torch.repeat_interleave(x, repeats, dim=axis) + + +def reshape(x, newshape): + if not isinstance(newshape, (list, tuple)): + newshape = (newshape,) + x = convert_to_tensor(x) + return torch.reshape(x, newshape) + + +def roll(x, shift, axis=None): + x = convert_to_tensor(x) + return torch.roll(x, shift, dims=axis) + + +def searchsorted(sorted_sequence, values, side="left"): + if ndim(sorted_sequence) != 1: + raise ValueError( + "`searchsorted` only supports 1-D sorted sequences. " + "You can use `keras.ops.vectorized_map` " + "to extend it to N-D sequences. Received: " + f"sorted_sequence.shape={sorted_sequence.shape}" + ) + out_int32 = sorted_sequence.shape[0] <= np.iinfo(np.int32).max + return torch.searchsorted( + sorted_sequence, values, side=side, out_int32=out_int32 + ) + + +def sign(x): + x = convert_to_tensor(x) + return torch.sign(x) + + +def signbit(x): + x = convert_to_tensor(x) + return torch.signbit(x) + + +def sin(x): + x = convert_to_tensor(x) + return torch.sin(x) + + +def sinh(x): + x = convert_to_tensor(x) + return torch.sinh(x) + + +def size(x): + x_shape = convert_to_tensor(tuple(x.shape)) + return torch.prod(x_shape) + + +def sort(x, axis=-1): + x = convert_to_tensor(x) + # TODO: torch.sort doesn't support bool with cuda + if get_device() == "cuda" and standardize_dtype(x.dtype) == "bool": + x = cast(x, "uint8") + return cast(torch.sort(x, dim=axis).values, "bool") + return torch.sort(x, dim=axis).values + + +def split(x, indices_or_sections, axis=0): + x = convert_to_tensor(x) + dim = x.shape[axis] + if not isinstance(indices_or_sections, int): + indices_or_sections = convert_to_tensor(indices_or_sections) + start_size = indices_or_sections[0:1] + end_size = dim - indices_or_sections[-1:] + chunk_sizes = torch.concat( + [start_size, torch.diff(indices_or_sections), end_size], dim=0 + ) + # torch.split doesn't support tensor input for `split_size_or_sections` + chunk_sizes = chunk_sizes.tolist() + else: + if dim % indices_or_sections != 0: + raise ValueError( + f"Received indices_or_sections={indices_or_sections} " + f"(interpreted as a number of sections) and axis={axis}, " + f"but input dimension x.shape[{axis}]={x.shape[axis]} " + f"is not divisible by {indices_or_sections}. " + f"Full input shape: x.shape={x.shape}" + ) + chunk_sizes = dim // indices_or_sections + out = torch.split( + tensor=x, + split_size_or_sections=chunk_sizes, + dim=axis, + ) + if dim == 0 and isinstance(indices_or_sections, int): + out = [out[0].clone() for _ in range(indices_or_sections)] + return list(out) + + +def stack(x, axis=0): + x = [convert_to_tensor(elem) for elem in x] + return torch.stack(x, dim=axis) + + +def std(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + if "int" in ori_dtype or ori_dtype == "bool": + x = cast(x, "float32") + # Remove Bessel correction to align with numpy + return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False) + + +def swapaxes(x, axis1, axis2): + x = convert_to_tensor(x) + return torch.swapaxes(x, axis0=axis1, axis1=axis2) + + +def take(x, indices, axis=None): + x = convert_to_tensor(x) + indices = convert_to_tensor(indices).long() + # Correct the indices using "fill" mode which is the same as in jax + x_dim = x.shape[axis] if axis is not None else x.shape[0] + indices = torch.where( + indices < 0, + indices + x_dim, + indices, + ) + if x.ndim == 2 and axis == 0: + # This case is equivalent to embedding lookup. + return torch.nn.functional.embedding(indices, x) + if axis is None: + x = torch.reshape(x, (-1,)) + axis = 0 + if axis is not None: + axis = canonicalize_axis(axis, x.ndim) + shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :] + # ravel the `indices` since `index_select` expects `indices` + # to be a vector (1-D tensor). + indices = indices.ravel() + out = torch.index_select(x, dim=axis, index=indices).squeeze(axis) + return out.reshape(shape) + return torch.take(x, index=indices) + + +def take_along_axis(x, indices, axis=None): + x = convert_to_tensor(x) + indices = convert_to_tensor(indices).long() + # Correct the indices using "fill" mode which is the same as in jax + x_dim = x.shape[axis] if axis is not None else x.shape[0] + indices = torch.where( + indices < 0, + indices + x_dim, + indices, + ) + return torch.take_along_dim(x, indices, dim=axis) + + +def tan(x): + x = convert_to_tensor(x) + return torch.tan(x) + + +def tanh(x): + x = convert_to_tensor(x) + return torch.tanh(x) + + +def tensordot(x1, x2, axes=2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + # TODO: torch.tensordot only supports float types + compute_dtype = dtypes.result_type(result_dtype, float) + # TODO: torch.tensordot doesn't support float16 with cpu + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + # torch only handles dims=((0,), (1,)), numpy accepts axes=(0, 1). + if isinstance(axes, (list, tuple)): + first, second = axes + if not isinstance(first, (list, tuple)): + first = (first,) + if not isinstance(second, (list, tuple)): + second = (second,) + axes = (first, second) + return cast(torch.tensordot(x1, x2, dims=axes), result_dtype) + + +def round(x, decimals=0): + x = convert_to_tensor(x) + ori_dtype = standardize_dtype(x.dtype) + # TODO: torch.round doesn't support int8, int16, int32, int64, uint8 + if "int" in ori_dtype: + x = cast(x, config.floatx()) + return cast(torch.round(x, decimals=decimals), ori_dtype) + return torch.round(x, decimals=decimals) + + +def tile(x, repeats): + if is_tensor(repeats): + repeats = tuple(repeats.int().numpy()) + if isinstance(repeats, int): + repeats = (repeats,) + x = convert_to_tensor(x) + return torch.tile(x, dims=repeats) + + +def trace(x, offset=0, axis1=0, axis2=1): + x = convert_to_tensor(x) + dtype = standardize_dtype(x.dtype) + if dtype != "int64": + dtype = dtypes.result_type(dtype, "int32") + return torch.sum( + torch.diagonal(x, offset, axis1, axis2), + dim=-1, + dtype=to_torch_dtype(dtype), + ) + + +def tri(N, M=None, k=0, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + M = M or N + x = torch.ones((N, M), dtype=dtype, device=get_device()) + return torch.tril(x, diagonal=k) + + +def tril(x, k=0): + x = convert_to_tensor(x) + return torch.tril(x, diagonal=k) + + +def triu(x, k=0): + x = convert_to_tensor(x) + return torch.triu(x, diagonal=k) + + +def trunc(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + return x + return torch.trunc(x) + + +def vdot(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + # TODO: torch.vdot only supports float types + compute_dtype = dtypes.result_type(result_dtype, float) + + # TODO: torch.vdot doesn't support float16 with cpu + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.vdot(x1, x2), result_dtype) + + +def inner(x1, x2): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + result_dtype = dtypes.result_type(x1.dtype, x2.dtype) + compute_dtype = dtypes.result_type(result_dtype, float) + + if get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.inner(x1, x2), result_dtype) + + +def vstack(xs): + xs = [convert_to_tensor(x) for x in xs] + return torch.vstack(xs) + + +def vectorize(pyfunc, *, excluded=None, signature=None): + return vectorize_impl( + pyfunc, torch.vmap, excluded=excluded, signature=signature + ) + + +def where(condition, x1=None, x2=None): + condition = convert_to_tensor(condition, dtype=bool) + if x1 is not None and x2 is not None: + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + return torch.where(condition, x1, x2) + else: + return torch.where(condition) + + +def divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + return torch.divide(x1, x2) + + +def divide_no_nan(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + return torch.where(x2 == 0, 0, torch.divide(x1, x2)) + + +def true_divide(x1, x2): + return divide(x1, x2) + + +def power(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.pow(x1, x2) + + +def negative(x): + x = convert_to_tensor(x) + return torch.negative(x) + + +def square(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "int32") + return torch.square(x) + + +def sqrt(x): + x = convert_to_tensor(x) + if standardize_dtype(x.dtype) == "int64": + x = cast(x, config.floatx()) + return torch.sqrt(x) + + +def squeeze(x, axis=None): + x = convert_to_tensor(x) + if axis is not None: + return torch.squeeze(x, dim=axis) + return torch.squeeze(x) + + +def transpose(x, axes=None): + x = convert_to_tensor(x) + if axes is not None: + return torch.permute(x, dims=axes) + return x.T + + +def var(x, axis=None, keepdims=False): + x = convert_to_tensor(x) + compute_dtype = dtypes.result_type(x.dtype, "float32") + result_dtype = dtypes.result_type(x.dtype, float) + if axis == [] or axis == (): + # Torch handles the empty axis case differently from numpy. + return zeros_like(x, result_dtype) + # Bessel correction removed for numpy compatibility + x = cast(x, compute_dtype) + return cast( + torch.var(x, dim=axis, keepdim=keepdims, correction=0), result_dtype + ) + + +def sum(x, axis=None, keepdims=False): + if isinstance(x, (list, tuple)): + x = stack(x) + x = convert_to_tensor(x) + if axis == () or axis == []: + # Torch handles the empty axis case differently from numpy. + return x + dtype = standardize_dtype(x.dtype) + # follow jax's rule + # TODO: torch doesn't support uint32 + if dtype in ("bool", "uint8", "int8", "int16"): + dtype = "int32" + if axis is not None: + return cast(torch.sum(x, axis=axis, keepdim=keepdims), dtype) + return cast(torch.sum(x), dtype) + + +def eye(N, M=None, k=0, dtype=None): + dtype = to_torch_dtype(dtype or config.floatx()) + M = N if M is None else M + k = 0 if k is None else k + if k == 0: + # TODO: torch.eye doesn't support bfloat16 with cpu + if get_device() == "cpu" and dtype == torch.bfloat16: + return cast( + torch.eye( + N, M, dtype=to_torch_dtype("float32"), device=get_device() + ), + dtype, + ) + return torch.eye(N, M, dtype=dtype, device=get_device()) + diag_length = builtins.max(N, M) + diag = torch.ones(diag_length, dtype=dtype, device=get_device()) + return torch.diag(diag, diagonal=k)[:N, :M] + + +def floor_divide(x1, x2): + if not isinstance(x1, (int, float)): + x1 = convert_to_tensor(x1) + if not isinstance(x2, (int, float)): + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return cast(torch.floor_divide(x1, x2), dtype) + + +def logical_xor(x1, x2): + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.logical_xor(x1, x2) + + +def corrcoef(x): + x = convert_to_tensor(x) + + if standardize_dtype(x.dtype) == "bool": + x = cast(x, config.floatx()) + elif standardize_dtype(x.dtype) == "int64": + x = cast(x, "float64") + + return torch.corrcoef(x) + + +def correlate(x1, x2, mode="valid"): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if dtype == "int64": + dtype = "float64" + elif dtype not in ["bfloat16", "float16", "float64"]: + dtype = "float32" + + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + + x1_len, x2_len = x1.size(0), x2.size(0) + + if x1.shape[:-1] != x2.shape[:-1]: + new_shape = [max(i, j) for i, j in zip(x1.shape[:-1], x2.shape[:-1])] + x1 = torch.broadcast_to(x1, new_shape + [x1.shape[-1]]) + x2 = torch.broadcast_to(x2, new_shape + [x2.shape[-1]]) + + num_signals = torch.tensor(x1.shape[:-1]).prod() + x1 = torch.reshape(x1, (int(num_signals), x1.size(-1))) + x2 = torch.reshape(x2, (int(num_signals), x2.size(-1))) + + output = torch.nn.functional.conv1d( + x1, x2.unsqueeze(1), groups=x1.size(0), padding=x2.size(-1) - 1 + ) + output_shape = x1.shape[:-1] + (-1,) + result = output.reshape(output_shape) + + if mode == "valid": + target_length = ( + builtins.max(x1_len, x2_len) - builtins.min(x1_len, x2_len) + 1 + ) + start_idx = (result.size(-1) - target_length) // 2 + result = result[..., start_idx : start_idx + target_length] + + if mode == "same": + start_idx = (result.size(-1) - x1_len) // 2 + result = result[..., start_idx : start_idx + x1_len] + + return torch.squeeze(result) + + +def select(condlist, choicelist, default=0): + condlist = [convert_to_tensor(c) for c in condlist] + choicelist = [convert_to_tensor(c) for c in choicelist] + out = convert_to_tensor(default) + for c, v in reversed(list(zip(condlist, choicelist))): + out = torch.where(c, v, out) + return out + + +def slogdet(x): + x = convert_to_tensor(x) + return tuple(torch.linalg.slogdet(x)) + + +def argpartition(x, kth, axis=-1): + x = convert_to_tensor(x, "int32") + x = torch.transpose(x, axis, -1) + bottom_ind = torch.topk(-x, kth + 1)[1] + + def set_to_zero(a, i): + a[i] = torch.zeros(1, dtype=a.dtype, device=a.device) + return a + + for _ in range(x.dim() - 1): + set_to_zero = torch.vmap(set_to_zero) + proxy = set_to_zero(torch.ones_like(x, dtype=torch.int32), bottom_ind) + top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1] + out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1) + return cast(torch.transpose(out, -1, axis), "int32") + + +def histogram(x, bins=10, range=None): + hist_result = torch.histogram(x, bins=bins, range=range) + return hist_result.hist, hist_result.bin_edges diff --git a/keras/src/backend/torch/optimizers/__init__.py b/keras/src/backend/torch/optimizers/__init__.py new file mode 100644 index 000000000000..008312b04b63 --- /dev/null +++ b/keras/src/backend/torch/optimizers/__init__.py @@ -0,0 +1 @@ +from keras.src.backend.torch.optimizers.torch_optimizer import TorchOptimizer diff --git a/keras/src/backend/torch/optimizers/torch_adadelta.py b/keras/src/backend/torch/optimizers/torch_adadelta.py new file mode 100644 index 000000000000..9e6038e7b6eb --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_adadelta.py @@ -0,0 +1,56 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class Adadelta( + torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adadelta +): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + rho = self.rho + + accumulated_grads = [ + self._accumulated_grads[self._get_variable_index(variable)].value + for variable in keras_variables + ] + accumulated_delta_vars = [ + self._accumulated_delta_vars[ + self._get_variable_index(variable) + ].value + for variable in keras_variables + ] + torch._foreach_mul_(accumulated_grads, rho) + torch._foreach_add_( + accumulated_grads, torch._foreach_mul(grads, grads), alpha=1 - rho + ) + + def rms(x): + return torch._foreach_sqrt(torch._foreach_add(x, self.epsilon)) + + delta_vars = torch._foreach_mul( + torch._foreach_div( + torch._foreach_mul(rms(accumulated_delta_vars), grads), + rms(accumulated_grads), + ), + -1, + ) + torch._foreach_mul_(accumulated_delta_vars, rho) + torch._foreach_add_( + accumulated_delta_vars, + torch._foreach_mul(delta_vars, delta_vars), + alpha=1 - rho, + ) + + torch._foreach_add_(variables, delta_vars, alpha=lr) diff --git a/keras/src/backend/torch/optimizers/torch_adagrad.py b/keras/src/backend/torch/optimizers/torch_adagrad.py new file mode 100644 index 000000000000..2a1e19f70fd6 --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_adagrad.py @@ -0,0 +1,37 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class Adagrad( + torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adagrad +): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + accumulators = [ + self._accumulators[self._get_variable_index(variable)].value + for variable in keras_variables + ] + torch._foreach_add_(accumulators, torch._foreach_mul(grads, grads)) + torch._foreach_add_( + variables, + torch._foreach_div( + torch._foreach_mul(grads, lr), + torch._foreach_sqrt( + torch._foreach_add(accumulators, self.epsilon) + ), + ), + alpha=-1, + ) diff --git a/keras/src/backend/torch/optimizers/torch_adam.py b/keras/src/backend/torch/optimizers/torch_adam.py new file mode 100644 index 000000000000..3bb7db7c341c --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_adam.py @@ -0,0 +1,58 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class Adam(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adam): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + local_step = ops.cast(self.iterations + 1, dtype) + + beta_1_power = ops.power(ops.cast(self.beta_1, dtype), local_step) + beta_2_power = ops.power(ops.cast(self.beta_2, dtype), local_step) + alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power) + + m_list = [ + self._momentums[self._get_variable_index(variable)].value + for variable in keras_variables + ] + v_list = [ + self._velocities[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + torch._foreach_mul_(m_list, self.beta_1) + torch._foreach_add_(m_list, grads, alpha=1 - self.beta_1) + + torch._foreach_mul_(v_list, self.beta_2) + torch._foreach_add_( + v_list, torch._foreach_mul(grads, grads), alpha=1 - self.beta_2 + ) + + if self.amsgrad: + v_hat_list = [ + self._velocity_hats[self._get_variable_index(variable)].value + for variable in keras_variables + ] + torch._foreach_maximum_(v_hat_list, v_list) + v_list = v_hat_list + + torch._foreach_add_( + variables, + torch._foreach_div( + torch._foreach_mul(m_list, alpha), + torch._foreach_add(torch._foreach_sqrt(v_list), self.epsilon), + ), + alpha=-1, + ) diff --git a/keras/src/backend/torch/optimizers/torch_adamax.py b/keras/src/backend/torch/optimizers/torch_adamax.py new file mode 100644 index 000000000000..9cb3c0184499 --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_adamax.py @@ -0,0 +1,52 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class Adamax( + torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adamax +): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + local_step = ops.cast(self.iterations + 1, dtype) + + beta_1_power = ops.power(ops.cast(self.beta_1, dtype), local_step) + + m_list = [ + self._m[self._get_variable_index(variable)].value + for variable in keras_variables + ] + u_list = [ + self._u[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + torch._foreach_mul_(m_list, self.beta_1) + torch._foreach_add_(m_list, grads, alpha=1 - self.beta_1) + + torch._foreach_mul_(u_list, self.beta_2) + torch._foreach_maximum_(u_list, torch._foreach_abs(grads)) + + torch._foreach_add_( + variables, + torch._foreach_div( + torch._foreach_mul(m_list, lr), + torch._foreach_mul( + torch._foreach_add(u_list, self.epsilon), + 1 - beta_1_power, + ), + ), + alpha=-1, + ) diff --git a/keras/src/backend/torch/optimizers/torch_adamw.py b/keras/src/backend/torch/optimizers/torch_adamw.py new file mode 100644 index 000000000000..394727cd9b59 --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_adamw.py @@ -0,0 +1,6 @@ +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_adam + + +class AdamW(torch_adam.Adam, optimizers.AdamW): + pass diff --git a/keras/src/backend/torch/optimizers/torch_lion.py b/keras/src/backend/torch/optimizers/torch_lion.py new file mode 100644 index 000000000000..f2022ad6e53e --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_lion.py @@ -0,0 +1,37 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class Lion(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Lion): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + m_list = [ + self._momentums[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + c_t = torch._foreach_mul(m_list, self.beta_1) + torch._foreach_add_(c_t, grads, alpha=1 - self.beta_1) + c_t = [c.sign() for c in c_t] + + torch._foreach_add_( + variables, + torch._foreach_mul(c_t, lr), + alpha=-1, + ) + + torch._foreach_mul_(m_list, self.beta_2) + torch._foreach_add_(m_list, grads, alpha=1 - self.beta_2) diff --git a/keras/src/backend/torch/optimizers/torch_nadam.py b/keras/src/backend/torch/optimizers/torch_nadam.py new file mode 100644 index 000000000000..df82bd2c473b --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_nadam.py @@ -0,0 +1,74 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch import core +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class Nadam(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Nadam): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + local_step = ops.cast(self.iterations + 1, dtype) + next_step = ops.cast(self.iterations + 2, dtype) + decay = ops.cast(0.96, dtype) + beta_1 = ops.cast(self.beta_1, dtype) + beta_2 = ops.cast(self.beta_2, dtype) + u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step))) + u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step))) + u_product_t = self._u_product.value * u_t + u_product_t_1 = u_product_t * u_t_1 + beta_2_power = ops.power(beta_2, local_step) + + self._u_product.assign(u_product_t) + + m_list = [ + self._momentums[self._get_variable_index(variable)].value + for variable in keras_variables + ] + v_list = [ + self._velocities[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + torch._foreach_mul_(m_list, self.beta_1) + torch._foreach_add_(m_list, grads, alpha=1 - self.beta_1) + + torch._foreach_mul_(v_list, self.beta_2) + torch._foreach_add_( + v_list, torch._foreach_mul(grads, grads), alpha=1 - self.beta_2 + ) + + m_hat_list = torch._foreach_add( + torch._foreach_div( + torch._foreach_mul(m_list, u_t_1), + 1 - core.convert_to_numpy(u_product_t_1), + ), + torch._foreach_div( + torch._foreach_mul(grads, 1 - u_t), + 1 - core.convert_to_numpy(u_product_t), + ), + ) + + v_hat_list = torch._foreach_div(v_list, 1 - beta_2_power) + + torch._foreach_add_( + variables, + torch._foreach_div( + torch._foreach_mul(m_hat_list, lr), + torch._foreach_add( + torch._foreach_sqrt(v_hat_list), self.epsilon + ), + ), + alpha=-1, + ) diff --git a/keras/src/backend/torch/optimizers/torch_optimizer.py b/keras/src/backend/torch/optimizers/torch_optimizer.py new file mode 100644 index 000000000000..85fc274c574f --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_optimizer.py @@ -0,0 +1,45 @@ +import torch + +from keras.src import optimizers +from keras.src.optimizers.base_optimizer import BaseOptimizer +from keras.src.utils import torch_utils + + +class TorchOptimizer(BaseOptimizer): + def __new__(cls, *args, **kwargs): + # Import locally to avoid circular imports. + from keras.src.backend.torch.optimizers import torch_adadelta + from keras.src.backend.torch.optimizers import torch_adagrad + from keras.src.backend.torch.optimizers import torch_adam + from keras.src.backend.torch.optimizers import torch_adamax + from keras.src.backend.torch.optimizers import torch_adamw + from keras.src.backend.torch.optimizers import torch_lion + from keras.src.backend.torch.optimizers import torch_nadam + from keras.src.backend.torch.optimizers import torch_rmsprop + from keras.src.backend.torch.optimizers import torch_sgd + + OPTIMIZERS = { + optimizers.Adadelta: torch_adadelta.Adadelta, + optimizers.Adagrad: torch_adagrad.Adagrad, + optimizers.Adam: torch_adam.Adam, + optimizers.Adamax: torch_adamax.Adamax, + optimizers.AdamW: torch_adamw.AdamW, + optimizers.Lion: torch_lion.Lion, + optimizers.Nadam: torch_nadam.Nadam, + optimizers.RMSprop: torch_rmsprop.RMSprop, + optimizers.SGD: torch_sgd.SGD, + } + + if cls in OPTIMIZERS: + return OPTIMIZERS[cls](*args, **kwargs) + return super().__new__(cls) + + @torch_utils.no_grad + def _apply_weight_decay(self, variables): + if self.weight_decay is None: + return + + torch._foreach_mul_( + [v.value for v in variables if self._use_weight_decay(v)], + 1 - self.weight_decay * self._get_current_learning_rate(), + ) diff --git a/keras/src/backend/torch/optimizers/torch_parallel_optimizer.py b/keras/src/backend/torch/optimizers/torch_parallel_optimizer.py new file mode 100644 index 000000000000..450fbf50ec54 --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_parallel_optimizer.py @@ -0,0 +1,26 @@ +import torch + +from keras.src.optimizers.base_optimizer import BaseOptimizer +from keras.src.utils import torch_utils + + +class TorchParallelOptimizer(BaseOptimizer): + @torch_utils.no_grad + def _backend_update_step(self, grads, trainable_variables, learning_rate): + self._parallel_update_step( + grads, + trainable_variables, + learning_rate, + ) + + @torch_utils.no_grad + def _backend_reset_gradient_accumulators(self): + acc_list = [ + v.value for v in self._accumulated_gradients if v is not None + ] + torch._foreach_mul_(acc_list, 0.0) + + @torch_utils.no_grad + def _backend_increment_gradient_accumulators(self, grads, acc_grads): + acc_list = [v.value for v in acc_grads] + torch._foreach_add_(acc_list, grads, alpha=1.0) diff --git a/keras/src/backend/torch/optimizers/torch_rmsprop.py b/keras/src/backend/torch/optimizers/torch_rmsprop.py new file mode 100644 index 000000000000..49c4c3916bc1 --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_rmsprop.py @@ -0,0 +1,64 @@ +import torch + +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class RMSprop( + torch_parallel_optimizer.TorchParallelOptimizer, optimizers.RMSprop +): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + + dtype = variables[0].dtype + lr = ops.cast(learning_rate, dtype) + + velocities = [ + self._velocities[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + rho = self.rho + + torch._foreach_mul_(velocities, rho) + torch._foreach_add_( + velocities, torch._foreach_mul(grads, grads), alpha=1 - rho + ) + + denominators = torch._foreach_add(velocities, self.epsilon) + if self.centered: + average_grads = [ + self._average_gradients[ + self._get_variable_index(variable) + ].value + for variable in keras_variables + ] + torch._foreach_mul_(average_grads, rho) + torch._foreach_add_(average_grads, grads, alpha=1 - rho) + torch._foreach_add_( + denominators, + torch._foreach_mul(average_grads, average_grads), + alpha=-1, + ) + torch._foreach_sqrt_(denominators) + increments = torch._foreach_div( + torch._foreach_mul(grads, lr), denominators + ) + + if self.momentum > 0: + momentum_list = [ + self._momentums[self._get_variable_index(variable)].value + for variable in keras_variables + ] + torch._foreach_mul_(momentum_list, self.momentum) + torch._foreach_add_(momentum_list, increments) + torch._foreach_add_(variables, momentum_list, alpha=-1) + else: + torch._foreach_add_(variables, increments, alpha=-1) diff --git a/keras/src/backend/torch/optimizers/torch_sgd.py b/keras/src/backend/torch/optimizers/torch_sgd.py new file mode 100644 index 000000000000..f16220d85ac3 --- /dev/null +++ b/keras/src/backend/torch/optimizers/torch_sgd.py @@ -0,0 +1,36 @@ +import torch + +from keras.src import optimizers +from keras.src.backend.torch.optimizers import torch_parallel_optimizer + + +class SGD(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.SGD): + def _parallel_update_step( + self, + grads, + variables, + learning_rate, + ): + keras_variables = variables + variables = [v.value for v in variables] + if self.momentum != 0: + bufs = [ + self.momentums[self._get_variable_index(variable)].value + for variable in keras_variables + ] + + for i in range(len(bufs)): + if bufs[i] is None: + bufs[i] = torch.clone(grads[i]).detach() + + torch._foreach_mul_(bufs, self.momentum) + torch._foreach_add_(bufs, grads, alpha=-learning_rate) + + if self.nesterov: + torch._foreach_add_(variables, grads, alpha=-learning_rate) + torch._foreach_add_(variables, bufs, alpha=self.momentum) + else: + torch._foreach_add_(variables, bufs) + + else: + torch._foreach_add_(variables, grads, alpha=-learning_rate) diff --git a/keras/src/backend/torch/random.py b/keras/src/backend/torch/random.py new file mode 100644 index 000000000000..e080731952e6 --- /dev/null +++ b/keras/src/backend/torch/random.py @@ -0,0 +1,244 @@ +import torch +import torch._dynamo as dynamo +import torch.nn.functional as tnn + +from keras.src.backend.config import floatx +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import get_device +from keras.src.backend.torch.core import to_torch_dtype +from keras.src.random.seed_generator import SeedGenerator +from keras.src.random.seed_generator import draw_seed +from keras.src.random.seed_generator import make_default_seed + + +# torch.Generator not supported with dynamo +# see: https://github.com/pytorch/pytorch/issues/88576 +@dynamo.disable() +def torch_seed_generator(seed): + first_seed, second_seed = draw_seed(seed) + device = get_device() + if device == "meta": + # Generator is not supported by the meta device. + return None + generator = torch.Generator(device=get_device()) + generator.manual_seed(int(first_seed + second_seed)) + return generator + + +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + dtype = to_torch_dtype(dtype) + # Do not use generator during symbolic execution. + if get_device() == "meta": + return torch.normal( + mean, stddev, size=shape, dtype=dtype, device=get_device() + ) + generator = torch_seed_generator(seed) + return torch.normal( + mean, + stddev, + size=shape, + generator=generator, + dtype=dtype, + device=get_device(), + ) + + +def categorical(logits, num_samples, dtype="int32", seed=None): + logits = convert_to_tensor(logits) + dtype = to_torch_dtype(dtype) + probs = torch.softmax(logits, dim=-1) + # Do not use generator during symbolic execution. + if get_device() == "meta": + return torch.multinomial( + probs, + num_samples, + replacement=True, + ).type(dtype) + generator = torch_seed_generator(seed) + return torch.multinomial( + probs, + num_samples, + replacement=True, + generator=generator, + ).type(dtype) + + +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + dtype = dtype or floatx() + dtype = to_torch_dtype(dtype) + requested_shape = shape + if len(requested_shape) == 0: + shape = (1,) + # Do not use generator during symbolic execution. + if get_device() == "meta": + rand_tensor = torch.rand(size=shape, dtype=dtype, device=get_device()) + else: + generator = torch_seed_generator(seed) + rand_tensor = torch.rand( + size=shape, generator=generator, dtype=dtype, device=get_device() + ) + + output = (maxval - minval) * rand_tensor + minval + + if len(requested_shape) == 0: + return output[0] + return output + + +def randint(shape, minval, maxval, dtype="int32", seed=None): + dtype = to_torch_dtype(dtype) + # Do not use generator during symbolic execution. + if get_device() == "meta": + return torch.randint( + low=minval, + high=maxval, + size=shape, + dtype=dtype, + device=get_device(), + ) + generator = torch_seed_generator(seed) + return torch.randint( + low=minval, + high=maxval, + size=shape, + generator=generator, + dtype=dtype, + device=get_device(), + ) + + +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + dtype = to_torch_dtype(dtype) + # Take a larger standard normal dist, discard values outside 2 * stddev + # Offset by mean and stddev + x = normal(tuple(shape) + (4,), mean=0, stddev=1, dtype=dtype, seed=seed) + valid = (x > -2) & (x < 2) + indexes = valid.max(-1, keepdim=True)[1] + trunc_x = torch.empty(shape, dtype=dtype, device=get_device()) + trunc_x.data.copy_(x.gather(-1, indexes).squeeze(-1)) + trunc_x.data.mul_(stddev).add_(mean) + return trunc_x + + +def _get_concrete_noise_shape(inputs, noise_shape): + if noise_shape is None: + return inputs.shape + + concrete_inputs_shape = inputs.shape + concrete_noise_shape = [] + for i, value in enumerate(noise_shape): + concrete_noise_shape.append( + concrete_inputs_shape[i] if value is None else value + ) + return concrete_noise_shape + + +def dropout(inputs, rate, noise_shape=None, seed=None): + if ( + seed is not None + and not (isinstance(seed, SeedGenerator) and seed._initial_seed is None) + or noise_shape is not None + ): + keep_prob = 1.0 - rate + noise_shape = _get_concrete_noise_shape(inputs, noise_shape) + keep_prob_matrix = torch.full( + noise_shape, keep_prob, device=get_device() + ) + generator = torch_seed_generator(seed) + + # Do not use generator during symbolic execution. + if get_device() == "meta": + mask = torch.bernoulli(keep_prob_matrix) + else: + mask = torch.bernoulli(keep_prob_matrix, generator=generator) + + mask = mask.bool() + mask = torch.broadcast_to(mask, inputs.shape) + return torch.where( + mask, + inputs / keep_prob, + torch.zeros_like(inputs, dtype=inputs.dtype), + ) + # Fast path, unseeded (since torch doesn't support seeding dropout!!!!) + # Using the above implementation is possible, but much slower. + return torch.nn.functional.dropout( + inputs, p=rate, training=True, inplace=False + ) + + +def shuffle(x, axis=0, seed=None): + # Ref: https://github.com/pytorch/pytorch/issues/71409 + x = convert_to_tensor(x) + + # Get permutation indices + # Do not use generator during symbolic execution. + if get_device() == "meta": + row_perm = torch.rand(x.shape[: axis + 1], device=get_device()).argsort( + axis + ) + else: + generator = torch_seed_generator(seed) + row_perm = torch.rand( + x.shape[: axis + 1], generator=generator, device=get_device() + ).argsort(axis) + for _ in range(x.ndim - axis - 1): + row_perm.unsqueeze_(-1) + + # Reformat this for the gather operation + row_perm = row_perm.repeat( + *[1 for _ in range(axis + 1)], *(x.shape[axis + 1 :]) + ) + return x.gather(axis, row_perm) + + +def gamma(shape, alpha, dtype=None, seed=None): + dtype = dtype or floatx() + dtype = to_torch_dtype(dtype) + alpha = torch.broadcast_to(convert_to_tensor(alpha), shape) + beta = torch.ones(shape, device=get_device()) + prev_rng_state = torch.random.get_rng_state() + # Do not draw seed during symbolic execution + if not get_device() == "meta": + first_seed, second_seed = draw_seed(seed) + torch.manual_seed(first_seed + second_seed) + gamma_distribution = torch.distributions.gamma.Gamma(alpha, beta) + sample = gamma_distribution.sample().type(dtype) + torch.random.set_rng_state(prev_rng_state) + return sample + + +def binomial(shape, counts, probabilities, dtype=None, seed=None): + dtype = dtype or floatx() + dtype = to_torch_dtype(dtype) + counts = torch.broadcast_to(convert_to_tensor(counts), shape) + probabilities = torch.broadcast_to(convert_to_tensor(probabilities), shape) + prev_rng_state = torch.random.get_rng_state() + # Do not draw seed during symbolic execution + if not get_device() == "meta": + first_seed, second_seed = draw_seed(seed) + torch.manual_seed(first_seed + second_seed) + binomial_distribution = torch.distributions.binomial.Binomial( + total_count=counts, probs=probabilities + ) + sample = binomial_distribution.sample().type(dtype) + torch.random.set_rng_state(prev_rng_state) + return sample + + +def beta(shape, alpha, beta, dtype=None, seed=None): + dtype = dtype or floatx() + dtype = to_torch_dtype(dtype) + alpha = torch.broadcast_to(convert_to_tensor(alpha), shape) + beta = torch.broadcast_to(convert_to_tensor(beta), shape) + prev_rng_state = torch.random.get_rng_state() + # Do not draw seed during symbolic execution + if not get_device() == "meta": + first_seed, second_seed = draw_seed(seed) + torch.manual_seed(first_seed + second_seed) + beta_distribution = torch.distributions.beta.Beta( + concentration1=alpha, concentration0=beta + ) + sample = beta_distribution.sample().type(dtype) + torch.random.set_rng_state(prev_rng_state) + return sample diff --git a/keras/src/backend/torch/rnn.py b/keras/src/backend/torch/rnn.py new file mode 100644 index 000000000000..bd9f2efe4731 --- /dev/null +++ b/keras/src/backend/torch/rnn.py @@ -0,0 +1,728 @@ +import numpy as np +import torch + +from keras.src import tree +from keras.src.backend.torch.core import convert_to_tensor +from keras.src.backend.torch.core import get_device + + +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + input_length = input_length or inputs.shape[1] + + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return torch.permute(input_t, axes) + + if not time_major: + inputs = tree.map_structure(swap_batch_timestep, inputs) + + flattened_inputs = tree.flatten(inputs) + time_steps = flattened_inputs[0].shape[0] + time_steps_t = time_steps + + if mask is not None: + if mask.dtype != torch.bool: + mask = mask.type(torch.bool) + if len(mask.shape) == 2: + mask = torch.unsqueeze(mask, -1) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tree.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor,\ + but got {mask_t}" + ) + if tree.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor,\ + but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = torch.unsqueeze(mask_t, -1) + multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:]) + return torch.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = torch.unbind(input_t) # unstack for time_step dim + if go_backwards: + input_t = input_t[::-1] + return input_t + + if tree.is_nested(inputs): + processed_input = tree.map_structure( + _process_single_input_t, inputs + ) # noqa: E501 + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return tree.pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = torch.unbind(mask) + if go_backwards: + mask_list = torch.flip(mask_list, dims=mask_list.shape) + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = torch.zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = torch.where(tiled_mask_t, output, prev_output) + + flat_states = tree.flatten(states) + flat_new_states = tree.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) # noqa: E501 + flat_final_states = tuple( + torch.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) # noqa: E501 + ) + states = tree.pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = torch.stack(successive_outputs) + + if zero_output_for_mask: + last_output = torch.where( + _expand_mask(mask_list[-1], last_output), + last_output, + torch.zeros_like(last_output), + ) + outputs = torch.where( + _expand_mask(mask, outputs, fixed_dim=2), + outputs, + torch.zeros_like(outputs), + ) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) # noqa: E501 + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = torch.stack(successive_outputs) + + else: # Unroll == False + states = tuple(initial_states) + + # Create input tensor array, if the inputs is nested tensors, then it + # will be flattened first, and tensor array will be created one per + # flattened tensor. + + input_ta = tuple( + ( + list(torch.unbind(input_)) + if not go_backwards + else list(torch.unbind(torch.flip(input_, [0]))) + ) + for input_ in flattened_inputs + ) + + # Get the time(0) input and compute the output for that. + input_time_zero = tree.pack_sequence_as( + inputs, [inp[0] for inp in flattened_inputs] + ) + # output_time_zero is used to determine the cell output shape. + output_time_zero, _ = step_function( + input_time_zero, tuple(initial_states) + tuple(constants) + ) + + output_ta_size = time_steps_t if return_all_outputs else 1 + output_ta = [] + for out in tree.flatten(output_time_zero): + out_list = list(out) + if len(out) < output_ta_size: + out_list.extend([[]] * (output_ta_size - len(out))) + output_ta.append(out_list) + + time = torch.tensor(0, dtype=torch.int32) + + if input_length is None: + max_iterations = time_steps_t + else: + if hasattr(input_length, "__len__"): + input_length = convert_to_tensor(input_length) + max_iterations = torch.max(input_length) + else: + max_iterations = input_length + + if mask is not None: + if go_backwards: + mask = torch.flip(mask, [0]) + + mask_ta = list(torch.unbind(mask)) + + def masking_fn(time): + return mask_ta[time] + + def compute_masked_output(mask_t, flat_out, flat_mask): + tiled_mask_t = tuple( + _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape)) + for o in flat_out + ) + return tuple( + torch.where(m, o, fm) + for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask) + ) + + elif isinstance(input_length, torch.Tensor): + if go_backwards: + max_len = torch.max(input_length, dim=0) + if isinstance(max_len, torch.return_types.max): + max_len = max_len[0] + rev_input_length = torch.subtract(max_len - 1, input_length) + + def masking_fn(time): + return torch.less(rev_input_length, time) + + else: + + def masking_fn(time): + return torch.greater(input_length, time) + + def compute_masked_output(mask_t, flat_out, flat_mask): + return tuple( + torch.where(mask_t, o, zo) + for (o, zo) in zip(flat_out, flat_mask) # noqa: E501 + ) + + else: + masking_fn = None + + if masking_fn is not None: + # Mask for the T output will be base on the output of T - 1. In the + # case T = 0, a zero filled tensor will be used. + flat_zero_output = tuple( + torch.zeros_like(o) for o in tree.flatten(output_time_zero) + ) + + def _step(time, output_ta_t, prev_output, *states): + """RNN step function. + + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + + Returns: + Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)` + """ + current_input = tuple(ta[time] for ta in input_ta) + # maybe set shape. + current_input = tree.pack_sequence_as(inputs, current_input) + mask_t = masking_fn(time) + output, new_states = step_function( + current_input, tuple(states) + tuple(constants) + ) + # mask output + flat_output = tree.flatten(output) + flat_mask_output = ( + flat_zero_output + if zero_output_for_mask + else tree.flatten(prev_output) + ) + flat_new_output = compute_masked_output( + mask_t, flat_output, flat_mask_output + ) + + # mask states + flat_state = tree.flatten(states) + flat_new_state = tree.flatten(new_states) + flat_final_state = compute_masked_output( + mask_t, flat_new_state, flat_state + ) + new_states = tree.pack_sequence_as(new_states, flat_final_state) # noqa: E501 + + ta_index_to_write = time if return_all_outputs else 0 + for ta, out in zip(output_ta_t, flat_new_output): + ta[ta_index_to_write] = out + + return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple( + new_states + ) + + it = 0 + output_ta_t, new_states, prev_output = ( + output_ta, + states, + flat_zero_output, + ) + while time < time_steps_t and it < max_iterations: + final_outputs = _step( + time, output_ta_t, prev_output, *new_states + ) # noqa: E501 + time, output_ta_t, prev_output = final_outputs[:3] + new_states = final_outputs[3:] + it += 1 + + else: + + def _step(time, output_ta_t, *states): + """RNN step function. + + Args: + time: Current timestep value. + output_ta_t: TensorArray. + *states: List of states. + + Returns: + Tuple: `(time + 1,output_ta_t) + tuple(new_states)` + """ + current_input = tuple(ta[time] for ta in input_ta) + current_input = tree.pack_sequence_as(inputs, current_input) + output, new_states = step_function( + current_input, tuple(states) + tuple(constants) + ) + flat_new_state = tree.flatten(new_states) + + flat_output = tree.flatten(output) + ta_index_to_write = time if return_all_outputs else 0 + for ta, out in zip(output_ta_t, flat_output): + ta[ta_index_to_write] = out + + new_states = tree.pack_sequence_as( + initial_states, flat_new_state + ) # noqa: E501 + return (time + 1, output_ta_t) + tuple(new_states) + + it = 0 + output_ta_t = output_ta + new_states = states + while time < time_steps_t and it < max_iterations: + final_outputs = _step(time, output_ta_t, *new_states) + time, output_ta_t = final_outputs[:2] + new_states = final_outputs[2:] + it += 1 + + def _stack(tensor_list): + max_ndims = max([t.ndim for t in tensor_list]) + max_list = [] + for i, t in enumerate(tensor_list): + if t.ndim == max_ndims: + max_list.append(t) + return torch.stack(max_list) + + output_ta = final_outputs[1] + + outputs = tuple(_stack(o) for o in output_ta) + last_output = tuple(o[-1] for o in outputs) + + outputs = tree.pack_sequence_as(output_time_zero, outputs) + last_output = tree.pack_sequence_as(output_time_zero, last_output) + + if not time_major: + outputs = tree.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +def _is_sequence_right_padded(mask): + """Check the mask tensor and see if it right padded. + + cuDNN uses the sequence length param to skip the tailing + timestep. If the data is left padded, or not a strict right padding (has + masked value in the middle of the sequence), then cuDNN won't work + properly in those cases. + + Left padded data: [[False, False, True, True, True]]. + Right padded data: [[True, True, True, False, False]]. + Mixture of mask/unmasked data: [[True, False, True, False, False]]. + + Note that for the mixed data example above, the actually data RNN should see + are those 2 Trues (index 0 and 2), the index 1 False should be ignored and + not pollute the internal states. + + Args: + mask: the Boolean tensor with shape [batch, timestep] + + Returns: + boolean scalar tensor, whether the mask is strictly right padded. + """ + # Get max sequence length + max_seq_length = mask.shape[1] + # Count True values in each sequence + count_of_true = torch.sum(mask, dim=1) + # Create right padded mask + batch_size = mask.shape[0] + indices = torch.arange(max_seq_length, device=mask.device).repeat( + batch_size, 1 + ) # noqa: E501 + right_padded_mask = indices < count_of_true.unsqueeze(1) + return torch.all(mask == right_padded_mask) + + +def _has_fully_masked_sequence(mask): + # Cudnn kernel will error out if the input sequence contains any + # fully masked data. We walk around this issue by rerouting the computation + # to standard kernel, until the issue on cudnn side has been fixed. For a + # fully masked sequence, it will contain all Falses. To make it easy to + # check, we inverse the boolean, check if any of the sequence has all True. + return torch.any(torch.all(~mask, dim=1)) + + +def _assert_valid_mask(mask): + # Check if mask is valid for cuDNN + no_fully_masked = ~_has_fully_masked_sequence(mask) + is_right_padded = _is_sequence_right_padded(mask) + valid = no_fully_masked & is_right_padded + + if not valid.item(): + error_message = ( + "You are passing a RNN mask that does not correspond to " + "right-padded sequences, while using cuDNN, which is not " + "supported. With cuDNN, RNN masks can only be used for " + "right-padding, e.g. `[[True, True, False, False]]` would " + "be a valid mask, but any mask that isn't just contiguous " + "`True`'s on the left and contiguous `False`'s on the right " + "would be invalid. You can pass `use_cudnn=False` to your " + "RNN layer to stop using cuDNN (this may be slower)." + ) + raise ValueError(error_message) + + +def _compute_sequence_length_from_mask(mask, batch_first): + """Calculate the sequence length tensor (1-D) based on the masking tensor. + + The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For + any timestep that should be masked, the corresponding field will be False. + Consider the following example: + a = [[True, True, False, False] + [True, True, True, False]] + It is a (2, 4) tensor, and the corresponding sequence length result should + be 1D tensor with value [2, 3]. Note that the masking tensor must be right + padded that could be checked by, e.g., `is_sequence_right_padded()`. + + Args: + mask: Boolean tensor with shape [batch, timestep] or [timestep, batch] + if time_major=True. + time_major: Boolean, which indicates whether the mask is time major or + batch major. + + Returns: + sequence_length: 1D int32 tensor. + """ + timestep_index = 0 if not batch_first else 1 + return torch.sum(mask.int(), dim=timestep_index) + + +def prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device): + """Copies kernel and recurrent kernel weights in the Pytorch format + We split the kernel and recurrent kernel weights, create associated + torch tensors adapted to be in line with the Cudnn optimization. + After we have copied the weights, we ensure the paramters are on + the same device and memory layout is optimized for Cudnn. + + """ + + lstm = lstm.to(device) + hidden_size = lstm.hidden_size + + # Convert gates from Keras [i,f,c,o] to PyTorch [i,f,g,o] + i_k, f_k, c_k, o_k = np.split(kernel, 4, axis=1) + weight_ih_data = np.concatenate([i_k, f_k, c_k, o_k], axis=1).T + + i_r, f_r, c_r, o_r = np.split(recurrent_kernel, 4, axis=1) + weight_hh_data = np.concatenate([i_r, f_r, c_r, o_r], axis=1).T + + if bias is not None: + # Split Keras combined bias into input and hidden biases + bias_ih_data = convert_to_tensor(bias, dtype="float32") + bias_hh_data = torch.zeros_like(bias_ih_data) + + else: + bias_ih_data = torch.zeros(4 * hidden_size, device=device) + bias_hh_data = torch.zeros(4 * hidden_size, device=device) + + # Create PyTorch tensors for weights + weight_ih = convert_to_tensor(weight_ih_data, dtype="float32").contiguous() + weight_hh = convert_to_tensor(weight_hh_data, dtype="float32").contiguous() + bias_ih = convert_to_tensor(bias_ih_data, dtype="float32").contiguous() + bias_hh = convert_to_tensor(bias_hh_data, dtype="float32").contiguous() + + # Ensure the weights are all on the same device + weight_ih = weight_ih.to(device) + weight_hh = weight_hh.to(device) + bias_ih = bias_ih.to(device) + bias_hh = bias_hh.to(device) + + # Copy Keras weights into Torch's flat weights + with torch.no_grad(): + lstm.weight_ih_l0.copy_(weight_ih) + lstm.weight_hh_l0.copy_(weight_hh) + lstm.bias_ih_l0.copy_(bias_ih) + lstm.bias_hh_l0.copy_(bias_hh) + + # Optimize the layout + lstm.flatten_parameters() + + # After prepare_lstm_weights: + # Force all LSTM parameters to be on the correct device + for param in lstm.parameters(): + if param.device != device: + param.data = param.data.to(device) + + +def _is_cuda_cudnn_available(): + # We check if the cuda device and drivers are available + return torch.cuda.is_available() and torch.backends.cudnn.is_available() + + +def cudnn_ok( + activation, + recurrent_activation, + unroll, + use_bias=True, +): + from keras.src import activations + from keras.src import ops + + return ( + activation in (activations.tanh, torch.tanh, ops.tanh) + and recurrent_activation + in (activations.sigmoid, torch.sigmoid, ops.sigmoid) # noqa: E501 + and not unroll + and use_bias + and _is_cuda_cudnn_available() + ) + + +def lstm( + inputs, + initial_state_h, + initial_state_c, + mask, + kernel, + recurrent_kernel, + bias, + activation, + recurrent_activation, + return_sequences=False, + go_backwards=False, + unroll=False, + batch_first=True, +): + cudnn_supported = cudnn_ok( + activation, + recurrent_activation, + unroll, + use_bias=bias is not None, + ) + + if not cudnn_supported: + raise NotImplementedError + + # Get device from inputs + device = get_device() + + from keras.src.backend.torch import Variable + + if isinstance(kernel, Variable): + kernel = kernel.value + if isinstance(recurrent_kernel, Variable): + recurrent_kernel = recurrent_kernel.value + if isinstance(bias, Variable): + bias = bias.value + + # Convert to torch tensors + inputs = convert_to_tensor(inputs, dtype="float32") + initial_state_h = convert_to_tensor(initial_state_h, dtype="float32") + initial_state_c = convert_to_tensor(initial_state_c, dtype="float32") + if mask is not None: + mask = convert_to_tensor(mask, dtype="bool") + + # Preprocess for go_backwards by flipping the sequence + if go_backwards: + seq_dim = 1 if batch_first else 0 + inputs = torch.flip(inputs, dims=[seq_dim]) + if mask is not None: + mask = torch.flip(mask, dims=[seq_dim]) + + # Move all tensors to the same device + inputs = inputs.to(device) + initial_state_h = initial_state_h.to(device) + initial_state_c = initial_state_c.to(device) + if mask is not None: + mask = mask.to(device) + + try: + return _cudnn_lstm( + inputs, + initial_state_h, + initial_state_c, + kernel, + recurrent_kernel, + bias, + mask, + batch_first, + go_backwards, + return_sequences, + device, + ) + except Exception: + raise NotImplementedError + + +def _cudnn_lstm( + inputs, + initial_state_h, + initial_state_c, + kernel, + recurrent_kernel, + bias, + mask, + batch_first, + go_backwards, + return_sequences, + device, +): + if mask is not None: + _assert_valid_mask(mask) + sequence_lengths = _compute_sequence_length_from_mask(mask, batch_first) + + # Ensure inputs are in batch_first format for consistency + if not batch_first: + inputs = inputs.permute(1, 0, 2) + + seq_axis, batch_axis = (0, 1) if not batch_first else (1, 0) + + # If shape is [batch, hidden]; Make [1, batch, hidden] + if initial_state_h.dim() == 2: + initial_state_h = initial_state_h.unsqueeze(0) + initial_state_c = initial_state_c.unsqueeze(0) + # If shape is [batch, 1, hidden] + elif initial_state_h.dim() == 3 and initial_state_h.shape[1] == 1: + initial_state_h = initial_state_h.permute(1, 0, 2) + initial_state_c = initial_state_c.permute(1, 0, 2) + + input_size = kernel.shape[0] + hidden_size = recurrent_kernel.shape[0] + + # Configure LSTM with the provided parameters + lstm = torch.nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=False, + ) + + prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device) + + if mask is not None: + # Sort and pack + sorted_lengths, sorted_indices = torch.sort( + sequence_lengths, descending=True + ) # noqa: E501 + sorted_inputs = inputs[sorted_indices] + sorted_initial_h = initial_state_h[:, sorted_indices] + sorted_initial_c = initial_state_c[:, sorted_indices] + + # Create the packed sequence + packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( + sorted_inputs, sorted_lengths.cpu(), batch_first + ) + + # Process with LSTM (which handles the packed sequence correctly) + packed_outputs, (h_n, c_n) = lstm( + packed_inputs, (sorted_initial_h, sorted_initial_c) + ) + + # Unpack back to padded tensor + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( + packed_outputs, batch_first + ) # noqa: E501 + + else: + # Run LSTM without packing for fixed-length sequences + outputs, (h_n, c_n) = lstm(inputs, (initial_state_h, initial_state_c)) + + outputs = outputs.detach().clone().cpu() + h_n = h_n.detach().clone().cpu() + c_n = c_n.detach().clone().cpu() + # Reshape hidden states for return + h_n = h_n.squeeze(batch_axis) + c_n = c_n.squeeze(batch_axis) + + # Return appropriate outputs based on return_sequences flag + + if mask is not None: + last_output = h_n + else: + last_output = outputs[:, -1] if batch_first else outputs[-1] + + if not return_sequences: + outputs = ( + last_output.unsqueeze(1) + if batch_first + else last_output.unsqueeze(0) + ) # noqa: E501 + + if go_backwards and return_sequences: + outputs = torch.flip(outputs, dims=[seq_axis]) + + return last_output, outputs, [h_n, c_n] + + +def gru(*args, **kwargs): + raise NotImplementedError diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py new file mode 100644 index 000000000000..ad68c2f3a7ec --- /dev/null +++ b/keras/src/backend/torch/trainer.py @@ -0,0 +1,518 @@ +import warnings + +import numpy as np +import torch +from packaging.version import parse + +from keras.src import backend +from keras.src import callbacks as callbacks_module +from keras.src import optimizers as optimizers_module +from keras.src import tree +from keras.src.backend import config +from keras.src.trainers import trainer as base_trainer +from keras.src.trainers.data_adapters import array_slicing +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.epoch_iterator import EpochIterator +from keras.src.utils import traceback_utils + + +class TorchTrainer(base_trainer.Trainer): + def __init__(self): + super().__init__() + self.train_function = None + self.test_function = None + self.predict_function = None + + def _should_torch_compile(self): + # require torch>=2.1.0 to enable dynamo since it + # includes many improvements/fixes to torch.compile() + # TODO eventually we want to get rid of this when + # torch is upgraded to >=2.1 (from 2.0.1) in g3 + if self.jit_compile and parse(torch.__version__) < parse("2.1.0"): + warnings.warn( + "Please upgrade to torch>=2.1.0 for `jit_compile=True` " + "to take effect. Using `jit_compile=False`" + ) + self.jit_compile = False + + return self.jit_compile + + def train_step(self, data): + x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data) + + # Compute predictions + if self._call_has_training_arg: + y_pred = self(x, training=True) + else: + y_pred = self(x) + + # Call torch.nn.Module.zero_grad() to clear the leftover gradients + # for the weights from the previous train step. + self.zero_grad() + + loss = self._compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True + ) + self._loss_tracker.update_state( + loss, + sample_weight=next( + i for i in tree.flatten(x) if i is not None + ).shape[0], + ) + if self.optimizer is not None: + loss = self.optimizer.scale_loss(loss) + + # Compute gradients + if self.trainable_weights: + # Call torch.Tensor.backward() on the loss to compute gradients + # for the weights. + loss.backward() + + trainable_weights = self.trainable_weights[:] + gradients = [v.value.grad for v in trainable_weights] + + # Update weights + with torch.no_grad(): + self.optimizer.apply(gradients, trainable_weights) + else: + warnings.warn("The model does not have any trainable weights.") + + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + + def test_step(self, data): + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + loss = self._compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False + ) + self._loss_tracker.update_state( + loss, + sample_weight=next( + i for i in tree.flatten(x) if i is not None + ).shape[0], + ) + return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) + + def predict_step(self, data): + x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data) + if self._call_has_training_arg: + y_pred = self(x, training=False) + else: + y_pred = self(x) + return y_pred + + def make_train_function(self, force=False): + if self.train_function is not None and not force: + return self.train_function + + if self.steps_per_execution > 1: + raise ValueError( + "`steps_per_execution` must be 1 with the PyTorch backend. " + f"Received: steps_per_execution={self.steps_per_execution}" + ) + + def one_step_on_data(data): + """Runs a single training step on a batch of data.""" + data = data[0] + return self.train_step(data) + + if self._should_torch_compile(): + self.train_function = torch.compile(one_step_on_data) + else: + self.train_function = one_step_on_data + + def make_test_function(self, force=False): + if self.test_function is not None and not force: + return self.test_function + + if self.steps_per_execution > 1: + raise ValueError( + "`steps_per_execution` must be 1 with the PyTorch backend. " + f"Received: steps_per_execution={self.steps_per_execution}" + ) + + def one_step_on_data(data): + """Runs a single test step on a batch of data.""" + data = data[0] + with torch.no_grad(): + return self.test_step(data) + + if self._should_torch_compile(): + self.test_function = torch.compile(one_step_on_data) + else: + self.test_function = one_step_on_data + + def make_predict_function(self, force=False): + if self.predict_function is not None and not force: + return self.predict_function + + if self.steps_per_execution > 1: + raise ValueError( + "`steps_per_execution` must be 1 with the PyTorch backend. " + f"Received: steps_per_execution={self.steps_per_execution}" + ) + + def one_step_on_data(data): + """Runs a predict test step on a batch of data.""" + data = data[0] + with torch.no_grad(): + return self.predict_step(data) + + if self._should_torch_compile(): + self.predict_function = torch.compile(one_step_on_data) + else: + self.predict_function = one_step_on_data + + @traceback_utils.filter_traceback + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + if not self.compiled: + raise ValueError( + "You must call `compile()` before calling `fit()`." + ) + # Possibly cap epochs for debugging runs. + max_epochs = config.max_epochs() + if max_epochs and max_epochs < epochs: + warnings.warn("Limiting epochs to %d" % max_epochs) + epochs = max_epochs + + # TODO: respect compiled trainable state + self._eval_epoch_iterator = None + if validation_split and validation_data is None: + # Create the validation data using the training data. Only supported + # for TF/numpy/jax arrays. + # TODO: Support torch tensors for validation data. + ( + (x, y, sample_weight), + validation_data, + ) = array_slicing.train_validation_split( + (x, y, sample_weight), validation_split=validation_split + ) + + if validation_data is not None: + ( + val_x, + val_y, + val_sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(validation_data) + + # Create an iterator that yields batches for one epoch. + epoch_iterator = TorchEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + shuffle=shuffle, + class_weight=class_weight, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=epochs, + steps=epoch_iterator.num_batches, + model=self, + ) + + self.stop_training = False + training_logs = {} + self.make_train_function() + callbacks.on_train_begin() + initial_epoch = self._initial_epoch or initial_epoch + for epoch in range(initial_epoch, epochs): + self.reset_metrics() + callbacks.on_epoch_begin(epoch) + + # Switch the torch Module to training mode. Inform torch layers to + # do training behavior in case the user did not use `self.training` + # when implementing a custom layer with torch layers. + self.train() + + logs = {} + for begin_step, end_step, data in epoch_iterator: + # Callbacks + callbacks.on_train_batch_begin(begin_step) + + logs = self.train_function(data) + + # Callbacks + callbacks.on_train_batch_end(end_step, logs) + if self.stop_training: + break + + # Override with model metrics instead of last step logs if needed. + epoch_logs = dict(self._get_metrics_result_or_logs(logs)) + + # Switch the torch Module back to testing mode. + self.eval() + + # Run validation. + if validation_data is not None and self._should_eval( + epoch, validation_freq + ): + # Create TorchEpochIterator for evaluation and cache it. + if getattr(self, "_eval_epoch_iterator", None) is None: + self._eval_epoch_iterator = TorchEpochIterator( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps_per_execution=self.steps_per_execution, + steps_per_epoch=validation_steps, + shuffle=False, + ) + val_logs = self.evaluate( + x=val_x, + y=val_y, + sample_weight=val_sample_weight, + batch_size=validation_batch_size or batch_size, + steps=validation_steps, + callbacks=callbacks, + return_dict=True, + _use_cached_eval_dataset=True, + ) + val_logs = { + f"val_{name}": val for name, val in val_logs.items() + } + epoch_logs.update(val_logs) + + callbacks.on_epoch_end(epoch, epoch_logs) + training_logs = epoch_logs + if self.stop_training: + break + + if ( + isinstance(self.optimizer, optimizers_module.Optimizer) + and epochs > 0 + ): + self.optimizer.finalize_variable_values(self.trainable_weights) + + # If _eval_epoch_iterator exists, delete it after all epochs are done. + if getattr(self, "_eval_epoch_iterator", None) is not None: + del self._eval_epoch_iterator + callbacks.on_train_end(logs=training_logs) + return self.history + + @traceback_utils.filter_traceback + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + # TODO: respect compiled trainable state + use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False) + if kwargs: + raise ValueError(f"Arguments not recognized: {kwargs}") + + if use_cached_eval_dataset: + epoch_iterator = self._eval_epoch_iterator + else: + # Create an iterator that yields batches of input/target data. + epoch_iterator = TorchEpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + self._symbolic_build(iterator=epoch_iterator) + epoch_iterator.reset() + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + # Switch the torch Module back to testing mode. + self.eval() + + self.make_test_function() + self.stop_evaluating = False + callbacks.on_test_begin() + logs = {} + self.reset_metrics() + for begin_step, end_step, data in epoch_iterator: + callbacks.on_test_batch_begin(begin_step) + logs = self.test_function(data) + callbacks.on_test_batch_end(end_step, logs) + if self.stop_evaluating: + break + logs = self._get_metrics_result_or_logs(logs) + callbacks.on_test_end(logs) + + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + @traceback_utils.filter_traceback + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + # Create an iterator that yields batches of input data. + epoch_iterator = TorchEpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + steps_per_execution=self.steps_per_execution, + ) + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + def append_to_outputs(batch_outputs, outputs): + if outputs is None: + outputs = tree.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tree.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + return outputs + + # Switch the torch Module back to testing mode. + self.eval() + + self.make_predict_function() + self.stop_predicting = False + callbacks.on_predict_begin() + outputs = None + for begin_step, end_step, data in epoch_iterator: + callbacks.on_predict_batch_begin(begin_step) + batch_outputs = self.predict_function(data) + outputs = append_to_outputs(batch_outputs, outputs) + callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) + if self.stop_predicting: + break + callbacks.on_predict_end() + outputs = tree.map_structure(backend.convert_to_numpy, outputs) + return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + self._assert_compile_called("train_on_batch") + if class_weight is not None: + if sample_weight is not None: + raise ValueError( + "Arguments `sample_weight` and `class_weight` " + "cannot be specified at the same time. " + f"Received: sample_weight={sample_weight}, " + f"class_weight={class_weight}" + ) + sample_weight = data_adapter_utils.class_weight_to_sample_weights( + y, class_weight + ) + + data = (x, y, sample_weight) + + # Maybe build model + self._symbolic_build(data_batch=data) + self.make_train_function() + + logs = self.train_function([data]) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + self._assert_compile_called("test_on_batch") + + data = (x, y, sample_weight) + + # Maybe build model + self._symbolic_build(data_batch=data) + self.make_test_function() + + logs = self.test_function([data]) + logs = tree.map_structure(lambda x: np.array(x), logs) + if return_dict: + return logs + return self._flatten_metrics_in_order(logs) + + def predict_on_batch(self, x): + self.make_predict_function() + batch_outputs = self.predict_function([(x,)]) + batch_outputs = tree.map_structure( + backend.convert_to_numpy, batch_outputs + ) + return batch_outputs + + +class TorchEpochIterator(EpochIterator): + def _get_iterator(self): + return self.data_adapter.get_torch_dataloader() diff --git a/keras/src/callbacks/__init__.py b/keras/src/callbacks/__init__.py new file mode 100644 index 000000000000..427c4f6da95f --- /dev/null +++ b/keras/src/callbacks/__init__.py @@ -0,0 +1,16 @@ +from keras.src.callbacks.backup_and_restore import BackupAndRestore +from keras.src.callbacks.callback import Callback +from keras.src.callbacks.callback_list import CallbackList +from keras.src.callbacks.csv_logger import CSVLogger +from keras.src.callbacks.early_stopping import EarlyStopping +from keras.src.callbacks.history import History +from keras.src.callbacks.lambda_callback import LambdaCallback +from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler +from keras.src.callbacks.model_checkpoint import ModelCheckpoint +from keras.src.callbacks.monitor_callback import MonitorCallback +from keras.src.callbacks.progbar_logger import ProgbarLogger +from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau +from keras.src.callbacks.remote_monitor import RemoteMonitor +from keras.src.callbacks.swap_ema_weights import SwapEMAWeights +from keras.src.callbacks.tensorboard import TensorBoard +from keras.src.callbacks.terminate_on_nan import TerminateOnNaN diff --git a/keras/src/callbacks/backup_and_restore.py b/keras/src/callbacks/backup_and_restore.py new file mode 100644 index 000000000000..55053cc43640 --- /dev/null +++ b/keras/src/callbacks/backup_and_restore.py @@ -0,0 +1,210 @@ +import json + +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.utils import file_utils + + +@keras_export("keras.callbacks.BackupAndRestore") +class BackupAndRestore(Callback): + """Callback to back up and restore the training state. + + `BackupAndRestore` callback is intended to recover training from an + interruption that has happened in the middle of a `Model.fit` execution, by + backing up the training states in a temporary checkpoint file, at the end of + each epoch. Each backup overwrites the previously written checkpoint file, + so at any given time there is at most one such checkpoint file for + backup/restoring purpose. + + If training restarts before completion, the training state (which includes + the `Model` weights and epoch number) is restored to the most recently saved + state at the beginning of a new `Model.fit` run. At the completion of a + `Model.fit` run, the temporary checkpoint file is deleted. + + Note that the user is responsible to bring jobs back after the interruption. + This callback is important for the backup and restore mechanism for fault + tolerance purpose, and the model to be restored from a previous checkpoint + is expected to be the same as the one used to back up. If user changes + arguments passed to compile or fit, the checkpoint saved for fault tolerance + can become invalid. + + Example: + + >>> class InterruptingCallback(keras.callbacks.Callback): + ... def on_epoch_begin(self, epoch, logs=None): + ... if epoch == 4: + ... raise RuntimeError('Interrupting!') + >>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup") + >>> model = keras.models.Sequential([keras.layers.Dense(10)]) + >>> model.compile(keras.optimizers.SGD(), loss='mse') + >>> model.build(input_shape=(None, 20)) + >>> try: + ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, + ... batch_size=1, callbacks=[callback, InterruptingCallback()], + ... verbose=0) + ... except: + ... pass + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=10, batch_size=1, callbacks=[callback], + ... verbose=0) + >>> # Only 6 more epochs are run, since first training got interrupted at + >>> # zero-indexed epoch 4, second training will continue from 4 to 9. + >>> len(history.history['loss']) + >>> 6 + + Args: + backup_dir: String, path of directory where to store the data + needed to restore the model. The directory + cannot be reused elsewhere to store other files, e.g. by the + `BackupAndRestore` callback of another training run, + or by another callback (e.g. `ModelCheckpoint`) + of the same training run. + save_freq: `"epoch"`, integer, or `False`. When set to `"epoch"` + the callback saves the checkpoint at the end of each epoch. + When set to an integer, the callback saves the checkpoint every + `save_freq` batches. Set `save_freq=False` only if using + preemption checkpointing (i.e. with `save_before_preemption=True`). + double_checkpoint: Boolean. If enabled, `BackupAndRestore` callback + will save 2 last training states (current and previous). After + interruption if current state can't be loaded due to IO error + (e.g. file corrupted) it will try to restore previous one. Such + behaviour will consume twice more space on disk, but increase fault + tolerance. Defaults to `False`. + delete_checkpoint: Boolean. This `BackupAndRestore` + callback works by saving a checkpoint to back up the training state. + If `delete_checkpoint=True`, the checkpoint will be deleted after + training is finished. Use `False` if you'd like to keep the checkpoint + for future usage. Defaults to `True`. + """ + + def __init__( + self, + backup_dir, + save_freq="epoch", + double_checkpoint=False, + delete_checkpoint=True, + ): + super().__init__() + self.save_freq = save_freq + self.double_checkpoint = double_checkpoint + self.delete_checkpoint = delete_checkpoint + self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 + self._current_epoch = 0 + + if not backup_dir: + raise ValueError("Empty `backup_dir` argument passed") + self.backup_dir = backup_dir + self._weights_path = file_utils.join(backup_dir, "latest.weights.h5") + self._training_metadata_path = file_utils.join( + backup_dir, "training_metadata.json" + ) + self._prev_weights_path = f"{self._weights_path}.bkp" + self._prev_training_metadata_path = ( + f"{self._training_metadata_path}.bkp" + ) + if save_freq != "epoch" and not isinstance(save_freq, int): + raise ValueError( + "Invalid value for argument `save_freq`. " + f"Received: save_freq={save_freq}. " + "Expected either 'epoch' or an integer value." + ) + + def on_train_begin(self, logs=None): + try: + self._load_model() + except OSError as e: + # Weights may be corrupted. Trying to load previous one. + if not file_utils.exists(self._prev_weights_path): + raise e + file_utils.copy(self._prev_weights_path, self._weights_path) + if file_utils.exists(self._prev_training_metadata_path): + file_utils.copy( + self._prev_training_metadata_path, + self._training_metadata_path, + ) + elif file_utils.exists(self._training_metadata_path): + file_utils.remove(self._training_metadata_path) + self._load_model() + + def _load_model(self): + """Get training state from temporary file and restore it.""" + if not self.model.built: + raise ValueError( + "To use the BackupAndRestore callback, " + "you model must be built before you call `fit()`. " + f"Model {self.model} is unbuilt. You can build it " + "beforehand by calling it on a batch of data." + ) + if file_utils.exists(self._weights_path): + if ( + self.model.optimizer is not None + and not self.model.optimizer.built + ): + # Make sure optimizer weights exist before loading. + self.model.optimizer.build(self.model.trainable_variables) + self.model.load_weights(self._weights_path) + + if file_utils.exists(self._training_metadata_path): + with file_utils.File(self._training_metadata_path, "r") as f: + training_metadata = json.loads(f.read()) + epoch = training_metadata["epoch"] + self.model._initial_epoch = epoch + + def on_epoch_end(self, epoch, logs=None): + self._current_epoch = epoch + 1 + self._last_batch_seen = 0 + if self.save_freq == "epoch": + self._save_model() + + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + self._save_model() + + def _save_model(self): + """Saves the model. + + Args: + epoch: the epoch this iteration is in. + batch: the batch this iteration is in. `None` if the `save_freq` + is set to `"epoch"`. + logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. + """ + # Create host directory if it doesn't exist. + if not file_utils.exists(self.backup_dir): + file_utils.makedirs(self.backup_dir) + if self.double_checkpoint and file_utils.exists(self._weights_path): + file_utils.copy(self._weights_path, self._prev_weights_path) + if self.double_checkpoint and file_utils.exists( + self._training_metadata_path + ): + file_utils.copy( + self._training_metadata_path, self._prev_training_metadata_path + ) + self.model.save_weights(filepath=self._weights_path, overwrite=True) + with file_utils.File(self._training_metadata_path, "w") as f: + training_metadata = { + "epoch": self._current_epoch, + "batch": self._last_batch_seen, + } + f.write(json.dumps(training_metadata)) + + def _should_save_on_batch(self, batch): + """Handles batch-level saving logic, supports steps_per_execution.""" + if self.save_freq == "epoch": + return False + if batch <= self._last_batch_seen: # New epoch. + add_batches = batch + 1 # batches are zero-indexed. + else: + add_batches = batch - self._last_batch_seen + self._batches_seen_since_last_saving += add_batches + self._last_batch_seen = batch + + if self._batches_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving = 0 + return True + return False + + def on_train_end(self, logs=None): + if self.delete_checkpoint and file_utils.exists(self.backup_dir): + file_utils.rmtree(self.backup_dir) diff --git a/keras/src/callbacks/backup_and_restore_test.py b/keras/src/callbacks/backup_and_restore_test.py new file mode 100644 index 000000000000..cde8dd87eb82 --- /dev/null +++ b/keras/src/callbacks/backup_and_restore_test.py @@ -0,0 +1,231 @@ +import numpy as np +import pytest + +from keras.src import callbacks +from keras.src import layers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.utils import file_utils + + +class InterruptingCallback(callbacks.Callback): + """A callback to intentionally interrupt training.""" + + def __init__(self, steps_int, epoch_int): + self.batch_count = 0 + self.epoch_count = 0 + self.steps_int = steps_int + self.epoch_int = epoch_int + + def on_epoch_end(self, epoch, log=None): + self.epoch_count += 1 + if self.epoch_int is not None and self.epoch_count == self.epoch_int: + raise RuntimeError("EpochInterruption") + + def on_batch_end(self, batch, logs=None): + self.batch_count += 1 + if self.steps_int is not None and self.batch_count == self.steps_int: + raise RuntimeError("StepsInterruption") + + +class CanaryLayer(layers.Layer): + def __init__(self): + super().__init__() + self.counter = self.add_weight( + shape=(), initializer="zeros", dtype="float32", trainable=False + ) + + def call(self, x): + self.counter.assign_add(1) + return x + + +class BackupAndRestoreCallbackTest(testing.TestCase): + def make_model(self): + model = Sequential( + [ + layers.Input((3,)), + CanaryLayer(), + layers.Dense(1), + ] + ) + model.compile( + loss="mse", + optimizer="sgd", + metrics=["mse"], + ) + return model + + # Check invalid save_freq, both string and non integer + def test_save_freq_unknown_error(self): + with self.assertRaisesRegex(ValueError, expected_regex="Invalid value"): + callbacks.BackupAndRestore( + backup_dir="backup_dir", save_freq="batch" + ) + + with self.assertRaisesRegex(ValueError, expected_regex="Invalid value"): + callbacks.BackupAndRestore(backup_dir="backup_dir", save_freq=0.15) + + # Checking if after interruption, correct model params and + # weights are loaded in step-wise backup + @pytest.mark.requires_trainable_backend + def test_best_case_step(self): + temp_dir = self.get_temp_dir() + backup_dir = file_utils.join(temp_dir, "subdir") + self.assertFalse(file_utils.exists(backup_dir)) + + model = self.make_model() + cbk = callbacks.BackupAndRestore(backup_dir, save_freq=1) + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + + try: + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[ + cbk, + InterruptingCallback(steps_int=2, epoch_int=None), + ], + epochs=2, + verbose=0, + ) + except RuntimeError: + self.assertTrue(file_utils.exists(backup_dir)) + self.assertEqual(cbk._current_epoch, 0) + self.assertEqual(cbk._last_batch_seen, 1) + self.assertEqual(int(model.layers[0].counter.value), 2) + + hist = model.fit( + x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5 + ) + + self.assertEqual(cbk._current_epoch, 5) + self.assertEqual(hist.epoch[-1], 4) + self.assertEqual(int(model.layers[0].counter.value), 17) + + # Checking if after interruption, correct model params and + # weights are loaded in epoch-wise backup + @pytest.mark.requires_trainable_backend + def test_best_case_epoch(self): + temp_dir = self.get_temp_dir() + backup_dir = file_utils.join(temp_dir, "subdir") + self.assertFalse(file_utils.exists(backup_dir)) + + model = self.make_model() + self.assertEqual(int(model.layers[0].counter.value), 0) + cbk = callbacks.BackupAndRestore( + backup_dir=backup_dir, save_freq="epoch" + ) + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + + try: + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[ + cbk, + InterruptingCallback(steps_int=None, epoch_int=2), + ], + epochs=6, + verbose=0, + ) + except RuntimeError: + self.assertEqual(cbk._current_epoch, 2) + self.assertTrue(file_utils.exists(backup_dir)) + self.assertEqual(int(model.layers[0].counter.value), 6) + + hist = model.fit( + x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5 + ) + self.assertEqual(cbk._current_epoch, 5) + self.assertEqual(hist.epoch[-1], 4) + self.assertEqual(int(model.layers[0].counter.value), 5 * 3) + + # Checking if after interruption and weights corruption, previous model + # params and weights are loaded + @pytest.mark.requires_trainable_backend + def test_backup_corrupted(self): + temp_dir = self.get_temp_dir() + backup_dir = file_utils.join(temp_dir, "subdir") + self.assertFalse(file_utils.exists(backup_dir)) + + model = self.make_model() + self.assertEqual(int(model.layers[0].counter.value), 0) + cbk = callbacks.BackupAndRestore( + backup_dir=backup_dir, save_freq="epoch", double_checkpoint=True + ) + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + + try: + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[ + cbk, + InterruptingCallback(steps_int=None, epoch_int=2), + ], + epochs=6, + verbose=0, + ) + except RuntimeError: + self.assertEqual(cbk._current_epoch, 2) + self.assertTrue(file_utils.exists(backup_dir)) + self.assertTrue(file_utils.exists(cbk._weights_path)) + self.assertTrue(file_utils.exists(cbk._training_metadata_path)) + self.assertTrue(file_utils.exists(cbk._prev_weights_path)) + self.assertTrue(file_utils.exists(cbk._prev_training_metadata_path)) + self.assertEqual(int(model.layers[0].counter.value), 6) + + # Corruption weights + with file_utils.File(cbk._weights_path, "w") as f: + f.write("0") + + hist = model.fit( + x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5 + ) + self.assertEqual(cbk._current_epoch, 5) + self.assertEqual(hist.epoch[-1], 4) + self.assertEqual(int(model.layers[0].counter.value), 5 * 3) + + # Checking if after interruption, when model is deleted + @pytest.mark.requires_trainable_backend + def test_model_deleted_case_epoch(self): + temp_dir = self.get_temp_dir() + backup_dir = file_utils.join(temp_dir, "subdir") + self.assertFalse(file_utils.exists(backup_dir)) + + model = self.make_model() + cbk = callbacks.BackupAndRestore(backup_dir, save_freq="epoch") + + x_train = np.random.random((10, 3)) + y_train = np.random.random((10, 1)) + model.fit( + x_train, + y_train, + batch_size=4, + callbacks=[cbk], + epochs=2, + verbose=0, + ) + self.assertFalse(file_utils.exists(backup_dir)) + + def test_backup_dir_empty_error(self): + with self.assertRaisesRegex( + ValueError, expected_regex="Empty `backup_dir` argument passed" + ): + callbacks.BackupAndRestore(backup_dir="", save_freq="epoch") + + def test_backup_dir_none_error(self): + with self.assertRaisesRegex( + ValueError, expected_regex="Empty `backup_dir` argument passed" + ): + callbacks.BackupAndRestore(backup_dir=None, save_freq="epoch") diff --git a/keras/src/callbacks/callback.py b/keras/src/callbacks/callback.py new file mode 100644 index 000000000000..f3f359657394 --- /dev/null +++ b/keras/src/callbacks/callback.py @@ -0,0 +1,308 @@ +from keras.src import backend +from keras.src import utils +from keras.src.api_export import keras_export + + +@keras_export("keras.callbacks.Callback") +class Callback: + """Base class used to build new callbacks. + + Callbacks can be passed to keras methods such as `fit()`, `evaluate()`, and + `predict()` in order to hook into the various stages of the model training, + evaluation, and inference lifecycle. + + To create a custom callback, subclass `keras.callbacks.Callback` and + override the method associated with the stage of interest. + + Example: + + >>> training_finished = False + >>> class MyCallback(Callback): + ... def on_train_end(self, logs=None): + ... global training_finished + ... training_finished = True + >>> model = Sequential([ + ... layers.Dense(1, input_shape=(1,))]) + >>> model.compile(loss='mean_squared_error') + >>> model.fit(np.array([[1.0]]), np.array([[1.0]]), + ... callbacks=[MyCallback()]) + >>> assert training_finished == True + + If you want to use `Callback` objects in a custom training loop: + + 1. You should pack all your callbacks into a single `callbacks.CallbackList` + so they can all be called together. + 2. You will need to manually call all the `on_*` methods at the appropriate + locations in your loop. Like this: + + Example: + + ```python + callbacks = keras.callbacks.CallbackList([...]) + callbacks.append(...) + callbacks.on_train_begin(...) + for epoch in range(EPOCHS): + callbacks.on_epoch_begin(epoch) + for i, data in dataset.enumerate(): + callbacks.on_train_batch_begin(i) + batch_logs = model.train_step(data) + callbacks.on_train_batch_end(i, batch_logs) + epoch_logs = ... + callbacks.on_epoch_end(epoch, epoch_logs) + final_logs=... + callbacks.on_train_end(final_logs) + ``` + + Attributes: + params: Dict. Training parameters + (eg. verbosity, batch size, number of epochs...). + model: Instance of `Model`. + Reference of the model being trained. + + The `logs` dictionary that callback methods + take as argument will contain keys for quantities relevant to + the current batch or epoch (see method-specific docstrings). + """ + + def __init__(self): + self.params = None + self._model = None + + def set_params(self, params): + self.params = params + + def set_model(self, model): + self._model = model + + @property + def model(self): + if backend.backend() == "torch": + from torch.nn.parallel import DistributedDataParallel + + if isinstance(self._model, DistributedDataParallel): + # Keras Callbacks expect to work with Keras models. e.g + # ModelCheckpoint and EarlyStopping both attempt to call + # keras-specific APIs on the value returned from this + # property. If this callback was created against a DDP + # wrapper instead of the underlying keras.Model, it is + # likely to fail. Return self._model.module for DDP + # instances instead. + return self._model.module + + if backend.backend() == "jax" and hasattr( + self._model, "jax_state_sync" + ): + # With JAX, by default the model state is not + # attached to the model in the middle of an + # epoch. We have to force a sync before + # accessing model state for e.g. checkpointing. + self._model.jax_state_sync() + return self._model + + @utils.default + def on_batch_begin(self, batch, logs=None): + """A backwards compatibility alias for `on_train_batch_begin`.""" + + @utils.default + def on_batch_end(self, batch, logs=None): + """A backwards compatibility alias for `on_train_batch_end`.""" + + @utils.default + def on_epoch_begin(self, epoch, logs=None): + """Called at the start of an epoch. + + Subclasses should override for any actions to run. This function should + only be called during TRAIN mode. + + Args: + epoch: Integer, index of epoch. + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + + @utils.default + def on_epoch_end(self, epoch, logs=None): + """Called at the end of an epoch. + + Subclasses should override for any actions to run. This function should + only be called during TRAIN mode. + + Args: + epoch: Integer, index of epoch. + logs: Dict, metric results for this training epoch, and for the + validation epoch if validation is performed. Validation result + keys are prefixed with `val_`. For training epoch, the values of + the `Model`'s metrics are returned. Example: + `{'loss': 0.2, 'accuracy': 0.7}`. + """ + + @utils.default + def on_train_batch_begin(self, batch, logs=None): + """Called at the beginning of a training batch in `fit` methods. + + Subclasses should override for any actions to run. + + Note that if the `steps_per_execution` argument to `compile` in + `Model` is set to `N`, this method will only be called every + `N` batches. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + # For backwards compatibility. + self.on_batch_begin(batch, logs=logs) + + @utils.default + def on_train_batch_end(self, batch, logs=None): + """Called at the end of a training batch in `fit` methods. + + Subclasses should override for any actions to run. + + Note that if the `steps_per_execution` argument to `compile` in + `Model` is set to `N`, this method will only be called every + `N` batches. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. + """ + # For backwards compatibility. + self.on_batch_end(batch, logs=logs) + + @utils.default + def on_test_batch_begin(self, batch, logs=None): + """Called at the beginning of a batch in `evaluate` methods. + + Also called at the beginning of a validation batch in the `fit` + methods, if validation data is provided. + + Subclasses should override for any actions to run. + + Note that if the `steps_per_execution` argument to `compile` in + `Model` is set to `N`, this method will only be called every + `N` batches. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + + @utils.default + def on_test_batch_end(self, batch, logs=None): + """Called at the end of a batch in `evaluate` methods. + + Also called at the end of a validation batch in the `fit` + methods, if validation data is provided. + + Subclasses should override for any actions to run. + + Note that if the `steps_per_execution` argument to `compile` in + `Model` is set to `N`, this method will only be called every + `N` batches. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. + """ + + @utils.default + def on_predict_batch_begin(self, batch, logs=None): + """Called at the beginning of a batch in `predict` methods. + + Subclasses should override for any actions to run. + + Note that if the `steps_per_execution` argument to `compile` in + `Model` is set to `N`, this method will only be called every + `N` batches. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + + @utils.default + def on_predict_batch_end(self, batch, logs=None): + """Called at the end of a batch in `predict` methods. + + Subclasses should override for any actions to run. + + Note that if the `steps_per_execution` argument to `compile` in + `Model` is set to `N`, this method will only be called every + `N` batches. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict. Aggregated metric results up until this batch. + """ + + @utils.default + def on_train_begin(self, logs=None): + """Called at the beginning of training. + + Subclasses should override for any actions to run. + + Args: + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + + @utils.default + def on_train_end(self, logs=None): + """Called at the end of training. + + Subclasses should override for any actions to run. + + Args: + logs: Dict. Currently the output of the last call to + `on_epoch_end()` is passed to this argument for this method but + that may change in the future. + """ + + @utils.default + def on_test_begin(self, logs=None): + """Called at the beginning of evaluation or validation. + + Subclasses should override for any actions to run. + + Args: + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + + @utils.default + def on_test_end(self, logs=None): + """Called at the end of evaluation or validation. + + Subclasses should override for any actions to run. + + Args: + logs: Dict. Currently the output of the last call to + `on_test_batch_end()` is passed to this argument for this method + but that may change in the future. + """ + + @utils.default + def on_predict_begin(self, logs=None): + """Called at the beginning of prediction. + + Subclasses should override for any actions to run. + + Args: + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ + + @utils.default + def on_predict_end(self, logs=None): + """Called at the end of prediction. + + Subclasses should override for any actions to run. + + Args: + logs: Dict. Currently no data is passed to this argument for this + method but that may change in the future. + """ diff --git a/keras/src/callbacks/callback_list.py b/keras/src/callbacks/callback_list.py new file mode 100644 index 000000000000..e020154ccd41 --- /dev/null +++ b/keras/src/callbacks/callback_list.py @@ -0,0 +1,281 @@ +import concurrent.futures + +from keras.src import backend +from keras.src import tree +from keras.src import utils +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.callbacks.history import History +from keras.src.callbacks.progbar_logger import ProgbarLogger +from keras.src.utils import python_utils + + +@keras_export("keras.callbacks.CallbackList") +class CallbackList(Callback): + """Container abstracting a list of callbacks.""" + + def __init__( + self, + callbacks=None, + add_history=False, + add_progbar=False, + model=None, + **params, + ): + """Container for `Callback` instances. + + This object wraps a list of `Callback` instances, making it possible + to call them all at once via a single endpoint + (e.g. `callback_list.on_epoch_end(...)`). + + Args: + callbacks: List of `Callback` instances. + add_history: Whether a `History` callback should be added, if one + does not already exist in the `callbacks` list. + add_progbar: Whether a `ProgbarLogger` callback should be added, if + one does not already exist in the `callbacks` list. + model: The `Model` these callbacks are used with. + **params: If provided, parameters will be passed to each `Callback` + via `Callback.set_params`. + """ + self.callbacks = tree.flatten(callbacks) if callbacks else [] + self._in_begin_end_block_count = 0 + self._executor = None + self._async_train = False + self._async_test = False + self._async_predict = False + self._futures = [] + self._configure_async_dispatch(callbacks) + self._add_default_callbacks(add_history, add_progbar) + self.set_model(model) + self.set_params(params) + + def set_params(self, params): + self.params = params + if params: + for callback in self.callbacks: + callback.set_params(params) + + def _configure_async_dispatch(self, callbacks): + # Determine whether callbacks can be dispatched asynchronously. + if not backend.IS_THREAD_SAFE: + return + async_train = True + async_test = True + async_predict = True + if callbacks: + if isinstance(callbacks, (list, tuple)): + for cbk in callbacks: + if getattr(cbk, "async_safe", False): + # Callbacks that expose self.async_safe == True + # will be assumed safe for async dispatch. + continue + if not utils.is_default(cbk.on_batch_end): + async_train = False + if not utils.is_default(cbk.on_train_batch_end): + async_train = False + if not utils.is_default(cbk.on_test_batch_end): + async_test = False + if not utils.is_default(cbk.on_predict_batch_end): + async_predict = False + + self._async_train = async_train + self._async_test = async_test + self._async_predict = async_predict + + def _add_default_callbacks(self, add_history, add_progbar): + """Adds `Callback`s that are always present.""" + self._progbar = None + self._history = None + + for cb in self.callbacks: + if isinstance(cb, ProgbarLogger): + self._progbar = cb + elif isinstance(cb, History): + self._history = cb + + if self._history is None and add_history: + self._history = History() + self.callbacks.append(self._history) + + if self._progbar is None and add_progbar: + self._progbar = ProgbarLogger() + self.callbacks.append(self._progbar) + + def set_model(self, model): + if not model: + return + super().set_model(model) + if self._history: + model.history = self._history + for callback in self.callbacks: + callback.set_model(model) + + def _on_begin(self): + """Called by `on_train/test/predict_begin`. + + Start the executor for async calls if needed. + """ + self._in_begin_end_block_count += 1 + if ( + self._in_begin_end_block_count == 1 + and (self._async_train or self._async_test or self._async_predict) + and self._executor is None + ): + self._executor = concurrent.futures.ThreadPoolExecutor() + + def _on_end(self): + """Called by `on_train/test/predict_end`. + + Shutdown the executor for async calls if all begin/end blocks completed. + """ + self._in_begin_end_block_count -= 1 + if self._in_begin_end_block_count < 0: + raise ValueError( + "`on_xxx_end` called without corresponding `on_xxx_begin`" + ) + if self._in_begin_end_block_count == 0 and self._executor is not None: + self._executor.shutdown() + self._executor = None + + def _async_dispatch(self, fn, *args): + for future in self._futures: + if future.done(): + future.result() + self._futures.remove(future) + future = self._executor.submit(fn, *args) + self._futures.append(future) + + def _flush_futures(self): + """Waits for all futures to complete and clears the list.""" + for future in self._futures: + future.result() + self._futures = [] + + def on_batch_begin(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_batch_begin(batch, logs=logs) + + def on_epoch_begin(self, epoch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_epoch_begin(epoch, logs) + + def on_epoch_end(self, epoch, logs=None): + if self._async_train: + self._flush_futures() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_epoch_end(epoch, logs) + + def on_train_batch_begin(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_train_batch_begin(batch, logs=logs) + + def on_test_batch_begin(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_test_batch_begin(batch, logs=logs) + + def on_predict_batch_begin(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_predict_batch_begin(batch, logs=logs) + + def on_batch_end(self, batch, logs=None): + if self._async_train: + self._async_dispatch(self._on_batch_end, batch, logs) + else: + self._on_batch_end(batch, logs) + + def on_train_batch_end(self, batch, logs=None): + if self._async_train: + self._async_dispatch(self._on_train_batch_end, batch, logs) + else: + self._on_train_batch_end(batch, logs) + + def on_test_batch_end(self, batch, logs=None): + if self._async_test: + self._async_dispatch(self._on_test_batch_end, batch, logs) + else: + self._on_test_batch_end(batch, logs) + + def on_predict_batch_end(self, batch, logs=None): + if self._async_predict: + self._async_dispatch(self._on_predict_batch_end, batch, logs) + else: + self._on_predict_batch_end(batch, logs) + + def _on_batch_end(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_batch_end(batch, logs=logs) + + def _on_train_batch_end(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_train_batch_end(batch, logs=logs) + + def _on_test_batch_end(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_test_batch_end(batch, logs=logs) + + def _on_predict_batch_end(self, batch, logs=None): + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_predict_batch_end(batch, logs=logs) + + def on_train_begin(self, logs=None): + self._on_begin() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_train_begin(logs) + + def on_train_end(self, logs=None): + if self._async_train: + self._flush_futures() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_train_end(logs) + + self._on_end() + + def on_test_begin(self, logs=None): + self._on_begin() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_test_begin(logs) + + def on_test_end(self, logs=None): + if self._async_test: + self._flush_futures() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_test_end(logs) + + self._on_end() + + def on_predict_begin(self, logs=None): + self._on_begin() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_predict_begin(logs) + + def on_predict_end(self, logs=None): + if self._async_predict: + self._flush_futures() + + logs = python_utils.pythonify_logs(logs) + for callback in self.callbacks: + callback.on_predict_end(logs) + + self._on_end() diff --git a/keras/src/callbacks/callback_test.py b/keras/src/callbacks/callback_test.py new file mode 100644 index 000000000000..31c77c904ceb --- /dev/null +++ b/keras/src/callbacks/callback_test.py @@ -0,0 +1,31 @@ +import numpy as np +import pytest + +from keras.src import models +from keras.src import testing +from keras.src.callbacks.callback import Callback + + +class CallbackTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_model_state_is_current_on_epoch_end(self): + class TestModel(models.Model): + def __init__(self): + super().__init__() + self.iterations = self.add_variable( + shape=(), initializer="zeros", trainable=False + ) + + def call(self, inputs): + self.iterations.assign(self.iterations + 1) + return inputs + + class CBK(Callback): + def on_batch_end(self, batch, logs): + assert np.int32(self.model.iterations) == batch + 1 + + model = TestModel() + model.compile(optimizer="sgd", loss="mse") + x = np.random.random((8, 1)) + y = np.random.random((8, 1)) + model.fit(x, y, callbacks=[CBK()], batch_size=2) diff --git a/keras/src/callbacks/csv_logger.py b/keras/src/callbacks/csv_logger.py new file mode 100644 index 000000000000..88dbeadb158f --- /dev/null +++ b/keras/src/callbacks/csv_logger.py @@ -0,0 +1,107 @@ +import collections +import csv + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.utils import file_utils + + +@keras_export("keras.callbacks.CSVLogger") +class CSVLogger(Callback): + """Callback that streams epoch results to a CSV file. + + Supports all values that can be represented as a string, + including 1D iterables such as `np.ndarray`. + + Args: + filename: Filename of the CSV file, e.g. `'run/log.csv'`. + separator: String used to separate elements in the CSV file. + append: Boolean. True: append if file exists (useful for continuing + training). False: overwrite existing file. + + Example: + + ```python + csv_logger = CSVLogger('training.log') + model.fit(X_train, Y_train, callbacks=[csv_logger]) + ``` + """ + + def __init__(self, filename, separator=",", append=False): + super().__init__() + self.sep = separator + self.filename = file_utils.path_to_string(filename) + self.append = append + self.writer = None + self.keys = None + self.append_header = True + self.csv_file = None + + def on_train_begin(self, logs=None): + if self.append: + if file_utils.exists(self.filename): + with file_utils.File(self.filename, "r") as f: + self.append_header = not bool(len(f.readline())) + mode = "a" + else: + mode = "w" + # ensure csv_file is None or closed before reassigning + if self.csv_file and not self.csv_file.closed: + self.csv_file.close() + self.csv_file = file_utils.File(self.filename, mode) + # Reset writer and keys + self.writer = None + self.keys = None + + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + + def handle_value(k): + is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 + if isinstance(k, str): + return k + elif ( + isinstance(k, collections.abc.Iterable) + and not is_zero_dim_ndarray + ): + return f'"[{", ".join(map(str, k))}]"' + else: + return k + + if self.keys is None: + self.keys = sorted(logs.keys()) + + val_keys_found = False + for key in self.keys: + if key.startswith("val_"): + val_keys_found = True + break + if not val_keys_found and self.keys: + self.keys.extend([f"val_{k}" for k in self.keys]) + + if not self.writer: + + class CustomDialect(csv.excel): + delimiter = self.sep + + fieldnames = ["epoch"] + (self.keys or []) + + self.writer = csv.DictWriter( + self.csv_file, fieldnames=fieldnames, dialect=CustomDialect + ) + if self.append_header: + self.writer.writeheader() + + row_dict = collections.OrderedDict({"epoch": epoch}) + row_dict.update( + (key, handle_value(logs.get(key, "NA"))) for key in self.keys + ) + self.writer.writerow(row_dict) + self.csv_file.flush() + + def on_train_end(self, logs=None): + if self.csv_file and not self.csv_file.closed: + self.csv_file.close() + self.writer = None diff --git a/keras/src/callbacks/csv_logger_test.py b/keras/src/callbacks/csv_logger_test.py new file mode 100644 index 000000000000..9da3be6aaa53 --- /dev/null +++ b/keras/src/callbacks/csv_logger_test.py @@ -0,0 +1,179 @@ +import csv +import os +import re +import tempfile + +import numpy as np +import pytest + +from keras.src import callbacks +from keras.src import initializers +from keras.src import layers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.utils import numerical_utils + +TRAIN_SAMPLES = 10 +TEST_SAMPLES = 10 +INPUT_DIM = 3 +BATCH_SIZE = 4 + + +class CSVLoggerTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_CSVLogger(self): + OUTPUT_DIM = 1 + np.random.seed(1337) + temp_dir = tempfile.TemporaryDirectory() + filepath = os.path.join(temp_dir.name, "log.tsv") + + sep = "\t" + x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM)) + y_train = np.random.random((TRAIN_SAMPLES, OUTPUT_DIM)) + x_test = np.random.random((TEST_SAMPLES, INPUT_DIM)) + y_test = np.random.random((TEST_SAMPLES, OUTPUT_DIM)) + + def make_model(): + np.random.seed(1337) + model = Sequential( + [ + layers.Dense(2, activation="relu"), + layers.Dense(OUTPUT_DIM), + ] + ) + model.compile( + loss="mse", + optimizer="sgd", + metrics=["mse"], + ) + return model + + # case 1, create new file with defined separator + model = make_model() + cbks = [callbacks.CSVLogger(filepath, separator=sep)] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + + assert os.path.exists(filepath) + with open(filepath) as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read()) + assert dialect.delimiter == sep + del model + del cbks + + # case 2, append data to existing file, skip header + model = make_model() + cbks = [callbacks.CSVLogger(filepath, separator=sep, append=True)] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + + # case 3, reuse of CSVLogger object + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=2, + verbose=0, + ) + + with open(filepath) as csvfile: + list_lines = csvfile.readlines() + for line in list_lines: + assert line.count(sep) == 4 + assert len(list_lines) == 5 + output = " ".join(list_lines) + assert len(re.findall("epoch", output)) == 1 + + os.remove(filepath) + + # case 3, Verify Val. loss also registered when Validation Freq > 1 + model = make_model() + cbks = [callbacks.CSVLogger(filepath, separator=sep)] + hist = model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + validation_freq=3, + callbacks=cbks, + epochs=5, + verbose=0, + ) + assert os.path.exists(filepath) + # Verify that validation loss is registered at val. freq + with open(filepath) as csvfile: + rows = csv.DictReader(csvfile, delimiter=sep) + for idx, row in enumerate(rows, 1): + self.assertIn("val_loss", row) + if idx == 3: + self.assertEqual( + row["val_loss"], str(hist.history["val_loss"][0]) + ) + else: + self.assertEqual(row["val_loss"], "NA") + + @pytest.mark.requires_trainable_backend + def test_stop_training_csv(self): + # Test that using the CSVLogger callback with the TerminateOnNaN + # callback does not result in invalid CSVs. + tmpdir = tempfile.TemporaryDirectory() + csv_logfile = os.path.join(tmpdir.name, "csv_logger.csv") + NUM_CLASSES = 2 + np.random.seed(1337) + x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM)) + y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES) + x_test = np.random.random((TEST_SAMPLES, INPUT_DIM)) + y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES) + + y_test = numerical_utils.to_categorical(y_test) + y_train = numerical_utils.to_categorical(y_train) + model = Sequential() + initializer = initializers.Constant(value=1e5) + for _ in range(5): + model.add( + layers.Dense( + 2, + activation="relu", + kernel_initializer=initializer, + ) + ) + model.add(layers.Dense(NUM_CLASSES)) + model.compile(loss="mean_squared_error", optimizer="sgd") + + history = model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=[ + callbacks.TerminateOnNaN(), + callbacks.CSVLogger(csv_logfile), + ], + epochs=20, + ) + loss = history.history["loss"] + self.assertEqual(len(loss), 1) + self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0])) + + values = [] + with open(csv_logfile) as f: + # On Windows, due to \r\n line ends, we may end up reading empty + # lines after each line. Skip empty lines. + values = [x for x in csv.reader(f) if x] + self.assertIn("nan", values[-1], "NaN not logged in CSV Logger.") diff --git a/keras/src/callbacks/early_stopping.py b/keras/src/callbacks/early_stopping.py new file mode 100644 index 000000000000..30fef26b8d9e --- /dev/null +++ b/keras/src/callbacks/early_stopping.py @@ -0,0 +1,154 @@ +import warnings + +from keras.src.api_export import keras_export +from keras.src.callbacks.monitor_callback import MonitorCallback +from keras.src.utils import io_utils + + +@keras_export("keras.callbacks.EarlyStopping") +class EarlyStopping(MonitorCallback): + """Stop training when a monitored metric has stopped improving. + + Assuming the goal of a training is to minimize the loss. With this, the + metric to be monitored would be `'loss'`, and mode would be `'min'`. A + `model.fit()` training loop will check at end of every epoch whether + the loss is no longer decreasing, considering the `min_delta` and + `patience` if applicable. Once it's found no longer decreasing, + `model.stop_training` is marked True and the training terminates. + + The quantity to be monitored needs to be available in `logs` dict. + To make it so, pass the loss or metrics at `model.compile()`. + + Args: + monitor: Quantity to be monitored. Defaults to `"val_loss"`. + min_delta: Minimum change in the monitored quantity to qualify as an + improvement, i.e. an absolute change of less than min_delta, will + count as no improvement. Defaults to `0`. + patience: Number of epochs with no improvement after which training will + be stopped. Defaults to `0`. + verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays + messages when the callback takes an action. Defaults to `0`. + mode: One of `{"auto", "min", "max"}`. In `min` mode, training will stop + when the quantity monitored has stopped decreasing; in `"max"` mode + it will stop when the quantity monitored has stopped increasing; in + `"auto"` mode, the direction is automatically inferred from the name + of the monitored quantity. Defaults to `"auto"`. + baseline: Baseline value for the monitored quantity. If not `None`, + training will stop if the model doesn't show improvement over the + baseline. Defaults to `None`. + restore_best_weights: Whether to restore model weights from the epoch + with the best value of the monitored quantity. If `False`, the model + weights obtained at the last step of training are used. An epoch + will be restored regardless of the performance relative to the + `baseline`. If no epoch improves on `baseline`, training will run + for `patience` epochs and restore weights from the best epoch in + that set. Defaults to `False`. + start_from_epoch: Number of epochs to wait before starting to monitor + improvement. This allows for a warm-up period in which no + improvement is expected and thus training will not be stopped. + Defaults to `0`. + + Example: + + >>> callback = keras.callbacks.EarlyStopping(monitor='loss', + ... patience=3) + >>> # This callback will stop the training when there is no improvement in + >>> # the loss for three consecutive epochs. + >>> model = keras.models.Sequential([keras.layers.Dense(10)]) + >>> model.compile(keras.optimizers.SGD(), loss='mse') + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=10, batch_size=1, callbacks=[callback], + ... verbose=0) + >>> len(history.history['loss']) # Only 4 epochs are run. + 4 + """ + + def __init__( + self, + monitor="val_loss", + min_delta=0, + patience=0, + verbose=0, + mode="auto", + baseline=None, + restore_best_weights=False, + start_from_epoch=0, + ): + super().__init__(monitor, mode, min_delta=min_delta) + self.patience = patience + self.verbose = verbose + self.baseline = baseline + self.wait = 0 + self.stopped_epoch = 0 + self.restore_best_weights = restore_best_weights + self.best_weights = None + self.start_from_epoch = start_from_epoch + + def on_train_begin(self, logs=None): + # Allow instances to be re-used + self.wait = 0 + self.stopped_epoch = 0 + self.best_weights = None + self.best_epoch = 0 + + def on_epoch_end(self, epoch, logs=None): + if self.monitor_op is None: + # Delay setup until the model's metrics are all built + self._set_monitor_op() + + current = self.get_monitor_value(logs) + if current is None or epoch < self.start_from_epoch: + # If no monitor value exists or still in initial warm-up stage. + return + if self.restore_best_weights and self.best_weights is None: + # If best weights were never set, + # then the current weights are the best. + self.best_weights = self.model.get_weights() + self.best_epoch = epoch + + self.wait += 1 + if self._is_improvement(current, self.best): + self.best = current + self.best_epoch = epoch + if self.restore_best_weights: + self.best_weights = self.model.get_weights() + # Only restart wait if we beat both the baseline and our previous + # best. + if self.baseline is None or self._is_improvement( + current, self.baseline + ): + self.wait = 0 + return + + if self.wait >= self.patience and epoch > 0: + # Patience has been exceeded: stop training + self.stopped_epoch = epoch + self.model.stop_training = True + + def on_train_end(self, logs=None): + if self.stopped_epoch > 0 and self.verbose > 0: + io_utils.print_msg( + f"Epoch {self.stopped_epoch + 1}: early stopping" + ) + if self.restore_best_weights and self.best_weights is not None: + if self.verbose > 0: + io_utils.print_msg( + "Restoring model weights from " + "the end of the best epoch: " + f"{self.best_epoch + 1}." + ) + self.model.set_weights(self.best_weights) + + def get_monitor_value(self, logs): + logs = logs or {} + monitor_value = logs.get(self.monitor) + if monitor_value is None: + warnings.warn( + ( + f"Early stopping conditioned on metric `{self.monitor}` " + "which is not available. " + f"Available metrics are: {','.join(list(logs.keys()))}" + ), + stacklevel=2, + ) + return monitor_value diff --git a/keras/src/callbacks/early_stopping_test.py b/keras/src/callbacks/early_stopping_test.py new file mode 100644 index 000000000000..d4b127675e7b --- /dev/null +++ b/keras/src/callbacks/early_stopping_test.py @@ -0,0 +1,271 @@ +import numpy as np +import pytest + +from keras.src import callbacks +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import ops +from keras.src import testing + + +class EarlyStoppingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_early_stopping(self): + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + x_test = np.random.random((10, 5)) + y_test = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile( + loss="mae", + optimizer="adam", + metrics=[ + "mse", + "acc", + "accuracy", + "hinge", + metrics.F1Score(name="f1_score"), + ], + ) + + cases = [ + ("max", "val_mse", "max"), + ("min", "val_loss", "min"), + ("auto", "val_mse", "min"), + ("auto", "loss", "min"), + ("auto", "acc", "max"), + ("auto", "val_accuracy", "max"), + ("auto", "hinge", "min"), + ("auto", "f1_score", "max"), + ] + for mode, monitor, expected_mode in cases: + patience = 0 + cbks = [ + callbacks.EarlyStopping( + patience=patience, monitor=monitor, mode=mode + ) + ] + model.fit( + x_train, + y_train, + batch_size=5, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=2, + verbose=0, + ) + if expected_mode == "max": + monitor_op = ops.greater + else: + monitor_op = ops.less + self.assertEqual(cbks[0].monitor_op, monitor_op) + + with self.assertRaises(ValueError): + cbks = [ + callbacks.EarlyStopping(patience=patience, monitor="unknown") + ] + model.fit( + x_train, + y_train, + batch_size=5, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=2, + verbose=0, + ) + + @pytest.mark.requires_trainable_backend + def test_early_stopping_patience(self): + cases = [0, 1, 2, 3] + losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5] + + for patience in cases: + stopper = callbacks.EarlyStopping(monitor="loss", patience=patience) + stopper.set_model(models.Sequential()) + stopper.model.compile(loss="mse", optimizer="sgd") + stopper.on_train_begin() + + for epoch, loss in enumerate(losses): + stopper.on_epoch_end(epoch=epoch, logs={"loss": loss}) + if stopper.model.stop_training: + break + + self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2) + + @pytest.mark.requires_trainable_backend + def test_early_stopping_reuse(self): + patience = 3 + data = np.random.random((100, 1)) + labels = np.where(data > 0.5, 1, 0) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile( + optimizer="sgd", + loss="mae", + metrics=["mse"], + ) + stopper = callbacks.EarlyStopping(monitor="mse", patience=patience) + + history1 = model.fit( + data, labels, callbacks=[stopper], verbose=0, epochs=20 + ) + self.assertGreaterEqual(len(history1.epoch), patience) + + history2 = model.fit( + data, labels, callbacks=[stopper], verbose=0, epochs=20 + ) + self.assertGreaterEqual(len(history2.epoch), patience) + + @pytest.mark.requires_trainable_backend + def test_early_stopping_with_baseline(self): + baseline = 0.6 + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile(optimizer="sgd", loss="mae", metrics=["mse"]) + + patience = 3 + stopper = callbacks.EarlyStopping( + monitor="mse", patience=patience, baseline=baseline + ) + hist = model.fit( + x_train, y_train, callbacks=[stopper], verbose=0, epochs=20 + ) + assert len(hist.epoch) >= patience + + def test_early_stopping_final_weights_when_restoring_model_weights(self): + class DummyModel: + def __init__(self): + self.stop_training = False + self.weights = -1 + + def get_weights(self): + return self.weights + + def set_weights(self, weights): + self.weights = weights + + def set_weight_to_epoch(self, epoch): + self.weights = epoch + + early_stop = callbacks.EarlyStopping( + monitor="val_loss", patience=2, restore_best_weights=True + ) + early_stop.set_model(DummyModel()) + losses = [0.2, 0.15, 0.1, 0.11, 0.12] + # The best configuration is in the epoch 2 (loss = 0.1000). + epochs_trained = 0 + early_stop.on_train_begin() + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.model.set_weight_to_epoch(epoch=epoch) + early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]}) + if early_stop.model.stop_training: + break + early_stop.on_train_end() + # The best configuration is in epoch 2 (loss = 0.1000), + # and while patience = 2, we're restoring the best weights, + # so we end up at the epoch with the best weights, i.e. epoch 2 + self.assertEqual(early_stop.model.get_weights(), 2) + + # Check early stopping when no model beats the baseline. + early_stop = callbacks.EarlyStopping( + monitor="val_loss", + patience=5, + baseline=0.5, + restore_best_weights=True, + ) + early_stop.set_model(DummyModel()) + losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73] + # The best configuration is in the epoch 2 (loss = 0.7000). + epochs_trained = 0 + early_stop.on_train_begin() + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.model.set_weight_to_epoch(epoch=epoch) + early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]}) + if early_stop.model.stop_training: + break + early_stop.on_train_end() + # No epoch improves on the baseline, so we should train for only 5 + # epochs, and restore the second model. + self.assertEqual(epochs_trained, 5) + self.assertEqual(early_stop.model.get_weights(), 2) + + # Check weight restoration when another callback requests a stop. + early_stop = callbacks.EarlyStopping( + monitor="val_loss", + patience=5, + baseline=0.5, + restore_best_weights=True, + ) + early_stop.set_model(DummyModel()) + losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73] + # The best configuration is in the epoch 2 (loss = 0.7000). + epochs_trained = 0 + early_stop.on_train_begin() + for epoch in range(len(losses)): + epochs_trained += 1 + early_stop.model.set_weight_to_epoch(epoch=epoch) + early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]}) + if epoch == 3: + early_stop.model.stop_training = True + if early_stop.model.stop_training: + break + early_stop.on_train_end() + # We should restore the second model. + self.assertEqual(epochs_trained, 4) + self.assertEqual(early_stop.model.get_weights(), 2) + + @pytest.mark.requires_trainable_backend + def test_early_stopping_with_start_from_epoch(self): + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile(optimizer="sgd", loss="mae", metrics=["mse"]) + start_from_epoch = 2 + patience = 3 + stopper = callbacks.EarlyStopping( + monitor="mse", + patience=patience, + start_from_epoch=start_from_epoch, + ) + history = model.fit( + x_train, y_train, callbacks=[stopper], verbose=0, epochs=20 + ) + # Test 'patience' argument functions correctly when used + # in conjunction with 'start_from_epoch'. + self.assertGreaterEqual(len(history.epoch), patience + start_from_epoch) + + start_from_epoch = 2 + patience = 0 + stopper = callbacks.EarlyStopping( + monitor="mse", + patience=patience, + start_from_epoch=start_from_epoch, + ) + history = model.fit( + x_train, y_train, callbacks=[stopper], verbose=0, epochs=20 + ) + # Test for boundary condition when 'patience' = 0. + self.assertGreaterEqual(len(history.epoch), start_from_epoch) diff --git a/keras/src/callbacks/history.py b/keras/src/callbacks/history.py new file mode 100644 index 000000000000..6fb3c3c86171 --- /dev/null +++ b/keras/src/callbacks/history.py @@ -0,0 +1,42 @@ +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback + + +@keras_export("keras.callbacks.History") +class History(Callback): + """Callback that records events into a `History` object. + + This callback is automatically applied to + every Keras model. The `History` object + gets returned by the `fit()` method of models. + + Example: + + >>> model = Sequential([layers.Dense(10)]) + >>> model.compile(SGD(), loss='mse') + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=10, verbose=1) + >>> print(history.params) + {'verbose': 1, 'epochs': 10, 'steps': 1} + >>> # check the keys of history object + >>> print(history.history.keys()) + dict_keys(['loss']) + + """ + + def __init__(self): + super().__init__() + self.history = {} + + def on_train_begin(self, logs=None): + self.epoch = [] + + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + self.epoch.append(epoch) + for k, v in logs.items(): + self.history.setdefault(k, []).append(v) + + # Set the history attribute on the model after the epoch ends. This will + # make sure that the state which is set is the latest one. + self.model.history = self diff --git a/keras/src/callbacks/lambda_callback.py b/keras/src/callbacks/lambda_callback.py new file mode 100644 index 000000000000..4a391167ef17 --- /dev/null +++ b/keras/src/callbacks/lambda_callback.py @@ -0,0 +1,87 @@ +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback + + +@keras_export("keras.callbacks.LambdaCallback") +class LambdaCallback(Callback): + """Callback for creating simple, custom callbacks on-the-fly. + + This callback is constructed with anonymous functions that will be called + at the appropriate time (during `Model.{fit | evaluate | predict}`). + Note that the callbacks expects positional arguments, as: + + - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: + `epoch`, `logs` + - `on_train_begin` and `on_train_end` expect one positional argument: + `logs` + - `on_train_batch_begin` and `on_train_batch_end` expect a positional + argument `batch` and a keyword argument `logs` + - See `Callback` class definition for the full list of functions and their + expected arguments. + + Args: + on_epoch_begin: called at the beginning of every epoch. + on_epoch_end: called at the end of every epoch. + on_train_begin: called at the beginning of model training. + on_train_end: called at the end of model training. + on_train_batch_begin: called at the beginning of every train batch. + on_train_batch_end: called at the end of every train batch. + kwargs: Any function in `Callback` that you want to override by + passing `function_name=function`. For example, + `LambdaCallback(.., on_train_end=train_end_fn)`. The custom function + needs to have same arguments as the ones defined in `Callback`. + + Example: + + ```python + # Print the batch number at the beginning of every batch. + batch_print_callback = LambdaCallback( + on_train_batch_begin=lambda batch,logs: print(batch)) + + # Stream the epoch loss to a file in JSON format. The file content + # is not well-formed JSON but rather has a JSON object per line. + import json + json_log = open('loss_log.json', mode='wt', buffering=1) + json_logging_callback = LambdaCallback( + on_epoch_end=lambda epoch, logs: json_log.write( + json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), + on_train_end=lambda logs: json_log.close() + ) + + # Terminate some processes after having finished model training. + processes = ... + cleanup_callback = LambdaCallback( + on_train_end=lambda logs: [ + p.terminate() for p in processes if p.is_alive()]) + + model.fit(..., + callbacks=[batch_print_callback, + json_logging_callback, + cleanup_callback]) + ``` + """ + + def __init__( + self, + on_epoch_begin=None, + on_epoch_end=None, + on_train_begin=None, + on_train_end=None, + on_train_batch_begin=None, + on_train_batch_end=None, + **kwargs, + ): + super().__init__() + self.__dict__.update(kwargs) + if on_epoch_begin is not None: + self.on_epoch_begin = on_epoch_begin + if on_epoch_end is not None: + self.on_epoch_end = on_epoch_end + if on_train_begin is not None: + self.on_train_begin = on_train_begin + if on_train_end is not None: + self.on_train_end = on_train_end + if on_train_batch_begin is not None: + self.on_train_batch_begin = on_train_batch_begin + if on_train_batch_end is not None: + self.on_train_batch_end = on_train_batch_end diff --git a/keras/src/callbacks/lambda_callback_test.py b/keras/src/callbacks/lambda_callback_test.py new file mode 100644 index 000000000000..4c8a6add2146 --- /dev/null +++ b/keras/src/callbacks/lambda_callback_test.py @@ -0,0 +1,165 @@ +import numpy as np +import pytest +from absl import logging + +from keras.src import callbacks +from keras.src import layers +from keras.src import losses +from keras.src import optimizers +from keras.src import testing +from keras.src.models.sequential import Sequential + + +class LambdaCallbackTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_lambda_callback(self): + """Test standard LambdaCallback functionalities with training.""" + batch_size = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + y = np.random.randn(16, 1) + lambda_log_callback = callbacks.LambdaCallback( + on_train_begin=lambda logs: logging.warning("on_train_begin"), + on_epoch_begin=lambda epoch, logs: logging.warning( + "on_epoch_begin" + ), + on_epoch_end=lambda epoch, logs: logging.warning("on_epoch_end"), + on_train_end=lambda logs: logging.warning("on_train_end"), + ) + with self.assertLogs(level="WARNING") as logs: + model.fit( + x, + y, + batch_size=batch_size, + validation_split=0.2, + callbacks=[lambda_log_callback], + epochs=5, + verbose=0, + ) + self.assertTrue(any("on_train_begin" in log for log in logs.output)) + self.assertTrue(any("on_epoch_begin" in log for log in logs.output)) + self.assertTrue(any("on_epoch_end" in log for log in logs.output)) + self.assertTrue(any("on_train_end" in log for log in logs.output)) + + @pytest.mark.requires_trainable_backend + def test_lambda_callback_with_batches(self): + """Test LambdaCallback's behavior with batch-level callbacks.""" + batch_size = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + y = np.random.randn(16, 1) + lambda_log_callback = callbacks.LambdaCallback( + on_train_batch_begin=lambda batch, logs: logging.warning( + "on_train_batch_begin" + ), + on_train_batch_end=lambda batch, logs: logging.warning( + "on_train_batch_end" + ), + ) + with self.assertLogs(level="WARNING") as logs: + model.fit( + x, + y, + batch_size=batch_size, + validation_split=0.2, + callbacks=[lambda_log_callback], + epochs=5, + verbose=0, + ) + self.assertTrue( + any("on_train_batch_begin" in log for log in logs.output) + ) + self.assertTrue( + any("on_train_batch_end" in log for log in logs.output) + ) + + @pytest.mark.requires_trainable_backend + def test_lambda_callback_with_kwargs(self): + """Test LambdaCallback's behavior with custom defined callback.""" + batch_size = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + y = np.random.randn(16, 1) + model.fit( + x, y, batch_size=batch_size, epochs=1, verbose=0 + ) # Train briefly for evaluation to work. + + def custom_on_test_begin(logs): + logging.warning("custom_on_test_begin_executed") + + lambda_log_callback = callbacks.LambdaCallback( + on_test_begin=custom_on_test_begin + ) + with self.assertLogs(level="WARNING") as logs: + model.evaluate( + x, + y, + batch_size=batch_size, + callbacks=[lambda_log_callback], + verbose=0, + ) + self.assertTrue( + any( + "custom_on_test_begin_executed" in log + for log in logs.output + ) + ) + + @pytest.mark.requires_trainable_backend + def test_lambda_callback_no_args(self): + """Test initializing LambdaCallback without any arguments.""" + lambda_callback = callbacks.LambdaCallback() + self.assertIsInstance(lambda_callback, callbacks.LambdaCallback) + + @pytest.mark.requires_trainable_backend + def test_lambda_callback_with_additional_kwargs(self): + """Test initializing LambdaCallback with non-predefined kwargs.""" + + def custom_callback(logs): + pass + + lambda_callback = callbacks.LambdaCallback( + custom_method=custom_callback + ) + self.assertTrue(hasattr(lambda_callback, "custom_method")) + + @pytest.mark.requires_trainable_backend + def test_lambda_callback_during_prediction(self): + """Test LambdaCallback's functionality during model prediction.""" + batch_size = 4 + model = Sequential( + [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)] + ) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.random.randn(16, 2) + + def custom_on_predict_begin(logs): + logging.warning("on_predict_begin_executed") + + lambda_callback = callbacks.LambdaCallback( + on_predict_begin=custom_on_predict_begin + ) + with self.assertLogs(level="WARNING") as logs: + model.predict( + x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0 + ) + self.assertTrue( + any("on_predict_begin_executed" in log for log in logs.output) + ) diff --git a/keras/src/callbacks/learning_rate_scheduler.py b/keras/src/callbacks/learning_rate_scheduler.py new file mode 100644 index 000000000000..6ac1486e8797 --- /dev/null +++ b/keras/src/callbacks/learning_rate_scheduler.py @@ -0,0 +1,81 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.utils import io_utils + + +@keras_export("keras.callbacks.LearningRateScheduler") +class LearningRateScheduler(Callback): + """Learning rate scheduler. + + At the beginning of every epoch, this callback gets the updated learning + rate value from `schedule` function provided at `__init__`, with the current + epoch and current learning rate, and applies the updated learning rate on + the optimizer. + + Args: + schedule: A function that takes an epoch index (integer, indexed from 0) + and current learning rate (float) as inputs and returns a new + learning rate as output (float). + verbose: Integer. 0: quiet, 1: log update messages. + + Example: + + >>> # This function keeps the initial learning rate for the first ten epochs + >>> # and decreases it exponentially after that. + >>> def scheduler(epoch, lr): + ... if epoch < 10: + ... return lr + ... else: + ... return lr * ops.exp(-0.1) + >>> + >>> model = keras.models.Sequential([keras.layers.Dense(10)]) + >>> model.compile(keras.optimizers.SGD(), loss='mse') + >>> round(model.optimizer.learning_rate, 5) + 0.01 + + >>> callback = keras.callbacks.LearningRateScheduler(scheduler) + >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), + ... epochs=15, callbacks=[callback], verbose=0) + >>> round(model.optimizer.learning_rate, 5) + 0.00607 + + """ + + def __init__(self, schedule, verbose=0): + super().__init__() + self.schedule = schedule + self.verbose = verbose + + def on_epoch_begin(self, epoch, logs=None): + if not hasattr(self.model.optimizer, "learning_rate"): + raise ValueError('Optimizer must have a "learning_rate" attribute.') + + try: # new API + learning_rate = float( + backend.convert_to_numpy(self.model.optimizer.learning_rate) + ) + learning_rate = self.schedule(epoch, learning_rate) + except TypeError: # Support for old API for backward compatibility + learning_rate = self.schedule(epoch) + + if not isinstance(learning_rate, (float, np.float32, np.float64)): + raise ValueError( + "The output of the `schedule` function should be a float. " + f"Got: {learning_rate}" + ) + + self.model.optimizer.learning_rate = learning_rate + if self.verbose > 0: + io_utils.print_msg( + f"\nEpoch {epoch + 1}: LearningRateScheduler setting learning " + f"rate to {learning_rate}." + ) + + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + logs["learning_rate"] = float( + backend.convert_to_numpy(self.model.optimizer.learning_rate) + ) diff --git a/keras/src/callbacks/learning_rate_scheduler_test.py b/keras/src/callbacks/learning_rate_scheduler_test.py new file mode 100644 index 000000000000..b76bcf8cf3cf --- /dev/null +++ b/keras/src/callbacks/learning_rate_scheduler_test.py @@ -0,0 +1,124 @@ +import pytest + +from keras.src import callbacks +from keras.src import layers +from keras.src import optimizers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.testing import test_utils +from keras.src.utils import io_utils +from keras.src.utils import numerical_utils + + +class LearningRateSchedulerTest(testing.TestCase): + def setUp(self): + (x_train, y_train), _ = test_utils.get_test_data( + train_samples=10, + test_samples=10, + input_shape=(3,), + num_classes=2, + ) + y_train = numerical_utils.to_categorical(y_train) + + model = Sequential([layers.Dense(5), layers.Dense(2)]) + + model.compile( + loss="mse", + optimizer="sgd", + ) + + self.model = model + self.x_train = x_train + self.y_train = y_train + + @pytest.mark.requires_trainable_backend + def test_updates_learning_rate(self): + lr_scheduler = callbacks.LearningRateScheduler( + lambda step: 1.0 / (2.0 + step), verbose=1 + ) + + self.model.fit( + self.x_train, + self.y_train, + callbacks=[lr_scheduler], + epochs=1, + ) + + self.assertEqual(self.model.optimizer.learning_rate.value, 0.5) + + @pytest.mark.requires_trainable_backend + def test_verbose_logging(self): + lr_scheduler = callbacks.LearningRateScheduler( + lambda step: 1.0 / (1.0 + step), verbose=1 + ) + io_utils.disable_interactive_logging() + io_utils.set_logging_verbosity("INFO") + + with self.assertLogs() as logs: + self.model.fit( + self.x_train, + self.y_train, + callbacks=[lr_scheduler], + epochs=1, + ) + expected_log = "LearningRateScheduler setting learning rate to 1.0" + self.assertTrue(any(expected_log in log for log in logs.output)) + + @pytest.mark.requires_trainable_backend + def test_schedule_dependent_on_previous_learning_rate(self): + lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2) + + initial_lr = 0.03 + self.model.compile( + loss="mse", + optimizer=optimizers.Adam(initial_lr), + ) + + self.model.fit( + self.x_train, + self.y_train, + callbacks=[lr_scheduler], + epochs=2, + ) + self.assertEqual( + self.model.optimizer.learning_rate.value, initial_lr / 4.0 + ) + + @pytest.mark.requires_trainable_backend + def test_throws_when_optimizer_has_schedule(self): + lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2) + + self.model.compile( + loss="mse", + optimizer=optimizers.Adam( + optimizers.schedules.PolynomialDecay( + initial_learning_rate=0.1, decay_steps=10 + ) + ), + ) + + with self.assertRaisesRegex( + TypeError, + "This optimizer was created with a `LearningRateSchedule`", + ): + self.model.fit( + self.x_train, + self.y_train, + callbacks=[lr_scheduler], + epochs=2, + ) + + @pytest.mark.requires_trainable_backend + def test_learning_rate_in_history(self): + lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: 0.5) + + history = self.model.fit( + self.x_train, + self.y_train, + callbacks=[lr_scheduler], + epochs=1, + ) + + self.assertTrue("learning_rate" in history.history) + self.assertEqual(type(history.history["learning_rate"][0]), float) + self.assertEqual(history.history["learning_rate"][0], 0.5) diff --git a/keras/src/callbacks/model_checkpoint.py b/keras/src/callbacks/model_checkpoint.py new file mode 100644 index 000000000000..6143cbfa8fcf --- /dev/null +++ b/keras/src/callbacks/model_checkpoint.py @@ -0,0 +1,412 @@ +import os +import re +import warnings + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.callbacks.monitor_callback import MonitorCallback +from keras.src.utils import file_utils +from keras.src.utils import io_utils + + +@keras_export("keras.callbacks.ModelCheckpoint") +class ModelCheckpoint(MonitorCallback): + """Callback to save the Keras model or model weights at some frequency. + + `ModelCheckpoint` callback is used in conjunction with training using + `model.fit()` to save a model or weights (in a checkpoint file) at some + interval, so the model or weights can be loaded later to continue the + training from the state saved. + + A few options this callback provides include: + + - Whether to only keep the model that has achieved the "best performance" so + far, or whether to save the model at the end of every epoch regardless of + performance. + - Definition of "best"; which quantity to monitor and whether it should be + maximized or minimized. + - The frequency it should save at. Currently, the callback supports saving + at the end of every epoch, or after a fixed number of training batches. + - Whether only weights are saved, or the whole model is saved. + + Example: + + ```python + model.compile(loss=..., optimizer=..., + metrics=['accuracy']) + + EPOCHS = 10 + checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras' + model_checkpoint_callback = keras.callbacks.ModelCheckpoint( + filepath=checkpoint_filepath, + monitor='val_accuracy', + mode='max', + save_best_only=True) + + # Model is saved at the end of every epoch, if it's the best seen so far. + model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) + + # The model (that are considered the best) can be loaded as - + keras.models.load_model(checkpoint_filepath) + + # Alternatively, one could checkpoint just the model weights as - + checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5' + model_checkpoint_callback = keras.callbacks.ModelCheckpoint( + filepath=checkpoint_filepath, + save_weights_only=True, + monitor='val_accuracy', + mode='max', + save_best_only=True) + + # Model weights are saved at the end of every epoch, if it's the best seen + # so far. + model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) + + # The model weights (that are considered the best) can be loaded as - + model.load_weights(checkpoint_filepath) + ``` + + Args: + filepath: string or `PathLike`, path to save the model file. + `filepath` can contain named formatting options, + which will be filled the value of `epoch` and keys in `logs` + (passed in `on_epoch_end`). + The `filepath` name needs to end with `".weights.h5"` when + `save_weights_only=True` or should end with `".keras"` or `".h5"` + when checkpoint saving the whole model (default). + For example: + if `filepath` is `"{epoch:02d}-{val_loss:.2f}.keras"` or + "{epoch:02d}-{val_loss:.2f}.weights.h5"`, then the model + checkpoints will be saved with the epoch number and the validation + loss in the filename. The directory of the filepath + should not be reused by any other callbacks to avoid conflicts. + monitor: The metric name to monitor. Typically the metrics are set by + the `Model.compile` method. Note: + * Prefix the name with `"val_"` to monitor validation metrics. + * Use `"loss"` or `"val_loss"` to monitor the model's total loss. + * If you specify metrics as strings, like `"accuracy"`, pass the + same string (with or without the `"val_"` prefix). + * If you pass `metrics.Metric` objects, `monitor` should be set to + `metric.name` + * If you're not sure about the metric names you can check the + contents of the `history.history` dictionary returned by + `history = model.fit()` + * Multi-output models set additional prefixes on the metric names. + verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 + displays messages when the callback takes an action. + save_best_only: if `save_best_only=True`, it only saves when the model + is considered the "best" and the latest best model according to the + quantity monitored will not be overwritten. If `filepath` doesn't + contain formatting options like `{epoch}` then `filepath` will be + overwritten by each new better model. + mode: one of {`"auto"`, `"min"`, `"max"`}. If `save_best_only=True`, the + decision to overwrite the current save file is made based on either + the maximization or the minimization of the monitored quantity. + For `val_acc`, this should be `"max"`, for `val_loss` this should be + `"min"`, etc. In `"auto"` mode, the direction is automatically + inferred from the name of the monitored quantity. + save_weights_only: if `True`, then only the model's weights will be + saved (`model.save_weights(filepath)`), else the full model is + saved (`model.save(filepath)`). + save_freq: `"epoch"` or integer. When using `"epoch"`, the callback + saves the model after each epoch. When using integer, the callback + saves the model at end of this many batches. If the `Model` is + compiled with `steps_per_execution=N`, then the saving criteria will + be checked every Nth batch. Note that if the saving isn't aligned to + epochs, the monitored metric may potentially be less reliable (it + could reflect as little as 1 batch, since the metrics get reset + every epoch). Defaults to `"epoch"`. + initial_value_threshold: Floating point initial "best" value of the + metric to be monitored. Only applies if `save_best_value=True`. Only + overwrites the model weights already saved if the performance of + current model is better than this value. + """ + + def __init__( + self, + filepath, + monitor="val_loss", + verbose=0, + save_best_only=False, + save_weights_only=False, + mode="auto", + save_freq="epoch", + initial_value_threshold=None, + ): + super().__init__(monitor, mode, initial_value_threshold) + self.verbose = verbose + self.filepath = file_utils.path_to_string(filepath) + self.save_best_only = save_best_only + self.save_weights_only = save_weights_only + self.save_freq = save_freq + self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 + + if self.save_freq != "epoch" and not isinstance(self.save_freq, int): + raise ValueError( + f"Unrecognized save_freq: {self.save_freq}. " + "Expected save_freq are 'epoch' or integer values" + ) + + if save_weights_only: + if not self.filepath.endswith(".weights.h5"): + raise ValueError( + "When using `save_weights_only=True` in `ModelCheckpoint`" + ", the filepath provided must end in `.weights.h5` " + "(Keras weights format). Received: " + f"filepath={self.filepath}" + ) + else: + if not any( + self.filepath.endswith(ext) for ext in (".keras", ".h5") + ): + raise ValueError( + "The filepath provided must end in `.keras` " + "(Keras model format). Received: " + f"filepath={self.filepath}" + ) + + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) + + def on_epoch_begin(self, epoch, logs=None): + self._current_epoch = epoch + + def on_epoch_end(self, epoch, logs=None): + if self.monitor_op is None: + # Delay setup until the model's metrics are all built + self._set_monitor_op() + + if self.save_freq == "epoch": + self._save_model(epoch=epoch, batch=None, logs=logs) + + def _should_save_on_batch(self, batch): + """Handles batch-level saving logic, supports steps_per_execution.""" + if self.save_freq == "epoch": + return False + if batch <= self._last_batch_seen: # New epoch. + add_batches = batch + 1 # batches are zero-indexed. + else: + add_batches = batch - self._last_batch_seen + self._batches_seen_since_last_saving += add_batches + self._last_batch_seen = batch + + if self._batches_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving = 0 + return True + return False + + def _should_save_model(self, epoch, batch, logs, filepath): + """Determines whether the model should be saved. + + The model should be saved in the following cases: + + - self.save_best_only is False + - self.save_best_only is True and `monitor` is a numpy array or + backend tensor (falls back to `save_best_only=False`) + - self.save_best_only is True and `self.monitor_op(current, self.best)` + evaluates to True. + + Args: + epoch: the epoch this iteration is in. + batch: the batch this iteration is in. `None` if the `save_freq` + is set to `"epoch"`. + logs: the `logs` dict passed in to `on_batch_end` or + `on_epoch_end`. + filepath: the path where the model would be saved + """ + logs = logs or {} + if self.save_best_only: + current = logs.get(self.monitor) + if current is None: + warnings.warn( + f"Can save best model only with {self.monitor} available.", + stacklevel=2, + ) + return True + elif ( + isinstance(current, np.ndarray) or backend.is_tensor(current) + ) and len(current.shape) > 0: + warnings.warn( + "Can save best model only when `monitor` is " + f"a scalar value. Received: {current}. " + "Falling back to `save_best_only=False`." + ) + return True + else: + best_str = "None" if self.best is None else f"{self.best:.5f}" + if self._is_improvement(current, self.best): + if self.verbose > 0: + io_utils.print_msg( + f"\nEpoch {epoch + 1}: {self.monitor} " + f"improved from {best_str} to {current:.5f}, " + f"saving model to {filepath}" + ) + self.best = current + return True + else: + if self.verbose > 0: + io_utils.print_msg( + f"\nEpoch {epoch + 1}: " + f"{self.monitor} did not improve from {best_str}" + ) + return False + else: + if self.verbose > 0: + io_utils.print_msg( + f"\nEpoch {epoch + 1}: saving model to {filepath}" + ) + return True + + def _save_model(self, epoch, batch, logs): + """Saves the model. + + Args: + epoch: the epoch this iteration is in. + batch: the batch this iteration is in. `None` if the `save_freq` + is set to `"epoch"`. + logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. + """ + filepath = self._get_file_path(epoch, batch, logs) + + try: + if self._should_save_model(epoch, batch, logs, filepath): + # Create host directory if it doesn't exist. + dirname = os.path.dirname(filepath) + if dirname and not file_utils.exists(dirname): + file_utils.makedirs(dirname) + + if self.save_weights_only: + self.model.save_weights(filepath, overwrite=True) + else: + self.model.save(filepath, overwrite=True) + except IsADirectoryError: # h5py 3.x + raise IOError( + "Please specify a non-directory filepath for " + "ModelCheckpoint. Filepath used is an existing " + f"directory: {filepath}" + ) + except IOError as e: # h5py 2.x + # `e.errno` appears to be `None` so checking the content of + # `e.args[0]`. + if "is a directory" in str(e.args[0]).lower(): + raise IOError( + "Please specify a non-directory filepath for " + "ModelCheckpoint. Filepath used is an existing " + f"directory: f{filepath}" + ) + # Re-throw the error for any other causes. + raise e + + def _get_file_path(self, epoch, batch, logs): + """Returns the file path for checkpoint.""" + + try: + # `filepath` may contain placeholders such as + # `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between + # logged metrics and the path's placeholders can cause formatting to + # fail. + if batch is None or "batch" in logs: + file_path = self.filepath.format(epoch=epoch + 1, **logs) + else: + file_path = self.filepath.format( + epoch=epoch + 1, batch=batch + 1, **logs + ) + except KeyError as e: + raise KeyError( + f'Failed to format this callback filepath: "{self.filepath}". ' + f"Reason: {e}" + ) + return file_path + + def _checkpoint_exists(self, filepath): + """Returns whether the checkpoint `filepath` refers to exists.""" + return file_utils.exists(filepath) + + def _get_most_recently_modified_file_matching_pattern(self, pattern): + """Returns the most recently modified filepath matching pattern. + + In the rare case where there are more than one pattern-matching file + having the same modified time that is most recent among all, return the + filepath that is largest (by `>` operator, lexicographically using the + numeric equivalents). This provides a tie-breaker when multiple files + are most recent. Note that a larger `filepath` can sometimes indicate a + later time of modification (for instance, when epoch/batch is used as + formatting option), but not necessarily (when accuracy or loss is used). + The tie-breaker is put in the logic as best effort to return the most + recent, and to avoid nondeterministic result. + + Modified time of a file is obtained with `os.path.getmtime()`. + + This utility function is best demonstrated via an example: + + ```python + file_pattern = 'batch{batch:02d}epoch{epoch:02d}.keras' + test_dir = self.get_temp_dir() + path_pattern = os.path.join(test_dir, file_pattern) + file_paths = [ + os.path.join(test_dir, file_name) for file_name in + ['batch03epoch02.keras', + 'batch02epoch02.keras', 'batch01epoch01.keras'] + ] + for file_path in file_paths: + # Write something to each of the files + ... + self.assertEqual( + _get_most_recently_modified_file_matching_pattern(path_pattern), + file_paths[-1]) + ``` + + Args: + pattern: The file pattern that may optionally contain python + placeholder such as `{epoch:02d}`. + + Returns: + The most recently modified file's full filepath matching `pattern`. + If `pattern` does not contain any placeholder, this returns the + filepath that exactly matches `pattern`. Returns `None` if no match + is found. + """ + dir_name = os.path.dirname(pattern) + base_name = os.path.basename(pattern) + base_name_regex = f"^{re.sub(r'{.*}', r'.*', base_name)}$" + + latest_mod_time = 0 + file_path_with_latest_mod_time = None + n_file_with_latest_mod_time = 0 + file_path_with_largest_file_name = None + + if file_utils.exists(dir_name): + for file_name in os.listdir(dir_name): + # Only consider if `file_name` matches the pattern. + if re.match(base_name_regex, file_name): + file_path = os.path.join(dir_name, file_name) + mod_time = os.path.getmtime(file_path) + if ( + file_path_with_largest_file_name is None + or file_path > file_path_with_largest_file_name + ): + file_path_with_largest_file_name = file_path + if mod_time > latest_mod_time: + latest_mod_time = mod_time + file_path_with_latest_mod_time = file_path + # In the case a file with later modified time is found, + # reset the counter for the number of files with latest + # modified time. + n_file_with_latest_mod_time = 1 + elif mod_time == latest_mod_time: + # In the case a file has modified time tied with the + # most recent, increment the counter for the number of + # files with latest modified time by 1. + n_file_with_latest_mod_time += 1 + + if n_file_with_latest_mod_time == 1: + # Return the sole file that has most recent modified time. + return file_path_with_latest_mod_time + else: + # If there are more than one file having latest modified time, + # return the file path with the largest file name. + return file_path_with_largest_file_name diff --git a/keras/src/callbacks/model_checkpoint_test.py b/keras/src/callbacks/model_checkpoint_test.py new file mode 100644 index 000000000000..2a2def35878c --- /dev/null +++ b/keras/src/callbacks/model_checkpoint_test.py @@ -0,0 +1,583 @@ +import os +import warnings + +import pytest + +from keras.src import callbacks +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import saving +from keras.src import testing +from keras.src.models import Sequential +from keras.src.testing import test_utils +from keras.src.utils import numerical_utils + +try: + import h5py +except ImportError: + h5py = None + +TRAIN_SAMPLES = 30 +TEST_SAMPLES = 30 +NUM_CLASSES = 3 +INPUT_DIM = 3 +NUM_HIDDEN = 5 +BATCH_SIZE = 5 + + +class ModelCheckpointTest(testing.TestCase): + @pytest.mark.skipif( + h5py is None, + reason="`h5py` is a required dependency for `ModelCheckpoint` tests.", + ) + @pytest.mark.skipif( + testing.jax_uses_gpu(), + reason="Mysterious core dump on CI after upgrading JAX", + ) + @pytest.mark.requires_trainable_backend + def test_model_checkpoint_options(self): + def get_model(): + model = Sequential( + [ + layers.Dense(NUM_HIDDEN, activation="relu"), + layers.Dense(NUM_CLASSES, activation="softmax"), + ] + ) + model.compile( + loss="categorical_crossentropy", + optimizer="sgd", + metrics=[metrics.Accuracy("acc")], + ) + return model + + model = get_model() + temp_dir = self.get_temp_dir() + + # Save model to a subdir inside the temp_dir so we can test + # automatic directory creation. + filepath = os.path.join(temp_dir, "subdir", "checkpoint.keras") + (x_train, y_train), (x_test, y_test) = test_utils.get_test_data( + random_seed=42, + train_samples=TRAIN_SAMPLES, + test_samples=TEST_SAMPLES, + input_shape=(INPUT_DIM,), + num_classes=NUM_CLASSES, + ) + y_test = numerical_utils.to_categorical(y_test, num_classes=NUM_CLASSES) + y_train = numerical_utils.to_categorical( + y_train, num_classes=NUM_CLASSES + ) + + # Case 1 + monitor = "val_loss" + save_best_only = False + mode = "auto" + + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + os.remove(filepath) + + # Case 2 + mode = "min" + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + os.remove(filepath) + + # Case 3 + mode = "max" + monitor = "val_acc" + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + os.remove(filepath) + + # Case 4 + save_best_only = True + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + os.remove(filepath) + + # Case 5: metric not available. + cbks = [ + callbacks.ModelCheckpoint( + filepath, monitor="unknown", save_best_only=True, mode="min" + ) + ] + with pytest.warns(UserWarning): + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + + # Case 6 + with warnings.catch_warnings(record=True) as warning_logs: + warnings.simplefilter("always") + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode="unknown", + ) + self.assertIn( + "ModelCheckpoint mode 'unknown' is unknown", + str(warning_logs[-1].message), + ) + + # Case 8a: `ModelCheckpoint` with an integer `save_freq` + temp_dir = self.get_temp_dir() + filepath = os.path.join(temp_dir, "checkpoint.epoch{epoch:02d}.keras") + save_best_only = False + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode=mode, + save_freq=15, + ) + ] + self.assertFalse(os.path.exists(filepath.format(epoch=3))) + model.fit( + x_train, + y_train, + batch_size=6, # 5 batches / epoch, so should backup every 3 epochs + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=10, + verbose=0, + ) + self.assertFalse(os.path.exists(filepath.format(epoch=1))) + self.assertFalse(os.path.exists(filepath.format(epoch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=3))) + self.assertFalse(os.path.exists(filepath.format(epoch=4))) + self.assertFalse(os.path.exists(filepath.format(epoch=5))) + self.assertTrue(os.path.exists(filepath.format(epoch=6))) + self.assertFalse(os.path.exists(filepath.format(epoch=7))) + self.assertFalse(os.path.exists(filepath.format(epoch=8))) + self.assertTrue(os.path.exists(filepath.format(epoch=9))) + os.remove(filepath.format(epoch=3)) + os.remove(filepath.format(epoch=6)) + os.remove(filepath.format(epoch=9)) + + # Case 8b: `ModelCheckpoint` with int `save_freq` & `save_weights_only` + temp_dir = self.get_temp_dir() + filepath = os.path.join( + temp_dir, "checkpoint.epoch{epoch:02d}.weights.h5" + ) + cbks = [ + callbacks.ModelCheckpoint( + filepath, monitor=monitor, save_freq=15, save_weights_only=True + ) + ] + self.assertFalse(os.path.exists(filepath.format(epoch=3))) + model.fit( + x_train, + y_train, + batch_size=6, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=10, + verbose=0, + ) + self.assertFalse(os.path.exists(filepath.format(epoch=1))) + self.assertFalse(os.path.exists(filepath.format(epoch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=3))) + self.assertFalse(os.path.exists(filepath.format(epoch=4))) + self.assertFalse(os.path.exists(filepath.format(epoch=5))) + self.assertTrue(os.path.exists(filepath.format(epoch=6))) + self.assertFalse(os.path.exists(filepath.format(epoch=7))) + self.assertFalse(os.path.exists(filepath.format(epoch=8))) + self.assertTrue(os.path.exists(filepath.format(epoch=9))) + + # Case 9: `ModelCheckpoint` with valid and invalid save_freq argument. + with self.assertRaisesRegex(ValueError, "Unrecognized save_freq"): + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + save_weights_only=True, + mode=mode, + save_freq="invalid_save_freq", + ) + # The following should not raise ValueError. + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + save_weights_only=True, + mode=mode, + save_freq="epoch", + ) + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + save_weights_only=True, + mode=mode, + save_freq=3, + ) + + # Case 10a: `ModelCheckpoint` save with batch in filename. + temp_dir = self.get_temp_dir() + filepath = os.path.join( + temp_dir, "checkpoint.epoch{epoch:02d}batch{batch:02d}.keras" + ) + cbks = [ + callbacks.ModelCheckpoint(filepath, monitor=monitor, save_freq=1) + ] + model.fit( + x_train, + y_train, + batch_size=15, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=5, + verbose=1, + ) + self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=2))) + + # Case 10b: `ModelCheckpoint` save weights with batch in filename. + temp_dir = self.get_temp_dir() + filepath = os.path.join( + temp_dir, "checkpoint.epoch{epoch:02d}batch{batch:02d}.weights.h5" + ) + cbks = [ + callbacks.ModelCheckpoint( + filepath, monitor=monitor, save_freq=1, save_weights_only=True + ) + ] + model.fit( + x_train, + y_train, + batch_size=15, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=5, + verbose=1, + ) + + self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=2))) + self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=1))) + self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=2))) + + # Case 11: ModelCheckpoint saves model with initial_value_threshold + # param + mode = "max" + monitor = "val_acc" + initial_value_threshold = -0.01 + save_best_only = True + filepath = os.path.join(temp_dir, "checkpoint.keras") + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + initial_value_threshold=initial_value_threshold, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + os.remove(filepath) + + # Case 12: ModelCheckpoint saves model with initial_value_threshold + # param + mode = "auto" + monitor = "val_loss" + initial_value_threshold = None + save_best_only = True + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + initial_value_threshold=initial_value_threshold, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertTrue(os.path.exists(filepath)) + os.remove(filepath) + + # Case 13: ModelCheckpoint doesn't save model if loss was minimum + # earlier + mode = "min" + monitor = "val_loss" + initial_value_threshold = 0 + save_best_only = True + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + initial_value_threshold=initial_value_threshold, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertFalse(os.path.exists(filepath)) + + # Case 14: ModelCheckpoint doesn't save model if loss was min earlier in + # auto mode + mode = "auto" + monitor = "val_loss" + initial_value_threshold = 0 + save_best_only = True + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + initial_value_threshold=initial_value_threshold, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertFalse(os.path.exists(filepath)) + + # Case 15: ModelCheckpoint doesn't save model if auc was max earlier in + # auto mode + mode = "auto" + monitor = "val_auc" + initial_value_threshold = 1 + save_best_only = True + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + initial_value_threshold=initial_value_threshold, + mode=mode, + ) + ] + model.compile( + loss="categorical_crossentropy", + optimizer="sgd", + metrics=[metrics.AUC()], + ) + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + self.assertFalse(os.path.exists(filepath)) + + @pytest.mark.skipif( + h5py is None, + reason="`h5py` is a required dependency for `ModelCheckpoint` tests.", + ) + @pytest.mark.requires_trainable_backend + def test_model_checkpoint_loading(self): + def get_model(): + inputs = layers.Input(shape=(INPUT_DIM,), batch_size=5) + x = layers.Dense(NUM_HIDDEN, activation="relu")(inputs) + outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x) + functional_model = models.Model(inputs, outputs) + functional_model.compile( + loss="categorical_crossentropy", + optimizer="sgd", + metrics=[metrics.Accuracy("acc")], + ) + return functional_model + + (x_train, y_train), (x_test, y_test) = test_utils.get_test_data( + random_seed=42, + train_samples=TRAIN_SAMPLES, + test_samples=TEST_SAMPLES, + input_shape=(INPUT_DIM,), + num_classes=NUM_CLASSES, + ) + y_test = numerical_utils.to_categorical(y_test, num_classes=NUM_CLASSES) + y_train = numerical_utils.to_categorical( + y_train, num_classes=NUM_CLASSES + ) + + # Model Checkpoint load model (default) + model = get_model() + temp_dir = self.get_temp_dir() + filepath = os.path.join(temp_dir, "checkpoint.model.keras") + mode = "auto" + monitor = "val_loss" + save_best_only = True + + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + ref_weights = model.get_weights() + self.assertTrue(os.path.exists(filepath)) + new_model = saving.load_model(filepath) + new_weights = new_model.get_weights() + self.assertEqual(len(ref_weights), len(new_weights)) + for ref_w, w in zip(ref_weights, new_weights): + self.assertAllClose(ref_w, w) + + # Model Checkpoint load model weights + model = get_model() + temp_dir = self.get_temp_dir() + filepath = os.path.join(temp_dir, "checkpoint.weights.h5") + mode = "auto" + monitor = "val_loss" + save_best_only = True + + cbks = [ + callbacks.ModelCheckpoint( + filepath, + monitor=monitor, + save_best_only=save_best_only, + save_weights_only=True, + mode=mode, + ) + ] + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=1, + verbose=0, + ) + ref_weights = model.get_weights() + self.assertTrue(os.path.exists(filepath)) + new_model = get_model() + new_model.load_weights(filepath) + new_weights = new_model.get_weights() + self.assertEqual(len(ref_weights), len(new_weights)) + for ref_w, w in zip(ref_weights, new_weights): + self.assertAllClose(ref_w, w) diff --git a/keras/src/callbacks/monitor_callback.py b/keras/src/callbacks/monitor_callback.py new file mode 100644 index 000000000000..30510ca54e16 --- /dev/null +++ b/keras/src/callbacks/monitor_callback.py @@ -0,0 +1,104 @@ +import warnings + +from keras.src import ops +from keras.src.callbacks.callback import Callback +from keras.src.trainers import compile_utils + + +class MonitorCallback(Callback): + """Base class for callbacks that monitor a quantity and evaluates + improvements. + + This class provides common functionality for callbacks that monitor a + metric during training to determine whether a condition has been met, + such as improvement over time. It encapsulates logic for selecting + the comparison operation based on a `monitor` value and `mode`, and + computing whether a new value is an improvement. + + It is intended to be subclassed by other callbacks like `ModelCheckpoint`, + `EarlyStopping`, or `ReduceLROnPlateau`, and is not meant to be used + directly. + + Arguments: + monitor: Quantity to be monitored. Defaults to `"val_loss"`. + mode: One of `{"auto", "min", "max"}`. In `min` mode, training will aim + to minimize the monitored quantity; in `'max'` mode it will aim to + maximize it.; in `"auto"` mode, the direction is automatically + inferred from the name of the monitored quantity. Defaults to + `"auto"`. + baseline: Floating point initial "best" value of the metric to be + monitored. If `None` (default), the first monitored value will be + used. + min_delta: Minimum change in the monitored quantity to qualify as an + improvement, i.e. an absolute change of less than min_delta, will + count as no improvement. Defaults to `0`. + + Raises: + ValueError: If `mode='auto'` is selected and the direction of the metric + cannot be inferred. + """ + + def __init__( + self, + monitor="val_loss", + mode="auto", + baseline=None, + min_delta=0, + ): + super().__init__() + if mode not in ["auto", "min", "max"]: + warnings.warn( + f"{self.__class__.__name__} mode '{mode}' is unknown, fallback " + "to auto mode.", + stacklevel=2, + ) + mode = "auto" + self.monitor = monitor + self.mode = mode + self.best = baseline + self.min_delta = abs(min_delta) + self.monitor_op = None + + def _set_monitor_op(self): + if self.mode == "min": + self.monitor_op = ops.less + elif self.mode == "max": + self.monitor_op = ops.greater + else: + metric_name = self.monitor.removeprefix("val_") + if metric_name == "loss": + self.monitor_op = ops.less + if hasattr(self.model, "metrics"): + all_metrics = [] + for m in self.model.metrics: + if isinstance( + m, + ( + compile_utils.CompileMetrics, + compile_utils.MetricsList, + ), + ): + all_metrics.extend(m.metrics) + for m in all_metrics: + if m.name == metric_name: + if hasattr(m, "_direction"): + if m._direction == "up": + self.monitor_op = ops.greater + else: + self.monitor_op = ops.less + if self.monitor_op is None: + raise ValueError( + f"{self.__class__.__name__} callback received " + f"monitor={self.monitor}, but Keras isn't able to " + "automatically determine whether that metric should be " + "maximized or minimized. Pass `mode='max'` in order to " + "monitor based on the highest metric value, or pass " + "`mode='min'` in order to use the lowest value." + ) + if self.monitor_op == ops.less: + self.min_delta *= -1 + + def _is_improvement(self, monitor_value, reference_value): + if reference_value is None: + return True + return self.monitor_op(monitor_value - self.min_delta, reference_value) diff --git a/keras/src/callbacks/monitor_callback_test.py b/keras/src/callbacks/monitor_callback_test.py new file mode 100644 index 000000000000..f81112ed7122 --- /dev/null +++ b/keras/src/callbacks/monitor_callback_test.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest + +from keras.src import callbacks +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import ops +from keras.src import testing + + +class MonitorCallbackTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_monitor_op_logic(self): + x_train = np.random.random((10, 5)) + y_train = np.random.random((10, 1)) + x_test = np.random.random((10, 5)) + y_test = np.random.random((10, 1)) + model = models.Sequential( + ( + layers.Dense(1, activation="relu"), + layers.Dense(1, activation="relu"), + ) + ) + model.compile( + loss="mae", + optimizer="adam", + metrics=[ + "mse", + "acc", + "accuracy", + "hinge", + metrics.F1Score(name="f1_score"), + ], + ) + + cases = [ + ("max", "val_mse", "max"), + ("min", "val_loss", "min"), + ("auto", "val_mse", "min"), + ("auto", "loss", "min"), + ("auto", "acc", "max"), + ("auto", "val_accuracy", "max"), + ("auto", "hinge", "min"), + ("auto", "f1_score", "max"), + ] + for mode, monitor, expected_mode in cases: + monitor_callback = callbacks.MonitorCallback(monitor, mode) + monitor_callback.set_model(model) + model.fit( + x_train, + y_train, + batch_size=5, + validation_data=(x_test, y_test), + epochs=2, + verbose=0, + ) + monitor_callback._set_monitor_op() + if expected_mode == "max": + monitor_op = ops.greater + else: + monitor_op = ops.less + self.assertEqual(monitor_callback.monitor_op, monitor_op) + + with self.assertRaises(ValueError): + monitor = "unknown" + monitor_callback = callbacks.MonitorCallback(monitor) + monitor_callback.set_model(model) + model.fit( + x_train, + y_train, + batch_size=5, + validation_data=(x_test, y_test), + epochs=2, + verbose=0, + ) + monitor_callback._set_monitor_op() + + @pytest.mark.requires_trainable_backend + def test_min_delta(self): + monitor_callback = callbacks.MonitorCallback(mode="max", min_delta=0.5) + monitor_callback._set_monitor_op() + self.assertTrue(monitor_callback._is_improvement(0.75, 0)) + self.assertTrue(monitor_callback._is_improvement(0.5, None)) + self.assertFalse(monitor_callback._is_improvement(0.5, 0)) + self.assertFalse(monitor_callback._is_improvement(0.2, 0.5)) diff --git a/keras/src/callbacks/progbar_logger.py b/keras/src/callbacks/progbar_logger.py new file mode 100644 index 000000000000..ac10d655a97c --- /dev/null +++ b/keras/src/callbacks/progbar_logger.py @@ -0,0 +1,102 @@ +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.utils import io_utils +from keras.src.utils.progbar import Progbar + + +@keras_export("keras.callbacks.ProgbarLogger") +class ProgbarLogger(Callback): + """Callback that prints metrics to stdout. + + Args: + count_mode: One of `"steps"` or `"samples"`. + Whether the progress bar should + count samples seen or steps (batches) seen. + + Raises: + ValueError: In case of invalid `count_mode`. + """ + + def __init__(self): + super().__init__() + self.seen = 0 + self.progbar = None + self.target = None + self.verbose = 1 + self.epochs = 1 + + self._called_in_fit = False + + def set_params(self, params): + verbose = params["verbose"] + if verbose == "auto": + verbose = 1 + self.verbose = verbose + self.epochs = params["epochs"] + self.target = params["steps"] + + def on_train_begin(self, logs=None): + # When this logger is called inside `fit`, validation is silent. + self._called_in_fit = True + + def on_test_begin(self, logs=None): + if not self._called_in_fit: + self._reset_progbar() + self._maybe_init_progbar() + + def on_predict_begin(self, logs=None): + self._reset_progbar() + self._maybe_init_progbar() + + def on_epoch_begin(self, epoch, logs=None): + self._reset_progbar() + self._maybe_init_progbar() + if self.verbose and self.epochs > 1: + io_utils.print_msg(f"Epoch {epoch + 1}/{self.epochs}") + + def on_train_batch_end(self, batch, logs=None): + self._update_progbar(batch, logs) + + def on_test_batch_end(self, batch, logs=None): + if not self._called_in_fit: + self._update_progbar(batch, logs) + + def on_predict_batch_end(self, batch, logs=None): + # Don't pass prediction results. + self._update_progbar(batch, None) + + def on_epoch_end(self, epoch, logs=None): + self._finalize_progbar(logs) + + def on_test_end(self, logs=None): + if not self._called_in_fit: + self._finalize_progbar(logs) + + def on_predict_end(self, logs=None): + self._finalize_progbar(logs) + + def _reset_progbar(self): + self.seen = 0 + self.progbar = None + + def _maybe_init_progbar(self): + if self.progbar is None: + self.progbar = Progbar( + target=self.target, verbose=self.verbose, unit_name="step" + ) + + def _update_progbar(self, batch, logs=None): + """Updates the progbar.""" + logs = logs or {} + self._maybe_init_progbar() + self.seen = batch + 1 # One-indexed. + + if self.verbose == 1: + self.progbar.update(self.seen, list(logs.items()), finalize=False) + + def _finalize_progbar(self, logs): + logs = logs or {} + if self.target is None: + self.target = self.seen + self.progbar.target = self.target + self.progbar.update(self.target, list(logs.items()), finalize=True) diff --git a/keras/src/callbacks/reduce_lr_on_plateau.py b/keras/src/callbacks/reduce_lr_on_plateau.py new file mode 100644 index 000000000000..b9c40afc4e92 --- /dev/null +++ b/keras/src/callbacks/reduce_lr_on_plateau.py @@ -0,0 +1,130 @@ +import warnings + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.callbacks.monitor_callback import MonitorCallback +from keras.src.utils import io_utils + + +@keras_export("keras.callbacks.ReduceLROnPlateau") +class ReduceLROnPlateau(MonitorCallback): + """Reduce learning rate when a metric has stopped improving. + + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This callback monitors a + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Example: + + ```python + reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, + patience=5, min_lr=0.001) + model.fit(x_train, y_train, callbacks=[reduce_lr]) + ``` + + Args: + monitor: String. Quantity to be monitored. + factor: Float. Factor by which the learning rate will be reduced. + `new_lr = lr * factor`. + patience: Integer. Number of epochs with no improvement after which + learning rate will be reduced. + verbose: Integer. 0: quiet, 1: update messages. + mode: String. One of `{'auto', 'min', 'max'}`. In `'min'` mode, + the learning rate will be reduced when the + quantity monitored has stopped decreasing; in `'max'` mode it will + be reduced when the quantity monitored has stopped increasing; in + `'auto'` mode, the direction is automatically inferred from the name + of the monitored quantity. + min_delta: Float. Threshold for measuring the new optimum, to only focus + on significant changes. + cooldown: Integer. Number of epochs to wait before resuming normal + operation after the learning rate has been reduced. + min_lr: Float. Lower bound on the learning rate. + """ + + def __init__( + self, + monitor="val_loss", + factor=0.1, + patience=10, + verbose=0, + mode="auto", + min_delta=1e-4, + cooldown=0, + min_lr=0.0, + **kwargs, + ): + super().__init__(monitor, mode, min_delta=min_delta) + if factor >= 1.0: + raise ValueError( + "ReduceLROnPlateau does not support a factor >= 1.0. " + f"Received factor={factor}" + ) + + self.factor = factor + self.min_lr = min_lr + self.patience = patience + self.verbose = verbose + self.cooldown = cooldown + self.cooldown_counter = 0 # Cooldown counter. + self.wait = 0 + + def _reset(self): + """Resets wait counter and cooldown counter.""" + self.cooldown_counter = 0 + self.wait = 0 + + def on_train_begin(self, logs=None): + self._reset() + + def on_epoch_end(self, epoch, logs=None): + if self.monitor_op is None: + # Delay setup until the model's metrics are all built + self._set_monitor_op() + logs = logs or {} + logs["learning_rate"] = float( + backend.convert_to_numpy(self.model.optimizer.learning_rate) + ) + current = logs.get(self.monitor) + + if current is None: + warnings.warn( + "Learning rate reduction is conditioned on metric " + f"`{self.monitor}` which is not available. Available metrics " + f"are: {','.join(list(logs.keys()))}.", + stacklevel=2, + ) + else: + if self.in_cooldown(): + self.cooldown_counter -= 1 + self.wait = 0 + + if self._is_improvement(current, self.best): + self.best = current + self.wait = 0 + elif not self.in_cooldown(): + self.wait += 1 + if self.wait >= self.patience: + old_lr = float( + backend.convert_to_numpy( + self.model.optimizer.learning_rate + ) + ) + if old_lr > np.float32(self.min_lr): + new_lr = old_lr * self.factor + new_lr = max(new_lr, self.min_lr) + self.model.optimizer.learning_rate = new_lr + if self.verbose > 0: + io_utils.print_msg( + f"\nEpoch {epoch + 1}: " + "ReduceLROnPlateau reducing " + f"learning rate to {new_lr}." + ) + self.cooldown_counter = self.cooldown + self.wait = 0 + + def in_cooldown(self): + return self.cooldown_counter > 0 diff --git a/keras/src/callbacks/reduce_lr_on_plateau_test.py b/keras/src/callbacks/reduce_lr_on_plateau_test.py new file mode 100644 index 000000000000..96ebbaab2cf2 --- /dev/null +++ b/keras/src/callbacks/reduce_lr_on_plateau_test.py @@ -0,0 +1,139 @@ +import pytest + +from keras.src import callbacks +from keras.src import layers +from keras.src import optimizers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.testing import test_utils +from keras.src.utils import io_utils +from keras.src.utils import numerical_utils + + +class ReduceLROnPlateauTest(testing.TestCase): + def setUp(self): + (x_train, y_train), (x_test, y_test) = test_utils.get_test_data( + train_samples=10, + test_samples=10, + input_shape=(3,), + num_classes=2, + ) + y_test = numerical_utils.to_categorical(y_test) + y_train = numerical_utils.to_categorical(y_train) + + model = Sequential([layers.Dense(5), layers.Dense(2)]) + + model.compile( + loss="mse", + optimizer=optimizers.Adam(0.1), + ) + + self.model = model + self.x_train = x_train + self.x_test = x_test + self.y_train = y_train + self.y_test = y_test + + @pytest.mark.requires_trainable_backend + def test_reduces_lr_with_model_fit(self): + reduce_lr = callbacks.ReduceLROnPlateau( + patience=1, factor=0.1, monitor="val_loss", min_delta=100 + ) + + self.model.fit( + self.x_train, + self.y_train, + validation_data=(self.x_test, self.y_test), + callbacks=[reduce_lr], + epochs=2, + ) + + self.assertEqual(self.model.optimizer.learning_rate.value, 0.01) + + @pytest.mark.requires_trainable_backend + def test_throws_when_optimizer_has_schedule(self): + reduce_lr = callbacks.ReduceLROnPlateau( + patience=1, factor=0.1, monitor="val_loss", min_delta=100 + ) + + self.model.compile( + loss="mse", + optimizer=optimizers.Adam( + optimizers.schedules.PolynomialDecay( + initial_learning_rate=0.1, decay_steps=10 + ) + ), + ) + + with self.assertRaisesRegex( + TypeError, + "This optimizer was created with a `LearningRateSchedule`", + ): + self.model.fit( + self.x_train, + self.y_train, + validation_data=(self.x_test, self.y_test), + callbacks=[reduce_lr], + epochs=2, + ) + + @pytest.mark.requires_trainable_backend + def test_verbose_logging(self): + reduce_lr = callbacks.ReduceLROnPlateau( + patience=1, factor=0.1, monitor="val_loss", min_delta=100, verbose=1 + ) + io_utils.disable_interactive_logging() + io_utils.set_logging_verbosity("INFO") + + with self.assertLogs() as logs: + self.model.fit( + self.x_train, + self.y_train, + validation_data=(self.x_test, self.y_test), + callbacks=[reduce_lr], + epochs=2, + ) + expected_log = "ReduceLROnPlateau reducing learning rate to 0.01" + self.assertTrue(any(expected_log in log for log in logs.output)) + + @pytest.mark.requires_trainable_backend + def test_honors_min_lr(self): + reduce_lr = callbacks.ReduceLROnPlateau( + patience=1, + factor=0.1, + monitor="val_loss", + min_delta=10, + min_lr=0.005, + ) + + self.model.fit( + self.x_train, + self.y_train, + validation_data=(self.x_test, self.y_test), + callbacks=[reduce_lr], + epochs=4, + ) + + self.assertEqual(self.model.optimizer.learning_rate.value, 0.005) + + @pytest.mark.requires_trainable_backend + def test_cooldown(self): + reduce_lr = callbacks.ReduceLROnPlateau( + patience=1, + factor=0.1, + monitor="val_loss", + min_delta=100, + cooldown=2, + ) + + self.model.fit( + self.x_train, + self.y_train, + validation_data=(self.x_test, self.y_test), + callbacks=[reduce_lr], + epochs=4, + ) + + # With a cooldown of 2 epochs, we should only reduce the LR every other + # epoch, so after 4 epochs we will have reduced 2 times. + self.assertAllClose(self.model.optimizer.learning_rate.value, 0.001) diff --git a/keras/src/callbacks/remote_monitor.py b/keras/src/callbacks/remote_monitor.py new file mode 100644 index 000000000000..f8605a5c1726 --- /dev/null +++ b/keras/src/callbacks/remote_monitor.py @@ -0,0 +1,83 @@ +import json +import warnings + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback + +try: + import requests +except ImportError: + requests = None + + +@keras_export("keras.callbacks.RemoteMonitor") +class RemoteMonitor(Callback): + """Callback used to stream events to a server. + + Requires the `requests` library. + Events are sent to `root + '/publish/epoch/end/'` by default. Calls are + HTTP POST, with a `data` argument which is a + JSON-encoded dictionary of event data. + If `send_as_json=True`, the content type of the request will be + `"application/json"`. + Otherwise the serialized JSON will be sent within a form. + + Args: + root: String; root url of the target server. + path: String; path relative to `root` to which the events will be sent. + field: String; JSON field under which the data will be stored. + The field is used only if the payload is sent within a form + (i.e. when `send_as_json=False`). + headers: Dictionary; optional custom HTTP headers. + send_as_json: Boolean; whether the request should be + sent as `"application/json"`. + """ + + def __init__( + self, + root="http://localhost:9000", + path="/publish/epoch/end/", + field="data", + headers=None, + send_as_json=False, + ): + super().__init__() + + self.root = root + self.path = path + self.field = field + self.headers = headers + self.send_as_json = send_as_json + + def on_epoch_end(self, epoch, logs=None): + if requests is None: + raise ImportError("RemoteMonitor requires the `requests` library.") + logs = logs or {} + send = {} + send["epoch"] = epoch + for k, v in logs.items(): + # np.ndarray and np.generic are not scalar types + # therefore we must unwrap their scalar values and + # pass to the json-serializable dict 'send' + if isinstance(v, (np.ndarray, np.generic)): + send[k] = v.item() + else: + send[k] = v + try: + if self.send_as_json: + requests.post( + self.root + self.path, json=send, headers=self.headers + ) + else: + requests.post( + self.root + self.path, + {self.field: json.dumps(send)}, + headers=self.headers, + ) + except requests.exceptions.RequestException: + warnings.warn( + f"Could not reach RemoteMonitor root server at {self.root}", + stacklevel=2, + ) diff --git a/keras/src/callbacks/remote_monitor_test.py b/keras/src/callbacks/remote_monitor_test.py new file mode 100644 index 000000000000..bc77aa6c9788 --- /dev/null +++ b/keras/src/callbacks/remote_monitor_test.py @@ -0,0 +1,104 @@ +import warnings +from unittest import mock + +import numpy as np + +from conftest import skip_if_backend +from keras.src import backend +from keras.src import callbacks +from keras.src import layers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.utils import numerical_utils + +try: + import requests +except ImportError: + requests = None + + +class TerminateOnNaNTest(testing.TestCase): + def test_RemoteMonitor(self): + if requests is None: + self.skipTest("`requests` required to run this test") + + monitor = callbacks.RemoteMonitor() + # This will raise a warning since the default address in unreachable: + warning_msg = "Could not reach RemoteMonitor root server" + with warnings.catch_warnings(record=True) as warning_logs: + warnings.simplefilter("always") + monitor.on_epoch_end(0, logs={"loss": 0.0}) + self.assertIn(warning_msg, str(warning_logs[-1].message)) + + def test_RemoteMonitor_np_array(self): + if requests is None: + self.skipTest("`requests` required to run this test") + + with mock.patch("requests.post") as requests_post: + monitor = callbacks.RemoteMonitor(send_as_json=True) + a = np.arange(1) # a 1 by 1 array + logs = {"loss": 0.0, "val": a} + monitor.on_epoch_end(0, logs=logs) + send = {"loss": 0.0, "epoch": 0, "val": 0} + requests_post.assert_called_once_with( + monitor.root + monitor.path, json=send, headers=monitor.headers + ) + + def test_RemoteMonitor_np_float32(self): + if requests is None: + self.skipTest("`requests` required to run this test") + + with mock.patch("requests.post") as requests_post: + monitor = callbacks.RemoteMonitor(send_as_json=True) + a = np.float32(1.0) # a float32 generic type + logs = {"loss": 0.0, "val": a} + monitor.on_epoch_end(0, logs=logs) + send = {"loss": 0.0, "epoch": 0, "val": 1.0} + requests_post.assert_called_once_with( + monitor.root + monitor.path, json=send, headers=monitor.headers + ) + + @skip_if_backend( + "openvino", "openvino backend does not support `fit` method" + ) + def test_RemoteMonitorWithJsonPayload(self): + if requests is None: + self.skipTest("`requests` required to run this test") + + if backend.backend() == "numpy": + self.skipTest("Trainer not implemented from NumPy backend.") + TRAIN_SAMPLES = 10 + TEST_SAMPLES = 10 + INPUT_DIM = 3 + NUM_CLASSES = 2 + BATCH_SIZE = 4 + + np.random.seed(1337) + x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM)) + y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES) + x_test = np.random.random((TEST_SAMPLES, INPUT_DIM)) + y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES) + y_test = numerical_utils.to_categorical(y_test) + y_train = numerical_utils.to_categorical(y_train) + + model = Sequential([layers.Dense(NUM_CLASSES)]) + model.compile(loss="mean_squared_error", optimizer="sgd") + + with mock.patch("requests.post") as requests_post: + monitor = callbacks.RemoteMonitor(send_as_json=True) + hist = model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=[monitor], + epochs=1, + ) + send = { + "epoch": 0, + "loss": hist.history["loss"][0], + "val_loss": hist.history["val_loss"][0], + } + requests_post.assert_called_once_with( + monitor.root + monitor.path, json=send, headers=monitor.headers + ) diff --git a/keras/src/callbacks/swap_ema_weights.py b/keras/src/callbacks/swap_ema_weights.py new file mode 100644 index 000000000000..9c13a90fff53 --- /dev/null +++ b/keras/src/callbacks/swap_ema_weights.py @@ -0,0 +1,180 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback + + +@keras_export("keras.callbacks.SwapEMAWeights") +class SwapEMAWeights(Callback): + """Swaps model weights and EMA weights before and after evaluation. + + This callbacks replaces the model's weight values with the values of + the optimizer's EMA weights (the exponential moving average of the past + model weights values, implementing "Polyak averaging") before model + evaluation, and restores the previous weights after evaluation. + + The `SwapEMAWeights` callback is to be used in conjunction with + an optimizer that sets `use_ema=True`. + + Note that the weights are swapped in-place in order to save memory. + The behavior is undefined if you modify the EMA weights + or model weights in other callbacks. + + Example: + + ```python + # Remember to set `use_ema=True` in the optimizer + optimizer = SGD(use_ema=True) + model.compile(optimizer=optimizer, loss=..., metrics=...) + + # Metrics will be computed with EMA weights + model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()]) + + # If you want to save model checkpoint with EMA weights, you can set + # `swap_on_epoch=True` and place ModelCheckpoint after SwapEMAWeights. + model.fit( + X_train, + Y_train, + callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)] + ) + ``` + + Args: + swap_on_epoch: whether to perform swapping at `on_epoch_begin()` + and `on_epoch_end()`. This is useful if you want to use + EMA weights for other callbacks such as `ModelCheckpoint`. + Defaults to `False`. + """ + + def __init__(self, swap_on_epoch=False): + super().__init__() + self.swap_on_epoch = swap_on_epoch + + self._ema_weights_in_model = False + + def _tf_swap_variables(self, optimizer): + for var, average_var in zip( + self.model.trainable_variables, + optimizer._model_variables_moving_average, + ): + if isinstance(var, backend.Variable): + var = var.value + if isinstance(average_var, backend.Variable): + average_var = average_var.value + # swap using addition to prevent variable creation + optimizer._distribution_strategy.extended.update( + var, + lambda a, b: a.assign_add(b), + args=(average_var,), + ) + optimizer._distribution_strategy.extended.update( + var, + lambda a, b: b.assign(a - b), + args=(average_var,), + ) + optimizer._distribution_strategy.extended.update( + var, + lambda a, b: a.assign(a - b), + args=(average_var,), + ) + + def _backend_swap_variables(self, optimizer): + for var, average_var in zip( + self.model.trainable_variables, + optimizer._model_variables_moving_average, + ): + temporary_variable = ops.convert_to_numpy(var) + var.assign(average_var) + average_var.assign(temporary_variable) + + def _tf_finalize_ema_values(self, optimizer): + for var, average_var in zip( + self.model.trainable_variables, + optimizer._model_variables_moving_average, + ): + if isinstance(var, backend.Variable): + var = var.value + if isinstance(average_var, backend.Variable): + average_var = average_var.value + optimizer._distribution_strategy.extended.update( + average_var, + lambda a, b: a.assign(b), + args=(var,), + ) + + def _backend_finalize_ema_values(self, optimizer): + for var, average_var in zip( + self.model.trainable_variables, + optimizer._model_variables_moving_average, + ): + average_var.assign(var) + + def _swap_variables(self): + if hasattr(self.model.optimizer, "inner_optimizer"): + # LossScaleOptimizer + optimizer = self.model.optimizer.inner_optimizer + else: + optimizer = self.model.optimizer + if not hasattr(optimizer, "_model_variables_moving_average"): + raise ValueError( + "SwapEMAWeights must be used when " + "`use_ema=True` is set on the optimizer. " + f"Received: use_ema={optimizer.use_ema}" + ) + if backend.backend() == "tensorflow": + self._tf_swap_variables(optimizer) + else: + self._backend_swap_variables(optimizer) + + def _finalize_ema_values(self): + if hasattr(self.model.optimizer, "inner_optimizer"): + # LossScaleOptimizer + optimizer = self.model.optimizer.inner_optimizer + else: + optimizer = self.model.optimizer + if not hasattr(optimizer, "_model_variables_moving_average"): + raise ValueError( + "SwapEMAWeights must be used when " + "`use_ema=True` is set on the optimizer. " + f"Received: use_ema={optimizer.use_ema}" + ) + if backend.backend() == "tensorflow": + self._tf_finalize_ema_values(optimizer) + else: + self._backend_finalize_ema_values(optimizer) + + def on_epoch_begin(self, epoch, logs=None): + if self.swap_on_epoch and self._ema_weights_in_model: + self._swap_variables() + self._ema_weights_in_model = False + + def on_epoch_end(self, epoch, logs=None): + if self.swap_on_epoch and not self._ema_weights_in_model: + self._swap_variables() + self._ema_weights_in_model = True + # We need to recover EMA weights from the previously swapped weights + # in the last epoch. This is because, at the end of the fitting, + # `finalize_variable_values` will be called to assign + # `_model_variables_moving_average` to `trainable_variables`. + if epoch == self.params["epochs"] - 1: + self._finalize_ema_values() + + def on_test_begin(self, logs=None): + if not self._ema_weights_in_model: + self._swap_variables() + self._ema_weights_in_model = True + + def on_test_end(self, logs=None): + if self._ema_weights_in_model: + self._swap_variables() + self._ema_weights_in_model = False + + def on_predict_begin(self, logs=None): + if not self._ema_weights_in_model: + self._swap_variables() + self._ema_weights_in_model = True + + def on_predict_end(self, logs=None): + if not self._ema_weights_in_model: + self._swap_variables() + self._ema_weights_in_model = False diff --git a/keras/src/callbacks/swap_ema_weights_test.py b/keras/src/callbacks/swap_ema_weights_test.py new file mode 100644 index 000000000000..795f1452a189 --- /dev/null +++ b/keras/src/callbacks/swap_ema_weights_test.py @@ -0,0 +1,188 @@ +import os.path +import tempfile + +import pytest +import tensorflow as tf +from tensorflow.python.eager import context + +from keras.src import backend +from keras.src import callbacks +from keras.src import layers +from keras.src import losses +from keras.src import metrics +from keras.src import optimizers +from keras.src import saving +from keras.src import testing +from keras.src.models import Sequential +from keras.src.testing import test_utils +from keras.src.utils import numerical_utils + + +class SwapEMAWeightsTest(testing.TestCase): + def setUp(self): + (x_train, y_train), _ = test_utils.get_test_data( + train_samples=10, + test_samples=10, + input_shape=(3,), + num_classes=2, + random_seed=2023, + ) + y_train = numerical_utils.to_categorical(y_train) + + self.x_train = x_train + self.y_train = y_train + + def _get_compiled_model( + self, use_ema=True, jit_compile=True, loss_scale=False + ): + optimizer = optimizers.SGD(use_ema=use_ema, ema_momentum=0.9) + if loss_scale: + optimizer = optimizers.LossScaleOptimizer(optimizer) + model = Sequential( + [layers.Dense(2, kernel_initializer="ones", use_bias=False)] + ) + model.compile( + optimizer=optimizer, + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + jit_compile=jit_compile, + ) + return model + + @pytest.mark.requires_trainable_backend + def test_swap_ema_weights_with_invalid_optimizer(self): + model = self._get_compiled_model(use_ema=False) + with self.assertRaisesRegex( + ValueError, + ("SwapEMAWeights must be used when `use_ema=True` is set"), + ): + model.fit( + self.x_train, + self.y_train, + epochs=2, + callbacks=[callbacks.SwapEMAWeights()], + validation_data=(self.x_train, self.y_train), + ) + + @pytest.mark.requires_trainable_backend + def test_swap_ema_weights(self): + # not using SwapEMAWeights + model = self._get_compiled_model() + history = model.fit( + self.x_train, + self.y_train, + epochs=2, + validation_data=(self.x_train, self.y_train), + ) + logs = model.evaluate(self.x_train, self.y_train, return_dict=True) + # final metric during fitting is different from the evaluation + self.assertNotEqual( + history.history["val_mean_squared_error"][-1], + logs["mean_squared_error"], + ) + + # using SwapEMAWeights + model = self._get_compiled_model() + history = model.fit( + self.x_train, + self.y_train, + epochs=2, + callbacks=[callbacks.SwapEMAWeights()], + validation_data=(self.x_train, self.y_train), + ) + logs = model.evaluate(self.x_train, self.y_train, return_dict=True) + # final metric during fitting is same as the evaluation + self.assertEqual( + history.history["val_mean_squared_error"][-1], + logs["mean_squared_error"], + ) + + @pytest.mark.requires_trainable_backend + def test_swap_ema_weights_on_epoch(self): + # using SwapEMAWeights together with ModelCheckpoint + model = self._get_compiled_model() + with tempfile.TemporaryDirectory() as temp_dir: + model.fit( + self.x_train, + self.y_train, + epochs=2, + callbacks=[ + callbacks.SwapEMAWeights(swap_on_epoch=True), + callbacks.ModelCheckpoint( + os.path.join(temp_dir, "{epoch:1d}.keras") + ), + ], + validation_data=(self.x_train, self.y_train), + ) + model2 = saving.load_model(os.path.join(temp_dir, "2.keras")) + + logs = model.evaluate(self.x_train, self.y_train, return_dict=True) + logs2 = model2.evaluate(self.x_train, self.y_train, return_dict=True) + # saved checkpoint will be applied by EMA weights + self.assertEqual( + logs["mean_squared_error"], + logs2["mean_squared_error"], + ) + + @pytest.mark.requires_trainable_backend + def test_swap_ema_weights_with_loss_scale_optimizer(self): + model = self._get_compiled_model(loss_scale=True) + history = model.fit( + self.x_train, + self.y_train, + epochs=2, + callbacks=[callbacks.SwapEMAWeights()], + validation_data=(self.x_train, self.y_train), + ) + logs = model.evaluate(self.x_train, self.y_train, return_dict=True) + # final metric during fitting is same as the evaluation + self.assertEqual( + history.history["val_mean_squared_error"][-1], + logs["mean_squared_error"], + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The distribute test can only run with TF backend.", + ) + def test_swap_ema_weights_with_tf_distribute(self): + # Need at least 2 devices for distribution related tests. + cpus = tf.config.list_physical_devices("CPU") + context._reset_context() + tf.config.set_logical_device_configuration( + cpus[0], + [ + tf.config.LogicalDeviceConfiguration(), + tf.config.LogicalDeviceConfiguration(), + ], + ) + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + with strategy.scope(): + # TODO: set jit_compile=True once the issue is resolved in + # integration_tests/tf_distribute_training_test.py#L52 + model = self._get_compiled_model(jit_compile=False) + with tempfile.TemporaryDirectory() as temp_dir: + model.fit( + self.x_train, + self.y_train, + epochs=2, + callbacks=[ + callbacks.SwapEMAWeights(swap_on_epoch=True), + callbacks.ModelCheckpoint( + os.path.join( + temp_dir, "distributed_{epoch:1d}.keras" + ) + ), + ], + validation_data=(self.x_train, self.y_train), + ) + model2 = saving.load_model( + os.path.join(temp_dir, "distributed_2.keras") + ) + logs = model.evaluate(self.x_train, self.y_train, return_dict=True) + logs2 = model2.evaluate(self.x_train, self.y_train, return_dict=True) + # saved checkpoint will be applied by EMA weights + self.assertEqual( + logs["mean_squared_error"], + logs2["mean_squared_error"], + ) diff --git a/keras/src/callbacks/tensorboard.py b/keras/src/callbacks/tensorboard.py new file mode 100644 index 000000000000..506c8d6dafb4 --- /dev/null +++ b/keras/src/callbacks/tensorboard.py @@ -0,0 +1,688 @@ +import logging +import os +import sys +import time +import warnings + +from keras.src import backend +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.layers import Embedding +from keras.src.optimizers import Optimizer +from keras.src.utils import file_utils + + +@keras_export("keras.callbacks.TensorBoard") +class TensorBoard(Callback): + """Enable visualizations for TensorBoard. + + TensorBoard is a visualization tool provided with TensorFlow. A TensorFlow + installation is required to use this callback. + + This callback logs events for TensorBoard, including: + + * Metrics summary plots + * Training graph visualization + * Weight histograms + * Sampled profiling + + When used in `model.evaluate()` or regular validation + in addition to epoch summaries, there will be a summary that records + evaluation metrics vs `model.optimizer.iterations` written. The metric names + will be prepended with `evaluation`, with `model.optimizer.iterations` being + the step in the visualized TensorBoard. + + If you have installed TensorFlow with pip, you should be able + to launch TensorBoard from the command line: + + ``` + tensorboard --logdir=path_to_your_logs + ``` + + You can find more information about TensorBoard + [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). + + Args: + log_dir: the path of the directory where to save the log files to be + parsed by TensorBoard. e.g., + `log_dir = os.path.join(working_dir, 'logs')`. + This directory should not be reused by any other callbacks. + histogram_freq: frequency (in epochs) at which to compute + weight histograms for the layers of the model. If set to 0, + histograms won't be computed. Validation data (or split) must be + specified for histogram visualizations. + write_graph: (Not supported at this time) + Whether to visualize the graph in TensorBoard. + Note that the log file can become quite large + when `write_graph` is set to `True`. + write_images: whether to write model weights to visualize as image in + TensorBoard. + write_steps_per_second: whether to log the training steps per second + into TensorBoard. This supports both epoch and batch frequency + logging. + update_freq: `"batch"` or `"epoch"` or integer. When using `"epoch"`, + writes the losses and metrics to TensorBoard after every epoch. + If using an integer, let's say `1000`, all metrics and losses + (including custom ones added by `Model.compile`) will be logged to + TensorBoard every 1000 batches. `"batch"` is a synonym for 1, + meaning that they will be written every batch. + Note however that writing too frequently to TensorBoard can slow + down your training, especially when used with distribution + strategies as it will incur additional synchronization overhead. + Batch-level summary writing is also available via `train_step` + override. Please see + [TensorBoard Scalars tutorial]( + https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) + for more details. + profile_batch: Profile the batch(es) to sample compute characteristics. + profile_batch must be a non-negative integer or a tuple of integers. + A pair of positive integers signify a range of batches to profile. + By default, profiling is disabled. + embeddings_freq: frequency (in epochs) at which embedding layers will be + visualized. If set to 0, embeddings won't be visualized. + embeddings_metadata: Dictionary which maps embedding layer names to the + filename of a file in which to save metadata for the embedding layer. + In case the same metadata file is to be + used for all embedding layers, a single filename can be passed. + + Examples: + + ```python + tensorboard_callback = keras.callbacks.TensorBoard(log_dir="./logs") + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + # Then run the tensorboard command to view the visualizations. + ``` + + Custom batch-level summaries in a subclassed Model: + + ```python + class MyModel(keras.Model): + + def build(self, _): + self.dense = keras.layers.Dense(10) + + def call(self, x): + outputs = self.dense(x) + tf.summary.histogram('outputs', outputs) + return outputs + + model = MyModel() + model.compile('sgd', 'mse') + + # Make sure to set `update_freq=N` to log a batch-level summary every N + # batches. In addition to any `tf.summary` contained in `model.call()`, + # metrics added in `Model.compile` will be logged every N batches. + tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1) + model.fit(x_train, y_train, callbacks=[tb_callback]) + ``` + + Custom batch-level summaries in a Functional API Model: + + ```python + def my_summary(x): + tf.summary.histogram('x', x) + return x + + inputs = keras.Input(10) + x = keras.layers.Dense(10)(inputs) + outputs = keras.layers.Lambda(my_summary)(x) + model = keras.Model(inputs, outputs) + model.compile('sgd', 'mse') + + # Make sure to set `update_freq=N` to log a batch-level summary every N + # batches. In addition to any `tf.summary` contained in `Model.call`, + # metrics added in `Model.compile` will be logged every N batches. + tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1) + model.fit(x_train, y_train, callbacks=[tb_callback]) + ``` + + Profiling: + + ```python + # Profile a single batch, e.g. the 5th batch. + tensorboard_callback = keras.callbacks.TensorBoard( + log_dir='./logs', profile_batch=5) + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + + # Profile a range of batches, e.g. from 10 to 20. + tensorboard_callback = keras.callbacks.TensorBoard( + log_dir='./logs', profile_batch=(10,20)) + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + ``` + """ # noqa: E501 + + def __init__( + self, + log_dir="logs", + histogram_freq=0, + write_graph=True, + write_images=False, + write_steps_per_second=False, + update_freq="epoch", + profile_batch=0, + embeddings_freq=0, + embeddings_metadata=None, + ): + super().__init__() + + self.log_dir = str(log_dir) + self.histogram_freq = histogram_freq + self.write_graph = write_graph + self.write_images = write_images + self.write_steps_per_second = write_steps_per_second + self.update_freq = 1 if update_freq == "batch" else update_freq + self.embeddings_freq = embeddings_freq + self.embeddings_metadata = embeddings_metadata + if profile_batch: + if backend.backend() not in ("jax", "tensorflow"): + # TODO: profiling not available in torch, numpy + raise ValueError( + "Profiling is not yet available with the " + f"{backend.backend()} backend. Please open a PR " + "if you'd like to add this feature. Received: " + f"profile_batch={profile_batch} (must be 0)" + ) + elif backend.backend() == "jax": + if sys.version_info[1] < 12: + warnings.warn( + "Profiling with the " + f"{backend.backend()} backend requires python >= 3.12." + ) + profile_batch = 0 + + self._init_profile_batch(profile_batch) + self._global_train_batch = 0 + self._global_test_batch = 0 + self._previous_epoch_iterations = 0 + self._train_accumulated_time = 0 + self._batch_start_time = 0 + self._summary_module = None + + # Lazily initialized in order to avoid creating event files when + # not needed. + self._writers = {} + + # Used to restore any existing `SummaryWriter` after training ends. + self._prev_summary_state = [] + + def set_model(self, model): + """Sets Keras model and writes graph if specified.""" + self._model = model + self._log_write_dir = self.log_dir + + self._train_dir = os.path.join(self._log_write_dir, "train") + self._val_dir = os.path.join(self._log_write_dir, "validation") + self._writers = {} # Resets writers. + + self._should_write_train_graph = False + if self.write_graph: + self._write_keras_model_summary() + self._should_write_train_graph = True + if self.embeddings_freq: + self._configure_embeddings() + + @property + def summary(self): + if self._summary_module is None: + import tensorflow.summary as summary + + self._summary_module = summary + return self._summary_module + + @property + def _train_writer(self): + if "train" not in self._writers: + self._writers["train"] = self.summary.create_file_writer( + self._train_dir + ) + return self._writers["train"] + + @property + def _val_writer(self): + if "val" not in self._writers: + self._writers["val"] = self.summary.create_file_writer( + self._val_dir + ) + return self._writers["val"] + + def _write_keras_model_train_graph(self): + """Writes Keras model train_function graph to TensorBoard.""" + with self._train_writer.as_default(): + train_fn = self.model.train_function + # If the train_function is a `tf.function`, we can write out a + # graph + if hasattr(train_fn, "function_spec"): + # TODO(b/243822285): Use _variable_creation_fn directly. + if hasattr(train_fn, "_concrete_stateful_fn"): + self.summary.graph(train_fn._concrete_stateful_fn.graph) + else: + self.summary.graph( + train_fn._concrete_variable_creation_fn.graph + ) + + def _write_keras_model_summary(self): + """Writes Keras graph network summary to TensorBoard.""" + with self._train_writer.as_default(): + if ( + self.model.__class__.__name__ == "Functional" + or self.model.__class__.__name__ == "Sequential" + ): + keras_model_summary("keras", self.model, step=0) + + def _configure_embeddings(self): + """Configure the Projector for embeddings.""" + from google.protobuf import text_format + from tensorboard.plugins import projector + + config = projector.ProjectorConfig() + for layer in self.model.layers: + if isinstance(layer, Embedding): + embedding = config.embeddings.add() + # Embeddings are always the first layer, so this naming should + # be consistent in any keras models checkpoints. + name = ( + "layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE" + ) + embedding.tensor_name = name + + if self.embeddings_metadata is not None: + if isinstance(self.embeddings_metadata, str): + embedding.metadata_path = self.embeddings_metadata + else: + if layer.name in self.embeddings_metadata.keys(): + embedding.metadata_path = ( + self.embeddings_metadata.pop(layer.name) + ) + + if self.embeddings_metadata and not isinstance( + self.embeddings_metadata, str + ): + raise ValueError( + "Unrecognized `Embedding` layer names passed to " + "`keras.callbacks.TensorBoard` `embeddings_metadata` " + f"argument: {self.embeddings_metadata.keys()}" + ) + + config_pbtxt = text_format.MessageToString(config) + path = os.path.join(self._log_write_dir, "projector_config.pbtxt") + with file_utils.File(path, "w") as f: + f.write(config_pbtxt) + + def _push_writer(self, writer, step): + """Sets the default writer for custom batch-level summaries.""" + if self.update_freq == "epoch": + return + + def should_record(): + return step % self.update_freq == 0 + + summary_context = ( + writer.as_default(step), + self.summary.record_if(should_record), + ) + self._prev_summary_state.append(summary_context) + summary_context[0].__enter__() + summary_context[1].__enter__() + + def _pop_writer(self): + """Pops the current writer.""" + if self.update_freq == "epoch": + return + + # See _push_writer for the content of the previous_context, which is + # pair of context. + previous_context = self._prev_summary_state.pop() + previous_context[1].__exit__(*sys.exc_info()) + previous_context[0].__exit__(*sys.exc_info()) + + def _close_writers(self): + for writer in self._writers.values(): + writer.close() + + def _init_profile_batch(self, profile_batch): + """Validate profile_batch value and set the range of batches to profile. + + Sets values of _start_batch and _stop_batch attributes, + specifying the start and stop batch to profile. + Setting `profile_batch=0` disables profiling. + + Args: + profile_batch: The range of batches to profile. Should be a + non-negative integer or a comma separated string of pair of positive + integers. A pair of positive integers signify a range of batches to + profile. + + Raises: + ValueError: If profile_batch is not an integer or a comma separated + pair of positive integers. + + """ + profile_batch_error_message = ( + "profile_batch must be a non-negative integer or " + "2-tuple of positive " + "integers. A pair of positive integers " + "signifies a range of batches " + f"to profile. Found: {profile_batch}" + ) + + # Support legacy way of specifying "start,stop" or "start" as str. + if isinstance(profile_batch, str): + profile_batch = str(profile_batch).split(",") + profile_batch = tree.map_structure(int, profile_batch) + + if isinstance(profile_batch, int): + self._start_batch = profile_batch + self._stop_batch = profile_batch + elif ( + isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2 + ): + self._start_batch, self._stop_batch = profile_batch + else: + raise ValueError(profile_batch_error_message) + + if self._start_batch < 0 or self._stop_batch < self._start_batch: + raise ValueError(profile_batch_error_message) + + # True when the profiler was successfully started by this callback. + # We track the status here to make sure callbacks do not interfere with + # each other. The callback will only stop the profiler it started. + self._profiler_started = False + self._batch_trace_context = None + + if self._start_batch > 0: + # Warm up and improve the profiling accuracy. + self._start_profiler(logdir="") + self._stop_profiler(save=False) + # True when a trace is running. + self._is_tracing = False + + # Setting `profile_batch=0` disables profiling. + self._should_trace = not ( + self._start_batch == 0 and self._stop_batch == 0 + ) + + def on_train_begin(self, logs=None): + self._global_train_batch = 0 + self._previous_epoch_iterations = 0 + self._push_writer(self._train_writer, self._global_train_batch) + + def on_train_end(self, logs=None): + self._pop_writer() + + if self._is_tracing: + self._stop_trace() + + self._close_writers() + + def on_test_begin(self, logs=None): + self._push_writer(self._val_writer, self._global_test_batch) + + def on_test_end(self, logs=None): + if self.model.optimizer and hasattr(self.model.optimizer, "iterations"): + with self._val_writer.as_default(): + for name, value in logs.items(): + self.summary.scalar( + f"evaluation_{name}_vs_iterations", + value, + step=self.model.optimizer.iterations, + ) + self._pop_writer() + + def on_train_batch_begin(self, batch, logs=None): + self._global_train_batch += 1 + if self.write_steps_per_second: + self._batch_start_time = time.time() + if not self._should_trace: + return + + if self._global_train_batch == self._start_batch: + self._start_trace() + if self._profiler_started: + self._batch_trace_context = backend.tensorboard.start_batch_trace( + batch + ) + + def on_train_batch_end(self, batch, logs=None): + if self._should_write_train_graph: + self._write_keras_model_train_graph() + self._should_write_train_graph = False + if self.write_steps_per_second: + batch_run_time = time.time() - self._batch_start_time + self.summary.scalar( + "batch_steps_per_second", + 1.0 / batch_run_time, + step=self._global_train_batch, + ) + + # `logs` isn't necessarily always a dict + if isinstance(logs, dict): + for name, value in logs.items(): + self.summary.scalar( + f"batch_{name}", value, step=self._global_train_batch + ) + + if not self._should_trace: + return + + if self._is_tracing: + if self._profiler_started and self._batch_trace_context is not None: + backend.tensorboard.stop_batch_trace(self._batch_trace_context) + self._batch_trace_context = None + if self._global_train_batch >= self._stop_batch: + self._stop_trace() + + def on_test_batch_begin(self, batch, logs=None): + self._global_test_batch += 1 + + def on_epoch_begin(self, epoch, logs=None): + # Keeps track of epoch for profiling. + if self.write_steps_per_second: + self._previous_epoch_iterations = ops.convert_to_tensor( + self.model.optimizer.iterations, "float32" + ) + self._epoch_start_time = time.time() + + def on_epoch_end(self, epoch, logs=None): + """Runs metrics and histogram summaries at epoch end.""" + self._log_epoch_metrics(epoch, logs) + + if self.histogram_freq and epoch % self.histogram_freq == 0: + self._log_weights(epoch) + + if self.embeddings_freq and epoch % self.embeddings_freq == 0: + self._log_embeddings(epoch) + + def _start_trace(self): + self.summary.trace_on(graph=True, profiler=False) + self._start_profiler(logdir=self._train_dir) + self._is_tracing = True + + def _stop_trace(self, batch=None): + """Logs the trace graph to TensorBoard.""" + if batch is None: + batch = self._stop_batch + with self._train_writer.as_default(): + # TODO(b/126388999): Remove step info in the summary name. + self.summary.trace_export(name="batch_%d" % batch, step=batch) + self._stop_profiler() + self._is_tracing = False + + def _collect_learning_rate(self, logs): + if isinstance(self.model.optimizer, Optimizer): + logs["learning_rate"] = float( + ops.convert_to_numpy(self.model.optimizer.learning_rate) + ) + return logs + + def _compute_steps_per_second(self): + current_iteration = self.model.optimizer.iterations + time_since_epoch_begin = time.time() - self._epoch_start_time + current_iteration = ops.convert_to_tensor(current_iteration, "float32") + time_since_epoch_begin = ops.convert_to_tensor( + time_since_epoch_begin, "float32" + ) + + steps_per_second = ( + current_iteration - self._previous_epoch_iterations + ) / time_since_epoch_begin + return float(steps_per_second) + + def _log_epoch_metrics(self, epoch, logs): + """Writes epoch metrics out as scalar summaries. + + Args: + epoch: Int. The global step to use for TensorBoard. + logs: Dict. Keys are scalar summary names, values are scalars. + """ + if not logs: + return + + train_logs = {k: v for k, v in logs.items() if not k.startswith("val_")} + val_logs = {k: v for k, v in logs.items() if k.startswith("val_")} + train_logs = self._collect_learning_rate(train_logs) + if self.write_steps_per_second: + train_logs["steps_per_second"] = self._compute_steps_per_second() + + if train_logs: + with self._train_writer.as_default(): + for name, value in train_logs.items(): + self.summary.scalar(f"epoch_{name}", value, step=epoch) + if val_logs: + with self._val_writer.as_default(): + for name, value in val_logs.items(): + name = name[4:] # Remove 'val_' prefix. + self.summary.scalar(f"epoch_{name}", value, step=epoch) + + def _log_weights(self, epoch): + """Logs the weights of the Model to TensorBoard.""" + with self._train_writer.as_default(): + for layer in self.model.layers: + for weight in layer.weights: + weight_name = weight.name.replace(":", "_") + # Add a suffix to prevent summary tag name collision. + histogram_weight_name = f"{weight_name}/histogram" + self.summary.histogram( + histogram_weight_name, weight, step=epoch + ) + if self.write_images: + # Add a suffix to prevent summary tag name + # collision. + image_weight_name = f"{weight_name}/image" + self._log_weight_as_image( + weight, image_weight_name, epoch + ) + self._train_writer.flush() + + def _log_weight_as_image(self, weight, weight_name, epoch): + """Logs a weight as a TensorBoard image.""" + w_img = ops.squeeze(weight) + shape = w_img.shape + if len(shape) == 1: # Bias case + w_img = ops.reshape(w_img, [1, shape[0], 1, 1]) + elif len(shape) == 2: # Dense layer kernel case + if shape[0] > shape[1]: + w_img = ops.transpose(w_img) + shape = w_img.shape + w_img = ops.reshape(w_img, [1, shape[0], shape[1], 1]) + elif len(shape) == 3: # ConvNet case + if backend.image_data_format() == "channels_last": + # Switch to channels_first to display every kernel as a separate + # image. + w_img = ops.transpose(w_img, [2, 0, 1]) + shape = w_img.shape + w_img = ops.reshape(w_img, [shape[0], shape[1], shape[2], 1]) + + w_img = backend.convert_to_numpy(w_img) + shape = w_img.shape + # Not possible to handle 3D convnets etc. + if len(shape) == 4 and shape[-1] in [1, 3, 4]: + self.summary.image(weight_name, w_img, step=epoch) + + def _log_embeddings(self, epoch): + embeddings_ckpt = os.path.join( + self._log_write_dir, + "train", + f"keras_embedding.ckpt-{epoch}.weights.h5", + ) + self.model.save_weights(embeddings_ckpt) + + def _start_profiler(self, logdir): + """Starts the profiler if currently inactive. + + Args: + logdir: Directory where profiler results will be saved. + """ + if self._profiler_started: + return + try: + backend.tensorboard.start_trace(logdir) + self._profiler_started = True + except Exception as e: + # Profiler errors should not be fatal. + logging.error("Failed to start profiler: %s", e) + + def _stop_profiler(self, save=True): + """Stops the profiler if currently active. + + Args: + save: Whether to save the profiler results to TensorBoard. + """ + if not self._profiler_started: + return + try: + backend.tensorboard.stop_trace(save=save) + except Exception as e: + # Profiler errors should not be fatal. + logging.error("Failed to stop profiler: %s", e) + finally: + self._profiler_started = False + + +def keras_model_summary(name, data, step=None): + """Writes a Keras model as JSON to as a Summary. + + Writing the Keras model configuration allows the TensorBoard graph plugin to + render a conceptual graph, as opposed to graph of ops. In case the model + fails to serialize as JSON, it ignores and returns False. + + Args: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A Keras Model to write. + step: Explicit `int64`-castable monotonic step value for this summary. + If omitted, this defaults to `tf.summary.experimental.get_step()`, + which must not be `None`. + + Returns: + True on success, or False if no summary was written because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is `None`. + """ + import tensorflow.summary as summary + from tensorflow.compat.v1 import SummaryMetadata + + summary_metadata = SummaryMetadata() + # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for + # the rationale. + summary_metadata.plugin_data.plugin_name = "graph_keras_model" + # version number = 1 + summary_metadata.plugin_data.content = b"1" + + try: + json_string = data.to_json() + except Exception as exc: + # An exception should not break a model code. + warnings.warn(f"Model failed to serialize as JSON. Ignoring... {exc}") + return False + + with summary.experimental.summary_scope( + name, "graph_keras_model", [data, step] + ) as (tag, _): + return summary.write( + tag=tag, tensor=json_string, step=step, metadata=summary_metadata + ) diff --git a/keras/src/callbacks/tensorboard_test.py b/keras/src/callbacks/tensorboard_test.py new file mode 100644 index 000000000000..a691509ea7db --- /dev/null +++ b/keras/src/callbacks/tensorboard_test.py @@ -0,0 +1,784 @@ +import collections +import os +import random +import sys + +import numpy as np +import pytest +import tensorflow.summary as summary +from tensorflow.compat.v1 import SummaryMetadata +from tensorflow.core.util import event_pb2 +from tensorflow.python.lib.io import tf_record + +from keras.src import backend +from keras.src import callbacks +from keras.src import layers +from keras.src import losses +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import testing +from keras.src.optimizers import schedules + +# Note: this file and tensorboard in general has a dependency on tensorflow + +# A summary that was emitted during a test. Fields: +# logdir: str. The logdir of the FileWriter to which the summary was +# written. +# tag: str. The name of the summary. +_ObservedSummary = collections.namedtuple("_ObservedSummary", ("logdir", "tag")) + + +class _SummaryIterator: + """Yields `Event` protocol buffers from a given path.""" + + def __init__(self, path): + self._tf_record_iterator = tf_record.tf_record_iterator(path) + + def __iter__(self): + return self + + def __next__(self): + r = next(self._tf_record_iterator) + return event_pb2.Event.FromString(r) + + next = __next__ + + +class _SummaryFile: + """A record of summary tags and the files to which they were written. + + Fields `scalars`, `images`, `histograms`, and `tensors` are sets + containing `_ObservedSummary` values. + """ + + def __init__(self): + self.scalars = set() + self.images = set() + self.histograms = set() + self.tensors = set() + self.graph_defs = [] + self.convert_from_v2_summary_proto = False + + +def list_summaries(logdir): + """Read all summaries under the logdir into a `_SummaryFile`. + + Args: + logdir: A path to a directory that contains zero or more event + files, either as direct children or in transitive subdirectories. + Summaries in these events must only contain old-style scalars, + images, and histograms. Non-summary events, like `graph_def`s, are + ignored. + + Returns: + A `_SummaryFile` object reflecting all summaries written to any + event files in the logdir or any of its descendant directories. + + Raises: + ValueError: If an event file contains an summary of unexpected kind. + """ + result = _SummaryFile() + for dirpath, _, filenames in os.walk(logdir): + for filename in filenames: + if not filename.startswith("events.out."): + continue + path = os.path.join(dirpath, filename) + for event in _SummaryIterator(path): + if event.graph_def: + result.graph_defs.append(event.graph_def) + if not event.summary: # (e.g., it's a `graph_def` event) + continue + for value in event.summary.value: + tag = value.tag + # Case on the `value` rather than the summary metadata + # because the Keras callback uses `summary_ops_v2` to emit + # old-style summaries. See b/124535134. + kind = value.WhichOneof("value") + container = { + "simple_value": result.scalars, + "image": result.images, + "histo": result.histograms, + "tensor": result.tensors, + }.get(kind) + if container is None: + raise ValueError( + "Unexpected summary kind %r in event file %s:\n%r" + % (kind, path, event) + ) + elif kind == "tensor" and tag != "keras": + # Convert the tf2 summary proto to old style for type + # checking. + plugin_name = value.metadata.plugin_data.plugin_name + container = { + "images": result.images, + "histograms": result.histograms, + "scalars": result.scalars, + }.get(plugin_name) + if container is not None: + result.convert_from_v2_summary_proto = True + else: + container = result.tensors + container.add(_ObservedSummary(logdir=dirpath, tag=tag)) + return result + + +class TestTensorBoardV2(testing.TestCase): + def _get_log_dirs(self): + logdir = os.path.join( + self.get_temp_dir(), str(random.randint(1, int(1e7))), "tb" + ) + train_dir = os.path.join(logdir, "train") + validation_dir = os.path.join(logdir, "validation") + return logdir, train_dir, validation_dir + + def _get_model(self, compile_model=True): + model = models.Sequential( + [ + layers.Input((10, 10, 1)), + layers.Flatten(), + layers.Dense(1), + ] + ) + if compile_model: + model.compile("sgd", "mse") + return model + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_basic(self): + model = self._get_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard(logdir) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + + summary_file = list_summaries(logdir) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=validation_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary( + logdir=validation_dir, + tag="evaluation_loss_vs_iterations", + ), + }, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_across_invocations(self): + """Regression test for summary writer resource use-after-free.""" + model = self._get_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard(logdir) + + for _ in (1, 2): + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + + summary_file = list_summaries(logdir) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=validation_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary( + logdir=validation_dir, + tag="evaluation_loss_vs_iterations", + ), + }, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_no_spurious_event_files(self): + model = self._get_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + logdir, train_dir, _ = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard(logdir) + model.fit(x, y, batch_size=2, epochs=2, callbacks=[tb_cbk]) + + events_file_run_basenames = set() + for dirpath, _, filenames in os.walk(train_dir): + if any(fn.startswith("events.out.") for fn in filenames): + events_file_run_basenames.add(os.path.basename(dirpath)) + self.assertEqual(events_file_run_basenames, {"train"}) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_batch_metrics(self): + model = self._get_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard(logdir, update_freq=1) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + + summary_file = list_summaries(logdir) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="batch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=validation_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary( + logdir=validation_dir, + tag="evaluation_loss_vs_iterations", + ), + }, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_learning_rate_schedules(self): + model = self._get_model(compile_model=False) + opt = optimizers.SGD(schedules.CosineDecay(0.01, 1)) + model.compile(opt, "mse") + logdir, train_dir, _ = self._get_log_dirs() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + callbacks=[callbacks.TensorBoard(logdir)], + ) + + summary_file = list_summaries(logdir) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + }, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_global_step(self): + model = self._get_model(compile_model=False) + opt = optimizers.SGD(schedules.CosineDecay(0.01, 1)) + model.compile(opt, "mse") + logdir, train_dir, _ = self._get_log_dirs() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + verbose=0, + callbacks=[ + callbacks.TensorBoard( + logdir, + update_freq=1, + profile_batch=0, + write_steps_per_second=True, + ) + ], + ) + + summary_file = list_summaries(logdir) + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="batch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary( + logdir=train_dir, tag="epoch_steps_per_second" + ), + _ObservedSummary( + logdir=train_dir, tag="batch_steps_per_second" + ), + }, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_weight_histograms(self): + model = self._get_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard(logdir, histogram_freq=1) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + summary_file = list_summaries(logdir) + + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=validation_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary( + logdir=validation_dir, + tag="evaluation_loss_vs_iterations", + ), + }, + ) + self.assertEqual( + self._strip_layer_names(summary_file.histograms, "sequential"), + {_ObservedSummary(logdir=train_dir, tag="histogram")}, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_weight_images(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (10, 10, 1) + x_shape = (10, 10, 10, 1) + else: + input_shape = (1, 10, 10) + x_shape = (10, 1, 10, 10) + x, y = np.ones(x_shape), np.ones((10, 1)) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard( + logdir, histogram_freq=1, write_images=True + ) + model_type = "sequential" + model = models.Sequential( + [ + layers.Input(input_shape), + layers.Conv2D(3, 10), + layers.GlobalAveragePooling2D(), + layers.Dense(1), + ] + ) + model.compile("sgd", "mse") + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + summary_file = list_summaries(logdir) + + self.assertEqual( + summary_file.scalars, + { + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=validation_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary( + logdir=validation_dir, + tag="evaluation_loss_vs_iterations", + ), + }, + ) + self.assertEqual( + self._strip_layer_names(summary_file.histograms, model_type), + { + _ObservedSummary(logdir=train_dir, tag="histogram"), + }, + ) + expected_image_summaries = { + _ObservedSummary(logdir=train_dir, tag="bias/image"), + _ObservedSummary(logdir=train_dir, tag="kernel/image"), + } + self.assertEqual( + self._strip_variable_names(summary_file.images), + expected_image_summaries, + ) + + @pytest.mark.requires_trainable_backend + def test_TensorBoard_projector_callback(self): + model = models.Sequential( + [ + layers.Input((10,)), + layers.Embedding(10, 10, name="test_embedding"), + layers.Dense(1, activation="sigmoid"), + ] + ) + model.compile( + optimizer="adam", loss=losses.BinaryCrossentropy(from_logits=True) + ) + x, y = np.ones((10, 10)), np.ones((10, 10)) + logdir, _, _ = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard( + logdir, + embeddings_freq=1, + embeddings_metadata={"test_embedding": "metadata.tsv"}, + ) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + + with open(os.path.join(logdir, "projector_config.pbtxt")) as f: + self.assertEqual( + f.readlines(), + [ + "embeddings {\n", + " tensor_name: " + '"layer_with_weights-0/embeddings/.ATTRIBUTES/' + 'VARIABLE_VALUE"\n', + ' metadata_path: "metadata.tsv"\n', + "}\n", + ], + ) + + @pytest.mark.requires_trainable_backend + def test_custom_summary(self): + def scalar_v2_mock(name, data, step=None): + """A reimplementation of the scalar plugin to avoid circular + deps.""" + metadata = SummaryMetadata() + # Should match value in tensorboard/plugins/scalar/metadata.py. + metadata.plugin_data.plugin_name = "scalars" + with summary.experimental.summary_scope( + name, "scalar_summary", values=[data, step] + ) as (tag, _): + tensor = backend.convert_to_tensor(data, dtype="float32") + if backend.backend() == "torch": + # TODO: Use device scope after the API is added. + if tensor.is_cuda: + tensor = tensor.cpu() + summary.write( + tag=tag, + tensor=tensor, + step=step, + metadata=metadata, + ) + + class LayerWithSummary(layers.Layer): + def call(self, x): + scalar_v2_mock("custom_summary", ops.sum(x)) + return x + + model = models.Sequential( + [ + layers.Input((5,)), + LayerWithSummary(), + ] + ) + + # summary ops not compatible with XLA + model.compile("sgd", "mse", jit_compile=False) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard(logdir, update_freq=1) + x, y = np.ones((10, 5)), np.ones((10, 5)) + model.fit( + x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk] + ) + summary_file = list_summaries(logdir) + # TODO: tensorflow will tag with model/layer_with_summary/custom_summary + # Jax will only use custom_summary tag + self.assertEqual( + self._strip_to_only_final_name(summary_file.scalars), + { + _ObservedSummary(logdir=train_dir, tag="batch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + _ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"), + _ObservedSummary(logdir=validation_dir, tag="epoch_loss"), + _ObservedSummary( + logdir=validation_dir, + tag="evaluation_loss_vs_iterations", + ), + _ObservedSummary( + logdir=train_dir, + tag="custom_summary", + ), + _ObservedSummary( + logdir=validation_dir, + tag="custom_summary", + ), + }, + ) + # self.assertEqual( + # summary_file.scalars, + # { + # _ObservedSummary(logdir=train_dir, tag="batch_loss"), + # _ObservedSummary(logdir=train_dir, tag="epoch_loss"), + # _ObservedSummary(logdir=validation_dir, + # tag="epoch_loss"), + # _ObservedSummary( + # logdir=validation_dir, + # tag="evaluation_loss_vs_iterations", + # ), + # _ObservedSummary( + # logdir=train_dir, + # tag="model/layer_with_summary/custom_summary", + # ), + # _ObservedSummary( + # logdir=validation_dir, + # tag="model/layer_with_summary/custom_summary", + # ), + # }, + # ) + + def _strip_to_only_final_name(self, summaries): + """Removes all leading names in a summary + + Args: + summaries: A `set` of `_ObservedSummary` values. + + Returns: + A new `set` of `_ObservedSummary` values striped of all + name except for the terminal one. + + """ + result = set() + for s in summaries: + if "/" not in s.tag: + result.add(s) + else: + new_tag = s.tag.split("/")[-1] + result.add(s._replace(tag=new_tag)) + return result + + def _strip_layer_names(self, summaries, model_type): + """Deduplicate summary names modulo layer prefix. + + This removes the first slash-component of each tag name: for + instance, "foo/bar/baz" becomes "bar/baz". + + Args: + summaries: A `set` of `_ObservedSummary` values. + model_type: The model type currently being tested. + + Returns: + A new `set` of `_ObservedSummary` values with layer prefixes + removed. + """ + result = set() + for s in summaries: + if "/" not in s.tag: + raise ValueError(f"tag has no layer name: {s.tag!r}") + start_from = 2 if "subclass" in model_type else 1 + new_tag = "/".join(s.tag.split("/")[start_from:]) + result.add(s._replace(tag=new_tag)) + return result + + def _strip_variable_names(self, summaries): + """Remove `variable_n` from summary tag + + `variable_n` tag names are added with random numbers. Removing them + ensures deterministic tag names. + + Args: + summaries: A `set` of `_ObservedSummary` values. + + Returns: + A new `set` of `_ObservedSummary` values with layer prefixes + removed. + """ + result = set() + for s in summaries: + if "/" not in s.tag: + result.add(s) + else: + split_tag = s.tag.split("/") + if "variable" in split_tag[0]: + result.add(s._replace(tag=split_tag[-1])) + else: + result.add(s) + return result + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Torch backend requires blocking numpy conversion.", + ) + @pytest.mark.requires_trainable_backend + def test_TensorBoard_non_blocking(self): + logdir, _, _ = self._get_log_dirs() + model = models.Sequential([layers.Dense(1)]) + model.optimizer = optimizers.Adam() + tb = callbacks.TensorBoard(logdir) + cb_list = callbacks.CallbackList( + [tb], model=model, epochs=1, steps=100, verbose=0 + ) + tensor = ops.convert_to_tensor(1.0) + + def mock_numpy(): + raise RuntimeError( + "If this error is seen, TensorBoard is causing a blocking " + "NumPy conversion." + ) + + tensor.numpy = mock_numpy + + logs = {"metric": tensor} + + cb_list.on_train_begin(logs) + cb_list.on_epoch_begin(0, logs) + cb_list.on_train_batch_begin(0, logs) + cb_list.on_train_batch_end(0, logs) + cb_list.on_epoch_end(0, logs) + cb_list.on_train_end(logs) + + cb_list.on_test_begin(logs) + cb_list.on_test_batch_begin(0, logs) + cb_list.on_test_batch_end(0, logs) + cb_list.on_test_end(logs) + + cb_list.on_predict_begin(logs) + cb_list.on_predict_batch_begin(logs) + cb_list.on_predict_batch_end(logs) + cb_list.on_predict_end(logs) + + def _count_xplane_file(self, logdir): + profile_dir = os.path.join(logdir, "plugins", "profile") + count = 0 + for dirpath, dirnames, filenames in os.walk(profile_dir): + del dirpath # unused + del dirnames # unused + for filename in filenames: + if filename.endswith(".xplane.pb"): + count += 1 + return count + + def fitModelAndAssertKerasModelWritten(self, model): + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + logdir, train_dir, validation_dir = self._get_log_dirs() + tb_cbk = callbacks.TensorBoard( + logdir, write_graph=True, profile_batch=0 + ) + model.fit( + x, + y, + batch_size=2, + epochs=3, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + summary_file = list_summaries(logdir) + self.assertEqual( + summary_file.tensors, + { + _ObservedSummary(logdir=train_dir, tag="keras"), + }, + ) + if not model.run_eagerly: + # There should be one train graph + self.assertLen(summary_file.graph_defs, 1) + for graph_def in summary_file.graph_defs: + graph_def_str = str(graph_def) + + # All the model layers should appear in the graphs + for layer in model.layers: + if "input" not in layer.name: + self.assertIn(layer.name, graph_def_str) + + def test_TensorBoard_write_sequential_model_no_input_shape(self): + # TODO: Requires to_json implementation in trainer + # model = models.Sequential( + # [ + # Conv2D(8, (3, 3)), + # Flatten(), + # Dense(1), + # ] + # ) + # model.compile("sgd", "mse") + # self.fitModelAndAssertKerasModelWritten(model) + pass + + def test_TensorBoard_write_sequential_model_with_input_shape(self): + # TODO: Requires to_json implementation in trainer + # model = models.Sequential( + # [ + # Input(input_shape=(10, 10, 1)), + # Conv2D(8, (3, 3)), + # Flatten(), + # Dense(1), + # ] + # ) + # model.compile("sgd", "mse") + # self.fitModelAndAssertKerasModelWritten(model) + pass + + def test_TensorBoard_write_model(self): + # TODO: Requires to_json implementation in trainer + # See https://github.com/keras-team/keras/blob/ \ + # a8d4a7f1ffc9de3c5932828a107e4e95e8803fb4/ \ + # keras/engine/training.py#L3313 + # inputs = Input([10, 10, 1]) + # x = Conv2D(8, (3, 3), activation="relu")(inputs) + # x = Flatten()(x) + # x = Dense(1)(x) + # model = models.Model(inputs=inputs, outputs=[x]) + # model.compile("sgd", "mse") + # breakpoint() + # self.fitModelAndAssertKerasModelWritten(model) + pass + + @pytest.mark.skipif( + backend.backend() not in ("jax", "tensorflow"), + reason="The profiling test can only run with TF and JAX backends.", + ) + def test_TensorBoard_auto_trace(self): + logdir, train_dir, validation_dir = self._get_log_dirs() + model = models.Sequential( + [ + layers.Input((10, 10, 1)), + layers.Flatten(), + layers.Dense(1), + ] + ) + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + if backend.backend() == "jax" and sys.version_info[1] < 12: + with pytest.warns(match="backend requires python >= 3.12"): + callbacks.TensorBoard( + logdir, histogram_freq=1, profile_batch=1, write_graph=False + ) + self.skipTest( + "Profiling with JAX and python < 3.12 " + "raises segmentation fault." + ) + + tb_cbk = callbacks.TensorBoard( + logdir, histogram_freq=1, profile_batch=1, write_graph=False + ) + model.compile("sgd", "mse") + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk], + ) + summary_file = list_summaries(logdir) + + self.assertEqual( + summary_file.tensors, + { + _ObservedSummary(logdir=train_dir, tag="batch_1"), + }, + ) + self.assertEqual(1, self._count_xplane_file(logdir=train_dir)) + pass diff --git a/keras/src/callbacks/terminate_on_nan.py b/keras/src/callbacks/terminate_on_nan.py new file mode 100644 index 000000000000..55f7e4c06ab8 --- /dev/null +++ b/keras/src/callbacks/terminate_on_nan.py @@ -0,0 +1,20 @@ +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.callbacks.callback import Callback +from keras.src.utils import io_utils + + +@keras_export("keras.callbacks.TerminateOnNaN") +class TerminateOnNaN(Callback): + """Callback that terminates training when a NaN loss is encountered.""" + + def on_batch_end(self, batch, logs=None): + logs = logs or {} + loss = logs.get("loss") + if loss is not None: + if np.isnan(loss) or np.isinf(loss): + io_utils.print_msg( + f"Batch {batch}: Invalid loss, terminating training" + ) + self.model.stop_training = True diff --git a/keras/src/callbacks/terminate_on_nan_test.py b/keras/src/callbacks/terminate_on_nan_test.py new file mode 100644 index 000000000000..f84b1b89b6bc --- /dev/null +++ b/keras/src/callbacks/terminate_on_nan_test.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from keras.src import callbacks +from keras.src import initializers +from keras.src import layers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.utils import numerical_utils + + +class TerminateOnNaNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_TerminateOnNaN(self): + TRAIN_SAMPLES = 10 + TEST_SAMPLES = 10 + INPUT_DIM = 3 + NUM_CLASSES = 2 + BATCH_SIZE = 4 + + np.random.seed(1337) + x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM)) + y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES) + x_test = np.random.random((TEST_SAMPLES, INPUT_DIM)) + y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES) + + y_test = numerical_utils.to_categorical(y_test) + y_train = numerical_utils.to_categorical(y_train) + model = Sequential() + initializer = initializers.Constant(value=1e5) + for _ in range(5): + model.add( + layers.Dense( + 2, + activation="relu", + kernel_initializer=initializer, + ) + ) + model.add(layers.Dense(NUM_CLASSES)) + model.compile(loss="mean_squared_error", optimizer="sgd") + + history = model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=[callbacks.TerminateOnNaN()], + epochs=20, + ) + loss = history.history["loss"] + self.assertEqual(len(loss), 1) + self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0])) diff --git a/keras/src/constraints/__init__.py b/keras/src/constraints/__init__.py new file mode 100644 index 000000000000..cfafab080cd6 --- /dev/null +++ b/keras/src/constraints/__init__.py @@ -0,0 +1,60 @@ +import inspect + +from keras.src.api_export import keras_export +from keras.src.constraints.constraints import Constraint +from keras.src.constraints.constraints import MaxNorm +from keras.src.constraints.constraints import MinMaxNorm +from keras.src.constraints.constraints import NonNeg +from keras.src.constraints.constraints import UnitNorm +from keras.src.saving import serialization_lib +from keras.src.utils.naming import to_snake_case + +ALL_OBJECTS = { + Constraint, + MaxNorm, + MinMaxNorm, + NonNeg, + UnitNorm, +} + +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) + + +@keras_export("keras.constraints.serialize") +def serialize(constraint): + return serialization_lib.serialize_keras_object(constraint) + + +@keras_export("keras.constraints.deserialize") +def deserialize(config, custom_objects=None): + """Return a Keras constraint object via its config.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.constraints.get") +def get(identifier): + """Retrieve a Keras constraint object via an identifier.""" + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj() + return obj + else: + raise ValueError( + f"Could not interpret constraint identifier: {identifier}" + ) diff --git a/keras/src/constraints/constraints.py b/keras/src/constraints/constraints.py new file mode 100644 index 000000000000..2fc9305e7486 --- /dev/null +++ b/keras/src/constraints/constraints.py @@ -0,0 +1,215 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export + + +@keras_export("keras.constraints.Constraint") +class Constraint: + """Base class for weight constraints. + + A `Constraint` instance works like a stateless function. + Users who subclass this + class should override the `__call__()` method, which takes a single + weight parameter and return a projected version of that parameter + (e.g. normalized or clipped). Constraints can be used with various Keras + layers via the `kernel_constraint` or `bias_constraint` arguments. + + Here's a simple example of a non-negative weight constraint: + + >>> class NonNegative(keras.constraints.Constraint): + ... + ... def __call__(self, w): + ... return w * ops.cast(ops.greater_equal(w, 0.), dtype=w.dtype) + + >>> weight = ops.convert_to_tensor((-1.0, 1.0)) + >>> NonNegative()(weight) + [0., 1.] + + Usage in a layer: + + >>> keras.layers.Dense(4, kernel_constraint=NonNegative()) + """ + + def __call__(self, w): + """Applies the constraint to the input weight variable. + + By default, the inputs weight variable is not modified. + Users should override this method to implement their own projection + function. + + Args: + w: Input weight variable. + + Returns: + Projected variable (by default, returns unmodified inputs). + """ + return w + + def get_config(self): + """Returns a Python dict of the object config. + + A constraint config is a Python dictionary (JSON-serializable) that can + be used to reinstantiate the same object. + + Returns: + Python dict containing the configuration of the constraint object. + """ + return {} + + @classmethod + def from_config(cls, config): + """Instantiates a weight constraint from a configuration dictionary. + + Example: + + ```python + constraint = UnitNorm() + config = constraint.get_config() + constraint = UnitNorm.from_config(config) + ``` + + Args: + config: A Python dictionary, the output of `get_config()`. + + Returns: + A `keras.constraints.Constraint` instance. + """ + return cls(**config) + + +@keras_export(["keras.constraints.MaxNorm", "keras.constraints.max_norm"]) +class MaxNorm(Constraint): + """MaxNorm weight constraint. + + Constrains the weights incident to each hidden unit + to have a norm less than or equal to a desired value. + + Also available via the shortcut function `keras.constraints.max_norm`. + + Args: + max_value: the maximum norm value for the incoming weights. + axis: integer, axis along which to calculate weight norms. + For instance, in a `Dense` layer the weight matrix + has shape `(input_dim, output_dim)`, + set `axis` to `0` to constrain each weight vector + of length `(input_dim,)`. + In a `Conv2D` layer with `data_format="channels_last"`, + the weight tensor has shape + `(rows, cols, input_depth, output_depth)`, + set `axis` to `[0, 1, 2]` + to constrain the weights of each filter tensor of size + `(rows, cols, input_depth)`. + + """ + + def __init__(self, max_value=2, axis=0): + self.max_value = max_value + self.axis = axis + + def __call__(self, w): + w = backend.convert_to_tensor(w) + norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True)) + desired = ops.clip(norms, 0, self.max_value) + return ops.cast(w, norms.dtype) * ( + desired / (backend.epsilon() + norms) + ) + + def get_config(self): + return {"max_value": self.max_value, "axis": self.axis} + + +@keras_export(["keras.constraints.NonNeg", "keras.constraints.non_neg"]) +class NonNeg(Constraint): + """Constrains the weights to be non-negative.""" + + def __call__(self, w): + w = backend.convert_to_tensor(w) + return ops.multiply(w, ops.greater_equal(w, 0.0)) + + +@keras_export(["keras.constraints.UnitNorm", "keras.constraints.unit_norm"]) +class UnitNorm(Constraint): + """Constrains the weights incident to each hidden unit to have unit norm. + + Args: + axis: integer, axis along which to calculate weight norms. + For instance, in a `Dense` layer the weight matrix + has shape `(input_dim, output_dim)`, + set `axis` to `0` to constrain each weight vector + of length `(input_dim,)`. + In a `Conv2D` layer with `data_format="channels_last"`, + the weight tensor has shape + `(rows, cols, input_depth, output_depth)`, + set `axis` to `[0, 1, 2]` + to constrain the weights of each filter tensor of size + `(rows, cols, input_depth)`. + """ + + def __init__(self, axis=0): + self.axis = axis + + def __call__(self, w): + w = backend.convert_to_tensor(w) + norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True)) + return ops.cast(w, norms.dtype) / (backend.epsilon() + norms) + + def get_config(self): + return {"axis": self.axis} + + +@keras_export( + ["keras.constraints.MinMaxNorm", "keras.constraints.min_max_norm"] +) +class MinMaxNorm(Constraint): + """MinMaxNorm weight constraint. + + Constrains the weights incident to each hidden unit + to have the norm between a lower bound and an upper bound. + + Args: + min_value: the minimum norm for the incoming weights. + max_value: the maximum norm for the incoming weights. + rate: rate for enforcing the constraint: weights will be + rescaled to yield + `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. + Effectively, this means that rate=1.0 stands for strict + enforcement of the constraint, while rate<1.0 means that + weights will be rescaled at each step to slowly move + towards a value inside the desired interval. + axis: integer, axis along which to calculate weight norms. + For instance, in a `Dense` layer the weight matrix + has shape `(input_dim, output_dim)`, + set `axis` to `0` to constrain each weight vector + of length `(input_dim,)`. + In a `Conv2D` layer with `data_format="channels_last"`, + the weight tensor has shape + `(rows, cols, input_depth, output_depth)`, + set `axis` to `[0, 1, 2]` + to constrain the weights of each filter tensor of size + `(rows, cols, input_depth)`. + """ + + def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): + self.min_value = min_value + self.max_value = max_value + self.rate = rate + self.axis = axis + + def __call__(self, w): + w = backend.convert_to_tensor(w) + norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True)) + desired = ( + self.rate * ops.clip(norms, self.min_value, self.max_value) + + (1 - self.rate) * norms + ) + return ops.cast(w, norms.dtype) * ( + desired / (backend.epsilon() + norms) + ) + + def get_config(self): + return { + "min_value": self.min_value, + "max_value": self.max_value, + "rate": self.rate, + "axis": self.axis, + } diff --git a/keras/src/constraints/constraints_test.py b/keras/src/constraints/constraints_test.py new file mode 100644 index 000000000000..50f9b3134545 --- /dev/null +++ b/keras/src/constraints/constraints_test.py @@ -0,0 +1,101 @@ +import numpy as np + +from keras.src import backend +from keras.src import constraints +from keras.src import testing + + +def get_example_array(): + np.random.seed(3537) + example_array = np.random.random((100, 100)) * 100.0 - 50.0 + example_array[0, 0] = 0.0 # Possible edge case + return example_array + + +class ConstraintsTest(testing.TestCase): + def test_max_norm(self): + constraint_fn = constraints.MaxNorm(2.0) + x = np.array([[0, 0, 0], [1.0, 0, 0], [3, 0, 0], [3, 3, 3]]).T + target = np.array( + [ + [0, 0, 0], + [1.0, 0, 0], + [2.0, 0, 0], + [2.0 / np.sqrt(3), 2.0 / np.sqrt(3), 2.0 / np.sqrt(3)], + ] + ).T + output = constraint_fn(x) + self.assertAllClose(target, output) + + def test_non_neg(self): + constraint_fn = constraints.NonNeg() + output = constraint_fn(get_example_array()) + output = backend.convert_to_numpy(output) + self.assertTrue((np.min(output, axis=1) >= 0.0).all()) + + def test_unit_norm(self): + constraint_fn = constraints.UnitNorm() + output = constraint_fn(get_example_array()) + output = backend.convert_to_numpy(output) + l2 = np.sqrt(np.sum(np.square(output), axis=0)) + self.assertAllClose(l2, 1.0) + + def test_min_max_norm(self): + constraint_fn = constraints.MinMaxNorm(min_value=0.2, max_value=0.5) + output = constraint_fn(get_example_array()) + output = backend.convert_to_numpy(output) + l2 = np.sqrt(np.sum(np.square(output), axis=0)) + self.assertTrue(np.all(l2 >= 0.2)) + self.assertTrue(np.all(l2 <= 0.5 + 1e-6)) + + def test_get_method(self): + obj = constraints.get("unit_norm") + self.assertTrue(obj, constraints.UnitNorm) + + obj = constraints.get(None) + self.assertEqual(obj, None) + + with self.assertRaises(ValueError): + constraints.get("typo") + + def test_default_constraint_call(self): + constraint_fn = constraints.Constraint() + x = np.array([1.0, 2.0, 3.0]) + output = constraint_fn(x) + self.assertAllClose(x, output) + + def test_constraint_get_config(self): + constraint_fn = constraints.Constraint() + config = constraint_fn.get_config() + self.assertEqual(config, {}) + + def test_constraint_from_config(self): + constraint_fn = constraints.Constraint() + config = constraint_fn.get_config() + recreated_constraint_fn = constraints.Constraint.from_config(config) + self.assertIsInstance(recreated_constraint_fn, constraints.Constraint) + + def test_max_norm_get_config(self): + constraint_fn = constraints.MaxNorm(max_value=3.0, axis=1) + config = constraint_fn.get_config() + expected_config = {"max_value": 3.0, "axis": 1} + self.assertEqual(config, expected_config) + + def test_unit_norm_get_config(self): + constraint_fn = constraints.UnitNorm(axis=1) + config = constraint_fn.get_config() + expected_config = {"axis": 1} + self.assertEqual(config, expected_config) + + def test_min_max_norm_get_config(self): + constraint_fn = constraints.MinMaxNorm( + min_value=0.5, max_value=2.0, rate=0.7, axis=1 + ) + config = constraint_fn.get_config() + expected_config = { + "min_value": 0.5, + "max_value": 2.0, + "rate": 0.7, + "axis": 1, + } + self.assertEqual(config, expected_config) diff --git a/keras/src/datasets/__init__.py b/keras/src/datasets/__init__.py new file mode 100644 index 000000000000..b62b41c4e61b --- /dev/null +++ b/keras/src/datasets/__init__.py @@ -0,0 +1,10 @@ +"""Small NumPy datasets for debugging/testing.""" + +from keras.src.datasets import boston_housing +from keras.src.datasets import california_housing +from keras.src.datasets import cifar10 +from keras.src.datasets import cifar100 +from keras.src.datasets import fashion_mnist +from keras.src.datasets import imdb +from keras.src.datasets import mnist +from keras.src.datasets import reuters diff --git a/keras/src/datasets/boston_housing.py b/keras/src/datasets/boston_housing.py new file mode 100644 index 000000000000..7864ea126b3b --- /dev/null +++ b/keras/src/datasets/boston_housing.py @@ -0,0 +1,70 @@ +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.file_utils import get_file + + +@keras_export("keras.datasets.boston_housing.load_data") +def load_data(path="boston_housing.npz", test_split=0.2, seed=113): + """Loads the Boston Housing dataset. + + This is a dataset taken from the StatLib library which is maintained at + Carnegie Mellon University. + + **WARNING:** This dataset has an ethical problem: the authors of this + dataset included a variable, "B", that may appear to assume that racial + self-segregation influences house prices. As such, we strongly discourage + the use of this dataset, unless in the context of illustrating ethical + issues in data science and machine learning. + + Samples contain 13 attributes of houses at different locations around the + Boston suburbs in the late 1970s. Targets are the median values of + the houses at a location (in k$). + + The attributes themselves are defined in the + [StatLib website](http://lib.stat.cmu.edu/datasets/boston). + + Args: + path: path where to cache the dataset locally + (relative to `~/.keras/datasets`). + test_split: fraction of the data to reserve as test set. + seed: Random seed for shuffling the data + before computing the test split. + + Returns: + Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`. + + **x_train, x_test**: NumPy arrays with shape `(num_samples, 13)` + containing either the training samples (for x_train), + or test samples (for y_train). + + **y_train, y_test**: NumPy arrays of shape `(num_samples,)` containing the + target scalars. The targets are float scalars typically between 10 and + 50 that represent the home prices in k$. + """ + assert 0 <= test_split < 1 + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + path, + origin=f"{origin_folder}boston_housing.npz", + file_hash=( # noqa: E501 + "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5" + ), + ) + with np.load(path, allow_pickle=True) as f: + x = f["x"] + y = f["y"] + + rng = np.random.RandomState(seed) + indices = np.arange(len(x)) + rng.shuffle(indices) + x = x[indices] + y = y[indices] + + x_train = np.array(x[: int(len(x) * (1 - test_split))]) + y_train = np.array(y[: int(len(x) * (1 - test_split))]) + x_test = np.array(x[int(len(x) * (1 - test_split)) :]) + y_test = np.array(y[int(len(x) * (1 - test_split)) :]) + return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/datasets/california_housing.py b/keras/src/datasets/california_housing.py new file mode 100644 index 000000000000..f93a8f47be15 --- /dev/null +++ b/keras/src/datasets/california_housing.py @@ -0,0 +1,104 @@ +"""Boston housing price regression dataset.""" + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.file_utils import get_file + + +@keras_export("keras.datasets.california_housing.load_data") +def load_data( + version="large", path="california_housing.npz", test_split=0.2, seed=113 +): + """Loads the California Housing dataset. + + This dataset was obtained from the [StatLib repository]( + https://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.html). + + It's a continuous regression dataset with 20,640 samples with + 8 features each. + + The target variable is a scalar: the median house value + for California districts, in dollars. + + The 8 input features are the following: + + - MedInc: median income in block group + - HouseAge: median house age in block group + - AveRooms: average number of rooms per household + - AveBedrms: average number of bedrooms per household + - Population: block group population + - AveOccup: average number of household members + - Latitude: block group latitude + - Longitude: block group longitude + + This dataset was derived from the 1990 U.S. census, using one row + per census block group. A block group is the smallest geographical + unit for which the U.S. Census Bureau publishes sample data + (a block group typically has a population of 600 to 3,000 people). + + A household is a group of people residing within a home. + Since the average number of rooms and bedrooms in this dataset are + provided per household, these columns may take surprisingly large + values for block groups with few households and many empty houses, + such as vacation resorts. + + Args: + version: `"small"` or `"large"`. The small version + contains 600 samples, the large version contains + 20,640 samples. The purpose of the small version is + to serve as an approximate replacement for the + deprecated `boston_housing` dataset. + path: path where to cache the dataset locally + (relative to `~/.keras/datasets`). + test_split: fraction of the data to reserve as test set. + seed: Random seed for shuffling the data + before computing the test split. + + Returns: + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`, `x_test`**: numpy arrays with shape `(num_samples, 8)` + containing either the training samples (for `x_train`), + or test samples (for `y_train`). + + **`y_train`, `y_test`**: numpy arrays of shape `(num_samples,)` + containing the target scalars. The targets are float scalars + typically between 25,000 and 500,000 that represent + the home prices in dollars. + """ + assert 0 <= test_split < 1 + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + path, + origin=f"{origin_folder}california_housing.npz", + file_hash=( # noqa: E501 + "1a2e3a52e0398de6463aebe6f4a8da34fb21fbb6b934cf88c3425e766f2a1a6f" + ), + ) + with np.load(path, allow_pickle=True) as f: + x = f["x"] + y = f["y"] + + if version == "small": + x = x[:600] + y = y[:600] + elif version != "large": + raise ValueError( + "Argument `version` must be one of 'small', 'large'. " + f"Received: version={version}" + ) + + rng = np.random.RandomState(seed) + indices = np.arange(len(x)) + rng.shuffle(indices) + x = x[indices] + y = y[indices] + + x_train = np.array(x[: int(len(x) * (1 - test_split))]) + y_train = np.array(y[: int(len(x) * (1 - test_split))]) + x_test = np.array(x[int(len(x) * (1 - test_split)) :]) + y_test = np.array(y[int(len(x) * (1 - test_split)) :]) + return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/datasets/cifar.py b/keras/src/datasets/cifar.py new file mode 100644 index 000000000000..4998174abbea --- /dev/null +++ b/keras/src/datasets/cifar.py @@ -0,0 +1,28 @@ +"""Utilities common to CIFAR10 and CIFAR100 datasets.""" + +import _pickle as cPickle + + +def load_batch(fpath, label_key="labels"): + """Internal utility for parsing CIFAR data. + + Args: + fpath: path the file to parse. + label_key: key for label data in the retrieve + dictionary. + + Returns: + A tuple `(data, labels)`. + """ + with open(fpath, "rb") as f: + d = cPickle.load(f, encoding="bytes") + # decode utf8 + d_decoded = {} + for k, v in d.items(): + d_decoded[k.decode("utf8")] = v + d = d_decoded + data = d["data"] + labels = d[label_key] + + data = data.reshape(data.shape[0], 3, 32, 32) + return data, labels diff --git a/keras/src/datasets/cifar10.py b/keras/src/datasets/cifar10.py new file mode 100644 index 000000000000..8b0f2e995fef --- /dev/null +++ b/keras/src/datasets/cifar10.py @@ -0,0 +1,101 @@ +"""CIFAR10 small images classification dataset.""" + +import os + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.datasets.cifar import load_batch +from keras.src.utils.file_utils import get_file + + +@keras_export("keras.datasets.cifar10.load_data") +def load_data(): + """Loads the CIFAR10 dataset. + + This is a dataset of 50,000 32x32 color training images and 10,000 test + images, labeled over 10 categories. See more info at the + [CIFAR homepage](https://www.cs.toronto.edu/~kriz/cifar.html). + + The classes are: + + | Label | Description | + |:-----:|-------------| + | 0 | airplane | + | 1 | automobile | + | 2 | bird | + | 3 | cat | + | 4 | deer | + | 5 | dog | + | 6 | frog | + | 7 | horse | + | 8 | ship | + | 9 | truck | + + Returns: + Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`**: `uint8` NumPy array of grayscale image data with shapes + `(50000, 32, 32, 3)`, containing the training data. Pixel values range + from 0 to 255. + + **`y_train`**: `uint8` NumPy array of labels (integers in range 0-9) + with shape `(50000, 1)` for the training data. + + **`x_test`**: `uint8` NumPy array of grayscale image data with shapes + `(10000, 32, 32, 3)`, containing the test data. Pixel values range + from 0 to 255. + + **`y_test`**: `uint8` NumPy array of labels (integers in range 0-9) + with shape `(10000, 1)` for the test data. + + Example: + + ```python + (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() + assert x_train.shape == (50000, 32, 32, 3) + assert x_test.shape == (10000, 32, 32, 3) + assert y_train.shape == (50000, 1) + assert y_test.shape == (10000, 1) + ``` + """ + dirname = "cifar-10-batches-py-target" + origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + path = get_file( + fname=dirname, + origin=origin, + extract=True, + file_hash=( # noqa: E501 + "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" + ), + ) + + num_train_samples = 50000 + + x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8") + y_train = np.empty((num_train_samples,), dtype="uint8") + + # batches are within an inner folder + path = os.path.join(path, "cifar-10-batches-py") + for i in range(1, 6): + fpath = os.path.join(path, f"data_batch_{i}") + ( + x_train[(i - 1) * 10000 : i * 10000, :, :, :], + y_train[(i - 1) * 10000 : i * 10000], + ) = load_batch(fpath) + + fpath = os.path.join(path, "test_batch") + x_test, y_test = load_batch(fpath) + + y_train = np.reshape(y_train, (len(y_train), 1)) + y_test = np.reshape(y_test, (len(y_test), 1)) + + if backend.image_data_format() == "channels_last": + x_train = x_train.transpose(0, 2, 3, 1) + x_test = x_test.transpose(0, 2, 3, 1) + + x_test = x_test.astype(x_train.dtype) + y_test = y_test.astype(y_train.dtype) + + return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/datasets/cifar100.py b/keras/src/datasets/cifar100.py new file mode 100644 index 000000000000..7576afd89878 --- /dev/null +++ b/keras/src/datasets/cifar100.py @@ -0,0 +1,86 @@ +"""CIFAR100 small images classification dataset.""" + +import os + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.datasets.cifar import load_batch +from keras.src.utils.file_utils import get_file + + +@keras_export("keras.datasets.cifar100.load_data") +def load_data(label_mode="fine"): + """Loads the CIFAR100 dataset. + + This is a dataset of 50,000 32x32 color training images and + 10,000 test images, labeled over 100 fine-grained classes that are + grouped into 20 coarse-grained classes. See more info at the + [CIFAR homepage](https://www.cs.toronto.edu/~kriz/cifar.html). + + Args: + label_mode: one of `"fine"`, `"coarse"`. + If it is `"fine"`, the category labels + are the fine-grained labels, and if it is `"coarse"`, + the output labels are the coarse-grained superclasses. + + Returns: + Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`**: `uint8` NumPy array of grayscale image data with shapes + `(50000, 32, 32, 3)`, containing the training data. Pixel values range + from 0 to 255. + + **`y_train`**: `uint8` NumPy array of labels (integers in range 0-99) + with shape `(50000, 1)` for the training data. + + **`x_test`**: `uint8` NumPy array of grayscale image data with shapes + `(10000, 32, 32, 3)`, containing the test data. Pixel values range + from 0 to 255. + + **`y_test`**: `uint8` NumPy array of labels (integers in range 0-99) + with shape `(10000, 1)` for the test data. + + Example: + + ```python + (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data() + assert x_train.shape == (50000, 32, 32, 3) + assert x_test.shape == (10000, 32, 32, 3) + assert y_train.shape == (50000, 1) + assert y_test.shape == (10000, 1) + ``` + """ + if label_mode not in ["fine", "coarse"]: + raise ValueError( + '`label_mode` must be one of `"fine"`, `"coarse"`. ' + f"Received: label_mode={label_mode}." + ) + + dirname = "cifar-100-python-target" + origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + path = get_file( + fname=dirname, + origin=origin, + extract=True, + file_hash=( # noqa: E501 + "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" + ), + ) + + path = os.path.join(path, "cifar-100-python") + fpath = os.path.join(path, "train") + x_train, y_train = load_batch(fpath, label_key=f"{label_mode}_labels") + + fpath = os.path.join(path, "test") + x_test, y_test = load_batch(fpath, label_key=f"{label_mode}_labels") + + y_train = np.reshape(y_train, (len(y_train), 1)) + y_test = np.reshape(y_test, (len(y_test), 1)) + + if backend.image_data_format() == "channels_last": + x_train = x_train.transpose(0, 2, 3, 1) + x_test = x_test.transpose(0, 2, 3, 1) + + return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/datasets/fashion_mnist.py b/keras/src/datasets/fashion_mnist.py new file mode 100644 index 000000000000..6700490e058d --- /dev/null +++ b/keras/src/datasets/fashion_mnist.py @@ -0,0 +1,96 @@ +"""Fashion-MNIST dataset.""" + +import gzip +import os + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.file_utils import get_file + + +@keras_export("keras.datasets.fashion_mnist.load_data") +def load_data(): + """Loads the Fashion-MNIST dataset. + + This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories, + along with a test set of 10,000 images. This dataset can be used as + a drop-in replacement for MNIST. + + The classes are: + + | Label | Description | + |:-----:|-------------| + | 0 | T-shirt/top | + | 1 | Trouser | + | 2 | Pullover | + | 3 | Dress | + | 4 | Coat | + | 5 | Sandal | + | 6 | Shirt | + | 7 | Sneaker | + | 8 | Bag | + | 9 | Ankle boot | + + Returns: + + Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`**: `uint8` NumPy array of grayscale image data with shapes + `(60000, 28, 28)`, containing the training data. + + **`y_train`**: `uint8` NumPy array of labels (integers in range 0-9) + with shape `(60000,)` for the training data. + + **`x_test`**: `uint8` NumPy array of grayscale image data with shapes + (10000, 28, 28), containing the test data. + + **`y_test`**: `uint8` NumPy array of labels (integers in range 0-9) + with shape `(10000,)` for the test data. + + Example: + + ```python + (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() + assert x_train.shape == (60000, 28, 28) + assert x_test.shape == (10000, 28, 28) + assert y_train.shape == (60000,) + assert y_test.shape == (10000,) + ``` + + License: + + The copyright for Fashion-MNIST is held by Zalando SE. + Fashion-MNIST is licensed under the [MIT license]( + https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE). + """ + dirname = os.path.join("datasets", "fashion-mnist") + base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + files = [ + "train-labels-idx1-ubyte.gz", + "train-images-idx3-ubyte.gz", + "t10k-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", + ] + + paths = [] + for fname in files: + paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname)) + + with gzip.open(paths[0], "rb") as lbpath: + y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[1], "rb") as imgpath: + x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape( + len(y_train), 28, 28 + ) + + with gzip.open(paths[2], "rb") as lbpath: + y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[3], "rb") as imgpath: + x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape( + len(y_test), 28, 28 + ) + + return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/datasets/imdb.py b/keras/src/datasets/imdb.py new file mode 100644 index 000000000000..753d7474cd54 --- /dev/null +++ b/keras/src/datasets/imdb.py @@ -0,0 +1,188 @@ +"""IMDB sentiment classification dataset.""" + +import json + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.file_utils import get_file +from keras.src.utils.python_utils import remove_long_seq + + +@keras_export("keras.datasets.imdb.load_data") +def load_data( + path="imdb.npz", + num_words=None, + skip_top=0, + maxlen=None, + seed=113, + start_char=1, + oov_char=2, + index_from=3, + **kwargs, +): + """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). + + This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment + (positive/negative). Reviews have been preprocessed, and each review is + encoded as a list of word indexes (integers). + For convenience, words are indexed by overall frequency in the dataset, + so that for instance the integer "3" encodes the 3rd most frequent word in + the data. This allows for quick filtering operations such as: + "only consider the top 10,000 most + common words, but eliminate the top 20 most common words". + + As a convention, "0" does not stand for a specific word, but instead is used + to encode the pad token. + + Args: + path: where to cache the data (relative to `~/.keras/dataset`). + num_words: integer or None. Words are + ranked by how often they occur (in the training set) and only + the `num_words` most frequent words are kept. Any less frequent word + will appear as `oov_char` value in the sequence data. If None, + all words are kept. Defaults to `None`. + skip_top: skip the top N most frequently occurring words + (which may not be informative). These words will appear as + `oov_char` value in the dataset. When 0, no words are + skipped. Defaults to `0`. + maxlen: int or None. Maximum sequence length. + Any longer sequence will be truncated. None, means no truncation. + Defaults to `None`. + seed: int. Seed for reproducible data shuffling. + start_char: int. The start of a sequence will be marked with this + character. 0 is usually the padding character. Defaults to `1`. + oov_char: int. The out-of-vocabulary character. + Words that were cut out because of the `num_words` or + `skip_top` limits will be replaced with this character. + index_from: int. Index actual words with this index and higher. + + Returns: + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`, `x_test`**: lists of sequences, which are lists of indexes + (integers). If the num_words argument was specific, the maximum + possible index value is `num_words - 1`. If the `maxlen` argument was + specified, the largest possible sequence length is `maxlen`. + + **`y_train`, `y_test`**: lists of integer labels (1 or 0). + + **Note**: The 'out of vocabulary' character is only used for + words that were present in the training set but are not included + because they're not making the `num_words` cut here. + Words that were not seen in the training set but are in the test set + have simply been skipped. + """ + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + fname=path, + origin=f"{origin_folder}imdb.npz", + file_hash=( # noqa: E501 + "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f" + ), + ) + with np.load(path, allow_pickle=True) as f: + x_train, labels_train = f["x_train"], f["y_train"] + x_test, labels_test = f["x_test"], f["y_test"] + + rng = np.random.RandomState(seed) + indices = np.arange(len(x_train)) + rng.shuffle(indices) + x_train = x_train[indices] + labels_train = labels_train[indices] + + indices = np.arange(len(x_test)) + rng.shuffle(indices) + x_test = x_test[indices] + labels_test = labels_test[indices] + + if start_char is not None: + x_train = [[start_char] + [w + index_from for w in x] for x in x_train] + x_test = [[start_char] + [w + index_from for w in x] for x in x_test] + elif index_from: + x_train = [[w + index_from for w in x] for x in x_train] + x_test = [[w + index_from for w in x] for x in x_test] + else: + x_train = [[w for w in x] for x in x_train] + x_test = [[w for w in x] for x in x_test] + + if maxlen: + x_train, labels_train = remove_long_seq(maxlen, x_train, labels_train) + x_test, labels_test = remove_long_seq(maxlen, x_test, labels_test) + if not x_train or not x_test: + raise ValueError( + "After filtering for sequences shorter than maxlen=" + f"{str(maxlen)}, no sequence was kept. Increase maxlen." + ) + + xs = x_train + x_test + labels = np.concatenate([labels_train, labels_test]) + + if not num_words: + num_words = max(max(x) for x in xs) + + # by convention, use 2 as OOV word + # reserve 'index_from' (=3 by default) characters: + # 0 (padding), 1 (start), 2 (OOV) + if oov_char is not None: + xs = [ + [w if (skip_top <= w < num_words) else oov_char for w in x] + for x in xs + ] + else: + xs = [[w for w in x if skip_top <= w < num_words] for x in xs] + + idx = len(x_train) + x_train, y_train = np.array(xs[:idx], dtype="object"), labels[:idx] + x_test, y_test = np.array(xs[idx:], dtype="object"), labels[idx:] + return (x_train, y_train), (x_test, y_test) + + +@keras_export("keras.datasets.imdb.get_word_index") +def get_word_index(path="imdb_word_index.json"): + """Retrieves a dict mapping words to their index in the IMDB dataset. + + Args: + path: where to cache the data (relative to `~/.keras/dataset`). + + Returns: + The word index dictionary. Keys are word strings, values are their + index. + + Example: + + ```python + # Use the default parameters to keras.datasets.imdb.load_data + start_char = 1 + oov_char = 2 + index_from = 3 + # Retrieve the training sequences. + (x_train, _), _ = keras.datasets.imdb.load_data( + start_char=start_char, oov_char=oov_char, index_from=index_from + ) + # Retrieve the word index file mapping words to indices + word_index = keras.datasets.imdb.get_word_index() + # Reverse the word index to obtain a dict mapping indices to words + # And add `index_from` to indices to sync with `x_train` + inverted_word_index = dict( + (i + index_from, word) for (word, i) in word_index.items() + ) + # Update `inverted_word_index` to include `start_char` and `oov_char` + inverted_word_index[start_char] = "[START]" + inverted_word_index[oov_char] = "[OOV]" + # Decode the first sequence in the dataset + decoded_sequence = " ".join(inverted_word_index[i] for i in x_train[0]) + ``` + """ + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + fname=path, + origin=f"{origin_folder}imdb_word_index.json", + file_hash="bfafd718b763782e994055a2d397834f", + ) + with open(path) as f: + return json.load(f) diff --git a/keras/src/datasets/mnist.py b/keras/src/datasets/mnist.py new file mode 100644 index 000000000000..697801b92cdf --- /dev/null +++ b/keras/src/datasets/mnist.py @@ -0,0 +1,71 @@ +"""MNIST handwritten digits dataset.""" + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.file_utils import get_file + + +@keras_export("keras.datasets.mnist.load_data") +def load_data(path="mnist.npz"): + """Loads the MNIST dataset. + + This is a dataset of 60,000 28x28 grayscale images of the 10 digits, + along with a test set of 10,000 images. + More info can be found at the + [MNIST homepage](http://yann.lecun.com/exdb/mnist/). + + Args: + path: path where to cache the dataset locally + (relative to `~/.keras/datasets`). + + Returns: + Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`**: `uint8` NumPy array of grayscale image data with shapes + `(60000, 28, 28)`, containing the training data. Pixel values range + from 0 to 255. + + **`y_train`**: `uint8` NumPy array of digit labels (integers in range 0-9) + with shape `(60000,)` for the training data. + + **`x_test`**: `uint8` NumPy array of grayscale image data with shapes + `(10000, 28, 28)`, containing the test data. Pixel values range + from 0 to 255. + + **`y_test`**: `uint8` NumPy array of digit labels (integers in range 0-9) + with shape `(10000,)` for the test data. + + Example: + + ```python + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + assert x_train.shape == (60000, 28, 28) + assert x_test.shape == (10000, 28, 28) + assert y_train.shape == (60000,) + assert y_test.shape == (10000,) + ``` + + License: + + Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, + which is a derivative work from original NIST datasets. + MNIST dataset is made available under the terms of the + [Creative Commons Attribution-Share Alike 3.0 license.]( + https://creativecommons.org/licenses/by-sa/3.0/) + """ + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + fname=path, + origin=f"{origin_folder}mnist.npz", + file_hash=( # noqa: E501 + "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1" + ), + ) + with np.load(path, allow_pickle=True) as f: + x_train, y_train = f["x_train"], f["y_train"] + x_test, y_test = f["x_test"], f["y_test"] + + return (x_train, y_train), (x_test, y_test) diff --git a/keras/src/datasets/reuters.py b/keras/src/datasets/reuters.py new file mode 100644 index 000000000000..b35a81859578 --- /dev/null +++ b/keras/src/datasets/reuters.py @@ -0,0 +1,221 @@ +"""Reuters topic classification dataset.""" + +import json + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.file_utils import get_file +from keras.src.utils.python_utils import remove_long_seq + + +@keras_export("keras.datasets.reuters.load_data") +def load_data( + path="reuters.npz", + num_words=None, + skip_top=0, + maxlen=None, + test_split=0.2, + seed=113, + start_char=1, + oov_char=2, + index_from=3, +): + """Loads the Reuters newswire classification dataset. + + This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics. + + This was originally generated by parsing and preprocessing the classic + Reuters-21578 dataset, but the preprocessing code is no longer packaged + with Keras. See this + [GitHub discussion](https://github.com/keras-team/keras/issues/12072) + for more info. + + Each newswire is encoded as a list of word indexes (integers). + For convenience, words are indexed by overall frequency in the dataset, + so that for instance the integer "3" encodes the 3rd most frequent word in + the data. This allows for quick filtering operations such as: + "only consider the top 10,000 most + common words, but eliminate the top 20 most common words". + + As a convention, "0" does not stand for a specific word, but instead is used + to encode any unknown word. + + Args: + path: where to cache the data (relative to `~/.keras/dataset`). + num_words: integer or None. Words are + ranked by how often they occur (in the training set) and only + the `num_words` most frequent words are kept. Any less frequent word + will appear as `oov_char` value in the sequence data. If None, + all words are kept. Defaults to `None`. + skip_top: skip the top N most frequently occurring words + (which may not be informative). These words will appear as + `oov_char` value in the dataset. 0 means no words are + skipped. Defaults to `0`. + maxlen: int or None. Maximum sequence length. + Any longer sequence will be truncated. None means no truncation. + Defaults to `None`. + test_split: Float between `0.` and `1.`. Fraction of the dataset to be + used as test data. `0.2` means that 20% of the dataset is used as + test data. Defaults to `0.2`. + seed: int. Seed for reproducible data shuffling. + start_char: int. The start of a sequence will be marked with this + character. 0 is usually the padding character. Defaults to `1`. + oov_char: int. The out-of-vocabulary character. + Words that were cut out because of the `num_words` or + `skip_top` limits will be replaced with this character. + index_from: int. Index actual words with this index and higher. + + Returns: + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + + **`x_train`, `x_test`**: lists of sequences, which are lists of indexes + (integers). If the num_words argument was specific, the maximum + possible index value is `num_words - 1`. If the `maxlen` argument was + specified, the largest possible sequence length is `maxlen`. + + **`y_train`, `y_test`**: lists of integer labels (1 or 0). + + **Note**: The 'out of vocabulary' character is only used for + words that were present in the training set but are not included + because they're not making the `num_words` cut here. + Words that were not seen in the training set but are in the test set + have simply been skipped. + """ + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + fname=path, + origin=f"{origin_folder}reuters.npz", + file_hash=( # noqa: E501 + "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916" + ), + ) + with np.load(path, allow_pickle=True) as f: + xs, labels = f["x"], f["y"] + + rng = np.random.RandomState(seed) + indices = np.arange(len(xs)) + rng.shuffle(indices) + xs = xs[indices] + labels = labels[indices] + + if start_char is not None: + xs = [[start_char] + [w + index_from for w in x] for x in xs] + elif index_from: + xs = [[w + index_from for w in x] for x in xs] + + if maxlen: + xs, labels = remove_long_seq(maxlen, xs, labels) + + if not num_words: + num_words = max(max(x) for x in xs) + + # by convention, use 2 as OOV word + # reserve 'index_from' (=3 by default) characters: + # 0 (padding), 1 (start), 2 (OOV) + if oov_char is not None: + xs = [ + [w if skip_top <= w < num_words else oov_char for w in x] + for x in xs + ] + else: + xs = [[w for w in x if skip_top <= w < num_words] for x in xs] + + idx = int(len(xs) * (1 - test_split)) + x_train, y_train = ( + np.array(xs[:idx], dtype="object"), + np.array(labels[:idx]), + ) + x_test, y_test = np.array(xs[idx:], dtype="object"), np.array(labels[idx:]) + + return (x_train, y_train), (x_test, y_test) + + +@keras_export("keras.datasets.reuters.get_word_index") +def get_word_index(path="reuters_word_index.json"): + """Retrieves a dict mapping words to their index in the Reuters dataset. + + Actual word indices starts from 3, with 3 indices reserved for: + 0 (padding), 1 (start), 2 (oov). + + E.g. word index of 'the' is 1, but the in the actual training data, the + index of 'the' will be 1 + 3 = 4. Vice versa, to translate word indices in + training data back to words using this mapping, indices need to subtract 3. + + Args: + path: where to cache the data (relative to `~/.keras/dataset`). + + Returns: + The word index dictionary. Keys are word strings, values are their + index. + """ + origin_folder = ( + "https://storage.googleapis.com/tensorflow/tf-keras-datasets/" + ) + path = get_file( + path, + origin=f"{origin_folder}reuters_word_index.json", + file_hash="4d44cc38712099c9e383dc6e5f11a921", + ) + with open(path) as f: + return json.load(f) + + +@keras_export("keras.datasets.reuters.get_label_names") +def get_label_names(): + """Returns labels as a list of strings with indices matching training data. + + Reference: + + - [Reuters Dataset](https://martin-thoma.com/nlp-reuters/) + """ + return ( + "cocoa", + "grain", + "veg-oil", + "earn", + "acq", + "wheat", + "copper", + "housing", + "money-supply", + "coffee", + "sugar", + "trade", + "reserves", + "ship", + "cotton", + "carcass", + "crude", + "nat-gas", + "cpi", + "money-fx", + "interest", + "gnp", + "meal-feed", + "alum", + "oilseed", + "gold", + "tin", + "strategic-metal", + "livestock", + "retail", + "ipi", + "iron-steel", + "rubber", + "heat", + "jobs", + "lei", + "bop", + "zinc", + "orange", + "pet-chem", + "dlr", + "gas", + "silver", + "wpi", + "hog", + "lead", + ) diff --git a/keras/src/distribution/__init__.py b/keras/src/distribution/__init__.py new file mode 100644 index 000000000000..04d907f35697 --- /dev/null +++ b/keras/src/distribution/__init__.py @@ -0,0 +1,11 @@ +from keras.src.distribution.distribution_lib import DataParallel +from keras.src.distribution.distribution_lib import DeviceMesh +from keras.src.distribution.distribution_lib import Distribution +from keras.src.distribution.distribution_lib import LayoutMap +from keras.src.distribution.distribution_lib import ModelParallel +from keras.src.distribution.distribution_lib import TensorLayout +from keras.src.distribution.distribution_lib import distribute_tensor +from keras.src.distribution.distribution_lib import distribution +from keras.src.distribution.distribution_lib import initialize +from keras.src.distribution.distribution_lib import list_devices +from keras.src.distribution.distribution_lib import set_distribution diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py new file mode 100644 index 000000000000..2daef40a2ed8 --- /dev/null +++ b/keras/src/distribution/distribution_lib.py @@ -0,0 +1,898 @@ +"""Unified high-level distribution APIs across backends. + +Currently only the JAX backend is supported. The TensorFlow backend +will be supported in the future (via tf.dtensor API). +""" + +import collections +import contextlib +import os +import re +import warnings + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import distribution_lib +from keras.src.backend.common import global_state + +DEFAULT_BATCH_DIM_NAME = "batch" +GLOBAL_ATTRIBUTE_NAME = "distribution" + + +@keras_export("keras.distribution.list_devices") +def list_devices(device_type=None): + """Return all the available devices based on the device type. + + Note: in a distributed setting, global devices are returned. + + Args: + device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`. + Defaults to `"gpu"` or `"tpu"` if available when + `device_type` is not provided. Otherwise + will return the `"cpu"` devices. + + Return: + List of devices that are available for distribute computation. + """ + return distribution_lib.list_devices(device_type) + + +@keras_export("keras.distribution.initialize") +def initialize(job_addresses=None, num_processes=None, process_id=None): + """Initialize the distribution system for multi-host/process setting. + + Calling `initialize` will prepare the backend for execution on multi-host + GPU or TPUs. It should be called before any computations. + + Note that the parameters can also be injected via environment variables, + which can be better controlled by the launch script at startup time. + For certain backend that also rely on the environment variables to + configure, Keras will properly forward them. + + Args: + job_addresses: string. Comma separated IP addresses for all the jobs + that will form the whole computation cluster. Note that for JAX + backend, only the address for job 0 (coodinator) is needed. For + certain runtime like cloud TPU, this value can be `None`, and the + backend will figure it out with the TPU environment variables. You + can also config this value via environment variable + `KERAS_DISTRIBUTION_JOB_ADDRESSES`. + num_processes: int. The number of worker/processes that will form the + whole computation cluster. For certain runtime like cloud TPU, this + value can be `None`, and the backend will figure it out with the TPU + environment variables. You can also configure this value via + environment variable `KERAS_DISTRIBUTION_NUM_PROCESSES`. + process_id: int. The ID number of the current worker/process. The value + should be ranged from `0` to `num_processes - 1`. `0` will indicate + the current worker/process is the master/coordinate job. You can + also configure this value via environment variable + `KERAS_DISTRIBUTION_PROCESS_ID`. + + Example: + Suppose there are two GPU processes, and process 0 is running at + address `10.0.0.1:1234`, and process 1 is running at address + `10.0.0.2:2345`. To configure such cluster, you can run + + On process 0: + ```python + keras.distribute.initialize( + job_addresses="10.0.0.1:1234,10.0.0.2:2345", + num_processes=2, + process_id=0) + ``` + + On process 1: + ```python + keras.distribute.initialize( + job_addresses="10.0.0.1:1234,10.0.0.2:2345", + num_processes=2, + process_id=1) + ``` + + or via the environment variables: + On process 0: + ```python + os.environ[ + "KERAS_DISTRIBUTION_JOB_ADDRESSES"] = "10.0.0.1:1234,10.0.0.2:2345" + os.environ["KERAS_DISTRIBUTION_NUM_PROCESSES"] = "2" + os.environ["KERAS_DISTRIBUTION_PROCESS_ID"] = "0" + keras.distribute.initialize() + ``` + + On process 1: + ```python + os.environ[ + "KERAS_DISTRIBUTION_JOB_ADDRESSES"] = "10.0.0.1:1234,10.0.0.2:2345" + os.environ["KERAS_DISTRIBUTION_NUM_PROCESSES"] = "2" + os.environ["KERAS_DISTRIBUTION_PROCESS_ID"] = "1" + keras.distribute.initialize() + ``` + + Also note that for JAX backend, the `job_addresses` can be further + reduced to just the master/coordinator address, which is + `10.0.0.1:1234`. + """ + if ( + job_addresses is None + and "KERAS_DISTRIBUTION_JOB_ADDRESSES" in os.environ + ): + job_addresses = os.environ["KERAS_DISTRIBUTION_JOB_ADDRESSES"] + if ( + num_processes is None + and "KERAS_DISTRIBUTION_NUM_PROCESSES" in os.environ + ): + num_processes = int(os.environ["KERAS_DISTRIBUTION_NUM_PROCESSES"]) + if process_id is None and "KERAS_DISTRIBUTION_PROCESS_ID" in os.environ: + process_id = int(os.environ["KERAS_DISTRIBUTION_PROCESS_ID"]) + distribution_lib.initialize(job_addresses, num_processes, process_id) + + +@keras_export("keras.distribution.DeviceMesh") +class DeviceMesh: + """A cluster of computation devices for distributed computation. + + This API is aligned with `jax.sharding.Mesh` and `tf.dtensor.Mesh`, which + represents the computation devices in the global context. + + See more details in [jax.sharding.Mesh]( + https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) + and [tf.dtensor.Mesh]( + https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh). + + Args: + shape: tuple of list of integers. The shape of the overall + `DeviceMesh`, e.g. `(8,)` for a data parallel only distribution, + or `(4, 2)` for a model+data parallel distribution. + axis_names: List of string. The logical name of the each axis for + the `DeviceMesh`. The length of the `axis_names` should match to + the rank of the `shape`. The `axis_names` will be used to + match/create the `TensorLayout` when distribute the data and + variables. + devices: Optional list of devices. Defaults to all the available + devices locally from `keras.distribution.list_devices()`. + """ + + def __init__( + self, + shape, + axis_names, + devices=None, + ): + if not shape or not axis_names: + raise ValueError( + "Shape and axis_names cannot be empty. Received: " + f"shape={shape}, axis_names={axis_names}" + ) + + if len(shape) != len(axis_names): + raise ValueError( + "Shape and axis_names should have same size. " + f"Received: shape={shape}, axis_names={axis_names}" + ) + if devices is None: + devices = list_devices() + devices = np.array(devices) + if np.prod(shape) != np.prod(devices.shape): + raise ValueError( + "Shape does not match the number of devices. " + f"Received: shape={shape}; devices.shape=" + f"{devices.shape}" + ) + + self._shape = shape + self._axis_names = axis_names + self._devices = np.reshape(devices, shape) + + @property + def shape(self): + return self._shape + + @property + def axis_names(self): + return self._axis_names + + @property + def devices(self): + return self._devices + + @property + def backend_mesh(self): + if not hasattr(self, "_backend_mesh"): + self._backend_mesh = distribution_lib._to_backend_mesh(self) + return self._backend_mesh + + def __repr__(self): + return ( + f"<{self.__class__.__name__} " + f"shape={self.shape}, axis_names={self.axis_names}>" + ) + + def __str__(self): + return self.__repr__() + + +@keras_export("keras.distribution.TensorLayout") +class TensorLayout: + """A layout to apply to a tensor. + + This API is aligned with `jax.sharding.NamedSharding` + and `tf.dtensor.Layout`. + + See more details in [jax.sharding.NamedSharding]( + https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) + and [tf.dtensor.Layout]( + https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Layout). + + Args: + axes: tuple of strings that should map to the `axis_names` in + a `DeviceMesh`. For any dimensions that doesn't need any sharding, + A `None` can be used a placeholder. + device_mesh: Optional `DeviceMesh` that will be used to create + the layout. The actual mapping of tensor to physical device + is not known until the mesh is specified. + """ + + def __init__(self, axes, device_mesh=None): + self._axes = tuple(axes) + self._device_mesh = device_mesh + self._validate_axes() + + @property + def axes(self): + return self._axes + + @property + def device_mesh(self): + return self._device_mesh + + @device_mesh.setter + def device_mesh(self, device_mesh): + if self._device_mesh is not None: + raise ValueError( + "Cannot override device mesh value. Existing " + f"value is {self._device_mesh}" + ) + self._device_mesh = device_mesh + self._validate_axes() + + @property + def backend_layout(self): + if not hasattr(self, "_backend_layout"): + self._backend_layout = distribution_lib._to_backend_layout(self) + return self._backend_layout + + def _validate_axes(self): + if self._device_mesh: + valid_axis_names = set(self._device_mesh.axis_names) + axis_names = set(self._axes) - set([None]) + if axis_names - valid_axis_names: + raise ValueError( + "Invalid axis names for Layout. Valid axis " + f"names: {valid_axis_names}, Got {axis_names}" + ) + + def __repr__(self): + return ( + f"<{self.__class__.__name__} " + f"axes={self.axes}, device_mesh={self.device_mesh}>" + ) + + def __str__(self): + return self.__repr__() + + +class Distribution: + """Base class for variable distribution strategies. + + A `Distribution` has following key functionalities: + + 1. Distribute the model variables to a `DeviceMesh`. + 2. Distribute the input data to a `DeviceMesh`. + 3. Distribute an intermediate state tensor in the model. + + It can create a context scope so that the framework to properly detect the + `Distribution` and distribute the variable/data accordingly. + + Args: + device_mesh: A `DeviceMesh` instance. + batch_dim_name: Optional string name for the batch dimension. + Defaults to None. + auto_shard_dataset: Automatically shard the dataset amongst + processes in a multi-process setting. Set to `False` if the dataset + is already sharded across hosts. Defaults to `True`. + """ + + def __init__( + self, device_mesh, batch_dim_name=None, auto_shard_dataset=True + ): + self._device_mesh = device_mesh + self._batch_dim_name = batch_dim_name + self._auto_shard_dataset = auto_shard_dataset + + def get_data_layout(self, data_shape): + """Retrieve the `TensorLayout` for the input data. + + Args: + data_shape: shape for the input data in list or tuple format. + + Returns: + The `TensorLayout` for the data, which can be used by + `backend.distribute_value()` to redistribute a input data. + """ + raise NotImplementedError() + + def get_variable_layout(self, variable): + """Retrieve the `TensorLayout` for the variable. + + Args: + variable: A `Variable` instance. + + return: + The `TensorLayout` for the variable, which can be used by + `backend.distribute_value()` to redistribute a variable. + """ + raise NotImplementedError() + + def get_tensor_layout(self, path): + """Retrieve the `TensorLayout` for the intermediate tensor. + + Args: + path: a string path for the corresponding tensor. + + return: + The `TensorLayout` for the intermediate tensor, which can be used + by `backend.relayout()` to reshard the tensor. Could also return + None. + """ + raise NotImplementedError() + + @contextlib.contextmanager + def scope(self): + """Context manager to make the `Distribution` current.""" + original_scope = distribution() + set_distribution(self) + try: + yield + finally: + set_distribution(original_scope) + + @property + def device_mesh(self): + return self._device_mesh + + @property + def batch_dim_name(self): + return self._batch_dim_name + + @property + def auto_shard_dataset(self): + return self._auto_shard_dataset + + @auto_shard_dataset.setter + def auto_shard_dataset(self, auto_shard_dataset): + self._auto_shard_dataset = auto_shard_dataset + + def distribute_dataset(self, dataset): + """Create a distributed dataset from the original global dataset. + + Args: + dataset: the original global dataset instance. + + Returns: + If `auto_shard_dataset` is `True`, returns a sharded dataset that + only produces data for the current local worker/process. Otherwise, + returns the original dataset. + + Raises: + ValueError: if auto-sharding is requested in a multi-process + setting, but the dataset type is not supported. + """ + raise NotImplementedError() + + def __repr__(self): + return f"<{self.__class__.__name__} device_mesh={self.device_mesh}>" + + def __str__(self): + return self.__repr__() + + +@keras_export("keras.distribution.DataParallel") +class DataParallel(Distribution): + """Distribution for data parallelism. + + You can choose to create this instance by either specifying + the `device_mesh` or `devices` arguments (but not both). + + The `device_mesh` argument is expected to be a `DeviceMesh` instance, + and is expected to be 1D only. In case that the mesh has multiple axes, + then the first axis will be treated as the data parallel dimension + (and a warning will be raised). + + When a list of `devices` are provided, they will be used to construct a + 1D mesh. + + When both `mesh` and `devices` are absent, then `list_devices()` + will be used to detect any available devices and create a 1D mesh from + them. + + Args: + device_mesh: Optional `DeviceMesh` instance. + devices: Optional list of devices. + auto_shard_dataset: Automatically shard the dataset amongst + processes in a multi-process setting. Set to `False` if the dataset + is already sharded across hosts. Defaults to `True`. + """ + + def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True): + if device_mesh: + self._initialize_with_device_mesh(device_mesh, auto_shard_dataset) + elif devices: + self._initialize_mesh_from_devices(devices, auto_shard_dataset) + else: + self._initialize_mesh_from_list_devices(auto_shard_dataset) + + # Those following attributes might get convert to public methods. + self._num_process = distribution_lib.num_processes() + self._process_id = distribution_lib.process_id() + self._is_multi_process = self._num_process > 1 + + def _initialize_with_device_mesh(self, device_mesh, auto_shard_dataset): + if not isinstance(device_mesh, DeviceMesh): + raise ValueError( + "Expect `mesh` to be an instance of `DeviceMesh`. " + f"Received: mesh={device_mesh} (of type {type(device_mesh)})" + ) + super().__init__( + device_mesh, device_mesh.axis_names[0], auto_shard_dataset + ) + if self.device_mesh.devices.ndim != 1: + warnings.warn( + "Expect the input mesh to be 1D, but received " + "mesh.devices.ndim=%d. " + "The first axis will be used for data-parallel sharding.", + device_mesh.devices.ndim, + ) + + def _initialize_mesh_from_devices(self, devices, auto_shard_dataset): + devices = np.array(devices) + device_mesh = DeviceMesh( + shape=devices.shape, + axis_names=[DEFAULT_BATCH_DIM_NAME], + devices=devices, + ) + super().__init__( + device_mesh, DEFAULT_BATCH_DIM_NAME, auto_shard_dataset + ) + + def _initialize_mesh_from_list_devices(self, auto_shard_dataset): + devices = np.array(list_devices()) + device_mesh = DeviceMesh( + shape=devices.shape, + axis_names=[DEFAULT_BATCH_DIM_NAME], + devices=devices, + ) + super().__init__( + device_mesh, DEFAULT_BATCH_DIM_NAME, auto_shard_dataset + ) + + def get_data_layout(self, data_shape): + data_shard_spec = [None] * len(data_shape) + data_shard_spec[0] = self.batch_dim_name # Shard on the first dim + return TensorLayout(data_shard_spec, self.device_mesh) + + def get_variable_layout(self, variable): + # First check if the variable already has a layout assigned. + if getattr(variable, "_layout", None) is not None: + return variable._layout + # Otherwise, replicate variable. + variable_shard_spec = [None] * len(variable.shape) + return TensorLayout(variable_shard_spec, self.device_mesh) + + def get_tensor_layout(self, path): + # For data parallel training, the intermediate state is not changed. + return None + + def distribute_dataset(self, dataset): + if not self._is_multi_process or not self.auto_shard_dataset: + return dataset + + # Try to distribute a global tf.data.Dataset. + from keras.src.utils.module_utils import tensorflow as tf + + if not tf.available or not isinstance(dataset, tf.data.Dataset): + raise ValueError( + "Only `tf.data.Dataset` is supported for auto-sharding, " + f"got {type(dataset)}" + ) + + from tensorflow.python.data.experimental.ops import ( + distribute as tf_data_distribute, + ) + + batch_size = tf_data_distribute.compute_batch_size(dataset) + if batch_size.numpy() < 0: + raise ValueError( + "The batch size of the input dataset is " + "unknown. Please config the batch size for " + "the input dataset, e.g via `dataset.batch(batch_size)`" + ) + per_worker_batch_size = tf_data_distribute.batch_sizes_for_worker( + global_batch_size=batch_size, + num_workers=self._num_process, + num_replicas_per_worker=1, # We hard code this for now. + worker_index=self._process_id, + ) + distributed_dataset = dataset.rebatch(per_worker_batch_size) + distributed_dataset = tf_data_distribute._AutoShardDataset( + distributed_dataset, + num_workers=self._num_process, + index=self._process_id, + num_replicas=self._num_process, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) + + +@keras_export("keras.distribution.ModelParallel") +class ModelParallel(Distribution): + """Distribution that shards model variables. + + Compare to `DataParallel` which replicates the variables across all devices, + `ModelParallel` allows you to shard variables in addition to the input data. + + To construct a `ModelParallel` distribution, you need to provide a + `DeviceMesh` and a `LayoutMap`. + + 1. `DeviceMesh` contains physical device information. The axis names in + the mesh will be used to map the variable and data layout. + 2. `LayoutMap` contains the mapping between variable paths to their + corresponding `TensorLayout`. + + Example: + + ```python + devices = list_devices() # Assume there are 8 devices. + + # Create a mesh with 2 devices for data parallelism and 4 devices for + # model parallelism. + device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), + devices=devices) + # Create a layout map that shard the `Dense` layer and `Conv2D` + # layer variables on the last dimension. + # Based on the `device_mesh`, this means the variables + # will be split across 4 devices. Any other variable that doesn't + # match any key in the layout map will be fully replicated. + layout_map = LayoutMap(device_mesh) + layout_map['dense.*kernel'] = (None, 'model') + layout_map['dense.*bias'] = ('model',) + layout_map['conv2d.*kernel'] = (None, None, None, 'model') + layout_map['conv2d.*bias'] = ('model',) + + distribution = ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + # Set the global distribution, or via `with distribution.scope():` + set_distribution(distribution) + + model = model_creation() + model.compile() + model.fit(data) + ``` + + You can quickly update the device mesh shape to change the sharding factor + of the variables. E.g. + + ```python + # With only the shape change for the device mesh, the variables will be + # sharded across 8 devices instead of 4, which further reduces the memory + # footprint of variables on each of the device. + device_mesh = DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=devices, + ) + ``` + + To figure out a proper layout mapping rule for all the model variables, you + can first list out all the model variable paths, which will be used as the + key to map the variables to `TensorLayout`. + + e.g. + + ```python + model = create_model() + for v in model.variables: + print(v.path) + ``` + + Args: + layout_map: `LayoutMap` instance which map the variable path to the + corresponding tensor layout. + batch_dim_name: Optional string, the axis name in the device mesh + (of the `layout_map` object) + that will be used to distribute data. If unspecified, the + first axis from the device mesh will be used. + auto_shard_dataset: Automatically shard the dataset amongst + processes in a multi-process setting. Set to `False` if the dataset + is already sharded across hosts. Defaults to `True`. + """ + + def __init__( + self, + *, + layout_map=None, + batch_dim_name=None, + auto_shard_dataset=True, + **kwargs, + ): + kwargs.pop("device_mesh", None) + if layout_map is None: + raise ValueError("You must specify a layout_map argument.") + if not isinstance(layout_map, LayoutMap): + raise ValueError( + "Argument `layout_map` must be a `LayoutMap` instance. " + f"Received: layout_map={layout_map}" + ) + device_mesh = layout_map.device_mesh + batch_dim_name = batch_dim_name or device_mesh.axis_names[0] + super().__init__(device_mesh, batch_dim_name, auto_shard_dataset) + self._layout_map = layout_map + + # Those following attributes might get convert to public methods. + self._num_process = distribution_lib.num_processes() + self._process_id = distribution_lib.process_id() + self._is_multi_process = self._num_process > 1 + + def get_data_layout(self, data_shape): + data_shard_spec = [None] * len(data_shape) + data_shard_spec[0] = self.batch_dim_name # Shard on the first dim + return TensorLayout(data_shard_spec, self.device_mesh) + + def get_variable_layout(self, variable): + # First check if the variable already has a layout assigned. + if getattr(variable, "_layout", None) is not None: + return variable._layout + # Check the layout map. + variable_layout = self._layout_map[variable.path] + if variable_layout is not None: + return variable_layout + variable_shard_spec = [None] * len(variable.shape) + return TensorLayout(variable_shard_spec, self.device_mesh) + + def get_tensor_layout(self, path): + return self._layout_map[path] + + def distribute_dataset(self, dataset): + if not self._is_multi_process or not self.auto_shard_dataset: + return dataset + + # Try to distribute a global tf.data.Dataset. + from keras.src.utils.module_utils import tensorflow as tf + + if not tf.available or not isinstance(dataset, tf.data.Dataset): + raise ValueError( + "Only `tf.data.Dataset` is supported for auto-sharding, " + f"got {type(dataset)}" + ) + + from tensorflow.python.data.experimental.ops import ( + distribute as tf_data_distribute, + ) + + global_batch_size = tf_data_distribute.compute_batch_size(dataset) + if global_batch_size.numpy() < 0: + raise ValueError( + "The batch size of the input dataset is " + "unknown. Please config the batch size for " + "the input dataset, e.g via `dataset.batch(batch_size)`" + ) + + # We need to compute the per-process/worker/host batch size. + # This will depend on how many model replicas we have on each process. + # Note that this might be smaller than one if model replicas are sharded + # across multiple processes. + mesh_batch_dim_index = self.device_mesh.axis_names.index( + self.batch_dim_name + ) + num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index] + if num_model_replicas == 1: + # No sharding is needed in this case. Each process will have the + # global batch size, and data from the iterator will need to be + # replicated across all processes. + return dataset.prefetch(tf.data.AUTOTUNE) + num_model_replicas_per_process = num_model_replicas / self._num_process + if num_model_replicas_per_process >= 1: + # Each process will have one or more full model replicas. Data will + # be sharded across all processes without replication. + if global_batch_size % self._num_process != 0: + raise ValueError( + "Global batch size must be divisible by the number of " + f"processes. `global_batch_size`={global_batch_size} and " + f"`num_process`={self._num_process}" + ) + per_process_batch_size = global_batch_size // self._num_process + distributed_dataset = dataset.rebatch(per_process_batch_size) + distributed_dataset = distributed_dataset.shard( + num_shards=self._num_process, + index=self._process_id, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) + else: + # Model replicas are sharded across multiple processes. Data will be + # sharded across model replicas, and replicated across processes + # within the same model replica. + if global_batch_size % num_model_replicas != 0: + raise ValueError( + "Global batch size must be divisible by the number of " + f"replicas. `global_batch_size`={global_batch_size} and " + f"`num_model_replicas`={num_model_replicas}" + ) + per_process_batch_size = global_batch_size // num_model_replicas + distributed_dataset = dataset.rebatch(per_process_batch_size) + processes_per_replica = self._num_process // num_model_replicas + # TODO: Figure out what the convention is for data sharding id. + data_shard_id = self._process_id % processes_per_replica + distributed_dataset = distributed_dataset.shard( + num_shards=num_model_replicas, + index=data_shard_id, + ) + return distributed_dataset.prefetch(tf.data.AUTOTUNE) + + +@keras_export("keras.distribution.LayoutMap") +class LayoutMap(collections.abc.MutableMapping): + """A dict-like object that maps string to `TensorLayout` instances. + + `LayoutMap` uses a string as key and a `TensorLayout` as value. There is a + behavior difference between a normal Python dict and this class. The string + key will be treated as a regex when retrieving the value. See the docstring + of `get` for more details. + + See below for a usage example. You can define the naming schema + of the `TensorLayout`, and then retrieve the corresponding + `TensorLayout` instance. + + In the normal case, the key to query is usually the `variable.path`, which + is the identifier of the variable. + + As shortcut, tuple or list of axis names are also allowed when inserting + as value, and will be converted to `TensorLayout`. + + ```python + layout_map = LayoutMap(device_mesh) + layout_map['dense.*kernel'] = (None, 'model') + layout_map['dense.*bias'] = ('model',) + layout_map['conv2d.*kernel'] = (None, None, None, 'model') + layout_map['conv2d.*bias'] = ('model',) + + layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d + layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d + layout_3 = layout_map['dense_2.kernel'] # layout_3 == layout_2d + layout_4 = layout_map['dense_2.bias'] # layout_4 == layout_1d + layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d + layout_6 = layout_map['my_model/conv2d_123/bias'] # layout_6 == layout_1d + layout_7 = layout_map['my_model/conv3d_1/kernel'] # layout_7 == None + layout_8 = layout_map['my_model/conv3d_1/bias'] # layout_8 == None + ``` + + Args: + device_mesh: `keras.distribution.DeviceMesh` instance. + """ + + def __init__(self, device_mesh): + self._layout_map = collections.OrderedDict() + self._device_mesh = device_mesh + + def __getitem__(self, key): + """Retrieves the corresponding layout by the string key. + + When there isn't an exact match, all the existing keys in the layout map + will be treated as a regex and map against the input key again. When + there are multiple matches for the regex, an `ValueError` will be + raised. Returns `None` if there isn't any match found. + + Args: + key: String key to query a layout. + + Returns: + Corresponding layout based on the query. + """ + if key in self._layout_map: + return self._layout_map[key] + + matching_keys = [] + for k in self._layout_map: + if re.search(k, key): + matching_keys.append(k) + if len(matching_keys) > 1: + raise ValueError( + f"Path '{key}' matches multiple layout " + f"specification keys: {matching_keys}. Please make " + "sure each tensor/variable path only matches at most " + "one layout specification key in the LayoutMap." + ) + elif len(matching_keys) == 1: + return self._layout_map[matching_keys[0]] + return None + + def __setitem__(self, key, layout): + """Insert TensorLayout to the LayoutMap. + + Args: + key: String key for the `TensorLayout`. + layout: The `TensorLayout`. As a shortcut, tuple of string and None + are also acceptable, and will be converted to `TensorLayout`. + """ + if key in self._layout_map: + raise ValueError( + f"{key} already exist in the LayoutMap with " + f"value {self._layout_map[key]}. Please make sure to " + "not use duplicated keys." + ) + if isinstance(layout, tuple): + layout = TensorLayout(axes=layout, device_mesh=None) + + if not isinstance(layout, TensorLayout): + raise ValueError( + f"{layout} should be a TensorLayout type, got {type(layout)}" + ) + self._maybe_populate_device_mesh(layout) + self._layout_map[key] = layout + + def __delitem__(self, key): + # let the dict to handle the key missing error + return self._layout_map.pop(key) + + def __len__(self): + return len(self._layout_map) + + def __iter__(self): + return iter(self._layout_map) + + @property + def device_mesh(self): + return self._device_mesh + + def _maybe_populate_device_mesh(self, layout): + if layout.device_mesh is None and self.device_mesh is not None: + layout.device_mesh = self.device_mesh + + +LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__ + + +@keras_export("keras.distribution.distribute_tensor") +def distribute_tensor(tensor, layout): + """Change the layout of a Tensor value in the jit function execution. + + Args: + tensor: a Tensor to change the layout. + layout: `TensorLayout` to be applied on the value. + + Returns: + a new value with the specified tensor layout. + """ + if isinstance(tensor, KerasTensor): + # keras tensor is only used for building functional model, and can't be + # used to alter layout/sharding. + return tensor + return distribution_lib.distribute_tensor(tensor, layout) + + +@keras_export("keras.distribution.distribution") +def distribution(): + """Retrieve the current distribution from global context.""" + return global_state.get_global_attribute(GLOBAL_ATTRIBUTE_NAME) + + +@keras_export("keras.distribution.set_distribution") +def set_distribution(value): + """Set the distribution as the global distribution setting. + + Args: + value: a `Distribution` instance. + """ + global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py new file mode 100644 index 000000000000..66f996b3fb68 --- /dev/null +++ b/keras/src/distribution/distribution_lib_test.py @@ -0,0 +1,537 @@ +"""Test for distribution_lib.py.""" + +import os +from unittest import mock + +import numpy as np +import pytest +import tensorflow as tf + +from keras.src import backend +from keras.src import testing +from keras.src.backend import distribution_lib as backend_dlib +from keras.src.distribution import distribution_lib + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Only JAX has the backend to mock at the moment", +) +@mock.patch.object( + backend_dlib, + "initialize", + return_value=None, +) +class MultiProcessInitializeTest(testing.TestCase): + def tearDown(self): + super().tearDown() + os.environ.clear() + + def test_initialize_with_explicit_param(self, mock_backend_initialize): + job_addresses = "10.0.0.1:1234,10.0.0.2:2345" + num_processes = 2 + current_process_id = 0 + + distribution_lib.initialize( + job_addresses, num_processes, current_process_id + ) + + mock_backend_initialize.assert_called_once_with( + job_addresses, num_processes, current_process_id + ) + + def test_initialize_with_env_vars(self, mock_backend_initialize): + job_addresses = "10.0.0.1:1234,10.0.0.2:2345" + num_processes = 2 + current_process_id = 0 + os.environ["KERAS_DISTRIBUTION_JOB_ADDRESSES"] = job_addresses + os.environ["KERAS_DISTRIBUTION_NUM_PROCESSES"] = str(num_processes) + os.environ["KERAS_DISTRIBUTION_PROCESS_ID"] = str(current_process_id) + + distribution_lib.initialize() + mock_backend_initialize.assert_called_once_with( + job_addresses, num_processes, current_process_id + ) + + def test_init_with_nones(self, mock_backend_initialize): + # This is also valid case for Cloud TPU on JAX + distribution_lib.initialize() + mock_backend_initialize.assert_called_once_with(None, None, None) + + +class DeviceMeshTest(testing.TestCase): + def test_mesh_creation(self): + devices = [f"cpu:{i}" for i in range(8)] + shape = (4, 2) + axis_names = ["batch", "model"] + + mesh = distribution_lib.DeviceMesh(shape, axis_names, devices) + self.assertEqual(mesh.shape, shape) + self.assertEqual(mesh.axis_names, axis_names) + self.assertEqual(mesh.devices.shape, shape) + + def test_input_validation(self): + devices = [f"cpu:{i}" for i in range(4)] + with self.assertRaisesRegex( + ValueError, "Shape and axis_names cannot be empty" + ): + distribution_lib.DeviceMesh((4,), "", devices) + + with self.assertRaisesRegex( + ValueError, "Shape and axis_names should have same size" + ): + distribution_lib.DeviceMesh((4, 2), ["batch"], devices) + + with self.assertRaisesRegex( + ValueError, "Shape does not match the number of devices" + ): + distribution_lib.DeviceMesh((4, 2), ["batch", "model"], devices) + + +class TensorLayoutTest(testing.TestCase): + def setUp(self): + self.mesh = distribution_lib.DeviceMesh( + (4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)] + ) + + def test_tensor_layout_creation(self): + axes = ("data", None) + layout = distribution_lib.TensorLayout(axes, self.mesh) + + self.assertEqual(layout.device_mesh, self.mesh) + self.assertEqual(layout.axes, axes) + + def test_tensor_layout_validation(self): + axes = ("data", "unknown", None) + with self.assertRaisesRegex( + ValueError, "Invalid axis names for Layout" + ): + distribution_lib.TensorLayout(axes, self.mesh) + + def test_lazy_device_mesh_injection(self): + axes = ("data", None) + layout = distribution_lib.TensorLayout(axes, None) + + self.assertIsNone(layout.device_mesh) + self.assertEqual(layout.axes, axes) + + layout.device_mesh = self.mesh + + self.assertEqual(layout.device_mesh, self.mesh) + self.assertEqual(layout.axes, axes) + + def test_lazy_device_mesh_validation(self): + axes = ("data", "unknown", None) + layout = distribution_lib.TensorLayout(axes, None) + + self.assertIsNone(layout.device_mesh) + self.assertEqual(layout.axes, axes) + + with self.assertRaisesRegex( + ValueError, "Invalid axis names for Layout" + ): + layout.device_mesh = self.mesh + + +class DistributionTest(testing.TestCase): + def setUp(self): + super().setUp() + devices = [f"cpu:{i}" for i in range(8)] + shape = (4, 2) + axis_names = ["batch", "model"] + + self.device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, devices + ) + + def test_init_with_device_mesh(self): + distribution = distribution_lib.Distribution(self.device_mesh) + self.assertIs(distribution.device_mesh, self.device_mesh) + + def test_scope(self): + distribution_1 = distribution_lib.Distribution(self.device_mesh) + distribution_2 = distribution_lib.Distribution(self.device_mesh) + + self.assertIsNone(distribution_lib.distribution()) + with distribution_1.scope(): + self.assertIs(distribution_lib.distribution(), distribution_1) + with distribution_2.scope(): + self.assertIs(distribution_lib.distribution(), distribution_2) + + self.assertIs(distribution_lib.distribution(), distribution_1) + + self.assertIsNone(distribution_lib.distribution()) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Only JAX has the proper backend distribution lib", +) +class DataParallelDistributionTest(testing.TestCase): + def setUp(self): + super().setUp() + self.devices = [f"cpu:{i}" for i in range(8)] + shape = (8,) + axis_names = ["data"] + + self.device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, self.devices + ) + + def test_create_with_device_mesh(self): + distribution = distribution_lib.DataParallel( + device_mesh=self.device_mesh + ) + + device_mesh = distribution.device_mesh + self.assertEqual(len(device_mesh.devices), 8) + self.assertEqual(device_mesh.axis_names, ["data"]) + self.assertEqual(distribution.batch_dim_name, "data") + + self.assertFalse(distribution._is_multi_process) + self.assertEqual(distribution._process_id, 0) + self.assertEqual(distribution._num_process, 1) + + def test_create_with_devices(self): + distribution = distribution_lib.DataParallel(devices=self.devices) + device_mesh = distribution.device_mesh + self.assertEqual(len(device_mesh.devices), 8) + self.assertEqual(device_mesh.axis_names, ["batch"]) + self.assertEqual(distribution.batch_dim_name, "batch") + + @mock.patch.object( + distribution_lib, + "list_devices", + return_value=[f"cpu:{i}" for i in range(8)], + ) + def test_create_with_list_devices(self, mock_list_devices): + distribution = distribution_lib.DataParallel() + mock_list_devices.assert_called_once() + + device_mesh = distribution.device_mesh + self.assertEqual(len(device_mesh.devices), 8) + self.assertEqual(device_mesh.axis_names, ["batch"]) + self.assertEqual(distribution.batch_dim_name, "batch") + + def test_get_data_layout(self): + distribution = distribution_lib.DataParallel( + device_mesh=self.device_mesh + ) + + data = np.arange(16).reshape((4, 2, 2)) + data_layout = distribution.get_data_layout(data.shape) + self.assertIs(data_layout.device_mesh, self.device_mesh) + self.assertEqual(data_layout.axes, ("data", None, None)) + + @pytest.mark.skipif(testing.jax_uses_gpu(), reason="CI segfault") + def test_get_variable_layout(self): + distribution = distribution_lib.DataParallel( + device_mesh=self.device_mesh + ) + + variable = backend.Variable(initializer=[1, 2, 3]) + variable_layout = distribution.get_variable_layout(variable) + self.assertIs(variable_layout.device_mesh, self.device_mesh) + self.assertEqual(variable_layout.axes, (None,)) + + @pytest.mark.skipif(testing.jax_uses_gpu(), reason="CI segfault") + def test_get_variable_layout_with_explicit_layout(self): + distribution = distribution_lib.DataParallel( + device_mesh=self.device_mesh + ) + + explicit_mesh = distribution_lib.DeviceMesh((8,), ["x"], self.devices) + explicit_layout = distribution_lib.TensorLayout(["x"], explicit_mesh) + + variable = backend.Variable(initializer=[1, 2, 3]) + variable._layout = explicit_layout + variable_layout = distribution.get_variable_layout(variable) + self.assertIs(variable_layout.device_mesh, explicit_mesh) + self.assertEqual(variable_layout.axes, explicit_layout.axes) + + def test_get_tensor_layout(self): + distribution = distribution_lib.DataParallel( + device_mesh=self.device_mesh + ) + + path = "path/to/tensor" + tensor_layout = distribution.get_tensor_layout(path) + self.assertIsNone(tensor_layout) + + def test_distribute_dataset(self): + # We can only verify the single worker/process case in OSS for now. + dataset = tf.data.Dataset.range(8) + distribution = distribution_lib.DataParallel( + device_mesh=self.device_mesh + ) + distributed_dataset = distribution.distribute_dataset(dataset) + self.assertIs(dataset, distributed_dataset) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Only JAX has the proper backend distribution lib", +) +class ModelParallelDistributionTest(testing.TestCase): + def setUp(self): + super().setUp() + self.devices = [f"cpu:{i}" for i in range(8)] + shape = (2, 4) + axis_names = ["data", "model"] + + self.device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, self.devices + ) + + @pytest.mark.skipif(testing.jax_uses_gpu(), reason="CI segfault") + def test_distribute_weights(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"]) + layout_map[".*bias"] = distribution_lib.TensorLayout(["model"]) + + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="data" + ) + kernel = backend.Variable(initializer=np.arange(8, 4), name="kernel") + bias = backend.Variable(initializer=np.arange(4), name="bias") + rng_seed = backend.Variable(initializer=[0, 1], name="seed") + + kernel_layout = distribution.get_variable_layout(kernel) + self.assertIs(kernel_layout.device_mesh, self.device_mesh) + self.assertEqual(kernel_layout.axes, (None, "model")) + + bias_layout = distribution.get_variable_layout(bias) + self.assertIs(bias_layout.device_mesh, self.device_mesh) + self.assertEqual(bias_layout.axes, ("model",)) + + rng_seed_layout = distribution.get_variable_layout(rng_seed) + self.assertIs(rng_seed_layout.device_mesh, self.device_mesh) + self.assertEqual(rng_seed_layout.axes, (None,)) + + def test_distribute_data(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="data" + ) + + data = np.arange(16).reshape((4, 2, 2)) + data_layout = distribution.get_data_layout(data.shape) + self.assertIs(data_layout.device_mesh, self.device_mesh) + self.assertEqual(data_layout.axes, ("data", None, None)) + + def test_get_tensor_layout(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"]) + layout_map[".*bias"] = distribution_lib.TensorLayout(["model"]) + layout_map["/model/layer/tensor"] = ("data", None) + + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="data" + ) + layout = distribution.get_tensor_layout("/model/layer/tensor") + self.assertIs(layout.device_mesh, self.device_mesh) + self.assertEqual(layout.axes, ("data", None)) + + layout = distribution.get_tensor_layout("/model/layer/other_tensor") + self.assertIsNone(layout) + + @pytest.mark.skipif(testing.jax_uses_gpu(), reason="CI segfault") + def test_get_variable_layout_with_explicit_layout(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + layout_map[".*kernel"] = distribution_lib.TensorLayout([None, "model"]) + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="data" + ) + + explicit_mesh = distribution_lib.DeviceMesh((8,), ["x"], self.devices) + explicit_layout = distribution_lib.TensorLayout(["x"], explicit_mesh) + variable = backend.Variable(initializer=[1, 2, 3], name="kernel") + variable._layout = explicit_layout + variable_layout = distribution.get_variable_layout(variable) + self.assertIs(variable_layout.device_mesh, explicit_mesh) + self.assertEqual(variable_layout.axes, explicit_layout.axes) + + def test_distribute_dataset(self): + # We can only verify the single worker/process case in OSS for now. + dataset = tf.data.Dataset.range(8) + layout_map = distribution_lib.LayoutMap(self.device_mesh) + distribution = distribution_lib.ModelParallel( + layout_map=layout_map, batch_dim_name="data" + ) + distributed_dataset = distribution.distribute_dataset(dataset) + self.assertIs(dataset, distributed_dataset) + + +class LayoutMapTest(testing.TestCase): + def setUp(self): + super().setUp() + self.devices = [f"cpu:{i}" for i in range(8)] + shape = (4, 2) + axis_names = ["data", "model"] + + self.device_mesh = distribution_lib.DeviceMesh( + shape, axis_names, self.devices + ) + self.sharded_2d = distribution_lib.TensorLayout([None, "model"]) + self.sharded_1d = distribution_lib.TensorLayout(["model"]) + + self.replicated_2d = distribution_lib.TensorLayout([None, None]) + self.replicated_1d = distribution_lib.TensorLayout([None]) + + def test_add(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + layout_map["dense/kernel"] = self.sharded_2d + layout_map["dense/bias"] = self.sharded_1d + # Test for adding list/tuple as shortcut for TensorLayout + layout_map["conv/bias"] = ("model",) + + # Make there are two items in the map, and we access them via the + # underlying container at layout_map._layout_map + self.assertLen(layout_map, 3) + + kernel_layout = layout_map["dense/kernel"] + self.assertEqual(kernel_layout.axes, (None, "model")) + self.assertIs(kernel_layout.device_mesh, self.device_mesh) + + bias_layout = layout_map["dense/bias"] + self.assertEqual(bias_layout.axes, ("model",)) + self.assertIs(bias_layout.device_mesh, self.device_mesh) + + conv_bias_layout = layout_map["conv/bias"] + self.assertEqual(conv_bias_layout.axes, ("model",)) + self.assertIs(bias_layout.device_mesh, self.device_mesh) + + with self.assertRaisesRegex(ValueError, "dense/kernel already exist"): + layout_map["dense/kernel"] = self.sharded_2d + + with self.assertRaisesRegex(ValueError, "should be a TensorLayout"): + layout_map["conv.kernel"] = ["a", "b"] + + def test_get(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + layout_map["dense/kernel"] = self.sharded_2d + layout_map["dense/bias"] = self.sharded_1d + + layout_map["dense.*kernel"] = self.replicated_2d + layout_map["dense.*bias"] = self.replicated_1d + + layout_map["bias"] = self.sharded_1d + + self.assertEqual(layout_map["dense/kernel"], self.sharded_2d) + self.assertEqual(layout_map["dense/bias"], self.sharded_1d) + + self.assertEqual(layout_map["dense_2/kernel"], self.replicated_2d) + # Map against the wildcard bias rule for dense. This will cause a + # ValueError + with self.assertRaisesRegex( + ValueError, "Path 'dense_2/bias' matches multiple layout" + ): + layout_map["dense_2/bias"] + + self.assertIsNone(layout_map["conv2d/kernel"]) + self.assertEqual(layout_map["conv2d/bias"], self.sharded_1d) + + def test_delete(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + + layout_map["dense/kernel"] = self.sharded_2d + layout_map["dense/bias"] = self.sharded_1d + + self.assertEqual(layout_map.pop("dense/kernel"), self.sharded_2d) + # Make sure to match against the exact string, not the regex + with self.assertRaises(KeyError): + layout_map.pop(".*bias") + + # Make sure del also works + del layout_map["dense/bias"] + + self.assertLen(layout_map, 0) + + def test_len(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + self.assertLen(layout_map, 0) + + layout_map["dense/kernel"] = self.sharded_2d + layout_map["dense/bias"] = self.sharded_1d + + self.assertLen(layout_map, 2) + + def test_iter(self): + layout_map = distribution_lib.LayoutMap(self.device_mesh) + + layout_map["dense/kernel"] = self.sharded_2d + layout_map["dense/bias"] = self.sharded_1d + + # Make sure the items are ordered based on the insertion order. + self.assertEqual( + list(layout_map.keys()), ["dense/kernel", "dense/bias"] + ) + + keys = [] + values = [] + for k, v in layout_map.items(): + keys.append(k) + values.append(v) + + self.assertEqual(keys, ["dense/kernel", "dense/bias"]) + self.assertEqual(values, [self.sharded_2d, self.sharded_1d]) + + +# @pytest.mark.skipif( +# backend.backend() != "tensorflow", +# reason="Backend specific test", +# ) +# class TensorflowDistributionLibTest(testing.TestCase): +# def setUp(self): +# super().setUp() +# # Config virtual devices for testing. +# cpus = tf.config.list_physical_devices("cpu") +# context._reset_context() +# tf.config.set_logical_device_configuration( +# cpus[0], [tf.config.LogicalDeviceConfiguration()] * 8 +# ) + +# dtensor.initialize_accelerator_system("cpu") + +# def tearDown(self) -> None: +# super().tearDown() +# dtensor.shutdown_accelerator_system() + +# def test_list_devices(self): +# self.assertEqual(len(distribution_lib.list_devices()), 8) +# self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) +# self.assertEqual(len(distribution_lib.list_devices("cpu")), 8) + +# def test_to_dtensor_mesh(self): +# devices = [f"cpu:{i}" for i in range(8)] +# shape = (4, 2) +# axis_names = ["batch", "model"] + +# mesh = distribution_lib.DeviceMesh(shape, axis_names, devices) +# dtensor_mesh = backend_dlib._to_dtensor_mesh(mesh) + +# self.assertIsInstance(dtensor_mesh, dtensor.Mesh) +# self.assertEqual(dtensor_mesh.shape(), list(shape)) +# self.assertEqual(dtensor_mesh.dim_names, axis_names) + +# def test_to_dtensor_layout(self): +# axes = ["data", None] +# mesh = distribution_lib.DeviceMesh( +# (4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)] +# ) +# layout = distribution_lib.TensorLayout(axes, mesh) +# dtensor_layout = backend_dlib._to_dtensor_layout(layout) +# dtensor_mesh = backend_dlib._to_dtensor_mesh(mesh) +# self.assertEqual( +# dtensor_layout, +# dtensor.Layout(["data", dtensor.UNSHARDED], dtensor_mesh), +# ) + +# def test_validation_for_device_mesh(self): +# axes = ["data", None] +# layout = distribution_lib.TensorLayout(axes, device_mesh=None) + +# with self.assertRaisesRegex( +# ValueError, "Cannot create sharding when device mesh is not set" +# ): +# backend_dlib._to_dtensor_layout(layout) diff --git a/keras/src/dtype_policies/__init__.py b/keras/src/dtype_policies/__init__.py new file mode 100644 index 000000000000..6bf0eb45bbb7 --- /dev/null +++ b/keras/src/dtype_policies/__init__.py @@ -0,0 +1,109 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.dtype_policies import dtype_policy +from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES +from keras.src.dtype_policies.dtype_policy import DTypePolicy +from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy +from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy +from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy +from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy +from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap + +ALL_OBJECTS = { + DTypePolicy, + FloatDTypePolicy, + QuantizedDTypePolicy, + QuantizedFloat8DTypePolicy, + DTypePolicyMap, + GPTQDTypePolicy, +} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} + + +@keras_export("keras.dtype_policies.serialize") +def serialize(dtype_policy): + """Serializes `DTypePolicy` instance. + + Args: + dtype_policy: A Keras `DTypePolicy` instance. + + Returns: + `DTypePolicy` configuration dictionary. + """ + from keras.src.saving import serialization_lib + + return serialization_lib.serialize_keras_object(dtype_policy) + + +@keras_export("keras.dtype_policies.deserialize") +def deserialize(config, custom_objects=None): + """Deserializes a serialized `DTypePolicy` instance. + + Args: + config: `DTypePolicy` configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras `DTypePolicy` instance. + """ + from keras.src.saving import serialization_lib + + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.dtype_policies.get") +def get(identifier): + """Retrieves a Keras `DTypePolicy` instance. + + The `identifier` may be the string name of a `DTypePolicy` class. + + >>> policy = dtype_policies.get("mixed_bfloat16") + >>> type(policy) + + + You can also specify `config` of the dtype policy to this function by + passing dict containing `class_name` and `config` as an identifier. Also + note that the `class_name` must map to a `DTypePolicy` class + + >>> identifier = {"class_name": "DTypePolicy", + ... "config": {"name": "float32"}} + >>> policy = dtype_policies.get(identifier) + >>> type(policy) + + + Args: + identifier: A dtype policy identifier. One of `None` or string name of a + `DTypePolicy` or `DTypePolicy` configuration dictionary or a + `DTypePolicy` instance. + + Returns: + A Keras `DTypePolicy` instance. + """ + from keras.src.dtype_policies.dtype_policy import ( + _get_quantized_dtype_policy_by_str, + ) + + if identifier is None: + return dtype_policy.dtype_policy() + if isinstance(identifier, DTypePolicy): + return identifier + if isinstance(identifier, dict): + return deserialize(identifier) + if isinstance(identifier, str): + if identifier.startswith(QUANTIZATION_MODES): + return _get_quantized_dtype_policy_by_str(identifier) + else: + return DTypePolicy(identifier) + try: + return DTypePolicy(backend.standardize_dtype(identifier)) + except: + raise ValueError( + "Cannot interpret `dtype` argument. Expected a string " + f"or an instance of DTypePolicy. Received: dtype={identifier}" + ) diff --git a/keras/src/dtype_policies/dtype_policy.py b/keras/src/dtype_policies/dtype_policy.py new file mode 100644 index 000000000000..0e5f8bb4f6fb --- /dev/null +++ b/keras/src/dtype_policies/dtype_policy.py @@ -0,0 +1,448 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + +QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq") + + +@keras_export( + [ + "keras.DTypePolicy", + "keras.dtype_policies.DTypePolicy", + "keras.mixed_precision.DTypePolicy", # Legacy + "keras.mixed_precision.Policy", # Legacy + ] +) +class DTypePolicy: + """A dtype policy for a Keras layer. + + A dtype policy determines a layer's computation and variable dtypes. Each + layer has a policy. Policies can be passed to the `dtype` argument of layer + constructors, or a global policy can be set with + `keras.config.set_dtype_policy`. + + Args: + name: The policy name, which determines the compute and variable dtypes. + Can be any dtype name, such as `"float32"` or `"float64"`, + which causes both the compute and variable dtypes + will be that dtype. + Can also be the string `"mixed_float16"` or `"mixed_bfloat16"`, + which causes the compute dtype to be `float16` or `bfloat16` + and the variable dtype to be `float32`. + + Typically you only need to interact with dtype policies when using mixed + precision, which is the use of float16 or bfloat16 for computations and + float32 for variables. This is why the term `mixed_precision` appears in the + API name. Mixed precision can be enabled by passing `"mixed_float16"` or + `"mixed_bfloat16"` to `keras.mixed_precision.set_dtype_policy()`. + + >>> keras.config.set_dtype_policy("mixed_float16") + >>> layer1 = keras.layers.Dense(10) + >>> layer1.dtype_policy # layer1 will automatically use mixed precision + + >>> # Can optionally override layer to use float32 + >>> # instead of mixed precision. + >>> layer2 = keras.layers.Dense(10, dtype="float32") + >>> layer2.dtype_policy + + >>> # Set policy back to initial float32. + >>> keras.config.set_dtype_policy('float32') + + In the example above, passing `dtype="float32"` to the layer is + equivalent to passing + `dtype=keras.config.DTypePolicy("float32")`. + In general, passing a dtype policy name to a layer is equivalent + to passing the corresponding policy, so it is never necessary + to explicitly construct a `DTypePolicy` object. + """ + + def __init__(self, name=None): + # Use the global dtype policy if `name` is not specified + if name is None: + name = dtype_policy().name + self._name = name + self._compute_dtype, self._variable_dtype = self._parse_name(name) + self._quantization_mode = None + + def _parse_name(self, name): + """Parses a `DTypePolicy` name into a compute and variable dtype. + + Args: + name: The name of the policy. + + Returns: + The `(compute_dtype, variable_dtype)` pair. + """ + if not isinstance(name, str): + raise TypeError( + "'name' must be a string, such as 'mixed_float16'. " + f"Received: name={name} (of type {type(name)})" + ) + if name == "mixed_float16": + return "float16", "float32" + elif name == "mixed_bfloat16": + return "bfloat16", "float32" + try: + dtype = backend.standardize_dtype(name) + return dtype, dtype + except ValueError: + raise ValueError( + f"Cannot convert '{name}' to a mixed precision " + "DTypePolicy. Valid policies include 'mixed_float16', " + "'mixed_bfloat16', and the name of any float dtype such as " + "'float32'." + ) + + @property + def variable_dtype(self): + """The variable dtype of this policy. + + This is the dtype layers will create their variables in, unless a layer + explicitly chooses a different dtype. If this is different than + `DTypePolicy.compute_dtype`, Layers will cast variables to + the compute dtype to avoid type errors. + + Variable regularizers are run in the variable dtype, not the compute + dtype. + + Returns: + The variable dtype of this policy, as a string. + """ + return self._variable_dtype + + @property + def compute_dtype(self): + """The compute dtype of this policy. + + This is the dtype layers will do their computations in. Typically layers + output tensors with the compute dtype as well. + + Note that even if the compute dtype is float16 or bfloat16, hardware + devices may not do individual adds, multiplies, and other fundamental + operations in float16 or bfloat16, but instead may do some of them in + float32 for numeric stability. The compute dtype is the dtype of the + inputs and outputs of the ops that the layer executes. + Internally, many ops will do certain internal calculations in + float32 or some other device-internal intermediate format with higher + precision than float16/bfloat16, to increase numeric stability. + + Returns: + The compute dtype of this policy, as a string. + """ + return self._compute_dtype + + @property + def name(self): + """Returns the name of this policy.""" + return self._name + + @property + def quantization_mode(self): + """The quantization mode of this policy. + + Returns: + The quantization mode of this policy, as a string. If this policy is + not quantized, it will return `None`. + """ + return self._quantization_mode + + def convert_input(self, x, autocast, dtype): + """Converts the input dtype based on `autocast` and `dtype`. + + Note that `x` can be a tensor, symbolic tensor or numpy array, and this + method will keep integer inputs untouched and only apply casting to + floats. + """ + + dtype = backend.standardize_dtype(dtype) + if backend.is_tensor(x): + if self._should_cast(x, autocast, dtype): + x = backend.cast(x, dtype=dtype) + return x + elif backend.is_keras_tensor(x): + if self._should_cast(x, autocast, dtype): + x = ops.cast(x, dtype=dtype) + return x + elif hasattr(x, "__array__"): + try: + x = backend.convert_to_tensor(x) + except TypeError: + x = backend.convert_to_tensor(x, dtype=dtype) + if self._should_cast(x, autocast, dtype): + x = backend.cast(x, dtype=dtype) + return x + return x + + def get_config(self): + return {"name": self.name} + + @classmethod + def from_config(cls, config): + return cls(**config) + + def __repr__(self): + class_name = self.__class__.__name__ + if class_name == "FloatDTypePolicy": + class_name = "DTypePolicy" + return f'<{class_name} "{self._name}">' + + def __eq__(self, other): + if self.__class__ in (DTypePolicy, FloatDTypePolicy): + if type(other) not in (DTypePolicy, FloatDTypePolicy): + return False + else: + if type(other) is not self.__class__: + return False + return self._name == other._name + + def _should_cast(self, x, autocast, dtype): + x_dtype = backend.standardize_dtype(x.dtype) + if autocast and backend.is_float_dtype(x_dtype) and x_dtype != dtype: + return True + else: + return False + + +@keras_export( + ["keras.FloatDTypePolicy", "keras.dtype_policies.FloatDTypePolicy"] +) +class FloatDTypePolicy(DTypePolicy): + # An alias for `DTypePolicy` + pass + + +@keras_export("keras.dtype_policies.QuantizedDTypePolicy") +class QuantizedDTypePolicy(DTypePolicy): + def __init__(self, mode, source_name=None): + # Use the global dtype policy if `source_name` is not specified + if source_name is None: + source_name = dtype_policy().name + name = f"{mode}_from_{source_name}" + self._compute_dtype, self._variable_dtype = self._parse_name( + source_name + ) + self._check_quantization_mode(mode, self._compute_dtype) + + self._name = name + self._source_name = source_name + self._quantization_mode = mode + + def __eq__(self, other): + if super().__eq__(other) is False: + return False + return ( + self._quantization_mode == other._quantization_mode + and self._source_name == other._source_name + ) + + def get_config(self): + return { + "mode": self._quantization_mode, + "source_name": self._source_name, + } + + def _check_quantization_mode(self, mode, compute_dtype): + if mode not in QUANTIZATION_MODES: + raise ValueError( + "Invalid quantization mode. " + f"Expected one of {QUANTIZATION_MODES}. " + f"Received: mode={mode}" + ) + if compute_dtype == "float16" and mode == "int8": + raise ValueError( + f"Quantization mode='{mode}' doesn't work well with " + "compute_dtype='float16'." + ) + + +@keras_export("keras.dtype_policies.QuantizedFloat8DTypePolicy") +class QuantizedFloat8DTypePolicy(QuantizedDTypePolicy): + default_amax_history_length = 1024 + + def __init__(self, mode, source_name=None, amax_history_length=1024): + super().__init__(mode=mode, source_name=source_name) + if not isinstance(amax_history_length, int): + raise TypeError( + "`amax_history_length` must be an integer. " + f"Received: amax_history_length={amax_history_length}" + ) + self._amax_history_length = amax_history_length + + @property + def amax_history_length(self): + """The length of the amax history window. + + This property is used for scaling factor computation in float8 training. + """ + return self._amax_history_length + + def __eq__(self, other): + if super().__eq__(other) is False: + return False + return self._amax_history_length == other._amax_history_length + + def get_config(self): + config = super().get_config() + config.update({"amax_history_length": self.amax_history_length}) + return config + + +@keras_export("keras.dtype_policies.GPTQDTypePolicy") +class GPTQDTypePolicy(QuantizedDTypePolicy): + """Quantized dtype policy for GPTQ quantization. + + This policy helps propagate quantization settings for GPTQ + when loading a GPTQ quantized model in Keras format. + + Args: + mode: The quantization mode. This should be a string in the format + `"gptq//"`. + - `"gptq"`: The identifier for the quantization algorithm. + - ``: Number of bits to quantize weights to. + Supported values are 2, 3, 4, and 8. + - ``: The group size for quantization. Supported + values are -1 (for whole-tensor quantization) or any + positive integer. Typically a smaller group size leads + to better accuracy but slower speed. + Example: `"gptq/4/128"`. + source_name: The source dtype policy name, e.g. "float32". + """ + + def __init__( + self, + mode, + source_name=None, + ): + parts = mode.split("/") + expected_format = "'gptq//'" + + # Validate format + if len(parts) != 3 or parts[0] != "gptq": + raise ValueError( + "Invalid mode for GPTQDTypePolicy. Expected format " + f"{expected_format}, but got '{mode}'." + ) + + # Validate and cast weight_bits and group_size + try: + weight_bits = int(parts[1]) + group_size = int(parts[2]) + except ValueError: + raise ValueError( + "Invalid mode for GPTQDTypePolicy. and " + " must be integers. Expected format " + f"{expected_format}, but got '{mode}'." + ) + + # Validate supported values + if weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Invalid weight_bits in mode. Supported values are " + f"2, 3, 4, and 8, but got {weight_bits} from '{mode}'." + ) + + if group_size < -1 or group_size == 0: + raise ValueError( + "Invalid group_size in mode. Supported values are " + "-1 (whole-tensor) or a positive integer, " + f"but got {group_size} from '{mode}'." + ) + + base_mode = parts[0] + super().__init__( + mode=base_mode, + source_name=source_name, + ) + + self._name = f"{mode}_from_{source_name}" + self.mode = base_mode + self.weight_bits = weight_bits + self.group_size = group_size + + def __eq__(self, other): + if super().__eq__(other) is False: + return False + return ( + self.weight_bits == other.weight_bits + and self.group_size == other.group_size + ) + + def get_config(self): + config = super().get_config() + # Reconstruct the full mode string for serialization + mode = f"{self.mode}/{self.weight_bits}/{self.group_size}" + config.update({"mode": mode}) + return config + + +@keras_export( + [ + "keras.config.set_dtype_policy", + "keras.mixed_precision.set_dtype_policy", # Legacy + "keras.mixed_precision.set_global_policy", # Legacy + ] +) +def set_dtype_policy(policy): + """Sets the default dtype policy globally. + + Example: + + >>> keras.config.set_dtype_policy("mixed_float16") + """ + if not isinstance(policy, DTypePolicy): + if isinstance(policy, str): + if policy.startswith(QUANTIZATION_MODES): + policy = _get_quantized_dtype_policy_by_str(policy) + else: + policy = DTypePolicy(policy) + else: + raise ValueError( + "Invalid `policy` argument. " + "Expected the string name of a policy " + "(such as 'mixed_float16') or a `DTypePolicy` " + f"instance. Received: policy={policy} " + f"(of type {type(policy)})" + ) + global_state.set_global_attribute("dtype_policy", policy) + + +@keras_export( + [ + "keras.config.dtype_policy", + "keras.mixed_precision.dtype_policy", # Legacy + "keras.mixed_precision.global_policy", # Legacy + ] +) +def dtype_policy(): + """Returns the current default dtype policy object.""" + policy = global_state.get_global_attribute("dtype_policy", None) + if policy is None: + policy = DTypePolicy(backend.floatx()) + set_dtype_policy(policy) + return policy + + +def _get_quantized_dtype_policy_by_str(policy): + if not isinstance(policy, str): + raise TypeError(f"`policy` must be a string. Received: policy={policy}") + if not policy.startswith(QUANTIZATION_MODES): + raise ValueError( + "`policy` is incompatible with the current supported quantization." + ) + split_name = policy.split("_from_") + if len(split_name) != 2: + raise ValueError( + "Cannot convert `policy` into a valid pair (`mode`, `source_name`) " + "to instantiate `QuantizedDTypePolicy`. " + f"Received: policy={policy}" + ) + mode, source_name = split_name + if policy.startswith("int8") or policy.startswith("int4"): + return QuantizedDTypePolicy(mode, source_name) + elif policy.startswith("gptq"): + return GPTQDTypePolicy(mode, source_name) + elif policy.startswith("float8"): + return QuantizedFloat8DTypePolicy(mode, source_name) + else: + raise NotImplementedError diff --git a/keras/src/dtype_policies/dtype_policy_map.py b/keras/src/dtype_policies/dtype_policy_map.py new file mode 100644 index 000000000000..d6dc7617b7f9 --- /dev/null +++ b/keras/src/dtype_policies/dtype_policy_map.py @@ -0,0 +1,294 @@ +import re +from collections.abc import MutableMapping + +from keras.src import dtype_policies +from keras.src.api_export import keras_export +from keras.src.dtype_policies import DTypePolicy + + +@keras_export(["keras.dtype_policies.DTypePolicyMap"]) +class DTypePolicyMap(DTypePolicy, MutableMapping): + """Dict-like object mapping layer paths to `DTypePolicy` instances. + + `DTypePolicyMap` can be used in `get_config` in layers and subclasses to + support a complex configurations of dtype policies. + + For example, we can modify `get_config` in `layers.MultiHeadAttention` as + follows to support the mixing of dtype policies, such as quantization. + + ```python + @keras.saving.register_keras_serializable("MyPackage") + class MyMultiHeadAttention(keras.layers.MultiHeadAttention): + def get_config(self): + config = super().get_config() + dtype_policy_map = dtype_policies.DTypePolicyMap() + for layer in self._flatten_layers(): + if layer.dtype_policy.quantization_mode is not None: + dtype_policy_map[layer.path] = layer.dtype_policy + if len(dtype_policy_map) > 0: + config.update({"dtype": dtype_policy_map}) + return config + ``` + + Internally, `DTypePolicyMap` uses a string as a key and a `DTypePolicy` + as the value. Typically, the key used for querying is the `Layer.path`. + However, it is also possible to set a regex as the key. See the docstring of + `get` for more details. + + Args: + default_policy: An optional `DTypePolicy` instance specifying the + default dtype policy. If not specified, the value will default to + `keras.config.dtype_policy()`. + policy_map: An optional dict that maps string to `DTypePolicy` + instances. Defaults to `None` + + Example: + + ```python + >>> from keras.src import dtype_policies + >>> bfloat16 = dtype_policies.DTypePolicy("bfloat16") + >>> float16 = dtype_policies.DTypePolicy("float16") + >>> float32 = dtype_policies.DTypePolicy("float32") + >>> policy_map = DTypePolicyMap(default_policy=float32) + + # Set policies using an exact path and a regex pattern. + # Note: "decoder" will only match the exact path, not its children. + >>> policy_map["encoder/layer_0/dense"] = bfloat16 + >>> policy_map["encoder/.*"] = float16 + >>> policy_map["decoder"] = bfloat16 + + # 1. An exact match is found and returned directly. + >>> policy_map["encoder/layer_0/dense"].name + 'bfloat16' + + # 2. A regex match is found for a child layer. + # It matches the "encoder/.*" pattern. + >>> policy_map["encoder/attention/query"].name + 'float16' + + # 3. No implicit prefix matching occurs. + # "decoder/attention" does not match the key "decoder". + # The default policy is returned. + >>> policy_map["decoder/attention"].name + 'float32' + + # 4. A ValueError is raised if a path matches multiple patterns. + >>> policy_map["encoder/attention/.*"] = bfloat16 + # "encoder/attention/query" now matches two patterns: + # - "encoder/.*" + # - "encoder/attention/.*" + >>> try: + ... policy_map["encoder/attention/query"] + ... except ValueError as e: + ... print(e) + Path 'encoder/attention/query' matches multiple dtype policy .. + ``` + """ + + def __init__(self, default_policy=None, policy_map=None): + if isinstance(default_policy, DTypePolicyMap): + raise ValueError("`default_policy` cannot be a `DTypePolicyMap`.") + if policy_map is not None and not isinstance(policy_map, dict): + raise TypeError( + "If specified, `policy_map` must be a dict. " + f"Received: policy_map={policy_map} of type {type(policy_map)}" + ) + self._default_policy_arg = default_policy + self._default_policy = dtype_policies.get(default_policy) + self._policy_map = policy_map or dict() + + @property + def name(self): + return f"map_{self.default_policy._name}" + + @property + def default_policy(self): + """The default dtype policy. + + If `default_policy` is not specified in the constructor, this property + will be `keras.config.dtype_policy()`. + """ + return dtype_policies.get(self._default_policy) + + @property + def variable_dtype(self): + return self.default_policy.variable_dtype + + @property + def compute_dtype(self): + return self.default_policy.compute_dtype + + @property + def quantization_mode(self): + return self.default_policy.quantization_mode + + def __getitem__(self, key): + """Retrieves the corresponding `DTypePolicy` by the string key. + + This method first attempts an exact key match. If no exact match is + found, it treats all keys in the map as regular expression patterns + and uses `re.fullmatch` to find a policy. + + For example, to apply a policy to all sublayers of an `encoder` block, + the key should be explicitly set to `"encoder/.*"`. A key of + `"encoder"` will only match the layer with that exact path. + + Args: + key: str. The key to query for a `DTypePolicy`. + + Returns: + The corresponding `DTypePolicy`. If no match is found, this method + returns `self.default_policy`. + + Raises: + ValueError: If the `key` matches more than one regex pattern in the + map. + + Example: + + ```python + >>> from keras.src import dtype_policies + >>> bfloat16 = dtype_policies.DTypePolicy("bfloat16") + >>> float16 = dtype_policies.DTypePolicy("float16") + >>> float32 = dtype_policies.DTypePolicy("float32") + >>> policy_map = DTypePolicyMap(default_policy=float32) + + # Set policies using an exact path and a regex pattern. + # Note: "decoder" will only match the exact path, not its children. + >>> policy_map["encoder/layer_0/dense"] = bfloat16 + >>> policy_map["encoder/.*"] = float16 + >>> policy_map["decoder"] = bfloat16 + + # 1. An exact match is found and returned directly. + >>> policy_map["encoder/layer_0/dense"].name + 'bfloat16' + + # 2. A regex match is found for a child layer. + # It matches the "encoder/.*" pattern. + >>> policy_map["encoder/attention/query"].name + 'float16' + + # 3. No implicit prefix matching occurs. + # "decoder/attention" does not match the key "decoder". + # The default policy is returned. + >>> policy_map["decoder/attention"].name + 'float32' + + # 4. A ValueError is raised if a path matches multiple patterns. + >>> policy_map["encoder/attention/.*"] = bfloat16 + # "encoder/attention/query" now matches two patterns: + # - "encoder/.*" + # - "encoder/attention/.*" + >>> try: + ... policy_map["encoder/attention/query"] + ... except ValueError as e: + ... print(e) + Path 'encoder/attention/query' matches multiple dtype policy .. + ``` + """ + # 1. Check for an exact match. + if key in self._policy_map: + return self._policy_map[key] + + # 2. Fallback to a full regex match. + matching_keys = [ + pattern + for pattern in self._policy_map + if re.fullmatch(pattern, key) + ] + + # 3. Handle cases based on the number of matches found. + if len(matching_keys) > 1: + raise ValueError( + f"Path '{key}' matches multiple dtype policy " + f"specification keys: {matching_keys}. Please make " + "sure each path only matches at most " + "one dtype policy specification key in the DTypePolicyMap." + ) + elif len(matching_keys) == 1: + return self._policy_map[matching_keys[0]] + + # 4. If there were no matches, return the default. + return self.default_policy + + def __setitem__(self, key, policy): + """Insert `DTypePolicy` to the `DTypePolicyMap`. + + Args: + key: String key for the `DTypePolicy`. + policy: The `DTypePolicy`. + """ + if key in self._policy_map: + raise ValueError( + f"{key} already exist in the DTypePolicyMap with " + f"value {self._policy_map[key]}. Please make sure to " + "not use duplicated keys." + ) + try: + policy = dtype_policies.get(policy) + except Exception: + raise ValueError( + "Cannot interpret the assigned value by " + "`keras.dtype_policies.get`. " + f"Received: {policy} of type {type(policy)}" + ) + self._policy_map[key] = policy + + def __delitem__(self, key): + # Let the dict to handle the key missing error + return self._policy_map.pop(key) + + def __contains__(self, key): + return key in self._policy_map + + def get_config(self): + from keras.src.saving import serialization_lib + + policy_map = self._policy_map + if self._default_policy_arg is None: + # `default_policy=None` enables us to defer to + # `keras.config.dtype_policy()` during loading. + # To support this feature, we can set `_name` and `_source_name` to + # `None` in `DTypePolicy` and `QuantizedDTypePolicy`, + # respectively. + for policy in policy_map.values(): + if isinstance(policy, dtype_policies.QuantizedDTypePolicy): + policy._name = None + policy._source_name = None + elif isinstance(policy, dtype_policies.DTypePolicy): + policy._name = None + return { + "default_policy": self._default_policy_arg, + "policy_map": serialization_lib.serialize_keras_object(policy_map), + } + + @classmethod + def from_config(cls, config, custom_objects=None): + from keras.src.saving import serialization_lib + + config = config.copy() + config["policy_map"] = serialization_lib.deserialize_keras_object( + config["policy_map"], custom_objects=custom_objects + ) + return cls(**config) + + def __len__(self): + return len(self._policy_map) + + def __iter__(self): + return iter(self._policy_map) + + def __repr__(self): + default_policy = ( + self._default_policy.name + if self._default_policy is not None + else None + ) + mapping = [] + for k, v in self._policy_map.items(): + mapping.append((k, v.name)) + return ( + f"" + ) diff --git a/keras/src/dtype_policies/dtype_policy_map_test.py b/keras/src/dtype_policies/dtype_policy_map_test.py new file mode 100644 index 000000000000..a0e6673cd695 --- /dev/null +++ b/keras/src/dtype_policies/dtype_policy_map_test.py @@ -0,0 +1,361 @@ +import numpy as np +import pytest + +from keras.src import dtype_policies +from keras.src import layers +from keras.src import models +from keras.src import saving +from keras.src import testing +from keras.src.dtype_policies.dtype_policy import dtype_policy +from keras.src.dtype_policies.dtype_policy import set_dtype_policy +from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap + + +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +class DTypePolicyMapTest(testing.TestCase): + def setUp(self): + super().setUp() + self._global_dtype_policy = dtype_policy() + + def tearDown(self): + super().tearDown() + set_dtype_policy(self._global_dtype_policy) + + @pytest.mark.requires_trainable_backend + def test_basic_usage(self): + # Create a subclass that might contain mixing dtype policies for + # sublayers. + # It is important to ensure that `dtype` is passed to sublayers and + # that each sublayer has a unique `name`. + @saving.register_keras_serializable() + class Subclass(layers.Layer): + def __init__(self, dtype=None, name="subclass", **kwargs): + super().__init__(dtype=dtype, name=name, **kwargs) + self.dense = layers.Dense(8, dtype=dtype, name=f"{name}_dense") + self.bn = layers.BatchNormalization( + dtype=dtype, name=f"{name}_bn" + ) + self.relu = layers.ReLU(dtype=dtype, name=f"{name}_relu") + + def call(self, inputs, training=None): + return self.relu(self.bn(self.dense(inputs), training=training)) + + def get_config(self): + # Typically, we only need to record the quantized policy for + # `DTypePolicyMap` + config = super().get_config() + dtype_policy_map = DTypePolicyMap() + for layer in self._flatten_layers(): + if layer.quantization_mode is not None: + dtype_policy_map[layer.path] = layer.dtype_policy + if len(dtype_policy_map) > 0: + config.update({"dtype": dtype_policy_map}) + return config + + # Instantiate the model + inputs = layers.Input([4]) + outputs = Subclass()(inputs) + model = models.Model(inputs, outputs) + + # Quantize the model to make mixing of dtype policies in sublayers + model.quantize("int8") + for layer in model._flatten_layers(): + if isinstance(layer, layers.Dense): + self.assertEqual( + layer.dtype_policy, + dtype_policies.QuantizedDTypePolicy("int8"), + ) + elif isinstance(layer, layers.BatchNormalization): + self.assertEqual( + layer.dtype_policy, dtype_policies.DTypePolicy() + ) + elif isinstance(layer, layers.ReLU): + self.assertEqual( + layer.dtype_policy, dtype_policies.DTypePolicy() + ) + + # Verify the output after saving and loading + x = np.random.uniform(size=[16, 4]) + temp_dir = self.get_temp_dir() + y = model(x, training=False) + model.save(f"{temp_dir}/model.keras") + reloaded_model = saving.load_model(f"{temp_dir}/model.keras") + reloaded_y = reloaded_model(x, training=False) + self.assertAllClose(y, reloaded_y) + + def test_add(self): + dtype_policy_map = DTypePolicyMap() + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + dtype_policy_map["layer/dense_2"] = ( + dtype_policies.QuantizedFloat8DTypePolicy("float8", "mixed_float16") + ) + + self.assertLen(dtype_policy_map, 3) + + policy = dtype_policy_map["layer/dense_0"] + self.assertIsInstance(policy, dtype_policies.DTypePolicy) + self.assertEqual(policy.name, "bfloat16") + + policy = dtype_policy_map["layer/dense_1"] + self.assertIsInstance(policy, dtype_policies.QuantizedDTypePolicy) + self.assertEqual(policy._source_name, "mixed_bfloat16") + self.assertEqual(policy.quantization_mode, "int8") + + policy = dtype_policy_map["layer/dense_2"] + self.assertIsInstance(policy, dtype_policies.QuantizedFloat8DTypePolicy) + self.assertEqual(policy._source_name, "mixed_float16") + self.assertEqual(policy.quantization_mode, "float8") + + with self.assertRaisesRegex( + ValueError, "layer/dense_0 already exist in the DTypePolicyMap" + ): + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "float32" + ) + + with self.assertRaisesRegex( + ValueError, "Cannot interpret the assigned value." + ): + dtype_policy_map["layer/dense_3"] = 123 + + def test_get(self): + # 1. Setup + bfloat16_policy = dtype_policies.DTypePolicy("bfloat16") + int8_policy = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + float32_policy = dtype_policies.DTypePolicy("float32") + float16_policy = dtype_policies.DTypePolicy("float16") + + policy_map = DTypePolicyMap() + # Policy for an exact layer path + policy_map["model/encoder/layer_0/dense"] = bfloat16_policy + # Policy for a layer that is also a prefix of another layer's name + policy_map["model/encoder/attention/query"] = int8_policy + # Regex policies for entire scopes MUST include wildcards + policy_map["model/decoder/.*"] = float32_policy + policy_map["model/decoder/attention/.*"] = float16_policy + + # 2. Test exact match + self.assertEqual( + policy_map["model/encoder/layer_0/dense"], bfloat16_policy + ) + self.assertEqual( + policy_map["model/encoder/attention/query"], int8_policy + ) + + # 3. Test successful regex fallback (explicit wildcard) + # "model/decoder/.*" should match its children. + self.assertEqual(policy_map["model/decoder/layer_0"], float32_policy) + + # 4. Test that partial matches are ignored + # The exact key "model/encoder/attention/query" should not match + # "model/encoder/attention/query_norm" without a wildcard. + self.assertEqual( + policy_map["model/encoder/attention/query_norm"], + policy_map.default_policy, + ) + # A plain key "model/decoder" will not match "model/decoder/layer_0" + policy_map["model/decoder"] = bfloat16_policy # Add exact key + self.assertEqual(policy_map["model/decoder/layer_0"], float32_policy) + # Still matches the more general regex + self.assertEqual(policy_map["model/decoder"], bfloat16_policy) + + # 5. Test no match + self.assertEqual( + policy_map["model/embedding"], policy_map.default_policy + ) + + # 6. Test multiple regex matches causing a ValueError + # "model/decoder/attention/output" matches two regex patterns: + # - "model/decoder/.*" + # - "model/decoder/attention/.*" + with self.assertRaisesRegex( + ValueError, + "Path 'model/decoder/attention/output' matches multiple " + "dtype policy", + ): + _ = policy_map["model/decoder/attention/output"] + + def test_delete(self): + dtype_policy_map = DTypePolicyMap() + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + + self.assertEqual( + dtype_policy_map.pop("layer/dense_0"), + dtype_policies.DTypePolicy("bfloat16"), + ) + with self.assertRaises(KeyError): + dtype_policy_map.pop("layer/dense_0") + + # Test `del`, causing no hit + del dtype_policy_map["layer/dense_1"] + self.assertEqual( + dtype_policy_map["layer/dense_1"], dtype_policy_map.default_policy + ) + + self.assertLen(dtype_policy_map, 0) + + def test_len(self): + dtype_policy_map = DTypePolicyMap() + self.assertLen(dtype_policy_map, 0) + + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + self.assertLen(dtype_policy_map, 2) + + def test_iter(self): + dtype_policy_map = DTypePolicyMap() + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + + self.assertEqual( + list(dtype_policy_map.keys()), ["layer/dense_0", "layer/dense_1"] + ) + + keys = [] + values = [] + for k, v in dtype_policy_map.items(): + keys.append(k) + values.append(v) + self.assertEqual(keys, ["layer/dense_0", "layer/dense_1"]) + self.assertEqual( + values, + [ + dtype_policies.DTypePolicy("bfloat16"), + dtype_policies.QuantizedDTypePolicy("int8", "mixed_bfloat16"), + ], + ) + + def test_in(self): + dtype_policy_map = DTypePolicyMap() + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + + self.assertTrue("layer/dense_0" in dtype_policy_map) + self.assertTrue("layer/dense_1" in dtype_policy_map) + self.assertFalse("layer/dense_2" in dtype_policy_map) + + def test_default_policy(self): + # Test default_policy is set to `"float32"` + dtype_policy_map = DTypePolicyMap(default_policy="mixed_bfloat16") + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "mixed_bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + config = dtype_policy_map.get_config() + dtype_policy_map = DTypePolicyMap.from_config(config) + self.assertEqual( + dtype_policy_map["layer/dense_0"], + dtype_policies.DTypePolicy("mixed_bfloat16"), + ) + self.assertEqual( + dtype_policy_map["layer/dense_1"], + dtype_policies.QuantizedDTypePolicy("int8", "mixed_bfloat16"), + ) + # No hit, defers to `dtype_policy_map.default_policy` + self.assertEqual( + dtype_policy_map["layer/dense_2"], dtype_policy_map.default_policy + ) + + # Test that default_policy defers to `keras.config.dtype_policy()` + # during loading + set_dtype_policy("bfloat16") + dtype_policy_map = DTypePolicyMap() + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "mixed_bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + config = dtype_policy_map.get_config() + dtype_policy_map = DTypePolicyMap.from_config(config) + self.assertEqual( + dtype_policy_map["layer/dense_0"], + dtype_policies.DTypePolicy("bfloat16"), + ) + self.assertEqual( + dtype_policy_map["layer/dense_1"], + dtype_policies.QuantizedDTypePolicy("int8", "bfloat16"), + ) + # No hit, defers to `dtype_policy_map.default_policy` which is + # `keras.config.dtype_policy()` + self.assertEqual( + dtype_policy_map["layer/dense_2"], dtype_policy_map.default_policy + ) + self.assertEqual( + dtype_policy_map["layer/dense_2"], dtype_policies.get("bfloat16") + ) + + def test_serialization(self): + dtype_policy_map = DTypePolicyMap(default_policy="mixed_bfloat16") + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "mixed_bfloat16" + ) + dtype_policy_map["layer/dense_1"] = dtype_policies.QuantizedDTypePolicy( + "int8", "mixed_bfloat16" + ) + + config = dtype_policies.serialize(dtype_policy_map) + reloaded_dtype_policy_map = dtype_policies.deserialize(config) + self.assertEqual( + dtype_policy_map.default_policy, + reloaded_dtype_policy_map.default_policy, + ) + for k, v in dtype_policy_map.items(): + self.assertEqual(reloaded_dtype_policy_map[k], v) + + # Test that config remains intact during deserialization + config = dtype_policy_map.get_config() + original_config = config.copy() + DTypePolicyMap.from_config(config) + self.assertDictEqual(config, original_config) + + def test_repr(self): + dtype_policy_map = DTypePolicyMap() + dtype_policy_map["layer/dense_0"] = dtype_policies.DTypePolicy( + "mixed_bfloat16" + ) + repr_str = repr(dtype_policy_map) + self.assertTrue("DTypePolicyMap" in repr_str) + self.assertTrue("default_policy" in repr_str) + self.assertTrue( + "mapping=[('layer/dense_0', 'mixed_bfloat16')]" in repr_str + ) + + def test_invalid_policy_map(self): + with self.assertRaisesRegex( + TypeError, "If specified, `policy_map` must be a dict." + ): + DTypePolicyMap(policy_map=123) + + with self.assertRaisesRegex( + TypeError, "If specified, `policy_map` must be a dict." + ): + DTypePolicyMap( + policy_map=dtype_policies.DTypePolicy("mixed_bfloat16") + ) diff --git a/keras/src/dtype_policies/dtype_policy_test.py b/keras/src/dtype_policies/dtype_policy_test.py new file mode 100644 index 000000000000..ac23fdbbd85f --- /dev/null +++ b/keras/src/dtype_policies/dtype_policy_test.py @@ -0,0 +1,746 @@ +from absl.testing import parameterized + +from keras.src.dtype_policies import deserialize +from keras.src.dtype_policies import get +from keras.src.dtype_policies import serialize +from keras.src.dtype_policies.dtype_policy import DTypePolicy +from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy +from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy +from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy +from keras.src.dtype_policies.dtype_policy import dtype_policy +from keras.src.dtype_policies.dtype_policy import set_dtype_policy +from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.testing import test_case + + +class DTypePolicyTest(test_case.TestCase): + """Test `DTypePolicy`. + + In the tests, we also test `DTypePolicy` for historical reasons. + """ + + def setUp(self): + """Record the global dtype policy before each test.""" + super().setUp() + self._global_dtype_policy = dtype_policy() + + def tearDown(self): + super().tearDown() + """Restore the global dtype policy after each test.""" + set_dtype_policy(self._global_dtype_policy) + + def test_initialization_valid_name(self): + """Test initialization with a valid name.""" + policy = DTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + @parameterized.named_parameters( + ("float32", "float32", "float32", "float32"), + ("float16", "float16", "float16", "float16"), + ("bfloat16", "bfloat16", "bfloat16", "bfloat16"), + ("mixed_float16", "mixed_float16", "float16", "float32"), + ("mixed_bfloat16", "mixed_bfloat16", "bfloat16", "float32"), + ) + def test_initialization_from_global( + self, + global_dtype_policy, + expected_compute_dtype, + expected_variable_dtype, + ): + set_dtype_policy(global_dtype_policy) + + policy = DTypePolicy(name=None) + self.assertEqual(policy.name, global_dtype_policy) + self.assertEqual(policy.compute_dtype, expected_compute_dtype) + self.assertEqual(policy.variable_dtype, expected_variable_dtype) + + policy = FloatDTypePolicy(name=None) + self.assertEqual(policy.name, global_dtype_policy) + self.assertEqual(policy.compute_dtype, expected_compute_dtype) + self.assertEqual(policy.variable_dtype, expected_variable_dtype) + + def test_initialization_invalid_name(self): + """Test initialization with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + DTypePolicy("invalid_name") + + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("invalid_name") + + def test_initialization_non_string_name(self): + """Test initialization with a non-string name.""" + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + DTypePolicy(123) + + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + FloatDTypePolicy(123) + + def test_properties_mixed_float16(self): + """Test properties for 'mixed_float16'.""" + policy = DTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_properties_mixed_bfloat16(self): + """Test properties for 'mixed_bfloat16'.""" + policy = DTypePolicy("mixed_bfloat16") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + + policy = FloatDTypePolicy("mixed_bfloat16") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_with_invalid_name_behaviour(self): + """Test initialization behavior with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + DTypePolicy("invalid_name") + + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("invalid_name") + + def test_properties(self): + """Test variable_dtype, compute_dtype, and name properties.""" + policy = DTypePolicy("mixed_float16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.name, "mixed_float16") + self.assertIsNone(policy.quantization_mode) + + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.name, "mixed_float16") + self.assertIsNone(policy.quantization_mode) + + def test_properties_uint8(self): + """Test properties for 'uint8'.""" + policy = DTypePolicy("uint8") + self.assertEqual(policy.compute_dtype, "uint8") + self.assertEqual(policy.variable_dtype, "uint8") + self.assertEqual(policy.name, "uint8") + + policy = FloatDTypePolicy("uint8") + self.assertEqual(policy.compute_dtype, "uint8") + self.assertEqual(policy.variable_dtype, "uint8") + self.assertEqual(policy.name, "uint8") + + def test_repr(self): + """Test __repr__ method.""" + policy = DTypePolicy("mixed_float16") + self.assertEqual(repr(policy), '') + + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(repr(policy), '') + + def test_get_config_from_config(self): + """Test get_config and from_config methods.""" + # Test DTypePolicy + policy = DTypePolicy("mixed_float16") + config = policy.get_config() + self.assertEqual(config, {"name": "mixed_float16"}) + new_policy = DTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "mixed_float16") + + # Test FloatDTypePolicy + policy = FloatDTypePolicy("mixed_float16") + config = policy.get_config() + self.assertEqual(config, {"name": "mixed_float16"}) + new_policy = FloatDTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "mixed_float16") + + def test_serialization(self): + # Test DTypePolicy + policy = DTypePolicy("mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test FloatDTypePolicy + policy = FloatDTypePolicy("mixed_float16") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + + def test_python_serialization(self): + """Test builtin serialization methods.""" + import copy + import pickle + + # Test DTypePolicy + policy = DTypePolicy("mixed_float16") + + # copy.deepcopy + copied_policy = copy.deepcopy(policy) + self.assertEqual(repr(copied_policy), '') + # copy.copy + copied_policy = copy.copy(policy) + self.assertEqual(repr(copied_policy), '') + # pickle + temp_dir = self.get_temp_dir() + with open(f"{temp_dir}/policy.pickle", "wb") as f: + pickle.dump(policy, f) + with open(f"{temp_dir}/policy.pickle", "rb") as f: + copied_policy = pickle.load(f) + self.assertEqual(repr(copied_policy), '') + + # Test FloatDTypePolicy + policy = FloatDTypePolicy("mixed_float16") + + # copy.deepcopy + copied_policy = copy.deepcopy(policy) + self.assertEqual(repr(copied_policy), '') + # copy.copy + copied_policy = copy.copy(policy) + self.assertEqual(repr(copied_policy), '') + # pickle + temp_dir = self.get_temp_dir() + with open(f"{temp_dir}/policy.pickle", "wb") as f: + pickle.dump(policy, f) + with open(f"{temp_dir}/policy.pickle", "rb") as f: + copied_policy = pickle.load(f) + self.assertEqual(repr(copied_policy), '') + + def test_eq(self): + policy = DTypePolicy("mixed_bfloat16") + + # Test True + self.assertEqual(policy, DTypePolicy("mixed_bfloat16")) + self.assertEqual(policy, FloatDTypePolicy("mixed_bfloat16")) + + # Test False + self.assertNotEqual(policy, "mixed_float16") + self.assertNotEqual( + policy, QuantizedDTypePolicy("int8", "mixed_bfloat16") + ) + + +class QuantizedDTypePolicyTest(test_case.TestCase): + def setUp(self): + """Record the global dtype policy before each test.""" + super().setUp() + self._global_dtype_policy = dtype_policy() + + def tearDown(self): + super().tearDown() + """Restore the global dtype policy after each test.""" + set_dtype_policy(self._global_dtype_policy) + + @parameterized.named_parameters( + ("float32", "float32", "float32", "float32"), + ("bfloat16", "bfloat16", "bfloat16", "bfloat16"), + ("mixed_bfloat16", "mixed_bfloat16", "bfloat16", "float32"), + ) + def test_initialization_for_int8( + self, source_name, expected_compute_dtype, expected_variable_dtype + ): + name = f"int8_from_{source_name}" + policy = QuantizedDTypePolicy(mode="int8", source_name=source_name) + self.assertEqual(policy.name, name) + self.assertEqual(policy.compute_dtype, expected_compute_dtype) + self.assertEqual(policy.variable_dtype, expected_variable_dtype) + self.assertEqual(repr(policy), f'') + + @parameterized.named_parameters( + ("float32", "float32", "float32", "float32"), + ("bfloat16", "bfloat16", "bfloat16", "bfloat16"), + ("mixed_bfloat16", "mixed_bfloat16", "bfloat16", "float32"), + ) + def test_initialization_for_int8_from_global( + self, + global_dtype_policy, + expected_compute_dtype, + expected_variable_dtype, + ): + set_dtype_policy(global_dtype_policy) + expected_name = f"int8_from_{global_dtype_policy}" + + policy = QuantizedDTypePolicy(mode="int8", source_name=None) + self.assertEqual(policy.name, expected_name) + self.assertEqual(policy.compute_dtype, expected_compute_dtype) + self.assertEqual(policy.variable_dtype, expected_variable_dtype) + + @parameterized.named_parameters( + ("float32", "float32", "float32", "float32"), + ("float16", "float16", "float16", "float16"), + ("bfloat16", "bfloat16", "bfloat16", "bfloat16"), + ("mixed_float16", "mixed_float16", "float16", "float32"), + ("mixed_bfloat16", "mixed_bfloat16", "bfloat16", "float32"), + ) + def test_initialization_for_float8( + self, source_name, expected_compute_dtype, expected_variable_dtype + ): + name = f"float8_from_{source_name}" + policy = QuantizedFloat8DTypePolicy( + mode="float8", source_name=source_name + ) + self.assertEqual(policy.name, name) + self.assertEqual(policy.compute_dtype, expected_compute_dtype) + self.assertEqual(policy.variable_dtype, expected_variable_dtype) + self.assertEqual(repr(policy), f'') + + @parameterized.named_parameters( + ("float32", "float32", "float32", "float32"), + ("float16", "float16", "float16", "float16"), + ("bfloat16", "bfloat16", "bfloat16", "bfloat16"), + ("mixed_float16", "mixed_float16", "float16", "float32"), + ("mixed_bfloat16", "mixed_bfloat16", "bfloat16", "float32"), + ) + def test_initialization_for_float8_from_global( + self, + global_dtype_policy, + expected_compute_dtype, + expected_variable_dtype, + ): + set_dtype_policy(global_dtype_policy) + expected_name = f"float8_from_{global_dtype_policy}" + + policy = QuantizedFloat8DTypePolicy(mode="float8", source_name=None) + self.assertEqual(policy.name, expected_name) + self.assertEqual(policy.compute_dtype, expected_compute_dtype) + self.assertEqual(policy.variable_dtype, expected_variable_dtype) + + @parameterized.named_parameters( + ("abc", "abc"), + ("abc_from_def", "def"), + ) + def test_initialization_with_invalid_name(self, invalid_name): + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy(mode="int8", source_name=invalid_name) + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedFloat8DTypePolicy(mode="float8", source_name=invalid_name) + + @parameterized.named_parameters( + ("int7", "int7"), + ("float7", "float7"), + ) + def test_initialization_with_invalid_mode(self, invalid_mode): + with self.assertRaisesRegex(ValueError, "Invalid quantization mode."): + QuantizedDTypePolicy(mode=invalid_mode) + with self.assertRaisesRegex(ValueError, "Invalid quantization mode."): + QuantizedFloat8DTypePolicy(mode=invalid_mode) + + @parameterized.named_parameters( + ("int8_from_float16", "float16"), + ("int8_from_mixed_float16", "mixed_float16"), + ) + def test_initialization_with_invalid_compute_dtype(self, invalid_name): + with self.assertRaisesRegex(ValueError, "doesn't work well"): + QuantizedDTypePolicy(mode="int8", source_name=invalid_name) + + def test_initialization_non_string_name(self): + """Test initialization with a non-string name.""" + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + QuantizedDTypePolicy(mode="int8", source_name=123) + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + QuantizedFloat8DTypePolicy(mode="float8", source_name=123) + + def test_properties(self): + # Test int8 + policy = QuantizedDTypePolicy(mode="int8", source_name="mixed_bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.name, "int8_from_mixed_bfloat16") + self.assertEqual(policy.quantization_mode, "int8") + + # Test float8 + policy = QuantizedFloat8DTypePolicy( + mode="float8", source_name="mixed_bfloat16" + ) + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.name, "float8_from_mixed_bfloat16") + self.assertEqual(policy.quantization_mode, "float8") + self.assertEqual(policy.amax_history_length, 1024) + + # Test float8 with amax_history_length + policy = QuantizedFloat8DTypePolicy( + mode="float8", source_name="mixed_bfloat16", amax_history_length=512 + ) + self.assertEqual(policy.amax_history_length, 512) + + # Test float8 default_amax_history_length + self.assertEqual( + QuantizedFloat8DTypePolicy.default_amax_history_length, 1024 + ) + + def test_invalid_properties_for_float8(self): + with self.assertRaisesRegex(TypeError, "must be an integer."): + QuantizedFloat8DTypePolicy( + mode="float8", source_name="float32", amax_history_length="512" + ) + with self.assertRaisesRegex(TypeError, "must be an integer."): + QuantizedFloat8DTypePolicy( + mode="float8", source_name="float32", amax_history_length=512.0 + ) + + def test_get_config_from_config(self): + """Test get_config and from_config methods.""" + # Test QuantizedDTypePolicy + policy = QuantizedDTypePolicy(mode="int8", source_name="mixed_bfloat16") + config = policy.get_config() + self.assertEqual( + config, {"mode": "int8", "source_name": "mixed_bfloat16"} + ) + new_policy = QuantizedDTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "int8_from_mixed_bfloat16") + + # Test QuantizedFloat8DTypePolicy + policy = QuantizedFloat8DTypePolicy( + mode="float8", source_name="mixed_bfloat16" + ) + config = policy.get_config() + self.assertEqual( + config, + { + "mode": "float8", + "source_name": "mixed_bfloat16", + "amax_history_length": 1024, + }, + ) + new_policy = QuantizedFloat8DTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "float8_from_mixed_bfloat16") + + def test_serialization(self): + # Test QuantizedDTypePolicy + policy = QuantizedDTypePolicy(mode="int8", source_name="float32") + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + + # Test QuantizedFloat8DTypePolicy + policy = QuantizedFloat8DTypePolicy( + mode="float8", source_name="float32" + ) + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + + @parameterized.named_parameters( + ( + "int8_from_mixed_bfloat16", + "int8", + "mixed_bfloat16", + '', + ), + ( + "float8_from_mixed_bfloat16", + "float8", + "mixed_bfloat16", + '', + ), + ) + def test_python_serialization(self, mode, source_name, repr_str): + import copy + import pickle + + if mode == "int8": + policy = QuantizedDTypePolicy(mode=mode, source_name=source_name) + else: + policy = QuantizedFloat8DTypePolicy( + mode=mode, source_name=source_name, amax_history_length=123 + ) + + # copy.deepcopy + copied_policy = copy.deepcopy(policy) + self.assertEqual(repr(copied_policy), repr_str) + if mode == "float8": + self.assertEqual(copied_policy.amax_history_length, 123) + # copy.copy + copied_policy = copy.copy(policy) + self.assertEqual(repr(copied_policy), repr_str) + if mode == "float8": + self.assertEqual(copied_policy.amax_history_length, 123) + # pickle + temp_dir = self.get_temp_dir() + with open(f"{temp_dir}/policy.pickle", "wb") as f: + pickle.dump(policy, f) + with open(f"{temp_dir}/policy.pickle", "rb") as f: + copied_policy = pickle.load(f) + self.assertEqual(repr(copied_policy), repr_str) + if mode == "float8": + self.assertEqual(copied_policy.amax_history_length, 123) + + def test_serialization_for_float8(self): + policy = QuantizedFloat8DTypePolicy( + mode="float8", source_name="mixed_float16" + ) + config = serialize(policy) + reloaded_policy = deserialize(config) + self.assertEqual(policy.name, reloaded_policy.name) + self.assertEqual( + policy.amax_history_length, reloaded_policy.amax_history_length + ) + + # Test `dtype_policies.get` + reloaded_policy = get(config) + self.assertEqual(policy.name, reloaded_policy.name) + self.assertEqual( + policy.amax_history_length, reloaded_policy.amax_history_length + ) + + def test_eq(self): + policy = QuantizedDTypePolicy("int8", "mixed_bfloat16") + + # Test True + self.assertEqual(policy, QuantizedDTypePolicy("int8", "mixed_bfloat16")) + + # Test False + self.assertNotEqual(policy, "mixed_bfloat16") + self.assertNotEqual(policy, DTypePolicy("mixed_bfloat16")) + self.assertNotEqual( + policy, QuantizedFloat8DTypePolicy("float8", "mixed_bfloat16") + ) + + @parameterized.named_parameters( + ("int8_from_mixed_bfloat16", "int8_from_mixed_bfloat16"), + ("float8_from_mixed_bfloat16", "float8_from_mixed_bfloat16"), + ) + def test_get_quantized_dtype_policy_by_str(self, name): + from keras.src.dtype_policies.dtype_policy import ( + _get_quantized_dtype_policy_by_str, + ) + + policy = _get_quantized_dtype_policy_by_str(name) + self.assertEqual(policy.name, name) + + def test_invalid_get_quantized_dtype_policy_by_str(self): + from keras.src.dtype_policies.dtype_policy import ( + _get_quantized_dtype_policy_by_str, + ) + + with self.assertRaisesRegex(TypeError, "must be a string."): + _get_quantized_dtype_policy_by_str(123) + with self.assertRaisesRegex( + ValueError, + "is incompatible with the current supported quantization.", + ): + _get_quantized_dtype_policy_by_str("float7") + + +class DTypePolicyGlobalFunctionsTest(test_case.TestCase): + def setUp(self): + """Reset the global dtype policy before each test.""" + set_dtype_policy("float32") + + def test_set_dtype_policy_valid_string(self): + """Test set_dtype_policy with a valid string.""" + set_dtype_policy("mixed_float16") + policy = dtype_policy() + self.assertEqual(policy.name, "mixed_float16") + + def test_set_dtype_policy_valid_string_quantized(self): + """Test set_dtype_policy with a valid string.""" + set_dtype_policy("int8_from_mixed_bfloat16") + policy = dtype_policy() + self.assertEqual(policy.name, "int8_from_mixed_bfloat16") + + def test_set_dtype_policy_valid_policy(self): + """Test set_dtype_policy with a valid DTypePolicy object.""" + policy_obj = DTypePolicy("mixed_float16") + set_dtype_policy(policy_obj) + policy = dtype_policy() + self.assertEqual(policy.name, "mixed_float16") + + def test_set_dtype_policy_valid_policy_quantized(self): + """Test set_dtype_policy with a valid QuantizedDTypePolicy object.""" + policy_obj = QuantizedDTypePolicy( + mode="int8", source_name="mixed_bfloat16" + ) + set_dtype_policy(policy_obj) + policy = dtype_policy() + self.assertEqual(policy.name, "int8_from_mixed_bfloat16") + + def test_set_dtype_policy_invalid(self): + """Test set_dtype_policy with an invalid input.""" + with self.assertRaisesRegex(ValueError, "Invalid `policy` argument"): + set_dtype_policy(12345) + + def test_dtype_policy_default(self): + """Test dtype_policy default value.""" + policy = dtype_policy() + self.assertEqual(policy.name, "float32") + + def test_get_valid_policy(self): + policy = get("bfloat16") + self.assertEqual(policy.name, "bfloat16") + + policy = get("mixed_float16") + self.assertEqual(policy.name, "mixed_float16") + + policy = get(DTypePolicy("bfloat16")) + self.assertEqual(policy.name, "bfloat16") + + policy = get(FloatDTypePolicy("mixed_float16")) + self.assertEqual(policy.name, "mixed_float16") + + def test_get_valid_policy_quantized(self): + policy = get("int8_from_mixed_bfloat16") + self.assertEqual(policy.name, "int8_from_mixed_bfloat16") + + policy = get("float8_from_float32") + self.assertEqual(policy.name, "float8_from_float32") + + policy = get(QuantizedDTypePolicy("int8", "mixed_bfloat16")) + self.assertEqual(policy.name, "int8_from_mixed_bfloat16") + + policy = get(QuantizedFloat8DTypePolicy("float8", "mixed_float16")) + self.assertEqual(policy.name, "float8_from_mixed_float16") + + def test_get_invalid_policy(self): + with self.assertRaisesRegex(ValueError, "Cannot convert"): + get("mixed_bfloat15") + with self.assertRaisesRegex( + ValueError, "Cannot interpret `dtype` argument." + ): + get(123) + + def test_get_invalid_policy_quantized(self): + with self.assertRaisesRegex(ValueError, "Cannot convert"): + get("int8_from_mixed_bfloat15") + with self.assertRaisesRegex(ValueError, "Cannot convert"): + get("int8_from_") + with self.assertRaisesRegex( + ValueError, "Cannot convert `policy` into a valid pair" + ): + get("int8_abc_") + + +class DTypePolicyEdgeCasesTest(test_case.TestCase): + def test_empty_name(self): + """Test initialization with an empty name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + DTypePolicy("") + + def test_special_character_name(self): + """Test initialization with special characters in the name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + DTypePolicy("@mixed_float16!") + + def test_very_long_name(self): + """Test initialization with a very long name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + DTypePolicy("mixed_float16" * 100) + + def test_almost_valid_name(self): + """Test initialization with a name close to a valid one.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + DTypePolicy("mixed_float15") + + +class QuantizedDTypePolicyEdgeCasesTest(test_case.TestCase): + def test_empty_name(self): + """Test initialization with an empty name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy(mode="int8", source_name="") + + def test_special_character_name(self): + """Test initialization with special characters in the name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy( + mode="int8", source_name="@int8_from_mixed_bfloat16!" + ) + + def test_very_long_name(self): + """Test initialization with a very long name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy( + mode="int8", source_name="int8_from_mixed_bfloat16" * 100 + ) + + def test_almost_valid_name(self): + """Test initialization with a name close to a valid one.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy( + mode="int8", source_name="int7_from_mixed_bfloat16" + ) + + +class DTypePolicyGlobalFunctionsEdgeCasesTest(test_case.TestCase): + def setUp(self): + """Reset the global dtype policy before each test.""" + set_dtype_policy("float32") + + def test_set_policy_multiple_times(self): + """Test setting the policy multiple times in a row.""" + set_dtype_policy("mixed_float16") + policy = dtype_policy() + self.assertEqual(policy.name, "mixed_float16") + + set_dtype_policy("float32") + policy = dtype_policy() + self.assertEqual(policy.name, "float32") + + def test_set_policy_none(self): + """Test setting the policy to None.""" + with self.assertRaisesRegex(ValueError, "Invalid `policy` argument"): + set_dtype_policy(None) + + +class GPTQConfigErrorHandlingTest(test_case.TestCase): + """Test error handling in GPTQConfig.""" + + def test_invalid_weight_bits(self): + with self.assertRaisesRegex(ValueError, "Unsupported weight_bits"): + GPTQConfig( + dataset=None, + tokenizer=None, + weight_bits=5, + ) + + def test_negative_num_samples(self): + with self.assertRaisesRegex( + ValueError, "num_samples must be a positive integer." + ): + GPTQConfig( + dataset=None, + tokenizer=None, + num_samples=-10, + ) + + def test_zero_sequence_length(self): + with self.assertRaisesRegex( + ValueError, "sequence_length must be a positive integer." + ): + GPTQConfig( + dataset=None, + tokenizer=None, + sequence_length=0, + ) + + def test_invalid_hessian_damping(self): + with self.assertRaisesRegex( + ValueError, "hessian_damping must be between 0 and 1." + ): + GPTQConfig( + dataset=None, + tokenizer=None, + hessian_damping=1.5, + ) + + def test_invalid_group_size(self): + with self.assertRaisesRegex( + ValueError, "Invalid group_size. Supported values are -1" + ): + GPTQConfig( + dataset=None, + tokenizer=None, + group_size=0, + ) diff --git a/keras/src/export/__init__.py b/keras/src/export/__init__.py new file mode 100644 index 000000000000..7adfd18513f6 --- /dev/null +++ b/keras/src/export/__init__.py @@ -0,0 +1,5 @@ +from keras.src.export.onnx import export_onnx +from keras.src.export.openvino import export_openvino +from keras.src.export.saved_model import ExportArchive +from keras.src.export.saved_model import export_saved_model +from keras.src.export.tfsm_layer import TFSMLayer diff --git a/keras/src/export/export_utils.py b/keras/src/export/export_utils.py new file mode 100644 index 000000000000..4b76f68fe4a6 --- /dev/null +++ b/keras/src/export/export_utils.py @@ -0,0 +1,107 @@ +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import tree +from keras.src.utils.module_utils import tensorflow as tf + + +def get_input_signature(model): + if not isinstance(model, models.Model): + raise TypeError( + "The model must be a `keras.Model`. " + f"Received: model={model} of the type {type(model)}" + ) + if not model.built: + raise ValueError( + "The model provided has not yet been built. It must be built " + "before export." + ) + if isinstance(model, models.Functional): + input_signature = [ + tree.map_structure(make_input_spec, model._inputs_struct) + ] + elif isinstance(model, models.Sequential): + input_signature = tree.map_structure(make_input_spec, model.inputs) + else: + input_signature = _infer_input_signature_from_model(model) + if not input_signature or not model._called: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + return input_signature + + +def _infer_input_signature_from_model(model): + shapes_dict = getattr(model, "_build_shapes_dict", None) + if not shapes_dict: + return None + + def _make_input_spec(structure): + # We need to turn wrapper structures like TrackingDict or _DictWrapper + # into plain Python structures because they don't work with jax2tf/JAX. + if isinstance(structure, dict): + return {k: _make_input_spec(v) for k, v in structure.items()} + elif isinstance(structure, tuple): + if all(isinstance(d, (int, type(None))) for d in structure): + return layers.InputSpec( + shape=(None,) + structure[1:], dtype=model.input_dtype + ) + return tuple(_make_input_spec(v) for v in structure) + elif isinstance(structure, list): + if all(isinstance(d, (int, type(None))) for d in structure): + return layers.InputSpec( + shape=[None] + structure[1:], dtype=model.input_dtype + ) + return [_make_input_spec(v) for v in structure] + else: + raise ValueError( + f"Unsupported type {type(structure)} for {structure}" + ) + + return [_make_input_spec(value) for value in shapes_dict.values()] + + +def make_input_spec(x): + if isinstance(x, layers.InputSpec): + if x.shape is None or x.dtype is None: + raise ValueError( + f"The `shape` and `dtype` must be provided. Received: x={x}" + ) + input_spec = x + elif isinstance(x, backend.KerasTensor): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name) + elif backend.is_tensor(x): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None) + else: + raise TypeError( + f"Unsupported x={x} of the type ({type(x)}). Supported types are: " + "`keras.InputSpec`, `keras.KerasTensor` and backend tensor." + ) + return input_spec + + +def make_tf_tensor_spec(x): + if isinstance(x, tf.TensorSpec): + tensor_spec = x + else: + input_spec = make_input_spec(x) + tensor_spec = tf.TensorSpec( + input_spec.shape, dtype=input_spec.dtype, name=input_spec.name + ) + return tensor_spec + + +def convert_spec_to_tensor(spec, replace_none_number=None): + shape = backend.standardize_shape(spec.shape) + if replace_none_number is not None: + replace_none_number = int(replace_none_number) + shape = tuple( + s if s is not None else replace_none_number for s in shape + ) + return ops.ones(shape, spec.dtype) diff --git a/keras/src/export/onnx.py b/keras/src/export/onnx.py new file mode 100644 index 000000000000..7d4d37d5e758 --- /dev/null +++ b/keras/src/export/onnx.py @@ -0,0 +1,219 @@ +import warnings + +from keras.src import backend +from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor +from keras.src.export.export_utils import get_input_signature +from keras.src.export.export_utils import make_tf_tensor_spec +from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME +from keras.src.export.saved_model import ExportArchive +from keras.src.export.tf2onnx_lib import patch_tf2onnx +from keras.src.utils import io_utils + + +def export_onnx( + model, + filepath, + verbose=None, + input_signature=None, + opset_version=None, + **kwargs, +): + """Export the model as a ONNX artifact for inference. + + This method lets you export a model to a lightweight ONNX artifact + that contains the model's forward pass only (its `call()` method) + and can be served via e.g. ONNX Runtime. + + The original code of the model (including any custom layers you may + have used) is *no longer* necessary to reload the artifact -- it is + entirely standalone. + + Args: + filepath: `str` or `pathlib.Path` object. The path to save the artifact. + verbose: `bool`. Whether to print a message during export. Defaults to + `None`, which uses the default value set by different backends and + formats. + input_signature: Optional. Specifies the shape and dtype of the model + inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor. If not provided, it will + be automatically computed. Defaults to `None`. + opset_version: Optional. An integer value that specifies the ONNX opset + version. If not provided, the default version for the backend will + be used. Defaults to `None`. + **kwargs: Additional keyword arguments. + + **Note:** This feature is currently supported only with TensorFlow, JAX and + Torch backends. + + **Note:** The dtype policy must be "float32" for the model. You can further + optimize the ONNX artifact using the ONNX toolkit. Learn more here: + [https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/). + + **Note:** The dynamic shape feature is not yet supported with Torch + backend. As a result, you must fully define the shapes of the inputs using + `input_signature`. If `input_signature` is not provided, all instances of + `None` (such as the batch size) will be replaced with `1`. + + Example: + + ```python + # Export the model as a ONNX artifact + model.export("path/to/location", format="onnx") + + # Load the artifact in a different process/environment + ort_session = onnxruntime.InferenceSession("path/to/location") + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) + } + predictions = ort_session.run(None, ort_inputs) + ``` + """ + actual_verbose = verbose + if actual_verbose is None: + actual_verbose = True # Defaults to `True` for all backends. + + if input_signature is None: + input_signature = get_input_signature(model) + if not input_signature or not model._called: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + input_names = [ + getattr(spec, "name", None) or f"input_{i}" + for i, spec in enumerate(input_signature) + ] + + if backend.backend() in ("tensorflow", "jax"): + from keras.src.utils.module_utils import tf2onnx + + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) + decorated_fn = get_concrete_fn(model, input_signature, **kwargs) + + # Use `tf2onnx` to convert the `decorated_fn` to the ONNX format. + patch_tf2onnx() # TODO: Remove this once `tf2onnx` supports numpy 2. + tf2onnx.convert.from_function( + decorated_fn, + input_signature, + opset=opset_version, + output_path=filepath, + ) + + elif backend.backend() == "torch": + import torch + + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + sample_inputs = tuple(sample_inputs) + # TODO: Make dict model exportable. + if any(isinstance(x, dict) for x in sample_inputs): + raise ValueError( + "Currently, `export_onnx` in the torch backend doesn't support " + "dictionaries as inputs." + ) + + if hasattr(model, "eval"): + model.eval() + with warnings.catch_warnings(): + # Suppress some unuseful warnings. + warnings.filterwarnings( + "ignore", + message=r".*\n.*\n*.*\n*.*export will treat it as a constant.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*not properly registered as a submodule,.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*which is what 'get_attr' Nodes typically target.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*underlying reference in the owning GraphModule.*", + ) + warnings.filterwarnings( + "ignore", message=r".*suppressed about get_attr references.*" + ) + try: + # Try the TorchDynamo-based ONNX exporter first. + onnx_program = torch.onnx.export( + model, + sample_inputs, + verbose=actual_verbose, + opset_version=opset_version, + input_names=input_names, + dynamo=True, + ) + if hasattr(onnx_program, "optimize"): + onnx_program.optimize() # Only supported by torch>=2.6.0. + onnx_program.save(filepath) + except: + if verbose is None: + # Set to `False` due to file system leakage issue: + # https://github.com/keras-team/keras/issues/20826 + actual_verbose = False + + # Fall back to the TorchScript-based ONNX exporter. + torch.onnx.export( + model, + sample_inputs, + filepath, + verbose=actual_verbose, + opset_version=opset_version, + input_names=input_names, + ) + else: + raise NotImplementedError( + "`export_onnx` is only compatible with TensorFlow, JAX and " + "Torch backends." + ) + + if actual_verbose: + io_utils.print_msg(f"Saved artifact at '{filepath}'.") + + +def _check_jax_kwargs(kwargs): + kwargs = kwargs.copy() + if "is_static" not in kwargs: + kwargs["is_static"] = True + if "jax2tf_kwargs" not in kwargs: + # TODO: These options will be deprecated in JAX. We need to + # find another way to export ONNX. + kwargs["jax2tf_kwargs"] = { + "enable_xla": False, + "native_serialization": False, + } + if kwargs["is_static"] is not True: + raise ValueError( + "`is_static` must be `True` in `kwargs` when using the jax backend." + ) + if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: + raise ValueError( + "`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " + "when using the jax backend." + ) + if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: + raise ValueError( + "`native_serialization` must be `False` in " + "`kwargs['jax2tf_kwargs']` when using the jax backend." + ) + return kwargs + + +def get_concrete_fn(model, input_signature, **kwargs): + """Get the `tf.function` associated with the model.""" + if backend.backend() == "jax": + kwargs = _check_jax_kwargs(kwargs) + export_archive = ExportArchive() + export_archive.track_and_add_endpoint( + DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs + ) + if backend.backend() == "tensorflow": + export_archive._filter_and_track_resources() + return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME) diff --git a/keras/src/export/onnx_test.py b/keras/src/export/onnx_test.py new file mode 100644 index 000000000000..d073f01bacd0 --- /dev/null +++ b/keras/src/export/onnx_test.py @@ -0,0 +1,300 @@ +"""Tests for ONNX exporting utilities.""" + +import os + +import numpy as np +import onnxruntime +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.export import onnx +from keras.src.layers.input_spec import InputSpec as InputSpec +from keras.src.saving import saving_lib +from keras.src.testing.test_utils import named_product + + +class CustomModel(models.Model): + def __init__(self, layer_list): + super().__init__() + self.layer_list = layer_list + + def call(self, input): + output = input + for layer in self.layer_list: + output = layer(output) + return output + + +def get_model(type="sequential", input_shape=(10,), layer_list=None): + layer_list = layer_list or [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + if type == "sequential": + return models.Sequential(layer_list) + elif type == "functional": + input = output = tree.map_shape_structure(layers.Input, input_shape) + for layer in layer_list: + output = layer(output) + return models.Model(inputs=input, outputs=output) + elif type == "subclass": + return CustomModel(layer_list) + elif type == "lstm": + # https://github.com/keras-team/keras/issues/21390 + inputs = layers.Input((4, 10)) + x = layers.Bidirectional( + layers.LSTM( + 10, + kernel_initializer="he_normal", + return_sequences=True, + kernel_regularizer=None, + ), + merge_mode="sum", + )(inputs) + outputs = layers.Bidirectional( + layers.LSTM( + 10, + kernel_initializer="he_normal", + return_sequences=True, + kernel_regularizer=None, + ), + merge_mode="concat", + )(x) + return models.Model(inputs=inputs, outputs=outputs) + + +@pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=( + "`export_onnx` only currently supports the tensorflow, jax and torch " + "backends." + ), +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.tensorflow_uses_gpu(), reason="Leads to core dumps on CI" +) +class ExportONNXTest(testing.TestCase): + @parameterized.named_parameters( + named_product( + model_type=["sequential", "functional", "subclass", "lstm"] + ) + ) + def test_standard_model_export(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + batch_size = 3 if backend.backend() != "torch" else 1 + if model_type == "lstm": + ref_input = np.random.normal(size=(batch_size, 4, 10)) + else: + ref_input = np.random.normal(size=(batch_size, 10)) + ref_input = ref_input.astype("float32") + ref_output = model(ref_input) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [np.concatenate([ref_input, ref_input], axis=0)], + ) + } + ort_session.run(None, ort_inputs) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + if backend.backend() == "torch" and struct_type == "dict": + self.skipTest("The torch backend doesn't support the dict model.") + + class TupleModel(models.Model): + def call(self, inputs): + x, y = inputs + return ops.add(x, y) + + class ArrayModel(models.Model): + def call(self, inputs): + x = inputs[0] + y = inputs[1] + return ops.add(x, y) + + class DictModel(models.Model): + def call(self, inputs): + x = inputs["x"] + y = inputs["y"] + return ops.add(x, y) + + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + if struct_type == "tuple": + model = TupleModel() + ref_input = (ref_input, ref_input * 2) + elif struct_type == "array": + model = ArrayModel() + ref_input = [ref_input, ref_input * 2] + elif struct_type == "dict": + model = DictModel() + ref_input = {"x": ref_input, "y": ref_input * 2} + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + if isinstance(ref_input, dict): + ort_inputs = { + k.name: v + for k, v in zip(ort_session.get_inputs(), ref_input.values()) + } + else: + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), ref_input) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + + # Test with keras.saving_lib + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model2") + onnx.export_onnx(revived_model, temp_filepath) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + bigger_ref_input = tree.map_structure( + lambda x: np.concatenate([x, x], axis=0), ref_input + ) + if isinstance(bigger_ref_input, dict): + bigger_ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), bigger_ref_input.values() + ) + } + else: + bigger_ort_inputs = { + k.name: v + for k, v in zip(ort_session.get_inputs(), bigger_ref_input) + } + ort_session.run(None, bigger_ort_inputs) + + def test_model_with_multiple_inputs(self): + class TwoInputsModel(models.Model): + def call(self, x, y): + return x + y + + def build(self, y_shape, x_shape): + self.built = True + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = TwoInputsModel() + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input_x, ref_input_y) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), [ref_input_x, ref_input_y] + ) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [ + np.concatenate([ref_input_x, ref_input_x], axis=0), + np.concatenate([ref_input_y, ref_input_y], axis=0), + ], + ) + } + ort_session.run(None, ort_inputs) + + @parameterized.named_parameters(named_product(opset_version=[None, 18])) + def test_export_with_opset_version(self, opset_version): + import onnx as onnx_lib + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model("sequential") + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)) + ref_input = ref_input.astype("float32") + ref_output = model(ref_input) + + onnx.export_onnx( + model, temp_filepath, opset_version=opset_version, verbose=True + ) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + + if opset_version is not None: + onnx_model = onnx_lib.load(temp_filepath) + self.assertEqual(onnx_model.opset_import[0].version, opset_version) + + def test_export_with_input_names(self): + """Test ONNX export uses InputSpec.name for input names.""" + import onnx as onnx_lib + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model("sequential") + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input) + + # Test with custom input name + input_spec = [ + InputSpec( + name="custom_input", shape=(batch_size, 10), dtype="float32" + ) + ] + onnx.export_onnx(model, temp_filepath, input_signature=input_spec) + + onnx_model = onnx_lib.load(temp_filepath) + input_names = [input.name for input in onnx_model.graph.input] + self.assertIn("custom_input", input_names) + + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) diff --git a/keras/src/export/openvino.py b/keras/src/export/openvino.py new file mode 100644 index 000000000000..bdd4b5c5a82e --- /dev/null +++ b/keras/src/export/openvino.py @@ -0,0 +1,204 @@ +import warnings + +from keras.src import backend +from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor +from keras.src.export.export_utils import get_input_signature +from keras.src.export.export_utils import make_tf_tensor_spec +from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME +from keras.src.export.saved_model import ExportArchive +from keras.src.utils import io_utils + + +def export_openvino( + model, filepath, verbose=None, input_signature=None, **kwargs +): + """Export the model as an OpenVINO IR artifact for inference. + + This method exports the model to the OpenVINO IR format, + which includes two files: + a `.xml` file containing the model structure and a `.bin` file + containing the weights. + The exported model contains only the forward pass + (i.e., the model's `call()` method), and can be deployed with the + OpenVINO Runtime for fast inference on CPU and other Intel hardware. + + Args: + filepath: `str` or `pathlib.Path`. Path to the output `.xml` file. + The corresponding `.bin` file will be saved alongside it. + verbose: Optional `bool`. Whether to print a confirmation message + after export. If `None`, it uses the default verbosity configured + by the backend. + input_signature: Optional. Specifies the shape and dtype of the + model inputs. If not provided, it will be inferred. + **kwargs: Additional keyword arguments. + + Example: + + ```python + import keras + + # Define or load a Keras model + model = keras.models.Sequential([ + keras.layers.Input(shape=(128,)), + keras.layers.Dense(64, activation="relu"), + keras.layers.Dense(10) + ]) + + # Export to OpenVINO IR + model.export("model.xml", format="openvino") + ``` + """ + assert filepath.endswith(".xml"), ( + "The OpenVINO export requires the filepath to end with '.xml'. " + f"Got: {filepath}" + ) + + import openvino as ov + from openvino.runtime import opset14 as ov_opset + + from keras.src.backend.openvino.core import OPENVINO_DTYPES + from keras.src.backend.openvino.core import OpenVINOKerasTensor + + actual_verbose = verbose if verbose is not None else True + + if input_signature is None: + input_signature = get_input_signature(model) + + if backend.backend() == "openvino": + import inspect + + def parameterize_inputs(inputs, prefix=""): + if isinstance(inputs, (list, tuple)): + return [ + parameterize_inputs(e, f"{prefix}{i}") + for i, e in enumerate(inputs) + ] + elif isinstance(inputs, dict): + return {k: parameterize_inputs(v, k) for k, v in inputs.items()} + elif isinstance(inputs, OpenVINOKerasTensor): + ov_type = OPENVINO_DTYPES[str(inputs.dtype)] + ov_shape = list(inputs.shape) + param = ov_opset.parameter(shape=ov_shape, dtype=ov_type) + param.set_friendly_name(prefix) + return OpenVINOKerasTensor(param.output(0)) + else: + raise TypeError(f"Unknown input type: {type(inputs)}") + + if isinstance(input_signature, list) and len(input_signature) == 1: + input_signature = input_signature[0] + + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + params = parameterize_inputs(sample_inputs) + signature = inspect.signature(model.call) + if len(signature.parameters) > 1 and isinstance(params, (list, tuple)): + outputs = model(*params) + else: + outputs = model(params) + parameters = [p.output.get_node() for p in tree.flatten(params)] + results = [ov_opset.result(r.output) for r in tree.flatten(outputs)] + ov_model = ov.Model(results=results, parameters=parameters) + flat_specs = tree.flatten(input_signature) + for ov_input, spec in zip(ov_model.inputs, flat_specs): + # Respect the dynamic axes from the original input signature. + dynamic_shape_dims = [ + -1 if dim is None else dim for dim in spec.shape + ] + dynamic_shape = ov.PartialShape(dynamic_shape_dims) + ov_input.get_node().set_partial_shape(dynamic_shape) + + elif backend.backend() in ("tensorflow", "jax"): + inputs = tree.map_structure(make_tf_tensor_spec, input_signature) + decorated_fn = get_concrete_fn(model, inputs, **kwargs) + ov_model = ov.convert_model(decorated_fn) + set_names(ov_model, inputs) + elif backend.backend() == "torch": + import torch + + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + sample_inputs = tuple(sample_inputs) + if hasattr(model, "eval"): + model.eval() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + traced = torch.jit.trace(model, sample_inputs) + ov_model = ov.convert_model(traced) + set_names(ov_model, sample_inputs) + else: + raise NotImplementedError( + "`export_openvino` is only compatible with OpenVINO, " + "TensorFlow, JAX and Torch backends." + ) + + ov.serialize(ov_model, filepath) + + if actual_verbose: + io_utils.print_msg(f"Saved OpenVINO IR at '{filepath}'.") + + +def collect_names(structure): + if isinstance(structure, dict): + for k, v in structure.items(): + if isinstance(v, (dict, list, tuple)): + yield from collect_names(v) + else: + yield k + elif isinstance(structure, (list, tuple)): + for v in structure: + yield from collect_names(v) + else: + if hasattr(structure, "name") and structure.name: + yield structure.name + else: + yield "input" + + +def set_names(model, inputs): + names = list(collect_names(inputs)) + for ov_input, name in zip(model.inputs, names): + ov_input.get_node().set_friendly_name(name) + ov_input.tensor.set_names({name}) + + +def _check_jax_kwargs(kwargs): + kwargs = kwargs.copy() + if "is_static" not in kwargs: + kwargs["is_static"] = True + if "jax2tf_kwargs" not in kwargs: + kwargs["jax2tf_kwargs"] = { + "enable_xla": False, + "native_serialization": False, + } + if kwargs["is_static"] is not True: + raise ValueError( + "`is_static` must be `True` in `kwargs` when using the jax backend." + ) + if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: + raise ValueError( + "`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " + "when using the jax backend." + ) + if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: + raise ValueError( + "`native_serialization` must be `False` in " + "`kwargs['jax2tf_kwargs']` when using the jax backend." + ) + return kwargs + + +def get_concrete_fn(model, input_signature, **kwargs): + if backend.backend() == "jax": + kwargs = _check_jax_kwargs(kwargs) + export_archive = ExportArchive() + export_archive.track_and_add_endpoint( + DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs + ) + if backend.backend() == "tensorflow": + export_archive._filter_and_track_resources() + return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME) diff --git a/keras/src/export/openvino_test.py b/keras/src/export/openvino_test.py new file mode 100644 index 000000000000..51b9f46cf1ad --- /dev/null +++ b/keras/src/export/openvino_test.py @@ -0,0 +1,229 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.export import openvino +from keras.src.saving import saving_lib +from keras.src.testing.test_utils import named_product + +try: + import openvino as ov +except ImportError: + ov = None + + +class CustomModel(models.Model): + def __init__(self, layer_list): + super().__init__() + self.layer_list = layer_list + + def call(self, input): + output = input + for layer in self.layer_list: + output = layer(output) + return output + + +def get_model(type="sequential", input_shape=(10,), layer_list=None): + layer_list = layer_list or [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + if type == "sequential": + return models.Sequential(layer_list) + elif type == "functional": + input = output = tree.map_shape_structure(layers.Input, input_shape) + for layer in layer_list: + output = layer(output) + return models.Model(inputs=input, outputs=output) + elif type == "subclass": + return CustomModel(layer_list) + elif type == "lstm": + # https://github.com/keras-team/keras/issues/21390 + inputs = layers.Input((4, 10)) + x = layers.Bidirectional( + layers.LSTM( + 10, + kernel_initializer="he_normal", + return_sequences=True, + kernel_regularizer=None, + ), + merge_mode="sum", + )(inputs) + outputs = layers.Bidirectional( + layers.LSTM( + 10, + kernel_initializer="he_normal", + return_sequences=True, + kernel_regularizer=None, + ), + merge_mode="concat", + )(x) + return models.Model(inputs=inputs, outputs=outputs) + + +@pytest.mark.skipif(ov is None, reason="OpenVINO is not installed") +@pytest.mark.skipif( + backend.backend() not in ("tensorflow", "openvino", "jax", "torch"), + reason=( + "`export_openvino` only currently supports" + "the tensorflow, jax, torch and openvino backends." + ), +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.tensorflow_uses_gpu(), reason="Leads to core dumps on CI" +) +class ExportOpenVINOTest(testing.TestCase): + @parameterized.named_parameters( + named_product( + model_type=["sequential", "functional", "subclass", "lstm"] + ) + ) + def test_standard_model_export(self, model_type): + if model_type == "lstm": + self.skipTest( + "LSTM export not supported - unimplemented QR operation" + ) + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model.xml") + model = get_model(model_type) + batch_size = 3 + if model_type == "lstm": + ref_input = np.random.normal(size=(batch_size, 4, 10)) + else: + ref_input = np.random.normal(size=(batch_size, 10)) + ref_input = ref_input.astype("float32") + ref_output = model(ref_input) + + openvino.export_openvino(model, temp_filepath) + + # Load and run inference with OpenVINO + core = ov.Core() + ov_model = core.read_model(temp_filepath) + compiled_model = core.compile_model(ov_model, "CPU") + + ov_output = compiled_model([ref_input])[compiled_model.output(0)] + + self.assertAllClose(ref_output, ov_output) + + larger_input = np.concatenate([ref_input, ref_input], axis=0) + compiled_model([larger_input]) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + class TupleModel(models.Model): + def call(self, inputs): + x, y = inputs + return ops.add(x, y) + + class ArrayModel(models.Model): + def call(self, inputs): + x = inputs[0] + y = inputs[1] + return ops.add(x, y) + + class DictModel(models.Model): + def call(self, inputs): + x = inputs["x"] + y = inputs["y"] + return ops.add(x, y) + + batch_size = 3 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + if struct_type == "tuple": + model = TupleModel() + ref_input = (ref_input, ref_input * 2) + elif struct_type == "array": + model = ArrayModel() + ref_input = [ref_input, ref_input * 2] + elif struct_type == "dict": + model = DictModel() + ref_input = {"x": ref_input, "y": ref_input * 2} + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model.xml") + ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) + + openvino.export_openvino(model, temp_filepath) + + # Load and run inference with OpenVINO + core = ov.Core() + ov_model = core.read_model(temp_filepath) + compiled_model = core.compile_model(ov_model, "CPU") + + if isinstance(ref_input, dict): + ov_inputs = [ref_input[key] for key in ref_input.keys()] + else: + ov_inputs = list(ref_input) + + ov_output = compiled_model(ov_inputs)[compiled_model.output(0)] + self.assertAllClose(ref_output, ov_output) + + # Test with keras.saving_lib + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model2.xml") + openvino.export_openvino(revived_model, temp_filepath) + + bigger_ref_input = tree.map_structure( + lambda x: np.concatenate([x, x], axis=0), ref_input + ) + if isinstance(bigger_ref_input, dict): + bigger_ov_inputs = [ + bigger_ref_input[key] for key in bigger_ref_input.keys() + ] + else: + bigger_ov_inputs = list(bigger_ref_input) + compiled_model(bigger_ov_inputs) + + def test_model_with_multiple_inputs(self): + class TwoInputsModel(models.Model): + def call(self, x, y): + return x + y + + def build(self, y_shape, x_shape): + self.built = True + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model.xml") + model = TwoInputsModel() + batch_size = 3 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input_x, ref_input_y) + + openvino.export_openvino(model, temp_filepath) + + # Load and run inference with OpenVINO + core = ov.Core() + ov_model = core.read_model(temp_filepath) + compiled_model = core.compile_model(ov_model, "CPU") + + ov_output = compiled_model([ref_input_x, ref_input_y])[ + compiled_model.output(0) + ] + self.assertAllClose(ref_output, ov_output) + larger_input_x = np.concatenate([ref_input_x, ref_input_x], axis=0) + larger_input_y = np.concatenate([ref_input_y, ref_input_y], axis=0) + compiled_model([larger_input_x, larger_input_y]) diff --git a/keras/src/export/saved_model.py b/keras/src/export/saved_model.py new file mode 100644 index 000000000000..d5009a7ec4a6 --- /dev/null +++ b/keras/src/export/saved_model.py @@ -0,0 +1,693 @@ +"""Library for exporting SavedModel for Keras models/layers.""" + +from keras.src import backend +from keras.src import layers +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.export.export_utils import get_input_signature +from keras.src.export.export_utils import make_tf_tensor_spec +from keras.src.utils import io_utils +from keras.src.utils.module_utils import tensorflow as tf + +if backend.backend() == "tensorflow": + from keras.src.backend.tensorflow.export import ( + TFExportArchive as BackendExportArchive, + ) +elif backend.backend() == "jax": + from keras.src.backend.jax.export import ( + JaxExportArchive as BackendExportArchive, + ) +elif backend.backend() == "torch": + from keras.src.backend.torch.export import ( + TorchExportArchive as BackendExportArchive, + ) +elif backend.backend() == "numpy": + from keras.src.backend.numpy.export import ( + NumpyExportArchive as BackendExportArchive, + ) +elif backend.backend() == "openvino": + from keras.src.backend.openvino.export import ( + OpenvinoExportArchive as BackendExportArchive, + ) +else: + raise RuntimeError( + f"Backend '{backend.backend()}' must implement ExportArchive." + ) + + +DEFAULT_ENDPOINT_NAME = "serve" + + +@keras_export("keras.export.ExportArchive") +class ExportArchive(BackendExportArchive): + """ExportArchive is used to write SavedModel artifacts (e.g. for inference). + + If you have a Keras model or layer that you want to export as SavedModel for + serving (e.g. via TensorFlow-Serving), you can use `ExportArchive` + to configure the different serving endpoints you need to make available, + as well as their signatures. Simply instantiate an `ExportArchive`, + use `track()` to register the layer(s) or model(s) to be used, + then use the `add_endpoint()` method to register a new serving endpoint. + When done, use the `write_out()` method to save the artifact. + + The resulting artifact is a SavedModel and can be reloaded via + `tf.saved_model.load`. + + Examples: + + Here's how to export a model for inference. + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], + ) + export_archive.write_out("path/to/location") + + # Elsewhere, we can reload the artifact and serve it. + # The endpoint we added is available as a method: + serving_model = tf.saved_model.load("path/to/location") + outputs = serving_model.serve(inputs) + ``` + + Here's how to export a model with one endpoint for inference and one + endpoint for a training-mode forward pass (e.g. with dropout on). + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="call_inference", + fn=lambda x: model.call(x, training=False), + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], + ) + export_archive.add_endpoint( + name="call_training", + fn=lambda x: model.call(x, training=True), + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], + ) + export_archive.write_out("path/to/location") + ``` + + **Note on resource tracking:** + + `ExportArchive` is able to automatically track all `keras.Variables` used + by its endpoints, so most of the time calling `.track(model)` + is not strictly required. However, if your model uses lookup layers such + as `IntegerLookup`, `StringLookup`, or `TextVectorization`, + it will need to be tracked explicitly via `.track(model)`. + + Explicit tracking is also required if you need to be able to access + the properties `variables`, `trainable_variables`, or + `non_trainable_variables` on the revived archive. + """ + + def __init__(self): + super().__init__() + if backend.backend() not in ("tensorflow", "jax", "torch"): + raise NotImplementedError( + "`ExportArchive` is only compatible with TensorFlow, JAX and " + "Torch backends." + ) + + self._endpoint_names = [] + self._endpoint_signatures = {} + self.tensorflow_version = tf.__version__ + + self._tf_trackable = tf.__internal__.tracking.AutoTrackable() + self._tf_trackable.variables = [] + self._tf_trackable.trainable_variables = [] + self._tf_trackable.non_trainable_variables = [] + + @property + def variables(self): + return self._tf_trackable.variables + + @property + def trainable_variables(self): + return self._tf_trackable.trainable_variables + + @property + def non_trainable_variables(self): + return self._tf_trackable.non_trainable_variables + + def track(self, resource): + """Track the variables (of a layer or model) and other assets. + + By default, all variables used by an endpoint function are automatically + tracked when you call `add_endpoint()`. However, non-variables assets + such as lookup tables need to be tracked manually. Note that lookup + tables used by built-in Keras layers (`TextVectorization`, + `IntegerLookup`, `StringLookup`) are automatically tracked by + `add_endpoint()`. + + Args: + resource: A layer, model or a TensorFlow trackable resource. + """ + if isinstance(resource, layers.Layer) and not resource.built: + raise ValueError( + "The layer provided has not yet been built. " + "It must be built before export." + ) + + # Note: with the TensorFlow backend, Layers and Models fall into both + # the Layer case and the Trackable case. The Trackable case is needed + # for preprocessing layers in order to track lookup tables. + if isinstance(resource, tf.__internal__.tracking.Trackable): + if not hasattr(self, "_tracked"): + self._tracked = [] + self._tracked.append(resource) + + if isinstance(resource, layers.Layer): + self._track_layer(resource) + elif not isinstance(resource, tf.__internal__.tracking.Trackable): + raise ValueError( + "Invalid resource type. Expected a Keras `Layer` or `Model` " + "or a TensorFlow `Trackable` object. " + f"Received object {resource} of type '{type(resource)}'. " + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + """Register a new serving endpoint. + + Args: + name: `str`. The name of the endpoint. + fn: A callable. It should only leverage resources + (e.g. `keras.Variable` objects or `tf.lookup.StaticHashTable` + objects) that are available on the models/layers tracked by the + `ExportArchive` (you can call `.track(model)` to track a new + model). + The shape and dtype of the inputs to the function must be + known. For that purpose, you can either 1) make sure that `fn` + is a `tf.function` that has been called at least once, or 2) + provide an `input_signature` argument that specifies the shape + and dtype of the inputs (see below). + input_signature: Optional. Specifies the shape and dtype of `fn`. + Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor (see below for an + example showing a `Functional` model with 2 input arguments). If + not provided, `fn` must be a `tf.function` that has been called + at least once. Defaults to `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are + not provided, they are automatically computed. + + Returns: + The `tf.function` wrapping `fn` that was added to the archive. + + Example: + + Adding an endpoint using the `input_signature` argument when the + model has a single input argument: + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], + ) + ``` + + Adding an endpoint using the `input_signature` argument when the + model has two positional input arguments: + + ```python + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[ + keras.InputSpec(shape=(None, 3), dtype="float32"), + keras.InputSpec(shape=(None, 4), dtype="float32"), + ], + ) + ``` + + Adding an endpoint using the `input_signature` argument when the + model has one input argument that is a list of 2 tensors (e.g. + a Functional model with 2 inputs): + + ```python + model = keras.Model(inputs=[x1, x2], outputs=outputs) + + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[ + [ + keras.InputSpec(shape=(None, 3), dtype="float32"), + keras.InputSpec(shape=(None, 4), dtype="float32"), + ], + ], + ) + ``` + + This also works with dictionary inputs: + + ```python + model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs) + + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[ + { + "x1": keras.InputSpec(shape=(None, 3), dtype="float32"), + "x2": keras.InputSpec(shape=(None, 4), dtype="float32"), + }, + ], + ) + ``` + + Adding an endpoint that is a `tf.function`: + + ```python + @tf.function() + def serving_fn(x): + return model(x) + + # The function must be traced, i.e. it must be called at least once. + serving_fn(tf.random.normal(shape=(2, 3))) + + export_archive = ExportArchive() + export_archive.track(model) + export_archive.add_endpoint(name="serve", fn=serving_fn) + ``` + + Combining a model with some TensorFlow preprocessing, which can use + TensorFlow resources: + + ```python + lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0) + + export_archive = ExportArchive() + model_fn = export_archive.track_and_add_endpoint( + "model_fn", + model, + input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)], + ) + export_archive.track(lookup_table) + + @tf.function() + def serving_fn(x): + x = lookup_table.lookup(x) + return model_fn(x) + + export_archive.add_endpoint(name="serve", fn=serving_fn) + ``` + """ + if name in self._endpoint_names: + raise ValueError(f"Endpoint name '{name}' is already taken.") + + if backend.backend() != "jax": + if "jax2tf_kwargs" in kwargs or "is_static" in kwargs: + raise ValueError( + "'jax2tf_kwargs' and 'is_static' are only supported with " + f"the jax backend. Current backend: {backend.backend()}" + ) + + # The fast path if `fn` is already a `tf.function`. + if input_signature is None: + if isinstance(fn, tf.types.experimental.GenericFunction): + if not fn._list_all_concrete_functions(): + raise ValueError( + f"The provided tf.function '{fn}' " + "has never been called. " + "To specify the expected shape and dtype " + "of the function's arguments, " + "you must either provide a function that " + "has been called at least once, or alternatively pass " + "an `input_signature` argument in `add_endpoint()`." + ) + decorated_fn = fn + else: + raise ValueError( + "If the `fn` argument provided is not a `tf.function`, " + "you must provide an `input_signature` argument to " + "specify the shape and dtype of the function arguments. " + "Example:\n\n" + "export_archive.add_endpoint(\n" + " name='call',\n" + " fn=model.call,\n" + " input_signature=[\n" + " keras.InputSpec(\n" + " shape=(None, 224, 224, 3),\n" + " dtype='float32',\n" + " )\n" + " ],\n" + ")" + ) + setattr(self._tf_trackable, name, decorated_fn) + self._endpoint_names.append(name) + return decorated_fn + + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) + decorated_fn = super().add_endpoint(name, fn, input_signature, **kwargs) + self._endpoint_signatures[name] = input_signature + setattr(self._tf_trackable, name, decorated_fn) + self._endpoint_names.append(name) + return decorated_fn + + def track_and_add_endpoint(self, name, resource, input_signature, **kwargs): + """Track the variables and register a new serving endpoint. + + This function combines the functionality of `track` and `add_endpoint`. + It tracks the variables of the `resource` (either a layer or a model) + and registers a serving endpoint using `resource.__call__`. + + Args: + name: `str`. The name of the endpoint. + resource: A trackable Keras resource, such as a layer or model. + input_signature: Optional. Specifies the shape and dtype of `fn`. + Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor (see below for an + example showing a `Functional` model with 2 input arguments). If + not provided, `fn` must be a `tf.function` that has been called + at least once. Defaults to `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are + not provided, they are automatically computed. + + """ + if name in self._endpoint_names: + raise ValueError(f"Endpoint name '{name}' is already taken.") + if not isinstance(resource, layers.Layer): + raise ValueError( + "Invalid resource type. Expected an instance of a Keras " + "`Layer` or `Model`. " + f"Received: resource={resource} (of type {type(resource)})" + ) + if not resource.built: + raise ValueError( + "The layer provided has not yet been built. " + "It must be built before export." + ) + if backend.backend() != "jax": + if "jax2tf_kwargs" in kwargs or "is_static" in kwargs: + raise ValueError( + "'jax2tf_kwargs' and 'is_static' are only supported with " + f"the jax backend. Current backend: {backend.backend()}" + ) + + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) + + if not hasattr(BackendExportArchive, "track_and_add_endpoint"): + # Default behavior. + self.track(resource) + return self.add_endpoint( + name, resource.__call__, input_signature, **kwargs + ) + else: + # Special case for the torch backend. + decorated_fn = super().track_and_add_endpoint( + name, resource, input_signature, **kwargs + ) + self._endpoint_signatures[name] = input_signature + setattr(self._tf_trackable, name, decorated_fn) + self._endpoint_names.append(name) + return decorated_fn + + def add_variable_collection(self, name, variables): + """Register a set of variables to be retrieved after reloading. + + Arguments: + name: The string name for the collection. + variables: A tuple/list/set of `keras.Variable` instances. + + Example: + + ```python + export_archive = ExportArchive() + export_archive.track(model) + # Register an endpoint + export_archive.add_endpoint( + name="serve", + fn=model.call, + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], + ) + # Save a variable collection + export_archive.add_variable_collection( + name="optimizer_variables", variables=model.optimizer.variables) + export_archive.write_out("path/to/location") + + # Reload the object + revived_object = tf.saved_model.load("path/to/location") + # Retrieve the variables + optimizer_variables = revived_object.optimizer_variables + ``` + """ + if not isinstance(variables, (list, tuple, set)): + raise ValueError( + "Expected `variables` to be a list/tuple/set. " + f"Received instead object of type '{type(variables)}'." + ) + # Ensure that all variables added are either tf.Variables + # or Variables created by Keras 3 with the TF or JAX backends. + if not all( + isinstance(v, (tf.Variable, backend.Variable)) for v in variables + ): + raise ValueError( + "Expected all elements in `variables` to be " + "`tf.Variable` instances. Found instead the following types: " + f"{list(set(type(v) for v in variables))}" + ) + if backend.backend() == "jax": + variables = tree.flatten( + tree.map_structure(self._convert_to_tf_variable, variables) + ) + setattr(self._tf_trackable, name, list(variables)) + + def write_out(self, filepath, options=None, verbose=True): + """Write the corresponding SavedModel to disk. + + Arguments: + filepath: `str` or `pathlib.Path` object. + Path where to save the artifact. + options: `tf.saved_model.SaveOptions` object that specifies + SavedModel saving options. + verbose: whether to print all the variables of an + exported SavedModel. + + **Note on TF-Serving**: all endpoints registered via `add_endpoint()` + are made visible for TF-Serving in the SavedModel artifact. In addition, + the first endpoint registered is made visible under the alias + `"serving_default"` (unless an endpoint with the name + `"serving_default"` was already registered manually), + since TF-Serving requires this endpoint to be set. + """ + if not self._endpoint_names: + raise ValueError( + "No endpoints have been set yet. Call add_endpoint()." + ) + self._filter_and_track_resources() + + signatures = {} + for name in self._endpoint_names: + signatures[name] = self._get_concrete_fn(name) + # Add "serving_default" signature key for TFServing + if "serving_default" not in self._endpoint_names: + signatures["serving_default"] = self._get_concrete_fn( + self._endpoint_names[0] + ) + + tf.saved_model.save( + self._tf_trackable, + filepath, + options=options, + signatures=signatures, + ) + + # Print out available endpoints + if verbose: + endpoints = "\n\n".join( + _print_signature( + getattr(self._tf_trackable, name), name, verbose=verbose + ) + for name in self._endpoint_names + ) + io_utils.print_msg( + f"Saved artifact at '{filepath}'. " + "The following endpoints are available:\n\n" + f"{endpoints}" + ) + + def _convert_to_tf_variable(self, backend_variable): + if not isinstance(backend_variable, backend.Variable): + raise TypeError( + "`backend_variable` must be a `backend.Variable`. " + f"Recevied: backend_variable={backend_variable} of type " + f"({type(backend_variable)})" + ) + return tf.Variable( + backend_variable.value, + dtype=backend_variable.dtype, + trainable=backend_variable.trainable, + name=backend_variable.name, + ) + + def _get_concrete_fn(self, endpoint): + """Workaround for some SavedModel quirks.""" + if endpoint in self._endpoint_signatures: + return getattr(self._tf_trackable, endpoint) + else: + traces = getattr(self._tf_trackable, endpoint)._trackable_children( + "saved_model" + ) + return list(traces.values())[0] + + def _get_variables_used_by_endpoints(self): + fns = [self._get_concrete_fn(name) for name in self._endpoint_names] + return _list_variables_used_by_fns(fns) + + def _filter_and_track_resources(self): + """Track resources used by endpoints / referenced in `track()` calls.""" + # Start by extracting variables from endpoints. + fns = [self._get_concrete_fn(name) for name in self._endpoint_names] + tvs, ntvs = _list_variables_used_by_fns(fns) + self._tf_trackable._all_variables = list(tvs + ntvs) + + # Next, track lookup tables. + # Hopefully, one day this will be automated at the tf.function level. + self._tf_trackable._misc_assets = [] + from tensorflow.saved_model.experimental import TrackableResource + + if hasattr(self, "_tracked"): + for root in self._tracked: + descendants = tf.train.TrackableView(root).descendants() + for trackable in descendants: + if isinstance(trackable, TrackableResource): + self._tf_trackable._misc_assets.append(trackable) + + +def export_saved_model( + model, filepath, verbose=None, input_signature=None, **kwargs +): + """Export the model as a TensorFlow SavedModel artifact for inference. + + This method lets you export a model to a lightweight SavedModel artifact + that contains the model's forward pass only (its `call()` method) + and can be served via e.g. TensorFlow Serving. The forward pass is + registered under the name `serve()` (see example below). + + The original code of the model (including any custom layers you may + have used) is *no longer* necessary to reload the artifact -- it is + entirely standalone. + + Args: + filepath: `str` or `pathlib.Path` object. The path to save the artifact. + verbose: `bool`. Whether to print a message during export. Defaults to + `None`, which uses the default value set by different backends and + formats. + input_signature: Optional. Specifies the shape and dtype of the model + inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor. If not provided, it will + be automatically computed. Defaults to `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are not + provided, they are automatically computed. + + **Note:** This feature is currently supported only with TensorFlow, JAX and + Torch backends. Support for the Torch backend is experimental. + + **Note:** The dynamic shape feature is not yet supported with Torch + backend. As a result, you must fully define the shapes of the inputs using + `input_signature`. If `input_signature` is not provided, all instances of + `None` (such as the batch size) will be replaced with `1`. + + Example: + + ```python + # Export the model as a TensorFlow SavedModel artifact + model.export("path/to/location", format="tf_saved_model") + + # Load the artifact in a different process/environment + reloaded_artifact = tf.saved_model.load("path/to/location") + predictions = reloaded_artifact.serve(input_data) + ``` + + If you would like to customize your serving endpoints, you can + use the lower-level `keras.export.ExportArchive` class. The + `export()` method relies on `ExportArchive` internally. + """ + if verbose is None: + verbose = True # Defaults to `True` for all backends. + export_archive = ExportArchive() + if input_signature is None: + input_signature = get_input_signature(model) + + export_archive.track_and_add_endpoint( + DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs + ) + export_archive.write_out(filepath, verbose=verbose) + + +def _print_signature(fn, name, verbose=True): + concrete_fn = fn._list_all_concrete_functions()[0] + pprinted_signature = concrete_fn.pretty_printed_signature(verbose=verbose) + lines = pprinted_signature.split("\n") + lines = [f"* Endpoint '{name}'"] + lines[1:] + endpoint = "\n".join(lines) + return endpoint + + +def _list_variables_used_by_fns(fns): + trainable_variables = [] + non_trainable_variables = [] + trainable_variables_ids = set() + non_trainable_variables_ids = set() + for fn in fns: + if hasattr(fn, "concrete_functions"): + concrete_functions = fn.concrete_functions + elif hasattr(fn, "get_concrete_function"): + concrete_functions = [fn.get_concrete_function()] + else: + concrete_functions = [fn] + for concrete_fn in concrete_functions: + for v in concrete_fn.trainable_variables: + if id(v) not in trainable_variables_ids: + trainable_variables.append(v) + trainable_variables_ids.add(id(v)) + + for v in concrete_fn.variables: + if ( + id(v) not in trainable_variables_ids + and id(v) not in non_trainable_variables_ids + ): + non_trainable_variables.append(v) + non_trainable_variables_ids.add(id(v)) + return trainable_variables, non_trainable_variables diff --git a/keras/src/export/saved_model_test.py b/keras/src/export/saved_model_test.py new file mode 100644 index 000000000000..3401cc35de27 --- /dev/null +++ b/keras/src/export/saved_model_test.py @@ -0,0 +1,1017 @@ +"""Tests for SavedModel exporting utilities.""" + +import os + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import random +from keras.src import testing +from keras.src import tree +from keras.src.export import saved_model +from keras.src.saving import saving_lib +from keras.src.testing.test_utils import named_product + + +class CustomModel(models.Model): + def __init__(self, layer_list): + super().__init__() + self.layer_list = layer_list + + def call(self, input): + output = input + for layer in self.layer_list: + output = layer(output) + return output + + +def get_model(type="sequential", input_shape=(10,), layer_list=None): + layer_list = layer_list or [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + if type == "sequential": + return models.Sequential(layer_list) + elif type == "functional": + input = output = tree.map_shape_structure(layers.Input, input_shape) + for layer in layer_list: + output = layer(output) + return models.Model(inputs=input, outputs=output) + elif type == "subclass": + return CustomModel(layer_list) + + +@pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=( + "`export_saved_model` only currently supports the tensorflow, jax and " + "torch backends." + ), +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" +) +class ExportSavedModelTest(testing.TestCase): + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_standard_model_export(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + revived_model.serve(tf.random.normal((6, 10))) + + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason=( + "RuntimeError: mutating a non-functional tensor with a " + "functional tensor is not allowed in the torch backend." + ), + ) + def test_model_with_rng_export(self, model_type): + class RandomLayer(layers.Layer): + def __init__(self): + super().__init__() + self.seed_generator = backend.random.SeedGenerator() + + def call(self, inputs): + return inputs + random.uniform( + ops.shape(inputs), seed=self.seed_generator + ) + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type, layer_list=[RandomLayer()]) + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape) + # Test with a different batch size + input = tf.random.normal((6, 10)) + output1 = revived_model.serve(input) + output2 = revived_model.serve(input) + # Verify RNG seeding works and produces random outputs + self.assertNotAllClose(output1, output2) + + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason=( + "RuntimeError: mutating a non-functional tensor with a " + "functional tensor is not allowed in the torch backend." + ), + ) + def test_model_with_non_trainable_state_export(self, model_type): + class StateLayer(layers.Layer): + def __init__(self): + super().__init__() + self.counter = self.add_variable( + (), "zeros", "int32", trainable=False + ) + + def call(self, inputs): + self.counter.assign_add(1) + return ops.array(inputs), ops.array(self.counter.value) + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type, layer_list=[StateLayer()]) + model(tf.random.normal((3, 10))) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + + # The non-trainable counter is expected to increment + input = tf.random.normal((6, 10)) + output1, counter1 = revived_model.serve(input) + self.assertAllClose(output1, input) + self.assertAllClose(counter1, 2) + output2, counter2 = revived_model.serve(input) + self.assertAllClose(output2, input) + self.assertAllClose(counter2, 3) + + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_model_with_tf_data_layer(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type, layer_list=[layers.Rescaling(scale=2.0)]) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + revived_model.serve(tf.random.normal((6, 10))) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + class TupleModel(models.Model): + def call(self, inputs): + x, y = inputs + return ops.add(x, y) + + class ArrayModel(models.Model): + def call(self, inputs): + x = inputs[0] + y = inputs[1] + return ops.add(x, y) + + class DictModel(models.Model): + def call(self, inputs): + x = inputs["x"] + y = inputs["y"] + return ops.add(x, y) + + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + if struct_type == "tuple": + model = TupleModel() + ref_input = (ref_input, ref_input * 2) + elif struct_type == "array": + model = ArrayModel() + ref_input = [ref_input, ref_input * 2] + elif struct_type == "dict": + model = DictModel() + ref_input = {"x": ref_input, "y": ref_input * 2} + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + + # Test with keras.saving_lib + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + saved_model.export_saved_model(revived_model, self.get_temp_dir()) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + bigger_input = tree.map_structure( + lambda x: tf.concat([x, x], axis=0), ref_input + ) + revived_model(bigger_input) + + def test_model_with_multiple_inputs(self): + class TwoInputsModel(models.Model): + def call(self, x, y): + return x + y + + def build(self, y_shape, x_shape): + self.built = True + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = TwoInputsModel() + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input_x, ref_input_y) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, revived_model.serve(ref_input_x, ref_input_y) + ) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + revived_model.serve( + tf.random.normal((6, 10)), tf.random.normal((6, 10)) + ) + + @parameterized.named_parameters( + named_product( + model_type=["sequential", "functional", "subclass"], + input_signature=[ + layers.InputSpec( + dtype="float32", shape=(None, 10), name="inputs" + ), + tf.TensorSpec((None, 10), dtype="float32", name="inputs"), + backend.KerasTensor((None, 10), dtype="float32", name="inputs"), + "backend_tensor", + ], + ) + ) + def test_input_signature(self, model_type, input_signature): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = ops.random.normal((batch_size, 10)) + ref_output = model(ref_input) + + if input_signature == "backend_tensor": + input_signature = (ref_input,) + else: + input_signature = (input_signature,) + saved_model.export_saved_model( + model, temp_filepath, input_signature=input_signature + ) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose( + ref_output, revived_model.serve(ops.convert_to_numpy(ref_input)) + ) + + def test_input_signature_error(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model("functional") + with self.assertRaisesRegex(TypeError, "Unsupported x="): + input_signature = (123,) + saved_model.export_saved_model( + model, temp_filepath, input_signature=input_signature + ) + + @parameterized.named_parameters( + named_product( + model_type=["sequential", "functional", "subclass"], + is_static=(True, False), + jax2tf_kwargs=( + None, + {"enable_xla": True, "native_serialization": True}, + ), + ) + ) + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is only for the jax backend.", + ) + def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + ref_input = ops.random.uniform((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model( + model, + temp_filepath, + is_static=is_static, + jax2tf_kwargs=jax2tf_kwargs, + ) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + + +@pytest.mark.skipif( + backend.backend() + not in ( + "tensorflow", + "jax", + # "torch", # TODO: Support low-level operations in the torch backend. + ), + reason="Export only currently supports the TF and JAX backends.", +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +@pytest.mark.skipif( + testing.torch_uses_gpu(), reason="Leads to core dumps on CI" +) +class ExportArchiveTest(testing.TestCase): + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_low_level_model_export(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = get_model(model_type) + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + # Test variable tracking + export_archive = saved_model.ExportArchive() + export_archive.track(model) + self.assertLen(export_archive.variables, 8) + self.assertLen(export_archive.trainable_variables, 6) + self.assertLen(export_archive.non_trainable_variables, 2) + + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.call(ref_input)) + # Test with a different batch size + revived_model.call(tf.random.normal((6, 10))) + + def test_low_level_model_export_with_alias(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + export_archive = saved_model.ExportArchive() + export_archive.track(model) + fn = export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out( + temp_filepath, + tf.saved_model.SaveOptions(function_aliases={"call_alias": fn}), + ) + revived_model = tf.saved_model.load( + temp_filepath, + options=tf.saved_model.LoadOptions( + experimental_load_function_aliases=True + ), + ) + self.assertAllClose( + ref_output, revived_model.function_aliases["call_alias"](ref_input) + ) + # Test with a different batch size + revived_model.function_aliases["call_alias"](tf.random.normal((6, 10))) + + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_low_level_model_export_with_dynamic_dims(self, model_type): + class ReductionLayer(layers.Layer): + def call(self, inputs): + return ops.max(inputs, axis=1) + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = get_model( + model_type, + input_shape=[(None,), (None,)], + layer_list=[layers.Concatenate(), ReductionLayer()], + ) + ref_input = [tf.random.normal((3, 8)), tf.random.normal((3, 6))] + ref_output = model(ref_input) + + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[ + [ + tf.TensorSpec(shape=(None, None), dtype=tf.float32), + tf.TensorSpec(shape=(None, None), dtype=tf.float32), + ] + ], + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.call(ref_input)) + # Test with a different batch size + revived_model.call([tf.random.normal((6, 8)), tf.random.normal((6, 6))]) + # Test with a different batch size and different dynamic sizes + revived_model.call([tf.random.normal((6, 3)), tf.random.normal((6, 5))]) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is only for the JAX backend.", + ) + def test_low_level_model_export_with_jax2tf_kwargs(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + jax2tf_kwargs={ + "native_serialization": True, + "native_serialization_platforms": ("cpu", "tpu"), + }, + ) + with self.assertRaisesRegex( + ValueError, "native_serialization_platforms.*bogus" + ): + export_archive.add_endpoint( + "call2", + model.__call__, + input_signature=[ + tf.TensorSpec(shape=(None, 10), dtype=tf.float32) + ], + jax2tf_kwargs={ + "native_serialization": True, + "native_serialization_platforms": ("cpu", "bogus"), + }, + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.call(ref_input)) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is only for the JAX backend.", + ) + def test_low_level_model_export_with_jax2tf_polymorphic_shapes(self): + class SquareLayer(layers.Layer): + def call(self, inputs): + return ops.matmul(inputs, inputs) + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = CustomModel([SquareLayer()]) + ref_input = tf.random.normal((3, 10, 10)) + ref_output = model(ref_input) + signature = [tf.TensorSpec(shape=(None, None, None), dtype=tf.float32)] + + with self.assertRaises(TypeError): + # This will fail because the polymorphic_shapes that is + # automatically generated will not account for the fact that + # dynamic dimensions 1 and 2 must have the same value. + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=signature, + jax2tf_kwargs={}, + ) + export_archive.write_out(temp_filepath) + + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=signature, + jax2tf_kwargs={"polymorphic_shapes": ["(batch, a, a)"]}, + ) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.call(ref_input)) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="This test is native to the TF backend.", + ) + def test_endpoint_registration_tf_function(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + # Test variable tracking + export_archive = saved_model.ExportArchive() + export_archive.track(model) + self.assertLen(export_archive.variables, 8) + self.assertLen(export_archive.trainable_variables, 6) + self.assertLen(export_archive.non_trainable_variables, 2) + + @tf.function() + def my_endpoint(x): + return model(x) + + # Test registering an endpoint that is a tf.function (called) + my_endpoint(ref_input) # Trace fn + + export_archive.add_endpoint( + "call", + my_endpoint, + ) + export_archive.write_out(temp_filepath) + + revived_model = tf.saved_model.load(temp_filepath) + self.assertFalse(hasattr(revived_model, "_tracked")) + self.assertAllClose(ref_output, revived_model.call(ref_input)) + self.assertLen(revived_model.variables, 8) + self.assertLen(revived_model.trainable_variables, 6) + self.assertLen(revived_model.non_trainable_variables, 2) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is native to the JAX backend.", + ) + def test_jax_endpoint_registration_tf_function(self): + model = get_model() + ref_input = np.random.normal(size=(3, 10)) + model(ref_input) + + # build a JAX function + def model_call(x): + return model(x) + + from jax import default_backend as jax_device + from jax.experimental import jax2tf + + native_jax_compatible = not ( + jax_device() == "gpu" + and len(tf.config.list_physical_devices("GPU")) == 0 + ) + # now, convert JAX function + converted_model_call = jax2tf.convert( + model_call, + native_serialization=native_jax_compatible, + polymorphic_shapes=["(b, 10)"], + ) + + # you can now build a TF inference function + @tf.function( + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + autograph=False, + ) + def infer_fn(x): + return converted_model_call(x) + + ref_output = infer_fn(ref_input) + + # Export with TF inference function as endpoint + temp_filepath = os.path.join(self.get_temp_dir(), "my_model") + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint("serve", infer_fn) + export_archive.write_out(temp_filepath) + + # Reload and verify outputs + revived_model = tf.saved_model.load(temp_filepath) + self.assertFalse(hasattr(revived_model, "_tracked")) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + self.assertLen(revived_model.variables, 8) + self.assertLen(revived_model.trainable_variables, 6) + self.assertLen(revived_model.non_trainable_variables, 2) + + # Assert all variables wrapped as `tf.Variable` + assert isinstance(export_archive.variables[0], tf.Variable) + assert isinstance(export_archive.trainable_variables[0], tf.Variable) + assert isinstance( + export_archive.non_trainable_variables[0], tf.Variable + ) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is native to the JAX backend.", + ) + def test_jax_multi_unknown_endpoint_registration(self): + window_size = 100 + + X = np.random.random((1024, window_size, 1)) + Y = np.random.random((1024, window_size, 1)) + + model = models.Sequential( + [ + layers.Dense(128, activation="relu"), + layers.Dense(64, activation="relu"), + layers.Dense(1, activation="relu"), + ] + ) + + model.compile(optimizer="adam", loss="mse") + + model.fit(X, Y, batch_size=32) + + # build a JAX function + def model_call(x): + return model(x) + + from jax import default_backend as jax_device + from jax.experimental import jax2tf + + native_jax_compatible = not ( + jax_device() == "gpu" + and len(tf.config.list_physical_devices("GPU")) == 0 + ) + # now, convert JAX function + converted_model_call = jax2tf.convert( + model_call, + native_serialization=native_jax_compatible, + polymorphic_shapes=["(b, t, 1)"], + ) + + # you can now build a TF inference function + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32) + ], + autograph=False, + ) + def infer_fn(x): + return converted_model_call(x) + + ref_input = np.random.random((1024, window_size, 1)) + ref_output = infer_fn(ref_input) + + # Export with TF inference function as endpoint + temp_filepath = os.path.join(self.get_temp_dir(), "my_model") + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint("serve", infer_fn) + export_archive.write_out(temp_filepath) + + # Reload and verify outputs + revived_model = tf.saved_model.load(temp_filepath) + self.assertFalse(hasattr(revived_model, "_tracked")) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + self.assertLen(revived_model.variables, 6) + self.assertLen(revived_model.trainable_variables, 6) + self.assertLen(revived_model.non_trainable_variables, 0) + + # Assert all variables wrapped as `tf.Variable` + assert isinstance(export_archive.variables[0], tf.Variable) + assert isinstance(export_archive.trainable_variables[0], tf.Variable) + + def test_layer_export(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") + + layer = layers.BatchNormalization() + ref_input = tf.random.normal((3, 10)) + ref_output = layer(ref_input) # Build layer (important) + + export_archive = saved_model.ExportArchive() + export_archive.track(layer) + export_archive.add_endpoint( + "call", + layer.call, + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_layer.call(ref_input)) + + def test_multi_input_output_functional_model(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + x1 = layers.Input((2,)) + x2 = layers.Input((2,)) + y1 = layers.Dense(3)(x1) + y2 = layers.Dense(3)(x2) + model = models.Model([x1, x2], [y1, y2]) + + ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))] + ref_outputs = model(ref_inputs) + + model.export(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_outputs[0], revived_model.serve(ref_inputs)[0]) + self.assertAllClose(ref_outputs[1], revived_model.serve(ref_inputs)[1]) + # Test with a different batch size + revived_model.serve( + [tf.random.normal((6, 2)), tf.random.normal((6, 2))] + ) + + # Now test dict inputs + model = models.Model({"x1": x1, "x2": x2}, [y1, y2]) + + ref_inputs = { + "x1": tf.random.normal((3, 2)), + "x2": tf.random.normal((3, 2)), + } + ref_outputs = model(ref_inputs) + + model.export(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_outputs[0], revived_model.serve(ref_inputs)[0]) + self.assertAllClose(ref_outputs[1], revived_model.serve(ref_inputs)[1]) + # Test with a different batch size + revived_model.serve( + { + "x1": tf.random.normal((6, 2)), + "x2": tf.random.normal((6, 2)), + } + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="String lookup requires TensorFlow backend", + ) + def test_model_with_lookup_table(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + text_vectorization = layers.TextVectorization() + text_vectorization.adapt(["one two", "three four", "five six"]) + model = models.Sequential( + [ + layers.Input(shape=(), dtype="string"), + text_vectorization, + layers.Embedding(10, 32), + layers.Dense(1), + ] + ) + ref_input = tf.convert_to_tensor(["one two three four"]) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + + def test_track_multiple_layers(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + layer_1 = layers.Dense(2) + ref_input_1 = tf.random.normal((3, 4)) + ref_output_1 = layer_1(ref_input_1) + layer_2 = layers.Dense(3) + ref_input_2 = tf.random.normal((3, 5)) + ref_output_2 = layer_2(ref_input_2) + + export_archive = saved_model.ExportArchive() + export_archive.add_endpoint( + "call_1", + layer_1.call, + input_signature=[tf.TensorSpec(shape=(None, 4), dtype=tf.float32)], + ) + export_archive.add_endpoint( + "call_2", + layer_2.call, + input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output_1, revived_layer.call_1(ref_input_1)) + self.assertAllClose(ref_output_2, revived_layer.call_2(ref_input_2)) + + def test_non_standard_layer_signature(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") + + layer = layers.MultiHeadAttention(2, 2) + x1 = tf.random.normal((3, 2, 2)) + x2 = tf.random.normal((3, 2, 2)) + ref_output = layer(x1, x2) # Build layer (important) + export_archive = saved_model.ExportArchive() + export_archive.track(layer) + export_archive.add_endpoint( + "call", + layer.call, + input_signature=[ + tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), + tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), + ], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_layer.call(x1, x2)) + + def test_non_standard_layer_signature_with_kwargs(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") + + layer = layers.MultiHeadAttention(2, 2) + x1 = tf.random.normal((3, 2, 2)) + x2 = tf.random.normal((3, 2, 2)) + ref_output = layer(x1, x2) # Build layer (important) + export_archive = saved_model.ExportArchive() + export_archive.track(layer) + export_archive.add_endpoint( + "call", + layer.call, + input_signature=[ + tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), + tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), + ], + ) + export_archive.write_out(temp_filepath) + revived_layer = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_layer.call(query=x1, value=x2)) + # Test with a different batch size + revived_layer.call( + query=tf.random.normal((6, 2, 2)), value=tf.random.normal((6, 2, 2)) + ) + + def test_variable_collection(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + model = models.Sequential( + [ + layers.Input((10,)), + layers.Dense(2), + layers.Dense(2), + ] + ) + + # Test variable tracking + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.add_variable_collection( + "my_vars", model.layers[1].weights + ) + + self.assertLen(export_archive._tf_trackable.my_vars, 2) + export_archive.write_out(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertLen(revived_model.my_vars, 2) + + def test_export_saved_model_errors(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + # Model has not been built + model = models.Sequential([layers.Dense(2)]) + with self.assertRaisesRegex(ValueError, "It must be built"): + saved_model.export_saved_model(model, temp_filepath) + + # Subclassed model has not been called + model = get_model("subclass") + model.build((2, 10)) + with self.assertRaisesRegex(ValueError, "It must be called"): + saved_model.export_saved_model(model, temp_filepath) + + def test_export_archive_errors(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([layers.Dense(2)]) + model(tf.random.normal((2, 3))) + + # Endpoint name reuse + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + with self.assertRaisesRegex(ValueError, "already taken"): + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[ + tf.TensorSpec(shape=(None, 3), dtype=tf.float32) + ], + ) + + # Write out with no endpoints + export_archive = saved_model.ExportArchive() + export_archive.track(model) + with self.assertRaisesRegex(ValueError, "No endpoints have been set"): + export_archive.write_out(temp_filepath) + + # Invalid object type + with self.assertRaisesRegex(ValueError, "Invalid resource type"): + export_archive = saved_model.ExportArchive() + export_archive.track("model") + + # Set endpoint with no input signature + export_archive = saved_model.ExportArchive() + export_archive.track(model) + with self.assertRaisesRegex( + ValueError, "you must provide an `input_signature`" + ): + export_archive.add_endpoint("call", model.__call__) + + # Set endpoint that has never been called + export_archive = saved_model.ExportArchive() + export_archive.track(model) + + @tf.function() + def my_endpoint(x): + return model(x) + + export_archive = saved_model.ExportArchive() + export_archive.track(model) + with self.assertRaisesRegex( + ValueError, "you must either provide a function" + ): + export_archive.add_endpoint("call", my_endpoint) + + def test_export_no_assets(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + # Case where there are legitimately no assets. + model = models.Sequential([layers.Flatten()]) + model(tf.random.normal((2, 3))) + export_archive = saved_model.ExportArchive() + export_archive.add_endpoint( + "call", + model.__call__, + input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_model_export_method(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + model.export(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + # Test with a different batch size + revived_model.serve(tf.random.normal((6, 10))) + + def test_model_combined_with_tf_preprocessing(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + + lookup_table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer( + tf.constant(["a", "b", "c"]), tf.constant([1.0, 2.0, 3.0]) + ), + default_value=-1.0, + ) + ref_input = tf.constant([["c", "b", "c", "a", "d"]]) + ref_intermediate = lookup_table.lookup(ref_input) + + model = models.Sequential([layers.Dense(1)]) + ref_output = model(ref_intermediate) + + export_archive = saved_model.ExportArchive() + model_fn = export_archive.track_and_add_endpoint( + "model", + model, + input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)], + ) + export_archive.track(lookup_table) + + @tf.function() + def combined_fn(x): + x = lookup_table.lookup(x) + x = model_fn(x) + return x + + self.assertAllClose(combined_fn(ref_input), ref_output) + + export_archive.add_endpoint("combined_fn", combined_fn) + export_archive.write_out(temp_filepath) + + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(revived_model.combined_fn(ref_input), ref_output) diff --git a/keras/src/export/tf2onnx_lib.py b/keras/src/export/tf2onnx_lib.py new file mode 100644 index 000000000000..b6ff3dfe37ae --- /dev/null +++ b/keras/src/export/tf2onnx_lib.py @@ -0,0 +1,178 @@ +import copy +import functools +import logging +import traceback + +import numpy as np + + +@functools.lru_cache() +def patch_tf2onnx(): + """Patches `tf2onnx` to ensure compatibility with numpy>=2.0.0.""" + + from onnx import AttributeProto + from onnx import TensorProto + + from keras.src.utils.module_utils import tf2onnx + + logger = logging.getLogger(tf2onnx.__name__) + + def patched_rewrite_constant_fold(g, ops): + """ + We call tensorflow transform with constant folding but in some cases + tensorflow does fold all constants. Since there are a bunch of ops in + onnx that use attributes where tensorflow has dynamic inputs, we badly + want constant folding to work. For cases where tensorflow missed + something, make another pass over the graph and fix want we care about. + """ + func_map = { + "Add": np.add, + "GreaterEqual": np.greater_equal, + "Cast": np.asarray, + "ConcatV2": np.concatenate, + "Less": np.less, + "ListDiff": np.setdiff1d, + "Mul": np.multiply, + "Pack": np.stack, + "Range": np.arange, + "Sqrt": np.sqrt, + "Sub": np.subtract, + } + ops = list(ops) + + keep_looking = True + while keep_looking: + keep_looking = False + for idx, op in enumerate(ops): + func = func_map.get(op.type) + if func is None: + continue + if set(op.output) & set(g.outputs): + continue + try: + inputs = [] + for node in op.inputs: + if not node.is_const(): + break + inputs.append(node.get_tensor_value(as_list=False)) + + logger.debug( + "op name %s, %s, %s", + op.name, + len(op.input), + len(inputs), + ) + if inputs and len(op.input) == len(inputs): + logger.info( + "folding node type=%s, name=%s" % (op.type, op.name) + ) + if op.type == "Cast": + dst = op.get_attr_int("to") + np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst) + val = np.asarray(*inputs, dtype=np_type) + elif op.type == "ConcatV2": + axis = inputs[-1] + values = inputs[:-1] + val = func(tuple(values), axis) + elif op.type == "ListDiff": + out_type = op.get_attr_int("out_idx") + np_type = tf2onnx.utils.map_onnx_to_numpy_type( + out_type + ) + val = func(*inputs) + val = val.astype(np_type) + elif op.type in ["Pack"]: + # handle ops that need input array and axis + axis = op.get_attr_int("axis") + val = func(inputs, axis=axis) + elif op.type == "Range": + dtype = op.get_attr_int("Tidx") + np_type = tf2onnx.utils.map_onnx_to_numpy_type( + dtype + ) + val = func(*inputs, dtype=np_type) + else: + val = func(*inputs) + + new_node_name = tf2onnx.utils.make_name(op.name) + new_output_name = new_node_name + old_output_name = op.output[0] + old_node_name = op.name + logger.debug( + "create const node [%s] replacing [%s]", + new_node_name, + old_node_name, + ) + ops[idx] = g.make_const(new_node_name, val) + + logger.debug( + "replace old output [%s] with new output [%s]", + old_output_name, + new_output_name, + ) + # need to re-write the consumers input name to use the + # const name + consumers = g.find_output_consumers(old_output_name) + if consumers: + for consumer in consumers: + g.replace_input( + consumer, old_output_name, new_output_name + ) + + # keep looking until there is nothing we can fold. + # We keep the graph in topological order so if we + # folded, the result might help a following op. + keep_looking = True + except Exception as ex: + tb = traceback.format_exc() + logger.info("exception: %s, details: %s", ex, tb) + # ignore errors + + return ops + + def patched_get_value_attr(self, external_tensor_storage=None): + """ + Return onnx attr for value property of node. + Attr is modified to point to external tensor data stored in + external_tensor_storage, if included. + """ + a = self._attr["value"] + if ( + external_tensor_storage is not None + and self in external_tensor_storage.node_to_modified_value_attr + ): + return external_tensor_storage.node_to_modified_value_attr[self] + if external_tensor_storage is None or a.type != AttributeProto.TENSOR: + return a + + def prod(x): + if hasattr(np, "product"): + return np.product(x) + else: + return np.prod(x) + + if ( + prod(a.t.dims) + > external_tensor_storage.external_tensor_size_threshold + ): + a = copy.deepcopy(a) + tensor_name = ( + f"{self.name.strip()}_{external_tensor_storage.name_counter}" + ) + for c in '~"#%&*:<>?/\\{|}': + tensor_name = tensor_name.replace(c, "_") + external_tensor_storage.name_counter += 1 + external_tensor_storage.name_to_tensor_data[tensor_name] = ( + a.t.raw_data + ) + external_tensor_storage.node_to_modified_value_attr[self] = a + a.t.raw_data = b"" + a.t.ClearField("raw_data") + location = a.t.external_data.add() + location.key = "location" + location.value = tensor_name + a.t.data_location = TensorProto.EXTERNAL + return a + + tf2onnx.tfonnx.rewrite_constant_fold = patched_rewrite_constant_fold + tf2onnx.graph.Node.get_value_attr = patched_get_value_attr diff --git a/keras/src/export/tfsm_layer.py b/keras/src/export/tfsm_layer.py new file mode 100644 index 000000000000..71e97c3746ca --- /dev/null +++ b/keras/src/export/tfsm_layer.py @@ -0,0 +1,148 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.export.saved_model import _list_variables_used_by_fns +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.TFSMLayer") +class TFSMLayer(layers.Layer): + """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. + + Arguments: + filepath: `str` or `pathlib.Path` object. The path to the SavedModel. + call_endpoint: Name of the endpoint to use as the `call()` method + of the reloaded layer. If the SavedModel was created + via `model.export()`, + then the default endpoint name is `'serve'`. In other cases + it may be named `'serving_default'`. + + Example: + + ```python + model.export("path/to/artifact") + reloaded_layer = TFSMLayer("path/to/artifact") + outputs = reloaded_layer(inputs) + ``` + + The reloaded object can be used like a regular Keras layer, and supports + training/fine-tuning of its trainable weights. Note that the reloaded + object retains none of the internal structure or custom methods of the + original object -- it's a brand new layer created around the saved + function. + + **Limitations:** + + * Only call endpoints with a single `inputs` tensor argument + (which may optionally be a dict/tuple/list of tensors) are supported. + For endpoints with multiple separate input tensor arguments, consider + subclassing `TFSMLayer` and implementing a `call()` method with a + custom signature. + * If you need training-time behavior to differ from inference-time behavior + (i.e. if you need the reloaded object to support a `training=True` argument + in `__call__()`), make sure that the training-time call function is + saved as a standalone endpoint in the artifact, and provide its name + to the `TFSMLayer` via the `call_training_endpoint` argument. + """ + + def __init__( + self, + filepath, + call_endpoint="serve", + call_training_endpoint=None, + trainable=True, + name=None, + dtype=None, + ): + if backend.backend() != "tensorflow": + raise NotImplementedError( + "The TFSMLayer is only currently supported with the " + "TensorFlow backend." + ) + + # Initialize an empty layer, then add_weight() etc. as needed. + super().__init__(trainable=trainable, name=name, dtype=dtype) + + self._reloaded_obj = tf.saved_model.load(filepath) + + self.filepath = filepath + self.call_endpoint = call_endpoint + self.call_training_endpoint = call_training_endpoint + + # Resolve the call function. + if hasattr(self._reloaded_obj, call_endpoint): + # Case 1: it's set as an attribute. + self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint) + elif call_endpoint in self._reloaded_obj.signatures: + # Case 2: it's listed in the `signatures` field. + self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] + else: + raise ValueError( + f"The endpoint '{call_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Select another endpoint via " + "the `call_endpoint` argument. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" + ) + + # Resolving the training function. + if call_training_endpoint: + if hasattr(self._reloaded_obj, call_training_endpoint): + self.call_training_endpoint_fn = getattr( + self._reloaded_obj, call_training_endpoint + ) + elif call_training_endpoint in self._reloaded_obj.signatures: + self.call_training_endpoint_fn = self._reloaded_obj.signatures[ + call_training_endpoint + ] + else: + raise ValueError( + f"The endpoint '{call_training_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" + ) + + # Add trainable and non-trainable weights from the call_endpoint_fn. + all_fns = [self.call_endpoint_fn] + if call_training_endpoint: + all_fns.append(self.call_training_endpoint_fn) + tvs, ntvs = _list_variables_used_by_fns(all_fns) + for v in tvs: + self._add_existing_weight(v) + for v in ntvs: + self._add_existing_weight(v) + + self._build_at_init() + + def _add_existing_weight(self, weight): + """Tracks an existing weight.""" + variable = backend.Variable( + initializer=weight, + trainable=weight.trainable, + dtype=weight.dtype, + shape=weight.shape, + # Keras variable names cannot contain slashes. + name=weight.name.replace("/", "_"), + ) + self._track_variable(variable) + + def call(self, inputs, training=False, **kwargs): + if training: + if self.call_training_endpoint: + return self.call_training_endpoint_fn(inputs, **kwargs) + return self.call_endpoint_fn(inputs, **kwargs) + + def get_config(self): + base_config = super().get_config() + config = { + # Note: this is not intended to be portable. + "filepath": self.filepath, + "call_endpoint": self.call_endpoint, + "call_training_endpoint": self.call_training_endpoint, + } + return {**base_config, **config} diff --git a/keras/src/export/tfsm_layer_test.py b/keras/src/export/tfsm_layer_test.py new file mode 100644 index 000000000000..887ed1070b6b --- /dev/null +++ b/keras/src/export/tfsm_layer_test.py @@ -0,0 +1,144 @@ +import os + +import numpy as np +import pytest +import tensorflow as tf + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src import utils +from keras.src.export import saved_model +from keras.src.export import tfsm_layer +from keras.src.export.saved_model_test import get_model +from keras.src.saving import saving_lib + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="TFSM Layer reloading is only for the TF backend.", +) +class TestTFSMLayer(testing.TestCase): + def test_reloading_export_archive(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_reloading_default_saved_model(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + tf.saved_model.save(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer( + temp_filepath, call_endpoint="serving_default" + ) + # The output is a dict, due to the nature of SavedModel saving. + new_output = reloaded_layer(ref_input) + self.assertAllClose( + new_output[list(new_output.keys())[0]], + ref_output, + atol=1e-7, + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + for keras_var in reloaded_layer.weights: + self.assertIsInstance(keras_var, backend.Variable) + + def test_call_training(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + utils.set_random_seed(1337) + model = models.Sequential( + [ + layers.Input((10,)), + layers.Dense(10), + layers.Dropout(0.99999), + ] + ) + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="call_inference", + fn=lambda x: model(x, training=False), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.add_endpoint( + name="call_training", + fn=lambda x: model(x, training=True), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer( + temp_filepath, + call_endpoint="call_inference", + call_training_endpoint="call_training", + ) + inference_output = reloaded_layer( + tf.random.normal((1, 10)), training=False + ) + training_output = reloaded_layer( + tf.random.normal((1, 10)), training=True + ) + self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7) + self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7) + + def test_serialization(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + + # Test reinstantiation from config + config = reloaded_layer.get_config() + rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config) + self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7) + + # Test whole model saving with reloaded layer inside + model = models.Sequential([reloaded_layer]) + temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") + model.save(temp_model_filepath, save_format="keras_v3") + reloaded_model = saving_lib.load_model( + temp_model_filepath, + custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer}, + ) + self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7) + + def test_errors(self): + # Test missing call endpoint + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) + saved_model.export_saved_model(model, temp_filepath) + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + tfsm_layer.TFSMLayer(temp_filepath, call_endpoint="wrong") + + # Test missing call training endpoint + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + tfsm_layer.TFSMLayer( + temp_filepath, + call_endpoint="serve", + call_training_endpoint="wrong", + ) diff --git a/keras/src/initializers/__init__.py b/keras/src/initializers/__init__.py new file mode 100644 index 000000000000..7223f5029f41 --- /dev/null +++ b/keras/src/initializers/__init__.py @@ -0,0 +1,163 @@ +import inspect + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.initializers.constant_initializers import STFT +from keras.src.initializers.constant_initializers import Constant +from keras.src.initializers.constant_initializers import Identity +from keras.src.initializers.constant_initializers import Ones +from keras.src.initializers.constant_initializers import Zeros +from keras.src.initializers.initializer import Initializer +from keras.src.initializers.random_initializers import GlorotNormal +from keras.src.initializers.random_initializers import GlorotUniform +from keras.src.initializers.random_initializers import HeNormal +from keras.src.initializers.random_initializers import HeUniform +from keras.src.initializers.random_initializers import LecunNormal +from keras.src.initializers.random_initializers import LecunUniform +from keras.src.initializers.random_initializers import Orthogonal +from keras.src.initializers.random_initializers import RandomNormal +from keras.src.initializers.random_initializers import RandomUniform +from keras.src.initializers.random_initializers import TruncatedNormal +from keras.src.initializers.random_initializers import VarianceScaling +from keras.src.saving import serialization_lib +from keras.src.utils.naming import to_snake_case + +ALL_OBJECTS = { + Initializer, + Constant, + Identity, + Ones, + STFT, + Zeros, + GlorotNormal, + GlorotUniform, + HeNormal, + HeUniform, + LecunNormal, + LecunUniform, + Orthogonal, + RandomNormal, + RandomUniform, + TruncatedNormal, + VarianceScaling, +} + +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) +# Aliases +ALL_OBJECTS_DICT.update( + { + "IdentityInitializer": Identity, # For compatibility + "normal": RandomNormal, + "one": Ones, + "STFTInitializer": STFT, # For compatibility + "OrthogonalInitializer": Orthogonal, # For compatibility + "uniform": RandomUniform, + "zero": Zeros, + } +) + + +@keras_export("keras.initializers.serialize") +def serialize(initializer): + """Returns the initializer configuration as a Python dict.""" + return serialization_lib.serialize_keras_object(initializer) + + +@keras_export("keras.initializers.deserialize") +def deserialize(config, custom_objects=None): + """Returns a Keras initializer object via its configuration.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.initializers.get") +def get(identifier): + """Retrieves a Keras initializer object via an identifier. + + The `identifier` may be the string name of a initializers function or class + (case-sensitively). + + >>> identifier = 'Ones' + >>> keras.initializers.get(identifier) + <...keras.initializers.initializers.Ones...> + + You can also specify `config` of the initializer to this function by passing + dict containing `class_name` and `config` as an identifier. Also note that + the `class_name` must map to a `Initializer` class. + + >>> cfg = {'class_name': 'Ones', 'config': {}} + >>> keras.initializers.get(cfg) + <...keras.initializers.initializers.Ones...> + + In the case that the `identifier` is a class, this method will return a new + instance of the class by its constructor. + + You may also pass a callable function with a signature that includes `shape` + and `dtype=None` as an identifier. + + >>> fn = lambda shape, dtype=None: ops.ones(shape, dtype) + >>> keras.initializers.get(fn) + at ...> + + Alternatively, you can pass a backend tensor or numpy array as the + `identifier` to define the initializer values directly. Note that when + calling the initializer, the specified `shape` argument must be the same as + the shape of the tensor. + + >>> tensor = ops.ones(shape=(5, 5)) + >>> keras.initializers.get(tensor) + .initialize_fn at ...> + + Args: + identifier: A string, dict, callable function, or tensor specifying + the initializer. If a string, it should be the name of an + initializer. If a dict, it should contain the configuration of an + initializer. Callable functions or predefined tensors are also + accepted. + + Returns: + Initializer instance base on the input identifier. + """ + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + config = {"class_name": str(identifier), "config": {}} + obj = deserialize(config) + elif ops.is_tensor(identifier) or isinstance( + identifier, (np.generic, np.ndarray) + ): + + def initialize_fn(shape, dtype=None): + dtype = backend.standardize_dtype(dtype) + if backend.standardize_shape(shape) != backend.standardize_shape( + identifier.shape + ): + raise ValueError( + f"Expected `shape` to be {identifier.shape} for direct " + f"tensor as initializer. Received shape={shape}" + ) + return ops.cast(identifier, dtype) + + obj = initialize_fn + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj() + return obj + else: + raise ValueError( + f"Could not interpret initializer identifier: {identifier}" + ) diff --git a/keras/src/initializers/constant_initializers.py b/keras/src/initializers/constant_initializers.py new file mode 100644 index 000000000000..b80e2973d2f0 --- /dev/null +++ b/keras/src/initializers/constant_initializers.py @@ -0,0 +1,284 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend import standardize_dtype +from keras.src.initializers.initializer import Initializer +from keras.src.saving import serialization_lib +from keras.src.utils.module_utils import scipy + + +@keras_export(["keras.initializers.Constant", "keras.initializers.constant"]) +class Constant(Initializer): + """Initializer that generates tensors with constant values. + + Only scalar values are allowed. + The constant value provided must be convertible to the dtype requested + when calling the initializer. + + Examples: + + >>> # Standalone usage: + >>> initializer = Constant(10.) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = Constant(10.) + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + value: A Python scalar. + """ + + def __init__(self, value=0.0): + self.value = value + + def __call__(self, shape, dtype=None): + dtype = standardize_dtype(dtype) + return ops.cast(self.value, dtype=dtype) * ops.ones( + shape=shape, dtype=dtype + ) + + def get_config(self): + return {"value": serialization_lib.serialize_keras_object(self.value)} + + @classmethod + def from_config(cls, config): + value = serialization_lib.deserialize_keras_object(config["value"]) + return cls(value) + + +@keras_export(["keras.initializers.Zeros", "keras.initializers.zeros"]) +class Zeros(Initializer): + """Initializer that generates tensors initialized to 0. + + Examples: + + >>> # Standalone usage: + >>> initializer = Zeros() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = Zeros() + >>> layer = Dense(units=3, kernel_initializer=initializer) + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes + are supported. If not specified, `keras.backend.floatx()` + is used, which default to `float32` unless you configured it + otherwise (via `keras.backend.set_floatx(float_dtype)`). + """ + dtype = standardize_dtype(dtype) + return ops.zeros(shape, dtype=dtype) + + +@keras_export(["keras.initializers.Ones", "keras.initializers.ones"]) +class Ones(Initializer): + """Initializer that generates tensors initialized to 1. + + Also available via the shortcut function `ones`. + + Examples: + + >>> # Standalone usage: + >>> initializer = Ones() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = Ones() + >>> layer = Dense(3, kernel_initializer=initializer) + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes + are supported. If not specified, `keras.backend.floatx()` + is used, which default to `float32` unless you configured it + otherwise (via `keras.backend.set_floatx(float_dtype)`). + """ + dtype = standardize_dtype(dtype) + return ops.ones(shape, dtype=dtype) + + +@keras_export( + [ + "keras.initializers.Identity", + "keras.initializers.identity", + "keras.initializers.IdentityInitializer", + ] +) +class Identity(Initializer): + """Initializer that generates the identity matrix. + + Only usable for generating 2D matrices. + + Examples: + + >>> # Standalone usage: + >>> initializer = Identity() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = Identity() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + gain: Multiplicative factor to apply to the identity matrix. + """ + + def __init__(self, gain=1.0): + self.gain = gain + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes + are supported. If not specified, `keras.backend.floatx()` + is used, which default to `float32` unless you configured it + otherwise (via `keras.backend.set_floatx(float_dtype)`). + """ + if len(shape) != 2: + raise ValueError( + "Identity matrix initializer can only be used for 2D matrices. " + f"Received: shape={shape} of rank {len(shape)}." + ) + dtype = standardize_dtype(dtype) + return self.gain * ops.eye(*shape, dtype=dtype) + + +@keras_export( + [ + "keras.initializers.STFT", + "keras.initializers.stft", + "keras.initializers.STFTInitializer", + ] +) +class STFT(Initializer): + """Initializer of Conv kernels for Short-term Fourier Transformation (STFT). + + Since the formula involves complex numbers, this class compute either the + real or the imaginary components of the final output. + + Additionally, this initializer supports windowing functions across the time + dimension as commonly used in STFT. Windowing functions from the module + `scipy.signal.windows` are supported, including the common `hann` and + `hamming` windowing functions. This layer supports periodic windows and + scaling-based normalization. + + This is primarily intended for use in the `STFTSpectrogram` layer. + + Examples: + + >>> # Standalone usage: + >>> initializer = STFTInitializer("real", "hann", "density", False) + >>> values = initializer(shape=(128, 1, 513)) + + Args: + side: String, `"real"` or `"imag"` deciding if the kernel will compute + the real side or the imaginary side of the output. Defaults to + `"real"`. + window: String for the name of the windowing function in the + `scipy.signal.windows` module, or array_like for the window values, + or `None` for no windowing. + scaling: String, `"density"` or `"spectrum"` for scaling of the window + for normalization, either L2 or L1 normalization. + `None` for no scaling. + periodic: Boolean, if True, the window function will be treated as + periodic. Defaults to `False`. + """ + + def __init__( + self, side="real", window="hann", scaling="density", periodic=False + ): + if side not in ["real", "imag"]: + raise ValueError(f"side should be 'real' or 'imag', not {side}") + if isinstance(window, str): + # throws an exception for invalid window function + scipy.signal.get_window(window, 1) + if scaling is not None and scaling not in ["density", "spectrum"]: + raise ValueError( + "Scaling is invalid, it must be `None`, 'density' " + f"or 'spectrum'. Received scaling={scaling}" + ) + self.side = side + self.window = window + self.scaling = scaling + self.periodic = periodic + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + The shape is assumed to be `(T, 1, F // 2 + 1)`, where `T` is the size + of the given window, and `F` is the number of frequency bands. Only half + the frequency bands are used, which is a common practice in STFT, + because the second half are the conjugates of the first half in + a reversed order. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. Only numeric or boolean dtypes + are supported. If not specified, `keras.backend.floatx()` + is used, which default to `float32` unless you configured it + otherwise (via `keras.backend.set_floatx(float_dtype)`). + """ + dtype = standardize_dtype(dtype) + frame_length, input_channels, fft_length = shape + + win = None + scaling = 1 + if self.window is not None: + win = self.window + if isinstance(win, str): + # Using SciPy since it provides more windowing functions, + # easier to be compatible with multiple backends. + win = scipy.signal.get_window(win, frame_length, self.periodic) + win = ops.convert_to_tensor(win, dtype=dtype) + if len(win.shape) != 1 or win.shape[-1] != frame_length: + raise ValueError( + "The shape of `window` must be equal to [frame_length]." + f"Received: window shape={win.shape}" + ) + win = ops.reshape(win, [frame_length, 1, 1]) + if self.scaling == "density": + scaling = ops.sqrt(ops.sum(ops.square(win))) + elif self.scaling == "spectrum": + scaling = ops.sum(ops.abs(win)) + + _fft_length = (fft_length - 1) * 2 + freq = ops.divide( + ops.reshape( + ops.arange(fft_length, dtype=dtype), (1, 1, fft_length) + ), + _fft_length, + ) + time = ops.reshape( + ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1) + ) + args = ops.multiply(ops.multiply(-2, time), freq) * ops.arccos( + ops.cast(-1, dtype) + ) + + if self.side == "real": + kernel = ops.cast(ops.cos(args), dtype) + else: + kernel = ops.cast(ops.sin(args), dtype) + + if win is not None: + kernel = ops.divide(ops.multiply(kernel, win), scaling) + return kernel + + def get_config(self): + return { + "side": self.side, + "window": self.window, + "periodic": self.periodic, + "scaling": self.scaling, + } diff --git a/keras/src/initializers/constant_initializers_test.py b/keras/src/initializers/constant_initializers_test.py new file mode 100644 index 000000000000..70c876cbd3bb --- /dev/null +++ b/keras/src/initializers/constant_initializers_test.py @@ -0,0 +1,138 @@ +import numpy as np +import scipy.signal + +from conftest import skip_if_backend +from keras.src import backend +from keras.src import initializers +from keras.src import testing + + +class ConstantInitializersTest(testing.TestCase): + def test_zeros_initializer(self): + shape = (3, 3) + + initializer = initializers.Zeros() + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values) + self.assertAllClose(np_values, np.zeros(shape=shape)) + + self.run_class_serialization_test(initializer) + + def test_ones_initializer(self): + shape = (3, 3) + + initializer = initializers.Ones() + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values) + self.assertAllClose(np_values, np.ones(shape=shape)) + + self.run_class_serialization_test(initializer) + + def test_constant_initializer(self): + shape = (3, 3) + constant_value = 6.0 + + initializer = initializers.Constant(value=constant_value) + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values) + self.assertAllClose( + np_values, np.full(shape=shape, fill_value=constant_value) + ) + + self.run_class_serialization_test(initializer) + + def test_constant_initializer_array_value(self): + shape = (3, 3) + constant_value = np.random.random((3, 3)) + + initializer = initializers.Constant(value=constant_value) + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values) + self.assertAllClose( + np_values, np.full(shape=shape, fill_value=constant_value) + ) + + self.run_class_serialization_test(initializer) + + @skip_if_backend("openvino", "openvino backend does not support `eye`") + def test_identity_initializer(self): + shape = (3, 3) + gain = 2 + + initializer = initializers.Identity(gain=gain) + values = initializer(shape=shape) + self.assertEqual(values.shape, shape) + np_values = backend.convert_to_numpy(values) + self.assertAllClose(np_values, np.eye(*shape) * gain) + + self.run_class_serialization_test(initializer) + + # Test compatible class_name + initializer = initializers.get("IdentityInitializer") + self.assertIsInstance(initializer, initializers.Identity) + + @skip_if_backend("openvino", "openvino backend does not support `arange`") + def test_stft_initializer(self): + shape = (256, 1, 513) + time_range = np.arange(256).reshape((-1, 1, 1)) + freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1)) + pi = np.arccos(np.float32(-1)) + args = -2 * pi * time_range * freq_range + tol_kwargs = {"atol": 1e-4, "rtol": 1e-6} + + initializer = initializers.STFT("real", None) + values = backend.convert_to_numpy(initializer(shape)) + self.assertAllClose(np.cos(args), values, atol=1e-4) + self.run_class_serialization_test(initializer) + + initializer = initializers.STFT( + "real", + "hamming", + None, + True, + ) + window = scipy.signal.windows.get_window("hamming", 256, True) + window = window.astype("float32").reshape((-1, 1, 1)) + values = backend.convert_to_numpy(initializer(shape, "float32")) + self.assertAllClose(np.cos(args) * window, values, **tol_kwargs) + self.run_class_serialization_test(initializer) + + initializer = initializers.STFT( + "imag", + "tukey", + "density", + False, + ) + window = scipy.signal.windows.get_window("tukey", 256, False) + window = window.astype("float32").reshape((-1, 1, 1)) + window = window / np.sqrt(np.sum(window**2)) + values = backend.convert_to_numpy(initializer(shape, "float32")) + self.assertAllClose(np.sin(args) * window, values, **tol_kwargs) + self.run_class_serialization_test(initializer) + + initializer = initializers.STFT( + "imag", + list(range(1, 257)), + "spectrum", + ) + window = np.arange(1, 257) + window = window.astype("float32").reshape((-1, 1, 1)) + window = window / np.sum(window) + values = backend.convert_to_numpy(initializer(shape, "float32")) + self.assertAllClose(np.sin(args) * window, values, **tol_kwargs) + self.run_class_serialization_test(initializer) + + with self.assertRaises(ValueError): + initializers.STFT("imaginary") + with self.assertRaises(ValueError): + initializers.STFT("real", scaling="l2") + with self.assertRaises(ValueError): + initializers.STFT("real", window="unknown") + + # Test compatible class_name + initializer = initializers.get("STFTInitializer") + self.assertIsInstance(initializer, initializers.STFT) diff --git a/keras/src/initializers/initializer.py b/keras/src/initializers/initializer.py new file mode 100644 index 000000000000..cef22f378c5c --- /dev/null +++ b/keras/src/initializers/initializer.py @@ -0,0 +1,84 @@ +from keras.src.api_export import keras_export + + +@keras_export(["keras.Initializer", "keras.initializers.Initializer"]) +class Initializer: + """Initializer base class: all Keras initializers inherit from this class. + + Initializers should implement a `__call__()` method with the following + signature: + + ```python + def __call__(self, shape, dtype=None, **kwargs): + # returns a tensor of shape `shape` and dtype `dtype` + # containing values drawn from a distribution of your choice. + ``` + + Optionally, you can also implement the method `get_config()` and the class + method `from_config` in order to support serialization, just like with + any Keras object. + + Here's a simple example: a random normal initializer. + + ```python + class ExampleRandomNormal(Initializer): + def __init__(self, mean, stddev): + self.mean = mean + self.stddev = stddev + + def __call__(self, shape, dtype=None, **kwargs): + return keras.random.normal( + shape, mean=self.mean, stddev=self.stddev, dtype=dtype + ) + + def get_config(self): # To support serialization + return {"mean": self.mean, "stddev": self.stddev} + ``` + + Note that we don't have to implement `from_config()` in the example above + since the constructor arguments of the class the keys in the config returned + by `get_config()` are the same. In this case, the default `from_config()` + works fine. + """ + + def __call__(self, shape, dtype=None): + """Returns a tensor object initialized as specified by the initializer. + + Args: + shape: Shape of the tensor. + dtype: Optional dtype of the tensor. + """ + raise NotImplementedError( + "Initializer subclasses must implement the `__call__()` method." + ) + + def get_config(self): + """Returns the initializer's configuration as a JSON-serializable dict. + + Returns: + A JSON-serializable Python dict. + """ + return {} + + @classmethod + def from_config(cls, config): + """Instantiates an initializer from a configuration dictionary. + + Example: + + ```python + initializer = RandomUniform(-1, 1) + config = initializer.get_config() + initializer = RandomUniform.from_config(config) + ``` + + Args: + config: A Python dictionary, the output of `get_config()`. + + Returns: + An `Initializer` instance. + """ + return cls(**config) + + def clone(self): + return self.__class__.from_config(self.get_config()) diff --git a/keras/src/initializers/random_initializers.py b/keras/src/initializers/random_initializers.py new file mode 100644 index 000000000000..ad1123e2a18f --- /dev/null +++ b/keras/src/initializers/random_initializers.py @@ -0,0 +1,715 @@ +import math + +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend import random +from keras.src.initializers.initializer import Initializer +from keras.src.saving import serialization_lib + + +class RandomInitializer(Initializer): + def __init__(self, seed=None): + self._init_seed = seed + if seed is None: + seed = random.make_default_seed() + elif isinstance(seed, dict): + seed = serialization_lib.deserialize_keras_object(seed) + elif not isinstance(seed, (int, random.SeedGenerator)): + raise ValueError( + "`seed` argument should be an instance of " + "`keras.random.SeedGenerator()` or an integer. " + f"Received: seed={seed}" + ) + self.seed = seed + + def get_config(self): + seed_config = serialization_lib.serialize_keras_object(self._init_seed) + return {"seed": seed_config} + + +@keras_export( + [ + "keras.initializers.RandomNormal", + "keras.initializers.random_normal", + ] +) +class RandomNormal(RandomInitializer): + """Random normal initializer. + + Draws samples from a normal distribution for given parameters. + + Examples: + + >>> # Standalone usage: + >>> initializer = RandomNormal(mean=0.0, stddev=1.0) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = RandomNormal(mean=0.0, stddev=1.0) + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + mean: A python scalar or a scalar keras tensor. Mean of the random + values to generate. + stddev: A python scalar or a scalar keras tensor. Standard deviation of + the random values to generate. + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __init__(self, mean=0.0, stddev=0.05, seed=None): + self.mean = mean + self.stddev = stddev + super().__init__(seed=seed) + + def __call__(self, shape, dtype=None): + return random.normal( + shape=shape, + mean=self.mean, + stddev=self.stddev, + seed=self.seed, + dtype=dtype, + ) + + def get_config(self): + base_config = super().get_config() + config = {"mean": self.mean, "stddev": self.stddev} + return {**base_config, **config} + + +@keras_export( + [ + "keras.initializers.TruncatedNormal", + "keras.initializers.truncated_normal", + ] +) +class TruncatedNormal(RandomInitializer): + """Initializer that generates a truncated normal distribution. + + The values generated are similar to values from a + `RandomNormal` initializer, except that values more + than two standard deviations from the mean are + discarded and re-drawn. + + Examples: + + >>> # Standalone usage: + >>> initializer = TruncatedNormal(mean=0., stddev=1.) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = TruncatedNormal(mean=0., stddev=1.) + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + mean: A python scalar or a scalar keras tensor. Mean of the random + values to generate. + stddev: A python scalar or a scalar keras tensor. Standard deviation of + the random values to generate. + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __init__(self, mean=0.0, stddev=0.05, seed=None): + self.mean = mean + self.stddev = stddev + super().__init__(seed=seed) + + def __call__(self, shape, dtype=None): + return random.truncated_normal( + shape=shape, + mean=self.mean, + stddev=self.stddev, + seed=self.seed, + dtype=dtype, + ) + + def get_config(self): + base_config = super().get_config() + config = {"mean": self.mean, "stddev": self.stddev} + return {**base_config, **config} + + +@keras_export( + [ + "keras.initializers.RandomUniform", + "keras.initializers.random_uniform", + ] +) +class RandomUniform(RandomInitializer): + """Random uniform initializer. + + Draws samples from a uniform distribution for given parameters. + + Examples: + + >>> # Standalone usage: + >>> initializer = RandomUniform(minval=0.0, maxval=1.0) + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = RandomUniform(minval=0.0, maxval=1.0) + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + minval: A python scalar or a scalar keras tensor. Lower bound of the + range of random values to generate (inclusive). + maxval: A python scalar or a scalar keras tensor. Upper bound of the + range of random values to generate (exclusive). + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __init__(self, minval=-0.05, maxval=0.05, seed=None): + self.minval = minval + self.maxval = maxval + super().__init__(seed=seed) + + def __call__(self, shape, dtype=None): + return random.uniform( + shape=shape, + minval=self.minval, + maxval=self.maxval, + seed=self.seed, + dtype=dtype, + ) + + def get_config(self): + base_config = super().get_config() + config = {"minval": self.minval, "maxval": self.maxval} + return {**base_config, **config} + + +@keras_export( + [ + "keras.initializers.VarianceScaling", + "keras.initializers.variance_scaling", + ] +) +class VarianceScaling(RandomInitializer): + """Initializer that adapts its scale to the shape of its input tensors. + + With `distribution="truncated_normal" or "untruncated_normal"`, samples are + drawn from a truncated/untruncated normal distribution with a mean of zero + and a standard deviation (after truncation, if used) `stddev = sqrt(scale / + n)`, where `n` is: + + - number of input units in the weight tensor, if `mode="fan_in"` + - number of output units, if `mode="fan_out"` + - average of the numbers of input and output units, if `mode="fan_avg"` + + With `distribution="uniform"`, samples are drawn from a uniform distribution + within `[-limit, limit]`, where `limit = sqrt(3 * scale / n)`. + + Examples: + + >>> # Standalone usage: + >>> initializer = VarianceScaling( + scale=0.1, mode='fan_in', distribution='uniform') + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = VarianceScaling( + scale=0.1, mode='fan_in', distribution='uniform') + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + scale: Scaling factor (positive float). + mode: One of `"fan_in"`, `"fan_out"`, `"fan_avg"`. + distribution: Random distribution to use. + One of `"truncated_normal"`, `"untruncated_normal"`, or `"uniform"`. + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + """ + + def __init__( + self, + scale=1.0, + mode="fan_in", + distribution="truncated_normal", + seed=None, + ): + if scale <= 0.0: + raise ValueError( + "Argument `scale` must be positive float. " + f"Received: scale={scale}" + ) + allowed_modes = {"fan_in", "fan_out", "fan_avg"} + if mode not in allowed_modes: + raise ValueError( + f"Invalid `mode` argument: {mode}. " + f"Please use one of {allowed_modes}" + ) + distribution = distribution.lower() + if distribution == "normal": + distribution = "truncated_normal" + allowed_distributions = { + "uniform", + "truncated_normal", + "untruncated_normal", + } + if distribution not in allowed_distributions: + raise ValueError( + f"Invalid `distribution` argument: {distribution}." + f"Please use one of {allowed_distributions}" + ) + self.scale = scale + self.mode = mode + self.distribution = distribution + super().__init__(seed=seed) + + def __call__(self, shape, dtype=None): + scale = self.scale + fan_in, fan_out = compute_fans(shape) + if self.mode == "fan_in": + scale /= max(1.0, fan_in) + elif self.mode == "fan_out": + scale /= max(1.0, fan_out) + else: + scale /= max(1.0, (fan_in + fan_out) / 2.0) + if self.distribution == "truncated_normal": + stddev = math.sqrt(scale) / 0.87962566103423978 + return random.truncated_normal( + shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + ) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + return random.normal( + shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed + ) + else: + limit = math.sqrt(3.0 * scale) + return random.uniform( + shape, minval=-limit, maxval=limit, dtype=dtype, seed=self.seed + ) + + def get_config(self): + base_config = super().get_config() + config = { + "scale": self.scale, + "mode": self.mode, + "distribution": self.distribution, + } + return {**base_config, **config} + + +@keras_export( + [ + "keras.initializers.GlorotUniform", + "keras.initializers.glorot_uniform", + ] +) +class GlorotUniform(VarianceScaling): + """The Glorot uniform initializer, also called Xavier uniform initializer. + + Draws samples from a uniform distribution within `[-limit, limit]`, where + `limit = sqrt(6 / (fan_in + fan_out))` (`fan_in` is the number of input + units in the weight tensor and `fan_out` is the number of output units). + + Examples: + + >>> # Standalone usage: + >>> initializer = GlorotUniform() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = GlorotUniform() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, mode="fan_avg", distribution="uniform", seed=seed + ) + + def get_config(self): + return { + "seed": serialization_lib.serialize_keras_object(self._init_seed) + } + + +@keras_export( + [ + "keras.initializers.GlorotNormal", + "keras.initializers.glorot_normal", + ] +) +class GlorotNormal(VarianceScaling): + """The Glorot normal initializer, also called Xavier normal initializer. + + Draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of + input units in the weight tensor and `fan_out` is the number of output units + in the weight tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = GlorotNormal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = GlorotNormal() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, + mode="fan_avg", + distribution="truncated_normal", + seed=seed, + ) + + def get_config(self): + return { + "seed": serialization_lib.serialize_keras_object(self._init_seed) + } + + +@keras_export( + [ + "keras.initializers.LecunNormal", + "keras.initializers.lecun_normal", + ] +) +class LecunNormal(VarianceScaling): + """Lecun normal initializer. + + Initializers allow you to pre-specify an initialization strategy, encoded in + the Initializer object, without knowing the shape and dtype of the variable + being initialized. + + Draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of input units in + the weight tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = LecunNormal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = LecunNormal() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, mode="fan_in", distribution="truncated_normal", seed=seed + ) + + def get_config(self): + return { + "seed": serialization_lib.serialize_keras_object(self._init_seed) + } + + +@keras_export( + [ + "keras.initializers.LecunUniform", + "keras.initializers.lecun_uniform", + ] +) +class LecunUniform(VarianceScaling): + """Lecun uniform initializer. + + Draws samples from a uniform distribution within `[-limit, limit]`, where + `limit = sqrt(3 / fan_in)` (`fan_in` is the number of input units in the + weight tensor). + + Examples: + + >>> # Standalone usage: + >>> initializer = LecunUniform() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = LecunUniform() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515) + """ + + def __init__(self, seed=None): + super().__init__( + scale=1.0, mode="fan_in", distribution="uniform", seed=seed + ) + + def get_config(self): + return { + "seed": serialization_lib.serialize_keras_object(self._init_seed) + } + + +@keras_export(["keras.initializers.HeNormal", "keras.initializers.he_normal"]) +class HeNormal(VarianceScaling): + """He normal initializer. + + It draws samples from a truncated normal distribution centered on 0 with + `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in + the weight tensor. + + Examples: + + >>> # Standalone usage: + >>> initializer = HeNormal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = HeNormal() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [He et al., 2015](https://arxiv.org/abs/1502.01852) + """ + + def __init__(self, seed=None): + super().__init__( + scale=2.0, mode="fan_in", distribution="truncated_normal", seed=seed + ) + + def get_config(self): + return { + "seed": serialization_lib.serialize_keras_object(self._init_seed) + } + + +@keras_export(["keras.initializers.HeUniform", "keras.initializers.he_uniform"]) +class HeUniform(VarianceScaling): + """He uniform variance scaling initializer. + + Draws samples from a uniform distribution within `[-limit, limit]`, where + `limit = sqrt(6 / fan_in)` (`fan_in` is the number of input units in the + weight tensor). + + Examples: + + >>> # Standalone usage: + >>> initializer = HeUniform() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = HeUniform() + >>> layer = Dense(3, kernel_initializer=initializer) + + Args: + seed: A Python integer or instance of + `keras.backend.SeedGenerator`. + Used to make the behavior of the initializer + deterministic. Note that an initializer seeded with an integer + or `None` (unseeded) will produce the same random values + across multiple calls. To get different random values + across multiple calls, use as seed an instance + of `keras.backend.SeedGenerator`. + + Reference: + + - [He et al., 2015](https://arxiv.org/abs/1502.01852) + """ + + def __init__(self, seed=None): + super().__init__( + scale=2.0, mode="fan_in", distribution="uniform", seed=seed + ) + + def get_config(self): + return { + "seed": serialization_lib.serialize_keras_object(self._init_seed) + } + + +def compute_fans(shape): + """Computes the number of input and output units for a weight shape. + + Args: + shape: Integer shape tuple. + + Returns: + A tuple of integer scalars: `(fan_in, fan_out)`. + """ + shape = tuple(shape) + if len(shape) < 1: # Just to avoid errors for constants. + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + fan_in = shape[0] + fan_out = shape[1] + else: + # Assuming convolution kernels (2D, 3D, or more). + # kernel shape: (..., input_depth, depth) + receptive_field_size = 1 + for dim in shape[:-2]: + receptive_field_size *= dim + fan_in = shape[-2] * receptive_field_size + fan_out = shape[-1] * receptive_field_size + return int(fan_in), int(fan_out) + + +@keras_export( + [ + "keras.initializers.Orthogonal", + "keras.initializers.orthogonal", + "keras.initializers.OrthogonalInitializer", + ] +) +class Orthogonal(RandomInitializer): + """Initializer that generates an orthogonal matrix. + + If the shape of the tensor to initialize is two-dimensional, it is + initialized with an orthogonal matrix obtained from the QR decomposition of + a matrix of random numbers drawn from a normal distribution. If the matrix + has fewer rows than columns then the output will have orthogonal rows. + Otherwise, the output will have orthogonal columns. + + If the shape of the tensor to initialize is more than two-dimensional, + a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])` + is initialized, where `n` is the length of the shape vector. + The matrix is subsequently reshaped to give a tensor of the desired shape. + + Examples: + + >>> # Standalone usage: + >>> initializer = keras.initializers.Orthogonal() + >>> values = initializer(shape=(2, 2)) + + >>> # Usage in a Keras layer: + >>> initializer = keras.initializers.Orthogonal() + >>> layer = keras.layers.Dense(3, kernel_initializer=initializer) + + Args: + gain: Multiplicative factor to apply to the orthogonal matrix. + seed: A Python integer. Used to make the behavior of the initializer + deterministic. + + Reference: + + - [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C) + """ + + def __init__(self, gain=1.0, seed=None): + self.gain = gain + super().__init__(seed=seed) + + def __call__(self, shape, dtype=None): + if len(shape) < 2: + raise ValueError( + "The tensor to initialize must be " + "at least two-dimensional. Received: " + f"shape={shape} of rank {len(shape)}." + ) + + # Flatten the input shape with the last dimension remaining + # its original shape so it works for conv2d + num_rows = 1 + for dim in shape[:-1]: + num_rows *= dim + num_cols = shape[-1] + flat_shape = (max(num_cols, num_rows), min(num_cols, num_rows)) + + # Generate a random matrix + a = random.normal(flat_shape, seed=self.seed, dtype=dtype) + # Compute the qr factorization + q, r = ops.qr(a) + # Make Q uniform + d = ops.diag(r) + q *= ops.sign(d) + if num_rows < num_cols: + q = ops.transpose(q) + return self.gain * ops.reshape(q, shape) + + def get_config(self): + base_config = super().get_config() + config = {"gain": self.gain} + return {**base_config, **config} diff --git a/keras/src/initializers/random_initializers_test.py b/keras/src/initializers/random_initializers_test.py new file mode 100644 index 000000000000..aaad117acee0 --- /dev/null +++ b/keras/src/initializers/random_initializers_test.py @@ -0,0 +1,235 @@ +import numpy as np + +from conftest import skip_if_backend +from keras.src import backend +from keras.src import initializers +from keras.src import random +from keras.src import testing +from keras.src import utils + + +class RandomInitializersTest(testing.TestCase): + def test_random_normal(self): + utils.set_random_seed(1337) + shape = (25, 20) + mean = 0.0 + stddev = 1.0 + seed = 1234 + initializer = initializers.RandomNormal( + mean=mean, stddev=stddev, seed=seed + ) + values = initializer(shape=shape) + self.assertEqual(initializer.mean, mean) + self.assertEqual(initializer.stddev, stddev) + self.assertEqual(initializer.seed, seed) + self.assertEqual(values.shape, shape) + self.assertAllClose( + np.std(backend.convert_to_numpy(values)), stddev, atol=1e-1 + ) + + self.run_class_serialization_test(initializer) + + # Test that a fixed seed yields the same results each call. + initializer = initializers.RandomNormal( + mean=mean, stddev=stddev, seed=1337 + ) + values = initializer(shape=shape) + next_values = initializer(shape=shape) + self.assertAllClose(values, next_values) + + # Test that a SeedGenerator yields different results each call. + initializer = initializers.RandomNormal( + mean=mean, stddev=stddev, seed=backend.random.SeedGenerator(1337) + ) + values = initializer(shape=shape) + next_values = initializer(shape=shape) + self.assertNotAllClose(values, next_values) + + # Test serialization with SeedGenerator + initializer = initializers.RandomNormal( + mean=mean, stddev=stddev, seed=backend.random.SeedGenerator(1337) + ) + values = initializer(shape=shape) + + # Test that unseeded generator gets different results after cloning + initializer = initializers.RandomNormal( + mean=mean, stddev=stddev, seed=None + ) + values = initializer(shape=shape) + cloned_initializer = initializers.RandomNormal.from_config( + initializer.get_config() + ) + new_values = cloned_initializer(shape=shape) + self.assertNotAllClose(values, new_values) + + # Test that seeded generator gets same results after cloning + initializer = initializers.RandomNormal( + mean=mean, stddev=stddev, seed=1337 + ) + values = initializer(shape=shape) + cloned_initializer = initializers.RandomNormal.from_config( + initializer.get_config() + ) + new_values = cloned_initializer(shape=shape) + self.assertAllClose(values, new_values) + + def test_random_uniform(self): + shape = (5, 5) + minval = -1.0 + maxval = 1.0 + seed = 1234 + initializer = initializers.RandomUniform( + minval=minval, maxval=maxval, seed=seed + ) + values = initializer(shape=shape) + self.assertEqual(initializer.minval, minval) + self.assertEqual(initializer.maxval, maxval) + self.assertEqual(initializer.seed, seed) + self.assertEqual(values.shape, shape) + values = backend.convert_to_numpy(values) + self.assertGreaterEqual(np.min(values), minval) + self.assertLess(np.max(values), maxval) + + self.run_class_serialization_test(initializer) + + def test_variance_scaling(self): + utils.set_random_seed(1337) + shape = (25, 20) + scale = 2.0 + seed = 1234 + initializer = initializers.VarianceScaling( + scale=scale, seed=seed, mode="fan_in" + ) + values = initializer(shape=shape) + self.assertEqual(initializer.scale, scale) + self.assertEqual(initializer.seed, seed) + self.assertEqual(values.shape, shape) + self.assertAllClose( + np.std(backend.convert_to_numpy(values)), + np.sqrt(scale / 25), + atol=1e-1, + ) + self.run_class_serialization_test(initializer) + + initializer = initializers.VarianceScaling( + scale=scale, seed=seed, mode="fan_out" + ) + values = initializer(shape=shape) + self.assertEqual(initializer.scale, scale) + self.assertEqual(initializer.seed, seed) + self.assertEqual(values.shape, shape) + self.assertAllClose( + np.std(backend.convert_to_numpy(values)), + np.sqrt(scale / 20), + atol=1e-1, + ) + self.run_class_serialization_test(initializer) + + @skip_if_backend("openvino", "openvino backend does not support `qr`") + def test_orthogonal(self): + shape = (5, 5) + gain = 2.0 + seed = 1234 + initializer = initializers.Orthogonal(gain=gain, seed=seed) + values = initializer(shape=shape) + self.assertEqual(initializer.seed, seed) + self.assertEqual(initializer.gain, gain) + + self.assertEqual(values.shape, shape) + array = backend.convert_to_numpy(values) + # Making sure that the columns have gain * unit norm value + for column in array.T: + self.assertAlmostEqual(np.linalg.norm(column), gain * 1.0) + + # Making sure that each column is orthonormal to the other column + for i in range(array.shape[-1]): + for j in range(i + 1, array.shape[-1]): + self.assertAlmostEqual( + np.dot(array[..., i], array[..., j]), 0.0 + ) + + self.run_class_serialization_test(initializer) + + # Test compatible class_name + initializer = initializers.get("OrthogonalInitializer") + self.assertIsInstance(initializer, initializers.Orthogonal) + + def test_get_method(self): + obj = initializers.get("glorot_normal") + self.assertTrue(obj, initializers.GlorotNormal) + + obj = initializers.get(None) + self.assertEqual(obj, None) + + with self.assertRaises(ValueError): + initializers.get("typo") + + @skip_if_backend( + "openvino", "openvino backend does not support `uniform` with None seed" + ) + def test_get_method_with_tensor(self): + shape = (5, 5) + + # Test backend tensor + tensor = random.uniform(shape=shape) + initializer = initializers.get(tensor) + values = initializer(shape=shape) + self.assertAllClose(values, tensor) + + # Test numpy array + tensor = np.random.uniform(size=shape).astype("float32") + initializer = initializers.get(tensor) + values = initializer(shape=shape) + self.assertAllClose(values, tensor) + + # Test bad `shape` argument + with self.assertRaisesRegex(ValueError, r"Expected `shape` to be"): + initializer(shape=(10, 10)) + + def test_variance_scaling_invalid_scale(self): + seed = 1234 + + with self.assertRaisesRegex( + ValueError, "Argument `scale` must be positive float." + ): + initializers.VarianceScaling(scale=-1.0, seed=seed, mode="fan_in") + + def test_variance_scaling_invalid_mode(self): + scale = 2.0 + seed = 1234 + + with self.assertRaisesRegex(ValueError, "Invalid `mode` argument:"): + initializers.VarianceScaling( + scale=scale, seed=seed, mode="invalid_mode" + ) + + def test_variance_scaling_invalid_distribution(self): + scale = 2.0 + seed = 1234 + + with self.assertRaisesRegex( + ValueError, "Invalid `distribution` argument:" + ): + initializers.VarianceScaling( + scale=scale, + seed=seed, + mode="fan_in", + distribution="invalid_dist", + ) + + def test_serialization_with_seed_generator(self): + seed = random.SeedGenerator() + initializer = initializers.Orthogonal(seed=seed) + self.run_class_serialization_test(initializer) + + seed = random.SeedGenerator() + initializer = initializers.VarianceScaling(seed=seed) + self.run_class_serialization_test(initializer) + + seed = random.SeedGenerator() + initializer = initializers.RandomUniform(seed=seed) + self.run_class_serialization_test(initializer) + + seed = random.SeedGenerator() + initializer = initializers.RandomNormal(seed=seed) + self.run_class_serialization_test(initializer) diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py new file mode 100644 index 000000000000..83a1f571251a --- /dev/null +++ b/keras/src/layers/__init__.py @@ -0,0 +1,252 @@ +from keras.src.api_export import keras_export +from keras.src.layers.activations.activation import Activation +from keras.src.layers.activations.elu import ELU +from keras.src.layers.activations.leaky_relu import LeakyReLU +from keras.src.layers.activations.prelu import PReLU +from keras.src.layers.activations.relu import ReLU +from keras.src.layers.activations.softmax import Softmax +from keras.src.layers.attention.additive_attention import AdditiveAttention +from keras.src.layers.attention.attention import Attention +from keras.src.layers.attention.grouped_query_attention import ( + GroupedQueryAttention, +) +from keras.src.layers.attention.multi_head_attention import MultiHeadAttention +from keras.src.layers.convolutional.conv1d import Conv1D +from keras.src.layers.convolutional.conv1d_transpose import Conv1DTranspose +from keras.src.layers.convolutional.conv2d import Conv2D +from keras.src.layers.convolutional.conv2d_transpose import Conv2DTranspose +from keras.src.layers.convolutional.conv3d import Conv3D +from keras.src.layers.convolutional.conv3d_transpose import Conv3DTranspose +from keras.src.layers.convolutional.depthwise_conv1d import DepthwiseConv1D +from keras.src.layers.convolutional.depthwise_conv2d import DepthwiseConv2D +from keras.src.layers.convolutional.separable_conv1d import SeparableConv1D +from keras.src.layers.convolutional.separable_conv2d import SeparableConv2D +from keras.src.layers.core.dense import Dense +from keras.src.layers.core.einsum_dense import EinsumDense +from keras.src.layers.core.embedding import Embedding +from keras.src.layers.core.identity import Identity +from keras.src.layers.core.input_layer import Input +from keras.src.layers.core.input_layer import InputLayer +from keras.src.layers.core.lambda_layer import Lambda +from keras.src.layers.core.masking import Masking +from keras.src.layers.core.wrapper import Wrapper +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.layers.merging.add import Add +from keras.src.layers.merging.add import add +from keras.src.layers.merging.average import Average +from keras.src.layers.merging.average import average +from keras.src.layers.merging.concatenate import Concatenate +from keras.src.layers.merging.concatenate import concatenate +from keras.src.layers.merging.dot import Dot +from keras.src.layers.merging.dot import dot +from keras.src.layers.merging.maximum import Maximum +from keras.src.layers.merging.maximum import maximum +from keras.src.layers.merging.minimum import Minimum +from keras.src.layers.merging.minimum import minimum +from keras.src.layers.merging.multiply import Multiply +from keras.src.layers.merging.multiply import multiply +from keras.src.layers.merging.subtract import Subtract +from keras.src.layers.merging.subtract import subtract +from keras.src.layers.normalization.batch_normalization import ( + BatchNormalization, +) +from keras.src.layers.normalization.group_normalization import ( + GroupNormalization, +) +from keras.src.layers.normalization.layer_normalization import ( + LayerNormalization, +) +from keras.src.layers.normalization.rms_normalization import RMSNormalization +from keras.src.layers.normalization.spectral_normalization import ( + SpectralNormalization, +) +from keras.src.layers.normalization.unit_normalization import UnitNormalization +from keras.src.layers.pooling.average_pooling1d import AveragePooling1D +from keras.src.layers.pooling.average_pooling2d import AveragePooling2D +from keras.src.layers.pooling.average_pooling3d import AveragePooling3D +from keras.src.layers.pooling.global_average_pooling1d import ( + GlobalAveragePooling1D, +) +from keras.src.layers.pooling.global_average_pooling2d import ( + GlobalAveragePooling2D, +) +from keras.src.layers.pooling.global_average_pooling3d import ( + GlobalAveragePooling3D, +) +from keras.src.layers.pooling.global_max_pooling1d import GlobalMaxPooling1D +from keras.src.layers.pooling.global_max_pooling2d import GlobalMaxPooling2D +from keras.src.layers.pooling.global_max_pooling3d import GlobalMaxPooling3D +from keras.src.layers.pooling.max_pooling1d import MaxPooling1D +from keras.src.layers.pooling.max_pooling2d import MaxPooling2D +from keras.src.layers.pooling.max_pooling3d import MaxPooling3D +from keras.src.layers.preprocessing.category_encoding import CategoryEncoding +from keras.src.layers.preprocessing.discretization import Discretization +from keras.src.layers.preprocessing.hashed_crossing import HashedCrossing +from keras.src.layers.preprocessing.hashing import Hashing +from keras.src.layers.preprocessing.image_preprocessing.aug_mix import AugMix +from keras.src.layers.preprocessing.image_preprocessing.auto_contrast import ( + AutoContrast, +) +from keras.src.layers.preprocessing.image_preprocessing.center_crop import ( + CenterCrop, +) +from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix +from keras.src.layers.preprocessing.image_preprocessing.equalization import ( + Equalization, +) +from keras.src.layers.preprocessing.image_preprocessing.max_num_bounding_box import ( + MaxNumBoundingBoxes, +) +from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment, +) +from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( + RandomBrightness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) +from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( + RandomColorJitter, +) +from keras.src.layers.preprocessing.image_preprocessing.random_contrast import ( + RandomContrast, +) +from keras.src.layers.preprocessing.image_preprocessing.random_crop import ( + RandomCrop, +) +from keras.src.layers.preprocessing.image_preprocessing.random_elastic_transform import ( + RandomElasticTransform, +) +from keras.src.layers.preprocessing.image_preprocessing.random_erasing import ( + RandomErasing, +) +from keras.src.layers.preprocessing.image_preprocessing.random_flip import ( + RandomFlip, +) +from keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import ( + RandomGaussianBlur, +) +from keras.src.layers.preprocessing.image_preprocessing.random_grayscale import ( + RandomGrayscale, +) +from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( + RandomHue, +) +from keras.src.layers.preprocessing.image_preprocessing.random_invert import ( + RandomInvert, +) +from keras.src.layers.preprocessing.image_preprocessing.random_perspective import ( + RandomPerspective, +) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) +from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( + RandomRotation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_saturation import ( + RandomSaturation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_sharpness import ( + RandomSharpness, +) +from keras.src.layers.preprocessing.image_preprocessing.random_shear import ( + RandomShear, +) +from keras.src.layers.preprocessing.image_preprocessing.random_translation import ( + RandomTranslation, +) +from keras.src.layers.preprocessing.image_preprocessing.random_zoom import ( + RandomZoom, +) +from keras.src.layers.preprocessing.image_preprocessing.resizing import Resizing +from keras.src.layers.preprocessing.image_preprocessing.solarization import ( + Solarization, +) +from keras.src.layers.preprocessing.index_lookup import IndexLookup +from keras.src.layers.preprocessing.integer_lookup import IntegerLookup +from keras.src.layers.preprocessing.mel_spectrogram import MelSpectrogram +from keras.src.layers.preprocessing.normalization import Normalization +from keras.src.layers.preprocessing.pipeline import Pipeline +from keras.src.layers.preprocessing.rescaling import Rescaling +from keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram +from keras.src.layers.preprocessing.string_lookup import StringLookup +from keras.src.layers.preprocessing.text_vectorization import TextVectorization +from keras.src.layers.regularization.activity_regularization import ( + ActivityRegularization, +) +from keras.src.layers.regularization.alpha_dropout import AlphaDropout +from keras.src.layers.regularization.dropout import Dropout +from keras.src.layers.regularization.gaussian_dropout import GaussianDropout +from keras.src.layers.regularization.gaussian_noise import GaussianNoise +from keras.src.layers.regularization.spatial_dropout import SpatialDropout1D +from keras.src.layers.regularization.spatial_dropout import SpatialDropout2D +from keras.src.layers.regularization.spatial_dropout import SpatialDropout3D +from keras.src.layers.reshaping.cropping1d import Cropping1D +from keras.src.layers.reshaping.cropping2d import Cropping2D +from keras.src.layers.reshaping.cropping3d import Cropping3D +from keras.src.layers.reshaping.flatten import Flatten +from keras.src.layers.reshaping.permute import Permute +from keras.src.layers.reshaping.repeat_vector import RepeatVector +from keras.src.layers.reshaping.reshape import Reshape +from keras.src.layers.reshaping.up_sampling1d import UpSampling1D +from keras.src.layers.reshaping.up_sampling2d import UpSampling2D +from keras.src.layers.reshaping.up_sampling3d import UpSampling3D +from keras.src.layers.reshaping.zero_padding1d import ZeroPadding1D +from keras.src.layers.reshaping.zero_padding2d import ZeroPadding2D +from keras.src.layers.reshaping.zero_padding3d import ZeroPadding3D +from keras.src.layers.rnn.bidirectional import Bidirectional +from keras.src.layers.rnn.conv_lstm1d import ConvLSTM1D +from keras.src.layers.rnn.conv_lstm2d import ConvLSTM2D +from keras.src.layers.rnn.conv_lstm3d import ConvLSTM3D +from keras.src.layers.rnn.gru import GRU +from keras.src.layers.rnn.gru import GRUCell +from keras.src.layers.rnn.lstm import LSTM +from keras.src.layers.rnn.lstm import LSTMCell +from keras.src.layers.rnn.rnn import RNN +from keras.src.layers.rnn.simple_rnn import SimpleRNN +from keras.src.layers.rnn.simple_rnn import SimpleRNNCell +from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells +from keras.src.layers.rnn.time_distributed import TimeDistributed +from keras.src.saving import serialization_lib + + +@keras_export("keras.layers.serialize") +def serialize(layer): + """Returns the layer configuration as a Python dict. + + Args: + layer: A `keras.layers.Layer` instance to serialize. + + Returns: + Python dict which contains the configuration of the layer. + """ + return serialization_lib.serialize_keras_object(layer) + + +@keras_export("keras.layers.deserialize") +def deserialize(config, custom_objects=None): + """Returns a Keras layer object via its configuration. + + Args: + config: A python dict containing a serialized layer configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras layer instance. + """ + obj = serialization_lib.deserialize_keras_object( + config, + custom_objects=custom_objects, + ) + if not isinstance(obj, Layer): + raise ValueError( + "`keras.layers.deserialize` was passed a `config` object that is " + f"not a `keras.layers.Layer`. Received: {config}" + ) + return obj diff --git a/keras/src/layers/activations/__init__.py b/keras/src/layers/activations/__init__.py new file mode 100644 index 000000000000..009ce976c51b --- /dev/null +++ b/keras/src/layers/activations/__init__.py @@ -0,0 +1,5 @@ +from keras.src.layers.activations.elu import ELU +from keras.src.layers.activations.leaky_relu import LeakyReLU +from keras.src.layers.activations.prelu import PReLU +from keras.src.layers.activations.relu import ReLU +from keras.src.layers.activations.softmax import Softmax diff --git a/keras/src/layers/activations/activation.py b/keras/src/layers/activations/activation.py new file mode 100644 index 000000000000..16b6a9748d95 --- /dev/null +++ b/keras/src/layers/activations/activation.py @@ -0,0 +1,41 @@ +from keras.src import activations +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Activation") +class Activation(Layer): + """Applies an activation function to an output. + + Args: + activation: Activation function. It could be a callable, or the name of + an activation from the `keras.activations` namespace. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + + Example: + + >>> layer = keras.layers.Activation('relu') + >>> layer(np.array([-3.0, -1.0, 0.0, 2.0])) + [0.0, 0.0, 0.0, 2.0] + >>> layer = keras.layers.Activation(keras.activations.relu) + >>> layer(np.array([-3.0, -1.0, 0.0, 2.0])) + [0.0, 0.0, 0.0, 2.0] + """ + + def __init__(self, activation, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + self.activation = activations.get(activation) + + self._build_at_init() + + def call(self, inputs): + return self.activation(inputs) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = {"activation": activations.serialize(self.activation)} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/activations/activation_test.py b/keras/src/layers/activations/activation_test.py new file mode 100644 index 000000000000..82400648351f --- /dev/null +++ b/keras/src/layers/activations/activation_test.py @@ -0,0 +1,38 @@ +import pytest + +from keras.src import activations +from keras.src import layers +from keras.src import testing + + +class ActivationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_activation_basics(self): + self.run_layer_test( + layers.Activation, + init_kwargs={ + "activation": "relu", + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + self.run_layer_test( + layers.Activation, + init_kwargs={ + "activation": activations.gelu, + }, + input_shape=(2, 2), + expected_output_shape=(2, 2), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) diff --git a/keras/src/layers/activations/elu.py b/keras/src/layers/activations/elu.py new file mode 100644 index 000000000000..5a63ee8e8e32 --- /dev/null +++ b/keras/src/layers/activations/elu.py @@ -0,0 +1,33 @@ +from keras.src import activations +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.ELU") +class ELU(Layer): + """Applies an Exponential Linear Unit function to an output. + + Formula: + + ``` + f(x) = alpha * (exp(x) - 1.) for x < 0 + f(x) = x for x >= 0 + ``` + + Args: + alpha: float, slope of negative section. Defaults to `1.0`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + """ + + def __init__(self, alpha=1.0, **kwargs): + super().__init__(**kwargs) + self.alpha = alpha + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs): + return activations.elu(inputs, alpha=self.alpha) + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/activations/elu_test.py b/keras/src/layers/activations/elu_test.py new file mode 100644 index 000000000000..de797e968e42 --- /dev/null +++ b/keras/src/layers/activations/elu_test.py @@ -0,0 +1,32 @@ +import numpy as np +import pytest + +from keras.src import testing +from keras.src.layers.activations import elu + + +class ELUTest(testing.TestCase): + def test_config(self): + elu_layer = elu.ELU() + self.run_class_serialization_test(elu_layer) + + @pytest.mark.requires_trainable_backend + def test_elu(self): + self.run_layer_test( + elu.ELU, + init_kwargs={}, + input_shape=(2, 3, 4), + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_correctness(self): + def np_elu(x, alpha=1.0): + return (x > 0) * x + (x <= 0) * alpha * (np.exp(x) - 1) + + x = np.random.random((2, 2, 5)) + elu_layer = elu.ELU() + self.assertAllClose(elu_layer(x), np_elu(x)) + + elu_layer = elu.ELU(alpha=0.7) + self.assertAllClose(elu_layer(x), np_elu(x, alpha=0.7)) diff --git a/keras/src/layers/activations/leaky_relu.py b/keras/src/layers/activations/leaky_relu.py new file mode 100644 index 000000000000..3b5602e0dbb7 --- /dev/null +++ b/keras/src/layers/activations/leaky_relu.py @@ -0,0 +1,67 @@ +import warnings + +from keras.src import activations +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.LeakyReLU") +class LeakyReLU(Layer): + """Leaky version of a Rectified Linear Unit activation layer. + + This layer allows a small gradient when the unit is not active. + + Formula: + + ``` python + f(x) = alpha * x if x < 0 + f(x) = x if x >= 0 + ``` + + Example: + + ``` python + leaky_relu_layer = LeakyReLU(negative_slope=0.5) + input = np.array([-10, -5, 0.0, 5, 10]) + result = leaky_relu_layer(input) + # result = [-5. , -2.5, 0. , 5. , 10.] + ``` + + Args: + negative_slope: Float >= 0.0. Negative slope coefficient. + Defaults to `0.3`. + **kwargs: Base layer keyword arguments, such as + `name` and `dtype`. + + """ + + def __init__(self, negative_slope=0.3, **kwargs): + if "alpha" in kwargs: + negative_slope = kwargs.pop("alpha") + warnings.warn( + "Argument `alpha` is deprecated. Use `negative_slope` instead." + ) + super().__init__(**kwargs) + if negative_slope is None or negative_slope < 0: + raise ValueError( + "The negative_slope value of a Leaky ReLU layer " + "cannot be None or negative value. Expected a float." + f" Received: negative_slope={negative_slope}" + ) + self.negative_slope = negative_slope + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs): + return activations.leaky_relu( + inputs, negative_slope=self.negative_slope + ) + + def get_config(self): + config = super().get_config() + config.update({"negative_slope": self.negative_slope}) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/activations/leaky_relu_test.py b/keras/src/layers/activations/leaky_relu_test.py new file mode 100644 index 000000000000..2b06b17b9b6e --- /dev/null +++ b/keras/src/layers/activations/leaky_relu_test.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest + +from keras.src import testing +from keras.src.layers.activations import leaky_relu + + +class LeakyReLUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_leaky_relu(self): + self.run_layer_test( + leaky_relu.LeakyReLU, + init_kwargs={ + "negative_slope": 1, + }, + input_shape=(2, 3, 4), + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_leaky_relu_correctness(self): + leaky_relu_layer = leaky_relu.LeakyReLU(negative_slope=0.5) + input = np.array([-10, -5, 0.0, 5, 10]) + expected_output = np.array([-5.0, -2.5, 0.0, 5.0, 10.0]) + result = leaky_relu_layer(input) + self.assertAllClose(result, expected_output) + + def test_invalid_usage(self): + with self.assertRaisesRegex( + ValueError, + "The negative_slope value of a Leaky ReLU layer cannot be None", + ): + self.run_layer_test( + leaky_relu.LeakyReLU, + init_kwargs={"negative_slope": None}, + input_shape=(2, 3, 4), + supports_masking=True, + ) diff --git a/keras/src/layers/activations/prelu.py b/keras/src/layers/activations/prelu.py new file mode 100644 index 000000000000..d4a054248c8d --- /dev/null +++ b/keras/src/layers/activations/prelu.py @@ -0,0 +1,98 @@ +from keras.src import activations +from keras.src import constraints +from keras.src import initializers +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.PReLU") +class PReLU(Layer): + """Parametric Rectified Linear Unit activation layer. + + Formula: + ``` python + f(x) = alpha * x for x < 0 + f(x) = x for x >= 0 + ``` + where `alpha` is a learned array with the same shape as x. + + Args: + alpha_initializer: Initializer function for the weights. + alpha_regularizer: Regularizer for the weights. + alpha_constraint: Constraint for the weights. + shared_axes: The axes along which to share learnable parameters for the + activation function. For example, if the incoming feature maps are + from a 2D convolution with output shape + `(batch, height, width, channels)`, and you wish to share parameters + across space so that each filter only has one set of parameters, + set `shared_axes=[1, 2]`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + """ + + def __init__( + self, + alpha_initializer="Zeros", + alpha_regularizer=None, + alpha_constraint=None, + shared_axes=None, + **kwargs, + ): + super().__init__(**kwargs) + self.supports_masking = True + self.alpha_initializer = initializers.get(alpha_initializer) + self.alpha_regularizer = regularizers.get(alpha_regularizer) + self.alpha_constraint = constraints.get(alpha_constraint) + if shared_axes is None: + self.shared_axes = None + elif not isinstance(shared_axes, (list, tuple)): + self.shared_axes = [shared_axes] + else: + self.shared_axes = list(shared_axes) + + def build(self, input_shape): + param_shape = list(input_shape[1:]) + if self.shared_axes is not None: + for i in self.shared_axes: + param_shape[i - 1] = 1 + self.alpha = self.add_weight( + shape=param_shape, + name="alpha", + initializer=self.alpha_initializer, + regularizer=self.alpha_regularizer, + constraint=self.alpha_constraint, + ) + # Set input spec + axes = {} + if self.shared_axes: + for i in range(1, len(input_shape)): + if i not in self.shared_axes: + axes[i] = input_shape[i] + self.input_spec = InputSpec(ndim=len(input_shape), axes=axes) + + def call(self, inputs): + pos = activations.relu(inputs) + neg = -self.alpha * activations.relu(-inputs) + return pos + neg + + def get_config(self): + config = super().get_config() + config.update( + { + "alpha_initializer": initializers.serialize( + self.alpha_initializer + ), + "alpha_regularizer": regularizers.serialize( + self.alpha_regularizer + ), + "alpha_constraint": constraints.serialize( + self.alpha_constraint + ), + "shared_axes": self.shared_axes, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/activations/prelu_test.py b/keras/src/layers/activations/prelu_test.py new file mode 100644 index 000000000000..63b4aee20617 --- /dev/null +++ b/keras/src/layers/activations/prelu_test.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest + +from keras.src import testing +from keras.src.layers.activations import prelu + + +class PReLUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_prelu(self): + self.run_layer_test( + prelu.PReLU, + init_kwargs={ + "alpha_initializer": "zeros", + "alpha_regularizer": "L1", + "alpha_constraint": "MaxNorm", + "shared_axes": 1, + }, + input_shape=(2, 3, 4), + supports_masking=True, + ) + + def test_prelu_correctness(self): + def np_prelu(x, alpha): + return (x > 0) * x + (x <= 0) * alpha * x + + inputs = np.random.randn(2, 10, 5, 3) + prelu_layer = prelu.PReLU( + alpha_initializer="glorot_uniform", + alpha_regularizer="l1", + alpha_constraint="non_neg", + shared_axes=(1, 2), + ) + prelu_layer.build(inputs.shape) + + weights = np.random.random((1, 1, 3)) + prelu_layer.alpha.assign(weights) + ref_out = np_prelu(inputs, weights) + self.assertAllClose(prelu_layer(inputs), ref_out) diff --git a/keras/src/layers/activations/relu.py b/keras/src/layers/activations/relu.py new file mode 100644 index 000000000000..72629ce32d98 --- /dev/null +++ b/keras/src/layers/activations/relu.py @@ -0,0 +1,87 @@ +from keras.src import activations +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.ReLU") +class ReLU(Layer): + """Rectified Linear Unit activation function layer. + + Formula: + ``` python + f(x) = max(x,0) + f(x) = max_value if x >= max_value + f(x) = x if threshold <= x < max_value + f(x) = negative_slope * (x - threshold) otherwise + ``` + + Example: + ``` python + relu_layer = keras.layers.ReLU( + max_value=10, + negative_slope=0.5, + threshold=0, + ) + input = np.array([-10, -5, 0.0, 5, 10]) + result = relu_layer(input) + # result = [-5. , -2.5, 0. , 5. , 10.] + ``` + + Args: + max_value: Float >= 0. Maximum activation value. None means unlimited. + Defaults to `None`. + negative_slope: Float >= 0. Negative slope coefficient. + Defaults to `0.0`. + threshold: Float >= 0. Threshold value for thresholded activation. + Defaults to `0.0`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + """ + + def __init__( + self, max_value=None, negative_slope=0.0, threshold=0.0, **kwargs + ): + super().__init__(**kwargs) + if max_value is not None and max_value < 0.0: + raise ValueError( + "max_value of a ReLU layer cannot be a negative " + f"value. Received: max_value={max_value}" + ) + if negative_slope is None or negative_slope < 0.0: + raise ValueError( + "negative_slope of a ReLU layer cannot be a negative " + f"value. Received: negative_slope={negative_slope}" + ) + if threshold is None or threshold < 0.0: + raise ValueError( + "threshold of a ReLU layer cannot be a negative " + f"value. Received: threshold={threshold}" + ) + + self.max_value = max_value + self.negative_slope = negative_slope + self.threshold = threshold + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs): + return activations.relu( + inputs, + negative_slope=self.negative_slope, + max_value=self.max_value, + threshold=self.threshold, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "max_value": self.max_value, + "negative_slope": self.negative_slope, + "threshold": self.threshold, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/activations/relu_test.py b/keras/src/layers/activations/relu_test.py new file mode 100644 index 000000000000..781d816ae9ad --- /dev/null +++ b/keras/src/layers/activations/relu_test.py @@ -0,0 +1,87 @@ +import numpy as np +import pytest + +from keras.src import testing +from keras.src.layers.activations import relu + + +class ReLUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_relu(self): + self.run_layer_test( + relu.ReLU, + init_kwargs={ + "max_value": 10, + "negative_slope": 1, + "threshold": 0.5, + }, + input_shape=(2, 3, 4), + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_normal_relu_correctness(self): + relu_layer = relu.ReLU(max_value=10, negative_slope=0.0, threshold=0) + input = np.array([-10, -5, 0.0, 5, 10]) + expected_output = np.array([0.0, 0.0, 0.0, 5.0, 10.0]) + result = relu_layer(input) + self.assertAllClose(result, expected_output) + + def test_leaky_relu_correctness(self): + relu_layer = relu.ReLU(max_value=10, negative_slope=0.5, threshold=0) + input = np.array([-10, -5, 0.0, 5, 10]) + expected_output = np.array([-5.0, -2.5, 0.0, 5.0, 10.0]) + result = relu_layer(input) + self.assertAllClose(result, expected_output) + + def test_threshold_relu_correctness(self): + relu_layer = relu.ReLU(max_value=8, negative_slope=0.0, threshold=5) + input = np.array([6.0, 7.0, 0.0, 5, 10]) + expected_output = np.array([6.0, 7.0, 0.0, 0.0, 8.0]) + result = relu_layer(input) + self.assertAllClose(result, expected_output) + + def test_invalid_usage(self): + with self.assertRaisesRegex( + ValueError, + "max_value of a ReLU layer cannot be a negative value", + ): + self.run_layer_test( + relu.ReLU, + init_kwargs={ + "max_value": -10, + "negative_slope": 1, + "threshold": 0.5, + }, + input_shape=(2, 3, 4), + supports_masking=True, + ) + + with self.assertRaisesRegex( + ValueError, + "negative_slope of a ReLU layer cannot be a negative value", + ): + self.run_layer_test( + relu.ReLU, + init_kwargs={ + "max_value": 10, + "negative_slope": -10, + "threshold": 0.5, + }, + input_shape=(2, 3, 4), + supports_masking=True, + ) + + with self.assertRaisesRegex( + ValueError, "threshold of a ReLU layer cannot be a negative value" + ): + self.run_layer_test( + relu.ReLU, + init_kwargs={ + "max_value": 10, + "negative_slope": 1, + "threshold": -10, + }, + input_shape=(2, 3, 4), + supports_masking=True, + ) diff --git a/keras/src/layers/activations/softmax.py b/keras/src/layers/activations/softmax.py new file mode 100644 index 000000000000..8660877977ec --- /dev/null +++ b/keras/src/layers/activations/softmax.py @@ -0,0 +1,87 @@ +from keras.src import activations +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +def _large_negative_number(dtype): + """Return a Large negative number based on dtype.""" + if backend.standardize_dtype(dtype) == "float16": + return -3e4 + return -1e9 + + +@keras_export("keras.layers.Softmax") +class Softmax(Layer): + """Softmax activation layer. + + Formula: + ``` python + exp_x = exp(x - max(x)) + f(x) = exp_x / sum(exp_x) + ``` + + Example: + >>> softmax_layer = keras.layers.Softmax() + >>> input = np.array([1.0, 2.0, 1.0]) + >>> result = softmax_layer(input) + >>> result + [0.21194157, 0.5761169, 0.21194157] + + + Args: + axis: Integer, or list of Integers, axis along which the softmax + normalization is applied. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + + Call arguments: + inputs: The inputs (logits) to the softmax layer. + mask: A boolean mask of the same shape as `inputs`. The mask + specifies 1 to keep and 0 to mask. Defaults to `None`. + + Returns: + Softmaxed output with the same shape as `inputs`. + """ + + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs, mask=None): + if mask is not None: + adder = ( + 1.0 - backend.cast(mask, inputs.dtype) + ) * _large_negative_number(inputs.dtype) + inputs += adder + if isinstance(self.axis, (tuple, list)): + if len(self.axis) > 1: + outputs = backend.numpy.exp( + inputs + - backend.math.logsumexp( + inputs, axis=self.axis, keepdims=True + ) + ) + else: + outputs = activations.softmax(inputs, axis=self.axis[0]) + else: + outputs = activations.softmax(inputs, axis=self.axis) + + if mask is not None: + # Apply the mask to the softmax output to ensure that masked + # values are set to 0 in case the entire axis is masked. + outputs = backend.numpy.multiply( + outputs, backend.cast(mask, outputs.dtype) + ) + + return outputs + + def get_config(self): + config = super().get_config() + config.update({"axis": self.axis}) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/activations/softmax_test.py b/keras/src/layers/activations/softmax_test.py new file mode 100644 index 000000000000..e5428854451e --- /dev/null +++ b/keras/src/layers/activations/softmax_test.py @@ -0,0 +1,88 @@ +import numpy as np +import pytest + +from keras.src import testing +from keras.src.layers.activations import softmax + + +class SoftmaxTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_softmax(self): + self.run_layer_test( + softmax.Softmax, + init_kwargs={}, + input_shape=(2, 3, 4), + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_softmax_correctness(self): + softmax_layer = softmax.Softmax() + input = np.array([[1.0, 2.0, 1.0], [1.0, 2.0, 1.0]]) + expected_output = np.array( + [ + [0.21194157, 0.5761169, 0.21194157], + [0.21194157, 0.5761169, 0.21194157], + ] + ) + result = softmax_layer(input) + self.assertAllClose(result, expected_output) + + def test_softmax_correctness_with_mask(self): + softmax_layer = softmax.Softmax(axis=(1, 0)) + input = np.array([[1.0, 2.0, 1.0], [1.0, 2.0, 1.0]]) + mask = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]) + expected_output = np.array( + [[0.21194154, 0.0, 0.21194154], [0.0, 0.57611686, 0.0]] + ) + result = softmax_layer(input, mask=mask) + self.assertAllClose(result, expected_output) + + def test_softmax_correctness_with_axis(self): + softmax_layer = softmax.Softmax(axis=(1)) + input = np.array([[1.0, 2.0, 1.0], [1.0, 2.0, 1.0]]) + expected_output = np.array( + [ + [0.21194157, 0.5761169, 0.21194157], + [0.21194157, 0.5761169, 0.21194157], + ] + ) + result = softmax_layer(input) + self.assertAllClose(result, expected_output) + + def test_softmax_masked_values_are_zero_including_fully_masked(self): + """ + Tests softmax with mask on default axis (-1). + Ensures output is 0 where mask is False. + Includes a row where all elements are masked. + """ + softmax_layer = softmax.Softmax() # Default axis = -1 + + input = np.array( + [ + [1.0, 2.0, 5.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [3.0, 1.0, 2.0, 4.0], + ], + dtype=np.float32, + ) + mask = np.array( + [ + [True, True, False, False], # Partially masked + [False, False, False, False], # Fully masked + [True, True, True, True], # Not masked + ], + dtype=bool, + ) + + expected_output = np.array( + [ + [0.268941, 0.731059, 0.0, 0.0], # last two masked + [0.0, 0.0, 0.0, 0.0], # Fully masked row should be all zeros + [0.236883, 0.032059, 0.087144, 0.643914], + ] + ) + + result = softmax_layer(input, mask=mask) + + self.assertAllClose(result, expected_output) diff --git a/keras/src/layers/attention/__init__.py b/keras/src/layers/attention/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/attention/additive_attention.py b/keras/src/layers/attention/additive_attention.py new file mode 100644 index 000000000000..6dac093d09d7 --- /dev/null +++ b/keras/src/layers/attention/additive_attention.py @@ -0,0 +1,102 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.attention.attention import Attention + + +@keras_export("keras.layers.AdditiveAttention") +class AdditiveAttention(Attention): + """Additive attention layer, a.k.a. Bahdanau-style attention. + + Inputs are a list with 2 or 3 elements: + 1. A `query` tensor of shape `(batch_size, Tq, dim)`. + 2. A `value` tensor of shape `(batch_size, Tv, dim)`. + 3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none + supplied, `value` will be used as `key`. + + The calculation follows the steps: + 1. Calculate attention scores using `query` and `key` with shape + `(batch_size, Tq, Tv)` as a non-linear sum + `scores = reduce_sum(tanh(query + key), axis=-1)`. + 2. Use scores to calculate a softmax distribution with shape + `(batch_size, Tq, Tv)`. + 3. Use the softmax distribution to create a linear combination of `value` + with shape `(batch_size, Tq, dim)`. + + Args: + use_scale: If `True`, will create a scalar variable to scale the + attention scores. + dropout: Float between 0 and 1. Fraction of the units to drop for the + attention scores. Defaults to `0.0`. + + Call arguments: + inputs: List of the following tensors: + - `query`: Query tensor of shape `(batch_size, Tq, dim)`. + - `value`: Value tensor of shape `(batch_size, Tv, dim)`. + - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If + not given, will use `value` for both `key` and `value`, which is + the most common case. + mask: List of the following tensors: + - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`. + If given, the output will be zero at the positions where + `mask==False`. + - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`. + If given, will apply the mask such that values at positions + where `mask==False` do not contribute to the result. + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds + a mask such that position `i` cannot attend to positions `j > i`. + This prevents the flow of information from the future towards the + past. Defaults to `False`. + + Output: + Attention outputs of shape `(batch_size, Tq, dim)`. + (Optional) Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + + def __init__( + self, + use_scale=True, + dropout=0.0, + **kwargs, + ): + super().__init__(use_scale=use_scale, dropout=dropout, **kwargs) + + def build(self, input_shape): + self._validate_inputs(input_shape) + dim = input_shape[0][-1] + self.scale = None + if self.use_scale: + self.scale = self.add_weight( + name="scale", + shape=[dim], + initializer="glorot_uniform", + dtype=self.dtype, + trainable=True, + ) + + def _calculate_scores(self, query, key): + """Calculates attention scores as a nonlinear sum of query and key. + + Args: + query: Query tensor of shape `(batch_size, Tq, dim)`. + key: Key tensor of shape `(batch_size, Tv, dim)`. + + Returns: + Tensor of shape `(batch_size, Tq, Tv)`. + """ + # Reshape tensors to enable broadcasting. + # Reshape into [batch_size, Tq, 1, dim]. + q_reshaped = ops.expand_dims(query, axis=-2) + # Reshape into [batch_size, 1, Tv, dim]. + k_reshaped = ops.expand_dims(key, axis=-3) + scale = self.scale if self.use_scale else 1.0 + return ops.sum(scale * ops.tanh(q_reshaped + k_reshaped), axis=-1) + + def get_config(self): + base_config = super().get_config() + del base_config["score_mode"] + return base_config diff --git a/keras/src/layers/attention/additive_attention_test.py b/keras/src/layers/attention/additive_attention_test.py new file mode 100644 index 000000000000..51092c6c4918 --- /dev/null +++ b/keras/src/layers/attention/additive_attention_test.py @@ -0,0 +1,86 @@ +import numpy as np + +from keras.src import layers +from keras.src import testing + + +class AdditiveAttentionTest(testing.TestCase): + def test_attention_basics(self): + # No scale + self.run_layer_test( + layers.AdditiveAttention, + init_kwargs={ + "use_scale": True, + "dropout": 0.5, + }, + input_shape=[(2, 3, 4), (2, 4, 4)], + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + # With scale. + self.run_layer_test( + layers.AdditiveAttention, + init_kwargs={ + "use_scale": False, + "dropout": 0.5, + }, + input_shape=[(2, 3, 4), (2, 4, 4)], + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + def test_attention_correctness(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + + layer = layers.AdditiveAttention(use_scale=False) + output, scores = layer( + [query, value, key], + return_attention_scores=True, + ) + self.assertAllClose( + output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3 + ) + self.assertAllClose( + scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3 + ) + + def test_attention_with_mask(self): + layer = layers.AdditiveAttention(use_scale=False) + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + value = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + query_mask = np.array([[True, False]]) + value_mask = np.array([[True, False]]) + output, scores = layer( + [query, value], + mask=[query_mask, value_mask], + return_attention_scores=True, + ) + self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]]) + self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]]) + + def test_attention_errors(self): + layer = layers.AdditiveAttention() + tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + with self.assertRaisesRegex(ValueError, "must be called on a list"): + layer(tensor) + + with self.assertRaisesRegex(ValueError, "length 2 or 3"): + layer([tensor, tensor, tensor, tensor]) + + with self.assertRaisesRegex(ValueError, "layer mask must be a list"): + layer([tensor, tensor], mask=tensor) + + with self.assertRaisesRegex(ValueError, "length 2 or 3"): + layer([tensor, tensor], mask=[tensor]) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py new file mode 100644 index 000000000000..04e3f399c5e5 --- /dev/null +++ b/keras/src/layers/attention/attention.py @@ -0,0 +1,331 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Attention") +class Attention(Layer): + """Dot-product attention layer, a.k.a. Luong-style attention. + + Inputs are a list with 2 or 3 elements: + 1. A `query` tensor of shape `(batch_size, Tq, dim)`. + 2. A `value` tensor of shape `(batch_size, Tv, dim)`. + 3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none + supplied, `value` will be used as a `key`. + + The calculation follows the steps: + 1. Calculate attention scores using `query` and `key` with shape + `(batch_size, Tq, Tv)`. + 2. Use scores to calculate a softmax distribution with shape + `(batch_size, Tq, Tv)`. + 3. Use the softmax distribution to create a linear combination of `value` + with shape `(batch_size, Tq, dim)`. + + Args: + use_scale: If `True`, will create a scalar variable to scale the + attention scores. + dropout: Float between 0 and 1. Fraction of the units to drop for the + attention scores. Defaults to `0.0`. + seed: A Python integer to use as random seed in case of `dropout`. + score_mode: Function to use to compute attention scores, one of + `{"dot", "concat"}`. `"dot"` refers to the dot product between the + query and key vectors. `"concat"` refers to the hyperbolic tangent + of the concatenation of the `query` and `key` vectors. + + Call arguments: + inputs: List of the following tensors: + - `query`: Query tensor of shape `(batch_size, Tq, dim)`. + - `value`: Value tensor of shape `(batch_size, Tv, dim)`. + - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If + not given, will use `value` for both `key` and `value`, which is + the most common case. + mask: List of the following tensors: + - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`. + If given, the output will be zero at the positions where + `mask==False`. + - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`. + If given, will apply the mask such that values at positions + where `mask==False` do not contribute to the result. + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds + a mask such that position `i` cannot attend to positions `j > i`. + This prevents the flow of information from the future towards the + past. Defaults to `False`. + + Output: + Attention outputs of shape `(batch_size, Tq, dim)`. + (Optional) Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + + def __init__( + self, + use_scale=False, + score_mode="dot", + dropout=0.0, + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.use_scale = use_scale + self.score_mode = score_mode + self.dropout = dropout + if self.dropout > 0: + self.seed_generator = backend.random.SeedGenerator(seed=seed) + + if self.score_mode not in ["dot", "concat"]: + raise ValueError( + "Invalid value for argument score_mode. " + "Expected one of {'dot', 'concat'}. " + f"Received: score_mode={score_mode}" + ) + + self._return_attention_scores = False + + def build(self, input_shape): + self._validate_inputs(input_shape) + self.scale = None + self.concat_score_weight = None + if self.use_scale: + self.scale = self.add_weight( + name="scale", + shape=(), + initializer="ones", + dtype=self.dtype, + trainable=True, + ) + if self.score_mode == "concat": + self.concat_score_weight = self.add_weight( + name="concat_score_weight", + shape=(), + initializer="ones", + dtype=self.dtype, + trainable=True, + ) + + def _calculate_scores(self, query, key): + """Calculates attention scores as a query-key dot product. + + Args: + query: Query tensor of shape `(batch_size, Tq, dim)`. + key: Key tensor of shape `(batch_size, Tv, dim)`. + + Returns: + Tensor of shape `(batch_size, Tq, Tv)`. + """ + if self.score_mode == "dot": + scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1])) + if self.scale is not None: + scores = ops.multiply(scores, self.scale) + elif self.score_mode == "concat": + # Reshape tensors to enable broadcasting. + # Reshape into [batch_size, Tq, 1, dim]. + q_reshaped = ops.expand_dims(query, axis=-2) + # Reshape into [batch_size, 1, Tv, dim]. + k_reshaped = ops.expand_dims(key, axis=-3) + if self.scale is not None: + scores = self.concat_score_weight * ops.sum( + ops.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1 + ) + else: + scores = self.concat_score_weight * ops.sum( + ops.tanh(q_reshaped + k_reshaped), axis=-1 + ) + else: + raise ValueError("scores not computed") + + return scores + + def _apply_scores(self, scores, value, scores_mask=None, training=False): + """Applies attention scores to the given value tensor. + + To use this method in your attention layer, follow the steps: + + * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of + shape `(batch_size, Tv)` to calculate the attention `scores`. + * Pass `scores` and `value` tensors to this method. The method applies + `scores_mask`, calculates + `attention_distribution = softmax(scores)`, then returns + `matmul(attention_distribution, value). + * Apply `query_mask` and return the result. + + Args: + scores: Scores float tensor of shape `(batch_size, Tq, Tv)`. + value: Value tensor of shape `(batch_size, Tv, dim)`. + scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)` + or `(batch_size, Tq, Tv)`. If given, scores at positions where + `scores_mask==False` do not contribute to the result. It must + contain at least one `True` value in each line along the last + dimension. + training: Python boolean indicating whether the layer should behave + in training mode (adding dropout) or in inference mode + (no dropout). + + Returns: + Tensor of shape `(batch_size, Tq, dim)`. + Attention scores after masking and softmax with shape + `(batch_size, Tq, Tv)`. + """ + if scores_mask is not None: + padding_mask = ops.logical_not(scores_mask) + # Bias so padding positions do not contribute to attention + # distribution. Note 65504. is the max float16 value. + max_value = 65504.0 if scores.dtype == "float16" else 1.0e9 + if len(padding_mask.shape) == 2: + padding_mask = ops.expand_dims(padding_mask, axis=-2) + scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype) + + weights = ops.softmax(scores, axis=-1) + if training and self.dropout > 0: + weights = backend.random.dropout( + weights, + self.dropout, + seed=self.seed_generator, + ) + return ops.matmul(weights, value), weights + + def _calculate_score_mask(self, scores, v_mask, use_causal_mask): + if use_causal_mask: + # Creates a lower triangular mask, so position i cannot attend to + # positions j > i. This prevents the flow of information from the + # future into the past. + score_shape = ops.shape(scores) + # causal_mask_shape = [1, Tq, Tv]. + mask_shape = (1, score_shape[-2], score_shape[-1]) + ones_mask = ops.ones(shape=mask_shape, dtype="int32") + row_index = ops.cumsum(ones_mask, axis=-2) + col_index = ops.cumsum(ones_mask, axis=-1) + causal_mask = ops.greater_equal(row_index, col_index) + + if v_mask is not None: + # Mask of shape [batch_size, 1, Tv]. + v_mask = ops.expand_dims(v_mask, axis=-2) + return ops.logical_and(v_mask, causal_mask) + return causal_mask + else: + # If not using causal mask, return the value mask as is, + # or None if the value mask is not provided. + return v_mask + + def call( + self, + inputs, + mask=None, + training=False, + return_attention_scores=False, + use_causal_mask=False, + ): + self._validate_inputs(inputs=inputs, mask=mask) + self._return_attention_scores = return_attention_scores + q = inputs[0] + v = inputs[1] + k = inputs[2] if len(inputs) > 2 else v + q_mask = mask[0] if mask else None + v_mask = mask[1] if mask else None + scores = self._calculate_scores(query=q, key=k) + scores_mask = self._calculate_score_mask( + scores, v_mask, use_causal_mask + ) + attention_output, attention_scores = self._apply_scores( + scores=scores, value=v, scores_mask=scores_mask, training=training + ) + if q_mask is not None: + # Mask of shape [batch_size, Tq, 1]. + q_mask = ops.expand_dims(q_mask, axis=-1) + attention_output *= ops.cast(q_mask, dtype=attention_output.dtype) + if return_attention_scores: + return (attention_output, attention_scores) + else: + return attention_output + + def compute_mask(self, inputs, mask=None): + self._validate_inputs(inputs=inputs, mask=mask) + if mask is None or mask[0] is None: + return None + return ops.convert_to_tensor(mask[0]) + + def compute_output_shape(self, input_shape): + query_shape, value_shape, key_shape = input_shape + if key_shape is None: + key_shape = value_shape + + output_shape = (*query_shape[:-1], value_shape[-1]) + if self._return_attention_scores: + scores_shape = (query_shape[0], query_shape[1], key_shape[1]) + return output_shape, scores_shape + return output_shape + + def compute_output_spec( + self, + inputs, + mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + # Validate and unpack inputs + self._validate_inputs(inputs, mask) + query = inputs[0] + value = inputs[1] + key = inputs[2] if len(inputs) > 2 else value + + # Compute primary output shape + output_shape = self.compute_output_shape( + [query.shape, value.shape, key.shape] + ) + output_spec = KerasTensor(output_shape, dtype=self.compute_dtype) + + # Handle attention scores if requested + if self._return_attention_scores or return_attention_scores: + scores_shape = ( + query.shape[0], + query.shape[1], + key.shape[1], + ) # (batch_size, Tq, Tv) + attention_scores_spec = KerasTensor( + scores_shape, dtype=self.compute_dtype + ) + return (output_spec, attention_scores_spec) + + return output_spec + + def _validate_inputs(self, inputs, mask=None): + """Validates arguments of the call method.""" + class_name = self.__class__.__name__ + if not isinstance(inputs, list): + raise ValueError( + f"{class_name} layer must be called on a list of inputs, " + "namely [query, value] or [query, value, key]. " + f"Received: inputs={inputs}." + ) + if len(inputs) < 2 or len(inputs) > 3: + raise ValueError( + f"{class_name} layer accepts inputs list of length 2 or 3, " + "namely [query, value] or [query, value, key]. " + f"Received length: {len(inputs)}." + ) + if mask is not None: + if not isinstance(mask, list): + raise ValueError( + f"{class_name} layer mask must be a list, " + f"namely [query_mask, value_mask]. Received: mask={mask}." + ) + if len(mask) < 2 or len(mask) > 3: + raise ValueError( + f"{class_name} layer accepts mask list of length 2 or 3. " + f"Received: inputs={inputs}, mask={mask}." + ) + + def get_config(self): + base_config = super().get_config() + config = { + "use_scale": self.use_scale, + "score_mode": self.score_mode, + "dropout": self.dropout, + } + return {**base_config, **config} diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py new file mode 100644 index 000000000000..805314010996 --- /dev/null +++ b/keras/src/layers/attention/attention_test.py @@ -0,0 +1,448 @@ +import numpy as np + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class AttentionTest(testing.TestCase): + def test_attention_basics(self): + # No scale, no concat. + self.run_layer_test( + layers.Attention, + init_kwargs={ + "score_mode": "dot", + "dropout": 0.5, + }, + input_shape=[(2, 3, 4), (2, 4, 4)], + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + # Scale and concat. + self.run_layer_test( + layers.Attention, + init_kwargs={ + "use_scale": True, + "score_mode": "concat", + "dropout": 0.5, + }, + input_shape=[(2, 3, 4), (2, 4, 4)], + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + def test_attention_correctness(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + + # Dot. + layer = layers.Attention(score_mode="dot") + output, scores = layer( + [query, value, key], + return_attention_scores=True, + ) + self.assertAllClose( + output, [[[2.462, 3.462], [1.538, 2.538]]], atol=1e-3 + ) + self.assertAllClose( + scores, [[[0.269, 0.731], [0.731, 0.269]]], atol=1e-3 + ) + + # Concat. + layer = layers.Attention(score_mode="concat") + output, scores = layer( + [query, value, key], + return_attention_scores=True, + ) + self.assertAllClose( + output, [[[1.727, 2.727], [2.272, 3.272]]], atol=1e-3 + ) + self.assertAllClose( + scores, [[[0.636, 0.363], [0.363, 0.636]]], atol=1e-3 + ) + + def test_attention_with_mask(self): + layer = layers.Attention() + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + value = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + query_mask = np.array([[True, False]]) + value_mask = np.array([[True, False]]) + output, scores = layer( + [query, value], + mask=[query_mask, value_mask], + return_attention_scores=True, + ) + self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]]) + self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]]) + + def test_attention_2D_mask_shape_mismatch(self): + layer = layers.Attention() + batch_size, Tq, Tv, dim = 2, 3, 4, 5 + query = np.random.random((batch_size, Tq, dim)).astype(np.float32) + value = np.random.random((batch_size, Tv, dim)).astype(np.float32) + query_mask = np.array([[True, False, True], [True, False, True]]) + value_mask = np.array( + [[True, False, True, True], [True, False, True, True]] + ) + output, scores = layer( + [query, value], + mask=[query_mask, value_mask], + return_attention_scores=True, + ) + self.assertEqual(output.shape, (batch_size, Tq, dim)) + self.assertEqual(scores.shape, (batch_size, Tq, Tv)) + + def test_attention_errors(self): + layer = layers.Attention() + tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + with self.assertRaisesRegex(ValueError, "must be called on a list"): + layer(tensor) + + with self.assertRaisesRegex(ValueError, "length 2 or 3"): + layer([tensor, tensor, tensor, tensor]) + + with self.assertRaisesRegex(ValueError, "layer mask must be a list"): + layer([tensor, tensor], mask=tensor) + + with self.assertRaisesRegex(ValueError, "length 2 or 3"): + layer([tensor, tensor], mask=[tensor]) + + def test_attention_with_dropout(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + value = np.array([[[1.0, 1.0], [1.0, 1.0]]]) + layer_with_dropout = layers.Attention(dropout=0.2) + layer_without_dropout = layers.Attention() + + output1, scores1 = layer_with_dropout( + [query, value], return_attention_scores=True, training=True + ) + output2, scores2 = layer_without_dropout( + [query, value], return_attention_scores=True, training=True + ) + self.assertNotAllClose(output1, output2) + self.assertNotAllClose(scores1, scores2) + + def test_attention_invalid_score_mode(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument score_mode. " + "Expected one of {'dot', 'concat'}", + ): + layers.Attention(score_mode="invalid_mode") + + def test_attention_calculate_scores_with_scale(self): + query = np.random.random((2, 3, 4)) + key = np.random.random((2, 4, 4)) + layer = layers.Attention(use_scale=True, score_mode="dot") + layer.build(input_shape=[(2, 3, 4), (2, 4, 4)]) + expected_scores = np.matmul(query, key.transpose((0, 2, 1))) + expected_scores *= layer.scale.numpy() + actual_scores = layer._calculate_scores(query, key) + self.assertAllClose(actual_scores, expected_scores) + + def test_attention_calculate_score_mask_no_causal_no_vmask(self): + scores = np.random.random((2, 3, 4)) + layer = layers.Attention() + mask = layer._calculate_score_mask( + scores, v_mask=None, use_causal_mask=False + ) + self.assertIsNone( + mask, + "Mask should be None when no causal mask and no value mask " + "are used", + ) + + def test_attention_calculate_score_mask_with_causal_no_vmask(self): + scores = np.random.random((2, 3, 4)) + layer = layers.Attention() + + causal_mask = layer._calculate_score_mask( + scores, v_mask=None, use_causal_mask=True + ) + expected_causal_mask = np.tril( + np.ones((1, scores.shape[1], scores.shape[2])), k=0 + ) + self.assertAllClose(causal_mask, expected_causal_mask, atol=1e-6) + + def test_attention_calculate_score_mask_with_causal_and_vmask(self): + scores = np.random.random((2, 3, 4)) + layer = layers.Attention() + v_mask = np.array([[True, False, True, False]]) + + combined_mask = layer._calculate_score_mask( + scores, v_mask=v_mask, use_causal_mask=True + ) + expected_causal_mask = np.tril( + np.ones((1, scores.shape[1], scores.shape[2])), k=0 + ) + expected_combined_mask = np.logical_and( + expected_causal_mask, v_mask[:, np.newaxis, :] + ) + self.assertAllClose(combined_mask, expected_combined_mask, atol=1e-6) + + def test_attention_compute_mask_with_no_mask(self): + layer = layers.Attention() + dummy_inputs = [ + np.random.random((2, 3, 4)), + np.random.random((2, 4, 4)), + ] + self.assertIsNone( + layer.compute_mask(inputs=dummy_inputs, mask=None), + "compute_mask should return None when mask is None", + ) + + def test_attention_compute_mask_with_first_element_none(self): + layer = layers.Attention() + dummy_inputs = [ + np.random.random((2, 3, 4)), + np.random.random((2, 4, 4)), + ] + mask = [None, np.array([True, False, True])] + self.assertIsNone( + layer.compute_mask(inputs=dummy_inputs, mask=mask), + "compute_mask should return None when the first element is None", + ) + + def test_attention_compute_mask_does_not_return_none_with_valid_mask(self): + layer = layers.Attention() + dummy_inputs = [ + np.random.random((2, 3, 4)), + np.random.random((2, 4, 4)), + ] + valid_mask = np.array([True, False, True]) + mask = [valid_mask, np.array([False, True, False])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + self.assertIsNotNone( + computed_mask, + "compute_mask should not return None with a valid mask", + ) + + def test_attention_compute_mask_returns_correct_tensor_with_valid_mask( + self, + ): + layer = layers.Attention() + dummy_inputs = [ + np.random.random((2, 3, 4)), + np.random.random((2, 4, 4)), + ] + valid_mask = np.array([True, False, True]) + mask = [valid_mask, np.array([False, True, False])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + self.assertTrue( + np.array_equal(computed_mask, valid_mask), + "compute_mask did not return the correct mask tensor", + ) + + def test_attention_compute_mask_returns_correct_tensor_with_all_true_mask( + self, + ): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([True, True, True]) + mask = [valid_mask, np.array([True, True, True])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + expected_mask = np.array([True, True, True]) + self.assertTrue( + np.array_equal(computed_mask, expected_mask), + "compute_mask did not return the correct mask tensor", + ) + + def test_attention_compute_mask_returns_correct_tensor_with_all_false_mask( + self, + ): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([False, False, False]) + mask = [valid_mask, np.array([False, False, False])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + expected_mask = np.array([False, False, False]) + self.assertTrue( + np.array_equal(computed_mask, expected_mask), + "compute_mask did not return the correct mask tensor", + ) + + def test_attention_compute_mask_with_tolerance_1e_3(self): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([1.0, 0.0, 1.0], dtype=float) + mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + expected_mask = valid_mask + self.assertTrue( + np.allclose(computed_mask, expected_mask, atol=1e-3), + "Incorrect mask tensor within tolerance 1e-3", + ) + + def test_attention_compute_mask_with_tolerance_1e_5(self): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([1.0, 0.0, 1.0], dtype=float) + mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + expected_mask = valid_mask + self.assertTrue( + np.allclose(computed_mask, expected_mask, atol=1e-5), + "Incorrect mask tensor within tolerance 1e-5", + ) + + def test_attention_compute_mask_with_tolerance_1e_7(self): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([1.0, 0.0, 1.0], dtype=float) + mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + expected_mask = valid_mask + self.assertTrue( + np.allclose(computed_mask, expected_mask, atol=1e-7), + "Incorrect mask tensor within tolerance 1e-7 ", + ) + + def test_attention_compute_mask_with_single_element_masks(self): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([True]) + mask = [valid_mask, np.array([False])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + expected_shape = (1,) + self.assertEqual(computed_mask.shape, expected_shape) + + def test_attention_compute_mask_with_non_boolean_masks(self): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + valid_mask = np.array([1, 0, 1]) + mask = [valid_mask, np.array([0, 1, 0])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + self.assertTrue(np.array_equal(computed_mask, valid_mask)) + + def test_attention_compute_mask_with_edge_case_masks(self): + layer = layers.Attention() + dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))] + edge_case_masks = [ + np.array([True, True, True]), + np.array([False, False, False]), + np.array([True, False, True]), + ] + for mask in edge_case_masks: + computed_mask = layer.compute_mask( + inputs=dummy_inputs, mask=[mask, mask] + ) + computed_mask = ops.convert_to_numpy(computed_mask) + self.assertTrue(np.array_equal(computed_mask, mask)) + + def test_attention_compute_mask_with_different_input_shapes(self): + layer = layers.Attention() + input_shapes = [(2, 3, 4), (3, 2, 5), (4, 1, 6)] + valid_mask = np.array([True, False, True]) + for shape in input_shapes: + dummy_inputs = [np.ones(shape), np.ones(shape)] + mask = [valid_mask, np.array([False, True, False])] + computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask) + computed_mask = ops.convert_to_numpy(computed_mask) + self.assertTrue(np.array_equal(computed_mask, valid_mask)) + + def test_attention_compute_output_shape(self): + layer = layers.Attention() + + query = np.random.random((2, 3, 4)) + value = np.random.random((2, 3, 5)) + key = np.random.random((2, 3, 4)) + layer = layers.Attention() + output = layer([query, value, key]) + self.assertAllEqual(output.shape, value.shape) + self.assertAllEqual( + layer.compute_output_shape( + input_shape=[query.shape, value.shape, key.shape] + ), + output.shape, + ) + + def test_return_attention_scores_true(self): + """Test that the layer returns attention scores along with outputs.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + output, attention_scores = layer( + [query, value], return_attention_scores=True + ) + + # Check the shape of the outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape + + def test_return_attention_scores_true_and_tuple(self): + """Test that the layer outputs are a tuple when + return_attention_scores=True.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Check that outputs is a tuple + self.assertIsInstance( + outputs, tuple, "Expected the outputs to be a tuple" + ) + + def test_return_attention_scores_true_tuple_then_unpack(self): + """Test that outputs can be unpacked correctly.""" + # Generate dummy input data + query = np.random.random((2, 8, 16)).astype(np.float32) + value = np.random.random((2, 4, 16)).astype(np.float32) + + # Initialize the Attention layer + layer = layers.Attention() + + # Call the layer with return_attention_scores=True + outputs = layer([query, value], return_attention_scores=True) + + # Unpack the outputs + output, attention_scores = outputs + + # Check the shape of the unpacked outputs + self.assertEqual(output.shape, (2, 8, 16)) # Output shape + self.assertEqual( + attention_scores.shape, (2, 8, 4) + ) # Attention scores shape + + def test_return_attention_scores_with_symbolic_tensors(self): + """Test to check outputs with symbolic tensors with + return_attention_scores = True""" + attention = layers.Attention() + x = layers.Input(shape=(3, 5)) + y = layers.Input(shape=(4, 5)) + output, attention_scores = attention( + [x, y], return_attention_scores=True + ) + self.assertEqual(output.shape, (None, 3, 5)) # Output shape + self.assertEqual(attention_scores.shape, (None, 3, 4)) diff --git a/keras/src/layers/attention/grouped_query_attention.py b/keras/src/layers/attention/grouped_query_attention.py new file mode 100644 index 000000000000..b57028446f0d --- /dev/null +++ b/keras/src/layers/attention/grouped_query_attention.py @@ -0,0 +1,503 @@ +import math + +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.backend.config import is_flash_attention_enabled +from keras.src.layers.activations.softmax import Softmax +from keras.src.layers.core.einsum_dense import EinsumDense +from keras.src.layers.layer import Layer +from keras.src.layers.regularization.dropout import Dropout + + +@keras_export("keras.layers.GroupQueryAttention") +class GroupedQueryAttention(Layer): + """Grouped Query Attention layer. + + This is an implementation of grouped-query attention introduced by + [Ainslie et al., 2023](https://arxiv.org/abs/2305.13245). Here + `num_key_value_heads` denotes number of groups, setting + `num_key_value_heads` to 1 is equivalent to multi-query attention, and + when `num_key_value_heads` is equal to `num_query_heads` it is equivalent + to multi-head attention. + + This layer first projects `query`, `key`, and `value` tensors. Then, `key` + and `value` are repeated to match the number of heads of `query`. + + Then, the `query` is scaled and dot-producted with `key` tensors. These are + softmaxed to obtain attention probabilities. The value tensors are then + interpolated by these probabilities and concatenated back to a single + tensor. + + Args: + head_dim: Size of each attention head. + num_query_heads: Number of query attention heads. + num_key_value_heads: Number of key and value attention heads. + dropout: Dropout probability. + use_bias: Boolean, whether the dense layers use bias vectors/matrices. + flash_attention: If `None`, the layer attempts to use flash + attention for faster and more memory-efficient attention + computations when possible. This behavior can be configured using + `keras.config.enable_flash_attention()` or + `keras.config.disable_flash_attention()`. + kernel_initializer: Initializer for dense layer kernels. + bias_initializer: Initializer for dense layer biases. + kernel_regularizer: Regularizer for dense layer kernels. + bias_regularizer: Regularizer for dense layer biases. + activity_regularizer: Regularizer for dense layer activity. + kernel_constraint: Constraint for dense layer kernels. + bias_constraint: Constraint for dense layer kernels. + seed: Optional integer to seed the dropout layer. + + Call arguments: + query: Query tensor of shape `(batch_dim, target_seq_len, feature_dim)`, + where `batch_dim` is batch size, `target_seq_len` is the length of + target sequence, and `feature_dim` is dimension of feature. + value: Value tensor of shape `(batch_dim, source_seq_len, feature_dim)`, + where `batch_dim` is batch size, `source_seq_len` is the length of + source sequence, and `feature_dim` is dimension of feature. + key: Optional key tensor of shape + `(batch_dim, source_seq_len, feature_dim)`. If not given, will use + `value` for both `key` and `value`, which is most common case. + attention_mask: A boolean mask of shape + `(batch_dim, target_seq_len, source_seq_len)`, that prevents + attention to certain positions. The boolean mask specifies which + query elements can attend to which key elements, where 1 indicates + attention and 0 indicates no attention. Broadcasting can happen for + the missing batch dimensions and the head dimension. + return_attention_scores: A boolean to indicate whether the output + should be `(attention_output, attention_scores)` if `True`, or + `attention_output` if `False`. Defaults to `False`. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + Will go with either using the training mode of the parent + layer/model or `False` (inference) if there is no parent layer. + use_causal_mask: A boolean to indicate whether to apply a causal mask to + prevent tokens from attending to future tokens (e.g., used in a + decoder Transformer). + + Returns: + attention_output: Result of the computation, of shape + `(batch_dim, target_seq_len, feature_dim)`, where `target_seq_len` + is for target sequence length and `feature_dim` is the query input + last dim. + attention_scores: (Optional) attention coefficients of shape + `(batch_dim, num_query_heads, target_seq_len, source_seq_len)`. + """ + + def __init__( + self, + head_dim, + num_query_heads, + num_key_value_heads, + dropout=0.0, + use_bias=True, + flash_attention=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.supports_masking = True + self.head_dim = head_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + if num_query_heads % num_key_value_heads != 0: + raise ValueError( + "`num_query_heads` must be divisible by `num_key_value_heads`." + ) + self.num_repeats = num_query_heads // num_key_value_heads + self.dropout = dropout + self.use_bias = use_bias + self._flash_attention = flash_attention or is_flash_attention_enabled() + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.seed = seed + + self._inverse_sqrt_head_dim = 1.0 / math.sqrt(float(self.head_dim)) + self._return_attention_scores = False + + # Check for flash attention constraints + if self._flash_attention and self.dropout > 0.0: + raise ValueError( + "Dropout is not supported when flash attention is enabled. " + "Please set dropout to 0.0 to use flash attention." + ) + + def build( + self, + query_shape, + value_shape, + key_shape=None, + ): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + key_shape = value_shape if key_shape is None else key_shape + self.feature_dim = query_shape[-1] + self._query_dense = EinsumDense( + "bqm,muh->bquh", + output_shape=(None, self.num_query_heads, self.head_dim), + bias_axes="uh" if self.use_bias else None, + name="query", + **self._get_common_kwargs_for_sublayer(), + ) + self._query_dense.build(query_shape) + + self._key_dense = EinsumDense( + "bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + bias_axes="vh" if self.use_bias else None, + name="key", + **self._get_common_kwargs_for_sublayer(), + ) + self._key_dense.build(key_shape) + + self._value_dense = EinsumDense( + "bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + bias_axes="vh" if self.use_bias else None, + name="value", + **self._get_common_kwargs_for_sublayer(), + ) + self._value_dense.build(value_shape) + + self._softmax = Softmax(axis=-1, dtype=self.dtype_policy) + self._dropout_layer = Dropout( + rate=self.dropout, dtype=self.dtype_policy, seed=self.seed + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self._output_dense = EinsumDense( + "bquh,uhm->bqm", + output_shape=(None, self.feature_dim), + bias_axes="m" if self.use_bias else None, + name="attention_output", + **self._get_common_kwargs_for_sublayer(), + ) + self._output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + + def _get_common_kwargs_for_sublayer(self): + common_kwargs = dict( + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + activity_regularizer=self.activity_regularizer, + kernel_constraint=self.kernel_constraint, + bias_constraint=self.bias_constraint, + dtype=self.dtype_policy, + ) + # Create new clone of kernel/bias initializer, so that we don't reuse + # the initializer instance, which could lead to same init value since + # initializer is stateless. + kernel_initializer = self.kernel_initializer.__class__.from_config( + self.kernel_initializer.get_config() + ) + bias_initializer = self.bias_initializer.__class__.from_config( + self.bias_initializer.get_config() + ) + common_kwargs["kernel_initializer"] = kernel_initializer + common_kwargs["bias_initializer"] = bias_initializer + return common_kwargs + + def call( + self, + query, + value, + key=None, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + self._return_attention_scores = return_attention_scores + if key is None: + key = value + + attention_mask = self._compute_attention_mask( + query, + value, + query_mask=query_mask, + value_mask=value_mask, + key_mask=key_mask, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + ) + + query = self._query_dense(query) + key = self._key_dense(key) + value = self._value_dense(value) + + key = ops.repeat( + key, self.num_repeats, axis=2 + ) # (batch_dim, source_seq_len, query_heads, head_dim) + value = ops.repeat( + value, self.num_repeats, axis=2 + ) # (batch_dim, source_seq_len, query_heads, head_dim) + + output, scores = self._compute_attention( + query, + key, + value, + attention_mask=attention_mask, + training=training, + ) + + output = self._output_dense( + output + ) # (batch_dim, target_seq_len, feature_dim) + + if return_attention_scores: + return output, scores + return output + + def _compute_attention_mask( + self, + query, + value, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + use_causal_mask=False, + ): + """Computes the attention mask, using the Keras masks of the inputs. + + * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. + * The `value`'s mask is reshaped from [B, S] to [B, 1, S]. + * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s + mask is ignored if `key` is `None` or if `key is value`. + * If `use_causal_mask=True`, then the causal mask is computed. Its shape + is [1, T, S]. + + All defined masks are merged using a logical AND operation (`&`). + + In general, if the `query` and `value` are masked, then there is no need + to define the `attention_mask`. + + Args: + query: Projected query tensor of shape `(B, T, N, key_dim)`. + key: Projected key tensor of shape `(B, T, N, key_dim)`. + value: Projected value tensor of shape `(B, T, N, value_dim)`. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. + use_causal_mask: A boolean to indicate whether to apply a causal + mask to prevent tokens from attending to future tokens (e.g., + used in a decoder Transformer). + + Returns: + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions, based on the Keras masks of the + `query`, `key`, `value`, and `attention_mask` tensors, and the + causal mask if `use_causal_mask=True`. + """ + auto_mask = None + if query_mask is not None: + query_mask = ops.cast(query_mask, "bool") # defensive casting + # B = batch size, T = max query length + auto_mask = ops.expand_dims(query_mask, -1) # shape is [B, T, 1] + if value_mask is not None: + value_mask = ops.cast(value_mask, "bool") # defensive casting + # B = batch size, S == max value length + mask = ops.expand_dims(value_mask, -2) # shape is [B, 1, S] + auto_mask = mask if auto_mask is None else auto_mask & mask + if key_mask is not None: + key_mask = ops.cast(key_mask, "bool") # defensive casting + # B == batch size, S == max key length == max value length + mask = ops.expand_dims(key_mask, -2) # shape is [B, 1, S] + auto_mask = mask if auto_mask is None else auto_mask & mask + if use_causal_mask: + # the shape of the causal mask is [1, T, S] + mask = self._compute_causal_mask(query, value) + auto_mask = mask if auto_mask is None else auto_mask & mask + if auto_mask is not None: + # merge attention_mask & automatic mask, to shape [B, T, S] + attention_mask = ( + auto_mask + if attention_mask is None + else ops.cast(attention_mask, bool) & auto_mask + ) + return attention_mask + + def _compute_causal_mask(self, query, value=None): + """Computes a causal mask (e.g., for masked self-attention layers). + + For example, if query and value both contain sequences of length 4, + this function returns a boolean tensor equal to: + + ``` + [[[True, False, False, False], + [True, True, False, False], + [True, True, True, False], + [True, True, True, True]]] + ``` + + Args: + query: query tensor of shape `(B, T, ...)`. + value: value tensor of shape `(B, S, ...)` (optional, defaults to + query). + + Returns: + mask: a boolean tensor of shape `(1, T, S)` containing a lower + triangular matrix of shape `(T, S)`. + """ + q_seq_length = ops.shape(query)[1] + v_seq_length = q_seq_length if value is None else ops.shape(value)[1] + ones_mask = ops.ones((1, q_seq_length, v_seq_length), dtype="int32") + row_index = ops.cumsum(ones_mask, axis=-2) + col_index = ops.cumsum(ones_mask, axis=-1) + return ops.greater_equal(row_index, col_index) + + def _compute_attention( + self, query, key, value, attention_mask=None, training=None + ): + # Check for flash attention constraints + if self._flash_attention and self._return_attention_scores: + raise ValueError( + "Returning attention scores is not supported when flash " + "attention is enabled. Please disable flash attention to access" + " attention scores." + ) + + # Determine whether to use dot-product attention + use_dot_product_attention = not ( + self.dropout > 0.0 + or self._return_attention_scores + or (len(query.shape) != 4) + ) + + if use_dot_product_attention: + if attention_mask is not None: + # Ensure attention_mask has the correct shape for broadcasting + # Expected shape: [batch_size, num_heads, query_seq_len, + # key_seq_len]. + mask_expansion_axis = -1 * 2 - 1 + len_attention_scores_shape = 4 # Only accepts 4D inputs + for _ in range( + len_attention_scores_shape - len(attention_mask.shape) + ): + attention_mask = ops.expand_dims( + attention_mask, axis=mask_expansion_axis + ) + attention_mask = ops.cast(attention_mask, dtype="bool") + # Directly compute the attention output using dot-product attention + attention_output = ops.dot_product_attention( + query=query, + key=key, + value=value, + bias=None, + mask=attention_mask, + scale=self._inverse_sqrt_head_dim, + is_causal=False, + flash_attention=self._flash_attention, + ) + return attention_output, None + + # Default behavior without flash attention, with explicit attention + # scores + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_head_dim, query.dtype) + ) + # Take the dot product between "query" and "key" to get the raw + # attention scores. + scores = ops.einsum( + self._dot_product_equation, query, key + ) # (batch_dim, query_heads, target_seq_len, source_seq_len) + scores = self._masked_softmax(scores, attention_mask=attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if self.dropout > 0.0: + scores_dropout = self._dropout_layer(scores, training=training) + else: + scores_dropout = scores + output = ops.einsum(self._combine_equation, scores_dropout, value) + return output, scores + + def _masked_softmax(self, scores, attention_mask=None): + # Normalize the attention scores to probabilities. + # scores = [B, N, T, S] + if attention_mask is not None: + # The expand dim happens starting from the `num_heads` dimension, + # (, num_heads, ) + mask_expansion_axis = -1 * 2 - 1 + for _ in range(len(scores.shape) - len(attention_mask.shape)): + attention_mask = ops.expand_dims( + attention_mask, axis=mask_expansion_axis + ) + return self._softmax(scores, mask=attention_mask) + + def compute_output_shape( + self, + query_shape, + value_shape, + key_shape=None, + ): + if key_shape is None: + key_shape = value_shape + + if query_shape[-1] != value_shape[-1]: + raise ValueError( + "The last dimension of `query_shape` and `value_shape` " + f"must be equal, but are {query_shape[-1]}, {value_shape[-1]}. " + f"Received: query_shape={query_shape}, " + f"value_shape={value_shape}" + ) + + if value_shape[1:-1] != key_shape[1:-1]: + raise ValueError( + "All dimensions of `value` and `key`, except the last one, " + f"must be equal. Received: value_shape={value_shape} and " + f"key_shape={key_shape}" + ) + + return query_shape + + def get_config(self): + config = { + "head_dim": self.head_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "use_bias": self.use_bias, + "dropout": self.dropout, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/attention/grouped_query_attention_test.py b/keras/src/layers/attention/grouped_query_attention_test.py new file mode 100644 index 000000000000..7dec844bd983 --- /dev/null +++ b/keras/src/layers/attention/grouped_query_attention_test.py @@ -0,0 +1,398 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import testing +from keras.src.backend.config import disable_flash_attention +from keras.src.backend.config import enable_flash_attention +from keras.src.backend.config import is_flash_attention_enabled + + +class GroupedQueryAttentionTest(testing.TestCase): + def setUp(self): + super().setUp() + # Flash attention is a newly introduced feature. We need to disable it + # for testing purposes. + disable_flash_attention() + + def tearDown(self): + enable_flash_attention() + return super().tearDown() + + def test_basics(self): + self.assertFalse(is_flash_attention_enabled()) + self.run_layer_test( + layers.GroupedQueryAttention, + init_kwargs={ + "num_query_heads": 2, + "num_key_value_heads": 2, + "head_dim": 2, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + self.run_layer_test( + layers.GroupedQueryAttention, + init_kwargs={ + "num_query_heads": 2, + "num_key_value_heads": 2, + "head_dim": 2, + "use_bias": False, + "dropout": 0.5, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=4, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + def test_basics_with_flash_attention(self): + enable_flash_attention() + init_kwargs = { + "num_query_heads": 2, + "num_key_value_heads": 2, + "head_dim": 8, + "dtype": "float16", + } + input_shape = { + "query_shape": (2, 8, 16), + "value_shape": (2, 4, 16), + } + expected_output_shape = (2, 8, 16) + if backend.backend() in ("tensorflow", "numpy"): + self.skipTest( + "Flash attention is not supported in tensorflow and numpy " + "backends." + ) + elif backend.backend() == "torch": + try: + self.run_layer_test( + layers.GroupedQueryAttention, + init_kwargs=init_kwargs, + input_shape=input_shape, + expected_output_shape=expected_output_shape, + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + except ImportError as e: + if "Flash attention is not supported" in str(e.args[0]): + self.assertTrue( + ( + "Flash attention is not supported in your current " + "PyTorch version." + ) + in str(e.args[0]) + ) + except RuntimeError as e: + if ( + "Flash attention is not supported with the provided inputs" + in str(e.args[0]) + ): + self.assertTrue( + ( + "Flash attention is not supported with the " + "provided inputs" + ) + in str(e.args[0]) + ) + elif backend.backend() == "jax": + try: + self.run_layer_test( + layers.GroupedQueryAttention, + init_kwargs=init_kwargs, + input_shape=input_shape, + expected_output_shape=expected_output_shape, + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + except ImportError as e: + if "Flash attention is not supported" in str(e.args[0]): + self.assertTrue( + ( + "Flash attention is not supported in your current " + "JAX version." + ) + in str(e.args[0]) + ) + except RuntimeError as e: + if "cuDNN" in str(e.args[0]): + self.assertTrue("cuDNN is not detected." in str(e.args[0])) + elif "Require at least" in str(e.args[0]): + self.assertTrue( + "Require at least Ampere arch to run" in str(e.args[0]) + ) + elif "Flash attention" in str(e.args[0]): + self.assertTrue( + ( + "Flash attention is not supported in your current " + "JAX version." + ) + in str(e.args[0]) + ) + + @parameterized.named_parameters( + ("without_key_proj_mha", (4, 8), (2, 8), None, 2, 2), + ("with_key_proj_mha", (4, 8), (2, 8), (2, 3), 2, 2), + ("without_key_proj_gqa", (4, 8), (2, 8), None, 4, 2), + ("with_key_proj_gqa", (4, 8), (2, 8), (2, 3), 4, 2), + ("without_key_value_proj_mqa", (4, 8), (2, 8), None, 4, 1), + ("with_key_value_proj_mqa", (4, 8), (2, 8), (2, 3), 4, 1), + ) + def test_compute_output_shape( + self, + query_dims, + value_dims, + key_dims, + num_query_heads, + num_key_value_heads, + ): + """Test computed shape is equal to the layer output's shape.""" + layer = layers.GroupedQueryAttention( + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + head_dim=2, + ) + batch_size = 7 + query_shape = (batch_size,) + query_dims + value_shape = (batch_size,) + value_dims + key_shape = (batch_size,) + key_dims if key_dims else None + + query = np.ones(query_shape) + value = np.ones(value_shape) + key = np.ones(key_shape) if key_shape else None + output = layer(query=query, value=value, key=key) + comp_output_shape = layer.compute_output_shape( + query_shape, value_shape, key_shape + ) + self.assertEqual(output.shape, comp_output_shape) + + @parameterized.named_parameters( + ("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), 2), + ("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)), + ) + def test_shape_mismatch_error(self, query_shape, value_shape, key_shape): + """Test dimension mismatches""" + layer = layers.GroupedQueryAttention( + num_query_heads=4, + num_key_value_heads=4, + head_dim=2, + ) + with self.assertRaisesRegex(ValueError, r"must be equal"): + layer.compute_output_shape(query_shape, value_shape, key_shape) + + def test_initializer(self): + # Test with a specified initializer. + layer = layers.GroupedQueryAttention( + num_query_heads=16, + num_key_value_heads=16, + head_dim=64, + kernel_initializer=initializers.TruncatedNormal(stddev=0.02), + ) + layer.build((2, 4, 8), (2, 4, 8)) + + # Make sure the sub layers have different kernel init value. + self.assertNotAllClose( + layer._query_dense.kernel, + layer._key_dense.kernel, + ) + self.assertNotAllClose( + layer._query_dense.kernel, + layer._value_dense.kernel, + ) + self.assertNotAllClose( + layer._query_dense.kernel, + layer._output_dense.kernel, + ) + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_query_mask_propagation(self): + """Test automatic propagation of the query's mask.""" + try: + layer = layers.GroupedQueryAttention( + num_query_heads=2, num_key_value_heads=2, head_dim=2 + ) + self.assertTrue(layer.supports_masking) + query = np.array( + [[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]] + ) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.random.normal(size=(3, 3, 8)) + output = layer(query=masked_query, value=value) + except RuntimeError as e: + if e.args[0].startswith( + "(*bias): last dimension must be contiguous" + ): + self.skipTest( + "PyTorch errors out on GPU: issue to track bug is here " + "https://github.com/keras-team/keras/issues/20459" + ) + self.assertAllClose(masked_query._keras_mask, output._keras_mask) + + @parameterized.named_parameters(("causal", True), ("not_causal", 0)) + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_masking(self, use_causal_mask): + """Test that the value and causal masks are taken into account.""" + layer = layers.GroupedQueryAttention( + num_query_heads=2, num_key_value_heads=2, head_dim=2 + ) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]]) + masked_value = layers.Embedding(6, 8, mask_zero=True)(value) + output = layer( + query=masked_query, + value=masked_value, + use_causal_mask=use_causal_mask, + ) + mask = np.array( + [[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2] + + [[[1, 0, 0]] * 5] + + [[[1, 1, 1]] + [[0, 0, 0]] * 4] + ).astype(bool) + if use_causal_mask: + mask = mask & np.array( + [[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3] + ).astype(bool) + del masked_query._keras_mask + del masked_value._keras_mask + output_with_manual_mask = layer( + query=masked_query, value=masked_value, attention_mask=mask + ) + self.assertAllClose(output, output_with_manual_mask) + + @parameterized.named_parameters( + ("disable_flash_attention", False), ("enable_flash_attention", True) + ) + def test_correctness(self, flash_attention): + if flash_attention: + # Let the backend decide whether to use flash attention + enable_flash_attention() + dtype = "float16" # Flash attention only accepts float16/bfloat16 + head_dim = 8 # key_dim % 8 == 0 to enable flash attention + num_query_heads = num_key_value_heads = 8 + + query = np.identity(head_dim)[np.newaxis, ...] + key = np.identity(head_dim)[np.newaxis, ...] + value = ( + np.reshape(np.arange(head_dim * head_dim), (1, head_dim, head_dim)) + / 100.0 # Prevent overflow/underflow + ) + + # Setup layer. + layer = layers.GroupedQueryAttention( + head_dim=head_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + dtype=dtype, + ) + layer.build(query.shape, key.shape, value.shape) + + # Set layer weights. + kernel = np.identity(head_dim) + # To get an identity kernel we need to add a head dim and repeat on it. + kernel = np.repeat(kernel[:, np.newaxis, :], num_query_heads, axis=1) + # Zeros for all biases. + bias = np.zeros((num_query_heads, head_dim)) + output_bias = np.zeros((head_dim,)) + layer.set_weights([kernel, bias] * 3 + [kernel, output_bias]) + + # Call layer and assert output. + expected_output = np.array( + [2.406, 2.440, 2.473, 2.504, 2.535, 2.568, 2.602, 2.633] + ) + expected_output = np.tile( + expected_output[np.newaxis, :, np.newaxis], (1, 1, head_dim) + ) + expected_score = np.array( + [ + [0.1187] * 0 + [0.1691] + [0.1187] * 7, + [0.1187] * 1 + [0.1691] + [0.1187] * 6, + [0.1187] * 2 + [0.1691] + [0.1187] * 5, + [0.1187] * 3 + [0.1691] + [0.1187] * 4, + [0.1187] * 4 + [0.1691] + [0.1187] * 3, + [0.1187] * 5 + [0.1691] + [0.1187] * 2, + [0.1187] * 6 + [0.1691] + [0.1187] * 1, + [0.1187] * 7 + [0.1691] + [0.1187] * 0, + ] + ) + expected_score = np.tile( + expected_score[np.newaxis, np.newaxis, ...], (1, head_dim, 1, 1) + ) + if flash_attention: + output = layer(query=query, value=value, key=key) + self.assertAllClose(output, expected_output, atol=1e-2) + else: + output, scores = layer( + query=query, + value=value, + key=key, + return_attention_scores=True, + ) + self.assertAllClose(output, expected_output, atol=1e-2) + self.assertAllClose(scores, expected_score, atol=1e-2) + + def test_flash_attention_with_errors(self): + if backend.backend() in ("numpy", "tensorflow"): + pytest.skip( + reason=( + "Flash attention is not supported on tensorflow and numpy." + ) + ) + # Check `flash_attention=True` and `dropout=0.1` + with self.assertRaisesRegex( + ValueError, + "Dropout is not supported when flash attention is enabled.", + ): + layer = layers.GroupedQueryAttention( + head_dim=2, + num_query_heads=2, + num_key_value_heads=2, + flash_attention=True, + dropout=0.1, + ) + + # Check `flash_attention=True` and `return_attention_scores=True` + layer = layers.GroupedQueryAttention( + head_dim=2, + num_query_heads=2, + num_key_value_heads=2, + flash_attention=True, + ) + self.assertTrue(layer._flash_attention) + query = np.random.random((2, 4, 8)) + value = np.random.random((2, 4, 8)) + with self.assertRaisesRegex( + ValueError, + "Returning attention scores is not supported when flash " + "attention is enabled. Please disable flash attention to access" + " attention scores.", + ): + layer(query=query, value=value, return_attention_scores=True) diff --git a/keras/src/layers/attention/multi_head_attention.py b/keras/src/layers/attention/multi_head_attention.py new file mode 100644 index 000000000000..a8aa86838d5a --- /dev/null +++ b/keras/src/layers/attention/multi_head_attention.py @@ -0,0 +1,826 @@ +import math +import string + +import numpy as np + +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.backend.config import is_flash_attention_enabled +from keras.src.layers.activations.softmax import Softmax +from keras.src.layers.core.einsum_dense import EinsumDense +from keras.src.layers.layer import Layer +from keras.src.layers.regularization.dropout import Dropout + + +@keras_export("keras.layers.MultiHeadAttention") +class MultiHeadAttention(Layer): + """MultiHeadAttention layer. + + This is an implementation of multi-headed attention as described in the + paper "Attention is all you Need" + [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762). + If `query`, `key,` `value` are the same, then + this is self-attention. Each timestep in `query` attends to the + corresponding sequence in `key`, and returns a fixed-width vector. + + This layer first projects `query`, `key` and `value`. These are + (effectively) a list of tensors of length `num_attention_heads`, where the + corresponding shapes are `(batch_size, , key_dim)`, + `(batch_size, , key_dim)`, + `(batch_size, , value_dim)`. + + Then, the query and key tensors are dot-producted and scaled. These are + softmaxed to obtain attention probabilities. The value tensors are then + interpolated by these probabilities, then concatenated back to a single + tensor. + + Finally, the result tensor with the last dimension as `value_dim` can take + a linear projection and return. + + Args: + num_heads: Number of attention heads. + key_dim: Size of each attention head for query and key. + value_dim: Size of each attention head for value. + dropout: Dropout probability. + use_bias: Boolean, whether the dense layers use bias vectors/matrices. + output_shape: The expected shape of an output tensor, besides the batch + and sequence dims. If not specified, projects back to the query + feature dim (the query input's last dimension). + attention_axes: axes over which the attention is applied. `None` means + attention over all axes, but batch, heads, and features. + flash_attention: If `None`, the layer attempts to use flash + attention for faster and more memory-efficient attention + computations when possible. This behavior can be configured using + `keras.config.enable_flash_attention()` or + `keras.config.disable_flash_attention()`. + kernel_initializer: Initializer for dense layer kernels. + bias_initializer: Initializer for dense layer biases. + kernel_regularizer: Regularizer for dense layer kernels. + bias_regularizer: Regularizer for dense layer biases. + activity_regularizer: Regularizer for dense layer activity. + kernel_constraint: Constraint for dense layer kernels. + bias_constraint: Constraint for dense layer kernels. + seed: Optional integer to seed the dropout layer. + + Call arguments: + query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size, + `T` is the target sequence length, and dim is the feature dimension. + value: Value tensor of shape `(B, S, dim)`, where `B` is the batch size, + `S` is the source sequence length, and dim is the feature dimension. + key: Optional key tensor of shape `(B, S, dim)`. If not given, will + use `value` for both `key` and `value`, which is the most common + case. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. The boolean mask specifies which + query elements can attend to which key elements, 1 indicates + attention and 0 indicates no attention. Broadcasting can happen for + the missing batch dimensions and the head dimension. + return_attention_scores: A boolean to indicate whether the output should + be `(attention_output, attention_scores)` if `True`, or + `attention_output` if `False`. Defaults to `False`. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + Will go with either using the training mode of the parent + layer/model, or `False` (inference) if there is no parent layer. + use_causal_mask: A boolean to indicate whether to apply a causal mask to + prevent tokens from attending to future tokens (e.g., used in a + decoder Transformer). + + Returns: + attention_output: The result of the computation, of shape `(B, T, E)`, + where `T` is for target sequence shapes and `E` is the query input + last dimension if `output_shape` is `None`. Otherwise, the + multi-head outputs are projected to the shape specified by + `output_shape`. + attention_scores: (Optional) multi-head attention coefficients over + attention axes. + """ + + def __init__( + self, + num_heads, + key_dim, + value_dim=None, + dropout=0.0, + use_bias=True, + output_shape=None, + attention_axes=None, + flash_attention=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.supports_masking = True + self._num_heads = num_heads + self._key_dim = key_dim + self._value_dim = value_dim if value_dim else key_dim + self._dropout = dropout + self._use_bias = use_bias + if output_shape: + if isinstance(output_shape, int): + output_shape = (output_shape,) + try: + output_shape = tuple(output_shape) + except: + raise ValueError( + f"Invalid `output_shape`: {output_shape}. When " + "specified, the `output_shape` should be of type tuple, " + "list, or int." + ) + self._output_shape = output_shape + self._flash_attention = flash_attention or is_flash_attention_enabled() + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) + self._kernel_regularizer = regularizers.get(kernel_regularizer) + self._bias_regularizer = regularizers.get(bias_regularizer) + self._activity_regularizer = regularizers.get(activity_regularizer) + self._kernel_constraint = constraints.get(kernel_constraint) + self._bias_constraint = constraints.get(bias_constraint) + if isinstance(attention_axes, int): + attention_axes = (attention_axes,) + elif attention_axes and not isinstance(attention_axes, (list, tuple)): + raise ValueError( + "`attention_axes` must be an int, list, or tuple." + f"Received: attention_axes={attention_axes}" + ) + self._attention_axes = attention_axes + self.seed = seed + + self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim)) + + # Check for flash attention constraints + if self._flash_attention and self._dropout > 0.0: + raise ValueError( + "Dropout is not supported when flash attention is enabled. " + "Please set dropout to 0.0 to use flash attention." + ) + + @property + def num_heads(self): + return self._num_heads + + @property + def key_dim(self): + return self._key_dim + + @property + def value_dim(self): + return self._value_dim + + @property + def dropout(self): + return self._dropout + + @property + def use_bias(self): + return self._use_bias + + # Avoid exposing `output_shape` as it may conflict with `Functional` and + # `Sequential` models when calling `summary()`. + + @property + def attention_axes(self): + return self._attention_axes + + def get_config(self): + base_config = super().get_config() + config = { + "num_heads": self._num_heads, + "key_dim": self._key_dim, + "value_dim": self._value_dim, + "dropout": self._dropout, + "use_bias": self._use_bias, + "output_shape": self._output_shape, + "attention_axes": self._attention_axes, + "kernel_initializer": initializers.serialize( + self._kernel_initializer + ), + "bias_initializer": initializers.serialize(self._bias_initializer), + "kernel_regularizer": regularizers.serialize( + self._kernel_regularizer + ), + "bias_regularizer": regularizers.serialize(self._bias_regularizer), + "activity_regularizer": regularizers.serialize( + self._activity_regularizer + ), + "kernel_constraint": constraints.serialize(self._kernel_constraint), + "bias_constraint": constraints.serialize(self._bias_constraint), + "seed": self.seed, + } + return {**base_config, **config} + + def build( + self, + query_shape, + value_shape, + key_shape=None, + ): + """Builds layers and variables. + + Args: + query_shape: Shape of the `query` tensor. + value_shape: Shape of the `value` tensor. + key: Optional shape of the `key` tensor. + """ + key_shape = value_shape if key_shape is None else key_shape + + if value_shape[1:-1] != key_shape[1:-1]: + raise ValueError( + "All dimensions of `value` and `key`, except the last one, " + f"must be equal. Received: value_shape={value_shape} and " + f"key_shape={key_shape}" + ) + + query_rank = len(query_shape) + value_rank = len(value_shape) + key_rank = len(key_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + query_rank - 1, bound_dims=1, output_dims=2 + ) + self._query_dense = EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="query", + **self._get_common_kwargs_for_sublayer(), + ) + self._query_dense.build(query_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + key_rank - 1, bound_dims=1, output_dims=2 + ) + self._key_dense = EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="key", + **self._get_common_kwargs_for_sublayer(), + ) + self._key_dense.build(key_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + value_rank - 1, bound_dims=1, output_dims=2 + ) + self._value_dense = EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._value_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="value", + **self._get_common_kwargs_for_sublayer(), + ) + self._value_dense.build(value_shape) + + # Builds the attention computations for multi-head dot product + # attention. These computations could be wrapped into the keras + # attention layer once it supports multi-head einsum computations. + self._build_attention(output_rank) + self._output_dense = self._make_output_dense( + query_shape, + self._get_common_kwargs_for_sublayer(), + "attention_output", + ) + output_dense_input_shape = list( + self._query_dense.compute_output_shape(query_shape) + ) + output_dense_input_shape[-1] = self._value_dim + self._output_dense.build(tuple(output_dense_input_shape)) + + @property + def query_dense(self): + return self._query_dense + + @property + def key_dense(self): + return self._key_dense + + @property + def value_dense(self): + return self._value_dense + + @property + def output_dense(self): + return self._output_dense + + def _get_common_kwargs_for_sublayer(self): + common_kwargs = dict( + kernel_regularizer=self._kernel_regularizer, + bias_regularizer=self._bias_regularizer, + activity_regularizer=self._activity_regularizer, + kernel_constraint=self._kernel_constraint, + bias_constraint=self._bias_constraint, + dtype=self.dtype_policy, + ) + # Create new clone of kernel/bias initializer, so that we don't reuse + # the initializer instance, which could lead to same init value since + # initializer is stateless. + kernel_initializer = self._kernel_initializer.__class__.from_config( + self._kernel_initializer.get_config() + ) + bias_initializer = self._bias_initializer.__class__.from_config( + self._bias_initializer.get_config() + ) + common_kwargs["kernel_initializer"] = kernel_initializer + common_kwargs["bias_initializer"] = bias_initializer + return common_kwargs + + def _make_output_dense(self, query_shape, common_kwargs, name=None): + """Builds the output projection matrix. + + Args: + free_dims: Number of free dimensions for einsum equation building. + common_kwargs: Common keyword arguments for einsum layer. + name: Name for the projection layer. + + Returns: + Projection layer. + """ + query_rank = len(query_shape) + if self._output_shape: + output_shape = self._output_shape + else: + output_shape = [query_shape[-1]] + einsum_equation, bias_axes, output_rank = _build_proj_equation( + query_rank - 1, bound_dims=2, output_dims=len(output_shape) + ) + return EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, output_shape), + bias_axes=bias_axes if self._use_bias else None, + name=name, + **common_kwargs, + ) + + def _build_attention(self, rank): + """Builds multi-head dot-product attention computations. + + This function builds attributes necessary for `_compute_attention` to + customize attention computation to replace the default dot-product + attention. + + Args: + rank: the rank of query, key, value tensors. + """ + if self._attention_axes is None: + self._attention_axes = tuple(range(1, rank - 2)) + else: + self._attention_axes = tuple(self._attention_axes) + ( + self._dot_product_equation, + self._combine_equation, + attn_scores_rank, + ) = _build_attention_equation(rank, attn_axes=self._attention_axes) + norm_axes = tuple( + range( + attn_scores_rank - len(self._attention_axes), attn_scores_rank + ) + ) + self._softmax = Softmax(axis=norm_axes, dtype=self.dtype_policy) + self._dropout_layer = Dropout( + rate=self._dropout, dtype=self.dtype_policy, seed=self.seed + ) + + def _masked_softmax(self, attention_scores, attention_mask=None): + # Normalize the attention scores to probabilities. + # attention_scores = [B, N, T, S] + if attention_mask is not None: + # The expand dim happens starting from the `num_heads` dimension, + # (, num_heads, ) + mask_expansion_axis = -len(self._attention_axes) * 2 - 1 + for _ in range( + len(attention_scores.shape) - len(attention_mask.shape) + ): + attention_mask = ops.expand_dims( + attention_mask, axis=mask_expansion_axis + ) + return self._softmax(attention_scores, mask=attention_mask) + + def _compute_attention( + self, + query, + key, + value, + attention_mask=None, + training=None, + return_attention_scores=False, + ): + """Applies Dot-product attention with query, key, value tensors. + + This function defines the computation inside `call` with projected + multi-head Q, K, V inputs. Users can override this function for + customized attention implementation. + + Args: + query: Projected query tensor of shape `(B, T, N, key_dim)`. + key: Projected key tensor of shape `(B, S, N, key_dim)`. + value: Projected value tensor of shape `(B, S, N, value_dim)`. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. It is generally not needed if + the `query` and `value` (and/or `key`) are masked. + training: Python boolean indicating whether the layer should behave + in training mode (adding dropout) or in inference mode (doing + nothing). + + Returns: + attention_output: Multi-headed outputs of attention computation. + attention_scores: Multi-headed attention weights. + """ + # Check for flash attention constraints + if self._flash_attention and return_attention_scores: + raise ValueError( + "Returning attention scores is not supported when flash " + "attention is enabled. Please disable flash attention to access" + " attention scores." + ) + + # Determine whether to use dot-product attention + use_dot_product_attention = not ( + self._dropout > 0.0 + or return_attention_scores + or (len(query.shape) != 4) + ) + + if use_dot_product_attention: + if attention_mask is not None: + # Ensure attention_mask has the correct shape for broadcasting + # Expected shape: [batch_size, num_heads, query_seq_len, + # key_seq_len]. + mask_expansion_axis = -len(self._attention_axes) * 2 - 1 + len_attention_scores_shape = 4 # Only accepts 4D inputs + for _ in range( + len_attention_scores_shape - len(attention_mask.shape) + ): + attention_mask = ops.expand_dims( + attention_mask, axis=mask_expansion_axis + ) + attention_mask = ops.cast(attention_mask, dtype="bool") + # Directly compute the attention output using dot-product attention + attention_output = ops.dot_product_attention( + query=query, + key=key, + value=value, + bias=None, + mask=attention_mask, + scale=self._inverse_sqrt_key_dim, + is_causal=False, + flash_attention=self._flash_attention, + ) + return attention_output, None + + # Default behavior without flash attention, with explicit attention + # scores + query = ops.multiply( + query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) + ) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + attention_scores = ops.einsum(self._dot_product_equation, key, query) + + # Apply the mask using the custom masked softmax + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + + # Apply dropout to the attention scores if needed + if self._dropout > 0.0: + final_attn_scores = self._dropout_layer( + attention_scores, training=training + ) + else: + final_attn_scores = attention_scores + + # `context_layer` = [B, T, N, H] + attention_output = ops.einsum( + self._combine_equation, final_attn_scores, value + ) + return attention_output, attention_scores + + def call( + self, + query, + value, + key=None, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + if key is None: + key = value + + # Delete the masks because the masks are handled at the level of the + # layer + query_mask = backend.get_keras_mask(query) + backend.set_keras_mask(query, None) + backend.set_keras_mask(value, None) + backend.set_keras_mask(key, None) + + attention_mask = self._compute_attention_mask( + query, + value, + query_mask=query_mask, + value_mask=value_mask, + key_mask=key_mask, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + ) + # N = `num_attention_heads` + # H = `size_per_head` + + # `query` = [B, T, N, H] + query = self._query_dense(query) + + # `key` = [B, S, N, H] + key = self._key_dense(key) + + # `value` = [B, S, N, H] + value = self._value_dense(value) + attention_output, attention_scores = self._compute_attention( + query, + key, + value, + attention_mask, + training, + return_attention_scores, + ) + attention_output = self._output_dense(attention_output) + + # Set mask on output if needed + if query_mask is not None: + backend.set_keras_mask(attention_output, query_mask) + + if return_attention_scores: + return attention_output, attention_scores + return attention_output + + def _compute_attention_mask( + self, + query, + value, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + use_causal_mask=False, + ): + """Computes the attention mask, using the Keras masks of the inputs. + + * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. + * The `value`'s mask is reshaped from [B, S] to [B, 1, S]. + * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s + mask is ignored if `key` is `None` or if `key is value`. + * If `use_causal_mask=True`, then the causal mask is computed. Its shape + is [1, T, S]. + + All defined masks are merged using a logical AND operation (`&`). + + In general, if the `query` and `value` are masked, then there is no need + to define the `attention_mask`. + + Args: + query: Projected query tensor of shape `(B, T, N, key_dim)`. + key: Projected key tensor of shape `(B, T, N, key_dim)`. + value: Projected value tensor of shape `(B, T, N, value_dim)`. + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions. + use_causal_mask: A boolean to indicate whether to apply a causal + mask to prevent tokens from attending to future tokens (e.g., + used in a decoder Transformer). + + Returns: + attention_mask: a boolean mask of shape `(B, T, S)`, that prevents + attention to certain positions, based on the Keras masks of the + `query`, `key`, `value`, and `attention_mask` tensors, and the + causal mask if `use_causal_mask=True`. + """ + auto_mask = None + if query_mask is not None: + query_mask = ops.cast(query_mask, "bool") # defensive casting + # B = batch size, T = max query length + auto_mask = ops.expand_dims(query_mask, -1) # shape is [B, T, 1] + if value_mask is not None: + value_mask = ops.cast(value_mask, "bool") # defensive casting + # B = batch size, S == max value length + mask = ops.expand_dims(value_mask, -2) # shape is [B, 1, S] + auto_mask = mask if auto_mask is None else auto_mask & mask + if key_mask is not None: + key_mask = ops.cast(key_mask, "bool") # defensive casting + # B == batch size, S == max key length == max value length + mask = ops.expand_dims(key_mask, -2) # shape is [B, 1, S] + auto_mask = mask if auto_mask is None else auto_mask & mask + if use_causal_mask: + # the shape of the causal mask is [1, T, S] + mask = self._compute_causal_mask(query, value) + auto_mask = mask if auto_mask is None else auto_mask & mask + + if attention_mask is not None: + attention_mask = ops.cast(attention_mask, "bool") + if auto_mask is not None: + # merge attention_mask & automatic mask, to shape [B, T, S] + attention_mask = ( + auto_mask + if attention_mask is None + else attention_mask & auto_mask + ) + return attention_mask + + def _compute_causal_mask(self, query, value=None): + """Computes a causal mask (e.g., for masked self-attention layers). + + For example, if query and value both contain sequences of length 4, + this function returns a boolean tensor equal to: + + ``` + [[[True, False, False, False], + [True, True, False, False], + [True, True, True, False], + [True, True, True, True]]] + ``` + + Args: + query: query tensor of shape `(B, T, ...)`. + value: value tensor of shape `(B, S, ...)` (optional, defaults to + query). + + Returns: + mask: a boolean tensor of shape `(1, T, S)` containing a lower + triangular matrix of shape `(T, S)`. + """ + q_seq_length = ops.shape(query)[1] + v_seq_length = q_seq_length if value is None else ops.shape(value)[1] + ones_mask = ops.ones((1, q_seq_length, v_seq_length), dtype="int32") + row_index = ops.cumsum(ones_mask, axis=-2) + col_index = ops.cumsum(ones_mask, axis=-1) + return ops.greater_equal(row_index, col_index) + + def compute_output_shape( + self, + query_shape, + value_shape, + key_shape=None, + ): + query_shape = tuple(query_shape) + value_shape = tuple(value_shape) + if key_shape is None: + key_shape = value_shape + else: + key_shape = tuple(key_shape) + + if value_shape[1:-1] != key_shape[1:-1]: + raise ValueError( + "All dimensions of `value` and `key`, except the last one, " + f"must be equal. Received: value_shape={value_shape} and " + f"key_shape={key_shape}" + ) + if self._output_shape: + query_shape = query_shape[:-1] + self._output_shape + return query_shape + + def compute_output_spec( + self, + query, + value, + key=None, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + return_attention_scores=False, + training=None, + use_causal_mask=False, + ): + if key is not None: + key_shape = key.shape + else: + key_shape = None + output_shape = self.compute_output_shape( + query.shape, value.shape, key_shape + ) + output_spec = backend.KerasTensor( + output_shape, dtype=self.compute_dtype + ) + if return_attention_scores: + length = query.shape[1] + attention_shape = (query.shape[0], self.num_heads, length, length) + return output_spec, backend.KerasTensor( + attention_shape, dtype=self.compute_dtype + ) + return output_spec + + +def _index_to_einsum_variable(i): + """Converts an index to a einsum variable name. + + We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'. + """ + return string.ascii_lowercase[i] + + +def _build_attention_equation(rank, attn_axes): + """Builds einsum equations for the attention computation. + + Query, key, value inputs after projection are expected to have the shape as: + `(bs, , , num_heads, channels)`. + `bs` and `` are treated as ``. + + The attention operations can be generalized: + 1. Query-key dot product: + (, , num_heads, channels), + (, , num_heads, channels) -> + (, num_heads, , ) + 2. Combination: + (, num_heads, , ), + (, , num_heads, channels) -> (, , num_heads, channels) + + Args: + rank: Rank of query, key, value tensors. + attn_axes: List/tuple of axes, `[-1, rank)`, + that attention will be applied to. + + Returns: + Einsum equations. + """ + target_notation = "" + for i in range(rank): + target_notation += _index_to_einsum_variable(i) + # `batch_dims` includes the head dim. + batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) + letter_offset = rank + source_notation = "" + for i in range(rank): + if i in batch_dims or i == rank - 1: + source_notation += target_notation[i] + else: + source_notation += _index_to_einsum_variable(letter_offset) + letter_offset += 1 + + product_notation = "".join( + [target_notation[i] for i in batch_dims] + + [target_notation[i] for i in attn_axes] + + [source_notation[i] for i in attn_axes] + ) + dot_product_equation = "%s,%s->%s" % ( + source_notation, + target_notation, + product_notation, + ) + attn_scores_rank = len(product_notation) + combine_equation = "%s,%s->%s" % ( + product_notation, + source_notation, + target_notation, + ) + return dot_product_equation, combine_equation, attn_scores_rank + + +def _build_proj_equation(free_dims, bound_dims, output_dims): + """Builds an einsum equation for projections inside multi-head attention.""" + input_str = "" + kernel_str = "" + output_str = "" + bias_axes = "" + letter_offset = 0 + for i in range(free_dims): + char = _index_to_einsum_variable(i + letter_offset) + input_str += char + output_str += char + + letter_offset += free_dims + for i in range(bound_dims): + char = _index_to_einsum_variable(i + letter_offset) + input_str += char + kernel_str += char + + letter_offset += bound_dims + for i in range(output_dims): + char = _index_to_einsum_variable(i + letter_offset) + kernel_str += char + output_str += char + bias_axes += char + equation = f"{input_str},{kernel_str}->{output_str}" + + return equation, bias_axes, len(output_str) + + +def _get_output_shape(output_rank, known_last_dims): + return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) diff --git a/keras/src/layers/attention/multi_head_attention_test.py b/keras/src/layers/attention/multi_head_attention_test.py new file mode 100644 index 000000000000..d74abbd8841c --- /dev/null +++ b/keras/src/layers/attention/multi_head_attention_test.py @@ -0,0 +1,698 @@ +import os +import warnings + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import constraints +from keras.src import dtype_policies +from keras.src import initializers +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import random +from keras.src import saving +from keras.src import testing +from keras.src.backend.config import disable_flash_attention +from keras.src.backend.config import enable_flash_attention +from keras.src.backend.config import is_flash_attention_enabled + + +class MultiHeadAttentionTest(testing.TestCase): + def setUp(self): + super().setUp() + # Flash attention is a newly introduced feature. We need to disable it + # for testing purposes. + disable_flash_attention() + + def tearDown(self): + enable_flash_attention() + return super().tearDown() + + def test_basics(self): + self.assertFalse(is_flash_attention_enabled()) + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + "value_dim": 4, + "use_bias": False, + "dropout": 0.5, + }, + input_shape={"query_shape": (2, 8, 16), "value_shape": (2, 4, 16)}, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=4, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + def test_basics_with_flash_attention(self): + enable_flash_attention() + if backend.backend() in ("tensorflow", "numpy"): + self.skipTest( + "Flash attention is not supported in tensorflow and numpy " + "backends." + ) + elif backend.backend() == "torch": + try: + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 8, + "dtype": "float16", + }, + input_shape={ + "query_shape": (2, 8, 16), + "value_shape": (2, 4, 16), + }, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + except ImportError as e: + if "Flash attention is not supported" in str(e.args[0]): + self.assertTrue( + ( + "Flash attention is not supported in your current " + "PyTorch version." + ) + in str(e.args[0]) + ) + except RuntimeError as e: + if ( + "Flash attention is not supported with the provided inputs" + in str(e.args[0]) + ): + self.assertTrue( + ( + "Flash attention is not supported with the " + "provided inputs" + ) + in str(e.args[0]) + ) + elif backend.backend() == "jax": + try: + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 8, + "dtype": "float16", + }, + input_shape={ + "query_shape": (2, 8, 16), + "value_shape": (2, 4, 16), + }, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + except ImportError as e: + if "Flash attention is not supported" in str(e.args[0]): + self.assertTrue( + ( + "Flash attention is not supported in your current " + "JAX version." + ) + in str(e.args[0]) + ) + except RuntimeError as e: + if "cuDNN" in str(e.args[0]): + self.assertTrue("cuDNN is not detected." in str(e.args[0])) + elif "Require at least" in str(e.args[0]): + self.assertTrue( + "Require at least Ampere arch to run" in str(e.args[0]) + ) + elif "Flash attention" in str(e.args[0]): + self.assertTrue( + ( + "Flash attention is not supported in your current " + "JAX version." + ) + in str(e.args[0]) + ) + + @parameterized.named_parameters( + ("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)), + ("4d_inputs_1freebatch_mask3", (3, 4), (3, 2), (3, 4, 2), (2,)), + ("4d_inputs_1freebatch_mask4", (3, 4), (3, 2), (3, 2, 4, 2), (2,)), + ("4d_inputs_2d_attention", (3, 4), (3, 2), (3, 4, 3, 2), (1, 2)), + ("5d_inputs_2d_attention", (5, 3, 4), (5, 3, 2), (3, 4, 3, 2), (2, 3)), + ( + "5d_inputs_2d_attention_fullmask", + (5, 3, 4), + (5, 3, 2), + (5, 3, 4, 3, 2), + (2, 3), + ), + ) + def test_high_dim_attention( + self, q_dims, v_dims, mask_dims, attention_axes + ): + batch_size, hidden_size = 3, 8 + query_shape = (batch_size,) + q_dims + (hidden_size,) + value_shape = (batch_size,) + v_dims + (hidden_size,) + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + "attention_axes": attention_axes, + }, + input_shape={ + "query_shape": query_shape, + "value_shape": value_shape, + }, + expected_output_shape=query_shape, + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + @parameterized.named_parameters( + ("without_key_same_proj", (4, 8), (2, 8), None, None), + ("with_key_same_proj", (4, 8), (2, 8), (2, 3), None), + ("without_key_different_proj", (4, 8), (2, 8), None, (3, 4)), + ("with_key_different_proj", (4, 8), (2, 8), (2, 3), (1, 5)), + ("high_dim_same_proj", (4, 2, 3, 8), (1, 1, 5, 8), (1, 1, 5, 2), None), + ( + "high_dim_different_proj", + (4, 2, 3, 8), + (1, 1, 5, 8), + (1, 1, 5, 2), + (3, 2), + ), + ( + "different_qv_last_dims", + (4, 2, 3, 8), + (4, 2, 3, 7), + (4, 2, 3, 8), + None, + ), + ) + def test_compute_output_shape( + self, query_dims, value_dims, key_dims, output_shape + ): + """Test computed shape is equal to the layer output's shape.""" + layer = layers.MultiHeadAttention( + num_heads=2, + key_dim=2, + value_dim=2, + output_shape=output_shape, + ) + batch_size = 7 + query_shape = (batch_size,) + query_dims + value_shape = (batch_size,) + value_dims + key_shape = (batch_size,) + key_dims if key_dims else None + + query = np.ones(query_shape) + value = np.ones(value_shape) + key = np.ones(key_shape) if key_shape else None + output = layer(query=query, value=value, key=key) + comp_output_shape = layer.compute_output_shape( + query_shape, value_shape, key_shape + ) + self.assertEqual(output.shape, comp_output_shape) + + # Test shapes as lists. + comp_output_shape = layer.compute_output_shape( + list(query_shape), + list(value_shape), + list(key_shape) if key_shape is not None else None, + ) + self.assertEqual(output.shape, comp_output_shape) + + @parameterized.named_parameters( + ("query_value_dim_mismatch", (2, 4, 8), (2, 2, 7), (2,)), + ("key_value_dim_mismatch", (2, 4, 8), (2, 2, 8), (2, 1, 7)), + ( + "key_value_dim_mismatch_high_dim", + (2, 4, 2, 3, 8), + (2, 1, 1, 5, 8), + (2, 1, 15, 5, 2), + ), + ) + def test_shape_mismatch_error(self, query_shape, value_shape, key_shape): + """Test dimension mismatches""" + layer = layers.MultiHeadAttention( + num_heads=4, + key_dim=2, + value_dim=2, + ) + with self.assertRaisesRegex(ValueError, r"must be equal"): + layer.compute_output_shape(query_shape, value_shape, key_shape) + with self.assertRaisesRegex(ValueError, r"must be equal"): + layer( + np.ones(query_shape), np.ones(value_shape), np.ones(key_shape) + ) + + def test_initializer(self): + # Test with a specified initializer. + layer = layers.MultiHeadAttention( + num_heads=12, + key_dim=64, + kernel_initializer=initializers.TruncatedNormal(stddev=0.02), + ) + layer.build((2, 4, 8), (2, 4, 8)) + + # Make sure the sub layers have different kernel init value. + self.assertNotAllClose( + layer._query_dense.kernel, + layer._key_dense.kernel, + ) + self.assertNotAllClose( + layer._query_dense.kernel, + layer._value_dense.kernel, + ) + self.assertNotAllClose( + layer._query_dense.kernel, + layer._output_dense.kernel, + ) + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_query_mask_propagation(self): + """Test automatic propagation of the query's mask.""" + try: + layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) + self.assertTrue(layer.supports_masking) + query = np.array( + [[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]] + ) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + query_mask = backend.get_keras_mask(masked_query) + value = np.random.normal(size=(3, 3, 8)) + output = layer(query=masked_query, value=value) + except RuntimeError as e: + if e.args[0].startswith( + "(*bias): last dimension must be contiguous" + ): + self.skipTest( + "PyTorch errors out on GPU: issue to track bug is here " + "https://github.com/keras-team/keras/issues/20459" + ) + self.assertAllClose(query_mask, output._keras_mask) + + @parameterized.named_parameters(("causal", True), ("not_causal", 0)) + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_masking(self, use_causal_mask): + """Test that the value and causal masks are taken into account.""" + layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]]) + masked_value = layers.Embedding(6, 8, mask_zero=True)(value) + output = layer( + query=masked_query, + value=masked_value, + use_causal_mask=use_causal_mask, + ) + mask = np.array( + [[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2] + + [[[1, 0, 0]] * 5] + + [[[1, 1, 1]] + [[0, 0, 0]] * 4] + ) + if use_causal_mask: + mask = mask & np.array([[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]) + del masked_query._keras_mask + del masked_value._keras_mask + output_with_manual_mask = layer( + query=masked_query, value=masked_value, attention_mask=mask + ) + self.assertAllClose(output, output_with_manual_mask) + + def test_masking_with_different_shapes(self): + x = random.uniform(shape=(2, 5, 8)) + mask = ops.tril(ops.ones((5, 5))) # (5, 5) + layer = layers.MultiHeadAttention(num_heads=2, key_dim=4) + output_1 = layer(query=x, value=x, attention_mask=mask) + + mask = ops.tile(mask[None, ...], (2, 1, 1)) # (2, 5, 5) + output_2 = layer(query=x, value=x, attention_mask=mask) + + mask = ops.tile(mask[:, None, ...], (1, 2, 1, 1)) # (2, 2, 5, 5) + output_3 = layer(query=x, value=x, attention_mask=mask) + + self.assertAllClose(output_1, output_2) + self.assertAllClose(output_1, output_3) + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_no_warning_with_keras_mask(self): + layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) + query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]]) + masked_value = layers.Embedding(6, 8, mask_zero=True)(value) + + with warnings.catch_warnings(record=True) as warning_logs: + _ = layer(query=masked_query, value=masked_value) + self.assertLen(warning_logs, 0) + + @parameterized.named_parameters( + ("disable_flash_attention", False), ("enable_flash_attention", True) + ) + def test_correctness(self, flash_attention): + if flash_attention: + # Let the backend decide whether to use flash attention + enable_flash_attention() + dtype = "float16" # Flash attention only accepts float16/bfloat16 + + num_heads = 8 + key_dim = 8 # key_dim % 8 == 0 to enable flash attention + + query = np.identity(key_dim)[np.newaxis, ...] + key = np.identity(key_dim)[np.newaxis, ...] + value = ( + np.reshape(np.arange(key_dim * key_dim), (1, key_dim, key_dim)) + / 100.0 # Prevent overflow/underflow + ) + + # Setup layer. + layer = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=key_dim, dtype=dtype + ) + layer.build(query.shape, key.shape, value.shape) + + # Set layer weights. + kernel = np.identity(key_dim) + # To get an identity kernel we need to add a head dim and repeat on it. + kernel = np.repeat(kernel[:, np.newaxis, :], num_heads, axis=1) + # Zeros for all biases. + bias = np.zeros((num_heads, key_dim)) + output_bias = np.zeros((key_dim,)) + layer.set_weights([kernel, bias] * 3 + [kernel, output_bias]) + # Call layer and assert output. + expected_output = np.array( + [2.406, 2.440, 2.473, 2.504, 2.535, 2.568, 2.602, 2.633] + ) + expected_output = np.tile( + expected_output[np.newaxis, :, np.newaxis], (1, 1, key_dim) + ) + expected_score = np.array( + [ + [0.1187] * 0 + [0.1691] + [0.1187] * 7, + [0.1187] * 1 + [0.1691] + [0.1187] * 6, + [0.1187] * 2 + [0.1691] + [0.1187] * 5, + [0.1187] * 3 + [0.1691] + [0.1187] * 4, + [0.1187] * 4 + [0.1691] + [0.1187] * 3, + [0.1187] * 5 + [0.1691] + [0.1187] * 2, + [0.1187] * 6 + [0.1691] + [0.1187] * 1, + [0.1187] * 7 + [0.1691] + [0.1187] * 0, + ] + ) + expected_score = np.tile( + expected_score[np.newaxis, np.newaxis, ...], (1, key_dim, 1, 1) + ) + if flash_attention: + output = layer(query=query, value=value, key=key) + self.assertAllClose(output, expected_output, atol=1e-2) + else: + output, scores = layer( + query=query, + value=value, + key=key, + return_attention_scores=True, + ) + self.assertAllClose(output, expected_output, atol=1e-2) + self.assertAllClose(scores, expected_score, atol=1e-2) + + def test_mha_constraints(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + num_heads = 2 + key_dim = 2 + layer = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=key_dim, + kernel_constraint="non_neg", + ) + layer.build(query.shape, key.shape, value.shape) + self.assertIsInstance( + layer._query_dense.kernel.constraint, constraints.NonNeg + ) + self.assertIsInstance( + layer._value_dense.kernel.constraint, constraints.NonNeg + ) + self.assertIsInstance( + layer._key_dense.kernel.constraint, constraints.NonNeg + ) + layer = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=key_dim, + bias_constraint="non_neg", + ) + layer.build(query.shape, key.shape, value.shape) + self.assertIsInstance( + layer._query_dense.bias.constraint, constraints.NonNeg + ) + self.assertIsInstance( + layer._value_dense.bias.constraint, constraints.NonNeg + ) + self.assertIsInstance( + layer._key_dense.bias.constraint, constraints.NonNeg + ) + + @pytest.mark.requires_trainable_backend + def test_lora(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + layer = layers.MultiHeadAttention( + num_heads=3, + key_dim=8, + use_bias=False, + ) + layer.build(query.shape, key.shape, value.shape) + layer.query_dense.enable_lora(2) + layer.key_dense.enable_lora(2) + layer.value_dense.enable_lora(2) + + self.assertLen(layer.trainable_variables, 7) + self.assertLen(layer.non_trainable_variables, 3) + + # Try eager call + x = { + "query": query, + "key": key, + "value": value, + } + y = np.random.random((1, 2, 2)) + _ = layer(**x) + + # Try calling fit() + inputs = { + "query": layers.Input((2, 2)), + "key": layers.Input((2, 2)), + "value": layers.Input((2, 2)), + } + outputs = layer(inputs["query"], inputs["key"], inputs["value"]) + model = models.Model(inputs, outputs) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + inputs = { + "query": layers.Input((2, 2)), + "key": layers.Input((2, 2)), + "value": layers.Input((2, 2)), + } + outputs = layers.MultiHeadAttention( + num_heads=3, + key_dim=8, + use_bias=False, + )(inputs["query"], inputs["key"], inputs["value"]) + new_model = models.Model(inputs, outputs) + + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @parameterized.parameters([((1, 2, 3),), ((2, 3, 5),)]) + def test_symbolic_return_attention_scores(self, shape): + mha = layers.MultiHeadAttention(num_heads=4, key_dim=2) + x = layers.Input(batch_shape=shape) + y = layers.Input(batch_shape=shape) + symbolic_out = mha(x, y, return_attention_scores=True) + self.assertLen(symbolic_out, 2) + + x = np.random.random(shape) + y = np.random.random(shape) + out = mha(x, y, return_attention_scores=True) + self.assertLen(out, 2) + self.assertEqual(symbolic_out[0].shape, out[0].shape) + self.assertEqual(symbolic_out[1].shape, out[1].shape) + + def test_dtype_policy_map(self): + quantized_policy = dtype_policies.QuantizedDTypePolicy( + "int8", "float32" + ) + policy_map = dtype_policies.DTypePolicyMap() + + # Preset the quantized policy + policy_map["mha/query"] = quantized_policy + policy_map["mha/key"] = quantized_policy + policy_map["mha/value"] = quantized_policy + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + layer = layers.MultiHeadAttention( + num_heads=3, key_dim=8, use_bias=False, dtype=policy_map, name="mha" + ) + layer.build(query.shape, key.shape, value.shape) + + # Sublayers should be quantized + self.assertDType(layer._query_dense._kernel, "int8") + self.assertDType(layer._key_dense._kernel, "int8") + self.assertDType(layer._value_dense._kernel, "int8") + + def test_flash_attention_with_errors(self): + if backend.backend() in ("numpy", "tensorflow"): + pytest.skip( + reason=( + "Flash attention is not supported on tensorflow and numpy." + ) + ) + # Check `flash_attention=True` and `dropout=0.1` + with self.assertRaisesRegex( + ValueError, + "Dropout is not supported when flash attention is enabled.", + ): + layer = layers.MultiHeadAttention( + num_heads=2, key_dim=2, flash_attention=True, dropout=0.1 + ) + + # Check `flash_attention=True` and `return_attention_scores=True` + layer = layers.MultiHeadAttention( + num_heads=2, key_dim=2, flash_attention=True + ) + self.assertTrue(layer._flash_attention) + query = np.random.random((2, 4, 8)) + value = np.random.random((2, 4, 8)) + with self.assertRaisesRegex( + ValueError, + "Returning attention scores is not supported when flash " + "attention is enabled. Please disable flash attention to access" + " attention scores.", + ): + layer(query=query, value=value, return_attention_scores=True) + + def test_multi_head_attention_output_shape_as_int(self): + """Test MultiHeadAttention with output_shape as an int.""" + mha = layers.MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8) + query = random.uniform((2, 4, 16)) + value = random.uniform((2, 4, 16)) + output = mha(query=query, value=value) + + assert output.shape == ( + 2, + 4, + 8, + ), f"Expected shape (2, 4, 8), got {output.shape}" + + def test_multi_head_attention_output_shape_as_tuple(self): + """Test MultiHeadAttention with output_shape as a tuple.""" + mha = layers.MultiHeadAttention( + num_heads=2, key_dim=16, output_shape=(8, 8) + ) + query = random.uniform((2, 4, 16)) + value = random.uniform((2, 4, 16)) + output = mha(query=query, value=value) + + assert output.shape == ( + 2, + 4, + 8, + 8, + ), f"Expected shape (2, 4, 8, 8), got {output.shape}" + + def test_multi_head_attention_output_shape_error(self): + with self.assertRaisesRegex(ValueError, r"Invalid `output_shape`"): + layers.MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8.0) + + def test_quantize_int8(self): + query = np.array([[[1.0, 0.0], [0.0, 1.0]]]) + key = np.array([[[0.0, 1.0], [1.0, 0.0]]]) + value = np.array([[[1.0, 2.0], [3.0, 4.0]]]) + layer = layers.MultiHeadAttention( + num_heads=3, + key_dim=8, + use_bias=False, + ) + layer.build(query.shape, value.shape, key.shape) + output_float = layer(query, key, value) + for sublayer in layer._flatten_layers(): + try: + sublayer.quantize("int8") + except: + pass + + # Verify weights dtype + self.assertDType(layer._query_dense._kernel, "int8") + self.assertDType(layer._key_dense._kernel, "int8") + self.assertDType(layer._value_dense._kernel, "int8") + self.assertDType(layer._output_dense._kernel, "int8") + + # Try eager call and verify output correctness + output_quantized = layer(query, key, value) + mse = ops.mean(ops.square(output_float - output_quantized)) + self.assertLess(mse, 1e-3) # A weak correctness test diff --git a/keras/src/layers/convolutional/__init__.py b/keras/src/layers/convolutional/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/convolutional/base_conv.py b/keras/src/layers/convolutional/base_conv.py new file mode 100644 index 000000000000..9b43cab4bd22 --- /dev/null +++ b/keras/src/layers/convolutional/base_conv.py @@ -0,0 +1,413 @@ +"""Keras base class for convolution layers.""" + +from keras.src import activations +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.backend import standardize_data_format +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.ops.operation_utils import compute_conv_output_shape +from keras.src.utils.argument_validation import standardize_padding +from keras.src.utils.argument_validation import standardize_tuple + + +class BaseConv(Layer): + """Abstract N-D convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved (actually + cross-correlated) with the layer input to produce a tensor of outputs. If + `use_bias` is True (and a `bias_initializer` is provided), a bias vector is + created and added to the outputs. Finally, if `activation` is not `None`, it + is applied to the outputs as well. + + Note: layer attributes cannot be modified after the layer has been called + once (except the `trainable` attribute). + + Args: + rank: int, the rank of the convolution, e.g. 2 for 2D convolution. + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of `rank` integers, specifying the size + of the convolution window. + strides: int or tuple/list of `rank` integers, specifying the stride + length of the convolution. If only one int is specified, the same + stride size will be used for all dimensions. `strides > 1` is + incompatible with `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of `rank` integers, specifying the + dilation rate to use for dilated convolution. If only one int is + specified, the same dilation rate will be used for all dimensions. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis. Each group is convolved + separately with `filters // groups` filters. The output is the + concatenation of all the `groups` results along the channel axis. + Input channels and `filters` must both be divisible by `groups`. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + lora_rank: Optional integer. If set, the layer's forward pass + will implement LoRA (Low-Rank Adaptation) + with the provided rank. LoRA sets the layer's kernel + to non-trainable and replaces it with a delta over the + original kernel, obtained via multiplying two lower-rank + trainable matrices. This can be useful to reduce the + computation cost of fine-tuning large dense layers. + You can also enable LoRA on an existing layer by calling + `layer.enable_lora(rank)`. + lora_alpha: Optional integer. If set, this parameter scales the + low-rank adaptation delta (computed as the product of two lower-rank + trainable matrices) during the forward pass. The delta is scaled by + `lora_alpha / lora_rank`, allowing you to fine-tune the strength of + the LoRA adjustment independently of `lora_rank`. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + lora_rank=None, + lora_alpha=None, + **kwargs, + ): + super().__init__(activity_regularizer=activity_regularizer, **kwargs) + self.rank = rank + self.filters = filters + self.groups = groups + self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size") + self.strides = standardize_tuple(strides, rank, "strides") + self.dilation_rate = standardize_tuple( + dilation_rate, rank, "dilation_rate" + ) + self.padding = standardize_padding(padding, allow_causal=rank == 1) + self.data_format = standardize_data_format(data_format) + self.activation = activations.get(activation) + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank + self.lora_enabled = False + self.input_spec = InputSpec(min_ndim=self.rank + 2) + self.data_format = self.data_format + + if self.filters is not None and self.filters <= 0: + raise ValueError( + "Invalid value for argument `filters`. Expected a strictly " + f"positive value. Received filters={self.filters}." + ) + + if self.groups <= 0: + raise ValueError( + "The number of groups must be a positive integer. " + f"Received: groups={self.groups}." + ) + + if self.filters is not None and self.filters % self.groups != 0: + raise ValueError( + "The number of filters must be evenly divisible by the " + f"number of groups. Received: groups={self.groups}, " + f"filters={self.filters}." + ) + + if not all(self.kernel_size): + raise ValueError( + "The argument `kernel_size` cannot contain 0. Received " + f"kernel_size={self.kernel_size}." + ) + + if not all(self.strides): + raise ValueError( + "The argument `strides` cannot contains 0. Received " + f"strides={self.strides}" + ) + + if max(self.strides) > 1 and max(self.dilation_rate) > 1: + raise ValueError( + "`strides > 1` not supported in conjunction with " + f"`dilation_rate > 1`. Received: strides={self.strides} and " + f"dilation_rate={self.dilation_rate}" + ) + + def build(self, input_shape): + if self.data_format == "channels_last": + channel_axis = -1 + input_channel = input_shape[-1] + else: + channel_axis = 1 + input_channel = input_shape[1] + self.input_spec = InputSpec( + min_ndim=self.rank + 2, axes={channel_axis: input_channel} + ) + if input_channel % self.groups != 0: + raise ValueError( + "The number of input channels must be evenly divisible by " + f"the number of groups. Received groups={self.groups}, but the " + f"input has {input_channel} channels (full input shape is " + f"{input_shape})." + ) + kernel_shape = self.kernel_size + ( + input_channel // self.groups, + self.filters, + ) + + # compute_output_shape contains some validation logic for the input + # shape, and make sure the output shape has all positive dimensions. + self.compute_output_shape(input_shape) + + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + dtype=self.dtype, + ) + if self.use_bias: + self.bias = self.add_weight( + name="bias", + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype, + ) + else: + self.bias = None + self.built = True + if self.lora_rank: + self.enable_lora(self.lora_rank, lora_alpha=self.lora_alpha) + + @property + def kernel(self): + if not self.built: + raise AttributeError( + "You must build the layer before accessing `kernel`." + ) + if self.lora_enabled: + return self._kernel + ( + self.lora_alpha / self.lora_rank + ) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b) + return self._kernel + + def convolution_op(self, inputs, kernel): + return ops.conv( + inputs, + kernel, + strides=list(self.strides), + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + def call(self, inputs): + outputs = self.convolution_op( + inputs, + self.kernel, + ) + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs = ops.add(outputs, bias) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def enable_lora( + self, + rank, + lora_alpha=None, + a_initializer="he_uniform", + b_initializer="zeros", + ): + if self.kernel_constraint: + raise ValueError( + "Lora is incompatible with kernel constraints. " + "In order to enable lora on this layer, remove the " + "`kernel_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. This can only be done once per layer." + ) + self._tracker.unlock() + self.lora_kernel_a = self.add_weight( + name="lora_kernel_a", + shape=self._kernel.shape[:-1] + (rank,), + initializer=initializers.get(a_initializer), + regularizer=self.kernel_regularizer, + ) + self.lora_kernel_b = self.add_weight( + name="lora_kernel_b", + shape=(rank, self.filters), + initializer=initializers.get(b_initializer), + regularizer=self.kernel_regularizer, + ) + self._kernel.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + self.lora_alpha = lora_alpha if lora_alpha is not None else rank + + def save_own_variables(self, store): + # Do nothing if the layer isn't yet built + if not self.built: + return + target_variables = [self.kernel] + if self.use_bias: + target_variables.append(self.bias) + for i, variable in enumerate(target_variables): + store[str(i)] = variable + + def load_own_variables(self, store): + if not self.lora_enabled: + self._check_load_own_variables(store) + # Do nothing if the layer isn't yet built + if not self.built: + return + target_variables = [self._kernel] + if self.use_bias: + target_variables.append(self.bias) + for i, variable in enumerate(target_variables): + variable.assign(store[str(i)]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "groups": self.groups, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": initializers.serialize( + self.bias_initializer + ), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": regularizers.serialize( + self.bias_regularizer + ), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize( + self.kernel_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + if self.lora_rank: + config["lora_rank"] = self.lora_rank + config["lora_alpha"] = self.lora_alpha + return config + + def _check_load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) diff --git a/keras/src/layers/convolutional/base_conv_transpose.py b/keras/src/layers/convolutional/base_conv_transpose.py new file mode 100644 index 000000000000..101a7d47d2a1 --- /dev/null +++ b/keras/src/layers/convolutional/base_conv_transpose.py @@ -0,0 +1,259 @@ +"""Keras base class for transpose convolution layers.""" + +from keras.src import activations +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.backend import standardize_data_format +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_output_shape, +) +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils.argument_validation import standardize_padding +from keras.src.utils.argument_validation import standardize_tuple + + +class BaseConvTranspose(Layer): + """Abstract N-D transposed convolution layer. + + The need for transposed convolutions generally arises from the desire to use + a transformation going in the opposite direction of a normal convolution, + i.e., from something that has the shape of the output of some convolution to + something that has the shape of its input while maintaining a connectivity + pattern that is compatible with said convolution. + + Args: + rank: int, the rank of the transposed convolution, e.g. 2 for 2D + transposed convolution. + filters: int, the dimension of the output space (the number of filters + in the transposed convolution). + kernel_size: int or tuple/list of `rank` integers, specifying the size + of the transposed convolution window. + strides: int or tuple/list of `rank` integers, specifying the stride + length of the transposed convolution. If only one int is specified, + the same stride size will be used for all dimensions. + `strides > 1` is incompatible with `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of `rank` integers, specifying the + dilation rate to use for dilated convolution. If only one int is + specified, the same dilation rate will be used for all dimensions. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs, + ): + super().__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs, + ) + self.rank = rank + self.filters = filters + self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size") + self.strides = standardize_tuple(strides, rank, "strides") + self.dilation_rate = standardize_tuple( + dilation_rate, rank, "dilation_rate" + ) + self.padding = standardize_padding(padding) + if output_padding is None: + self.output_padding = None + else: + self.output_padding = standardize_tuple( + output_padding, + rank, + "output_padding", + allow_zero=True, + ) + self.data_format = standardize_data_format(data_format) + self.activation = activations.get(activation) + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.input_spec = InputSpec(min_ndim=self.rank + 2) + self.data_format = self.data_format + + if self.filters is not None and self.filters <= 0: + raise ValueError( + "Invalid value for argument `filters`. Expected a strictly " + f"positive value. Received filters={self.filters}." + ) + + if not all(self.kernel_size): + raise ValueError( + "The argument `kernel_size` cannot contain 0. Received " + f"kernel_size={self.kernel_size}." + ) + + if not all(self.strides): + raise ValueError( + "The argument `strides` cannot contains 0. Received " + f"strides={self.strides}." + ) + + if max(self.strides) > 1 and max(self.dilation_rate) > 1: + raise ValueError( + "`strides > 1` not supported in conjunction with " + f"`dilation_rate > 1`. Received: strides={self.strides} and " + f"dilation_rate={self.dilation_rate}" + ) + + def build(self, input_shape): + if self.data_format == "channels_last": + channel_axis = -1 + input_channel = input_shape[-1] + else: + channel_axis = 1 + input_channel = input_shape[1] + self.input_spec = InputSpec( + min_ndim=self.rank + 2, axes={channel_axis: input_channel} + ) + kernel_shape = self.kernel_size + ( + self.filters, + input_channel, + ) + + self.kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=True, + dtype=self.dtype, + ) + if self.use_bias: + self.bias = self.add_weight( + name="bias", + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype, + ) + else: + self.bias = None + + def call(self, inputs): + outputs = ops.conv_transpose( + inputs, + self.kernel, + strides=list(self.strides), + padding=self.padding, + output_padding=self.output_padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs = ops.add(outputs, bias) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def compute_output_shape(self, input_shape): + return compute_conv_transpose_output_shape( + input_shape, + self.kernel_size, + self.filters, + strides=self.strides, + padding=self.padding, + output_padding=self.output_padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": initializers.serialize( + self.bias_initializer + ), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": regularizers.serialize( + self.bias_regularizer + ), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize( + self.kernel_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config diff --git a/keras/src/layers/convolutional/base_depthwise_conv.py b/keras/src/layers/convolutional/base_depthwise_conv.py new file mode 100644 index 000000000000..b4e529d607f9 --- /dev/null +++ b/keras/src/layers/convolutional/base_depthwise_conv.py @@ -0,0 +1,273 @@ +"""Keras base class for depthwise convolution layers.""" + +from keras.src import activations +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.backend import standardize_data_format +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.ops.operation_utils import compute_conv_output_shape +from keras.src.utils.argument_validation import standardize_padding +from keras.src.utils.argument_validation import standardize_tuple + + +class BaseDepthwiseConv(Layer): + """Abstract N-D depthwise convolution layer. + + Depthwise convolution is a type of convolution in which each input channel + is convolved with a different kernel (called a depthwise kernel). You can + understand depthwise convolution as the first step in a depthwise separable + convolution. + + It is implemented via the following steps: + + - Split the input into individual channels. + - Convolve each channel with an individual depthwise kernel with + `depth_multiplier` output channels. + - Concatenate the convolved outputs along the channels axis. + + Unlike a regular convolution, depthwise convolution does not mix information + across different input channels. + + The `depth_multiplier` argument determines how many filter are applied to + one input channel. As such, it controls the amount of output channels that + are generated per input channel in the depthwise step. + + + Args: + rank: int, the rank of the convolution, e.g. 2 for 2D convolution. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `input_channel * depth_multiplier`. + kernel_size: int or tuple/list of `rank` integers, specifying the size + of the depthwise convolution window. + strides: int or tuple/list of `rank` integers, specifying the stride + length of the depthwise convolution. If only one int is specified, + the same stride size will be used for all dimensions. + `strides > 1` is incompatible with `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of `rank` integers, specifying the + dilation rate to use for dilated convolution. If only one int is + specified, the same dilation rate will be used for all dimensions. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + depthwise_initializer: Initializer for the depthwsie convolution + kernel. If `None`, the default initializer (`"glorot_uniform"`) + will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + depthwise_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + """ + + def __init__( + self, + rank, + depth_multiplier, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs, + ): + super().__init__( + trainable=trainable, + name=name, + activity_regularizer=regularizers.get(activity_regularizer), + **kwargs, + ) + self.rank = rank + self.depth_multiplier = depth_multiplier + self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size") + self.strides = standardize_tuple(strides, rank, "strides") + self.dilation_rate = standardize_tuple( + dilation_rate, rank, "dilation_rate" + ) + self.padding = standardize_padding(padding) + self.data_format = standardize_data_format(data_format) + self.activation = activations.get(activation) + self.use_bias = use_bias + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.input_spec = InputSpec(min_ndim=self.rank + 2) + self.data_format = self.data_format + + if self.depth_multiplier is not None and self.depth_multiplier <= 0: + raise ValueError( + "Invalid value for argument `depth_multiplier`. Expected a " + "strictly positive value. Received " + f"depth_multiplier={self.depth_multiplier}." + ) + + if not all(self.kernel_size): + raise ValueError( + "The argument `kernel_size` cannot contain 0. Received " + f"kernel_size={self.kernel_size}." + ) + + if not all(self.strides): + raise ValueError( + "The argument `strides` cannot contains 0. Received " + f"strides={self.strides}" + ) + + if max(self.strides) > 1 and max(self.dilation_rate) > 1: + raise ValueError( + "`strides > 1` not supported in conjunction with " + f"`dilation_rate > 1`. Received: strides={self.strides} and " + f"dilation_rate={self.dilation_rate}" + ) + + def build(self, input_shape): + if self.data_format == "channels_last": + channel_axis = -1 + input_channel = input_shape[-1] + else: + channel_axis = 1 + input_channel = input_shape[1] + self.input_spec = InputSpec( + min_ndim=self.rank + 2, axes={channel_axis: input_channel} + ) + depthwise_shape = self.kernel_size + ( + input_channel, + self.depth_multiplier, + ) + self.kernel = self.add_weight( + name="kernel", + shape=depthwise_shape, + initializer=self.depthwise_initializer, + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint, + trainable=True, + dtype=self.dtype, + ) + if self.use_bias: + self.bias = self.add_weight( + name="bias", + shape=(self.depth_multiplier * input_channel,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype, + ) + else: + self.bias = None + + def _get_input_channel(self, input_shape): + if self.data_format == "channels_last": + input_channel = input_shape[-1] + else: + input_channel = input_shape[1] + return input_channel + + def call(self, inputs): + input_channel = self._get_input_channel(inputs.shape) + outputs = ops.depthwise_conv( + inputs, + self.kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + ( + self.depth_multiplier * input_channel, + ) + else: + bias_shape = (1, self.depth_multiplier * input_channel) + ( + 1, + ) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs = ops.add(outputs, bias) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def compute_output_shape(self, input_shape): + input_channel = self._get_input_channel(input_shape) + return compute_conv_output_shape( + input_shape, + self.depth_multiplier * input_channel, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "depth_multiplier": self.depth_multiplier, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "depthwise_initializer": initializers.serialize( + self.depthwise_initializer + ), + "bias_initializer": initializers.serialize( + self.bias_initializer + ), + "depthwise_regularizer": regularizers.serialize( + self.depthwise_regularizer + ), + "bias_regularizer": regularizers.serialize( + self.bias_regularizer + ), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "depthwise_constraint": constraints.serialize( + self.depthwise_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config diff --git a/keras/src/layers/convolutional/base_separable_conv.py b/keras/src/layers/convolutional/base_separable_conv.py new file mode 100644 index 000000000000..2fcfc23fe521 --- /dev/null +++ b/keras/src/layers/convolutional/base_separable_conv.py @@ -0,0 +1,294 @@ +"""Keras abstract base layer for separable convolution.""" + +from keras.src import activations +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.backend import standardize_data_format +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.ops.operation_utils import compute_conv_output_shape +from keras.src.utils.argument_validation import standardize_padding +from keras.src.utils.argument_validation import standardize_tuple + + +class BaseSeparableConv(Layer): + """Abstract base layer for separable convolution. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. If + `use_bias` is True and a bias initializer is provided, it adds a bias vector + to the output. + + Args: + rank: int, the rank of the convolution, e.g. 2 for 2D convolution. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `input_channel * depth_multiplier`. + filters: int, the dimensionality of the output space (i.e. the number + of filters in the pointwise convolution). + kernel_size: int or tuple/list of `rank` integers, specifying the size + of the depthwise convolution window. + strides: int or tuple/list of `rank` integers, specifying the stride + length of the depthwise convolution. If only one int is specified, + the same stride size will be used for all dimensions. + `stride value != 1` is incompatible with `dilation_rate != 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of `rank` integers, specifying the + dilation rate to use for dilated convolution. If only one int is + specified, the same dilation rate will be used for all dimensions. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + depthwise_initializer: An initializer for the depthwise convolution + kernel. If None, then the default initializer (`"glorot_uniform"`) + will be used. + pointwise_initializer: An initializer for the pointwise convolution + kernel. If None, then the default initializer (`"glorot_uniform"`) + will be used. + bias_initializer: An initializer for the bias vector. If None, the + default initializer ('"zeros"') will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used + for norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + """ + + def __init__( + self, + rank, + depth_multiplier, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + trainable=True, + name=None, + **kwargs, + ): + super().__init__( + trainable=trainable, + name=name, + activity_regularizer=regularizers.get(activity_regularizer), + **kwargs, + ) + self.rank = rank + self.depth_multiplier = depth_multiplier + self.filters = filters + self.kernel_size = standardize_tuple(kernel_size, rank, "kernel_size") + self.strides = standardize_tuple(strides, rank, "strides") + self.dilation_rate = standardize_tuple( + dilation_rate, rank, "dilation_rate" + ) + self.padding = standardize_padding(padding) + self.data_format = standardize_data_format(data_format) + self.activation = activations.get(activation) + self.use_bias = use_bias + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.pointwise_initializer = initializers.get(pointwise_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.pointwise_regularizer = regularizers.get(pointwise_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.pointwise_constraint = constraints.get(pointwise_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.data_format = self.data_format + + self.input_spec = InputSpec(min_ndim=self.rank + 2) + + if self.depth_multiplier is not None and self.depth_multiplier <= 0: + raise ValueError( + "Invalid value for argument `depth_multiplier`. Expected a " + "strictly positive value. Received " + f"depth_multiplier={self.depth_multiplier}." + ) + + if self.filters is not None and self.filters <= 0: + raise ValueError( + "Invalid value for argument `filters`. Expected a strictly " + f"positive value. Received filters={self.filters}." + ) + + if not all(self.kernel_size): + raise ValueError( + "The argument `kernel_size` cannot contain 0. Received: " + f"kernel_size={self.kernel_size}." + ) + + if not all(self.strides): + raise ValueError( + "The argument `strides` cannot contains 0(s). Received: " + f"strides={self.strides}" + ) + + if max(self.strides) > 1 and max(self.dilation_rate) > 1: + raise ValueError( + "`strides > 1` not supported in conjunction with " + f"`dilation_rate > 1`. Received: strides={self.strides} and " + f"dilation_rate={self.dilation_rate}" + ) + + def build(self, input_shape): + if self.data_format == "channels_last": + channel_axis = -1 + input_channel = input_shape[-1] + else: + channel_axis = 1 + input_channel = input_shape[1] + self.input_spec = InputSpec( + min_ndim=self.rank + 2, axes={channel_axis: input_channel} + ) + depthwise_kernel_shape = self.kernel_size + ( + input_channel, + self.depth_multiplier, + ) + pointwise_kernel_shape = (1,) * self.rank + ( + self.depth_multiplier * input_channel, + self.filters, + ) + + self.depthwise_kernel = self.add_weight( + name="depthwise_kernel", + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint, + trainable=True, + dtype=self.dtype, + ) + self.pointwise_kernel = self.add_weight( + name="pointwise_kernel", + shape=pointwise_kernel_shape, + initializer=self.pointwise_initializer, + regularizer=self.pointwise_regularizer, + constraint=self.pointwise_constraint, + trainable=True, + dtype=self.dtype, + ) + if self.use_bias: + self.bias = self.add_weight( + name="bias", + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + dtype=self.dtype, + ) + else: + self.bias = None + + def call(self, inputs): + outputs = ops.separable_conv( + inputs, + self.depthwise_kernel, + self.pointwise_kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs = ops.add(outputs, bias) + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "depth_multiplier": self.depth_multiplier, + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "depthwise_initializer": initializers.serialize( + self.depthwise_initializer + ), + "pointwise_initializer": initializers.serialize( + self.pointwise_initializer + ), + "bias_initializer": initializers.serialize( + self.bias_initializer + ), + "depthwise_regularizer": regularizers.serialize( + self.depthwise_regularizer + ), + "pointwise_regularizer": regularizers.serialize( + self.pointwise_regularizer + ), + "bias_regularizer": regularizers.serialize( + self.bias_regularizer + ), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "depthwise_constraint": constraints.serialize( + self.depthwise_constraint + ), + "pointwise_constraint": constraints.serialize( + self.pointwise_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + ) + return config diff --git a/keras/src/layers/convolutional/conv1d.py b/keras/src/layers/convolutional/conv1d.py new file mode 100644 index 000000000000..ce1ced8c422b --- /dev/null +++ b/keras/src/layers/convolutional/conv1d.py @@ -0,0 +1,170 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_conv import BaseConv + + +@keras_export(["keras.layers.Conv1D", "keras.layers.Convolution1D"]) +class Conv1D(BaseConv): + """1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved with the layer + input over a single spatial (or temporal) dimension to produce a tensor of + outputs. If `use_bias` is True, a bias vector is created and added to the + outputs. Finally, if `activation` is not `None`, it is applied to the + outputs as well. + + Args: + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of 1 integer, specifying the size of the + convolution window. + strides: int or tuple/list of 1 integer, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, `"valid"`, `"same"` or `"causal"`(case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + `"causal"` results in causal(dilated) convolutions, e.g. `output[t]` + does not depend on`input[t+1:]`. Useful when modeling temporal data + where the model should not violate the temporal order. + See [WaveNet: A Generative Model for Raw Audio, section2.1]( + https://arxiv.org/abs/1609.03499). + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 1 integers, specifying the dilation + rate to use for dilated convolution. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis. Each group is convolved + separately with `filters // groups` filters. The output is the + concatenation of all the `groups` results along the channel axis. + Input channels and `filters` must both be divisible by `groups`. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, steps, channels)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, channels, steps)` + + Output shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, new_steps, filters)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, filters, new_steps)` + + Returns: + A 3D tensor representing `activation(conv1d(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + Example: + + >>> # The inputs are 128-length vectors with 10 timesteps, and the + >>> # batch size is 4. + >>> x = np.random.rand(4, 10, 128) + >>> y = keras.layers.Conv1D(32, 3, activation='relu')(x) + >>> print(y.shape) + (4, 8, 32) + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) + + def _compute_causal_padding(self): + left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1) + if self.data_format == "channels_last": + causal_padding = [[0, 0], [left_pad, 0], [0, 0]] + else: + causal_padding = [[0, 0], [0, 0], [left_pad, 0]] + return causal_padding + + def call(self, inputs): + padding = self.padding + if self.padding == "causal": + # Apply causal padding to inputs. + inputs = ops.pad(inputs, self._compute_causal_padding()) + padding = "valid" + + outputs = ops.conv( + inputs, + self.kernel, + strides=list(self.strides), + padding=padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs = ops.add(outputs, bias) + + if self.activation is not None: + return self.activation(outputs) + return outputs diff --git a/keras/src/layers/convolutional/conv1d_transpose.py b/keras/src/layers/convolutional/conv1d_transpose.py new file mode 100644 index 000000000000..01c2d245973d --- /dev/null +++ b/keras/src/layers/convolutional/conv1d_transpose.py @@ -0,0 +1,140 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_conv_transpose import BaseConvTranspose + + +@keras_export( + [ + "keras.layers.Conv1DTranspose", + "keras.layers.Convolution1DTranspose", + ] +) +class Conv1DTranspose(BaseConvTranspose): + """1D transposed convolution layer. + + The need for transposed convolutions generally arise from the desire to use + a transformation going in the opposite direction of a normal convolution, + i.e., from something that has the shape of the output of some convolution + to something that has the shape of its input while maintaining a + connectivity pattern that is compatible with said convolution. + + Args: + filters: int, the dimension of the output space (the number of filters + in the transpose convolution). + kernel_size: int or tuple/list of 1 integer, specifying the size of the + transposed convolution window. + strides: int or tuple/list of 1 integer, specifying the stride length + of the transposed convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + output_padding: An integer tuple/list of 1 integer specifying the + amount of padding along the time dimension of the output tensor. + The amount of output padding must be lower than the stride. + If set to `None` (default), the output shape is inferred. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: An integer tuple/list of 1 integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying a `dilation_rate` value != 1 is + incompatible with specifying a stride value != 1. + Also dilation rate larger than 1 is not currently supported. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, steps, channels)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, channels, steps)` + + Output shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, new_steps, filters)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, filters, new_steps)` + + Returns: + A 3D tensor representing + `activation(conv1d_transpose(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + References: + - [A guide to convolution arithmetic for deep learning]( + https://arxiv.org/abs/1603.07285v1) + - [Deconvolutional Networks]( + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf) + + Example: + + >>> x = np.random.rand(4, 10, 128) + >>> y = keras.layers.Conv1DTranspose(32, 3, 2, activation='relu')(x) + >>> print(y.shape) + (4, 21, 32) + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/conv2d.py b/keras/src/layers/convolutional/conv2d.py new file mode 100644 index 000000000000..577ff664e841 --- /dev/null +++ b/keras/src/layers/convolutional/conv2d.py @@ -0,0 +1,137 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_conv import BaseConv + + +@keras_export(["keras.layers.Conv2D", "keras.layers.Convolution2D"]) +class Conv2D(BaseConv): + """2D convolution layer. + + This layer creates a convolution kernel that is convolved with the layer + input over a 2D spatial (or temporal) dimension (height and width) to + produce a tensor of outputs. If `use_bias` is True, a bias vector is created + and added to the outputs. Finally, if `activation` is not `None`, it is + applied to the outputs as well. + + Note on numerical precision: While in general Keras operation execution + results are identical across backends up to 1e-7 precision in float32, + `Conv2D` operations may show larger variations. Due to the large + number of element-wise multiplications and additions in convolution + operations, especially with large inputs or kernel sizes, accumulated + floating-point differences can exceed this 1e-7 threshold. These variations + are particularly noticeable when using different backends (e.g., TensorFlow + vs JAX) or different hardware. + + Args: + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of 2 integer, specifying the size of the + convolution window. + strides: int or tuple/list of 2 integer, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dilation_rate: int or tuple/list of 2 integers, specifying the dilation + rate to use for dilated convolution. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis. Each group is convolved + separately with `filters // groups` filters. The output is the + concatenation of all the `groups` results along the channel axis. + Input channels and `filters` must both be divisible by `groups`. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, height, width, channels)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, channels, height, width)` + + Output shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, new_height, new_width, filters)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, filters, new_height, new_width)` + + Returns: + A 4D tensor representing `activation(conv2d(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + Example: + + >>> x = np.random.rand(4, 10, 10, 128) + >>> y = keras.layers.Conv2D(32, 3, activation='relu')(x) + >>> print(y.shape) + (4, 8, 8, 32) + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/conv2d_transpose.py b/keras/src/layers/convolutional/conv2d_transpose.py new file mode 100644 index 000000000000..33e0f9c607be --- /dev/null +++ b/keras/src/layers/convolutional/conv2d_transpose.py @@ -0,0 +1,148 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_conv_transpose import BaseConvTranspose + + +@keras_export( + [ + "keras.layers.Conv2DTranspose", + "keras.layers.Convolution2DTranspose", + ] +) +class Conv2DTranspose(BaseConvTranspose): + """2D transposed convolution layer. + + The need for transposed convolutions generally arise from the desire to use + a transformation going in the opposite direction of a normal convolution, + i.e., from something that has the shape of the output of some convolution + to something that has the shape of its input while maintaining a + connectivity pattern that is compatible with said convolution. + + Args: + filters: int, the dimension of the output space (the number of filters + in the transposed convolution). + kernel_size: int or tuple/list of 1 integer, specifying the size of the + transposed convolution window. + strides: int or tuple/list of 1 integer, specifying the stride length + of the transposed convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + output_padding: An integer or tuple/list of 2 integers, + specifying the amount of padding along the height and width + of the output tensor. + Can be a single integer to specify the same value for all + spatial dimensions. + The amount of output padding along a given dimension must be + lower than the stride along that same dimension. + If set to `None` (default), the output shape is inferred. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dilation_rate: An integer or tuple/list of 2 integers, + specifying the dilation rate for + all spatial dimensions for dilated convolution. + Specifying different dilation rates + for different dimensions is not supported. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, height, width, channels)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, channels, height, width)` + + Output shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, new_height, new_width, filters)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, filters, new_height, new_width)` + + Returns: + A 4D tensor representing + `activation(conv2d_transpose(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + References: + - [A guide to convolution arithmetic for deep learning]( + https://arxiv.org/abs/1603.07285v1) + - [Deconvolutional Networks]( + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf) + + Example: + + >>> x = np.random.rand(4, 10, 8, 128) + >>> y = keras.layers.Conv2DTranspose(32, 2, 2, activation='relu')(x) + >>> print(y.shape) + (4, 20, 16, 32) + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/conv3d.py b/keras/src/layers/convolutional/conv3d.py new file mode 100644 index 000000000000..4badd2042c37 --- /dev/null +++ b/keras/src/layers/convolutional/conv3d.py @@ -0,0 +1,134 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_conv import BaseConv + + +@keras_export(["keras.layers.Conv3D", "keras.layers.Convolution3D"]) +class Conv3D(BaseConv): + """3D convolution layer. + + This layer creates a convolution kernel that is convolved with the layer + input over a 3D spatial (or temporal) dimension (width,height and depth) to + produce a tensor of outputs. If `use_bias` is True, a bias vector is created + and added to the outputs. Finally, if `activation` is not `None`, it is + applied to the outputs as well. + + Args: + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of 3 integer, specifying the size of the + convolution window. + strides: int or tuple/list of 3 integer, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + It defaults to the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. If you never set it, then it + will be `"channels_last"`. + dilation_rate: int or tuple/list of 3 integers, specifying the dilation + rate to use for dilated convolution. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis. Each group is convolved + separately with `filters // groups` filters. The output is the + concatenation of all the `groups` results along the channel axis. + Input channels and `filters` must both be divisible by `groups`. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` + + Output shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3, + filters)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, filters, new_spatial_dim1, new_spatial_dim2, + new_spatial_dim3)` + + Returns: + A 5D tensor representing `activation(conv3d(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + Example: + + >>> x = np.random.rand(4, 10, 10, 10, 128) + >>> y = keras.layers.Conv3D(32, 3, activation='relu')(x) + >>> print(y.shape) + (4, 8, 8, 8, 32) + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1, 1), + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/conv3d_transpose.py b/keras/src/layers/convolutional/conv3d_transpose.py new file mode 100644 index 000000000000..a46696563aa1 --- /dev/null +++ b/keras/src/layers/convolutional/conv3d_transpose.py @@ -0,0 +1,152 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_conv_transpose import BaseConvTranspose + + +@keras_export( + [ + "keras.layers.Conv3DTranspose", + "keras.layers.Convolution3DTranspose", + ] +) +class Conv3DTranspose(BaseConvTranspose): + """3D transposed convolution layer. + + The need for transposed convolutions generally arise from the desire to use + a transformation going in the opposite direction of a normal convolution, + i.e., from something that has the shape of the output of some convolution + to something that has the shape of its input while maintaining a + connectivity pattern that is compatible with said convolution. + + Args: + filters: int, the dimension of the output space (the number of filters + in the transposed convolution). + kernel_size: int or tuple/list of 1 integer, specifying the size of the + transposed convolution window. + strides: int or tuple/list of 1 integer, specifying the stride length + of the transposed convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + output_padding: An integer or tuple/list of 3 integers, + specifying the amount of padding along the depth, height, and + width. + Can be a single integer to specify the same value for all + spatial dimensions. + The amount of output padding along a given dimension must be + lower than the stride along that same dimension. + If set to `None` (default), the output shape is inferred. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + It defaults to the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. If you never set it, then it + will be `"channels_last"`. + dilation_rate: an integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + kernel_initializer: Initializer for the convolution kernel. If `None`, + the default initializer (`"glorot_uniform"`) will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + kernel_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` + + Output shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3, + filters)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, filters, new_spatial_dim1, new_spatial_dim2, + new_spatial_dim3)` + + Returns: + A 5D tensor representing `activation(conv3d(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + References: + - [A guide to convolution arithmetic for deep learning]( + https://arxiv.org/abs/1603.07285v1) + - [Deconvolutional Networks]( + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf) + + Example: + + >>> x = np.random.rand(4, 10, 8, 12, 128) + >>> y = keras.layers.Conv3DTranspose(32, 2, 2, activation='relu')(x) + >>> print(y.shape) + (4, 20, 16, 24, 32) + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format=None, + output_padding=None, + dilation_rate=(1, 1, 1), + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/conv_test.py b/keras/src/layers/convolutional/conv_test.py new file mode 100644 index 000000000000..a734fa3b9cf2 --- /dev/null +++ b/keras/src/layers/convolutional/conv_test.py @@ -0,0 +1,1097 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized +from numpy.lib.stride_tricks import as_strided + +from keras.src import backend +from keras.src import constraints +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import saving +from keras.src import testing + + +def _same_padding(input_size, kernel_size, stride): + if input_size % stride == 0: + padding = max(kernel_size - stride, 0) + else: + padding = max(kernel_size - (input_size % stride), 0) + return padding // 2, padding - padding // 2 + + +def np_conv1d( + x, + kernel_weights, + bias_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.swapaxes(1, 2) + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + if isinstance(dilation_rate, (tuple, list)): + dilation_rate = dilation_rate[0] + kernel_size, ch_in, ch_out = kernel_weights.shape + + if dilation_rate > 1: + new_kernel_size = kernel_size + (dilation_rate - 1) * (kernel_size - 1) + new_kernel_weights = np.zeros( + (new_kernel_size, ch_in, ch_out), dtype=kernel_weights.dtype + ) + new_kernel_weights[::dilation_rate] = kernel_weights + kernel_weights = new_kernel_weights + kernel_size = kernel_weights.shape[0] + + if padding != "valid": + n_batch, h_x, _ = x.shape + h_pad = _same_padding(h_x, kernel_size, h_stride) + npad = [(0, 0)] * x.ndim + if padding == "causal": + npad[1] = (h_pad[0] + h_pad[1], 0) + else: + npad[1] = h_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, _ = x.shape + h_out = int((h_x - kernel_size) / h_stride) + 1 + + kernel_weights = kernel_weights.reshape(-1, ch_out) + bias_weights = bias_weights.reshape(1, ch_out) + + out_grps = [] + for grp in range(1, groups + 1): + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + stride_shape = (n_batch, h_out, kernel_size, ch_in) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + x_in.strides[1], + x_in.strides[2], + ) + inner_dim = kernel_size * ch_in + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(n_batch, h_out, inner_dim) + ch_out_groups = ch_out // groups + kernel_weights_grp = kernel_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ] + bias_weights_grp = bias_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ] + out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp) + out = np.concatenate(out_grps, axis=-1) + if data_format == "channels_first": + out = out.swapaxes(1, 2) + return out + + +def np_conv2d( + x, + kernel_weights, + bias_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride = strides + else: + h_stride = strides + w_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + h_kernel, w_kernel, ch_in, ch_out = kernel_weights.shape + + if h_dilation > 1 or w_dilation > 1: + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_kenel_size_tuple = (new_h_kernel, new_w_kernel) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_in, ch_out), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel, w_kernel = kernel_weights.shape[:2] + + if padding == "same": + n_batch, h_x, w_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + w_pad = _same_padding(w_x, w_kernel, w_stride) + npad = [(0, 0)] * x.ndim + npad[1] = h_pad + npad[2] = w_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, w_x, _ = x.shape + h_out = int((h_x - h_kernel) / h_stride) + 1 + w_out = int((w_x - w_kernel) / w_stride) + 1 + + out_grps = [] + for grp in range(1, groups + 1): + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel, ch_in) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + w_stride * x_in.strides[2], + x_in.strides[1], + x_in.strides[2], + x_in.strides[3], + ) + inner_dim = h_kernel * w_kernel * ch_in + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(-1, inner_dim) + ch_out_groups = ch_out // groups + kernel_weights_grp = kernel_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ].reshape(-1, ch_out_groups) + bias_weights_grp = bias_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ] + out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp) + out = np.concatenate(out_grps, axis=-1).reshape( + n_batch, h_out, w_out, ch_out + ) + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + return out + + +def np_conv3d( + x, + kernel_weights, + bias_weights, + strides, + padding, + data_format, + dilation_rate, + groups, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride, d_stride = strides + else: + h_stride = strides + w_stride = strides + d_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation, d_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + d_dilation = dilation_rate + + h_kernel, w_kernel, d_kernel, ch_in, ch_out = kernel_weights.shape + + if h_dilation > 1 or w_dilation > 1 or d_dilation > 1: + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1) + new_kenel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_in, ch_out), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = ( + kernel_weights + ) + kernel_weights = new_kernel_weights + h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3] + + if padding == "same": + n_batch, h_x, w_x, d_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + w_pad = _same_padding(w_x, w_kernel, w_stride) + d_pad = _same_padding(d_x, d_kernel, d_stride) + npad = [(0, 0)] * x.ndim + npad[1] = h_pad + npad[2] = w_pad + npad[3] = d_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, w_x, d_x, _ = x.shape + h_out = int((h_x - h_kernel) / h_stride) + 1 + w_out = int((w_x - w_kernel) / w_stride) + 1 + d_out = int((d_x - d_kernel) / d_stride) + 1 + + out_grps = [] + for grp in range(1, groups + 1): + x_in = x[..., (grp - 1) * ch_in : grp * ch_in] + stride_shape = ( + n_batch, + h_out, + w_out, + d_out, + h_kernel, + w_kernel, + d_kernel, + ch_in, + ) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + w_stride * x_in.strides[2], + d_stride * x_in.strides[3], + x_in.strides[1], + x_in.strides[2], + x_in.strides[3], + x_in.strides[4], + ) + inner_dim = h_kernel * w_kernel * d_kernel * ch_in + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(-1, inner_dim) + ch_out_groups = ch_out // groups + kernel_weights_grp = kernel_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ].reshape(-1, ch_out_groups) + bias_weights_grp = bias_weights[ + ..., (grp - 1) * ch_out_groups : grp * ch_out_groups + ] + out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp) + out = np.concatenate(out_grps, axis=-1).reshape( + n_batch, h_out, w_out, d_out, ch_out + ) + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) + return out + + +class ConvBasicTest(testing.TestCase): + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (3, 5, 4), + "output_shape": (3, 4, 5), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2,), + "groups": 2, + "input_shape": (3, 4, 4), + "output_shape": (3, 4, 6), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "causal", + "data_format": "channels_last", + "dilation_rate": (2,), + "groups": 2, + "input_shape": (3, 4, 4), + "output_shape": (3, 4, 6), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": (2,), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (3, 5, 4), + "output_shape": (3, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_conv1d_basic( + self, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.Conv1D, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + "groups": groups, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (3, 5, 5, 4), + "output_shape": (3, 4, 4, 5), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2), + "groups": 2, + "input_shape": (3, 4, 4, 4), + "output_shape": (3, 4, 4, 6), + }, + { + "filters": 6, + "kernel_size": (2, 2), + "strides": (2, 1), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1), + "groups": 2, + "input_shape": (3, 5, 5, 4), + "output_shape": (3, 2, 4, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_conv2d_basic( + self, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.Conv2D, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + "groups": groups, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (3, 5, 5, 5, 4), + "output_shape": (3, 4, 4, 4, 5), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2, 2), + "groups": 2, + "input_shape": (3, 4, 4, 4, 4), + "output_shape": (3, 4, 4, 4, 6), + }, + { + "filters": 6, + "kernel_size": (2, 2, 3), + "strides": (2, 1, 2), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1, 1), + "groups": 2, + "input_shape": (3, 5, 5, 5, 4), + "output_shape": (3, 2, 4, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_conv3d_basic( + self, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.Conv3D, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + "groups": groups, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_bad_init_args(self): + # `filters` is not positive. + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `filters`. Expected a " + "strictly positive value. Received filters=0.", + ): + layers.Conv1D(filters=0, kernel_size=1) + + # `kernel_size` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `kernel_size` argument must be a tuple of \d+ " + r"integers. Received kernel_size=\(1, 0\), including values \{0\} " + r"that do not satisfy `value > 0`", + ): + layers.Conv2D(filters=2, kernel_size=(1, 0)) + + # `strides` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `strides` argument must be a tuple of \d+ " + r"integers. Received strides=\(1, 0\), including values \{0\} that " + r"do not satisfy `value > 0`", + ): + layers.Conv2D(filters=2, kernel_size=(2, 2), strides=(1, 0)) + + # `dilation_rate > 1` while `strides > 1`. + with self.assertRaisesRegex( + ValueError, + r"`strides > 1` not supported in conjunction with " + r"`dilation_rate > 1`. Received: strides=\(2, 2\) and " + r"dilation_rate=\(2, 1\)", + ): + layers.Conv2D( + filters=2, kernel_size=(2, 2), strides=2, dilation_rate=(2, 1) + ) + + # `groups` is not strictly positive. + with self.assertRaisesRegex( + ValueError, + "The number of groups must be a positive integer. " + "Received: groups=0.", + ): + layers.Conv2D(filters=5, kernel_size=(2, 2), groups=0) + + # `filters` cannot be divided by `groups`. + with self.assertRaisesRegex( + ValueError, + "The number of filters must be evenly divisible by the" + " number of groups. Received: groups=2, filters=5.", + ): + layers.Conv2D(filters=5, kernel_size=(2, 2), groups=2) + + @parameterized.named_parameters( + { + "testcase_name": "conv1d_kernel_size3_strides1", + "conv_cls": layers.Conv1D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 4), + "output_shape": (None, 3, 6), + }, + { + "testcase_name": "conv1d_kernel_size2_strides2", + "conv_cls": layers.Conv1D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 4), + "output_shape": (None, 2, 6), + }, + { + "testcase_name": "conv2d_kernel_size3_strides1", + "conv_cls": layers.Conv2D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 5, 4), + "output_shape": (None, 3, 3, 6), + }, + { + "testcase_name": "conv2d_kernel_size2_strides2", + "conv_cls": layers.Conv2D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 5, 4), + "output_shape": (None, 2, 2, 6), + }, + { + "testcase_name": "conv3d_kernel_size3_strides1", + "conv_cls": layers.Conv3D, + "filters": 6, + "kernel_size": 3, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + "input_shape": (None, 5, 5, 5, 4), + "output_shape": (None, 3, 3, 3, 6), + }, + { + "testcase_name": "conv3d_kernel_size2_strides2", + "conv_cls": layers.Conv3D, + "filters": 6, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + "input_shape": (None, 5, 5, 5, 4), + "output_shape": (None, 2, 2, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_enable_lora( + self, + conv_cls, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + input_shape, + output_shape, + ): + if conv_cls not in (layers.Conv1D, layers.Conv2D, layers.Conv3D): + raise TypeError + layer = conv_cls( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + layer.build(input_shape) + layer.enable_lora(2) + self.assertLen(layer.trainable_weights, 3) + self.assertLen(layer.non_trainable_weights, 1) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 4) + # Try eager call + x = np.random.random((64,) + input_shape[1:]) + y = np.random.random((64,) + output_shape[1:]) + _ = layer(x[:2]) + + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + + # Try calling fit() + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + new_model = models.Sequential( + [ + conv_cls( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + ] + ) + new_model.build(input_shape) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @pytest.mark.requires_trainable_backend + def test_lora_weight_name(self): + class MyModel(models.Model): + def __init__(self): + super().__init__(name="mymodel") + self.conv2d = layers.Conv2D(4, 3, name="conv2d") + + def build(self, input_shape): + self.conv2d.build(input_shape) + + def call(self, x): + return self.conv2d(x) + + model = MyModel() + model.build((None, 5, 5, 4)) + model.conv2d.enable_lora(2) + self.assertEqual( + model.conv2d.lora_kernel_a.path, "mymodel/conv2d/lora_kernel_a" + ) + + @pytest.mark.requires_trainable_backend + def test_enable_lora_with_alpha(self): + # Create a `Conv2D` layer with a small kernel for simplicity. + layer = layers.Conv2D(filters=3, kernel_size=(2, 2), padding="valid") + # Use a fixed input shape: batch size 1, height=4, width=4, channels=3. + input_shape = (1, 4, 4, 3) + layer.build(input_shape) + + # Set the base kernel to known, deterministic values. + base_kernel = np.linspace( + 0, 1, num=np.prod(layer.kernel.shape), dtype=np.float32 + ) + base_kernel = base_kernel.reshape(layer.kernel.shape) + layer.kernel.assign(base_kernel) + + # Enable LoRA with `rank`=2 and a custom `lora_alpha` value (e.g. 3.0). + layer.enable_lora(rank=2, lora_alpha=3.0) + self.assertEqual(layer.lora_rank, 2) + self.assertEqual(layer.lora_alpha, 3.0) + + # For `Conv2D`, assume the LoRA weights have shapes: + # `lora_kernel_a`: (kernel_height, kernel_width, in_channels, rank) + # `lora_kernel_b`: (rank, out_channels) + lora_a_shape = layer.lora_kernel_a.shape + lora_b_shape = layer.lora_kernel_b.shape + + # Assign known constant values to LoRA weights. + lora_a = np.full(lora_a_shape, 0.1, dtype=np.float32) + lora_b = np.full(lora_b_shape, 0.2, dtype=np.float32) + layer.lora_kernel_a.assign(lora_a) + layer.lora_kernel_b.assign(lora_b) + + # Compute the expected delta. + # Flatten `lora_kernel_a` to shape (-1, `rank`), + # multiply with `lora_kernel_b`, + # then reshape to the kernel's shape. + scaling = 3.0 / 2 # `lora_alpha / lora_rank` + delta = np.matmul(lora_a.reshape(-1, 2), lora_b) + delta = delta.reshape(base_kernel.shape) + expected_effective_kernel = base_kernel + scaling * delta + + # Compare the effective kernel computed via the property. + actual_effective_kernel = ops.convert_to_numpy(layer.kernel) + self.assertAllClose(actual_effective_kernel, expected_effective_kernel) + + @pytest.mark.requires_trainable_backend + def test_lora_rank_argument(self): + self.run_layer_test( + layers.Conv2D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "activation": "sigmoid", + "data_format": "channels_last", + "kernel_regularizer": "l2", + "lora_rank": 2, + }, + input_shape=(2, 5, 5, 4), + expected_output_shape=(2, 3, 3, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=2, # we have 2 regularizers. + supports_masking=False, + ) + + +class ConvCorrectnessTest(testing.TestCase): + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2,), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "causal", + "data_format": "channels_last", + "dilation_rate": (2,), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": (2,), + "strides": (2,), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 2, + }, + { + "filters": 6, + "kernel_size": (2,), + "strides": (2,), + "padding": "valid", + "data_format": "channels_first", + "dilation_rate": 1, + "groups": 2, + }, + ) + def test_conv1d( + self, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + ): + layer = layers.Conv1D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + + inputs = np.random.normal(size=[2, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(filters,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_conv1d( + inputs, + kernel_weights, + bias_weights, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + }, + { + "filters": 4, + "kernel_size": 3, + "strides": 2, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 3), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": (4, 3), + "strides": (2, 1), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": (4, 3), + "strides": (2, 1), + "padding": "valid", + "data_format": "channels_first", + "dilation_rate": (1, 1), + "groups": 2, + }, + ) + def test_conv2d( + self, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + ): + layer = layers.Conv2D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + + inputs = np.random.normal(size=[2, 8, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(filters,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_conv2d( + inputs, + kernel_weights, + bias_weights, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + self.assertAllClose(outputs, expected, rtol=5e-4) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "groups": 1, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2, 2), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 3, 4), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": (2, 2, 3), + "strides": (2, 1, 2), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1, 1), + "groups": 2, + }, + { + "filters": 6, + "kernel_size": (2, 2, 3), + "strides": (2, 1, 2), + "padding": "valid", + "data_format": "channels_first", + "dilation_rate": (1, 1, 1), + "groups": 2, + }, + ) + def test_conv3d( + self, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + ): + layer = layers.Conv3D( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + + inputs = np.random.normal(size=[2, 8, 8, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(filters,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_conv3d( + inputs, + kernel_weights, + bias_weights, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + groups=groups, + ) + self.assertAllClose(outputs, expected, rtol=1e-3) + + def test_conv_constraints(self): + layer = layers.Conv2D( + filters=4, + kernel_size=3, + kernel_constraint="non_neg", + ) + layer.build((None, 5, 5, 3)) + self.assertIsInstance(layer.kernel.constraint, constraints.NonNeg) + layer = layers.Conv2D( + filters=4, + kernel_size=3, + bias_constraint="non_neg", + ) + layer.build((None, 5, 5, 3)) + self.assertIsInstance(layer.bias.constraint, constraints.NonNeg) diff --git a/keras/src/layers/convolutional/conv_transpose_test.py b/keras/src/layers/convolutional/conv_transpose_test.py new file mode 100644 index 000000000000..89c5bd65e782 --- /dev/null +++ b/keras/src/layers/convolutional/conv_transpose_test.py @@ -0,0 +1,887 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.backend.common.backend_utils import ( + _convert_conv_transpose_padding_args_from_keras_to_torch, +) +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_output_shape, +) +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_jax, +) + + +def np_conv1d_transpose( + x, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 1)) + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation = dilation_rate[0] + else: + h_dilation = dilation_rate + + h_kernel, ch_out, ch_in = kernel_weights.shape + n_batch, h_x, _ = x.shape + # Get output shape and padding + _, h_out, _ = compute_conv_transpose_output_shape( + x.shape, + kernel_weights.shape, + ch_out, + strides, + padding, + output_padding, + "channels_last", + dilation_rate, + ) + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=x.shape, + kernel_shape=kernel_weights.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + h_pad_side1 = h_kernel - 1 - jax_padding[0][0] + + if h_dilation > 1: + # Increase kernel size + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_kernel_size_tuple = (new_h_kernel,) + new_kernel_weights = np.zeros( + (*new_kernel_size_tuple, ch_out, ch_in), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel = kernel_weights.shape[0] + + # Compute output + output = np.zeros([n_batch, h_out + h_kernel, ch_out]) + for nb in range(n_batch): + for h_x_idx in range(h_x): + h_out_idx = h_x_idx * h_stride # Index in output + output[nb, h_out_idx : h_out_idx + h_kernel, :] += np.sum( + kernel_weights[:, :, :] * x[nb, h_x_idx, :], axis=-1 + ) + output = output + bias_weights + + # Cut padding results from output + output = output[:, h_pad_side1 : h_out + h_pad_side1] + if data_format == "channels_first": + output = output.transpose((0, 2, 1)) + return output + + +def np_conv2d_transpose( + x, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride = strides + else: + h_stride = strides + w_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + + h_kernel, w_kernel, ch_out, ch_in = kernel_weights.shape + n_batch, h_x, w_x, _ = x.shape + # Get output shape and padding + _, h_out, w_out, _ = compute_conv_transpose_output_shape( + x.shape, + kernel_weights.shape, + ch_out, + strides, + padding, + output_padding, + "channels_last", + dilation_rate, + ) + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=x.shape, + kernel_shape=kernel_weights.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + h_pad_side1 = h_kernel - 1 - jax_padding[0][0] + w_pad_side1 = w_kernel - 1 - jax_padding[1][0] + + if h_dilation > 1 or w_dilation > 1: + # Increase kernel size + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_kernel_size_tuple = (new_h_kernel, new_w_kernel) + new_kernel_weights = np.zeros( + (*new_kernel_size_tuple, ch_out, ch_in), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel, w_kernel = kernel_weights.shape[:2] + + # Compute output + output = np.zeros([n_batch, h_out + h_kernel, w_out + w_kernel, ch_out]) + for nb in range(n_batch): + for h_x_idx in range(h_x): + h_out_idx = h_x_idx * h_stride # Index in output + for w_x_idx in range(w_x): + w_out_idx = w_x_idx * w_stride + output[ + nb, + h_out_idx : h_out_idx + h_kernel, + w_out_idx : w_out_idx + w_kernel, + :, + ] += np.sum( + kernel_weights[:, :, :, :] * x[nb, h_x_idx, w_x_idx, :], + axis=-1, + ) + output = output + bias_weights + + # Cut padding results from output + output = output[ + :, + h_pad_side1 : h_out + h_pad_side1, + w_pad_side1 : w_out + w_pad_side1, + ] + if data_format == "channels_first": + output = output.transpose((0, 3, 1, 2)) + return output + + +def np_conv3d_transpose( + x, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride, d_stride = strides + else: + h_stride = strides + w_stride = strides + d_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation, d_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + d_dilation = dilation_rate + + h_kernel, w_kernel, d_kernel, ch_out, ch_in = kernel_weights.shape + n_batch, h_x, w_x, d_x, _ = x.shape + # Get output shape and padding + _, h_out, w_out, d_out, _ = compute_conv_transpose_output_shape( + x.shape, + kernel_weights.shape, + ch_out, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=x.shape, + kernel_shape=kernel_weights.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + h_pad_side1 = h_kernel - 1 - jax_padding[0][0] + w_pad_side1 = w_kernel - 1 - jax_padding[1][0] + d_pad_side1 = d_kernel - 1 - jax_padding[2][0] + + if h_dilation > 1 or w_dilation > 1 or d_dilation > 1: + # Increase kernel size + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1) + new_kernel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) + new_kernel_weights = np.zeros( + (*new_kernel_size_tuple, ch_out, ch_in), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = ( + kernel_weights + ) + kernel_weights = new_kernel_weights + h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3] + + # Compute output + output = np.zeros( + [ + n_batch, + h_out + h_kernel, + w_out + w_kernel, + d_out + d_kernel, + ch_out, + ] + ) + for nb in range(n_batch): + for h_x_idx in range(h_x): + h_out_idx = h_x_idx * h_stride # Index in output + for w_x_idx in range(w_x): + w_out_idx = w_x_idx * w_stride + for d_x_idx in range(d_x): + d_out_idx = d_x_idx * d_stride + output[ + nb, + h_out_idx : h_out_idx + h_kernel, + w_out_idx : w_out_idx + w_kernel, + d_out_idx : d_out_idx + d_kernel, + :, + ] += np.sum( + kernel_weights[:, :, :, :, :] + * x[nb, h_x_idx, w_x_idx, d_x_idx, :], + axis=-1, + ) + output = output + bias_weights + + # Cut padding results from output + output = output[ + :, + h_pad_side1 : h_out + h_pad_side1, + w_pad_side1 : w_out + w_pad_side1, + d_pad_side1 : d_out + d_pad_side1, + ] + if data_format == "channels_first": + output = output.transpose((0, 4, 1, 2, 3)) + return output + + +class ConvTransposeBasicTest(testing.TestCase): + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (2, 8, 4), + "output_shape": (2, 16, 5), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 3, + "padding": "same", + "output_padding": 2, + "data_format": "channels_last", + "dilation_rate": (1,), + "input_shape": (2, 8, 4), + "output_shape": (2, 23, 6), + }, + { + "filters": 6, + "kernel_size": (2,), + "strides": (2,), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (2, 8, 4), + "output_shape": (2, 16, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_conv1d_transpose_basic( + self, + filters, + kernel_size, + strides, + padding, + output_padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.Conv1DTranspose, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "output_padding": output_padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (2, 8, 8, 4), + "output_shape": (2, 16, 16, 5), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 3, + "padding": "same", + "output_padding": 2, + "data_format": "channels_last", + "dilation_rate": (1, 1), + "input_shape": (2, 8, 8, 4), + "output_shape": (2, 23, 23, 6), + }, + { + "filters": 6, + "kernel_size": (2, 3), + "strides": (2, 1), + "padding": "valid", + "output_padding": None, + "data_format": "channels_first", + "dilation_rate": (1, 1), + "input_shape": (2, 4, 8, 8), + "output_shape": (2, 6, 16, 10), + }, + { + "filters": 2, + "kernel_size": (7, 7), + "strides": (16, 16), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": (1, 1), + "input_shape": (1, 14, 14, 2), + "output_shape": (1, 224, 224, 2), + }, + ) + @pytest.mark.requires_trainable_backend + def test_conv2d_transpose_basic( + self, + filters, + kernel_size, + strides, + padding, + output_padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + if ( + data_format == "channels_first" + and backend.backend() == "tensorflow" + ): + pytest.skip("channels_first unsupported on CPU with TF") + + self.run_layer_test( + layers.Conv2DTranspose, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "output_padding": output_padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (2, 8, 8, 8, 4), + "output_shape": (2, 16, 16, 16, 5), + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 3, + "padding": "same", + "output_padding": 2, + "data_format": "channels_last", + "dilation_rate": (1, 1, 1), + "input_shape": (2, 8, 8, 8, 4), + "output_shape": (2, 23, 23, 23, 6), + }, + { + "filters": 6, + "kernel_size": (2, 2, 3), + "strides": (2, 1, 2), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": (1, 1, 1), + "input_shape": (2, 8, 8, 8, 4), + "output_shape": (2, 16, 9, 17, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_conv3d_transpose_basic( + self, + filters, + kernel_size, + strides, + padding, + output_padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.Conv3DTranspose, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "output_padding": output_padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_bad_init_args(self): + # `filters` is not positive. + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `filters`. Expected a " + "strictly positive value. Received filters=0.", + ): + layers.Conv1DTranspose(filters=0, kernel_size=1) + + # `kernel_size` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `kernel_size` argument must be a tuple of " + r"\d+ integers. Received kernel_size=\(1, 0\), including values" + r" \{0\} that do not satisfy `value > 0`", + ): + layers.Conv2DTranspose(filters=2, kernel_size=(1, 0)) + + # `strides` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `strides` argument must be a tuple of \d+ " + r"integers. Received strides=\(1, 0\), including values \{0\} " + r"that do not satisfy `value > 0`", + ): + layers.Conv2DTranspose( + filters=2, kernel_size=(2, 2), strides=(1, 0) + ) + + # `dilation_rate > 1` while `strides > 1`. + with self.assertRaisesRegex( + ValueError, + r"`strides > 1` not supported in conjunction with " + r"`dilation_rate > 1`. Received: strides=\(2, 2\) and " + r"dilation_rate=\(2, 1\)", + ): + layers.Conv2DTranspose( + filters=2, kernel_size=(2, 2), strides=2, dilation_rate=(2, 1) + ) + + +class ConvTransposeCorrectnessTest(testing.TestCase): + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 3, + "padding": "same", + "output_padding": 2, + "data_format": "channels_last", + "dilation_rate": (1,), + }, + { + "filters": 6, + "kernel_size": (2,), + "strides": (2,), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + }, + ) + def test_conv1d_transpose( + self, + filters, + kernel_size, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ): + layer = layers.Conv1DTranspose( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(filters,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_conv1d_transpose( + inputs, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + self.assertAllClose(outputs, expected, atol=1e-5) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "filters": 6, + "kernel_size": 7, + "strides": 16, + "padding": "same", + "output_padding": 2, + "data_format": "channels_last", + "dilation_rate": (1, 1), + }, + { + "filters": 6, + "kernel_size": (2, 3), + "strides": (2, 1), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": (1, 1), + }, + { + "filters": 2, + "kernel_size": (7, 7), + "strides": (16, 16), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": (1, 1), + }, + ) + def test_conv2d_transpose( + self, + filters, + kernel_size, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ): + layer = layers.Conv2DTranspose( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 14, 14, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(filters,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_conv2d_transpose( + inputs, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + self.assertAllClose(outputs, expected, atol=1e-5) + + @parameterized.parameters( + { + "filters": 5, + "kernel_size": 2, + "strides": 2, + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "filters": 6, + "kernel_size": 2, + "strides": 3, + "padding": "same", + "output_padding": 2, + "data_format": "channels_last", + "dilation_rate": (1, 1, 1), + }, + { + "filters": 6, + "kernel_size": (2, 2, 3), + "strides": (2, 1, 2), + "padding": "valid", + "output_padding": None, + "data_format": "channels_last", + "dilation_rate": (1, 1, 1), + }, + ) + def test_conv3d_transpose( + self, + filters, + kernel_size, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ): + layer = layers.Conv3DTranspose( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 8, 8, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(filters,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_conv3d_transpose( + inputs, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + self.assertAllClose(outputs, expected, atol=1e-5) + + @parameterized.product( + kernel_size=list(range(1, 5)), + strides=list(range(1, 5)), + padding=["same", "valid"], + output_padding=[None] + list(range(1, 5)), + ) + def test_conv1d_transpose_consistency( + self, kernel_size, strides, padding, output_padding + ): + """Test conv transpose, on an 1D array of size 3, against several + convolution parameters. In particular, tests if Torch inconsistencies + are raised. + """ + + # output_padding cannot be greater than strides + if isinstance(output_padding, int) and output_padding >= strides: + pytest.skip( + "`output_padding` greater than `strides` is not supported" + ) + + if backend.config.image_data_format() == "channels_last": + input_shape = (1, 3, 1) + else: + input_shape = (1, 1, 3) + + input = np.ones(shape=input_shape) + kernel_weights = np.arange(1, kernel_size + 1).reshape( + (kernel_size, 1, 1) + ) + + # Expected result + expected_res = np_conv1d_transpose( + x=input, + kernel_weights=kernel_weights, + bias_weights=np.zeros(shape=(1,)), + strides=strides, + padding=padding, + output_padding=output_padding, + data_format=backend.config.image_data_format(), + dilation_rate=1, + ) + + # keras layer + kc_layer = layers.Conv1DTranspose( + filters=1, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=1, + ) + kc_layer.build(input_shape=input_shape) + kc_layer.kernel.assign(kernel_weights) + + # Special cases for Torch + if backend.backend() == "torch": + # The following set of arguments lead to Torch output padding to be + # greater than strides, which is not supported by Torch. + # An error is raised. + if (kernel_size, strides, padding, output_padding) in [ + (2, 1, "same", None), + (4, 1, "same", None), + ]: + with pytest.raises(ValueError): + kc_res = kc_layer(input) + return + + # When both torch_padding and torch_output_padding are greater + # than 0, Torch outputs are inconsistent with the ones from + # Tensorflow. A warning is raised, and we expect the results to be + # different. + ( + torch_padding, + torch_output_padding, + ) = _convert_conv_transpose_padding_args_from_keras_to_torch( + kernel_size=kernel_size, + stride=strides, + dilation_rate=1, + padding=padding, + output_padding=output_padding, + ) + if torch_padding > 0 and torch_output_padding > 0: + with pytest.raises(AssertionError): + kc_res = kc_layer(input) + self.assertAllClose(expected_res, kc_res, atol=1e-5) + return + + # Compare results + kc_res = kc_layer(input) + self.assertAllClose(expected_res, kc_res, atol=1e-5) + + @parameterized.product( + kernel_size=list(range(1, 5)), + strides=list(range(1, 5)), + padding=["same", "valid"], + output_padding=[None] + list(range(1, 5)), + ) + def test_shape_inference_static_unknown_shape( + self, kernel_size, strides, padding, output_padding + ): + if backend.config.image_data_format() == "channels_last": + input_shape = (None, None, 3) + output_tensor_shape = (None, None, None, 2) + else: + input_shape = (3, None, None) + output_tensor_shape = (None, 2, None, None) + x = layers.Input(shape=input_shape) + x = layers.Conv2DTranspose( + filters=2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=1, + )(x) + self.assertEqual(x.shape, output_tensor_shape) diff --git a/keras/src/layers/convolutional/depthwise_conv1d.py b/keras/src/layers/convolutional/depthwise_conv1d.py new file mode 100644 index 000000000000..51312d8447e2 --- /dev/null +++ b/keras/src/layers/convolutional/depthwise_conv1d.py @@ -0,0 +1,137 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv + + +@keras_export("keras.layers.DepthwiseConv1D") +class DepthwiseConv1D(BaseDepthwiseConv): + """1D depthwise convolution layer. + + Depthwise convolution is a type of convolution in which each input channel + is convolved with a different kernel (called a depthwise kernel). You can + understand depthwise convolution as the first step in a depthwise separable + convolution. + + It is implemented via the following steps: + + - Split the input into individual channels. + - Convolve each channel with an individual depthwise kernel with + `depth_multiplier` output channels. + - Concatenate the convolved outputs along the channels axis. + + Unlike a regular 1D convolution, depthwise convolution does not mix + information across different input channels. + + The `depth_multiplier` argument determines how many filters are applied to + one input channel. As such, it controls the amount of output channels that + are generated per input channel in the depthwise step. + + Args: + kernel_size: int or tuple/list of 1 integer, specifying the size of the + depthwise convolution window. + strides: int or tuple/list of 1 integer, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `input_channel * depth_multiplier`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 1 integers, specifying the dilation + rate to use for dilated convolution. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + depthwise_initializer: Initializer for the convolution kernel. + If `None`, the default initializer (`"glorot_uniform"`) + will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + depthwise_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, steps, channels)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, channels, steps)` + + Output shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: + `(batch_shape, new_steps, channels * depth_multiplier)` + - If `data_format="channels_first"`: + A 3D tensor with shape: + `(batch_shape, channels * depth_multiplier, new_steps)` + + Returns: + A 3D tensor representing + `activation(depthwise_conv1d(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + Example: + + >>> x = np.random.rand(4, 10, 12) + >>> y = keras.layers.DepthwiseConv1D(3, 3, 2, activation='relu')(x) + >>> print(y.shape) + (4, 4, 36) + """ + + def __init__( + self, + kernel_size, + strides=1, + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=1, + depth_multiplier=depth_multiplier, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/depthwise_conv2d.py b/keras/src/layers/convolutional/depthwise_conv2d.py new file mode 100644 index 000000000000..71c950246e03 --- /dev/null +++ b/keras/src/layers/convolutional/depthwise_conv2d.py @@ -0,0 +1,138 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv + + +@keras_export("keras.layers.DepthwiseConv2D") +class DepthwiseConv2D(BaseDepthwiseConv): + """2D depthwise convolution layer. + + Depthwise convolution is a type of convolution in which each input channel + is convolved with a different kernel (called a depthwise kernel). You can + understand depthwise convolution as the first step in a depthwise separable + convolution. + + It is implemented via the following steps: + + - Split the input into individual channels. + - Convolve each channel with an individual depthwise kernel with + `depth_multiplier` output channels. + - Concatenate the convolved outputs along the channels axis. + + Unlike a regular 2D convolution, depthwise convolution does not mix + information across different input channels. + + The `depth_multiplier` argument determines how many filters are applied to + one input channel. As such, it controls the amount of output channels that + are generated per input channel in the depthwise step. + + Args: + kernel_size: int or tuple/list of 2 integer, specifying the size of the + depthwise convolution window. + strides: int or tuple/list of 2 integer, specifying the stride length + of the depthwise convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `input_channel * depth_multiplier`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file + at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 2 integers, specifying the dilation + rate to use for dilated convolution. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + depthwise_initializer: Initializer for the convolution kernel. + If `None`, the default initializer (`"glorot_uniform"`) + will be used. + bias_initializer: Initializer for the bias vector. If `None`, the + default initializer (`"zeros"`) will be used. + depthwise_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + kernel after being updated by an `Optimizer` (e.g. used to implement + norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). Constraints + are not safe to use when doing asynchronous distributed training. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, height, width, channels)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, channels, height, width)` + + Output shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: + `(batch_size, new_height, new_width, channels * depth_multiplier)` + - If `data_format="channels_first"`: + A 4D tensor with shape: + `(batch_size, channels * depth_multiplier, new_height, new_width)` + + Returns: + A 4D tensor representing + `activation(depthwise_conv2d(inputs, kernel) + bias)`. + + Raises: + ValueError: when both `strides > 1` and `dilation_rate > 1`. + + Example: + + >>> x = np.random.rand(4, 10, 10, 12) + >>> y = keras.layers.DepthwiseConv2D(kernel_size=3, activation='relu')(x) + >>> print(y.shape) + (4, 8, 8, 12) + """ + + def __init__( + self, + kernel_size, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=2, + depth_multiplier=depth_multiplier, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/depthwise_conv_test.py b/keras/src/layers/convolutional/depthwise_conv_test.py new file mode 100644 index 000000000000..a81dd69035b2 --- /dev/null +++ b/keras/src/layers/convolutional/depthwise_conv_test.py @@ -0,0 +1,469 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from numpy.lib.stride_tricks import as_strided + +from keras.src import layers +from keras.src import testing + + +def _same_padding(input_size, kernel_size, stride): + if input_size % stride == 0: + padding = max(kernel_size - stride, 0) + else: + padding = max(kernel_size - (input_size % stride), 0) + return padding // 2, padding - padding // 2 + + +def np_depthwise_conv1d( + x, + kernel_weights, + bias_weights, + strides, + padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 1)) + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation = dilation_rate[0] + else: + h_dilation = dilation_rate + h_kernel, ch_in, ch_out = kernel_weights.shape + + if h_dilation > 1: + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_kernel_weights = np.zeros( + (new_h_kernel, ch_in, ch_out), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel = kernel_weights.shape[0] + + if padding == "same": + n_batch, h_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + npad = [(0, 0)] * x.ndim + npad[1] = h_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, _ = x.shape + h_out = int((h_x - h_kernel) / h_stride) + 1 + + out_grps = [] + bias_weights = bias_weights.reshape(ch_in, ch_out) + for ch_in_idx in range(ch_in): + for ch_out_idx in range(ch_out): + x_in = np.ascontiguousarray(x[..., ch_in_idx]) + stride_shape = (n_batch, h_out, h_kernel) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + x_in.strides[1], + ) + inner_dim = h_kernel + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(-1, inner_dim) + kernel_weights_grp = kernel_weights[ + ..., ch_in_idx, ch_out_idx + ].reshape(-1, 1) + bias_weights_grp = bias_weights[..., ch_in_idx, ch_out_idx] + out_grps.append( + (x_strided @ kernel_weights_grp + bias_weights_grp).reshape( + n_batch, h_out, 1 + ) + ) + out = np.concatenate(out_grps, axis=-1) + if data_format == "channels_first": + out = out.transpose((0, 2, 1)) + return out + + +def np_depthwise_conv2d( + x, + kernel_weights, + bias_weights, + strides, + padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride = strides + else: + h_stride = strides + w_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + h_kernel, w_kernel, ch_in, ch_out = kernel_weights.shape + + if h_dilation > 1 or w_dilation > 1: + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_kenel_size_tuple = (new_h_kernel, new_w_kernel) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_in, ch_out), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel, w_kernel = kernel_weights.shape[:2] + + if padding == "same": + n_batch, h_x, w_x, _ = x.shape + h_pad = _same_padding(h_x, h_kernel, h_stride) + w_pad = _same_padding(w_x, w_kernel, w_stride) + npad = [(0, 0)] * x.ndim + npad[1] = h_pad + npad[2] = w_pad + x = np.pad(x, pad_width=npad, mode="constant", constant_values=0) + + n_batch, h_x, w_x, _ = x.shape + h_out = int((h_x - h_kernel) / h_stride) + 1 + w_out = int((w_x - w_kernel) / w_stride) + 1 + + out_grps = [] + bias_weights = bias_weights.reshape(ch_in, ch_out) + for ch_in_idx in range(ch_in): + for ch_out_idx in range(ch_out): + x_in = np.ascontiguousarray(x[..., ch_in_idx]) + stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel) + strides = ( + x_in.strides[0], + h_stride * x_in.strides[1], + w_stride * x_in.strides[2], + x_in.strides[1], + x_in.strides[2], + ) + inner_dim = h_kernel * w_kernel + x_strided = as_strided( + x_in, shape=stride_shape, strides=strides + ).reshape(-1, inner_dim) + kernel_weights_grp = kernel_weights[ + ..., ch_in_idx, ch_out_idx + ].reshape(-1, 1) + bias_weights_grp = bias_weights[..., ch_in_idx, ch_out_idx] + out_grps.append( + (x_strided @ kernel_weights_grp + bias_weights_grp).reshape( + n_batch, h_out, w_out, 1 + ) + ) + out = np.concatenate(out_grps, axis=-1) + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + return out + + +class DepthwiseConvBasicTest(testing.TestCase): + @parameterized.parameters( + { + "depth_multiplier": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (3, 5, 4), + "output_shape": (3, 4, 20), + }, + { + "depth_multiplier": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2,), + "input_shape": (3, 4, 4), + "output_shape": (3, 4, 24), + }, + { + "depth_multiplier": 6, + "kernel_size": 2, + "strides": (2,), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (3, 5, 4), + "output_shape": (3, 2, 24), + }, + ) + @pytest.mark.requires_trainable_backend + def test_depthwise_conv1d_basic( + self, + depth_multiplier, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.DepthwiseConv1D, + init_kwargs={ + "depth_multiplier": depth_multiplier, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + { + "depth_multiplier": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (3, 5, 5, 4), + "output_shape": (3, 4, 4, 20), + }, + { + "depth_multiplier": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2), + "input_shape": (3, 4, 4, 4), + "output_shape": (3, 4, 4, 24), + }, + { + "depth_multiplier": 6, + "kernel_size": (2, 2), + "strides": (2, 2), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1), + "input_shape": (3, 5, 5, 4), + "output_shape": (3, 2, 2, 24), + }, + ) + @pytest.mark.requires_trainable_backend + def test_depthwise_conv2d_basic( + self, + depth_multiplier, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.DepthwiseConv2D, + init_kwargs={ + "depth_multiplier": depth_multiplier, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_bad_init_args(self): + # `depth_multiplier` is not positive. + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `depth_multiplier`. " + "Expected a strictly positive value. Received " + "depth_multiplier=0.", + ): + layers.DepthwiseConv1D(depth_multiplier=0, kernel_size=1) + + # `kernel_size` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `kernel_size` argument must be a tuple of 2 " + r"integers. Received kernel_size=\(1, 0\), including values " + r"\{0\} that do not satisfy `value > 0`", + ): + layers.DepthwiseConv2D(depth_multiplier=2, kernel_size=(1, 0)) + + # `strides` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `strides` argument must be a tuple of \d+ " + r"integers. Received strides=\(1, 0\), including values \{0\} " + r"that do not satisfy `value > 0`", + ): + layers.DepthwiseConv2D( + depth_multiplier=2, kernel_size=(2, 2), strides=(1, 0) + ) + + # `dilation_rate > 1` while `strides > 1`. + with self.assertRaisesRegex( + ValueError, + r"`strides > 1` not supported in conjunction with " + r"`dilation_rate > 1`. Received: strides=\(2, 2\) and " + r"dilation_rate=\(2, 1\)", + ): + layers.DepthwiseConv2D( + depth_multiplier=2, + kernel_size=(2, 2), + strides=2, + dilation_rate=(2, 1), + ) + + +class DepthwiseConvCorrectnessTest(testing.TestCase): + @parameterized.parameters( + { + "depth_multiplier": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "depth_multiplier": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2,), + }, + { + "depth_multiplier": 6, + "kernel_size": (2,), + "strides": (2,), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + }, + ) + def test_depthwise_conv1d( + self, + depth_multiplier, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + ): + layer = layers.DepthwiseConv1D( + depth_multiplier=depth_multiplier, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(depth_multiplier * 4,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_depthwise_conv1d( + inputs, + kernel_weights, + bias_weights, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + { + "depth_multiplier": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "depth_multiplier": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2), + }, + { + "depth_multiplier": 6, + "kernel_size": (2, 2), + "strides": (2, 2), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1), + }, + ) + def test_depthwise_conv2d( + self, + depth_multiplier, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + ): + layer = layers.DepthwiseConv2D( + depth_multiplier=depth_multiplier, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 8, 8, 4]) + layer.build(input_shape=inputs.shape) + + kernel_shape = layer.kernel.shape + kernel_weights = np.random.normal(size=kernel_shape) + bias_weights = np.random.normal(size=(depth_multiplier * 4,)) + layer.kernel.assign(kernel_weights) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected = np_depthwise_conv2d( + inputs, + kernel_weights, + bias_weights, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + self.assertAllClose(outputs.shape, expected.shape) + self.assertAllClose(outputs, expected, atol=1e-5) diff --git a/keras/src/layers/convolutional/separable_conv1d.py b/keras/src/layers/convolutional/separable_conv1d.py new file mode 100644 index 000000000000..2f03161981d4 --- /dev/null +++ b/keras/src/layers/convolutional/separable_conv1d.py @@ -0,0 +1,143 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv + + +@keras_export( + [ + "keras.layers.SeparableConv1D", + "keras.layers.SeparableConvolution1D", + ] +) +class SeparableConv1D(BaseSeparableConv): + """1D separable convolution layer. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. It then optionally applies an + activation function to produce the final output. + + Args: + filters: int, the dimensionality of the output space (i.e. the number + of filters in the pointwise convolution). + kernel_size: int or tuple/list of 1 integers, specifying the size of the + depthwise convolution window. + strides: int or tuple/list of 1 integers, specifying the stride length + of the depthwise convolution. If only one int is specified, the same + stride size will be used for all dimensions. `strides > 1` is + incompatible with `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 1 integers, specifying the dilation + rate to use for dilated convolution. If only one int is specified, + the same dilation rate will be used for all dimensions. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `input_channel * depth_multiplier`. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + depthwise_initializer: An initializer for the depthwise convolution + kernel. If None, then the default initializer (`"glorot_uniform"`) + will be used. + pointwise_initializer: An initializer for the pointwise convolution + kernel. If None, then the default initializer (`"glorot_uniform"`) + will be used. + bias_initializer: An initializer for the bias vector. If None, the + default initializer ('"zeros"') will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used + for norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, steps, channels)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, channels, steps)` + + Output shape: + + - If `data_format="channels_last"`: + A 3D tensor with shape: `(batch_shape, new_steps, filters)` + - If `data_format="channels_first"`: + A 3D tensor with shape: `(batch_shape, filters, new_steps)` + + Returns: + A 3D tensor representing + `activation(separable_conv1d(inputs, kernel) + bias)`. + + Example: + + >>> x = np.random.rand(4, 10, 12) + >>> y = keras.layers.SeparableConv1D(3, 4, 3, 2, activation='relu')(x) + >>> print(y.shape) + (4, 4, 4) + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=1, + depth_multiplier=depth_multiplier, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/separable_conv2d.py b/keras/src/layers/convolutional/separable_conv2d.py new file mode 100644 index 000000000000..27c1548231dd --- /dev/null +++ b/keras/src/layers/convolutional/separable_conv2d.py @@ -0,0 +1,144 @@ +from keras.src.api_export import keras_export +from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv + + +@keras_export( + [ + "keras.layers.SeparableConv2D", + "keras.layers.SeparableConvolution2D", + ] +) +class SeparableConv2D(BaseSeparableConv): + """2D separable convolution layer. + + This layer performs a depthwise convolution that acts separately on + channels, followed by a pointwise convolution that mixes channels. + If `use_bias` is True and a bias initializer is provided, + it adds a bias vector to the output. It then optionally applies an + activation function to produce the final output. + + Args: + filters: int, the dimensionality of the output space (i.e. the number + of filters in the pointwise convolution). + kernel_size: int or tuple/list of 2 integers, specifying the size of the + depthwise convolution window. + strides: int or tuple/list of 2 integers, specifying the stride length + of the depthwise convolution. If only one int is specified, the same + stride size will be used for all dimensions. `strides > 1` is + incompatible with `dilation_rate > 1`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input. When `padding="same"` and + `strides=1`, the output has the same size as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file + at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 2 integers, specifying the dilation + rate to use for dilated convolution. If only one int is specified, + the same dilation rate will be used for all dimensions. + depth_multiplier: The number of depthwise convolution output channels + for each input channel. The total number of depthwise convolution + output channels will be equal to `input_channel * depth_multiplier`. + activation: Activation function. If `None`, no activation is applied. + use_bias: bool, if `True`, bias will be added to the output. + depthwise_initializer: An initializer for the depthwise convolution + kernel. If None, then the default initializer (`"glorot_uniform"`) + will be used. + pointwise_initializer: An initializer for the pointwise convolution + kernel. If None, then the default initializer (`"glorot_uniform"`) + will be used. + bias_initializer: An initializer for the bias vector. If None, the + default initializer ('"zeros"') will be used. + depthwise_regularizer: Optional regularizer for the depthwise + convolution kernel. + pointwise_regularizer: Optional regularizer for the pointwise + convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Optional regularizer function for the output. + depthwise_constraint: Optional projection function to be applied to the + depthwise kernel after being updated by an `Optimizer` (e.g. used + for norm constraints or value constraints for layer weights). The + function must take as input the unprojected variable and must return + the projected variable (which must have the same shape). + pointwise_constraint: Optional projection function to be applied to the + pointwise kernel after being updated by an `Optimizer`. + bias_constraint: Optional projection function to be applied to the + bias after being updated by an `Optimizer`. + + Input shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, height, width, channels)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, channels, height, width)` + + Output shape: + + - If `data_format="channels_last"`: + A 4D tensor with shape: `(batch_size, new_height, new_width, filters)` + - If `data_format="channels_first"`: + A 4D tensor with shape: `(batch_size, filters, new_height, new_width)` + + Returns: + A 4D tensor representing + `activation(separable_conv2d(inputs, kernel) + bias)`. + + Example: + + >>> x = np.random.rand(4, 10, 10, 12) + >>> y = keras.layers.SeparableConv2D(3, 4, 3, 2, activation='relu')(x) + >>> print(y.shape) + (4, 4, 4, 4) + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + rank=2, + depth_multiplier=depth_multiplier, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + pointwise_initializer=pointwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + pointwise_regularizer=pointwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + pointwise_constraint=pointwise_constraint, + bias_constraint=bias_constraint, + **kwargs, + ) diff --git a/keras/src/layers/convolutional/separable_conv_test.py b/keras/src/layers/convolutional/separable_conv_test.py new file mode 100644 index 000000000000..a3e600ca4898 --- /dev/null +++ b/keras/src/layers/convolutional/separable_conv_test.py @@ -0,0 +1,384 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import testing +from keras.src.layers.convolutional.conv_test import np_conv1d +from keras.src.layers.convolutional.conv_test import np_conv2d +from keras.src.layers.convolutional.depthwise_conv_test import ( + np_depthwise_conv1d, +) +from keras.src.layers.convolutional.depthwise_conv_test import ( + np_depthwise_conv2d, +) + + +class SeparableConvBasicTest(testing.TestCase): + @parameterized.parameters( + { + "depth_multiplier": 5, + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (3, 5, 4), + "output_shape": (3, 4, 5), + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2,), + "input_shape": (3, 4, 4), + "output_shape": (3, 4, 6), + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": 2, + "strides": (2,), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (3, 5, 4), + "output_shape": (3, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_separable_conv1d_basic( + self, + depth_multiplier, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.SeparableConv1D, + init_kwargs={ + "depth_multiplier": depth_multiplier, + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + { + "depth_multiplier": 5, + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + "input_shape": (3, 5, 5, 4), + "output_shape": (3, 4, 4, 5), + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2), + "input_shape": (3, 4, 4, 4), + "output_shape": (3, 4, 4, 6), + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": (2, 2), + "strides": (2, 2), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1), + "input_shape": (3, 5, 5, 4), + "output_shape": (3, 2, 2, 6), + }, + ) + @pytest.mark.requires_trainable_backend + def test_separable_conv2d_basic( + self, + depth_multiplier, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.SeparableConv2D, + init_kwargs={ + "depth_multiplier": depth_multiplier, + "filters": filters, + "kernel_size": kernel_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + "dilation_rate": dilation_rate, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_bad_init_args(self): + # `depth_multiplier` is not positive. + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `depth_multiplier`. " + "Expected a strictly positive value. Received " + "depth_multiplier=0.", + ): + layers.SeparableConv1D(depth_multiplier=0, filters=1, kernel_size=1) + + # `filters` is not positive. + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `filters`. Expected a " + "strictly positive value. Received filters=0.", + ): + layers.SeparableConv1D(depth_multiplier=1, filters=0, kernel_size=1) + + # `kernel_size` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `kernel_size` argument must be a tuple of " + r"\d+ integers. Received kernel_size=\(1, 0\), including values" + r" \{0\} that do not satisfy `value > 0`", + ): + layers.SeparableConv2D( + depth_multiplier=2, filters=2, kernel_size=(1, 0) + ) + + # `strides` has 0. + with self.assertRaisesRegex( + ValueError, + r"The `strides` argument must be a tuple of \d+ " + r"integers. Received strides=\(1, 0\), including values \{0\} " + r"that do not satisfy `value > 0`", + ): + layers.SeparableConv2D( + depth_multiplier=2, + filters=2, + kernel_size=(2, 2), + strides=(1, 0), + ) + + # `dilation_rate > 1` while `strides > 1`. + with self.assertRaisesRegex( + ValueError, + r"`strides > 1` not supported in conjunction with " + r"`dilation_rate > 1`. Received: strides=\(2, 2\) and " + r"dilation_rate=\(2, 1\)", + ): + layers.SeparableConv2D( + depth_multiplier=2, + filters=2, + kernel_size=(2, 2), + strides=2, + dilation_rate=(2, 1), + ) + + +class SeparableConvCorrectnessTest(testing.TestCase): + @parameterized.parameters( + { + "depth_multiplier": 5, + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2,), + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": (2,), + "strides": (2,), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + }, + ) + def test_separable_conv1d( + self, + depth_multiplier, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + ): + layer = layers.SeparableConv1D( + depth_multiplier=depth_multiplier, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 8, 4]) + layer.build(input_shape=inputs.shape) + + depthwise_kernel_shape = layer.depthwise_kernel.shape + depthwise_kernel_weights = np.random.normal(size=depthwise_kernel_shape) + layer.depthwise_kernel.assign(depthwise_kernel_weights) + + pointwise_kernel_shape = layer.pointwise_kernel.shape + pointwise_kernel_weights = np.random.normal(size=pointwise_kernel_shape) + layer.pointwise_kernel.assign(pointwise_kernel_weights) + + bias_weights = np.random.normal(size=(filters,)) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected_depthwise = np_depthwise_conv1d( + inputs, + depthwise_kernel_weights, + np.zeros(4 * depth_multiplier), + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + expected = np_conv1d( + expected_depthwise, + pointwise_kernel_weights, + bias_weights, + strides=1, + padding=padding, + data_format=data_format, + dilation_rate=1, + groups=1, + ) + + self.assertAllClose(outputs.shape, expected.shape) + self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + + @parameterized.parameters( + { + "depth_multiplier": 5, + "filters": 5, + "kernel_size": 2, + "strides": 1, + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": 1, + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": 2, + "strides": 1, + "padding": "same", + "data_format": "channels_last", + "dilation_rate": (2, 2), + }, + { + "depth_multiplier": 6, + "filters": 6, + "kernel_size": (2, 2), + "strides": (2, 2), + "padding": "valid", + "data_format": "channels_last", + "dilation_rate": (1, 1), + }, + ) + def test_separable_conv2d( + self, + depth_multiplier, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + ): + layer = layers.SeparableConv2D( + depth_multiplier=depth_multiplier, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + inputs = np.random.normal(size=[2, 8, 8, 4]) + layer.build(input_shape=inputs.shape) + + depthwise_kernel_shape = layer.depthwise_kernel.shape + depthwise_kernel_weights = np.random.normal(size=depthwise_kernel_shape) + layer.depthwise_kernel.assign(depthwise_kernel_weights) + + pointwise_kernel_shape = layer.pointwise_kernel.shape + pointwise_kernel_weights = np.random.normal(size=pointwise_kernel_shape) + layer.pointwise_kernel.assign(pointwise_kernel_weights) + + bias_weights = np.random.normal(size=(filters,)) + layer.bias.assign(bias_weights) + + outputs = layer(inputs) + expected_depthwise = np_depthwise_conv2d( + inputs, + depthwise_kernel_weights, + np.zeros(4 * depth_multiplier), + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + expected = np_conv2d( + expected_depthwise, + pointwise_kernel_weights, + bias_weights, + strides=1, + padding=padding, + data_format=data_format, + dilation_rate=1, + groups=1, + ) + + self.assertAllClose(outputs.shape, expected.shape) + self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) diff --git a/keras/src/layers/core/__init__.py b/keras/src/layers/core/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py new file mode 100644 index 000000000000..7eedbbcc8783 --- /dev/null +++ b/keras/src/layers/core/dense.py @@ -0,0 +1,922 @@ +import math + +import ml_dtypes + +from keras.src import activations +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import quantizers +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.quantizers.quantizers import dequantize_with_sz_map + + +@keras_export("keras.layers.Dense") +class Dense(Layer): + """Just your regular densely-connected NN layer. + + `Dense` implements the operation: + `output = activation(dot(input, kernel) + bias)` + where `activation` is the element-wise activation function + passed as the `activation` argument, `kernel` is a weights matrix + created by the layer, and `bias` is a bias vector created by the layer + (only applicable if `use_bias` is `True`). + + Note: If the input to the layer has a rank greater than 2, `Dense` + computes the dot product between the `inputs` and the `kernel` along the + last axis of the `inputs` and axis 0 of the `kernel` (using `tf.tensordot`). + For example, if input has dimensions `(batch_size, d0, d1)`, then we create + a `kernel` with shape `(d1, units)`, and the `kernel` operates along axis 2 + of the `input`, on every sub-tensor of shape `(1, 1, d1)` (there are + `batch_size * d0` such sub-tensors). The output in this case will have + shape `(batch_size, d0, units)`. + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + lora_rank: Optional integer. If set, the layer's forward pass + will implement LoRA (Low-Rank Adaptation) + with the provided rank. LoRA sets the layer's kernel + to non-trainable and replaces it with a delta over the + original kernel, obtained via multiplying two lower-rank + trainable matrices. This can be useful to reduce the + computation cost of fine-tuning large dense layers. + You can also enable LoRA on an existing + `Dense` layer by calling `layer.enable_lora(rank)`. + lora_alpha: Optional integer. If set, this parameter scales the + low-rank adaptation delta (computed as the product of two lower-rank + trainable matrices) during the forward pass. The delta is scaled by + `lora_alpha / lora_rank`, allowing you to fine-tune the strength of + the LoRA adjustment independently of `lora_rank`. + + Input shape: + N-D tensor with shape: `(batch_size, ..., input_dim)`. + The most common situation would be + a 2D input with shape `(batch_size, input_dim)`. + + Output shape: + N-D tensor with shape: `(batch_size, ..., units)`. + For instance, for a 2D input with shape `(batch_size, input_dim)`, + the output would have shape `(batch_size, units)`. + """ + + def __init__( + self, + units, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + lora_rank=None, + lora_alpha=None, + **kwargs, + ): + super().__init__(activity_regularizer=activity_regularizer, **kwargs) + self.units = units + self.activation = activations.get(activation) + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank + self.lora_enabled = False + self.input_spec = InputSpec(min_ndim=2) + self.supports_masking = True + + def build(self, input_shape): + kernel_shape = (input_shape[-1], self.units) + if self.quantization_mode: + self.quantized_build(kernel_shape, mode=self.quantization_mode) + if self.quantization_mode not in ("int8", "int4", "gptq"): + # If the layer is quantized to int8 or int4, `self._kernel` will be + # added in `self._int8_build` or `_int4_build`. Therefore, we skip + # it here. + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + if self.use_bias: + self.bias = self.add_weight( + name="bias", + shape=(self.units,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + self.input_spec = InputSpec(min_ndim=2, axes={-1: input_shape[-1]}) + self.built = True + if self.lora_rank: + self.enable_lora(self.lora_rank) + + @property + def kernel(self): + from keras.src.quantizers import gptq_core + + if not self.built: + raise AttributeError( + "You must build the layer before accessing `kernel`." + ) + + mode = self.quantization_mode + is_gptq = mode == "gptq" + is_int4 = mode == "int4" + calibrated = bool(getattr(self, "is_gptq_calibrated", False)) + gptq_bits = ( + gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None + ) + + # Decide the source tensor first (packed vs already-quantized vs plain + # kernel) + if is_gptq and calibrated and gptq_bits != 4: + # calibrated GPTQ, not 4-bit, no unpacking needed + kernel = self.quantized_kernel + else: + # Start with the stored kernel + kernel = getattr(self, "_kernel", None) + + # Handle int4 unpacking cases in one place + if is_int4: + kernel = quantizers.unpack_int4(kernel, self._orig_input_dim) + elif is_gptq and calibrated and gptq_bits == 4: + kernel = quantizers.unpack_int4( + self.quantized_kernel, + orig_len=self.units, + axis=0, + dtype="uint8", + ) + + # Apply LoRA once at the end. + if self.lora_enabled: + kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_kernel_a, self.lora_kernel_b + ) + + return kernel + + def call(self, inputs, training=None): + x = ops.matmul(inputs, self.kernel) + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + output_shape[-1] = self.units + return tuple(output_shape) + + def enable_lora( + self, + rank, + lora_alpha=None, + a_initializer="he_uniform", + b_initializer="zeros", + ): + if self.kernel_constraint: + raise ValueError( + "Lora is incompatible with kernel constraints. " + "In order to enable lora on this layer, remove the " + "`kernel_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. This can only be done once per layer." + ) + if self.quantization_mode == "gptq": + raise NotImplementedError( + "lora is not currently supported with GPTQ quantization." + ) + self._tracker.unlock() + # Determine the correct input dimension for the LoRA A matrix. When + # the layer has been int4-quantized, `self._kernel` stores a *packed* + # representation whose first dimension is `ceil(input_dim/2)`. We + # saved the true, *unpacked* input dimension in `self._orig_input_dim` + # during quantization. Use it if available; otherwise fall back to the + # first dimension of `self.kernel`. + if self.quantization_mode == "int4" and hasattr( + self, "_orig_input_dim" + ): + input_dim_for_lora = self._orig_input_dim + else: + input_dim_for_lora = self.kernel.shape[0] + + self.lora_kernel_a = self.add_weight( + name="lora_kernel_a", + shape=(input_dim_for_lora, rank), + initializer=initializers.get(a_initializer), + regularizer=self.kernel_regularizer, + ) + self.lora_kernel_b = self.add_weight( + name="lora_kernel_b", + shape=(rank, self.kernel.shape[1]), + initializer=initializers.get(b_initializer), + regularizer=self.kernel_regularizer, + ) + self._kernel.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + self.lora_alpha = lora_alpha if lora_alpha is not None else rank + + def save_own_variables(self, store): + # Do nothing if the layer isn't yet built + if not self.built: + return + mode = self.quantization_mode + if mode not in self.quantization_variable_spec: + raise self._quantization_mode_error(mode) + + # Kernel plus optional merged LoRA-aware scale (returns (kernel, None) + # for None/gptq) + kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora() + + # Save the variables using the name as the key. + if mode != "gptq": + store["kernel"] = kernel_value + if self.bias is not None: + store["bias"] = self.bias + for name in self.quantization_variable_spec[mode]: + if name == "kernel_scale" and mode in ("int4", "int8"): + # For int4/int8, the merged LoRA scale (if any) comes from + # `_get_kernel_with_merged_lora()` + store[name] = merged_kernel_scale + else: + store[name] = getattr(self, name) + + def load_own_variables(self, store): + if not self.lora_enabled: + self._check_load_own_variables(store) + # Do nothing if the layer isn't yet built + if not self.built: + return + mode = self.quantization_mode + if mode not in self.quantization_variable_spec: + raise self._quantization_mode_error(mode) + + # Determine whether to use the legacy loading method. + if "0" in store: + return self._legacy_load_own_variables(store) + + # Load the variables using the name as the key. + if mode != "gptq": + self._kernel.assign(store["kernel"]) + if self.bias is not None: + self.bias.assign(store["bias"]) + for name in self.quantization_variable_spec[mode]: + getattr(self, name).assign(store[name]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + + def _legacy_load_own_variables(self, store): + # The keys of the `store` will be saved as determined because the + # default ordering will change after quantization + mode = self.quantization_mode + targets = [] + if mode != "gptq": + targets.append(self._kernel) + if self.bias is not None: + targets.append(self.bias) + targets.extend( + getattr(self, name) + for name in self.quantization_variable_spec[mode] + ) + for i, variable in enumerate(targets): + variable.assign(store[str(i)]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + + def get_config(self): + base_config = super().get_config() + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + if self.lora_rank: + config["lora_rank"] = self.lora_rank + config["lora_alpha"] = self.lora_alpha + return {**base_config, **config} + + def _check_load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) + + @property + def quantization_variable_spec(self): + """Returns a dict mapping quantization modes to variable names. + + This spec is used by `save_own_variables` and `load_own_variables` to + determine which variables should be saved/loaded for each quantization + mode. + """ + return { + None: [], + "int8": ["kernel_scale"], + "int4": ["kernel_scale"], + "float8": [ + "inputs_scale", + "inputs_amax_history", + "kernel_scale", + "kernel_amax_history", + "outputs_grad_scale", + "outputs_grad_amax_history", + ], + "gptq": [ + "quantized_kernel", + "kernel_scale", + "kernel_zero", + "g_idx", + ], + } + + def quantized_build(self, kernel_shape, mode, config=None): + if mode == "int8": + self._int8_build(kernel_shape) + elif mode == "int4": + self._int4_build(kernel_shape) + elif mode == "float8": + self._float8_build() + elif mode == "gptq": + self._gptq_build(kernel_shape, config) + else: + raise self._quantization_mode_error(mode) + self._is_quantized = True + + def _int8_build(self, kernel_shape): + self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer="zeros", + dtype="int8", + trainable=False, + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=(self.units,), + initializer="ones", + trainable=False, + ) + + def _gptq_build(self, kernel_shape, config): + from keras.src.quantizers import gptq_core + + # Ensures the forward pass uses the original high-precision kernel + # until calibration has been performed. + self.is_gptq_calibrated = False + self.kernel_shape = kernel_shape + + weight_bits = gptq_core.get_weight_bits_for_layer(self, config) + # For 4-bit weights, we pack two values per byte. + units = ( + (kernel_shape[1] + 1) // 2 if weight_bits == 4 else kernel_shape[1] + ) + + self.quantized_kernel = self.add_weight( + name="kernel", + shape=(units, kernel_shape[0]), + initializer="zeros", + dtype="uint8", + trainable=False, + ) + + group_size = gptq_core.get_group_size_for_layer(self, config) + n_groups = ( + 1 + if group_size == -1 + else math.ceil(self.kernel_shape[0] / group_size) + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=(self.units, n_groups), + initializer="ones", + trainable=False, + ) + self.kernel_zero = self.add_weight( + name="kernel_zero", + shape=(self.units, n_groups), + initializer="zeros", + dtype="uint8", + trainable=False, + ) + self.g_idx = self.add_weight( + name="g_idx", + shape=(self.kernel_shape[0],), + initializer="zeros", + dtype="float32", + trainable=False, + ) + + def _gptq_call(self, inputs, training=False): + from keras.src.quantizers import gptq_core + + if not self.is_gptq_calibrated: + W = self._kernel + else: + should_unpack = ( + gptq_core.get_weight_bits_for_layer(self, config=None) == 4 + ) + W = ( + quantizers.unpack_int4( + self.quantized_kernel, + orig_len=self.units, + axis=0, + dtype="uint8", + ) + if should_unpack + else self.quantized_kernel + ) + W = ops.transpose( + dequantize_with_sz_map( + W, + self.kernel_scale, + self.kernel_zero, + self.g_idx, + ) + ) + + y = ops.matmul(inputs, W) + if self.bias is not None: + y = ops.add(y, self.bias) + if self.activation is not None: + y = self.activation(y) + return y + + def _int4_build(self, kernel_shape): + """Build variables for int4 quantization. + + `kernel_shape` is the *original* float32 kernel shape + `(input_dim, units)`. We allocate the stored kernel with rows + `ceil(input_dim/2)` because two int4 values are packed into a single + int8 byte. + """ + # Per-channel int8 quantizer for the last axis (features). + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=-1, + ) + input_dim, output_dim = kernel_shape + packed_rows = (input_dim + 1) // 2 # ceil for odd dims + + # Kernel is stored *packed*: each int8 byte contains two int4 values. + self._kernel = self.add_weight( + name="kernel", + shape=(packed_rows, output_dim), + initializer="zeros", + dtype="int8", + trainable=False, + ) + # One scale per output unit (per-channel). + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=(self.units,), + initializer="ones", + trainable=False, + ) + # Record original input_dim for unpacking at runtime. + self._orig_input_dim = input_dim + + def _float8_build(self): + from keras.src.dtype_policies import QuantizedFloat8DTypePolicy + + # If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set + # `amax_history_length` to its default value. + amax_history_length = getattr( + self.dtype_policy, + "amax_history_length", + QuantizedFloat8DTypePolicy.default_amax_history_length, + ) + # We set `trainable=True` because we will use the gradients to overwrite + # these variables + scale_kwargs = { + "shape": (), + "initializer": "ones", + "dtype": "float32", # Always be float32 + "trainable": True, + "autocast": False, + "overwrite_with_gradient": True, + } + amax_history_kwargs = { + "shape": (amax_history_length,), + "initializer": "zeros", + "dtype": "float32", # Always be float32 + "trainable": True, + "autocast": False, + "overwrite_with_gradient": True, + } + self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.inputs_amax_history = self.add_weight( + name="inputs_amax_history", **amax_history_kwargs + ) + self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.kernel_amax_history = self.add_weight( + name="kernel_amax_history", **amax_history_kwargs + ) + self.outputs_grad_scale = self.add_weight( + name="outputs_grad_scale", **scale_kwargs + ) + self.outputs_grad_amax_history = self.add_weight( + name="outputs_grad_amax_history", **amax_history_kwargs + ) + + def _int8_call(self, inputs, training=None): + @ops.custom_gradient + def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + """Custom gradient function to handle the int8 quantized weights. + + Automatic differentiation will not know how to handle the int8 + quantized weights. So a custom gradient function is needed to + handle the int8 quantized weights. + + The custom gradient function will use the dequantized kernel to + compute the gradient. + """ + + def grad_fn(*args, upstream=None): + if upstream is None: + (upstream,) = args + float_kernel = ops.divide( + ops.cast(kernel, dtype=self.compute_dtype), + kernel_scale, + ) + inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) + return (inputs_grad, None, None) + + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.matmul(inputs, kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + return x, grad_fn + + x = matmul_with_inputs_gradient( + inputs, + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ) + if self.lora_enabled: + lora_x = ops.matmul(inputs, self.lora_kernel_a) + lora_x = ops.matmul(lora_x, self.lora_kernel_b) + x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x) + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def _int4_call(self, inputs, training=None): + """Forward pass for int4 quantized Dense layer.""" + + @ops.custom_gradient + def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + """Custom gradient function for int4 quantized weights. + + Automatic differentiation will not know how to handle the + int4 quantized weights. So a custom gradient function is needed + to handle the int4 quantized weights. + + The custom gradient function will use the dequantized kernel to + compute the gradient. + """ + + unpacked_kernel = quantizers.unpack_int4( + kernel, self._orig_input_dim + ) + + def grad_fn(*args, upstream=None): + if upstream is None: + (upstream,) = args + float_kernel = ops.divide( + ops.cast(unpacked_kernel, dtype=self.compute_dtype), + kernel_scale, + ) + inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) + return (inputs_grad, None, None) + + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.matmul(inputs, unpacked_kernel) + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + return x, grad_fn + + x = matmul_with_inputs_gradient( + inputs, + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ) + + if self.lora_enabled: + lora_x = ops.matmul(inputs, self.lora_kernel_a) + lora_x = ops.matmul(lora_x, self.lora_kernel_b) + x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x) + + # Add bias and activation + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def _float8_call(self, inputs, training=None): + if self.lora_enabled: + raise NotImplementedError( + "Currently, `_float8_call` doesn't support LoRA" + ) + + @ops.custom_gradient + def quantized_dequantize_inputs(inputs, scale, amax_history): + if training: + new_scale = quantizers.compute_float8_scale( + ops.max(amax_history, axis=0), + scale, + ops.cast( + float(ml_dtypes.finfo("float8_e4m3fn").max), "float32" + ), + ) + new_amax_history = quantizers.compute_float8_amax_history( + inputs, amax_history + ) + else: + new_scale = None + new_amax_history = None + qdq_inputs = quantizers.quantize_and_dequantize( + inputs, scale, "float8_e4m3fn", self.compute_dtype + ) + + def grad(*args, upstream=None, variables=None): + if upstream is None: + (upstream,) = args + return upstream, new_scale, new_amax_history + + return qdq_inputs, grad + + @ops.custom_gradient + def quantized_dequantize_outputs(outputs, scale, amax_history): + """Quantize-dequantize the output gradient but not the output.""" + + def grad(*args, upstream=None, variables=None): + if upstream is None: + (upstream,) = args + new_scale = quantizers.compute_float8_scale( + ops.max(amax_history, axis=0), + scale, + ops.cast( + float(ml_dtypes.finfo("float8_e5m2").max), "float32" + ), + ) + qdq_upstream = quantizers.quantize_and_dequantize( + upstream, scale, "float8_e5m2", self.compute_dtype + ) + new_amax_history = quantizers.compute_float8_amax_history( + upstream, amax_history + ) + return qdq_upstream, new_scale, new_amax_history + + return outputs, grad + + x = ops.matmul( + quantized_dequantize_inputs( + inputs, + ops.convert_to_tensor(self.inputs_scale), + ops.convert_to_tensor(self.inputs_amax_history), + ), + quantized_dequantize_inputs( + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ops.convert_to_tensor(self.kernel_amax_history), + ), + ) + # `quantized_dequantize_outputs` is placed immediately after + # `ops.matmul` for the sake of pattern matching in gemm_rewrite. That + # way, the qdq will be adjacent to the corresponding matmul_bprop in the + # bprop. + x = quantized_dequantize_outputs( + x, + ops.convert_to_tensor(self.outputs_grad_scale), + ops.convert_to_tensor(self.outputs_grad_amax_history), + ) + if self.bias is not None: + # Under non-mixed precision cases, F32 bias has to be converted to + # BF16 first to get the biasAdd fusion support. ref. PR + # https://github.com/tensorflow/tensorflow/pull/60306 + bias = self.bias + if self.dtype_policy.compute_dtype == "float32": + bias_bf16 = ops.cast(bias, "bfloat16") + bias = ops.cast(bias_bf16, bias.dtype) + x = ops.add(x, bias) + if self.activation is not None: + x = self.activation(x) + return x + + def quantize(self, mode, type_check=True, config=None): + # Prevent quantization of the subclasses + if type_check and (type(self) is not Dense): + raise self._not_implemented_error(self.quantize) + + kernel_shape = self._kernel.shape + if mode == "int8": + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=0, to_numpy=True + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + del self._kernel + # Build variables for int8 mode + self.quantized_build(kernel_shape, mode) + self._kernel.assign(kernel_value) + self.kernel_scale.assign(kernel_scale) + elif mode == "int4": + # 1. Quantize to int4 values (still int8 dtype, range [-8,7]) + kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( + self._kernel, + axis=0, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + # 2. Pack two int4 values into a single int8 byte. + packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4) + del self._kernel + # Build variables using the original kernel shape; _int4_build will + # compute the packed shape internally. + self.quantized_build(kernel_shape, mode) + # Assign packed values. + self._kernel.assign(packed_kernel_value) + self.kernel_scale.assign(kernel_scale) + elif mode == "gptq": + self.quantized_build(kernel_shape, mode, config) + elif mode == "float8": + self.quantized_build(kernel_shape, mode) + else: + raise self._quantization_mode_error(mode) + + # Set new dtype policy only for modes that already have a policy. + if self.dtype_policy.quantization_mode is None: + from keras.src import dtype_policies # local import to avoid cycle + + policy_name = mode + if mode == "gptq": + policy_name = config.dtype_policy_string() + policy = dtype_policies.get( + f"{policy_name}_from_{self.dtype_policy.name}" + ) + self.dtype_policy = policy + + def _get_kernel_with_merged_lora(self): + """Returns the kernel with LoRA matrices merged, for serialization. + + This method is called by `save_own_variables` to produce a single + kernel tensor that includes the adaptations from LoRA. This is useful + for deploying the model or for continuing training after permanently + applying the LoRA update. + + If the layer is quantized (`int8` or `int4`), the process is: + 1. Dequantize the base kernel to float. + 2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add + it to the dequantized kernel. + 3. Re-quantize the merged result back to the original quantized + type (`int8` or packed `int4`), calculating a new scale factor. + + If the layer is not quantized, this method returns the result of the + `kernel` property (which computes the merge in floating-point) and a + scale of `None`. + + If LoRA is not enabled, it returns the original kernel and scale + without modification. + + Returns: + A tuple `(kernel_value, kernel_scale)`: + `kernel_value`: The merged kernel. A quantized tensor if + quantization is active, otherwise a high precision tensor. + `kernel_scale`: The quantization scale for the merged kernel. + This is `None` if the layer is not quantized. + """ + if self.dtype_policy.quantization_mode in (None, "gptq"): + return self.kernel, None + + kernel_value = self._kernel + kernel_scale = self.kernel_scale + + if not self.lora_enabled: + return kernel_value, kernel_scale + + # Dequantize, Merge, and Re-quantize + + # Dequantize kernel to float + if self.quantization_mode == "int4": + unpacked_kernel = quantizers.unpack_int4( + kernel_value, self._orig_input_dim + ) + float_kernel = ops.divide( + ops.cast(unpacked_kernel, self.compute_dtype), + kernel_scale, + ) + quant_range = (-8, 7) + elif self.quantization_mode == "int8": + float_kernel = ops.divide( + ops.cast(kernel_value, self.compute_dtype), kernel_scale + ) + quant_range = (-127, 127) + else: + raise ValueError( + f"Unsupported quantization mode: {self.quantization_mode}" + ) + + # Merge LoRA weights in float domain + lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_kernel_a, self.lora_kernel_b + ) + merged_float_kernel = ops.add(float_kernel, lora_delta) + + # Requantize + requantized_kernel, kernel_scale = quantizers.abs_max_quantize( + merged_float_kernel, + axis=0, + value_range=quant_range, + dtype="int8", + to_numpy=True, + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + + # Pack if int4 + if self.quantization_mode == "int4": + kernel_value, _, _ = quantizers.pack_int4(requantized_kernel) + else: + kernel_value = requantized_kernel + return kernel_value, kernel_scale diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py new file mode 100644 index 000000000000..11b7587195c4 --- /dev/null +++ b/keras/src/layers/core/dense_test.py @@ -0,0 +1,955 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import constraints +from keras.src import export +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import quantizers +from keras.src import random +from keras.src import saving +from keras.src import testing +from keras.src.backend.common import keras_tensor +from keras.src.quantizers.gptq_config import GPTQConfig + + +class DenseTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_dense_basics(self): + # 2D case, no bias. + self.run_layer_test( + layers.Dense, + init_kwargs={ + "units": 4, + "activation": "relu", + "kernel_initializer": "random_uniform", + "bias_initializer": "ones", + "use_bias": False, + }, + input_shape=(2, 3), + expected_output_shape=(2, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # 3D case, some regularizers. + self.run_layer_test( + layers.Dense, + init_kwargs={ + "units": 5, + "activation": "sigmoid", + "kernel_regularizer": "l2", + "bias_regularizer": "l2", + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=2, # we have 2 regularizers. + supports_masking=True, + ) + + def test_dense_correctness(self): + # With bias and activation. + layer = layers.Dense(units=2, activation="relu") + layer.build((1, 2)) + layer.set_weights( + [ + np.array([[1.0, -2.0], [3.0, -4.0]]), + np.array([5.0, -6.0]), + ] + ) + inputs = np.array( + [[-1.0, 2.0]], + ) + self.assertAllClose(layer(inputs), [[10.0, 0.0]]) + + # Just a kernel matmul. + layer = layers.Dense(units=2, use_bias=False) + layer.build((1, 2)) + layer.set_weights( + [ + np.array([[1.0, -2.0], [3.0, -4.0]]), + ] + ) + inputs = np.array( + [[-1.0, 2.0]], + ) + self.assertEqual(layer.bias, None) + self.assertAllClose(layer(inputs), [[5.0, -6.0]]) + + def test_dense_errors(self): + with self.assertRaisesRegex(ValueError, "incompatible with the layer"): + layer = layers.Dense(units=2, activation="relu") + layer(keras_tensor.KerasTensor((1, 2))) + layer(keras_tensor.KerasTensor((1, 3))) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_dense_sparse(self): + import tensorflow as tf + + self.run_layer_test( + layers.Dense, + init_kwargs={ + "units": 4, + }, + input_shape=(2, 3), + input_sparse=True, + expected_output_shape=(2, 4), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + ) + + inputs = 4 * backend.random.uniform((10, 10)) + inputs = tf.sparse.from_dense(tf.nn.dropout(inputs, 0.8)) + + inputs = np.random.random((10, 10)).astype("float32") + inputs = np.multiply(inputs, inputs >= 0.8) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + inputs = tf.sparse.from_dense(inputs) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + inputs = jax_sparse.BCOO.fromdense(inputs) + else: + self.fail(f"Sparse is unsupported with backend {backend.backend()}") + + layer = layers.Dense(units=10) + outputs = layer(inputs) + + # Verify the computation is the same as if it had been a dense tensor + expected_outputs = ops.add( + ops.matmul( + backend.convert_to_tensor(inputs, sparse=False), layer.kernel + ), + layer.bias, + ) + self.assertAllClose(outputs, expected_outputs) + + # Verify the gradient is sparse + if backend.backend() == "tensorflow": + import tensorflow as tf + + with tf.GradientTape() as g: + outputs = layer(inputs) + + self.assertIsInstance( + g.gradient(outputs, layer.kernel), tf.IndexedSlices + ) + + def test_dense_no_activation(self): + layer = layers.Dense(units=2, use_bias=False, activation=None) + layer.build((1, 2)) + layer.set_weights( + [ + np.array([[1.0, -2.0], [3.0, -4.0]]), + ] + ) + inputs = np.array( + [[-1.0, 2.0]], + ) + self.assertEqual(layer.bias, None) + self.assertAllClose(layer(inputs), [[5.0, -6.0]]) + + def test_dense_without_activation_set(self): + layer = layers.Dense(units=2, use_bias=False) + layer.build((1, 2)) + layer.set_weights( + [ + np.array([[1.0, -2.0], [3.0, -4.0]]), + ] + ) + layer.activation = None + inputs = np.array( + [[-1.0, 2.0]], + ) + self.assertEqual(layer.bias, None) + self.assertAllClose(layer(inputs), [[5.0, -6.0]]) + + def test_dense_with_activation(self): + layer = layers.Dense(units=2, use_bias=False, activation="relu") + layer.build((1, 2)) + layer.set_weights( + [ + np.array([[1.0, -2.0], [3.0, -4.0]]), + ] + ) + + inputs = np.array( + [[-1.0, 2.0]], + ) + output = layer(inputs) + expected_output = np.array([[5.0, 0.0]]) + self.assertAllClose(output, expected_output) + + def test_dense_constraints(self): + layer = layers.Dense(units=2, kernel_constraint="non_neg") + layer.build((None, 2)) + self.assertIsInstance(layer.kernel.constraint, constraints.NonNeg) + layer = layers.Dense(units=2, bias_constraint="non_neg") + layer.build((None, 2)) + self.assertIsInstance(layer.bias.constraint, constraints.NonNeg) + + @pytest.mark.requires_trainable_backend + def test_enable_lora(self): + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.enable_lora(4) + self.assertLen(layer.trainable_weights, 3) + self.assertLen(layer.non_trainable_weights, 1) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 4) + # Try eager call + x = np.random.random((64, 8)) + y = np.random.random((64, 16)) + _ = layer(x[:2]) + + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + + # Try calling fit() + model = models.Sequential( + [ + layer, + ] + ) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + new_model = models.Sequential( + [ + layers.Dense(units=16), + ] + ) + new_model.build((None, 8)) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @pytest.mark.requires_trainable_backend + def test_enable_lora_with_alpha(self): + # Create a `Dense` layer and build it. + layer = layers.Dense(units=8) + layer.build((None, 4)) + + # Enable LoRA with `rank`=2 and `lora_alpha`=3.0. + layer.enable_lora(2, lora_alpha=3.0) + self.assertEqual(layer.lora_rank, 2) + self.assertEqual(layer.lora_alpha, 3.0) + + # Manually compute the expected effective kernel: + # `effective_kernel_expected` = `base_kernel` + + # `lora_alpha / lora_rank` * `lora_kernel_a @ lora_kernel_b` + base_kernel = ops.convert_to_numpy(layer._kernel) + lora_update = np.matmul( + ops.convert_to_numpy(layer.lora_kernel_a), + ops.convert_to_numpy(layer.lora_kernel_b), + ) + effective_kernel_expected = base_kernel + (3.0 / 2) * lora_update + + # Verify that the effective kernel matches expectation. + self.assertAllClose( + ops.convert_to_numpy(layer.kernel), effective_kernel_expected + ) + + @pytest.mark.requires_trainable_backend + def test_lora_weight_name(self): + class MyModel(models.Model): + def __init__(self): + super().__init__(name="mymodel") + self.dense = layers.Dense(16, name="dense") + + def build(self, input_shape): + self.dense.build(input_shape) + + def call(self, x): + return self.dense(x) + + model = MyModel() + model.build((None, 8)) + model.dense.enable_lora(4) + self.assertEqual( + model.dense.lora_kernel_a.path, "mymodel/dense/lora_kernel_a" + ) + + @pytest.mark.requires_trainable_backend + def test_lora_rank_argument(self): + self.run_layer_test( + layers.Dense, + init_kwargs={ + "units": 5, + "activation": "sigmoid", + "kernel_regularizer": "l2", + "lora_rank": 2, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=2, # we have 2 regularizers. + supports_masking=True, + ) + + def test_enable_lora_with_kernel_constraint(self): + layer = layers.Dense(units=2, kernel_constraint="max_norm") + with self.assertRaisesRegex( + ValueError, "incompatible with kernel constraints" + ): + layer.enable_lora(rank=2) + + def test_enable_lora_on_unbuilt_layer(self): + layer = layers.Dense(units=2) + with self.assertRaisesRegex( + ValueError, "Cannot enable lora on a layer that isn't yet built" + ): + layer.enable_lora(rank=2) + + def test_enable_lora_when_already_enabled(self): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.enable_lora(rank=2) + with self.assertRaisesRegex(ValueError, "lora is already enabled"): + layer.enable_lora(rank=2) + + # Test quantization-related methods. + + @parameterized.named_parameters( + ("int8", "int8", 1e-3), + ("int4", "int4", 2e-3), + ) + def test_quantize_int(self, mode, error_threshold): + if mode == "int4" and testing.tensorflow_uses_gpu(): + self.skipTest("Segfault") + layer = layers.Dense(units=16) + layer.build((None, 8)) + x = np.random.random((2, 8)) + y_float = layer(x) + layer.quantize(mode) + + # Verify the dtype of the weights. + # The kernel's data type is int8, despite the int4 quantization, because + # we pack the int4 values into int8. + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), + layer.variable_dtype, + ) + + # Verify the correctness of the outputs. + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, error_threshold) # A weak correctness test + + # Check model save / load round-trip. + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Check weights-only save / load round-trip. + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(units=16)]) + new_model.build((None, 8)) + new_model.quantize(mode) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ("float8", "float8"), + ) + def test_quantize_on_unbuilt_layer(self, mode): + layer = layers.Dense(units=2) + with self.assertRaisesRegex( + ValueError, "Cannot quantize a layer that isn't yet built." + ): + layer.quantize(mode) + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ("float8", "float8"), + ) + def test_quantize_on_subclass(self, mode): + class MyDense(layers.Dense): + pass + + layer = MyDense(units=16) + layer.build((None, 8)) + with self.assertRaises(NotImplementedError): + layer.quantize(mode) + + layer.quantize(mode, type_check=False) # No error + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ("float8", "float8"), + ) + def test_quantize_when_already_quantized(self, mode): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize(mode) + for m in ["int8", "int4", "float8"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + layer = layers.Dense(units=2, dtype=f"{mode}_from_float32") + layer.build((None, 2)) + for m in ["int8", "int4", "float8"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + @parameterized.named_parameters( + ("int8", "int8_from_float32", 3), + ("int4", "int4_from_float32", 3), # bias + packed kernel + scale + ("float8", "float8_from_float32", 8), + ) + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_quantize_by_setting_dtype_policy( + self, policy, expected_num_variables + ): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.dtype_policy = policy + self.assertLen(layer.variables, expected_num_variables) + + @parameterized.named_parameters( + ("int7", "int7"), + ("float7", "float7"), + ) + def test_quantize_invalid_mode(self, mode): + layer = layers.Dense(units=2) + layer.build((None, 2)) + x = np.random.random((1, 2)) + # dtype_policy should not be altered by failed quantization + original_dtype_policy = layer.dtype_policy + + # Test quantize + with self.assertRaisesRegex(ValueError, "Invalid quantization mode."): + layer.quantize(mode) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + # Test quantized_build + with self.assertRaisesRegex( + NotImplementedError, "Invalid quantization mode." + ): + layer.quantized_build((None, 2), mode) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + # Test quantized_call + with self.assertRaisesRegex( + NotImplementedError, "Invalid quantization mode." + ): + # Explicitly set quantization_mode + layer._dtype_policy._quantization_mode = mode + layer.quantized_call(x) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + @parameterized.named_parameters( + ("int8", "int8_from_mixed_bfloat16", 1, 2), + ("int4", "int4_from_mixed_bfloat16", 1, 2), + ("float8", "float8_from_mixed_bfloat16", 8, 0), + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_quantize_dtype_argument( + self, dtype, num_trainable_weights, num_non_trainable_weights + ): + self.run_layer_test( + layers.Dense, + init_kwargs={"units": 5, "dtype": dtype}, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=num_trainable_weights, + expected_num_non_trainable_weights=num_non_trainable_weights, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + @parameterized.named_parameters( + ("int8", "int8", 3, 2, 5), + ("int4", "int4", 3, 2, 5), + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_quantize_lora_integration( + self, + mode, + num_trainable_weights, + num_non_trainable_weights, + num_torch_params, + ): + # Note that saving and loading with lora_enabled and quantized are + # lossy, so we use a weak correctness test for model outputs (atol=0.5). + config = dict(units=16) + layer = layers.Dense(**config) + layer.build((None, 8)) + layer.enable_lora(4) + layer.quantize(mode) + self.assertLen(layer.trainable_weights, num_trainable_weights) + self.assertLen(layer.non_trainable_weights, num_non_trainable_weights) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, num_torch_params) + + # Try calling fit() + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + x = np.random.random((64, 8)) + y = np.random.random((64, 16)) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(**config)]) + new_model.build((None, 8)) + new_model.quantize(mode) + new_model.load_weights(temp_filepath) + self.assertFalse(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal((2, 8)) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose( + reloaded_layer(ref_input), ref_output, atol=1e-7 + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_quantize_float8(self): + import ml_dtypes + + from keras.src import quantizers + + layer = layers.Dense(units=32) + layer.build((None, 16)) + layer.quantize("float8") + optimizer = optimizers.AdamW(learning_rate=0.1) + optimizer.build(layer.trainable_variables) + + def loss_fn(x, dy): + y = layer(x, training=True) + loss = y * ops.cast(dy, y.dtype) + return ops.sum(loss) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + @tf.function(jit_compile=True) + def train_one_step(x, dy): + with tf.GradientTape() as tape: + loss = loss_fn(x, dy) + grads = tape.gradient(loss, layer.trainable_variables) + optimizer.apply(grads, layer.trainable_variables) + + elif backend.backend() == "jax": + import jax + + def stateless_loss_fn(trainable_variables, x, dy): + y = layer.stateless_call( + trainable_variables, [], x, training=True + )[0] + loss = y * ops.cast(dy, y.dtype) + return ops.sum(loss) + + grad_fn = jax.jit(jax.grad(stateless_loss_fn)) + + def train_one_step(x, dy): + trainable_variables = [ + v.value for v in layer.trainable_variables + ] + optimizer_variables = [v.value for v in optimizer.variables] + grads = grad_fn(trainable_variables, x, dy) + trainable_variables, optimizer_variables = ( + optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + ) + for variable, value in zip( + layer.trainable_variables, trainable_variables + ): + variable.assign(value) + for variable, value in zip( + optimizer.variables, optimizer_variables + ): + variable.assign(value) + + elif backend.backend() == "torch": + + def train_one_step(x, dy): + layer.zero_grad() + loss = loss_fn(x, dy) + loss.backward() + grads = [v.value.grad for v in layer.trainable_variables] + optimizer.apply(grads, layer.trainable_variables) + + scale_x, amax_history_x = ops.ones(()), ops.zeros((1024,)) + scale_k, amax_history_k = ops.ones(()), ops.zeros((1024,)) + scale_g, amax_history_g = ops.ones(()), ops.zeros((1024,)) + e4m3_max = ops.cast( + float(ml_dtypes.finfo("float8_e4m3fn").max), "float32" + ) + e5m2_max = ops.cast( + float(ml_dtypes.finfo("float8_e5m2").max), "float32" + ) + + for _ in range(3): + x = random.normal((16, 16), dtype="float32") + g = random.normal((16, 32), dtype="float32") + k = ops.convert_to_tensor(layer._kernel) + + # Manually compute the expected amax history and scaling factors. + amax_from_history_x = ops.max(amax_history_x) + amax_from_history_k = ops.max(amax_history_k) + amax_from_history_g = ops.max(amax_history_g) + scale_x = quantizers.compute_float8_scale( + amax_from_history_x, scale_x, e4m3_max + ) + scale_k = quantizers.compute_float8_scale( + amax_from_history_k, scale_k, e4m3_max + ) + scale_g = quantizers.compute_float8_scale( + amax_from_history_g, scale_g, e5m2_max + ) + amax_history_x = quantizers.compute_float8_amax_history( + x, amax_history_x + ) + amax_history_k = quantizers.compute_float8_amax_history( + k, amax_history_k + ) + amax_history_g = quantizers.compute_float8_amax_history( + g, amax_history_g + ) + + train_one_step(x, g) + + self.assertAllClose(layer.inputs_amax_history, amax_history_x) + self.assertAllClose(layer.kernel_amax_history, amax_history_k) + self.assertAllClose(layer.outputs_grad_amax_history, amax_history_g) + self.assertAllClose(layer.inputs_scale, scale_x) + self.assertAllClose(layer.kernel_scale, scale_k) + self.assertAllClose(layer.outputs_grad_scale, scale_g) + + @pytest.mark.requires_trainable_backend + def test_quantize_float8_fitting(self): + config = dict(units=16) + layer = layers.Dense(**config) + layer.build((None, 8)) + layer.quantize("float8") + self.assertLen(layer.trainable_weights, 8) + self.assertLen(layer.non_trainable_weights, 0) + + # Try calling fit() + x = np.random.random((64, 8)) + y = np.random.random((64, 16)) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + # Try saving and reloading the model + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_float8_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_float8_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(**config)]) + new_model.build((None, 8)) + new_model.quantize("float8") + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal((2, 8)) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose(reloaded_layer(ref_input), ref_output) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_quantize_float8_inference(self): + config = dict(units=16) + layer = layers.Dense(**config) + layer.build((None, 8)) + layer.quantize("float8") + + # Try calling with `training=False` and the result must match + # `training=True` because there is no update. + x = np.random.random((64, 8)) + y_inference = layer(x, training=False) + y_training = layer(x, training=True) + self.assertAllClose(y_inference, y_training) + + def test_gptq_serialization(self): + """Test that a GPTQ-quantized layer can be serialized and deserialized + correctly.""" + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.quantize( + "gptq", + config=GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=8 + ), + ) + config = layer.get_config() + new_layer = layers.Dense.from_config(config) + new_layer.build((None, 8)) + self.assertEqual(new_layer.quantization_mode, "gptq") + + def test_int4_kernel_returns_unpacked_form(self): + """Test that the `kernel` property returns the unpacked int4 kernel.""" + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize("int4") + packed_kernel = layer._kernel + self.assertAllClose( + layer.kernel, quantizers.unpack_int4(packed_kernel, 2) + ) + + def test_legacy_load_own_variables(self): + # In previous versions, `load_own_variables` accepted a store with + # numeric keys. + float32_store = { + "0": np.random.random((8, 16)).astype("float32"), + "1": np.random.random((16,)).astype("float32"), + } + int8_store = { + "0": np.random.randint(-128, 127, size=(8, 16), dtype="int8"), + "1": np.random.random((16,)).astype("float32"), + "2": np.random.random((16,)).astype("float32"), # kernel_scale. + } + int4_store = { + "0": np.random.randint(-128, 127, size=(4, 16), dtype="int8"), + "1": np.random.random((16,)).astype("float32"), + "2": np.random.random((16,)).astype("float32"), # kernel_scale. + } + float8_store = { + "0": np.random.random((8, 16)).astype("float32"), + "1": np.random.random((16,)).astype("float32"), + # inputs_scale. + "2": np.random.random(()).astype("float32"), + # inputs_amax_history. + "3": np.random.random((1024,)).astype("float32"), + # kernel_scale. + "4": np.random.random(()).astype("float32"), + # kernel_amax_history. + "5": np.random.random((1024,)).astype("float32"), + # outputs_grad_scale. + "6": np.random.random(()).astype("float32"), + # outputs_grad_amax_history. + "7": np.random.random((1024,)).astype("float32"), + } + gptq_store = { + # bias + "0": np.random.random((16,)).astype("float32"), + # quantized_kernel + "1": np.random.randint(0, 16, size=(8, 8), dtype="uint8"), + # kernel_scale. + "2": np.random.random((16, 1)).astype("float32"), + # kernel_zero + "3": np.random.random((16, 1)).astype("uint8"), + # g_idx + "4": np.random.random((8,)).astype("float32"), + } + + # Test float32 layer. + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.load_own_variables(float32_store) + self.assertAllClose(layer._kernel, float32_store["0"]) + self.assertAllClose(layer.bias, float32_store["1"]) + + # Test int8-quantized layer. + layer = layers.Dense(units=16, dtype="int8_from_float32") + layer.build((None, 8)) + layer.load_own_variables(int8_store) + self.assertAllClose(layer._kernel, int8_store["0"]) + self.assertAllClose(layer.bias, int8_store["1"]) + self.assertAllClose(layer.kernel_scale, int8_store["2"]) + + # Test int4-quantized layer. + layer = layers.Dense(units=16, dtype="int4_from_float32") + layer.build((None, 8)) + layer.load_own_variables(int4_store) + self.assertAllClose(layer._kernel, int4_store["0"]) + self.assertAllClose(layer.bias, int4_store["1"]) + self.assertAllClose(layer.kernel_scale, int4_store["2"]) + + # Test float8-quantized layer. + layer = layers.Dense(units=16, dtype="float8_from_float32") + layer.build((None, 8)) + layer.load_own_variables(float8_store) + self.assertAllClose(layer._kernel, float8_store["0"]) + self.assertAllClose(layer.bias, float8_store["1"]) + self.assertAllClose(layer.inputs_scale, float8_store["2"]) + self.assertAllClose(layer.inputs_amax_history, float8_store["3"]) + self.assertAllClose(layer.kernel_scale, float8_store["4"]) + self.assertAllClose(layer.kernel_amax_history, float8_store["5"]) + self.assertAllClose(layer.outputs_grad_scale, float8_store["6"]) + self.assertAllClose(layer.outputs_grad_amax_history, float8_store["7"]) + + # Test gptq-quantized layer. + layer = layers.Dense(units=16, dtype="gptq/4/8_from_float32") + layer.build((None, 8)) + layer.load_own_variables(gptq_store) + self.assertAllClose(layer.bias, gptq_store["0"]) + self.assertAllClose(layer.quantized_kernel, gptq_store["1"]) + self.assertAllClose(layer.kernel_scale, gptq_store["2"]) + self.assertAllClose(layer.kernel_zero, gptq_store["3"]) + self.assertAllClose(layer.g_idx, gptq_store["4"]) + + def test_int4_gptq_kernel_returns_unpacked_form(self): + """Test that the `kernel` property returns the unpacked int4 GPTQ + kernel.""" + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize( + "gptq", + config=GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=8 + ), + ) + layer.is_gptq_calibrated = True # Bypass calibration check + packed_kernel = layer.quantized_kernel + self.assertAllClose( + layer.kernel, quantizers.unpack_int4(packed_kernel, 2) + ) + + def test_gptq_kernel_packing(self): + """Validates that 4-bit GPTQ packing reduces the kernel size.""" + layer = layers.Dense(units=16, use_bias=False) + layer.build((None, 8)) + + original_kernel_params = ops.prod(layer._kernel.shape) + + layer.quantize( + "gptq", + config=GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=8 + ), + ) + + quantized_kernel_params = ops.prod(layer.quantized_kernel.shape) + self.assertEqual(quantized_kernel_params, original_kernel_params // 2) diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py new file mode 100644 index 000000000000..2c8f2e2d90d6 --- /dev/null +++ b/keras/src/layers/core/einsum_dense.py @@ -0,0 +1,1572 @@ +import math +import re +import string + +import ml_dtypes +import numpy as np + +from keras.src import activations +from keras.src import constraints +from keras.src import dtype_policies +from keras.src import initializers +from keras.src import ops +from keras.src import quantizers +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.quantizers.quantizers import dequantize_with_sz_map + + +@keras_export("keras.layers.EinsumDense") +class EinsumDense(Layer): + """A layer that uses `einsum` as the backing computation. + + This layer can perform einsum calculations of arbitrary dimensionality. + + Args: + equation: An equation describing the einsum to perform. + This equation must be a valid einsum string of the form + `ab,bc->ac`, `...ab,bc->...ac`, or + `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum + axis expression sequence. + output_shape: The expected shape of the output tensor + (excluding the batch dimension and any dimensions + represented by ellipses). You can specify `None` for any dimension + that is unknown or can be inferred from the input shape. + activation: Activation function to use. If you don't specify anything, + no activation is applied + (that is, a "linear" activation: `a(x) = x`). + bias_axes: A string containing the output dimension(s) + to apply a bias to. Each character in the `bias_axes` string + should correspond to a character in the output portion + of the `equation` string. + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. + bias_regularizer: Regularizer function applied to the bias vector. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. + bias_constraint: Constraint function applied to the bias vector. + lora_rank: Optional integer. If set, the layer's forward pass + will implement LoRA (Low-Rank Adaptation) + with the provided rank. LoRA sets the layer's kernel + to non-trainable and replaces it with a delta over the + original kernel, obtained via multiplying two lower-rank + trainable matrices + (the factorization happens on the last dimension). + This can be useful to reduce the + computation cost of fine-tuning large dense layers. + You can also enable LoRA on an existing + `EinsumDense` layer by calling `layer.enable_lora(rank)`. + lora_alpha: Optional integer. If set, this parameter scales the + low-rank adaptation delta (computed as the product of two lower-rank + trainable matrices) during the forward pass. The delta is scaled by + `lora_alpha / lora_rank`, allowing you to fine-tune the strength of + the LoRA adjustment independently of `lora_rank`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + + Examples: + + **Biased dense layer with einsums** + + This example shows how to instantiate a standard Keras dense layer using + einsum operations. This example is equivalent to + `keras.layers.Dense(64, use_bias=True)`. + + >>> layer = keras.layers.EinsumDense("ab,bc->ac", + ... output_shape=64, + ... bias_axes="c") + >>> input_tensor = keras.Input(shape=[32]) + >>> output_tensor = layer(input_tensor) + >>> output_tensor.shape + (None, 64) + + **Applying a dense layer to a sequence** + + This example shows how to instantiate a layer that applies the same dense + operation to every element in a sequence. Here, the `output_shape` has two + values (since there are two non-batch dimensions in the output); the first + dimension in the `output_shape` is `None`, because the sequence dimension + `b` has an unknown shape. + + >>> layer = keras.layers.EinsumDense("abc,cd->abd", + ... output_shape=(None, 64), + ... bias_axes="d") + >>> input_tensor = keras.Input(shape=[32, 128]) + >>> output_tensor = layer(input_tensor) + >>> output_tensor.shape + (None, 32, 64) + + **Applying a dense layer to a sequence using ellipses** + + This example shows how to instantiate a layer that applies the same dense + operation to every element in a sequence, but uses the ellipsis notation + instead of specifying the batch and sequence dimensions. + + Because we are using ellipsis notation and have specified only one axis, the + `output_shape` arg is a single value. When instantiated in this way, the + layer can handle any number of sequence dimensions - including the case + where no sequence dimension exists. + + >>> layer = keras.layers.EinsumDense("...x,xy->...y", + ... output_shape=64, + ... bias_axes="y") + >>> input_tensor = keras.Input(shape=[32, 128]) + >>> output_tensor = layer(input_tensor) + >>> output_tensor.shape + (None, 32, 64) + """ + + def __init__( + self, + equation, + output_shape, + activation=None, + bias_axes=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + lora_rank=None, + lora_alpha=None, + gptq_unpacked_column_size=None, + **kwargs, + ): + super().__init__(**kwargs) + self.equation = equation + if isinstance(output_shape, int): + self.partial_output_shape = (output_shape,) + else: + self.partial_output_shape = tuple(output_shape) + self.bias_axes = bias_axes + self.activation = activations.get(activation) + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank + self.lora_enabled = False + self.gptq_unpacked_column_size = gptq_unpacked_column_size + + def build(self, input_shape): + shape_data = _analyze_einsum_string( + self.equation, + self.bias_axes, + input_shape, + self.partial_output_shape, + ) + kernel_shape, bias_shape, full_output_shape = shape_data + self.full_output_shape = tuple(full_output_shape) + self.input_spec = InputSpec(ndim=len(input_shape)) + if self.quantization_mode is not None: + self.quantized_build( + kernel_shape, + mode=self.quantization_mode, + ) + # Skip creating a duplicate kernel variable when the layer is already + # quantized to int8 or int4, because `quantized_build` has created the + # appropriate kernel variable. For other modes (e.g., float8 or no + # quantization), we still need the floating-point kernel. + if self.quantization_mode not in ("int8", "int4", "gptq"): + # If the layer is quantized to int8, `self._kernel` will be added + # in `self._int8_build`. Therefore, we skip it here. + self._kernel = self.add_weight( + name="kernel", + shape=tuple(kernel_shape), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True, + ) + if bias_shape is not None: + self.bias = self.add_weight( + name="bias", + shape=tuple(bias_shape), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True, + ) + else: + self.bias = None + self.built = True + if self.lora_rank: + self.enable_lora(self.lora_rank, lora_alpha=self.lora_alpha) + + @property + def kernel(self): + from keras.src.quantizers import gptq_core + + if not self.built: + raise AttributeError( + "You must build the layer before accessing `kernel`." + ) + + mode = self.quantization_mode + is_gptq = mode == "gptq" + is_int4 = mode == "int4" + calibrated = bool(getattr(self, "is_gptq_calibrated", False)) + gptq_bits = ( + gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None + ) + + # Decide the source tensor first (packed vs already-quantized vs plain + # kernel) + if is_gptq and calibrated and gptq_bits != 4: + # calibrated GPTQ, not 4-bit, no unpacking needed + kernel = self.quantized_kernel + else: + # Start with the stored kernel + kernel = getattr(self, "_kernel", None) + + # Handle int4 unpacking cases in one place + if is_int4: + kernel = quantizers.unpack_int4( + kernel, + self._orig_length_along_pack_axis, + self._int4_pack_axis, + ) + elif is_gptq and calibrated and gptq_bits == 4: + kernel = quantizers.unpack_int4( + self.quantized_kernel, + orig_len=self.gptq_unpacked_column_size, + axis=0, + dtype="uint8", + ) + + # Apply LoRA if enabled + if self.lora_enabled: + kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_kernel_a, self.lora_kernel_b + ) + + return kernel + + def compute_output_shape(self, _): + return self.full_output_shape + + def call(self, inputs, training=None): + x = ops.einsum(self.equation, inputs, self.kernel) + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def enable_lora( + self, + rank, + lora_alpha=None, + a_initializer="he_uniform", + b_initializer="zeros", + ): + if self.kernel_constraint: + raise ValueError( + "Lora is incompatible with kernel constraints. " + "In order to enable lora on this layer, remove the " + "`kernel_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. This can only be done once per layer." + ) + if self.quantization_mode == "gptq": + raise NotImplementedError( + "lora is not currently supported with GPTQ quantization." + ) + self._tracker.unlock() + # Determine the appropriate (unpacked) kernel shape for LoRA. + if self.quantization_mode == "int4": + # When int4-quantized, `self._kernel` is packed along + # `self._int4_pack_axis` and its length equals + # `(orig_len + 1) // 2`. Recover the original length so that + # the LoRA matrices operate in the full-precision space. + kernel_shape_for_lora = list(self._kernel.shape) + pack_axis = getattr(self, "_int4_pack_axis", 0) + orig_len = getattr(self, "_orig_length_along_pack_axis", None) + if orig_len is not None: + kernel_shape_for_lora[pack_axis] = orig_len + kernel_shape_for_lora = tuple(kernel_shape_for_lora) + else: + kernel_shape_for_lora = self.kernel.shape + + self.lora_kernel_a = self.add_weight( + name="lora_kernel_a", + shape=(kernel_shape_for_lora[:-1] + (rank,)), + initializer=initializers.get(a_initializer), + regularizer=self.kernel_regularizer, + ) + self.lora_kernel_b = self.add_weight( + name="lora_kernel_b", + shape=(rank, self.kernel.shape[-1]), + initializer=initializers.get(b_initializer), + regularizer=self.kernel_regularizer, + ) + self._kernel.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + self.lora_alpha = lora_alpha if lora_alpha is not None else rank + + def save_own_variables(self, store): + # Do nothing if the layer isn't yet built + if not self.built: + return + mode = self.quantization_mode + if mode not in self.quantization_variable_spec: + raise self._quantization_mode_error(mode) + + # Kernel plus optional merged LoRA-aware scale (returns (kernel, None) + # for None/gptq) + kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora() + + # Save the variables using the name as the key. + if mode != "gptq": + store["kernel"] = kernel_value + if self.bias is not None: + store["bias"] = self.bias + for name in self.quantization_variable_spec[mode]: + if name == "kernel_scale" and mode in ("int4", "int8"): + # For int4/int8, the merged LoRA scale (if any) comes from + # `_get_kernel_with_merged_lora()` + store[name] = merged_kernel_scale + else: + store[name] = getattr(self, name) + + def load_own_variables(self, store): + if not self.lora_enabled: + self._check_load_own_variables(store) + # Do nothing if the layer isn't yet built + if not self.built: + return + mode = self.quantization_mode + if mode not in self.quantization_variable_spec: + raise self._quantization_mode_error(mode) + + # Determine whether to use the legacy loading method. + if "0" in store: + return self._legacy_load_own_variables(store) + + # Load the variables using the name as the key. + if mode != "gptq": + self._kernel.assign(store["kernel"]) + if self.bias is not None: + self.bias.assign(store["bias"]) + for name in self.quantization_variable_spec[mode]: + getattr(self, name).assign(store[name]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + + def _legacy_load_own_variables(self, store): + # The keys of the `store` will be saved as determined because the + # default ordering will change after quantization + mode = self.quantization_mode + targets = [] + if mode != "gptq": + targets.append(self._kernel) + if self.bias is not None: + targets.append(self.bias) + targets.extend( + getattr(self, name) + for name in self.quantization_variable_spec[mode] + ) + for i, variable in enumerate(targets): + variable.assign(store[str(i)]) + if self.lora_enabled: + self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) + + def get_config(self): + base_config = super().get_config() + config = { + "output_shape": self.partial_output_shape, + "equation": self.equation, + "activation": activations.serialize(self.activation), + "bias_axes": self.bias_axes, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + if self.lora_rank: + config["lora_rank"] = self.lora_rank + config["lora_alpha"] = self.lora_alpha + if self.gptq_unpacked_column_size: + config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size + return {**base_config, **config} + + def _check_load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) + + @property + def quantization_variable_spec(self): + """Returns a dict mapping quantization modes to variable names. + + This spec is used by `save_own_variables` and `load_own_variables` to + determine which variables should be saved/loaded for each quantization + mode. + """ + return { + None: [], + "int8": ["kernel_scale"], + "int4": ["kernel_scale"], + "float8": [ + "inputs_scale", + "inputs_amax_history", + "kernel_scale", + "kernel_amax_history", + "outputs_grad_scale", + "outputs_grad_amax_history", + ], + "gptq": [ + "quantized_kernel", + "kernel_scale", + "kernel_zero", + "g_idx", + ], + } + + def quantized_build(self, kernel_shape, mode, config=None): + if mode == "int8": + self._int8_build(kernel_shape) + elif mode == "int4": + self._int4_build(kernel_shape) + elif mode == "float8": + self._float8_build() + elif mode == "gptq": + self._gptq_build(kernel_shape, config) + else: + raise self._quantization_mode_error(mode) + self._is_quantized = True + + def _int8_build(self, kernel_shape): + self._set_quantization_info() + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=self._input_reduced_axes + ) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer="zeros", + dtype="int8", + trainable=False, + ) + kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale_shape, + initializer="ones", + trainable=False, + ) + + def _gptq_build(self, kernel_shape, config): + """ + Allocate quantized kernel & params for EinsumDense. + + Args: + kernel_shape: tuple/list; the layer's original kernel shape, e.g. + [in_features, out_features] or [in_features, heads, head_dim]. + group_size: int; contiguous input-group size for quantization + (=-1 means per-output-channel with no grouping). + """ + from keras.src.quantizers import gptq_core + + # Ensures the forward pass uses the original high-precision kernel + # until calibration has been performed. + self.is_gptq_calibrated = False + + self.original_kernel_shape = kernel_shape + if len(kernel_shape) == 2: + rows = kernel_shape[0] + columns = kernel_shape[1] + elif len(kernel_shape) == 3: + shape = list(self.original_kernel_shape) + try: + d_model_dim_index = shape.index(max(shape)) + except ValueError: + raise TypeError( + f"Could not determine hidden dimension from shape {shape}" + ) + + if d_model_dim_index == 0: # QKV projection case + in_features, heads, head_dim = shape + rows, columns = ( + in_features, + heads * head_dim, + ) + elif d_model_dim_index in [1, 2]: # Attention Output case + heads, head_dim, out_features = shape + rows, columns = ( + heads * head_dim, + out_features, + ) + else: + raise ValueError("Could not determine row/column split.") + + group_size = gptq_core.get_group_size_for_layer(self, config) + n_groups = 1 if group_size == -1 else math.ceil(rows / group_size) + + self.gptq_unpacked_column_size = columns + + weight_bits = gptq_core.get_weight_bits_for_layer(self, config) + # For 4-bit weights, we pack two values per byte. + kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns + + if hasattr(self, "_set_quantization_info"): + self._set_quantization_info() + + self.quantized_kernel = self.add_weight( + name="kernel", + shape=(kernel_columns, rows), + initializer="zeros", + dtype="uint8", + trainable=False, + ) + + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=(columns, n_groups), + initializer="ones", + trainable=False, + ) + self.kernel_zero = self.add_weight( + name="zero_point", + shape=(columns, n_groups), + initializer="zeros", + dtype="uint8", + trainable=False, + ) + + self.g_idx = self.add_weight( + name="g_idx", + shape=(rows,), + initializer="zeros", + dtype="float32", + trainable=False, + ) + + def _gptq_call(self, inputs, training=False): + from keras.src.quantizers import gptq_core + + if not self.is_gptq_calibrated: + W = self._kernel + else: + should_unpack = ( + gptq_core.get_weight_bits_for_layer(self, config=None) == 4 + ) + W = ( + quantizers.unpack_int4( + self.quantized_kernel, + orig_len=self.gptq_unpacked_column_size, + axis=0, + dtype="uint8", + ) + if should_unpack + else self.quantized_kernel + ) + W = dequantize_with_sz_map( + W, + self.kernel_scale, + self.kernel_zero, + self.g_idx, + ) + W = ops.transpose(W) + + W = ops.reshape(W, self.original_kernel_shape) + + y = ops.einsum(self.equation, inputs, W) + if self.bias is not None: + y = ops.add(y, self.bias) + if self.activation is not None: + y = self.activation(y) + return y + + def _int4_build(self, kernel_shape): + """Build variables for int4 quantization. + + The packed int4 kernel stores two int4 values within a single int8 + byte. Packing is performed along the first axis contained in + `self._kernel_reduced_axes` (which is the axis that gets reduced in + the einsum and thus analogous to the input-dim axis of a `Dense` + layer). + """ + self._set_quantization_info() + + # Quantizer for the inputs (per the reduced axes) + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=self._input_reduced_axes + ) + + # Choose the axis to perform int4 packing - use the first reduced axis + # for the kernel (analogous to the input dimension of a Dense layer). + self._int4_pack_axis = ( + self._kernel_reduced_axes[0] if self._kernel_reduced_axes else 0 + ) + + # Original length along the packing axis (needed for unpacking). + self._orig_length_along_pack_axis = kernel_shape[self._int4_pack_axis] + + # Packed length (ceil division by 2). Note: assumes static integer. + packed_len = (self._orig_length_along_pack_axis + 1) // 2 + + # Derive packed kernel shape by replacing the pack axis dimension. + packed_kernel_shape = list(kernel_shape) + packed_kernel_shape[self._int4_pack_axis] = packed_len + packed_kernel_shape = tuple(packed_kernel_shape) + + # Add packed int4 kernel variable (stored as int8 dtype). + self._kernel = self.add_weight( + name="kernel", + shape=packed_kernel_shape, + initializer="zeros", + dtype="int8", + trainable=False, + ) + + # Kernel scale + kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale_shape, + initializer="ones", + trainable=False, + ) + + def _float8_build(self): + from keras.src.dtype_policies import QuantizedFloat8DTypePolicy + + # If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set + # `amax_history_length` to its default value. + amax_history_length = getattr( + self.dtype_policy, + "amax_history_length", + QuantizedFloat8DTypePolicy.default_amax_history_length, + ) + # We set `trainable=True` because we will use the gradients to overwrite + # these variables + scale_kwargs = { + "shape": (), + "initializer": "ones", + "dtype": "float32", # Always be float32 + "trainable": True, + "autocast": False, + "overwrite_with_gradient": True, + } + amax_history_kwargs = { + "shape": (amax_history_length,), + "initializer": "zeros", + "dtype": "float32", # Always be float32 + "trainable": True, + "autocast": False, + "overwrite_with_gradient": True, + } + self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.inputs_amax_history = self.add_weight( + name="inputs_amax_history", **amax_history_kwargs + ) + self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.kernel_amax_history = self.add_weight( + name="kernel_amax_history", **amax_history_kwargs + ) + self.outputs_grad_scale = self.add_weight( + name="outputs_grad_scale", **scale_kwargs + ) + self.outputs_grad_amax_history = self.add_weight( + name="outputs_grad_amax_history", **amax_history_kwargs + ) + + def _int8_call(self, inputs, training=None): + @ops.custom_gradient + def einsum_with_inputs_gradient(inputs, kernel, kernel_scale): + """Performs int8 quantized einsum with a custom gradient. + + Computes the einsum operation with quantized inputs and a quantized + kernel, then de-quantizes the result. + + Also computes the gradient with respect to the original, + full-precision inputs by using a de-quantized kernel. + + Args: + inputs: The full-precision input tensor. + kernel: The int8 quantized kernel tensor. + kernel_scale: The float32 scale factor for the kernel. + + Returns: + A tuple `(output, grad_fn)`: + `output`: The de-quantized result of the einsum operation. + `grad_fn`: The custom gradient function for the backward + pass. + + Raises: + ValueError: If the quantization mode is not supported. + """ + + def grad_fn(*args, upstream=None): + if upstream is None: + (upstream,) = args + # De-scale kernel + _kernel_scale = kernel_scale + _kernel_scale = self._adjust_scale_for_dequant(_kernel_scale) + float_kernel = ops.divide( + ops.cast(kernel, dtype=self.compute_dtype), + _kernel_scale, + ) + # From https://stackoverflow.com/a/47609896 + inputs_grad = ops.einsum( + self._custom_gradient_equation, upstream, float_kernel + ) + return (inputs_grad, None, None) + + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.einsum(self.equation, inputs, kernel) + # Deal with `inputs_scale` + inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input") + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + return x, grad_fn + + x = einsum_with_inputs_gradient( + inputs, + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ) + if self.lora_enabled: + lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a) + lora_x = ops.matmul(lora_x, self.lora_kernel_b) + x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x) + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def _int4_call(self, inputs, training=None): + """Forward pass for int4 quantized `EinsumDense`.""" + + pack_axis = getattr(self, "_int4_pack_axis", 0) + orig_len = getattr(self, "_orig_length_along_pack_axis", None) + + @ops.custom_gradient + def einsum_with_inputs_gradient(inputs, packed_kernel, kernel_scale): + """Performs int4 quantized einsum with a custom gradient. + + Computes the einsum operation with quantized inputs and a quantized + kernel, then de-quantizes the result. + + Also computes the gradient with respect to the original, + full-precision inputs by using a de-quantized kernel. + + Args: + inputs: The full-precision input tensor. + packed_kernel: The int4-packed kernel tensor. + kernel_scale: The float32 scale factor for the kernel. + + Returns: + A tuple `(output, grad_fn)`: + `output`: The de-quantized result of the einsum operation. + `grad_fn`: The custom gradient function for the backward + pass. + + Raises: + ValueError: If the quantization mode is not supported. + """ + # Unpack the int4-packed kernel back to int8 values [-8, 7]. + unpacked_kernel = quantizers.unpack_int4( + packed_kernel, orig_len, axis=pack_axis + ) + + def grad_fn(*args, upstream=None): + if upstream is None: + (upstream,) = args + # Align `kernel_scale` to the same layout as `unpacked_kernel`. + _kernel_scale = kernel_scale + _kernel_scale = self._adjust_scale_for_dequant(_kernel_scale) + + float_kernel = ops.divide( + ops.cast(unpacked_kernel, dtype=self.compute_dtype), + _kernel_scale, + ) + inputs_grad = ops.einsum( + self._custom_gradient_equation, upstream, float_kernel + ) + return (inputs_grad, None, None) + + # Quantize inputs per `self.inputs_quantizer`. + inputs_q, inputs_scale = self.inputs_quantizer(inputs) + + # Compute einsum on quantized inputs and unpacked int4 kernel. + x = ops.einsum(self.equation, inputs_q, unpacked_kernel) + + # Align `inputs_scale` axes with the output for correct broadcasting + inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input") + + # De-scale outputs. + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + return x, grad_fn + + x = einsum_with_inputs_gradient( + inputs, + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ) + + # Add LoRA contribution if enabled + if self.lora_enabled: + lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a) + lora_x = ops.matmul(lora_x, self.lora_kernel_b) + x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x) + + # Bias & activation + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def _float8_call(self, inputs, training=None): + if self.lora_enabled: + raise NotImplementedError( + "Currently, `_float8_call` doesn't support LoRA" + ) + + @ops.custom_gradient + def quantized_dequantize_inputs(inputs, scale, amax_history): + if training: + new_scale = quantizers.compute_float8_scale( + ops.max(amax_history, axis=0), + scale, + ops.cast( + float(ml_dtypes.finfo("float8_e4m3fn").max), "float32" + ), + ) + new_amax_history = quantizers.compute_float8_amax_history( + inputs, amax_history + ) + else: + new_scale = None + new_amax_history = None + qdq_inputs = quantizers.quantize_and_dequantize( + inputs, scale, "float8_e4m3fn", self.compute_dtype + ) + + def grad(*args, upstream=None, variables=None): + if upstream is None: + (upstream,) = args + return upstream, new_scale, new_amax_history + + return qdq_inputs, grad + + @ops.custom_gradient + def quantized_dequantize_outputs(outputs, scale, amax_history): + """Quantize-dequantize the output gradient but not the output.""" + + def grad(*args, upstream=None, variables=None): + if upstream is None: + (upstream,) = args + new_scale = quantizers.compute_float8_scale( + ops.max(amax_history, axis=0), + scale, + ops.cast( + float(ml_dtypes.finfo("float8_e5m2").max), "float32" + ), + ) + qdq_upstream = quantizers.quantize_and_dequantize( + upstream, scale, "float8_e5m2", self.compute_dtype + ) + new_amax_history = quantizers.compute_float8_amax_history( + upstream, amax_history + ) + return qdq_upstream, new_scale, new_amax_history + + return outputs, grad + + x = ops.einsum( + self.equation, + quantized_dequantize_inputs( + inputs, + ops.convert_to_tensor(self.inputs_scale), + ops.convert_to_tensor(self.inputs_amax_history), + ), + quantized_dequantize_inputs( + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ops.convert_to_tensor(self.kernel_amax_history), + ), + ) + # `quantized_dequantize_outputs` is placed immediately after + # `ops.einsum` for the sake of pattern matching in gemm_rewrite. That + # way, the qdq will be adjacent to the corresponding einsum_bprop in the + # bprop. + x = quantized_dequantize_outputs( + x, + ops.convert_to_tensor(self.outputs_grad_scale), + ops.convert_to_tensor(self.outputs_grad_amax_history), + ) + if self.bias is not None: + # Under non-mixed precision cases, F32 bias has to be converted to + # BF16 first to get the biasAdd fusion support. ref. PR + # https://github.com/tensorflow/tensorflow/pull/60306 + bias = self.bias + if self.dtype_policy.compute_dtype == "float32": + bias_bf16 = ops.cast(bias, "bfloat16") + bias = ops.cast(bias_bf16, bias.dtype) + x = ops.add(x, bias) + if self.activation is not None: + x = self.activation(x) + return x + + def quantize(self, mode, type_check=True, config=None): + # Prevent quantization of the subclasses + if type_check and (type(self) is not EinsumDense): + raise self._not_implemented_error(self.quantize) + + kernel_shape = self._kernel.shape + if mode in ("int8", "int4", "gptq"): + self._set_quantization_info() + + if mode == "int8": + # Quantize `self._kernel` to int8 and compute corresponding scale + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=self._kernel_reduced_axes, to_numpy=True + ) + kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel") + del self._kernel + elif mode == "int4": + # Quantize to int4 values (stored in int8 dtype, range [-8, 7]) + kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( + self._kernel, + axis=self._kernel_reduced_axes, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel") + + # Pack along the first kernel-reduced axis. + pack_axis = self._kernel_reduced_axes[0] + packed_kernel_value, _, _ = quantizers.pack_int4( + kernel_value_int4, axis=pack_axis + ) + kernel_value = packed_kernel_value + del self._kernel + self.quantized_build(kernel_shape, mode, config) + + # Assign values to the newly created variables. + if mode in ("int8", "int4"): + self._kernel.assign(kernel_value) + self.kernel_scale.assign(kernel_scale) + + # Set new dtype policy + if self.dtype_policy.quantization_mode is None: + policy_name = mode + if mode == "gptq": + policy_name = config.dtype_policy_string() + policy = dtype_policies.get( + f"{policy_name}_from_{self.dtype_policy.name}" + ) + self.dtype_policy = policy + + def _get_kernel_scale_shape(self, kernel_shape): + """Get the shape of the kernel scale tensor. + + The kernel scale tensor is used to scale the kernel tensor. + The shape of the kernel scale tensor is the same as the shape of the + kernel tensor, but with the reduced axes set to 1 and the transpose + axes set to the original axes + + Args: + kernel_shape: The shape of the kernel tensor. + + Returns: + The shape of the kernel scale tensor. + """ + kernel_scale_shape = np.array(kernel_shape) + kernel_scale_shape[self._kernel_reduced_axes] = 1 + kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes] + kernel_scale_shape = kernel_scale_shape.tolist() + for a in sorted(self._kernel_expand_axes): + kernel_scale_shape.insert(a, 1) + for a in sorted(self._kernel_squeeze_axes, reverse=True): + kernel_scale_shape.pop(a) + return kernel_scale_shape + + def _get_kernel_with_merged_lora(self): + """Returns the kernel with LoRA matrices merged, for serialization. + + This method is called by `save_own_variables` to produce a single + kernel tensor that includes the adaptations from LoRA. This is useful + for deploying the model or for continuing training after permanently + applying the LoRA update. + + If the layer is quantized (`int8` or `int4`), the process is: + 1. Dequantize the base kernel to float. + 2. Adjust the scale tensor layout for dequantization. This is the + reverse order of operations used when building the layer. + 3. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add + it to the dequantized kernel. + 4. Re-quantize the merged result back to the original quantized + type (`int8` or packed `int4`), calculating a new scale factor. + 5. Adjust the scale tensor layout for quantization. This is the forward + order of operations used when building the layer. + + If the layer is not quantized, this method returns the result of the + `kernel` property (which computes the merge in floating-point) and a + scale of `None`. + + If LoRA is not enabled, it returns the original kernel and scale + without modification. + + Returns: + A tuple `(kernel_value, kernel_scale)`: + `kernel_value`: The merged kernel. A quantized tensor if + quantization is active, otherwise a high precision tensor. + `kernel_scale`: The quantization scale for the merged kernel. + This is `None` if the layer is not quantized. + """ + # If not a quantized layer, return the full-precision kernel directly. + if self.dtype_policy.quantization_mode in (None, "gptq"): + return self.kernel, None + + # If quantized but LoRA is not enabled, return the original quantized + # kernel. + if not self.lora_enabled: + return self._kernel, self.kernel_scale + + # Dequantize, Merge, and Re-quantize + + # 1. Dequantize the kernel + if self.quantization_mode == "int4": + unpacked_kernel = quantizers.unpack_int4( + self._kernel, + self._orig_length_along_pack_axis, + axis=self._int4_pack_axis, + ) + # Adjust scale for dequantization (reverse the transformations). + adjusted_scale = self._adjust_scale_for_dequant(self.kernel_scale) + kernel_fp = ops.divide(unpacked_kernel, adjusted_scale) + elif self.quantization_mode == "int8": + adjusted_scale = self._adjust_scale_for_dequant(self.kernel_scale) + kernel_fp = ops.divide(self._kernel, adjusted_scale) + else: + raise ValueError( + f"Unsupported quantization mode: {self.quantization_mode}" + ) + + # 2. Merge the LoRA update in the float domain + lora_update = (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_kernel_a, self.lora_kernel_b + ) + merged_kernel_fp = ops.add(kernel_fp, lora_update) + + # 3. Re-quantize the merged float kernel back to the target format + if self.quantization_mode == "int4": + kernel_quant, new_scale = quantizers.abs_max_quantize( + merged_kernel_fp, + axis=self._kernel_reduced_axes, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + # Pack back to int4 + new_kernel, _, _ = quantizers.pack_int4( + kernel_quant, axis=self._int4_pack_axis + ) + elif self.quantization_mode == "int8": + new_kernel, new_scale = quantizers.abs_max_quantize( + merged_kernel_fp, + axis=self._kernel_reduced_axes, + to_numpy=True, + ) + + # Adjust the new scale tensor to the required layout. + new_scale = self._adjust_scale_for_quant(new_scale, "kernel") + + return new_kernel, new_scale + + def _adjust_scale_for_dequant(self, scale): + """Adjusts scale tensor layout for dequantization. + + Helper method to handle scale adjustments before dequantization. + This is the reverse order of operations used when building the layer. + + Args: + scale: The scale tensor to adjust. + + Returns: + The adjusted scale tensor. + """ + if self._kernel_squeeze_axes: + scale = ops.expand_dims(scale, axis=self._kernel_squeeze_axes) + if self._kernel_expand_axes: + scale = ops.squeeze(scale, axis=self._kernel_expand_axes) + if self._kernel_transpose_axes: + # We need to reverse the transpose operation. + reverse_transpose = sorted( + range(len(self._kernel_transpose_axes)), + key=self._kernel_transpose_axes.__getitem__, + ) + scale = ops.transpose(scale, axes=reverse_transpose) + return scale + + def _adjust_scale_for_quant(self, scale, tensor_type="kernel"): + """Adjusts scale tensor layout after quantization. + + Helper method to handle scale adjustments after re-quantization. + This is the forward order of operations used when building the layer. + + Args: + scale: The scale tensor to adjust. + tensor_type: The type of tensor to adjust the scale for. + "kernel" or "input". + Returns: + The adjusted scale tensor. + """ + if tensor_type == "kernel": + transpose_axes = self._kernel_transpose_axes + expand_axes = self._kernel_expand_axes + squeeze_axes = self._kernel_squeeze_axes + elif tensor_type == "input": + transpose_axes = self._input_transpose_axes + expand_axes = self._input_expand_axes + squeeze_axes = self._input_squeeze_axes + else: + raise ValueError(f"Invalid tensor type: {tensor_type}") + + if transpose_axes: + scale = ops.transpose(scale, transpose_axes) + if expand_axes: + scale = ops.expand_dims(scale, axis=expand_axes) + if squeeze_axes: + scale = ops.squeeze(scale, axis=squeeze_axes) + return scale + + def _set_quantization_info(self): + if hasattr(self, "_input_reduced_axes"): + # Already set. + return + ( + self._input_reduced_axes, + self._kernel_reduced_axes, + self._input_transpose_axes, + self._kernel_transpose_axes, + self._input_expand_axes, + self._kernel_expand_axes, + self._input_squeeze_axes, + self._kernel_squeeze_axes, + self._custom_gradient_equation, + self._kernel_reverse_transpose_axes, + ) = _analyze_quantization_info(self.equation, self.input_spec.ndim) + + +def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): + """Parses an einsum string to determine the shapes of the weights. + + This function is the main entry point for analyzing the einsum equation. + It handles equations with and without ellipses (`...`) by converting them + to a standard format and then delegating to `_analyze_split_string` for + the core logic. + + Args: + equation: The einsum equation string, e.g., "ab,bc->ac" or + "...ab,bc->...ac". + bias_axes: A string indicating which output axes to apply a bias to. + input_shape: The shape of the input tensor. + output_shape: The user-specified shape of the output tensor (may be + partial). + + Returns: + A tuple `(kernel_shape, bias_shape, full_output_shape)` where: + `kernel_shape`: The calculated shape of the einsum kernel. + `bias_shape`: The calculated shape of the bias, or `None`. + `full_output_shape`: The fully-resolved shape of the output tensor. + + Raises: + ValueError: If the einsum `equation` is not in a supported format. + """ + + dot_replaced_string = re.sub(r"\.\.\.", "0", equation) + + # This is the case where no ellipses are present in the string. + split_string = re.match( + "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string + ) + if split_string: + return _analyze_split_string( + split_string, bias_axes, input_shape, output_shape + ) + + # This is the case where ellipses are present on the left. + split_string = re.match( + "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string + ) + if split_string: + return _analyze_split_string( + split_string, bias_axes, input_shape, output_shape, left_elided=True + ) + + # This is the case where ellipses are present on the right. + split_string = re.match( + "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string + ) + if split_string: + return _analyze_split_string( + split_string, bias_axes, input_shape, output_shape + ) + + raise ValueError( + f"Invalid einsum equation '{equation}'. Equations must be in the form " + "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." + ) + + +def _analyze_split_string( + split_string, bias_axes, input_shape, output_shape, left_elided=False +): + """Computes kernel and bias shapes from a parsed einsum equation. + + This function takes the components of an einsum equation, validates them, + and calculates the required shapes for the kernel and bias weights. + + Args: + split_string: A regex match object containing the input, weight, and + output specifications. + bias_axes: A string indicating which output axes to apply a bias to. + input_shape: The shape of the input tensor. + output_shape: The user-specified partial shape of the output tensor. + left_elided: A boolean indicating if the ellipsis "..." was on the + left side of the equation. + + Returns: + A tuple `(kernel_shape, bias_shape, full_output_shape)` where: + `kernel_shape`: The calculated shape of the einsum kernel. + `bias_shape`: The calculated shape of the bias, or `None`. + `full_output_shape`: The fully-resolved shape of the output tensor. + + Raises: + ValueError: If there are inconsistencies between the input and output + shapes or if the equation specifications are invalid. + """ + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + + if isinstance(output_shape, int): + output_shape = [output_shape] + else: + output_shape = list(output_shape) + + output_shape.insert(0, input_shape[0]) + + if elided > 0 and left_elided: + for i in range(1, elided): + # We already inserted the 0th input dimension at dim 0, so we need + # to start at location 1 here. + output_shape.insert(1, input_shape[i]) + elif elided > 0 and not left_elided: + for i in range(len(input_shape) - elided, len(input_shape)): + output_shape.append(input_shape[i]) + + if left_elided: + # If we have beginning dimensions elided, we need to use negative + # indexing to determine where in the input dimension our values are. + input_dim_map = { + dim: (i + elided) - len(input_shape) + for i, dim in enumerate(input_spec) + } + # Because we've constructed the full output shape already, we don't need + # to do negative indexing. + output_dim_map = { + dim: (i + elided) for i, dim in enumerate(output_spec) + } + else: + input_dim_map = {dim: i for i, dim in enumerate(input_spec)} + output_dim_map = {dim: i for i, dim in enumerate(output_spec)} + + for dim in input_spec: + input_shape_at_dim = input_shape[input_dim_map[dim]] + if dim in output_dim_map: + output_shape_at_dim = output_shape[output_dim_map[dim]] + if ( + output_shape_at_dim is not None + and output_shape_at_dim != input_shape_at_dim + ): + raise ValueError( + "Input shape and output shape do not match at shared " + f"dimension '{dim}'. Input shape is {input_shape_at_dim}, " + "and output shape " + f"is {output_shape[output_dim_map[dim]]}." + ) + + for dim in output_spec: + if dim not in input_spec and dim not in weight_spec: + raise ValueError( + f"Dimension '{dim}' was specified in the output " + f"'{output_spec}' but has no corresponding dim in the input " + f"spec '{input_spec}' or weight spec '{output_spec}'" + ) + + weight_shape = [] + for dim in weight_spec: + if dim in input_dim_map: + weight_shape.append(input_shape[input_dim_map[dim]]) + elif dim in output_dim_map: + weight_shape.append(output_shape[output_dim_map[dim]]) + else: + raise ValueError( + f"Weight dimension '{dim}' did not have a match in either " + f"the input spec '{input_spec}' or the output " + f"spec '{output_spec}'. For this layer, the weight must " + "be fully specified." + ) + + if bias_axes is not None: + num_left_elided = elided if left_elided else 0 + idx_map = { + char: output_shape[i + num_left_elided] + for i, char in enumerate(output_spec) + } + + for char in bias_axes: + if char not in output_spec: + raise ValueError( + f"Bias dimension '{char}' was requested, but is not part " + f"of the output spec '{output_spec}'" + ) + + first_bias_location = min( + [output_spec.find(char) for char in bias_axes] + ) + bias_output_spec = output_spec[first_bias_location:] + + bias_shape = [ + idx_map[char] if char in bias_axes else 1 + for char in bias_output_spec + ] + + if not left_elided: + for _ in range(elided): + bias_shape.append(1) + else: + bias_shape = None + + return weight_shape, bias_shape, output_shape + + +def _analyze_quantization_info(equation, input_shape): + """Analyzes an einsum equation to derive information for quantization. + + This function canonicalizes the einsum equation (handling ellipses) and + determines the necessary tensor manipulations (reduction, transposition, + expansion, squeezing) required to correctly apply per-axis quantization + to the inputs and kernel. It also derives the einsum equation needed for + the custom gradient. + + Args: + equation: The einsum equation string. + input_shape: The shape of the input tensor. + + Returns: + A tuple containing metadata for quantization operations: + `input_reduced_axes`: Axes to reduce for input quantization. + `kernel_reduced_axes`: Axes to reduce for kernel quantization. + `input_transpose_axes`: Permutation for transposing the input scale. + `kernel_transpose_axes`: Permutation for transposing the kernel scale. + `input_expand_axes`: Axes to expand for the input scale. + `kernel_expand_axes`: Axes to expand for the kernel scale. + `input_squeeze_axes`: Axes to squeeze from the input scale. + `kernel_squeeze_axes`: Axes to squeeze from the kernel scale. + `custom_gradient_equation`: Einsum equation for the backward pass. + `kernel_reverse_transpose_axes`: Permutation to reverse the kernel + scale transpose. + """ + + def get_specs(equation, input_shape): + possible_labels = string.ascii_letters + dot_replaced_string = re.sub(r"\.\.\.", "0", equation) + + # This is the case where no ellipses are present in the string. + split_string = re.match( + "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + return input_spec, weight_spec, output_spec + + # This is the case where ellipses are present on the left. + split_string = re.match( + "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + possible_labels = sorted( + set(possible_labels) + - set(input_spec) + - set(weight_spec) + - set(output_spec) + ) + # Pad labels on the left to `input_spec` and `output_spec` + for i in range(elided): + input_spec = possible_labels[i] + input_spec + output_spec = possible_labels[i] + output_spec + return input_spec, weight_spec, output_spec + + # This is the case where ellipses are present on the right. + split_string = re.match( + "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + possible_labels = sorted( + set(possible_labels) + - set(input_spec) + - set(weight_spec) + - set(output_spec) + ) + # Pad labels on the right to `input_spec` and `output_spec` + for i in range(elided): + input_spec = input_spec + possible_labels[i] + output_spec = output_spec + possible_labels[i] + return input_spec, weight_spec, output_spec + + raise ValueError( + f"Invalid einsum equation '{equation}'. Equations must be in the " + "form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." + ) + + input_spec, weight_spec, output_spec = get_specs(equation, input_shape) + + # Determine the axes that should be reduced by the quantizer + input_reduced_axes = [] + weight_reduced_axes = [] + for i, label in enumerate(input_spec): + index = output_spec.find(label) + if index == -1: + input_reduced_axes.append(i) + for i, label in enumerate(weight_spec): + index = output_spec.find(label) + if index == -1: + weight_reduced_axes.append(i) + + # Determine the axes of `ops.expand_dims` + input_expand_axes = [] + weight_expand_axes = [] + for i, label in enumerate(output_spec): + index_input = input_spec.find(label) + index_weight = weight_spec.find(label) + if index_input == -1: + input_expand_axes.append(i) + if index_weight == -1: + weight_expand_axes.append(i) + + # Determine the axes of `ops.transpose` + input_transpose_axes = [] + weight_transpose_axes = [] + for i, label in enumerate(output_spec): + index_input = input_spec.find(label) + index_weight = weight_spec.find(label) + if index_input != -1: + input_transpose_axes.append(index_input) + if index_weight != -1: + weight_transpose_axes.append(index_weight) + # Postprocess the information: + # 1. Add dummy axes (1) to transpose_axes + # 2. Add axis to squeeze_axes if 1. failed + input_squeeze_axes = [] + weight_squeeze_axes = [] + for ori_index in input_reduced_axes: + try: + index = input_expand_axes.pop(0) + except IndexError: + input_squeeze_axes.append(ori_index) + input_transpose_axes.insert(index, ori_index) + for ori_index in weight_reduced_axes: + try: + index = weight_expand_axes.pop(0) + except IndexError: + weight_squeeze_axes.append(ori_index) + weight_transpose_axes.insert(index, ori_index) + # Prepare equation for `einsum_with_inputs_gradient` + custom_gradient_equation = f"{output_spec},{weight_spec}->{input_spec}" + weight_reverse_transpose_axes = [ + i + for (_, i) in sorted( + (v, i) for (i, v) in enumerate(weight_transpose_axes) + ) + ] + return ( + input_reduced_axes, + weight_reduced_axes, + input_transpose_axes, + weight_transpose_axes, + input_expand_axes, + weight_expand_axes, + input_squeeze_axes, + weight_squeeze_axes, + custom_gradient_equation, + weight_reverse_transpose_axes, + ) diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py new file mode 100644 index 000000000000..22e453ecddc4 --- /dev/null +++ b/keras/src/layers/core/einsum_dense_test.py @@ -0,0 +1,1183 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import constraints +from keras.src import export +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import quantizers +from keras.src import random +from keras.src import saving +from keras.src import testing +from keras.src.quantizers.gptq_config import GPTQConfig + + +class EinsumDenseTest(testing.TestCase): + @parameterized.named_parameters( + { + "testcase_name": "_1d_end_weight", + "equation": "ab,b->a", + "bias_axes": None, + "input_shape": (2, 32), + "output_shape": (), + "expected_kernel_shape": (32,), + "expected_bias_shape": None, + "expected_output_shape": (2,), + }, + { + "testcase_name": "_2d_middle_weight", + "equation": "ab,bc->ac", + "bias_axes": None, + "input_shape": (2, 32), + "output_shape": (64), + "expected_kernel_shape": (32, 64), + "expected_bias_shape": None, + "expected_output_shape": (2, 64), + }, + { + "testcase_name": "_3d_bert", + "equation": "abc,cde->abde", + "bias_axes": None, + "input_shape": (2, 1, 2), + "output_shape": (1, 3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": None, + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_3d_3_bias", + "equation": "abc,cde->abde", + "bias_axes": "e", + "input_shape": (2, 1, 2), + "output_shape": (1, 3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": (4,), + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_3d_2_bias", + "equation": "abc,cde->abde", + "bias_axes": "d", + "input_shape": (2, 1, 2), + "output_shape": (1, 3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": (3, 1), + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_3d_1_3_bias", + "equation": "abc,cde->abde", + "bias_axes": "be", + "input_shape": (2, 7, 2), + "output_shape": (7, 3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": (7, 1, 4), + "expected_output_shape": (2, 7, 3, 4), + }, + { + "testcase_name": "_3d_bert_projection", + "equation": "BFNH,NHD->BFD", + "bias_axes": None, + "input_shape": (2, 1, 2, 3), + "output_shape": (1, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": None, + "expected_output_shape": (2, 1, 4), + }, + { + "testcase_name": "_2d_bert", + "equation": "abc,cd->abd", + "bias_axes": None, + "input_shape": (2, 1, 2), + "output_shape": (1, 4), + "expected_kernel_shape": (2, 4), + "expected_bias_shape": None, + "expected_output_shape": (2, 1, 4), + }, + { + "testcase_name": "_embedding_1d", + "equation": "i,d->id", + "bias_axes": None, + "input_shape": (2,), + "output_shape": (2,), + "expected_kernel_shape": (2,), + "expected_bias_shape": None, + "expected_output_shape": (2, 2), + }, + { + "testcase_name": "_xlnet_lm", + "equation": "ibd,nd->ibn", + "bias_axes": None, + "input_shape": (2, 2, 1), + "output_shape": (2, 2), + "expected_kernel_shape": (2, 1), + "expected_bias_shape": None, + "expected_output_shape": (2, 2, 2), + }, + { + "testcase_name": "_2d_precast", + "equation": "...b,bc->...c", + "bias_axes": None, + "input_shape": (2, 32), + "output_shape": (64,), + "expected_kernel_shape": (32, 64), + "expected_bias_shape": None, + "expected_output_shape": (2, 64), + }, + { + "testcase_name": "_2d_precast_elided_input_used_in_output", + "equation": "...bc,bc->...b", + "bias_axes": None, + "input_shape": (2, 32, 64), + "output_shape": (32,), + "expected_kernel_shape": (32, 64), + "expected_bias_shape": None, + "expected_output_shape": (2, 32), + }, + { + "testcase_name": "_2d_precast_multiple_elided_dims", + "equation": "...b,bc->...c", + "bias_axes": None, + "input_shape": (2, 3, 32), + "output_shape": (64,), + "expected_kernel_shape": (32, 64), + "expected_bias_shape": None, + "expected_output_shape": (2, 3, 64), + }, + { + "testcase_name": "_3d_precast", + "equation": "...c,cde->...de", + "bias_axes": None, + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": None, + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_3d_precast_3_bias", + "equation": "...c,cde->...de", + "bias_axes": "e", + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": (4,), + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_3d_precast_2_bias", + "equation": "...c,cde->...de", + "bias_axes": "d", + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": (3, 1), + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_3d_precast_2_3_bias", + "equation": "...c,cde->...de", + "bias_axes": "de", + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (2, 3, 4), + "expected_bias_shape": (3, 4), + "expected_output_shape": (2, 1, 3, 4), + }, + { + "testcase_name": "_2d_postcast", + "equation": "bc...,cd->bd...", + "bias_axes": None, + "input_shape": (2, 1, 2, 3), + "output_shape": (4,), + "expected_kernel_shape": (1, 4), + "expected_bias_shape": None, + "expected_output_shape": (2, 4, 2, 3), + }, + { + "testcase_name": "_3d_postcast", + "equation": "bc...,cde->bde...", + "bias_axes": None, + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (1, 3, 4), + "expected_bias_shape": None, + "expected_output_shape": (2, 3, 4, 2), + }, + { + "testcase_name": "_3d_postcast_1_bias", + "equation": "bc...,cde->bde...", + "bias_axes": "d", + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (1, 3, 4), + "expected_bias_shape": (3, 1, 1), + "expected_output_shape": (2, 3, 4, 2), + }, + { + "testcase_name": "_3d_postcast_2_bias", + "equation": "bc...,cde->bde...", + "bias_axes": "e", + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (1, 3, 4), + "expected_bias_shape": (4, 1), + "expected_output_shape": (2, 3, 4, 2), + }, + { + "testcase_name": "_3d_postcast_1_2_bias", + "equation": "bc...,cde->bde...", + "bias_axes": "de", + "input_shape": (2, 1, 2), + "output_shape": (3, 4), + "expected_kernel_shape": (1, 3, 4), + "expected_bias_shape": (3, 4, 1), + "expected_output_shape": (2, 3, 4, 2), + }, + ) + @pytest.mark.requires_trainable_backend + def test_einsum_dense_basics( + self, + equation, + bias_axes, + input_shape, + output_shape, + expected_kernel_shape, + expected_bias_shape, + expected_output_shape, + ): + self.run_layer_test( + layers.EinsumDense, + init_kwargs={ + "equation": equation, + "output_shape": output_shape, + "bias_axes": bias_axes, + }, + input_shape=input_shape, + expected_output_shape=expected_output_shape, + expected_num_trainable_weights=( + 2 if expected_bias_shape is not None else 1 + ), + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + layer = layers.EinsumDense( + equation, output_shape=output_shape, bias_axes=bias_axes + ) + layer.build(input_shape) + self.assertEqual(layer.kernel.shape, expected_kernel_shape) + if expected_bias_shape is not None: + self.assertEqual(layer.bias.shape, expected_bias_shape) + + def test_einsum_dense_constraints(self): + layer = layers.EinsumDense( + "abc,cde->abde", (1, 3, 4), kernel_constraint="non_neg" + ) + layer.build((2, 1, 2)) + self.assertIsInstance(layer.kernel.constraint, constraints.NonNeg) + layer = layers.EinsumDense( + "ab,b->a", (1, 3, 4), bias_axes="a", bias_constraint="non_neg" + ) + layer.build((2, 1, 2)) + self.assertIsInstance(layer.bias.constraint, constraints.NonNeg) + + @pytest.mark.requires_trainable_backend + def test_enable_lora(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes=None, + ) + layer.build((None, 3)) + layer.enable_lora(2) + self.assertLen(layer.trainable_weights, 2) + self.assertLen(layer.non_trainable_weights, 1) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 3) + # Try eager call + x = np.random.random((64, 3)) + y = np.random.random((64, 8, 32)) + _ = layer(x[:2]) + + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + + # Try calling fit() + model = models.Sequential( + [ + layer, + ] + ) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + new_model = models.Sequential( + [ + layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes=None, + ), + ] + ) + new_model.build((None, 3)) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @pytest.mark.requires_trainable_backend + def test_enable_lora_with_alpha(self): + # Use a simple equation that mimics a `Dense` layer behavior. + equation = "ab,bc->ac" + output_shape = 3 # This means the kernel shape will be (input_dim, 3). + bias_axes = None + + # Create and build the `EinsumDense` layer + # with an input shape (None, 2). + layer = layers.EinsumDense( + equation=equation, output_shape=output_shape, bias_axes=bias_axes + ) + # Build the layer with an input shape of (batch, 2). + layer.build((None, 2)) + + # Set the base kernel weights to a known value. + base_kernel = np.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32 + ) + layer._kernel.assign(base_kernel) + + # Enable LoRA with `rank`=2 and a custom `lora_alpha`=3.0. + layer.enable_lora(rank=2, lora_alpha=3.0) + self.assertEqual(layer.lora_rank, 2) + self.assertEqual(layer.lora_alpha, 3.0) + + # The expected shapes are: + # `base_kernel`: (2, 3) + # `lora_kernel_a`: (2, 2) and `lora_kernel_b`: (2, 3) + a_val = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) + b_val = np.array([[0.5, 0.6, 0.7], [0.8, 0.9, 1.0]], dtype=np.float32) + layer.lora_kernel_a.assign(a_val) + layer.lora_kernel_b.assign(b_val) + + # Compute expected effective kernel. + # Scaling factor is `lora_alpha / lora_rank` = 3.0 / 2 = 1.5 + expected_delta = 1.5 * np.matmul(a_val, b_val) + expected_kernel = base_kernel + expected_delta + + # Verify that the effective kernel property returns the expected value. + actual_kernel = ops.convert_to_numpy(layer.kernel) + self.assertAllClose(actual_kernel, expected_kernel) + + @pytest.mark.requires_trainable_backend + def test_lora_rank_argument(self): + self.run_layer_test( + layers.EinsumDense, + init_kwargs={ + "equation": "ab,bcd->acd", + "output_shape": (8, 32), + "bias_axes": None, + "lora_rank": 2, + }, + input_shape=(2, 3), + expected_output_shape=(2, 8, 32), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + # Test quantization-related methods. + + @parameterized.named_parameters( + ("int8", "int8", 1e-3), + ("int4", "int4", 3e-3), + ) + def test_quantize_int(self, mode, error_threshold): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + x = np.random.random((2, 3)) + y_float = layer(x) + layer.quantize(mode) + + # Verify weights dtype + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), + layer.variable_dtype, + ) + + # Try eager call and verify output correctness + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, error_threshold) # A weak correctness test + + # Try saving and reloading the model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Try building with quantized dtype policy + layer = layers.EinsumDense( + equation="abcde,afce->acdbf", # Test reduce and transpose + output_shape=(2, 4, 8, 16), + bias_axes="d", + dtype=f"{mode}_from_mixed_bfloat16", + ) + layer.build((1, 8, 2, 4, 32)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" + ) + layer = layers.EinsumDense( + equation="a,b->ab", # Test expand + output_shape=(4,), + dtype=f"{mode}_from_float32", + ) + layer.build((None,)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" + ) + layer = layers.EinsumDense( + equation="ab,ab->a", # Test squeeze + output_shape=(2,), + dtype="int8_from_float32", + ) + layer.build((2, 4)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" + ) + + @parameterized.named_parameters( + ( + "int8_btnh,nhd->btd", + "int8", + "btnh,nhd->btd", + (None, 8), + (1, 2, 2, 4), + 1e-3, + ), + ( + "int8_btd,ndh->btnh", + "int8", + "btd,ndh->btnh", + (None, 2, 8), + (1, 2, 4), + 1e-3, + ), + ("int8_btd,df->btf", "int8", "btd,df->btf", (None, 4), (1, 2, 4), 1e-3), + ( + "int4_btnh,nhd->btd", + "int4", + "btnh,nhd->btd", + (None, 8), + (1, 2, 2, 4), + 3e-3, + ), + ( + "int4_btd,ndh->btnh", + "int4", + "btd,ndh->btnh", + (None, 2, 8), + (1, 2, 4), + 3e-3, + ), + ( + "int4_btd,df->btf", + "int4", + "btd,df->btf", + (None, 4), + (1, 2, 4), + 3e-3, + ), + ) + def test_quantize_with_specific_equations( + self, + quantization_mode, + equation, + output_shape, + input_shape, + error_threshold, + ): + layer = layers.EinsumDense(equation=equation, output_shape=output_shape) + layer.build(input_shape) + x = ops.random.uniform(input_shape) + y_float = layer(x) + + layer.quantize(quantization_mode) + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, error_threshold) # A weak correctness test + + @parameterized.named_parameters( + ("int8", "int8"), + ("float8", "float8"), + ("int4", "int4"), + ) + def test_quantize_on_unbuilt_layer(self, mode): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + with self.assertRaisesRegex( + ValueError, "Cannot quantize a layer that isn't yet built." + ): + layer.quantize(mode) + + @parameterized.named_parameters( + ("int8", "int8"), + ("float8", "float8"), + ("int4", "int4"), + ) + def test_quantize_on_subclass(self, mode): + class MyEinsumDense(layers.EinsumDense): + pass + + layer = MyEinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + with self.assertRaises(NotImplementedError): + layer.quantize(mode) + + layer.quantize(mode, type_check=False) # No error + + @parameterized.named_parameters( + ("int8", "int8"), + ("float8", "float8"), + ("int4", "int4"), + ) + def test_quantize_when_already_quantized(self, mode): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 16), + bias_axes="d", + ) + layer.build((None, 3)) + layer.quantize(mode) + for m in ["int8", "float8"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 16), + bias_axes="d", + dtype=f"{mode}_from_float32", + ) + layer.build((None, 3)) + for m in ["int8", "float8"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + @parameterized.named_parameters( + ("int8", "int8_from_float32", 3), + ("float8", "float8_from_float32", 8), + ("int4", "int4_from_float32", 3), + ) + def test_quantize_by_setting_dtype_policy( + self, policy, expected_num_variables + ): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.dtype_policy = policy + self.assertLen(layer.variables, expected_num_variables) + + @parameterized.named_parameters( + ("int7", "int7"), + ("float7", "float7"), + ("int3", "int3"), + ) + def test_quantize_invalid_mode(self, mode): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + x = np.random.random((1, 3)) + # dtype_policy should not be altered by failed quantization + original_dtype_policy = layer.dtype_policy + + # Test quantize + with self.assertRaisesRegex(ValueError, "Invalid quantization mode."): + layer.quantize(mode) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + # Test quantized_build + with self.assertRaisesRegex( + NotImplementedError, "Invalid quantization mode." + ): + layer.quantized_build((None, 2), mode) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + # Test quantized_call + with self.assertRaisesRegex( + NotImplementedError, "Invalid quantization mode." + ): + # Explicitly set quantization_mode + layer._dtype_policy._quantization_mode = mode + layer.quantized_call(x) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + @parameterized.named_parameters( + ("int8", "int8_from_mixed_bfloat16", 1, 2), + ("float8", "float8_from_mixed_bfloat16", 8, 0), + ("int4", "int4_from_mixed_bfloat16", 1, 2), + ) + @pytest.mark.requires_trainable_backend + def test_quantize_dtype_argument( + self, dtype, num_trainable_weights, num_non_trainable_weights + ): + self.run_layer_test( + layers.EinsumDense, + init_kwargs={ + "equation": "ab,bcd->acd", + "output_shape": (8, 32), + "bias_axes": "d", + "dtype": dtype, + }, + input_shape=(2, 3), + expected_output_shape=(2, 8, 32), + expected_num_trainable_weights=num_trainable_weights, + expected_num_non_trainable_weights=num_non_trainable_weights, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.named_parameters( + ("int8_ab,bcd->acd", "int8", "ab,bcd->acd", (64, 3), (64, 8, 32)), + ( + "int8_btd,ndh->btnh", + "int8", + "btd,ndh->btnh", + (1, 4, 32), + (1, 4, 8, 16), + ), + ("int4_ab,bcd->acd", "int4", "ab,bcd->acd", (64, 3), (64, 8, 32)), + ( + "int4_btd,ndh->btnh", + "int4", + "btd,ndh->btnh", + (1, 4, 32), + (1, 4, 8, 16), + ), + ) + @pytest.mark.requires_trainable_backend + def test_quantize_lora_integration( + self, quantization_mode, equation, input_shape, output_shape + ): + config = dict( + equation=equation, output_shape=output_shape[1:], bias_axes=None + ) + layer = layers.EinsumDense(**config) + layer.build(input_shape) + layer.enable_lora(2) + layer.quantize(quantization_mode) + self.assertLen(layer.trainable_weights, 2) + self.assertLen(layer.non_trainable_weights, 2) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 4) + + # Try calling fit() + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + x = np.random.random(input_shape) + y = np.random.random(output_shape) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.EinsumDense(**config)]) + new_model.build(input_shape) + new_model.quantize(quantization_mode) + new_model.load_weights(temp_filepath) + self.assertFalse(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal(input_shape) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose( + reloaded_layer(ref_input), ref_output, atol=1e-7 + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + @pytest.mark.requires_trainable_backend + def test_quantize_float8(self): + import ml_dtypes + + from keras.src import quantizers + + layer = layers.EinsumDense( + "ab,bc->ac", + output_shape=[32], + bias_axes="c", + ) + layer.build((None, 16)) + layer.quantize("float8") + optimizer = optimizers.AdamW(learning_rate=0.1) + optimizer.build(layer.trainable_variables) + + def loss_fn(x, dy): + y = layer(x, training=True) + loss = y * ops.cast(dy, y.dtype) + return ops.sum(loss) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + @tf.function(jit_compile=True) + def train_one_step(x, dy): + with tf.GradientTape() as tape: + loss = loss_fn(x, dy) + grads = tape.gradient(loss, layer.trainable_variables) + optimizer.apply(grads, layer.trainable_variables) + + elif backend.backend() == "jax": + import jax + + def stateless_loss_fn(trainable_variables, x, dy): + y = layer.stateless_call( + trainable_variables, [], x, training=True + )[0] + loss = y * ops.cast(dy, y.dtype) + return ops.sum(loss) + + grad_fn = jax.jit(jax.grad(stateless_loss_fn)) + + def train_one_step(x, dy): + trainable_variables = [ + v.value for v in layer.trainable_variables + ] + optimizer_variables = [v.value for v in optimizer.variables] + grads = grad_fn(trainable_variables, x, dy) + trainable_variables, optimizer_variables = ( + optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + ) + for variable, value in zip( + layer.trainable_variables, trainable_variables + ): + variable.assign(value) + for variable, value in zip( + optimizer.variables, optimizer_variables + ): + variable.assign(value) + + elif backend.backend() == "torch": + + def train_one_step(x, dy): + layer.zero_grad() + loss = loss_fn(x, dy) + loss.backward() + grads = [v.value.grad for v in layer.trainable_variables] + optimizer.apply(grads, layer.trainable_variables) + + scale_x, amax_history_x = ops.ones(()), ops.zeros((1024,)) + scale_k, amax_history_k = ops.ones(()), ops.zeros((1024,)) + scale_g, amax_history_g = ops.ones(()), ops.zeros((1024,)) + e4m3_max = ops.cast( + float(ml_dtypes.finfo("float8_e4m3fn").max), "float32" + ) + e5m2_max = ops.cast( + float(ml_dtypes.finfo("float8_e5m2").max), "float32" + ) + + for _ in range(3): + x = random.normal((16, 16), dtype="float32") + g = random.normal((16, 32), dtype="float32") + k = ops.convert_to_tensor(layer._kernel) + + # Manually compute the expected amax history and scaling factors. + amax_from_history_x = ops.max(amax_history_x) + amax_from_history_k = ops.max(amax_history_k) + amax_from_history_g = ops.max(amax_history_g) + scale_x = quantizers.compute_float8_scale( + amax_from_history_x, scale_x, e4m3_max + ) + scale_k = quantizers.compute_float8_scale( + amax_from_history_k, scale_k, e4m3_max + ) + scale_g = quantizers.compute_float8_scale( + amax_from_history_g, scale_g, e5m2_max + ) + amax_history_x = quantizers.compute_float8_amax_history( + x, amax_history_x + ) + amax_history_k = quantizers.compute_float8_amax_history( + k, amax_history_k + ) + amax_history_g = quantizers.compute_float8_amax_history( + g, amax_history_g + ) + + train_one_step(x, g) + + self.assertAllClose(layer.inputs_amax_history, amax_history_x) + self.assertAllClose(layer.kernel_amax_history, amax_history_k) + self.assertAllClose(layer.outputs_grad_amax_history, amax_history_g) + self.assertAllClose(layer.inputs_scale, scale_x) + self.assertAllClose(layer.kernel_scale, scale_k) + self.assertAllClose(layer.outputs_grad_scale, scale_g) + + @pytest.mark.requires_trainable_backend + def test_quantize_float8_fitting(self): + config = dict( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer = layers.EinsumDense(**config) + layer.build((None, 3)) + layer.quantize("float8") + self.assertLen(layer.trainable_weights, 8) + self.assertLen(layer.non_trainable_weights, 0) + + # Try calling fit() + x = np.random.random((64, 3)) + y = np.random.random((64, 8, 32)) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + # Try saving and reloading the model + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.EinsumDense(**config)]) + new_model.build((None, 3)) + new_model.quantize("float8") + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal((2, 3)) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose(reloaded_layer(ref_input), ref_output) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_quantize_float8_inference(self): + config = dict( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer = layers.EinsumDense(**config) + layer.build((None, 3)) + layer.quantize("float8") + + # Try calling with `training=False` and the result must match + # `training=True` because there is no update. + x = np.random.random((64, 3)) + y_inference = layer(x, training=False) + y_training = layer(x, training=True) + self.assertAllClose(y_inference, y_training) + + def test_gptq_serialization(self): + """Test that a GPTQ-quantized layer can be serialized and deserialized + correctly.""" + config = dict( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer = layers.EinsumDense(**config) + layer.build((None, 3)) + layer.quantize( + "gptq", + config=GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=8 + ), + ) + config = layer.get_config() + new_layer = layers.EinsumDense.from_config(config) + new_layer.build((None, 3)) + self.assertEqual(new_layer.quantization_mode, "gptq") + + def test_int4_kernel_returns_unpacked_form(self): + """Test that the `kernel` property returns the unpacked int4 kernel.""" + layer = layers.EinsumDense( + equation="ab,bc->ac", + output_shape=(2,), + ) + layer.build((None, 2)) + layer.quantize("int4") + packed_kernel = layer._kernel + self.assertAllClose( + layer.kernel, quantizers.unpack_int4(packed_kernel, 2) + ) + + def test_legacy_load_own_variables(self): + # In previous versions, `load_own_variables` accepted a store with + # numeric keys. + float32_store = { + "0": np.random.random((3, 8, 32)).astype("float32"), + "1": np.random.random((32,)).astype("float32"), + } + int8_store = { + "0": np.random.randint(-128, 127, size=(3, 8, 32), dtype="int8"), + "1": np.random.random((32,)).astype("float32"), + "2": np.random.random((1, 8, 32)).astype("float32"), + } + int4_store = { + "0": np.random.randint(-128, 127, size=(2, 8, 32), dtype="int8"), + "1": np.random.random((32,)).astype("float32"), + "2": np.random.random((1, 8, 32)).astype("float32"), + } + float8_store = { + "0": np.random.random((3, 8, 32)).astype("float32"), + "1": np.random.random((32,)).astype("float32"), + # inputs_scale. + "2": np.random.random(()).astype("float32"), + # inputs_amax_history. + "3": np.random.random((1024,)).astype("float32"), + # kernel_scale. + "4": np.random.random(()).astype("float32"), + # kernel_amax_history. + "5": np.random.random((1024,)).astype("float32"), + # outputs_grad_scale. + "6": np.random.random(()).astype("float32"), + # outputs_grad_amax_history. + "7": np.random.random((1024,)).astype("float32"), + } + gptq_store = { + # bias + "0": np.random.random((32,)).astype("float32"), + # quantized_kernel + "1": np.random.randint(0, 16, size=(16, 24), dtype="uint8"), + # kernel_scale. + "2": np.random.random((32, 3)).astype("float32"), + # kernel_zero + "3": np.random.random((32, 3)).astype("uint8"), + # g_idx + "4": np.random.random((24,)).astype("float32"), + } + config = dict( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + + # Test float32 layer. + layer = layers.EinsumDense(**config) + layer.build((None, 3)) + layer.load_own_variables(float32_store) + self.assertAllClose(layer._kernel, float32_store["0"]) + self.assertAllClose(layer.bias, float32_store["1"]) + + # Test int8-quantized layer. + layer = layers.EinsumDense(**config, dtype="int8_from_float32") + layer.build((None, 3)) + layer.load_own_variables(int8_store) + self.assertAllClose(layer._kernel, int8_store["0"]) + self.assertAllClose(layer.bias, int8_store["1"]) + self.assertAllClose(layer.kernel_scale, int8_store["2"]) + + # Test int4-quantized layer. + layer = layers.EinsumDense(**config, dtype="int4_from_float32") + layer.build((None, 3)) + layer.load_own_variables(int4_store) + self.assertAllClose(layer._kernel, int4_store["0"]) + self.assertAllClose(layer.bias, int4_store["1"]) + self.assertAllClose(layer.kernel_scale, int4_store["2"]) + + # Test float8-quantized layer. + layer = layers.EinsumDense(**config, dtype="float8_from_float32") + layer.build((None, 3)) + layer.load_own_variables(float8_store) + self.assertAllClose(layer._kernel, float8_store["0"]) + self.assertAllClose(layer.bias, float8_store["1"]) + self.assertAllClose(layer.inputs_scale, float8_store["2"]) + self.assertAllClose(layer.inputs_amax_history, float8_store["3"]) + self.assertAllClose(layer.kernel_scale, float8_store["4"]) + self.assertAllClose(layer.kernel_amax_history, float8_store["5"]) + self.assertAllClose(layer.outputs_grad_scale, float8_store["6"]) + self.assertAllClose(layer.outputs_grad_amax_history, float8_store["7"]) + + # Test gptq-quantized layer. + layer = layers.EinsumDense(**config, dtype="gptq/4/8_from_float32") + layer.build((None, 3)) + layer.load_own_variables(gptq_store) + self.assertAllClose(layer.bias, gptq_store["0"]) + self.assertAllClose(layer.quantized_kernel, gptq_store["1"]) + self.assertAllClose(layer.kernel_scale, gptq_store["2"]) + self.assertAllClose(layer.kernel_zero, gptq_store["3"]) + self.assertAllClose(layer.g_idx, gptq_store["4"]) + + def test_int4_gptq_kernel_returns_unpacked_form(self): + """Test that the `kernel` property returns the unpacked int4 GPTQ + kernel.""" + layer = layers.EinsumDense( + equation="ab,bc->ac", + output_shape=(2,), + ) + layer.build((None, 2)) + layer.quantize( + "gptq", + config=GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=8 + ), + ) + layer.is_gptq_calibrated = True # Bypass calibration check + packed_kernel = layer.quantized_kernel + self.assertAllClose( + layer.kernel, quantizers.unpack_int4(packed_kernel, 2) + ) + + def test_gptq_kernel_packing(self): + """Validates that 4-bit GPTQ packing reduces the kernel size.""" + config = dict( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer = layers.EinsumDense(**config) + layer.build((None, 3)) + + original_kernel_params = ops.prod(layer._kernel.shape) + + layer.quantize( + "gptq", + config=GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=8 + ), + ) + + quantized_kernel_params = ops.prod(layer.quantized_kernel.shape) + self.assertEqual( + quantized_kernel_params, + original_kernel_params // 2, + ) diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py new file mode 100644 index 000000000000..aa809be63f34 --- /dev/null +++ b/keras/src/layers/core/embedding.py @@ -0,0 +1,579 @@ +import warnings + +from keras.src import backend +from keras.src import constraints +from keras.src import dtype_policies +from keras.src import initializers +from keras.src import ops +from keras.src import quantizers +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Embedding") +class Embedding(Layer): + """Turns nonnegative integers (indexes) into dense vectors of fixed size. + + e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]` + + This layer can only be used on nonnegative integer inputs of a fixed range. + + Example: + + >>> model = keras.Sequential() + >>> model.add(keras.layers.Embedding(1000, 64)) + >>> # The model will take as input an integer matrix of size (batch, + >>> # input_length), and the largest integer (i.e. word index) in the input + >>> # should be no larger than 999 (vocabulary size). + >>> # Now model.output_shape is (None, 10, 64), where `None` is the batch + >>> # dimension. + >>> input_array = np.random.randint(1000, size=(32, 10)) + >>> model.compile('rmsprop', 'mse') + >>> output_array = model.predict(input_array) + >>> print(output_array.shape) + (32, 10, 64) + + Args: + input_dim: Integer. Size of the vocabulary, + i.e. maximum integer index + 1. + output_dim: Integer. Dimension of the dense embedding. + embeddings_initializer: Initializer for the `embeddings` + matrix (see `keras.initializers`). + embeddings_regularizer: Regularizer function applied to + the `embeddings` matrix (see `keras.regularizers`). + embeddings_constraint: Constraint function applied to + the `embeddings` matrix (see `keras.constraints`). + mask_zero: Boolean, whether or not the input value 0 is a special + "padding" value that should be masked out. + This is useful when using recurrent layers which + may take variable length input. If this is `True`, + then all subsequent layers in the model need + to support masking or an exception will be raised. + If `mask_zero` is set to `True`, as a consequence, + index 0 cannot be used in the vocabulary (`input_dim` should + equal size of vocabulary + 1). + weights: Optional floating-point matrix of size + `(input_dim, output_dim)`. The initial embeddings values + to use. + lora_rank: Optional integer. If set, the layer's forward pass + will implement LoRA (Low-Rank Adaptation) + with the provided rank. LoRA sets the layer's embeddings + matrix to non-trainable and replaces it with a delta over the + original matrix, obtained via multiplying two lower-rank + trainable matrices. This can be useful to reduce the + computation cost of fine-tuning large embedding layers. + You can also enable LoRA on an existing + `Embedding` layer by calling `layer.enable_lora(rank)`. + lora_alpha: Optional integer. If set, this parameter scales the + low-rank adaptation delta (computed as the product of two lower-rank + trainable matrices) during the forward pass. The delta is scaled by + `lora_alpha / lora_rank`, allowing you to fine-tune the strength of + the LoRA adjustment independently of `lora_rank`. + + Input shape: + 2D tensor with shape: `(batch_size, input_length)`. + + Output shape: + 3D tensor with shape: `(batch_size, input_length, output_dim)`. + """ + + def __init__( + self, + input_dim, + output_dim, + embeddings_initializer="uniform", + embeddings_regularizer=None, + embeddings_constraint=None, + mask_zero=False, + weights=None, + lora_rank=None, + lora_alpha=None, + **kwargs, + ): + input_length = kwargs.pop("input_length", None) + if input_length is not None: + warnings.warn( + "Argument `input_length` is deprecated. Just remove it." + ) + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + self.embeddings_initializer = initializers.get(embeddings_initializer) + self.embeddings_regularizer = regularizers.get(embeddings_regularizer) + self.embeddings_constraint = constraints.get(embeddings_constraint) + self.mask_zero = mask_zero + self.supports_masking = mask_zero + self.autocast = False + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank + self.lora_enabled = False + + if weights is not None: + self.build() + if not (isinstance(weights, list) and len(weights) == 1): + weights = [weights] + self.set_weights(weights) + + def build(self, input_shape=None): + if self.built: + return + embeddings_shape = (self.input_dim, self.output_dim) + if self.quantization_mode: + self.quantized_build(embeddings_shape, mode=self.quantization_mode) + if self.quantization_mode not in ("int8", "int4"): + self._embeddings = self.add_weight( + shape=embeddings_shape, + initializer=self.embeddings_initializer, + name="embeddings", + regularizer=self.embeddings_regularizer, + constraint=self.embeddings_constraint, + trainable=True, + ) + self.built = True + if self.lora_rank: + self.enable_lora(self.lora_rank) + + @property + def embeddings(self): + if not self.built: + raise AttributeError( + "You must build the layer before accessing `embeddings`." + ) + embeddings = self._embeddings + if self.quantization_mode == "int4": + embeddings = quantizers.unpack_int4( + embeddings, self._orig_output_dim, axis=-1 + ) + if self.lora_enabled: + return embeddings + (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_embeddings_a, self.lora_embeddings_b + ) + return embeddings + + def call(self, inputs): + if inputs.dtype != "int32" and inputs.dtype != "int64": + inputs = ops.cast(inputs, "int32") + outputs = ops.take(self.embeddings, inputs, axis=0) + return ops.cast(outputs, dtype=self.compute_dtype) + + def compute_mask(self, inputs, mask=None): + if not self.mask_zero: + return None + return ops.not_equal(inputs, 0) + + def compute_output_shape(self, input_shape): + return (*input_shape, self.output_dim) + + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape(inputs.shape) + ragged = getattr(inputs, "ragged", False) + return KerasTensor( + output_shape, dtype=self.compute_dtype, ragged=ragged + ) + + def enable_lora( + self, + rank, + lora_alpha=None, + a_initializer="he_uniform", + b_initializer="zeros", + ): + if self.embeddings_constraint: + raise ValueError( + "Lora is incompatible with embedding constraints. " + "In order to enable lora on this layer, remove the " + "`embeddings_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. This can only be done once per layer." + ) + self._tracker.unlock() + self.lora_embeddings_a = self.add_weight( + name="lora_embeddings_a", + shape=(self.input_dim, rank), + initializer=initializers.get(a_initializer), + regularizer=self.embeddings_regularizer, + ) + self.lora_embeddings_b = self.add_weight( + name="lora_embeddings_b", + shape=(rank, self.output_dim), + initializer=initializers.get(b_initializer), + regularizer=self.embeddings_regularizer, + ) + self.embeddings.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + self.lora_alpha = lora_alpha if lora_alpha is not None else rank + + def save_own_variables(self, store): + # Do nothing if the layer isn't yet built + if not self.built: + return + mode = self.quantization_mode + if mode not in self.quantization_variable_spec: + raise self._quantization_mode_error(mode) + + # Embeddings plus optional merged LoRA-aware scale + # (returns (kernel, None) for None/gptq). + embeddings_value, merged_kernel_scale = ( + self._get_embeddings_with_merged_lora() + ) + + # Save the variables using the name as the key. + store["embeddings"] = embeddings_value + for name in self.quantization_variable_spec[mode]: + if name == "embeddings_scale" and mode in ("int4", "int8"): + # For int4/int8, the merged LoRA scale (if any) comes from + # `_get_embeddings_with_merged_lora()` + store[name] = merged_kernel_scale + else: + store[name] = getattr(self, name) + + def load_own_variables(self, store): + if not self.lora_enabled: + self._check_load_own_variables(store) + # Do nothing if the layer isn't yet built + if not self.built: + return + mode = self.quantization_mode + if mode not in self.quantization_variable_spec: + raise self._quantization_mode_error(mode) + + # Determine whether to use the legacy loading method. + if "0" in store: + return self._legacy_load_own_variables(store) + + # Load the variables using the name as the key. + self._embeddings.assign(store["embeddings"]) + for name in self.quantization_variable_spec[mode]: + getattr(self, name).assign(store[name]) + if self.lora_enabled: + self.lora_embeddings_a.assign( + ops.zeros(self.lora_embeddings_a.shape) + ) + self.lora_embeddings_b.assign( + ops.zeros(self.lora_embeddings_b.shape) + ) + + def _legacy_load_own_variables(self, store): + # The keys of the `store` will be saved as determined because the + # default ordering will change after quantization + mode = self.quantization_mode + targets = [self._embeddings] + targets.extend( + getattr(self, name) + for name in self.quantization_variable_spec[mode] + ) + for i, variable in enumerate(targets): + variable.assign(store[str(i)]) + if self.lora_enabled: + self.lora_embeddings_a.assign( + ops.zeros(self.lora_embeddings_a.shape) + ) + self.lora_embeddings_b.assign( + ops.zeros(self.lora_embeddings_b.shape) + ) + + def get_config(self): + base_config = super().get_config() + config = { + "input_dim": self.input_dim, + "output_dim": self.output_dim, + "embeddings_initializer": initializers.serialize( + self.embeddings_initializer + ), + "embeddings_regularizer": regularizers.serialize( + self.embeddings_regularizer + ), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "embeddings_constraint": constraints.serialize( + self.embeddings_constraint + ), + "mask_zero": self.mask_zero, + } + if self.lora_rank: + config["lora_rank"] = self.lora_rank + config["lora_alpha"] = self.lora_alpha + return {**base_config, **config} + + def _check_load_own_variables(self, store): + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) + + def _quantization_mode_error(self, mode): + return NotImplementedError( + "Invalid quantization mode. Expected one of ('int8', 'int4'). " + f"Received: quantization_mode={mode}" + ) + + @property + def quantization_variable_spec(self): + """Returns a dict mapping quantization modes to variable names. + + This spec is used by `save_own_variables` and `load_own_variables` to + determine which variables should be saved/loaded for each quantization + mode. + """ + return { + None: [], + "int8": ["embeddings_scale"], + "int4": ["embeddings_scale"], + } + + def quantized_build(self, embeddings_shape, mode): + if mode == "int8": + self._int8_build(embeddings_shape) + elif mode == "int4": + self._int4_build(embeddings_shape) + else: + raise self._quantization_mode_error(mode) + self._is_quantized = True + + def _int8_build(self, embeddings_shape): + self._embeddings = self.add_weight( + name="embeddings", + shape=embeddings_shape, + initializer="zeros", + dtype="int8", + trainable=False, + ) + # We choose to reduce the axis of `output_dim` because, typically, + # `input_dim` is larger than `output_dim`. This reduces quantization + # error. + self.embeddings_scale = self.add_weight( + name="embeddings_scale", + shape=(self.input_dim,), + initializer="ones", + trainable=False, + ) + + def _int4_build(self, embeddings_shape): + input_dim, output_dim = embeddings_shape + packed_rows = (output_dim + 1) // 2 # ceil for odd dims + + # Embeddings are stored *packed*: each int8 byte contains two int4 + # values. + self._embeddings = self.add_weight( + name="embeddings", + shape=(input_dim, packed_rows), + initializer="zeros", + dtype="int8", + trainable=False, + ) + self.embeddings_scale = self.add_weight( + name="embeddings_scale", + shape=(self.input_dim,), + initializer="ones", + trainable=False, + ) + # Record original output_dim for unpacking at runtime. + self._orig_output_dim = output_dim + + def _int8_call(self, inputs, training=None): + # We cannot update quantized self._embeddings, so the custom gradient is + # not needed + if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"): + inputs = ops.cast(inputs, "int32") + embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0) + outputs = ops.take(self._embeddings, inputs, axis=0) + # De-scale outputs + outputs = ops.divide( + ops.cast(outputs, dtype=self.compute_dtype), + ops.expand_dims(embeddings_scale, axis=-1), + ) + if self.lora_enabled: + lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0) + lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b) + outputs = ops.add( + outputs, (self.lora_alpha / self.lora_rank) * lora_outputs + ) + return outputs + + def _int4_call(self, inputs, training=None): + # We cannot update quantized self._embeddings, so the custom gradient is + # not needed + if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"): + inputs = ops.cast(inputs, "int32") + embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0) + unpacked_embeddings = quantizers.unpack_int4( + self._embeddings, self._orig_output_dim, axis=-1 + ) + outputs = ops.take(unpacked_embeddings, inputs, axis=0) + # De-scale outputs + outputs = ops.divide( + ops.cast(outputs, dtype=self.compute_dtype), + ops.expand_dims(embeddings_scale, axis=-1), + ) + if self.lora_enabled: + lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0) + lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b) + outputs = ops.add( + outputs, (self.lora_alpha / self.lora_rank) * lora_outputs + ) + return outputs + + def quantize(self, mode, type_check=True, config=None): + # Prevent quantization of the subclasses. + if type_check and (type(self) is not Embedding): + raise self._not_implemented_error(self.quantize) + + embeddings_shape = (self.input_dim, self.output_dim) + if mode == "int8": + # Quantize `self._embeddings` to int8 and compute corresponding + # scale. + embeddings_value, embeddings_scale = quantizers.abs_max_quantize( + self._embeddings, axis=-1, to_numpy=True + ) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) + del self._embeddings + self.quantized_build(embeddings_shape, mode) + self._embeddings.assign(embeddings_value) + self.embeddings_scale.assign(embeddings_scale) + elif mode == "int4": + # Quantize to int4 values (stored in int8 dtype, range [-8, 7]). + embeddings_value, embeddings_scale = quantizers.abs_max_quantize( + self._embeddings, + axis=-1, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) + # 2. Pack two int4 values into a single int8 byte. + packed_embeddings_value, _, _ = quantizers.pack_int4( + embeddings_value, axis=-1 + ) + del self._embeddings + self.quantized_build(embeddings_shape, mode) + self._embeddings.assign(packed_embeddings_value) + self.embeddings_scale.assign(embeddings_scale) + else: + raise self._quantization_mode_error(mode) + + # Set new dtype policy. + if self.dtype_policy.quantization_mode is None: + policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}") + self.dtype_policy = policy + + def _get_embeddings_with_merged_lora(self): + """Returns the embeddings with LoRA matrices merged, for serialization. + + This method is called by `save_own_variables` to produce a single + embeddings tensor that includes the adaptations from LoRA. This is + useful for deploying the model or for continuing training after + permanently applying the LoRA update. + + If the layer is quantized (`int8` or `int4`), the process is: + 1. Dequantize the base embeddings to float. + 2. Compute the LoRA delta (`lora_embeddings_a @ lora_embeddings_b`) and + add it to the dequantized embeddings. + 3. Re-quantize the merged result back to the original quantized + type (`int8` or packed `int4`), calculating a new scale factor. + + If the layer is not quantized, this method returns the result of the + `embeddings` property (which computes the merge in floating-point) and a + scale of `None`. + + If LoRA is not enabled, it returns the original embeddings and scale + without modification. + + Returns: + A tuple `(embeddings_value, embeddings_scale)`: + `embeddings_value`: The merged embeddings. A quantized tensor if + quantization is active, otherwise a high precision tensor. + `embeddings_scale`: The quantization scale for the merged + embeddings. This is `None` if the layer is not quantized. + """ + if self.dtype_policy.quantization_mode in (None, "gptq"): + return self.embeddings, None + + embeddings_value = self._embeddings + embeddings_scale = self.embeddings_scale + if not self.lora_enabled: + return embeddings_value, embeddings_scale + + # Dequantize embeddings to float. + if self.quantization_mode == "int4": + unpacked_embeddings = quantizers.unpack_int4( + embeddings_value, self._orig_output_dim, axis=-1 + ) + float_embeddings = ops.divide( + ops.cast(unpacked_embeddings, self.compute_dtype), + ops.expand_dims(embeddings_scale, axis=-1), + ) + quant_range = (-8, 7) + elif self.quantization_mode == "int8": + float_embeddings = ops.divide( + ops.cast(embeddings_value, self.compute_dtype), + ops.expand_dims(embeddings_scale, axis=-1), + ) + quant_range = (-127, 127) + else: + raise ValueError( + f"Unsupported quantization mode: {self.quantization_mode}" + ) + + # Merge LoRA weights in float domain. + lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_embeddings_a, self.lora_embeddings_b + ) + merged_float_embeddings = ops.add(float_embeddings, lora_delta) + + # Requantize. + requantized_embeddings, embeddings_scale = quantizers.abs_max_quantize( + merged_float_embeddings, + axis=-1, + value_range=quant_range, + dtype="int8", + to_numpy=True, + ) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) + + # Pack if int4. + if self.quantization_mode == "int4": + embeddings_value, _, _ = quantizers.pack_int4( + requantized_embeddings, axis=-1 + ) + else: + embeddings_value = requantized_embeddings + return embeddings_value, embeddings_scale diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py new file mode 100644 index 000000000000..a22cab911caa --- /dev/null +++ b/keras/src/layers/core/embedding_test.py @@ -0,0 +1,585 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import constraints +from keras.src import export +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import quantizers +from keras.src import saving +from keras.src.testing import test_case + + +class EmbeddingTest(test_case.TestCase): + @pytest.mark.requires_trainable_backend + def test_embedding_basics(self): + self.run_layer_test( + layers.Embedding, + {"input_dim": 4, "output_dim": 3}, + input_shape=(2,), + input_dtype="int32", + expected_output_shape=(2, 3), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.Embedding, + {"input_dim": 5, "output_dim": 4, "mask_zero": True}, + input_shape=(2, 3), + input_dtype="int64", + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_sparse(self): + self.run_layer_test( + layers.Embedding, + {"input_dim": 5, "output_dim": 4}, + input_shape=(2, 3), + input_dtype="int32", + input_sparse=True, + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @pytest.mark.skipif( + not backend.SUPPORTS_RAGGED_TENSORS, + reason="Backend does not support ragged tensors.", + ) + def test_ragged(self): + self.run_layer_test( + layers.Embedding, + {"input_dim": 5, "output_dim": 4}, + input_shape=(2, 3), + input_dtype="int32", + input_ragged=True, + expected_output_shape=(2, None, 4), + expected_output_ragged=True, + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + # run_training_check=False, + ) + + def test_correctness(self): + layer = layers.Embedding(input_dim=3, output_dim=2) + layer.build() + layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]])) + out = layer(np.array([2, 1, 0])) + self.assertAllClose(out, np.array([[3.0, 3.0], [2.0, 2.0], [0.0, 0.0]])) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_correctness_sparse(self): + layer = layers.Embedding(input_dim=3, output_dim=2) + layer.build() + layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]])) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + x = tf.SparseTensor([[0, 0], [1, 2]], [2, 1], (2, 3)) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + x = jax_sparse.BCOO(([2, 1], [[0, 0], [1, 2]]), shape=(2, 3)) + else: + self.fail(f"Sparse is unsupported with backend {backend.backend()}") + + self.assertAllClose( + layer(x), + np.array( + [ + [[3.0, 3.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [2.0, 2.0]], + ] + ), + ) + + def test_masking(self): + layer = layers.Embedding(input_dim=3, output_dim=2, mask_zero=True) + layer.build() + out = layer.compute_mask(np.array(([2, 1, 0]))) + self.assertAllClose(out, np.array([True, True, False])) + + def test_compute_mask_no_masking(self): + layer = layers.Embedding(input_dim=3, output_dim=2, mask_zero=False) + input_data = np.array([2, 1, 0]) + mask = layer.compute_mask(input_data) + self.assertIsNone(mask) + + def test_embedding_constraints(self): + layer = layers.Embedding(3, 2, embeddings_constraint="non_neg") + layer.build((None, 2)) + self.assertIsInstance(layer.embeddings.constraint, constraints.NonNeg) + + def test_weights_constructor_arg(self): + layer = layers.Embedding(3, 4, weights=np.ones((3, 4))) + self.assertAllClose(layer.embeddings.numpy(), np.ones((3, 4))) + layer = layers.Embedding(3, 4, weights=[np.ones((3, 4))]) + self.assertAllClose(layer.embeddings.numpy(), np.ones((3, 4))) + + @pytest.mark.requires_trainable_backend + def test_enable_lora(self): + layer = layers.Embedding(10, 16) + layer.build() + layer.enable_lora(4) + self.assertLen(layer.trainable_weights, 2) + self.assertLen(layer.non_trainable_weights, 1) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 3) + # Try eager call + x = np.random.randint(0, 9, size=(64, 3)) + y = np.random.random((64, 3, 16)) + _ = layer(x[:2]) + + init_lora_a_embeddings_value = layer.lora_embeddings_a.numpy() + init_lora_b_embeddings_value = layer.lora_embeddings_b.numpy() + + # Try calling fit() + model = models.Sequential( + [ + layer, + ] + ) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + final_lora_a_embeddings_value = layer.lora_embeddings_a.numpy() + final_lora_b_embeddings_value = layer.lora_embeddings_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_embeddings_value - final_lora_a_embeddings_value) + ) + diff_b = np.max( + np.abs(init_lora_b_embeddings_value - final_lora_b_embeddings_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "lora_model.keras") + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Load the file into a fresh, non-lora model + new_model = models.Sequential( + [ + layers.Input((3,), dtype="int32"), + layers.Embedding(10, 16), + ] + ) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @pytest.mark.requires_trainable_backend + def test_enable_lora_with_alpha(self): + # Create an `Embedding` layer without specifying `lora_rank` + layer = layers.Embedding(input_dim=3, output_dim=2) + layer.build((None,)) # Build the layer + + # Set the base embeddings to known values. + base_emb = np.array( + [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=np.float32 + ) + layer.embeddings.assign(base_emb) + + # Enable LoRA with a custom alpha: `rank`=2, `lora_alpha`=3.0. + layer.enable_lora(2, lora_alpha=3.0) + self.assertEqual(layer.lora_rank, 2) + self.assertEqual(layer.lora_alpha, 3.0) + + # Manually assign known values to lora weights. + a_val = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]], dtype=np.float32) + b_val = np.array([[0.5, 0.5], [0.6, 0.6]], dtype=np.float32) + layer.lora_embeddings_a.assign(a_val) + layer.lora_embeddings_b.assign(b_val) + + # Compute the expected delta. + # Scaling factor: (3.0 / 2) = 1.5 + effective_delta = 1.5 * np.matmul(a_val, b_val) + expected_embeddings = base_emb + effective_delta + + # Verify that the effective embeddings match expectation. + actual_embeddings = ops.convert_to_numpy(layer.embeddings) + self.assertAllClose(actual_embeddings, expected_embeddings) + + @pytest.mark.requires_trainable_backend + def test_lora_rank_argument(self): + self.run_layer_test( + layers.Embedding, + init_kwargs={"input_dim": 5, "output_dim": 4, "lora_rank": 2}, + input_shape=(2, 3), + input_dtype="int32", + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_enable_lora_with_embeddings_constraint(self): + layer = layers.Embedding( + input_dim=10, output_dim=16, embeddings_constraint="max_norm" + ) + with self.assertRaisesRegex( + ValueError, "incompatible with embedding constraints" + ): + layer.enable_lora(rank=2) + + def test_enable_lora_when_already_enabled(self): + layer = layers.Embedding(input_dim=10, output_dim=16) + layer.build() + layer.enable_lora(rank=2) + with self.assertRaisesRegex(ValueError, "lora is already enabled"): + layer.enable_lora(rank=2) + + # Test quantization-related methods. + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ) + def test_quantize_int(self, mode): + layer = layers.Embedding(10, 16) + layer.build() + x = np.random.randint(0, 9, size=(64, 3)) + y_float = layer(x) + layer.quantize(mode) + + # Verify the dtype of the weights. + # The embeddings's dtype is int8, despite the int4 quantization, because + # we pack the int4 values into int8. + self.assertEqual( + backend.standardize_dtype(layer._embeddings.dtype), "int8" + ) + self.assertEqual( + backend.standardize_dtype(layer.embeddings_scale.dtype), + layer.variable_dtype, + ) + + # Verify the unpacked embeddings for int4 quantization. + if mode == "int4": + self.assertAllClose( + layer.embeddings, + quantizers.unpack_int4( + layer._embeddings, layer.output_dim, axis=-1 + ), + ) + + # Verify the correctness of the outputs. + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, 1e-3) # A weak correctness test + + # Check model save / load round-trip. + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Check weights-only save / load round-trip. + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Embedding(10, 16)]) + new_model.build((None, 3)) + new_model.quantize(mode) + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ) + def test_quantize_on_unbuilt_layer(self, mode): + layer = layers.Embedding(10, 16) + with self.assertRaisesRegex( + ValueError, "Cannot quantize a layer that isn't yet built." + ): + layer.quantize(mode) + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ) + def test_quantize_on_subclass(self, mode): + class MyEmbedding(layers.Embedding): + pass + + layer = MyEmbedding(10, 16) + layer.build() + with self.assertRaises(NotImplementedError): + layer.quantize(mode) + + layer.quantize(mode, type_check=False) # No error + + @parameterized.named_parameters( + ("int8", "int8"), + ("int4", "int4"), + ) + def test_quantize_when_already_quantized(self, mode): + layer = layers.Embedding(10, 16) + layer.build() + layer.quantize(mode) + for m in ("int8", "int4"): + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + layer = layers.Embedding(10, 16, dtype=f"{mode}_from_float32") + layer.build() + for m in ("int8", "int4"): + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + @parameterized.named_parameters( + ("int8", "int8_from_float32", 2), + ("int4", "int4_from_float32", 2), + ) + def test_quantize_by_setting_dtype_policy( + self, policy, expected_num_variables + ): + layer = layers.Embedding(10, 16) + layer.build() + layer.dtype_policy = policy + self.assertLen(layer.variables, expected_num_variables) + + @parameterized.named_parameters( + ("int7", "int7"), + ("float7", "float7"), + ) + def test_quantize_invalid_mode(self, mode): + layer = layers.Embedding(10, 16) + layer.build() + x = np.random.randint(0, 9, size=(1, 3)) + # dtype_policy should not be altered by failed quantization + original_dtype_policy = layer.dtype_policy + + # Test quantize + with self.assertRaisesRegex(ValueError, "Invalid quantization mode."): + layer.quantize(mode) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + # Test quantized_build + with self.assertRaisesRegex( + NotImplementedError, "Invalid quantization mode." + ): + layer.quantized_build((None, 2), mode) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + # Test quantized_call + with self.assertRaisesRegex( + NotImplementedError, "Invalid quantization mode." + ): + # Explicitly set quantization_mode + layer._dtype_policy._quantization_mode = mode + layer.quantized_call(x) + self.assertEqual(layer.dtype_policy, original_dtype_policy) + + @parameterized.named_parameters( + ("int8", "int8_from_mixed_bfloat16", 0, 2), + ("int4", "int4_from_mixed_bfloat16", 0, 2), + ) + @pytest.mark.requires_trainable_backend + def test_quantize_dtype_argument( + self, dtype, num_trainable_weights, num_non_trainable_weights + ): + self.run_layer_test( + layers.Embedding, + {"input_dim": 4, "output_dim": 3, "dtype": dtype}, + input_shape=(2,), + input_dtype="int32", + expected_output_shape=(2, 3), + expected_num_trainable_weights=num_trainable_weights, + expected_num_non_trainable_weights=num_non_trainable_weights, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.Embedding, + { + "input_dim": 5, + "output_dim": 4, + "mask_zero": True, + "dtype": dtype, + }, + input_shape=(2, 3), + input_dtype="int64", + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=num_trainable_weights, + expected_num_non_trainable_weights=num_non_trainable_weights, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + @parameterized.named_parameters( + ("int8", "int8", 2, 2, 4), + ("int4", "int4", 2, 2, 4), + ) + @pytest.mark.requires_trainable_backend + def test_quantize_lora_integration( + self, + mode, + num_trainable_weights, + num_non_trainable_weights, + num_torch_params, + ): + layer = layers.Embedding(10, 16) + layer.build() + layer.enable_lora(4) + layer.quantize(mode) + self.assertLen(layer.trainable_weights, num_trainable_weights) + self.assertLen(layer.non_trainable_weights, num_non_trainable_weights) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, num_torch_params) + + # Try calling fit() + init_lora_a_embeddings_value = layer.lora_embeddings_a.numpy() + init_lora_b_embeddings_value = layer.lora_embeddings_b.numpy() + x = np.random.randint(0, 9, size=(64, 3)) + y = np.random.random((64, 3, 16)) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y) + + final_lora_a_embeddings_value = layer.lora_embeddings_a.numpy() + final_lora_b_embeddings_value = layer.lora_embeddings_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_embeddings_value - final_lora_a_embeddings_value) + ) + diff_b = np.max( + np.abs(init_lora_b_embeddings_value - final_lora_b_embeddings_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Try saving and reloading the model + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential( + [layers.Input((3,), dtype="int32"), layers.Embedding(10, 16)] + ) + new_model.quantize(mode) + new_model.load_weights(temp_filepath) + self.assertFalse(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal((32, 3)) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose( + reloaded_layer(ref_input), ref_output, atol=1e-7 + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_legacy_load_own_variables(self): + # In previous versions, `load_own_variables` accepted a store with + # numeric keys. + float32_store = { + "0": np.random.random((10, 16)).astype("float32"), + } + int8_store = { + "0": np.random.randint(-128, 127, size=(10, 16), dtype="int8"), + "1": np.random.random((10,)).astype("float32"), + } + int4_store = { + "0": np.random.randint(-128, 127, size=(10, 8), dtype="int8"), + "1": np.random.random((10,)).astype("float32"), + } + + # Test float32 layer. + layer = layers.Embedding(10, 16) + layer.build() + layer.load_own_variables(float32_store) + self.assertAllClose(layer._embeddings, float32_store["0"]) + + # Test int8-quantized layer. + layer = layers.Embedding(10, 16, dtype="int8_from_float32") + layer.build() + layer.load_own_variables(int8_store) + self.assertAllClose(layer._embeddings, int8_store["0"]) + self.assertAllClose(layer.embeddings_scale, int8_store["1"]) + + # Test int4-quantized layer. + layer = layers.Embedding(10, 16, dtype="int4_from_float32") + layer.build() + layer.load_own_variables(int4_store) + self.assertAllClose(layer._embeddings, int4_store["0"]) + self.assertAllClose(layer.embeddings_scale, int4_store["1"]) diff --git a/keras/src/layers/core/identity.py b/keras/src/layers/core/identity.py new file mode 100644 index 000000000000..206835831bcd --- /dev/null +++ b/keras/src/layers/core/identity.py @@ -0,0 +1,31 @@ +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Identity") +class Identity(Layer): + """Identity layer. + + This layer should be used as a placeholder when no operation is to be + performed. The layer just returns its `inputs` argument as output. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs): + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def compute_output_spec(self, inputs): + return tree.map_structure( + lambda x: KerasTensor(x.shape, dtype=x.dtype, sparse=x.sparse), + inputs, + ) diff --git a/keras/src/layers/core/identity_test.py b/keras/src/layers/core/identity_test.py new file mode 100644 index 000000000000..c0b5af9214c3 --- /dev/null +++ b/keras/src/layers/core/identity_test.py @@ -0,0 +1,34 @@ +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class IdentityTest(testing.TestCase): + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + @pytest.mark.requires_trainable_backend + def test_identity_basics(self, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors.") + self.run_layer_test( + layers.Identity, + init_kwargs={}, + input_shape=(2, 3), + input_sparse=sparse, + expected_output_shape=(2, 3), + expected_output_sparse=sparse, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + run_training_check=not sparse, + supports_masking=True, + assert_built_after_instantiation=True, + ) diff --git a/keras/src/layers/core/input_layer.py b/keras/src/layers/core/input_layer.py new file mode 100644 index 000000000000..abad4617e90b --- /dev/null +++ b/keras/src/layers/core/input_layer.py @@ -0,0 +1,220 @@ +import warnings + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.ops.node import Node + + +@keras_export("keras.layers.InputLayer") +class InputLayer(Layer): + def __init__( + self, + shape=None, + batch_size=None, + dtype=None, + sparse=None, + ragged=None, + batch_shape=None, + input_tensor=None, + optional=False, + name=None, + **kwargs, + ): + super().__init__(name=name) + + if "input_shape" in kwargs: + warnings.warn( + "Argument `input_shape` is deprecated. Use `shape` instead." + ) + shape = kwargs.pop("input_shape") + if "batch_input_shape" in kwargs: + batch_shape = kwargs.pop("batch_input_shape") + + if input_tensor is not None: + if not isinstance(input_tensor, backend.KerasTensor): + raise ValueError( + "Argument `input_tensor` must be a KerasTensor. " + f"Received invalid type: input_tensor={input_tensor} " + f"(of type {type(input_tensor)})" + ) + if batch_size is not None: + if ( + len(input_tensor.shape) < 1 + or input_tensor.shape[0] != batch_size + ): + raise ValueError( + "When providing the `input_tensor` argument, you " + "cannot provide an incompatible `batch_size` argument." + ) + if shape is not None: + if ( + len(shape) != len(input_tensor.shape) - 1 + or shape != input_tensor.shape[1:] + ): + raise ValueError( + "When providing the `input_tensor` argument, you " + "cannot provide an incompatible `shape` argument." + ) + if batch_shape is not None and batch_shape != input_tensor.shape: + raise ValueError( + "When providing the `input_tensor` argument, you " + "cannot provide an incompatible `batch_shape` argument." + ) + if dtype is not None and input_tensor.dtype != dtype: + raise ValueError( + "When providing the `input_tensor` argument, you " + "cannot provide an incompatible `dtype` argument." + ) + if sparse is not None and input_tensor.sparse != sparse: + raise ValueError( + "When providing the `input_tensor` argument, you " + "cannot provide an incompatible `sparse` argument." + ) + batch_shape = input_tensor.shape + dtype = input_tensor.dtype + sparse = input_tensor.sparse + else: + if shape is not None and batch_shape is not None: + raise ValueError( + "You cannot pass both `shape` and `batch_shape` at the " + "same time." + ) + if batch_size is not None and batch_shape is not None: + raise ValueError( + "You cannot pass both `batch_size` and `batch_shape` " + "at the same time." + ) + if shape is None and batch_shape is None: + raise ValueError("You must pass a `shape` argument.") + + if shape is not None: + shape = backend.standardize_shape(shape) + batch_shape = (batch_size,) + shape + + self._batch_shape = backend.standardize_shape(batch_shape) + self._dtype = backend.standardize_dtype(dtype) + self.sparse = bool(sparse) + if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS: + raise ValueError( + f"`sparse=True` is not supported with the {backend.backend()} " + "backend" + ) + self.ragged = bool(ragged) + if self.ragged and not backend.SUPPORTS_RAGGED_TENSORS: + raise ValueError( + f"`ragged=True` is not supported with the {backend.backend()} " + "backend" + ) + + if input_tensor is None: + input_tensor = backend.KerasTensor( + shape=batch_shape, + dtype=dtype, + sparse=sparse, + ragged=ragged, + name=name, + ) + self._input_tensor = input_tensor + Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor) + self.built = True + self.optional = optional + + def call(self): + return + + @property + def batch_shape(self): + return self._batch_shape + + @property + def dtype(self): + return self._dtype + + def get_config(self): + return { + "batch_shape": self.batch_shape, + "dtype": self.dtype, + "sparse": self.sparse, + "ragged": self.ragged, + "name": self.name, + } + + +@keras_export(["keras.layers.Input", "keras.Input"]) +def Input( + shape=None, + batch_size=None, + dtype=None, + sparse=None, + ragged=None, + batch_shape=None, + name=None, + tensor=None, + optional=False, +): + """Used to instantiate a Keras tensor. + + A Keras tensor is a symbolic tensor-like object, which we augment with + certain attributes that allow us to build a Keras model just by knowing the + inputs and outputs of the model. + + For instance, if `a`, `b` and `c` are Keras tensors, + it becomes possible to do: + `model = Model(input=[a, b], output=c)` + + Args: + shape: A shape tuple (tuple of integers or `None` objects), + not including the batch size. + For instance, `shape=(32,)` indicates that the expected input + will be batches of 32-dimensional vectors. Elements of this tuple + can be `None`; `None` elements represent dimensions where the shape + is not known and may vary (e.g. sequence length). + batch_size: Optional static batch size (integer). + dtype: The data type expected by the input, as a string + (e.g. `"float32"`, `"int32"`...) + sparse: A boolean specifying whether the expected input will be sparse + tensors. Note that, if `sparse` is `False`, sparse tensors can still + be passed into the input - they will be densified with a default + value of 0. This feature is only supported with the TensorFlow and + the JAX backends. Defaults to `False`. + ragged: A boolean specifying whether the expected input will be ragged + tensors. Note that, if `ragged` is `False`, ragged tensors can still + be passed into the input - they will be densified with a default + value of 0. This feature is only supported with the TensorFlow + backend. Defaults to `False`. + batch_shape: Optional shape tuple (tuple of integers or `None` objects), + including the batch size. + name: Optional name string for the layer. + Should be unique in a model (do not reuse the same name twice). + It will be autogenerated if it isn't provided. + tensor: Optional existing tensor to wrap into the `Input` layer. + If set, the layer will use this tensor rather + than creating a new placeholder tensor. + optional: Boolean, whether the input is optional or not. + An optional input can accept `None` values. + + Returns: + A Keras tensor. + + Example: + + ```python + # This is a logistic regression in Keras + x = Input(shape=(32,)) + y = Dense(16, activation='softmax')(x) + model = Model(x, y) + ``` + """ + layer = InputLayer( + shape=shape, + batch_size=batch_size, + dtype=dtype, + sparse=sparse, + ragged=ragged, + batch_shape=batch_shape, + name=name, + input_tensor=tensor, + optional=optional, + ) + return layer.output diff --git a/keras/src/layers/core/input_layer_test.py b/keras/src/layers/core/input_layer_test.py new file mode 100644 index 000000000000..766a07edb634 --- /dev/null +++ b/keras/src/layers/core/input_layer_test.py @@ -0,0 +1,190 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.backend import KerasTensor +from keras.src.layers import InputLayer + + +class InputLayerTest(testing.TestCase): + # Testing happy path for layer without input tensor + @parameterized.named_parameters( + [ + {"testcase_name": "dense"}, + {"testcase_name": "sparse", "sparse": True}, + {"testcase_name": "ragged", "ragged": True}, + ] + ) + def test_input_basic(self, sparse=False, ragged=False): + input_shape = (2, 3) + batch_size = 4 + dtype = "float32" + ndim = len(tuple((batch_size,) + input_shape)) + + init_kwargs = { + "shape": input_shape, + "batch_size": batch_size, + "dtype": dtype, + "sparse": sparse, + "ragged": ragged, + } + + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + with self.assertRaisesRegex( + ValueError, "`sparse=True` is not supported" + ): + InputLayer(**init_kwargs) + return + if ragged and not backend.SUPPORTS_RAGGED_TENSORS: + with self.assertRaisesRegex( + ValueError, "`ragged=True` is not supported" + ): + InputLayer(**init_kwargs) + return + + values = InputLayer(**init_kwargs) + + self.assertEqual(values.dtype, dtype) + self.assertEqual(values.batch_shape[0], batch_size) + self.assertEqual(values.batch_shape[1:], input_shape) + self.assertEqual(values.sparse, sparse) + self.assertEqual(values.ragged, ragged) + self.assertEqual(values.trainable, True) + self.assertIsInstance(values.output, KerasTensor) + self.assertEqual(values.output.ndim, ndim) + self.assertEqual(values.output.dtype, dtype) + self.assertEqual(values.output.sparse, sparse) + self.assertEqual(values.output.ragged, ragged) + + # Testing shape is not None and batch_shape is not None condition + def test_input_error1(self): + input_shape = (2, 3) + + with self.assertRaisesRegex( + ValueError, "cannot pass both `shape` and `batch_shape`" + ): + InputLayer(shape=input_shape, batch_shape=input_shape) + + # Testing batch_size is not None and batch_shape is not None + def test_input_error2(self): + input_shape = (2, 3) + batch_size = 4 + + with self.assertRaisesRegex( + ValueError, "cannot pass both `batch_size` and `batch_shape`" + ): + InputLayer(batch_size=batch_size, batch_shape=input_shape) + + # Testing shape is None and batch_shape is None + def test_input_error3(self): + with self.assertRaisesRegex(ValueError, "pass a `shape` argument."): + InputLayer(shape=None, batch_shape=None) + + # Testing Input tensor is not Keras tensor + def test_input_tensor_error(self): + input_shape = (2, 3) + batch_size = 4 + input_tensor = np.zeros(input_shape) + + with self.assertRaisesRegex( + ValueError, "Argument `input_tensor` must be a KerasTensor" + ): + InputLayer( + shape=input_shape, + batch_size=batch_size, + input_tensor=input_tensor, + ) + + # Testing happy path for layer with input tensor + def testing_input_tensor(self): + input_shape = (2, 3) + dtype = "float32" + input_tensor = KerasTensor(shape=input_shape, dtype=dtype) + + layer = InputLayer( + input_tensor=input_tensor, + ) + + self.assertEqual(layer.dtype, dtype) + self.assertEqual(layer.batch_shape, (2, 3)) + self.assertEqual(layer.trainable, True) + self.assertIsInstance(layer.output, KerasTensor) + self.assertEqual(layer.output, input_tensor) + self.assertEqual(layer.output.ndim, input_tensor.ndim) + self.assertEqual(layer.output.dtype, dtype) + + def test_input_shape_deprecated(self): + input_shape = (2, 3) + batch_size = 4 + dtype = "float32" + + with self.assertWarnsRegex( + UserWarning, + "Argument `input_shape` is deprecated. Use `shape` instead.", + ): + layer = InputLayer( + input_shape=input_shape, batch_size=batch_size, dtype=dtype + ) + + self.assertEqual(layer.batch_shape[0], batch_size) + self.assertEqual(layer.batch_shape[1:], input_shape) + self.assertEqual(layer.dtype, dtype) + self.assertIsInstance(layer.output, KerasTensor) + + def test_call_method(self): + layer = InputLayer(shape=(32,)) + output = layer.call() + self.assertIsNone(output) + + def test_numpy_shape(self): + # non-python int type shapes should be ok + InputLayer(shape=(np.int64(32),)) + + def test_invalid_arg_combinations(self): + input_tensor = KerasTensor(shape=(2, 3), dtype="float32") + + with self.assertRaisesRegex( + ValueError, "cannot provide an incompatible `shape`" + ): + _ = InputLayer( + shape=(2, 4), + input_tensor=input_tensor, + ) + with self.assertRaisesRegex( + ValueError, "cannot provide an incompatible `batch_shape`" + ): + _ = InputLayer( + batch_shape=(2, 4), + input_tensor=input_tensor, + ) + with self.assertRaisesRegex( + ValueError, "cannot provide an incompatible `batch_size`" + ): + _ = InputLayer( + batch_size=5, + input_tensor=input_tensor, + ) + with self.assertRaisesRegex( + ValueError, "cannot provide an incompatible `dtype`" + ): + _ = InputLayer( + dtype="float16", + input_tensor=input_tensor, + ) + with self.assertRaisesRegex( + ValueError, "cannot provide an incompatible `sparse`" + ): + _ = InputLayer( + sparse=True, + input_tensor=input_tensor, + ) + + # This works + _ = InputLayer( + shape=(3,), + batch_size=2, + sparse=False, + dtype="float32", + input_tensor=input_tensor, + ) diff --git a/keras/src/layers/core/lambda_layer.py b/keras/src/layers/core/lambda_layer.py new file mode 100644 index 000000000000..f782f4e0b22f --- /dev/null +++ b/keras/src/layers/core/lambda_layer.py @@ -0,0 +1,232 @@ +import inspect +import types + +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.saving import serialization_lib +from keras.src.utils import python_utils + + +@keras_export("keras.layers.Lambda") +class Lambda(Layer): + """Wraps arbitrary expressions as a `Layer` object. + + The `Lambda` layer exists so that arbitrary expressions can be used + as a `Layer` when constructing Sequential + and Functional API models. `Lambda` layers are best suited for simple + operations or quick experimentation. For more advanced use cases, + prefer writing new subclasses of `Layer`. + + WARNING: `Lambda` layers have (de)serialization limitations! + + The main reason to subclass `Layer` instead of using a + `Lambda` layer is saving and inspecting a model. `Lambda` layers + are saved by serializing the Python bytecode, which is fundamentally + non-portable and potentially unsafe. + They should only be loaded in the same environment where + they were saved. Subclassed layers can be saved in a more portable way + by overriding their `get_config()` method. Models that rely on + subclassed Layers are also often easier to visualize and reason about. + + Example: + + ```python + # add a x -> x^2 layer + model.add(Lambda(lambda x: x ** 2)) + ``` + + Args: + function: The function to be evaluated. Takes input tensor as first + argument. + output_shape: Expected output shape from function. This argument + can usually be inferred if not explicitly provided. + Can be a tuple or function. If a tuple, it only specifies + the first dimension onward; sample dimension is assumed + either the same as the input: + `output_shape = (input_shape[0], ) + output_shape` or, + the input is `None` and the sample dimension is also `None`: + `output_shape = (None, ) + output_shape`. + If a function, it specifies the + entire shape as a function of the input shape: + `output_shape = f(input_shape)`. + mask: Either None (indicating no masking) or a callable with the same + signature as the `compute_mask` layer method, or a tensor + that will be returned as output mask regardless + of what the input is. + arguments: Optional dictionary of keyword arguments to be passed to the + function. + """ + + def __init__( + self, function, output_shape=None, mask=None, arguments=None, **kwargs + ): + super().__init__(**kwargs) + + self.arguments = arguments or {} + self.function = function + + if mask is not None: + self.supports_masking = True + else: + self.supports_masking = False + self.mask = mask + self._output_shape = output_shape + + # Warning on every invocation will be quite irksome in Eager mode. + self._already_warned = False + + function_args = inspect.getfullargspec(function).args + self._fn_expects_training_arg = "training" in function_args + self._fn_expects_mask_arg = "mask" in function_args + + def compute_output_shape(self, input_shape): + if self._output_shape is None: + # Leverage backend shape inference + try: + inputs = tree.map_shape_structure( + lambda x: backend.KerasTensor(x, dtype=self.compute_dtype), + input_shape, + ) + output_spec = backend.compute_output_spec(self.call, inputs) + return tree.map_structure(lambda x: x.shape, output_spec) + except: + raise NotImplementedError( + "We could not automatically infer the shape of " + "the Lambda's output. Please specify the `output_shape` " + "argument for this Lambda layer." + ) + + if callable(self._output_shape): + return self._output_shape(input_shape) + + # Output shapes are passed directly and don't include batch dimension. + batch_size = tree.flatten(input_shape)[0] + + def _add_batch(shape): + return (batch_size,) + shape + + return tree.map_shape_structure(_add_batch, self._output_shape) + + def call(self, inputs, mask=None, training=None): + # We must copy for thread safety, + # but it only needs to be a shallow copy. + kwargs = {k: v for k, v in self.arguments.items()} + if self._fn_expects_mask_arg: + kwargs["mask"] = mask + if self._fn_expects_training_arg: + kwargs["training"] = training + return self.function(inputs, **kwargs) + + def compute_mask(self, inputs, mask=None): + if callable(self.mask): + return self.mask(inputs, mask) + return self.mask + + def get_config(self): + config = { + "function": self._serialize_function_to_config(self.function), + } + if self._output_shape is not None: + if callable(self._output_shape): + output_shape = self._serialize_function_to_config( + self._output_shape + ) + else: + output_shape = self._output_shape + config["output_shape"] = output_shape + if self.mask is not None: + if callable(self.mask): + mask = self._serialize_function_to_config(self.mask) + else: + mask = serialization_lib.serialize_keras_object(self.mask) + config["mask"] = mask + config["arguments"] = serialization_lib.serialize_keras_object( + self.arguments + ) + base_config = super().get_config() + return {**base_config, **config} + + def _serialize_function_to_config(self, fn): + if isinstance(fn, types.LambdaType) and fn.__name__ == "": + code, defaults, closure = python_utils.func_dump(fn) + return { + "class_name": "__lambda__", + "config": { + "code": code, + "defaults": defaults, + "closure": closure, + }, + } + elif callable(fn): + return serialization_lib.serialize_keras_object(fn) + raise ValueError( + "Invalid input type for serialization. " + f"Received: {fn} of type {type(fn)}." + ) + + @staticmethod + def _raise_for_lambda_deserialization(safe_mode): + if safe_mode: + raise ValueError( + "Requested the deserialization of a `Lambda` layer whose " + "`function` is a Python lambda. This carries a potential risk " + "of arbitrary code execution and thus it is disallowed by " + "default. If you trust the source of the artifact, you can " + "override this error by passing `safe_mode=False` to the " + "loading function, or calling " + "`keras.config.enable_unsafe_deserialization()." + ) + + @classmethod + def from_config(cls, config, custom_objects=None, safe_mode=None): + safe_mode = safe_mode or serialization_lib.in_safe_mode() + fn_config = config["function"] + if ( + isinstance(fn_config, dict) + and "class_name" in fn_config + and fn_config["class_name"] == "__lambda__" + ): + cls._raise_for_lambda_deserialization(safe_mode) + inner_config = fn_config["config"] + fn = python_utils.func_load( + inner_config["code"], + defaults=inner_config["defaults"], + closure=inner_config["closure"], + ) + config["function"] = fn + else: + config["function"] = serialization_lib.deserialize_keras_object( + fn_config, custom_objects=custom_objects + ) + if "output_shape" in config: + fn_config = config["output_shape"] + if ( + isinstance(fn_config, dict) + and "class_name" in fn_config + and fn_config["class_name"] == "__lambda__" + ): + cls._raise_for_lambda_deserialization(safe_mode) + inner_config = fn_config["config"] + fn = python_utils.func_load( + inner_config["code"], + defaults=inner_config["defaults"], + closure=inner_config["closure"], + ) + config["output_shape"] = fn + else: + output_shape = serialization_lib.deserialize_keras_object( + fn_config, custom_objects=custom_objects + ) + if isinstance(output_shape, list) and all( + isinstance(e, (int, type(None))) for e in output_shape + ): + output_shape = tuple(output_shape) + config["output_shape"] = output_shape + + if "arguments" in config: + config["arguments"] = serialization_lib.deserialize_keras_object( + config["arguments"], custom_objects=custom_objects + ) + return cls(**config) diff --git a/keras/src/layers/core/lambda_layer_test.py b/keras/src/layers/core/lambda_layer_test.py new file mode 100644 index 000000000000..1f80bcb0206b --- /dev/null +++ b/keras/src/layers/core/lambda_layer_test.py @@ -0,0 +1,94 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class LambdaTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_lambda_basics(self): + self.run_layer_test( + layers.Lambda, + init_kwargs={ + "function": ops.square, + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + custom_objects={"square": ops.square}, + ) + self.run_layer_test( + layers.Lambda, + init_kwargs={"function": ops.square, "mask": ops.ones((2, 3))}, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + custom_objects={"square": ops.square}, + ) + + def stacker(x): + return ops.concatenate([x, x], axis=1) + + self.run_layer_test( + layers.Lambda, + init_kwargs={"function": stacker, "output_shape": (6,)}, + input_shape=(2, 3), + expected_output_shape=(2, 6), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + custom_objects={"stacker": stacker}, + ) + + def stacker_shape(s): + return (s[0], s[1] * 2) + + self.run_layer_test( + layers.Lambda, + init_kwargs={ + "function": stacker, + "output_shape": stacker_shape, + }, + input_shape=(2, 3), + expected_output_shape=(2, 6), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + custom_objects={"stacker": stacker, "stacker_shape": stacker_shape}, + ) + + def test_correctness(self): + layer = layers.Lambda(lambda x: x**2) + output = layer(2 * np.ones((2, 3))) + self.assertAllClose(4 * np.ones((2, 3)), output) + + # Test serialization roundtrip + config = layer.get_config() + layer = layers.Lambda.from_config(config, safe_mode=False) + output = layer(2 * np.ones((2, 3))) + self.assertAllClose(4 * np.ones((2, 3)), output) + + def test_correctness_lambda_shape(self): + layer = layers.Lambda(lambda x: x**2, output_shape=lambda x: x) + output = layer(2 * np.ones((2, 3))) + self.assertAllClose(4 * np.ones((2, 3)), output) + + # Test serialization roundtrip + config = layer.get_config() + layer = layers.Lambda.from_config(config, safe_mode=False) + output = layer(2 * np.ones((2, 3))) + self.assertAllClose(4 * np.ones((2, 3)), output) diff --git a/keras/src/layers/core/masking.py b/keras/src/layers/core/masking.py new file mode 100644 index 000000000000..692c322d0aae --- /dev/null +++ b/keras/src/layers/core/masking.py @@ -0,0 +1,76 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.saving.serialization_lib import deserialize_keras_object + + +@keras_export("keras.layers.Masking") +class Masking(Layer): + """Masks a sequence by using a mask value to skip timesteps. + + For each timestep in the input tensor (dimension #1 in the tensor), + if all values in the input tensor at that timestep + are equal to `mask_value`, then the timestep will be masked (skipped) + in all downstream layers (as long as they support masking). + + If any downstream layer does not support masking yet receives such + an input mask, an exception will be raised. + + Example: + + Consider a NumPy data array `x` of shape `(samples, timesteps, features)`, + to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you + lack data for these timesteps. You can: + + - Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.` + - Insert a `Masking` layer with `mask_value=0.` before the LSTM layer: + + ```python + samples, timesteps, features = 32, 10, 8 + inputs = np.random.random([samples, timesteps, features]).astype(np.float32) + inputs[:, 3, :] = 0. + inputs[:, 5, :] = 0. + + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0.0)) + model.add(keras.layers.LSTM(32)) + output = model(inputs) + # The time step 3 and 5 will be skipped from LSTM calculation. + ``` + + Note: in the Keras masking convention, a masked timestep is denoted by + a mask value of `False`, while a non-masked (i.e. usable) timestep + is denoted by a mask value of `True`. + """ + + def __init__(self, mask_value=0.0, **kwargs): + super().__init__(**kwargs) + # `mask_value` can be a serialized tensor, hence verify it + if isinstance(mask_value, dict) and mask_value.get("config", None): + mask_value = deserialize_keras_object(mask_value) + self.mask_value = mask_value + self.supports_masking = True + + self._build_at_init() + + def compute_mask(self, inputs, mask=None): + return ops.any(ops.not_equal(inputs, self.mask_value), axis=-1) + + def call(self, inputs): + boolean_mask = ops.any( + ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True + ) + # Set masked outputs to 0 + outputs = inputs * backend.cast(boolean_mask, dtype=inputs.dtype) + # Compute the mask and outputs simultaneously. + backend.set_keras_mask(outputs, mask=ops.squeeze(boolean_mask, axis=-1)) + return outputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = {"mask_value": self.mask_value} + return {**base_config, **config} diff --git a/keras/src/layers/core/masking_test.py b/keras/src/layers/core/masking_test.py new file mode 100644 index 000000000000..224e7c7906db --- /dev/null +++ b/keras/src/layers/core/masking_test.py @@ -0,0 +1,83 @@ +import os + +import numpy as np +import pytest + +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.saving import load_model + + +class MaskingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_masking_basics(self): + self.run_layer_test( + layers.Masking, + init_kwargs={"mask_value": 0.0}, + input_shape=(2, 3, 2), + expected_output_shape=(2, 3, 2), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + + @pytest.mark.requires_trainable_backend + def test_masking_correctness(self): + x = np.array( + [ + [[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]], + [[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]], + ] + ) + expected_mask = [[False, True, False], [True, False, True]] + + layer = layers.Masking(mask_value=0.0) + self.assertAllClose(layer.compute_mask(x), expected_mask) + + test_obj = self + + class TestLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, inputs, mask=None): + assert mask is not None + test_obj.assertAllClose(mask, expected_mask) + return inputs + + model = models.Sequential( + [ + layers.Masking(mask_value=0.0), + TestLayer(), + ] + ) + model(x) + + @pytest.mark.requires_trainable_backend + def test_masking_with_tensor(self): + model = models.Sequential( + [ + layers.Masking(mask_value=ops.convert_to_tensor([0.0])), + layers.LSTM(1), + ] + ) + x = np.array( + [ + [[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]], + [[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]], + ] + ) + model(x) + temp_filepath = os.path.join(self.get_temp_dir(), "model.keras") + model.save(temp_filepath) + reload_model = load_model(temp_filepath) + reload_model(x) diff --git a/keras/src/layers/core/wrapper.py b/keras/src/layers/core/wrapper.py new file mode 100644 index 000000000000..8f4878919360 --- /dev/null +++ b/keras/src/layers/core/wrapper.py @@ -0,0 +1,46 @@ +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.saving import serialization_lib + + +@keras_export("keras.layers.Wrapper") +class Wrapper(Layer): + """Abstract wrapper base class. + + Wrappers take another layer and augment it in various ways. + Do not use this class as a layer, it is only an abstract base class. + Two usable wrappers are the `TimeDistributed` and `Bidirectional` layers. + + Args: + layer: The layer to be wrapped. + """ + + def __init__(self, layer, **kwargs): + try: + assert isinstance(layer, Layer) + except Exception: + raise ValueError( + f"Layer {layer} supplied to Wrapper isn't " + "a supported layer type. Please " + "ensure wrapped layer is a valid Keras layer." + ) + super().__init__(**kwargs) + self.layer = layer + + def build(self, input_shape=None): + if not self.layer.built: + self.layer.build(input_shape) + self.layer.built = True + + def get_config(self): + config = {"layer": serialization_lib.serialize_keras_object(self.layer)} + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + layer = serialization_lib.deserialize_keras_object( + config.pop("layer"), + custom_objects=custom_objects, + ) + return cls(layer, **config) diff --git a/keras/src/layers/core/wrapper_test.py b/keras/src/layers/core/wrapper_test.py new file mode 100644 index 000000000000..9302ca784240 --- /dev/null +++ b/keras/src/layers/core/wrapper_test.py @@ -0,0 +1,80 @@ +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class ExampleWrapper(layers.Wrapper): + """Simple Wrapper subclass.""" + + def call(self, inputs, **kwargs): + return ops.cast(self.layer(inputs, **kwargs), self.compute_dtype) + + +class WrapperTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_wrapper_basics(self): + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.Dense(2), + }, + input_shape=(2, 3), + expected_output_shape=(2, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.Dense(2, activity_regularizer="l2"), + }, + input_shape=(2, 3), + expected_output_shape=(2, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=1, + supports_masking=False, + ) + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.Dense(2), + "activity_regularizer": "l2", + }, + input_shape=(2, 3), + expected_output_shape=(2, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=1, + supports_masking=False, + ) + self.run_layer_test( + ExampleWrapper, + init_kwargs={ + "layer": layers.BatchNormalization(), + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_wrapper_invalid_layer(self): + invalid_layer = "This is not a valid Keras layer." + + with self.assertRaisesRegex( + ValueError, + "Layer .* supplied to Wrapper isn't a supported layer type. " + "Please ensure wrapped layer is a valid Keras layer.", + ): + layers.Wrapper(invalid_layer) diff --git a/keras/src/layers/input_spec.py b/keras/src/layers/input_spec.py new file mode 100644 index 000000000000..abc767fba5aa --- /dev/null +++ b/keras/src/layers/input_spec.py @@ -0,0 +1,250 @@ +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export + + +@keras_export(["keras.InputSpec", "keras.layers.InputSpec"]) +class InputSpec: + """Specifies the rank, dtype and shape of every input to a layer. + + Layers can expose (if appropriate) an `input_spec` attribute: + an instance of `InputSpec`, or a nested structure of `InputSpec` instances + (one per input tensor). These objects enable the layer to run input + compatibility checks for input structure, input rank, input shape, and + input dtype for the first argument of `Layer.__call__`. + + A `None` entry in a shape is compatible with any dimension. + + Args: + dtype: Expected dtype of the input. + shape: Shape tuple, expected shape of the input + (may include `None` for dynamic axes). + Includes the batch size. + ndim: Integer, expected rank of the input. + max_ndim: Integer, maximum rank of the input. + min_ndim: Integer, minimum rank of the input. + axes: Dictionary mapping integer axes to + a specific dimension value. + allow_last_axis_squeeze: If `True`, allow inputs of rank N+1 as long + as the last axis of the input is 1, as well as inputs of rank N-1 + as long as the last axis of the spec is 1. + name: Expected key corresponding to this input when passing data as + a dictionary. + optional: Boolean, whether the input is optional or not. + An optional input can accept `None` values. + + Example: + + ```python + class MyLayer(Layer): + def __init__(self): + super().__init__() + # The layer will accept inputs with + # shape (*, 28, 28) & (*, 28, 28, 1) + # and raise an appropriate error message otherwise. + self.input_spec = InputSpec( + shape=(None, 28, 28, 1), + allow_last_axis_squeeze=True) + ``` + """ + + def __init__( + self, + dtype=None, + shape=None, + ndim=None, + max_ndim=None, + min_ndim=None, + axes=None, + allow_last_axis_squeeze=False, + name=None, + optional=False, + ): + self.dtype = ( + backend.standardize_dtype(dtype) if dtype is not None else None + ) + if shape is not None: + self.shape = backend.standardize_shape(shape) + self.ndim = len(shape) + else: + self.ndim = ndim + self.shape = None + self.max_ndim = max_ndim + self.min_ndim = min_ndim + self.name = name + self.optional = optional + self.allow_last_axis_squeeze = allow_last_axis_squeeze + try: + axes = axes or {} + self.axes = {int(k): axes[k] for k in axes} + except (ValueError, TypeError): + raise TypeError( + "Argument `axes` must be a dict with integer keys. " + f"Received: axes={axes}" + ) + + if self.axes and (self.ndim is not None or self.max_ndim is not None): + max_dim = (self.ndim if self.ndim else self.max_ndim) - 1 + max_axis = max(self.axes) + if max_axis > max_dim: + raise ValueError( + "Axis {} is greater than the maximum " + "allowed value: {}".format(max_axis, max_dim) + ) + + def __repr__(self): + spec = [ + (f"dtype={str(self.dtype)}") if self.dtype else "", + (f"shape={str(self.shape)}") if self.shape else "", + (f"ndim={str(self.ndim)}") if self.ndim else "", + (f"max_ndim={str(self.max_ndim)}") if self.max_ndim else "", + (f"min_ndim={str(self.min_ndim)}") if self.min_ndim else "", + (f"axes={str(self.axes)}") if self.axes else "", + ] + return f"InputSpec({', '.join(x for x in spec if x)})" + + def get_config(self): + return { + "dtype": self.dtype, + "shape": self.shape, + "ndim": self.ndim, + "max_ndim": self.max_ndim, + "min_ndim": self.min_ndim, + "axes": self.axes, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def assert_input_compatibility(input_spec, inputs, layer_name): + """Checks compatibility between the layer and provided inputs. + + This checks that the tensor(s) `inputs` verify the input assumptions + of a layer (if any). If not, a clear and actional exception gets raised. + + Args: + input_spec: An InputSpec instance, list of InputSpec instances, a nested + structure of InputSpec instances, or None. + inputs: Input tensor, list of input tensors, or a nested structure of + input tensors. + layer_name: String, name of the layer (for error message formatting). + + Raises: + ValueError: in case of mismatch between + the provided inputs and the expectations of the layer. + """ + if not input_spec: + return + + input_spec = tree.flatten(input_spec) + if isinstance(inputs, dict): + # Flatten `inputs` by reference order if input spec names are provided + names = [spec.name for spec in input_spec] + if all(names): + list_inputs = [] + for name in names: + if name not in inputs: + raise ValueError( + f'Missing data for input "{name}". ' + "You passed a data dictionary with keys " + f"{list(inputs.keys())}. " + f"Expected the following keys: {names}" + ) + list_inputs.append(inputs[name]) + inputs = list_inputs + + inputs = tree.flatten(inputs) + if len(inputs) != len(input_spec): + raise ValueError( + f'Layer "{layer_name}" expects {len(input_spec)} input(s),' + f" but it received {len(inputs)} input tensors. " + f"Inputs received: {inputs}" + ) + for input_index, (x, spec) in enumerate(zip(inputs, input_spec)): + if spec is None: + continue + if x is None and spec.optional: + continue + + # Having a shape/dtype is the only commonality of the various + # tensor-like objects that may be passed. The most common kind of + # invalid type we are guarding for is a Layer instance (Functional API), + # which does not have a `shape` attribute. + if not hasattr(x, "shape"): + raise ValueError( + f"Inputs to a layer should be tensors. Got '{x}' " + f"(of type {type(x)}) as input for layer '{layer_name}'." + ) + + shape = backend.standardize_shape(x.shape) + ndim = len(shape) + # Check ndim. + if spec.ndim is not None and not spec.allow_last_axis_squeeze: + if ndim != spec.ndim: + raise ValueError( + f'Input {input_index} of layer "{layer_name}" ' + "is incompatible with the layer: " + f"expected ndim={spec.ndim}, found ndim={ndim}. " + f"Full shape received: {shape}" + ) + if spec.max_ndim is not None: + if ndim is not None and ndim > spec.max_ndim: + raise ValueError( + f'Input {input_index} of layer "{layer_name}" ' + "is incompatible with the layer: " + f"expected max_ndim={spec.max_ndim}, " + f"found ndim={ndim}" + ) + if spec.min_ndim is not None: + if ndim is not None and ndim < spec.min_ndim: + raise ValueError( + f'Input {input_index} of layer "{layer_name}" ' + "is incompatible with the layer: " + f"expected min_ndim={spec.min_ndim}, " + f"found ndim={ndim}. " + f"Full shape received: {shape}" + ) + # Check dtype. + if spec.dtype is not None: + dtype = backend.standardize_dtype(x.dtype) + if dtype != spec.dtype: + raise ValueError( + f'Input {input_index} of layer "{layer_name}" ' + "is incompatible with the layer: " + f"expected dtype={spec.dtype}, " + f"found dtype={dtype}" + ) + + # Check specific shape axes. + if spec.axes: + for axis, value in spec.axes.items(): + if value is not None and shape[axis] not in { + value, + None, + }: + raise ValueError( + f'Input {input_index} of layer "{layer_name}" is ' + f"incompatible with the layer: expected axis {axis} " + f"of input shape to have value {value}, " + "but received input with " + f"shape {shape}" + ) + # Check shape. + if spec.shape is not None: + spec_shape = spec.shape + if spec.allow_last_axis_squeeze: + if shape and shape[-1] == 1: + shape = shape[:-1] + if spec_shape and spec_shape[-1] == 1: + spec_shape = spec_shape[:-1] + for spec_dim, dim in zip(spec_shape, shape): + if spec_dim is not None and dim is not None: + if spec_dim != dim: + raise ValueError( + f'Input {input_index} of layer "{layer_name}" is ' + "incompatible with the layer: " + f"expected shape={spec.shape}, " + f"found shape={shape}" + ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py new file mode 100644 index 000000000000..11e4046c7b8a --- /dev/null +++ b/keras/src/layers/layer.py @@ -0,0 +1,2005 @@ +"""Layer is an Operation with state. + +Takes care of: + +- Weights / variables (and tracking thereof) +- deferred build +- trainable argument value inference +- masking +- autocasting + +And some more magic: + +- add_loss +- metric tracking +- RNG seed tracking +- activity regularization +""" + +import collections +import functools +import inspect +import math +import warnings +from functools import wraps + +from keras.src import backend +from keras.src import constraints +from keras.src import dtype_policies +from keras.src import initializers +from keras.src import regularizers +from keras.src import tree +from keras.src import utils +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend.common import global_state +from keras.src.backend.common import remat +from keras.src.backend.common.keras_tensor import any_symbolic_tensors +from keras.src.backend.common.name_scope import current_path +from keras.src.backend.common.remat import get_current_remat_mode +from keras.src.backend.common.symbolic_scope import in_symbolic_scope +from keras.src.backend.config import is_nnx_enabled +from keras.src.distribution import distribution_lib +from keras.src.dtype_policies import DTypePolicyMap +from keras.src.layers import input_spec +from keras.src.metrics.metric import Metric +from keras.src.ops.node import Node +from keras.src.ops.operation import Operation +from keras.src.utils import python_utils +from keras.src.utils import summary_utils +from keras.src.utils import traceback_utils +from keras.src.utils import tracking + +if backend.backend() == "tensorflow": + from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer +elif backend.backend() == "jax": + from keras.src.backend.jax.layer import JaxLayer as BackendLayer +elif backend.backend() == "torch": + from keras.src.backend.torch.layer import TorchLayer as BackendLayer +elif backend.backend() == "numpy": + from keras.src.backend.numpy.layer import NumpyLayer as BackendLayer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.layer import OpenvinoLayer as BackendLayer +else: + raise RuntimeError( + f"Backend '{backend.backend()}' must implement a layer mixin class." + ) + + +@keras_export(["keras.Layer", "keras.layers.Layer"]) +class Layer(BackendLayer, Operation): + """This is the class from which all layers inherit. + + A layer is a callable object that takes as input one or more tensors and + that outputs one or more tensors. It involves *computation*, defined + in the `call()` method, and a *state* (weight variables). State can be + created: + + * in `__init__()`, for instance via `self.add_weight()`; + * in the optional `build()` method, which is invoked by the first + `__call__()` to the layer, and supplies the shape(s) of the input(s), + which may not have been known at initialization time. + + Layers are recursively composable: If you assign a Layer instance as an + attribute of another Layer, the outer layer will start tracking the weights + created by the inner layer. Nested layers should be instantiated in the + `__init__()` method or `build()` method. + + Users will just instantiate a layer and then treat it as a callable. + + Args: + trainable: Boolean, whether the layer's variables should be trainable. + name: String name of the layer. + dtype: The dtype of the layer's computations and weights. Can also be a + `keras.DTypePolicy`, which allows the computation and weight dtype + to differ. Defaults to `None`. `None` means to use + `keras.config.dtype_policy()`, which is a `float32` policy unless + set to different value (via `keras.config.set_dtype_policy()`). + + Attributes: + name: The name of the layer (string). + dtype: Dtype of the layer's weights. Alias of `layer.variable_dtype`. + variable_dtype: Dtype of the layer's weights. + compute_dtype: The dtype of the layer's computations. + Layers automatically cast inputs to this dtype, which causes + the computations and output to also be in this dtype. + When mixed precision is used with a + `keras.DTypePolicy`, this will be different + than `variable_dtype`. + trainable_weights: List of variables to be included in backprop. + non_trainable_weights: List of variables that should not be + included in backprop. + weights: The concatenation of the lists trainable_weights and + non_trainable_weights (in this order). + trainable: Whether the layer should be trained (boolean), i.e. + whether its potentially-trainable weights should be returned + as part of `layer.trainable_weights`. + input_spec: Optional (list of) `InputSpec` object(s) specifying the + constraints on inputs that can be accepted by the layer. + + We recommend that descendants of `Layer` implement the following methods: + + * `__init__()`: Defines custom layer attributes, and creates layer weights + that do not depend on input shapes, using `add_weight()`, + or other state. + * `build(self, input_shape)`: This method can be used to create weights that + depend on the shape(s) of the input(s), using `add_weight()`, or other + state. `__call__()` will automatically build the layer + (if it has not been built yet) by calling `build()`. + * `call(self, *args, **kwargs)`: Called in `__call__` after making + sure `build()` has been called. `call()` performs the logic of applying + the layer to the input arguments. + Two reserved keyword arguments you can optionally use in `call()` are: + 1. `training` (boolean, whether the call is in inference mode or + training mode). + 2. `mask` (boolean tensor encoding masked timesteps in the input, + used e.g. in RNN layers). + A typical signature for this method is `call(self, inputs)`, and user + could optionally add `training` and `mask` if the layer need them. + * `get_config(self)`: Returns a dictionary containing the configuration + used to initialize this layer. If the keys differ from the arguments + in `__init__()`, then override `from_config(self)` as well. + This method is used when saving + the layer or a model that contains this layer. + + Examples: + + Here's a basic example: a layer with two variables, `w` and `b`, + that returns `y = w . x + b`. + It shows how to implement `build()` and `call()`. + Variables set as attributes of a layer are tracked as weights + of the layers (in `layer.weights`). + + ```python + class SimpleDense(Layer): + def __init__(self, units=32): + super().__init__() + self.units = units + + # Create the state of the layer (weights) + def build(self, input_shape): + self.kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="glorot_uniform", + trainable=True, + name="kernel", + ) + self.bias = self.add_weight( + shape=(self.units,), + initializer="zeros", + trainable=True, + name="bias", + ) + + # Defines the computation + def call(self, inputs): + return ops.matmul(inputs, self.kernel) + self.bias + + # Instantiates the layer. + linear_layer = SimpleDense(4) + + # This will also call `build(input_shape)` and create the weights. + y = linear_layer(ops.ones((2, 2))) + assert len(linear_layer.weights) == 2 + + # These weights are trainable, so they're listed in `trainable_weights`: + assert len(linear_layer.trainable_weights) == 2 + ``` + + Besides trainable weights, updated via backpropagation during training, + layers can also have non-trainable weights. These weights are meant to + be updated manually during `call()`. Here's a example layer that computes + the running sum of its inputs: + + ```python + class ComputeSum(Layer): + + def __init__(self, input_dim): + super(ComputeSum, self).__init__() + # Create a non-trainable weight. + self.total = self.add_weight( + shape=(), + initializer="zeros", + trainable=False, + name="total", + ) + + def call(self, inputs): + self.total.assign(self.total + ops.sum(inputs)) + return self.total + + my_sum = ComputeSum(2) + x = ops.ones((2, 2)) + y = my_sum(x) + + assert my_sum.weights == [my_sum.total] + assert my_sum.non_trainable_weights == [my_sum.total] + assert my_sum.trainable_weights == [] + ``` + """ + + def __new__(cls, *args, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + # Wrap the user-provided `build` method in the `build_wrapper` + # to add name scope support and serialization support. + original_build_method = obj.build + + @wraps(original_build_method) + def build_wrapper(*args, **kwargs): + with obj._open_name_scope(): + obj._path = current_path() + original_build_method(*args, **kwargs) + # Record build config. + signature = inspect.signature(original_build_method) + obj._build_shapes_dict = signature.bind(*args, **kwargs).arguments + # Set built, post build actions, and lock state. + obj.built = True + obj._post_build() + obj._lock_state() + + obj.build = build_wrapper + + # Wrap the user-provided `quantize` method in the `quantize_wrapper` + # to add tracker support. + original_quantize_method = obj.quantize + + @wraps(original_quantize_method) + def quantize_wrapper(mode, **kwargs): + obj._check_quantize_args(mode, obj.compute_dtype) + obj._tracker.unlock() + try: + original_quantize_method(mode, **kwargs) + except Exception: + raise + finally: + obj._tracker.lock() + + obj.quantize = quantize_wrapper + + return obj + + def __init__( + self, + *, + activity_regularizer=None, + trainable=True, + dtype=None, + autocast=True, + name=None, + **kwargs, + ): + BackendLayer.__init__(self) + self._lock = False + Operation.__init__(self, name=name) + self._dtype_policy = dtype_policies.get(dtype) + self.activity_regularizer = regularizers.get(activity_regularizer) + input_dim_arg = kwargs.pop("input_dim", None) + if input_dim_arg is not None: + input_shape_arg = (input_dim_arg,) + else: + input_shape_arg = kwargs.pop("input_shape", None) + if input_shape_arg is not None: + warnings.warn( + "Do not pass an `input_shape`/`input_dim` argument to " + "a layer. When using Sequential models, " + "prefer using an `Input(shape)` object as the " + "first layer in the model instead.", + stacklevel=2, + ) + self._input_shape_arg = input_shape_arg + if kwargs: + raise ValueError( + "Unrecognized keyword arguments " + f"passed to {self.__class__.__name__}: {kwargs}" + ) + + self._path = None # Will be determined in `build_wrapper` + self.built = False + self.autocast = autocast + self._input_spec = None + self._called = False + self.supports_jit = True + + self._trainable = trainable + self._losses = [] + self._loss_ids = set() + self._losses_override = [] + + self._call_signature = inspect.signature(self.call) + self.call_signature_parameters = [ + p.name for p in self._call_signature.parameters.values() + ] + self._call_has_training_arg = ( + "training" in self.call_signature_parameters + ) + self._call_has_mask_arg = "mask" in self.call_signature_parameters + + # 1. collect names that should be auto‑propagated + self._call_context_args = {"training"} + + # 2. remember which of them exist in *this* call signature + self._call_has_context_arg = { + arg: (arg in self.call_signature_parameters) + for arg in self._call_context_args + } + + self._supports_masking = not utils.is_default(self.compute_mask) + # Whether to automatically convert (+ auto-cast) inputs to `call()`. + self._convert_input_args = True + # Whether to allow non-tensors as positional arguments in `call()`. + self._allow_non_tensor_positional_args = False + # Dict of shapes that were used to call `build()`. + self._build_shapes_dict = None + # Parent path + self._parent_path = None + self._remat_mode = get_current_remat_mode() + self._initialize_tracker() + + @tracking.no_automatic_dependency_tracking + def _initialize_tracker(self): + if hasattr(self, "_tracker"): + return + + trainable_variables = [] + non_trainable_variables = [] + layers = [] + metrics = [] + seed_generators = [] + self._tracker = tracking.Tracker( + { + "trainable_variables": ( + lambda x: isinstance(x, backend.Variable) and x.trainable, + trainable_variables, + ), + "non_trainable_variables": ( + lambda x: isinstance(x, backend.Variable) + and not x.trainable, + non_trainable_variables, + ), + "metrics": (lambda x: isinstance(x, Metric), metrics), + "layers": ( + lambda x: isinstance(x, Layer) + and not isinstance(x, Metric), + layers, + ), + "seed_generators": ( + lambda x: isinstance(x, backend.random.SeedGenerator), + seed_generators, + ), + }, + exclusions={"non_trainable_variables": ["trainable_variables"]}, + ) + if backend.backend() == "tensorflow": + # Remove attribute tracking for lists (TF-specific attribute) + _self_setattr_tracking = getattr( + self, "_self_setattr_tracking", True + ) + self._self_setattr_tracking = False + + self._trainable_variables = trainable_variables + self._non_trainable_variables = non_trainable_variables + self._layers = layers + self._metrics = metrics + self._seed_generators = seed_generators + + if backend.backend() == "tensorflow": + # Reset attribute tracking (TF-specific) + self._self_setattr_tracking = _self_setattr_tracking + + def _build_at_init(self): + """Build the layer at `Layer.__init__`. + + We can only safely mark the layer as `built=True` in `Layer.__init__` if + `build` is not overridden. Otherwise, it might cause the subclasses to + ignore the user's `build`. + """ + if utils.is_default(self.build): + self.built = True + self._post_build() + self._lock_state() + + @property + def path(self): + """The path of the layer. + + If the layer has not been built yet, it will be `None`. + """ + return self._path + + @property + def input_spec(self): + return self._input_spec + + @input_spec.setter + def input_spec(self, value): + self._input_spec = value + + @utils.default + def build(self, input_shape): + self._check_super_called() + if utils.is_default(self.build) and might_have_unbuilt_state(self): + warnings.warn( + f"`build()` was called on layer '{self.name}', however " + "the layer does not have a `build()` method implemented " + "and it looks like it has unbuilt state. This will cause " + "the layer to be marked as built, despite not being " + "actually built, which may cause failures down the line. " + "Make sure to implement a proper `build()` method." + ) + self.built = True + + def _lock_state(self): + """Prevent further state updates, called automatically in `build()`.""" + if not self._tracker.locked: + self._tracker.lock( + msg=( + "You cannot add new elements of state " + "(variables or sub-layers) " + "to a layer that is already built. All state " + "must be created in the `__init__()` method or " + "in the `build()` method." + ) + ) + + def get_build_config(self): + """Returns a dictionary with the layer's input shape. + + This method returns a config dict that can be used by + `build_from_config(config)` to create all states (e.g. Variables and + Lookup tables) needed by the layer. + + By default, the config only contains the input shape that the layer + was built with. If you're writing a custom layer that creates state in + an unusual way, you should override this method to make sure this state + is already created when Keras attempts to load its value upon model + loading. + + Returns: + A dict containing the input shape associated with the layer. + """ + if self._build_shapes_dict is not None: + if len(self._build_shapes_dict) == 1: + return { + "input_shape": tuple(self._build_shapes_dict.values())[0], + } + else: + return {"shapes_dict": self._build_shapes_dict} + + def build_from_config(self, config): + """Builds the layer's states with the supplied config dict. + + By default, this method calls the `build(config["input_shape"])` method, + which creates weights based on the layer's input shape in the supplied + config. If your config contains other information needed to load the + layer's state, you should override this method. + + Args: + config: Dict containing the input shape associated with this layer. + """ + if config: + if "input_shape" in config: + self.build(config["input_shape"]) + elif "shapes_dict" in config: + self.build(**config["shapes_dict"]) + + def _obj_type(self): + return "Layer" + + def add_variable( + self, + shape, + initializer, + dtype=None, + trainable=True, + autocast=True, + regularizer=None, + constraint=None, + name=None, + ): + """Add a weight variable to the layer. + + Alias of `add_weight()`. + """ + return self.add_weight( + shape=shape, + initializer=initializer, + dtype=dtype, + trainable=trainable, + autocast=autocast, + regularizer=regularizer, + constraint=constraint, + name=name, + ) + + def add_weight( + self, + shape=None, + initializer=None, + dtype=None, + trainable=True, + autocast=True, + regularizer=None, + constraint=None, + aggregation="none", + overwrite_with_gradient=False, + name=None, + ): + """Add a weight variable to the layer. + + Args: + shape: Shape tuple for the variable. Must be fully-defined + (no `None` entries). Defaults to `()` (scalar) if unspecified. + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). If unspecified, defaults to + `"glorot_uniform"` for floating-point variables and to `"zeros"` + for all other types (e.g. int, bool). + dtype: Dtype of the variable to create, e.g. `"float32"`. If + unspecified, defaults to the layer's variable dtype + (which itself defaults to `"float32"` if unspecified). + trainable: Boolean, whether the variable should be trainable via + backprop or whether its updates are managed manually. Defaults + to `True`. + autocast: Boolean, whether to autocast layers variables when + accessing them. Defaults to `True`. + regularizer: Regularizer object to call to apply penalty on the + weight. These penalties are summed into the loss function + during optimization. Defaults to `None`. + constraint: Contrainst object to call on the variable after any + optimizer update, or string name of a built-in constraint. + Defaults to `None`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"`. Annotates the variable with + the type of multi-replica aggregation to be used for this + variable when writing custom data parallel training loops. + Defaults to `"none"`. + overwrite_with_gradient: Boolean, whether to overwrite the variable + with the computed gradient. This is useful for float8 training. + Defaults to `False`. + name: String name of the variable. Useful for debugging purposes. + """ + self._check_super_called() + if shape is None: + shape = () + if dtype is not None: + dtype = backend.standardize_dtype(dtype) + else: + dtype = self.variable_dtype + if initializer is None: + if "float" in dtype: + initializer = "glorot_uniform" + else: + initializer = "zeros" + initializer = initializers.get(initializer) + with backend.name_scope(self.name, caller=self): + variable = backend.Variable( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + name=name, + ) + # Will be added to layer.losses + variable.regularizer = regularizers.get(regularizer) + variable.constraint = constraints.get(constraint) + variable.overwrite_with_gradient = overwrite_with_gradient + self._track_variable(variable) + return variable + + @property + def trainable(self): + """Settable boolean, whether this layer should be trainable or not.""" + return self._trainable + + @trainable.setter + def trainable(self, value): + """Sets trainable attribute for the layer and its sublayers. + + When this value is changed during training (e.g. with a + `Callback`) you need to call the parent + `Model.make_train_function` with `force=True` in order to + recompile the training graph. + + Args: + value: Boolean with the desired state for the layer's trainable + attribute. + """ + value = bool(value) + self._trainable = value + for v in self._trainable_variables: + v.trainable = value + for layer in self._layers: + layer.trainable = value + + @property + def variables(self): + """List of all layer state, including random seeds. + + This extends `layer.weights` to include all state used by the layer + including `SeedGenerator`s. + + Note that metrics variables are not included here, use + `metrics_variables` to visit all the metric variables. + """ + # Return all `Variables` associate with the layer including metrics + # and random seeds. Also deduplicate them. + variables = [] + seen_ids = set() + for v in self._trainable_variables + self._non_trainable_variables: + if id(v) not in seen_ids: + variables.append(v) + seen_ids.add(id(v)) + for sg in self._seed_generators: + variables.append(sg.state) + for layer in self._layers: + for v in layer.variables: + if id(v) not in seen_ids: + variables.append(v) + seen_ids.add(id(v)) + return variables + + @property + def trainable_variables(self): + """List of all trainable layer state. + + This is equivalent to `layer.trainable_weights`. + """ + if not self.trainable: + return [] + return [v for v in self.variables if v.trainable] + + @property + def non_trainable_variables(self): + """List of all non-trainable layer state. + + This extends `layer.non_trainable_weights` to include all state used by + the layer including state for metrics and `SeedGenerator`s. + """ + if not self.trainable: + return self.variables + return [v for v in self.variables if not v.trainable] + + @property + def weights(self): + """List of all weight variables of the layer. + + Unlike, `layer.variables` this excludes metric state and random seeds. + """ + # Return only `Variables` directly owned by layers and sub-layers. + # Also deduplicate them. + weights = [] + seen_ids = set() + for w in self._trainable_variables + self._non_trainable_variables: + if id(w) not in seen_ids: + weights.append(w) + seen_ids.add(id(w)) + for layer in self._layers: + for w in layer.weights: + if id(w) not in seen_ids: + weights.append(w) + seen_ids.add(id(w)) + return weights + + @property + def trainable_weights(self): + """List of all trainable weight variables of the layer. + + These are the weights that get updated by the optimizer during training. + """ + if not self.trainable: + return [] + return [v for v in self.weights if v.trainable] + + @property + def non_trainable_weights(self): + """List of all non-trainable weight variables of the layer. + + These are the weights that should not be updated by the optimizer during + training. Unlike, `layer.non_trainable_variables` this excludes metric + state and random seeds. + """ + if not self.trainable: + return self.weights + return [v for v in self.weights if not v.trainable] + + @property + def metrics(self): + """List of all metrics.""" + metrics = list(self._metrics) + for layer in self._layers: + metrics.extend(layer.metrics) + return metrics + + @property + def metrics_variables(self): + """List of all metric variables.""" + vars = [] + for metric in self.metrics: + vars.extend(metric.variables) + return vars + + def get_weights(self): + """Return the values of `layer.weights` as a list of NumPy arrays.""" + return [v.numpy() for v in self.weights] + + def set_weights(self, weights): + """Sets the values of `layer.weights` from a list of NumPy arrays.""" + layer_weights = self.weights + if len(layer_weights) != len(weights): + raise ValueError( + f"You called `set_weights(weights)` on layer '{self.name}' " + f"with a weight list of length {len(weights)}, but the layer " + f"was expecting {len(layer_weights)} weights." + ) + for variable, value in zip(layer_weights, weights): + if variable.shape != value.shape: + raise ValueError( + f"Layer {self.name} weight shape {variable.shape} " + "is not compatible with provided weight " + f"shape {value.shape}." + ) + variable.assign(value) + + @property + def dtype_policy(self): + return self._dtype_policy + + @dtype_policy.setter + def dtype_policy(self, value): + policy = dtype_policies.get(value) + if isinstance(self._dtype_policy, DTypePolicyMap) and self.path: + if self.path in self._dtype_policy: + del self._dtype_policy[self.path] + self._dtype_policy[self.path] = policy + else: + self._dtype_policy = policy + if policy.quantization_mode is not None: + if self.built and not getattr(self, "_is_quantized", False): + self.quantize(policy.quantization_mode) + + @property + def dtype(self): + """Alias of `layer.variable_dtype`.""" + return self.variable_dtype + + @property + def compute_dtype(self): + """The dtype of the computations performed by the layer.""" + if isinstance(self._dtype_policy, DTypePolicyMap) and self.path: + policy = self._dtype_policy[self.path] + else: + policy = self._dtype_policy + return policy.compute_dtype + + @property + def variable_dtype(self): + """The dtype of the state (weights) of the layer.""" + if isinstance(self._dtype_policy, DTypePolicyMap) and self.path: + policy = self._dtype_policy[self.path] + else: + policy = self._dtype_policy + return policy.variable_dtype + + @property + def quantization_mode(self): + """The quantization mode of this layer, `None` if not quantized.""" + if isinstance(self._dtype_policy, DTypePolicyMap) and self.path: + policy = self._dtype_policy[self.path] + else: + policy = self._dtype_policy + return policy.quantization_mode + + @property + def input_dtype(self): + """The dtype layer inputs should be converted to.""" + return self.compute_dtype + + @property + def supports_masking(self): + """Whether this layer supports computing a mask using `compute_mask`.""" + return self._supports_masking + + @supports_masking.setter + def supports_masking(self, value): + self._supports_masking = value + + @utils.default + def compute_mask(self, inputs, previous_mask): + return previous_mask + + def symbolic_call(self, *args, **kwargs): + # Node is created at the end of `__call__` instead of `symbolic_call`. + return self.compute_output_spec(*args, **kwargs) + + @traceback_utils.filter_traceback + def __call__(self, *args, **kwargs): + self._check_super_called() + self._called = True + + original_args = args + original_kwargs = kwargs + + ############################################################# + # 1. Convert any array arguments to tensors of correct dtype. + def maybe_convert(x): + return self.dtype_policy.convert_input( + x, self.autocast, self.input_dtype + ) + + # Used to avoid expensive `tree` operations in the most common case. + if ( + kwargs + or len(args) != 1 + or not is_backend_tensor_or_symbolic(args[0], allow_none=False) + or backend.standardize_dtype(args[0].dtype) != self.input_dtype + ) and self._convert_input_args: + args = tree.map_structure(maybe_convert, args) + kwargs = tree.map_structure(maybe_convert, kwargs) + + ########################################################## + # 2. Enforce that only tensors can be passed positionally. + if not self._allow_non_tensor_positional_args: + for arg in tree.flatten(args): + if not is_backend_tensor_or_symbolic(arg, allow_none=True): + raise ValueError( + "Only input tensors may be passed as " + "positional arguments. The following argument value " + f"should be passed as a keyword argument: {arg} " + f"(of type {type(arg)})" + ) + + # Caches info about `call()` signature, args, kwargs. + call_spec = CallSpec( + self._call_signature, self._call_context_args, args, kwargs + ) + + ############################################ + # 3. Check input spec for 1st positional arg. + # TODO: consider extending this to all args and kwargs. + self._assert_input_compatibility(call_spec.first_arg) + + ################ + # 4. Call build + with self._open_name_scope(): + self._maybe_build(call_spec) + + ########################## + # 5. Infer training value + # Training phase for `Layer.call` is set via (in order of priority): + # (1) The `training` argument passed to this `Layer.call`, if not None + # (2) The training argument of an outer `Layer.call`. + # (4) Any non-None default value for `training` in the call signature + # (5) False (treating the layer as if it's in inference) + + # Maintains info about the `Layer.call` stack + # across nested calls. + call_context = self._get_call_context() + + for context_arg in self._call_context_args: + self._resolve_and_populate_arg( + context_arg, call_spec, call_context, kwargs + ) + + ############################## + # 6. Populate mask argument(s) + if len(call_spec.tensor_arguments_dict) == 1: + if ( + "mask" in call_spec.argument_names + and call_spec.arguments_dict["mask"] is None + ): + arg_name = list(call_spec.tensor_arguments_dict.keys())[0] + only_tensor_arg = call_spec.tensor_arguments_dict[arg_name] + mask = tree.map_structure( + backend.get_keras_mask, + only_tensor_arg, + ) + kwargs["mask"] = mask + elif len(call_spec.tensor_arguments_dict) > 1: + for k, v in call_spec.tensor_arguments_dict.items(): + expected_mask_arg_name = f"{k}_mask" + if expected_mask_arg_name in call_spec.argument_names: + if call_spec.arguments_dict[expected_mask_arg_name] is None: + mask = tree.map_structure(backend.get_keras_mask, v) + kwargs[expected_mask_arg_name] = mask + + # We need to cache the `previous_mask` before `__call__` because the + # mask might be removed during the call, such as `MultiHeadAttention`. + if "mask" in kwargs and kwargs["mask"] is not None: + # Case 1: Mask was explicitly passed or auto-populated in step 6. + previous_mask = kwargs["mask"] + else: + # Case 2: Fallback to the mask attached to the first input tensor. + previous_mask = tree.map_structure( + backend.get_keras_mask, call_spec.first_arg + ) + + #################### + # 7. Call the layer. + try: + with self._open_name_scope(): + current_scope = backend.get_autocast_scope() + new_scope = None + if current_scope is not None: + # Clear or update the current scope if necessary. + if not self.autocast: + new_scope = backend.AutocastScope(None) + elif not backend.is_float_dtype(self.compute_dtype): + # Some preprocessing layers might have a non-float + # dtype, we should not autocast in this case. + new_scope = backend.AutocastScope(None) + elif current_scope.dtype != self.compute_dtype: + new_scope = backend.AutocastScope(self.compute_dtype) + elif self.compute_dtype != self.variable_dtype: + # Enter a new scope if our dtypes are "mixed". + new_scope = backend.AutocastScope(self.compute_dtype) + if new_scope is not None: + with new_scope: + outputs = super().__call__(*args, **kwargs) + else: + outputs = super().__call__(*args, **kwargs) + # Change the layout for the layer output if needed. + # This is useful for relayout intermediate tensor in the model + # to achieve the optimal performance. + distribution = distribution_lib.distribution() + if distribution is not None: + current_layer_path = current_path() + current_layer_path += "/output" + layout = distribution.get_tensor_layout(current_layer_path) + if layout: + outputs = distribution_lib.distribute_tensor( + outputs, layout + ) + + self.built = True + # Record activity regularizer loss. + if self.activity_regularizer is not None: + for output in tree.flatten(outputs): + if backend.is_tensor(output): + self.add_loss(self.activity_regularizer(output)) + + # Set `previous_mask` on outputs if available. It is provided only + # for the first positional input arg and its mask. + # TODO: consider extending this to all args and kwargs. + if self.supports_masking: + self._set_mask_metadata( + call_spec.first_arg, outputs, previous_mask + ) + elif any(m is not None for m in tree.flatten(previous_mask)): + warnings.warn( + f"Layer '{self.name}' (of type {self.__class__.__name__}) " + "was passed an input with a mask attached to it. " + "However, this layer does not support masking and will " + "therefore destroy the mask information. Downstream " + "layers will not see the mask." + ) + finally: + # Destroy call context if we created it + self._maybe_reset_call_context() + + ################################################ + # 8. Add a node in the graph for symbolic calls. + if any_symbolic_tensors(original_args, original_kwargs): + Node( + operation=self, + call_args=original_args, + call_kwargs=original_kwargs, + outputs=outputs, + ) + + return outputs + + def call(self, *args, **kwargs): + raise self._not_implemented_error(self.call) + + def _resolve_and_populate_arg( + self, arg_name, call_spec, call_context, kwargs + ): + # 1) user explicitly passed it? + if arg_name in call_spec.user_arguments_dict: + value = call_spec.user_arguments_dict[arg_name] + # 2) else: inherited from outer layer call? + elif call_context.get_value(arg_name) is not None: + value = call_context.get_value(arg_name) + # 3) else: default from the call() signature + else: + value = call_spec.arguments_dict.get(arg_name, None) + + # stash it for downstream layers + call_context.set_value(arg_name, value) + + # only inject it if this layer actually accepts it and it's not None + if ( + self._call_has_context_arg.get(arg_name, False) + and value is not None + ): + kwargs[arg_name] = value + + @traceback_utils.filter_traceback + def stateless_call( + self, + trainable_variables, + non_trainable_variables, + *args, + return_losses=False, + **kwargs, + ): + """Call the layer without any side effects. + + Args: + trainable_variables: List of trainable variables of the model. + non_trainable_variables: List of non-trainable variables of the + model. + *args: Positional arguments to be passed to `call()`. + return_losses: If `True`, `stateless_call()` will return the list of + losses created during `call()` as part of its return values. + **kwargs: Keyword arguments to be passed to `call()`. + + Returns: + A tuple. By default, returns `(outputs, non_trainable_variables)`. + If `return_losses = True`, then returns + `(outputs, non_trainable_variables, losses)`. + + Note: `non_trainable_variables` include not only non-trainable weights + such as `BatchNormalization` statistics, but also RNG seed state + (if there are any random operations part of the layer, such as dropout), + and `Metric` state (if there are any metrics attached to the layer). + These are all elements of state of the layer. + + Example: + + ```python + model = ... + data = ... + trainable_variables = model.trainable_variables + non_trainable_variables = model.non_trainable_variables + # Call the model with zero side effects + outputs, non_trainable_variables = model.stateless_call( + trainable_variables, + non_trainable_variables, + data, + ) + # Attach the updated state to the model + # (until you do this, the model is still in its pre-call state). + for ref_var, value in zip( + model.non_trainable_variables, non_trainable_variables + ): + ref_var.assign(value) + ``` + """ + self._check_super_called() + if not self.built: + raise ValueError( + f"To call stateless_call, {self.__class__.__name__} must be " + "built (i.e. its variables must have been already created). " + "You can build it by calling it on some data." + ) + if len(trainable_variables) != len(self.trainable_variables): + raise ValueError( + "Argument `trainable_variables` must be a list of tensors " + "corresponding 1:1 to " + f"{self.__class__.__name__}().trainable_variables. " + f"Received list with length {len(trainable_variables)}, " + f"but expected {len(self.trainable_variables)} variables." + ) + if len(non_trainable_variables) != len(self.non_trainable_variables): + raise ValueError( + "Argument `non_trainable_variables` must be a list of tensors " + "corresponding 1:1 to " + f"{self.__class__.__name__}().non_trainable_variables. " + f"Received list with length {len(non_trainable_variables)}, " + f"but expected {len(self.non_trainable_variables)} variables." + ) + + # Gather variable mapping + trainable_mapping = zip(self.trainable_variables, trainable_variables) + non_trainable_mapping = zip( + self.non_trainable_variables, non_trainable_variables + ) + mapping = list(trainable_mapping) + list(non_trainable_mapping) + + # Call in stateless scope + losses = None + with backend.StatelessScope( + state_mapping=mapping, collect_losses=return_losses + ) as scope: + if self.dtype_policy.quantization_mode is not None: + if self._remat_mode is not None: + outputs = self.rematerialized_call( + self.quantized_call, *args, **kwargs + )(*args, **kwargs) + else: + outputs = self.quantized_call(*args, **kwargs) + elif self._remat_mode is not None: + outputs = self.rematerialized_call(self.call, *args, **kwargs)( + *args, **kwargs + ) + else: + outputs = self.call(*args, **kwargs) + if return_losses: + losses = self.losses + + # Gather updated non-trainable variables + non_trainable_variables = [] + for v in self.non_trainable_variables: + new_v = scope.get_current_value(v) + non_trainable_variables.append(new_v) + + if return_losses: + return outputs, non_trainable_variables, losses + return outputs, non_trainable_variables + + def compute_output_spec(self, *args, **kwargs): + if utils.is_default(self.compute_output_shape): + return super().compute_output_spec(*args, **kwargs) + else: + # Use compute_output_shape() to return the right output spec + call_spec = CallSpec( + self._call_signature, self._call_context_args, args, kwargs + ) + shapes_dict = get_shapes_dict(call_spec) + shapes_dict = update_shapes_dict_for_target_fn( + self.compute_output_shape, + shapes_dict=shapes_dict, + call_spec=call_spec, + class_name=self.__class__.__name__, + ) + output_shape = self.compute_output_shape(**shapes_dict) + + if ( + isinstance(output_shape, list) + and output_shape + and isinstance(output_shape[0], (int, type(None))) + ): + output_shape = tuple(output_shape) + if not isinstance(output_shape, (list, tuple, dict)): + try: + output_shape = tuple(output_shape) + except: + raise ValueError( + "Method `compute_output_shape()` of layer " + f"{self.__class__.__name__} is returning " + "a type that cannot be interpreted as a shape. " + "It should return a shape tuple. " + f"Received: {output_shape}" + ) + if ( + isinstance(output_shape, tuple) + and output_shape + and isinstance(output_shape[0], (int, type(None))) + ): + return KerasTensor(output_shape, dtype=self.compute_dtype) + # Case: nested. Could be a tuple/list of shapes, or a dict of + # shapes. Could be deeply nested. + return tree.map_shape_structure( + lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape + ) + + @utils.default + def compute_output_shape(self, *args, **kwargs): + raise self._not_implemented_error( + self.compute_output_shape, + "Should implement `def compute_output_shape(self, input_shape)`.", + ) + + def add_loss(self, loss): + """Can be called inside of the `call()` method to add a scalar loss. + + Example: + + ```python + class MyLayer(Layer): + ... + def call(self, x): + self.add_loss(ops.sum(x)) + return x + ``` + """ + # Eager only. + losses = tree.flatten(loss) + for x in losses: + if not backend.is_tensor(x): + raise ValueError( + "`add_loss()` can only be called from inside `build()` or " + f"`call()`, on a tensor input. Received invalid value: {x}" + ) + if backend.in_stateless_scope(): + scope = backend.get_stateless_scope() + if scope.collect_losses: + for x in losses: + scope.add_loss(x) + self._loss_ids.add(id(x)) + else: + self._losses.extend(losses) + + def _get_own_losses(self): + if backend.in_stateless_scope(): + losses = [] + scope = backend.get_stateless_scope() + for loss in scope.losses: + if id(loss) in self._loss_ids: + losses.append(loss) + return losses + else: + return self._losses[:] + + def _get_regularization_losses(self): + weight_regularization_losses = [] + for variable in self.trainable_weights: + if variable.regularizer is None: + continue + if backend.in_stateless_scope() and not in_symbolic_scope(): + # If in symbolic scope, we might get `None` from + # `get_current_value` in `backend.compute_output_spec`. So we + # assign `variable` instead. + v = backend.get_stateless_scope().get_current_value(variable) + else: + v = variable + weight_regularization_losses.append(variable.regularizer(v)) + return weight_regularization_losses + + @property + def losses(self): + """List of scalar losses from `add_loss`, regularizers and sublayers.""" + if self._losses_override: + return self._losses_override + losses = self._get_own_losses() + for layer in self._flatten_layers(include_self=False): + losses.extend(layer._get_own_losses()) + weight_regularization_losses = self._get_regularization_losses() + losses.extend(weight_regularization_losses) + return losses + + def _clear_losses(self): + if backend.in_stateless_scope(): + scope = backend.get_stateless_scope() + if scope.collect_losses: + for x in scope.losses: + if id(x) in self._loss_ids: + scope.losses.remove(x) + self._losses.clear() + self._loss_ids.clear() + for layer in self._layers: + layer._clear_losses() + + # Quantization-related (int8 and float8) methods + + def quantized_build(self, input_shape, mode): + raise self._not_implemented_error(self.quantized_build) + + def quantize(self, mode, type_check=True, config=None): + raise self._not_implemented_error(self.quantize) + + def _check_quantize_args(self, mode, compute_dtype): + if not self.built: + raise ValueError( + "Cannot quantize a layer that isn't yet built. " + f"Layer '{self.name}' (of type '{self.__class__.__name__}') " + "is not built yet." + ) + if getattr(self, "_is_quantized", False): + raise ValueError( + f"Layer '{self.name}' is already quantized with " + f"dtype_policy='{self.dtype_policy.name}'. " + f"Received: mode={mode}" + ) + if mode not in dtype_policies.QUANTIZATION_MODES: + raise ValueError( + "Invalid quantization mode. " + f"Expected one of {dtype_policies.QUANTIZATION_MODES}. " + f"Received: mode={mode}" + ) + if mode == "int8" and compute_dtype == "float16": + raise ValueError( + f"Quantization mode='{mode}' doesn't work well with " + "compute_dtype='float16'. Consider loading model/layer with " + "another dtype policy such as 'mixed_bfloat16' or " + "'mixed_float16' before calling `quantize()`." + ) + + def quantized_call(self, *args, **kwargs): + current_remat_mode = get_current_remat_mode() + + if ( + current_remat_mode != self._remat_mode + and current_remat_mode is not None + ): + warnings.warn( + f"The RematScope at call time ({current_remat_mode}) differs " + f"the one set during layer initialization " + f"({self._remat_mode}). " + f"Restoring the correct rematerialization mode " + f"{self._remat_mode} for this layer." + ) + if self.quantization_mode == "int8": + return self._int8_call(*args, **kwargs) + elif self.quantization_mode == "float8": + return self._float8_call(*args, **kwargs) + elif self.quantization_mode == "int4": + return self._int4_call(*args, **kwargs) + elif self.quantization_mode == "gptq": + return self._gptq_call(*args, **kwargs) + else: + raise self._quantization_mode_error(self.quantization_mode) + + def _int4_call(self, *args, **kwargs): + raise self._not_implemented_error(self._int4_call) + + def _int8_call(self, *args, **kwargs): + raise self._not_implemented_error(self._int8_call) + + def _float8_call(self, *args, **kwargs): + raise self._not_implemented_error(self._float8_call) + + def _gptq_call(self, *args, **kwargs): + raise self._not_implemented_error(self._gptq_call) + + def _not_implemented_error(self, attr, msg=None): + if callable(attr): + attr_name = attr.__name__ + attr_type = "method" + else: + attr_name = str(attr) + attr_type = "attribute" + msg = f" {msg}" if msg is not None else "" + return NotImplementedError( + f"Layer {self.__class__.__name__} does not have a `{attr_name}` " + f"{attr_type} implemented.{msg}" + ) + + def _quantization_mode_error(self, mode): + return NotImplementedError( + "Invalid quantization mode. Expected one of " + f"{dtype_policies.QUANTIZATION_MODES}. " + f"Received: quantization_mode={mode}" + ) + + def save_own_variables(self, store): + """Saves the state of the layer. + + You can override this method to take full control of how the state of + the layer is saved upon calling `model.save()`. + + Args: + store: Dict where the state of the model will be saved. + """ + all_vars = self._trainable_variables + self._non_trainable_variables + for i, v in enumerate(all_vars): + store[f"{i}"] = v + + def load_own_variables(self, store): + """Loads the state of the layer. + + You can override this method to take full control of how the state of + the layer is loaded upon calling `keras.models.load_model()`. + + Args: + store: Dict from which the state of the model will be loaded. + """ + all_vars = self._trainable_variables + self._non_trainable_variables + if len(store.keys()) != len(all_vars): + if len(all_vars) == 0 and not self.built: + raise ValueError( + f"Layer '{self.name}' was never built " + "and thus it doesn't have any variables. " + f"However the weights file lists {len(store.keys())} " + "variables for this layer.\n" + "In most cases, this error indicates that either:\n\n" + "1. The layer is owned by a parent layer that " + "implements a `build()` method, but calling the " + "parent's `build()` method did NOT create the state of " + f"the child layer '{self.name}'. A `build()` method " + "must create ALL state for the layer, including " + "the state of any children layers.\n\n" + "2. You need to implement " + "the `def build_from_config(self, config)` method " + f"on layer '{self.name}', to specify how to rebuild " + "it during loading. " + "In this case, you might also want to implement the " + "method that generates the build config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant " + "to create the state " + "of the layer (i.e. its variables) upon deserialization.", + ) + raise ValueError( + f"Layer '{self.name}' expected {len(all_vars)} variables, " + "but received " + f"{len(store.keys())} variables during loading. " + f"Expected: {[v.name for v in all_vars]}" + ) + for i, v in enumerate(all_vars): + v.assign(store[f"{i}"]) + + def _track_variable(self, variable): + if variable.trainable: + self._tracker.add_to_store("trainable_variables", variable) + else: + self._tracker.add_to_store("non_trainable_variables", variable) + if not self.trainable: + variable.trainable = False + self._post_track_variable(variable) + + def _untrack_variable(self, variable): + previous_lock_state = self._tracker.locked + self._tracker.unlock() + self._tracker.untrack(variable) + if previous_lock_state is True: + self._tracker.lock() + self._post_untrack_variable(variable) + + def add_metric(self, *args, **kwargs): + # Permanently disabled + raise NotImplementedError( + "Layer `add_metric()` method is deprecated. " + "Add your metric in `Model.compile(metrics=[...])`, " + "or create metric trackers in init() or build() " + "when subclassing the layer or model, then call " + "`metric.update_state()` whenever necessary." + ) + + def count_params(self): + """Count the total number of scalars composing the weights. + + Returns: + An integer count. + """ + if not self.built: + raise ValueError( + "You tried to call `count_params` " + f"on layer '{self.name}', " + "but the layer isn't built. " + "You can build it manually via: " + f"`layer.build(input_shape)`." + ) + return summary_utils.count_params(self.weights) + + def _maybe_build(self, call_spec): + if self.built: + return + + shapes_dict = get_shapes_dict(call_spec) + first_shape = next(iter(shapes_dict.values()), None) + + # If the layer has a build method, call it with our input shapes. + if not utils.is_default(self.build): + shapes_dict = update_shapes_dict_for_target_fn( + self.build, + shapes_dict=shapes_dict, + call_spec=call_spec, + class_name=self.__class__.__name__, + ) + self.build(**shapes_dict) + # Check input spec again (after build, since self.input_spec + # may have been updated + self._assert_input_compatibility(call_spec.first_arg) + return + + # Otherwise, attempt to build the layer by calling it on symbolic input. + if might_have_unbuilt_state(self): + try: + backend.compute_output_spec( + self.call, **call_spec.arguments_dict + ) + except Exception as e: + if call_spec.eager: + # Will let the actual eager call do state-building + return + warnings.warn( + f"Layer '{self.name}' looks like it has unbuilt state, but " + "Keras is not able to trace the layer `call()` in order to " + "build it automatically. Possible causes:\n" + "1. The `call()` method of your layer may be crashing. Try " + "to `__call__()` the layer eagerly on some test input " + "first to see if it works. " + "E.g. `x = np.random.random((3, 4)); y = layer(x)`\n" + "2. If the `call()` method is correct, then you may need " + "to implement the `def build(self, input_shape)` method on " + "your layer. It should create all variables used by the " + "layer (e.g. by calling `layer.build()` on all its " + "children layers).\n" + f"Exception encountered: ''{e}''" + ) + self.build(first_shape) + + def _build_by_run_for_single_pos_arg(self, input_shape): + # Case: all inputs are in the first arg (possibly nested). + input_tensors = tree.map_shape_structure( + lambda s: backend.KerasTensor(s), input_shape + ) + try: + backend.compute_output_spec(self.call, input_tensors) + return True + except: + return False + + def _build_by_run_for_kwargs(self, shapes_dict): + # Case: inputs were recorded as multiple keyword arguments. + if all(is_shape_tuple(s) for s in shapes_dict.values()): + # Case: all input keyword arguments were plain tensors. + input_tensors = { + # We strip the `_shape` suffix to recover kwarg names. + utils.removesuffix(k, "_shape"): backend.KerasTensor(shape) + for k, shape in shapes_dict.items() + } + try: + backend.compute_output_spec(self.call, **input_tensors) + return True + except: + return False + else: + # Not supported: nested input keyword arguments. + return False + + def __repr__(self): + return ( + f"<{self.__class__.__name__} name={self.name}, built={self.built}>" + ) + + def __str__(self): + return self.__repr__() + + def __setattr__(self, name, value): + # Track Variables, Layers, Metrics, SeedGenerators. + name, value = self._setattr_hook(name, value) + if name != "_tracker": + if not hasattr(self, "_tracker"): + self._initialize_tracker() + value = self._tracker.track(value) + + # NNX-specific bypass for `_called` and `built` attributes + # bypass nnx.Module.__setattr__ which cannot be called while tracing + if ( + backend.backend() == "jax" + and is_nnx_enabled() + and (name == "_called" or name == "built") + ): + object.__setattr__(self, name, value) + return + + super().__setattr__(name, value) + + def __delattr__(self, name): + obj = getattr(self, name) + if isinstance(obj, backend.Variable): + import gc + + # It will take a short amount of time for the corresponding buffer + # to be actually removed from the device. + # https://stackoverflow.com/a/74631949 + self._untrack_variable(obj) + super().__delattr__(name) + gc.collect() + else: + super().__delattr__(name) + + def _check_super_called(self): + if getattr(self, "_lock", True): + raise RuntimeError( + f"In layer '{self.__class__.__name__}', you forgot to call " + "`super().__init__()` as the first statement " + "in the `__init__()` method. Go add it!" + ) + + def _assert_input_compatibility(self, arg_0): + if self.input_spec: + try: + input_spec.assert_input_compatibility( + self.input_spec, arg_0, layer_name=self.name + ) + except SystemError: + if backend.backend() == "torch": + # TODO: The torch backend failed the ONNX CI with the error: + # SystemError: returned a result with an exception set + # As a workaround, we are skipping this for now. + pass + else: + raise + + def _get_call_context(self): + """Returns currently active `CallContext`.""" + layer_call_ctx = global_state.get_global_attribute("current_call_ctx") + if layer_call_ctx is None: + # Enter new call context. + layer_call_ctx = CallContext(entry_layer=self) + global_state.set_global_attribute( + "current_call_ctx", layer_call_ctx + ) + self._clear_losses() + return layer_call_ctx + + def _maybe_reset_call_context(self): + layer_call_ctx = global_state.get_global_attribute("current_call_ctx") + if layer_call_ctx is None or layer_call_ctx.entry_layer == self: + global_state.set_global_attribute("current_call_ctx", None) + + def _flatten_layers(self, include_self=True, recursive=True): + layers = [] + if include_self: + layers.append(self) + seen_object_ids = set() + deque = collections.deque(self._layers) + while deque: + layer = deque.popleft() + if id(layer) in seen_object_ids: + continue + seen_object_ids.add(id(layer)) + layers.append(layer) + # Introspect recursively through sublayers. + if recursive: + deque.extendleft(layer._layers) + return layers + + def _set_mask_metadata(self, inputs, outputs, previous_mask): + flat_outputs = tree.flatten(outputs) + + mask_already_computed = all( + backend.get_keras_mask(x) is not None for x in flat_outputs + ) + if mask_already_computed: + return + + output_masks = self.compute_mask(inputs, previous_mask) + if output_masks is None: + return + + flat_masks = tree.flatten(output_masks) + for tensor, mask in zip(flat_outputs, flat_masks): + if backend.get_keras_mask(tensor) is None and mask is not None: + if backend.backend() == "numpy": + warnings.warn( + "The NumPy backend does not support masking at this" + "time. Masks will be ignored." + ) + else: + backend.set_keras_mask(tensor, mask) + + @python_utils.default + def get_config(self): + self._check_super_called() + base_config = super().get_config() + config = { + "trainable": self.trainable, + "dtype": dtype_policies.serialize(self.dtype_policy), + } + if self.activity_regularizer is not None: + config["activity_regularizer"] = regularizers.serialize( + self.activity_regularizer + ) + return {**base_config, **config} + + def _open_name_scope(self): + from keras.src.utils import jax_utils # avoid circular imports + + if self._parent_path is None: + # Avoid mutating _parent_path during a JAX trace if it's part of + # nnx.Object state and the object was created at a different trace + # level. We check if we are in NNX mode and if we are in a JAX + # trace. + if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()): + self._parent_path = current_path() + + return backend.name_scope(self.name, caller=self) + + def rematerialized_call(self, layer_call, *args, **kwargs): + """Enable rematerialization dynamically for layer's call method. + + Args: + layer_call: The original `call` method of a layer. + + Returns: + Rematerialized layer's `call` method. + """ + + def compute_size(x): + return ( + math.prod([d or 1 for d in x.shape]) + if isinstance(x, KerasTensor) + else 0 + ) + + # Full rematerialization + if self._remat_mode.mode == "full": + return remat.remat(layer_call) + + # Apply rematerialization to specific layers + elif self._remat_mode.mode == "list_of_layers" and ( + self.name in self._remat_mode.layer_names + ): + return remat.remat(layer_call) + + # Apply rematerialization based on output size threshold + elif self._remat_mode.mode == "larger_than": + output_spec = self.compute_output_spec(*args, **kwargs) + output_size = sum( + tree.flatten(tree.map_structure(compute_size, output_spec)) + ) + if ( + output_size + and output_size > self._remat_mode.output_size_threshold + ): + return remat.remat(layer_call) + elif self._remat_mode.mode == "activations": + has_activation = ( + hasattr(self, "activation") and self.activation is not None + ) + if has_activation: + + @functools.wraps(layer_call) + def rematerialized_activation_call_wrapper(*args, **kwargs): + original_activation = self.activation + self.activation = remat.remat(original_activation) + try: + return layer_call(*args, **kwargs) + finally: + self.activation = original_activation + + return rematerialized_activation_call_wrapper + return layer_call + + def _register_call_context_args(self, *names): + """Registers call-context args for this layer. + + If this layer declares a `call()` method that accepts + one or more of the given args, those args will be + automatically injected into the call signature of this + layer. This layer will also propagate the args to any + nested sublayers that are called from within this layer. + + If this layer doesn't declare a `call()` method that + accepts one or more of the given args, these args will + simply be propagated to any nested sublayers without + being injected into the call signature of this layer. + This is useful for propagating custom arguments + from top-level layers/models to sublayers. + + Example: + ``` + class Inner(layers.Layer): + + def __init__(self): + super().__init__() + # Register `foo_mode` as a call-context arg + self._register_call_context_args("foo_mode") + + def call(self, x, foo_mode=False): + # If foo_mode=True add 1, otherwise add 0 + add_val = ops.where(foo_mode, 1.0, 0.0) + return x + add_val + + class Outer(layers.Layer): + def __init__(self): + super().__init__() + self.inner = Inner() + + def call(self, x): + # We don't explicitly pass foo_mode here—Base Layer.__call__ + # should inject it into `self.inner` + return self.inner(x) + + sample_input = np.array([[1.0], [2.0]]) + + # Sequential model + seq = models.Sequential([Outer()]) + + # Tell the Sequential model to propagate foo_mode down + # the call-stack + seq.register_call_context_args("foo_mode") + + # foo_mode=True -> input + 1 + out_true = seq(sample_input, foo_mode=True) + """ + if self._called: + raise RuntimeError( + "Cannot add call-context args after the layer has been called." + ) + self._call_context_args = self._call_context_args | set(names) + + self._call_has_context_arg.update( + {arg: (arg in self.call_signature_parameters) for arg in names} + ) + + +def is_backend_tensor_or_symbolic(x, allow_none=False): + if allow_none and x is None: + return True + return backend.is_tensor(x) or isinstance(x, backend.KerasTensor) + + +class CallSpec: + def __init__(self, signature, call_context_args, args, kwargs): + # Strip out user-supplied call-context args that this layer’s `call()` + # does not accept (otherwise `signature.bind` would raise). + # This includes built-in args like `training`, and user-defined args. + call_args = { + context_arg: kwargs.pop(context_arg) + for context_arg in call_context_args + if context_arg in kwargs and context_arg not in signature.parameters + } + + bound_args = signature.bind(*args, **kwargs) + + # Combine the two dicts. + self.user_arguments_dict = {**call_args, **bound_args.arguments} + + bound_args.apply_defaults() + arg_dict = {} + arg_names = [] + tensor_arg_dict = {} + tensor_args = [] + tensor_arg_names = [] + nested_tensor_arg_names = [] + for name, value in bound_args.arguments.items(): + arg_dict[name] = value + arg_names.append(name) + if is_backend_tensor_or_symbolic(value): + tensor_args.append(value) + tensor_arg_names.append(name) + tensor_arg_dict[name] = value + elif tree.is_nested(value) and len(value) > 0: + flat_values = tree.flatten(value) + if all( + is_backend_tensor_or_symbolic(x, allow_none=True) + for x in flat_values + ): + tensor_args.append(value) + tensor_arg_names.append(name) + tensor_arg_dict[name] = value + nested_tensor_arg_names.append(name) + elif any(is_backend_tensor_or_symbolic(x) for x in flat_values): + raise ValueError( + "In a nested call() argument, " + "you cannot mix tensors and non-tensors. " + "Received invalid mixed argument: " + f"{name}={value}" + ) + self.arguments_dict = arg_dict + self.argument_names = arg_names + self.tensor_arguments_dict = tensor_arg_dict + self.tensor_arguments_names = tensor_arg_names + self.nested_tensor_argument_names = nested_tensor_arg_names + self.first_arg = arg_dict[arg_names[0]] + if all( + backend.is_tensor(x) for x in self.tensor_arguments_dict.values() + ): + self.eager = True + else: + self.eager = False + + +def get_arguments_dict(fn, args, kwargs): + """Return a dict mapping argument names to their values.""" + sig = inspect.signature(fn) + bound_args = sig.bind(*args, **kwargs) + arg_dict = {} + for name, value in bound_args.arguments.items(): + arg_dict[name] = value + return arg_dict + + +def get_shapes_dict(call_spec): + """Convert the call() arguments dict into a dict of input shape arguments. + + Example: + + ``` + >>> get_shapes_dict(call_spec) + {"input_a_shape": (2, 3)} + ``` + """ + shapes_dict = {} + for k, v in call_spec.tensor_arguments_dict.items(): + if k == "mask" or k.endswith("_mask"): + # Do not include mask tensors in shapes dict + continue + if k == "kwargs" or k == "args": + # Do not include catch-alls in shapes dict + continue + if k in call_spec.nested_tensor_argument_names: + shapes_dict[f"{k}_shape"] = tree.map_structure( + lambda x: backend.standardize_shape(x.shape), v + ) + else: + shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape) + return shapes_dict + + +def update_shapes_dict_for_target_fn( + target_fn, + shapes_dict, + call_spec, + class_name, +): + """Updates a `shapes_dict` for `build()` or `compute_output_shape()`. + + This function will align a dictionary of the shapes of all tensor + passed to `call`, with the signatures of `build()` or + `compute_output_shape()`. + + The alignment is a follows: + + - If `build()` or `compute_output_shape()` accept only one argument, + forward the shape of the first positional argument from call without + checking any argument names. + - If `build()` or `compute_output_shape()` accept multiple arguments, + enforce that all argument names match a call argument name, e.g. + `foo_shape` would match call argument `foo`. + + Returns: + An updated `shapes_dict` that can be used to invoke + `target_fn(**shapes_dict)`. + """ + if utils.is_default(target_fn): + return None + sig = inspect.signature(target_fn) + expected_names = [] + for name, param in sig.parameters.items(): + if param.kind in ( + param.POSITIONAL_OR_KEYWORD, + param.POSITIONAL_ONLY, + param.KEYWORD_ONLY, + ): + expected_names.append(name) + + # Single arg: don't check names, pass first shape. + if len(expected_names) == 1: + key = expected_names[0] + values = tuple(shapes_dict.values()) + if values: + input_shape = values[0] + else: + input_shape = None + return {key: input_shape} + + # Multiple args: check that all names line up. + kwargs = {} + for name in expected_names: + method_name = target_fn.__name__ + error_preamble = ( + f"For a `{method_name}()` method with more than one argument, all " + "arguments should have a `_shape` suffix and match an argument " + f"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` " + ) + if not name.endswith("_shape"): + raise ValueError( + f"{error_preamble} For layer '{class_name}', " + f"Received `{method_name}()` argument " + f"`{name}`, which does not end in `_shape`." + ) + expected_call_arg = utils.removesuffix(name, "_shape") + if expected_call_arg not in call_spec.arguments_dict: + raise ValueError( + f"{error_preamble} For layer '{class_name}', " + f"received `{method_name}()` argument " + f"`{name}`, but `call()` does not have argument " + f"`{expected_call_arg}`." + ) + if name in shapes_dict: + kwargs[name] = shapes_dict[name] + + return kwargs + + +class CallContext: + def __init__(self, entry_layer): + self.entry_layer = entry_layer + + def get_value(self, arg_name, default=None): + """Get the context value for `arg_name`, or `default` if unset.""" + return getattr(self, arg_name, default) + + def set_value(self, arg_name, value): + """Set `arg_name` = `value` on this context object.""" + setattr(self, arg_name, value) + + +def is_shape_tuple(s): + return isinstance(s, (list, tuple)) and all( + d is None or isinstance(d, int) for d in s + ) + + +def might_have_unbuilt_state(layer): + return any(not lr.built for lr in layer._layers) diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py new file mode 100644 index 000000000000..53531b679cc5 --- /dev/null +++ b/keras/src/layers/layer_test.py @@ -0,0 +1,1760 @@ +import pickle +from unittest import mock + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import Input +from keras.src import backend +from keras.src import dtype_policies +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.backend.common import global_state +from keras.src.backend.common.remat import RematScope +from keras.src.models import Model +from keras.src.utils import traceback_utils + + +class MockRemat: + """Mock remat by returning a wrapper Mock calling the original function""" + + def __init__(self): + self.rematted_functions = {} + + def __call__(self, func): + if func in self.rematted_functions: + return self.rematted_functions[func] + + wrapped_func = mock.Mock(wraps=func) + self.rematted_functions[func] = wrapped_func + return wrapped_func + + +class LayerTest(testing.TestCase): + def test_compute_output_spec(self): + # Test that implementing compute_output_shape + # is enough to make compute_output_spec work. + + # Case: single output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return input_shape + + layer = TestLayer() + self.assertEqual( + layer.compute_output_spec(backend.KerasTensor((2, 3))).shape, (2, 3) + ) + + # Case: tuple output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return (input_shape, input_shape) + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 2) + self.assertEqual(out[0].shape, (2, 3)) + self.assertEqual(out[1].shape, (2, 3)) + + # Case: list output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return [input_shape, input_shape] + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertIsInstance(out, list) + self.assertEqual(len(out), 2) + self.assertEqual(out[0].shape, (2, 3)) + self.assertEqual(out[1].shape, (2, 3)) + + # Case: dict output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return {"1": input_shape, "2": input_shape} + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertIsInstance(out, dict) + self.assertEqual(len(out), 2) + self.assertEqual(out["1"].shape, (2, 3)) + self.assertEqual(out["2"].shape, (2, 3)) + + # Case: nested tuple output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return ( + input_shape, + (input_shape, input_shape), + (input_shape, input_shape), + ) + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 3) + self.assertEqual(out[0].shape, (2, 3)) + self.assertIsInstance(out[1], tuple) + self.assertEqual(len(out[1]), 2) + self.assertEqual(out[1][0].shape, (2, 3)) + self.assertEqual(out[1][1].shape, (2, 3)) + self.assertIsInstance(out[2], tuple) + self.assertEqual(len(out[2]), 2) + self.assertEqual(out[2][0].shape, (2, 3)) + self.assertEqual(out[2][1].shape, (2, 3)) + + # Case: nested dict output + class TestLayer(layers.Layer): + def call(self, x): + assert False # Should never be called. + + def compute_output_shape(self, input_shape): + return { + "1": input_shape, + "2": {"11": input_shape, "22": input_shape}, + } + + layer = TestLayer() + out = layer.compute_output_spec(backend.KerasTensor((2, 3))) + self.assertIsInstance(out, dict) + self.assertEqual(len(out), 2) + self.assertEqual(out["1"].shape, (2, 3)) + self.assertIsInstance(out["2"], dict) + self.assertEqual(len(out["2"]), 2) + self.assertEqual(out["2"]["11"].shape, (2, 3)) + self.assertEqual(out["2"]["22"].shape, (2, 3)) + + def test_positional_arg_error(self): + class SomeLayer(layers.Layer): + def call(self, x, bool_arg): + if bool_arg: + return x + return x + 1 + + x = backend.KerasTensor(shape=(2, 3), name="x") + with self.assertRaisesRegex( + ValueError, "Only input tensors may be passed as" + ): + SomeLayer()(x, True) + + # This works + SomeLayer()(x, bool_arg=True) + + @parameterized.named_parameters( + ("call", "call", None), + ("compute_output_shape", "compute_output_shape", None), + ( + "quantized_build", + "quantized_build", + {"input_shape": None, "mode": None}, + ), + ("quantize", "quantize", {"mode": "int8"}), + ("_int8_call", "_int8_call", None), + ("_float8_call", "_float8_call", None), + ) + def test_not_implemented_error(self, method, args): + layer = layers.Layer() + layer.built = True + + with self.assertRaisesRegex( + NotImplementedError, + f"does not have a `{method}` method implemented.", + ): + if isinstance(args, dict): + getattr(layer, method)(**args) + else: + getattr(layer, method)(args) + + def test_layer_with_remat(self): + """Test rematerialization on a simple layer.""" + # Create a mock to track calls to remat + mock_remat = MockRemat() + with mock.patch( + "keras.src.backend.common.remat.remat", wraps=mock_remat + ): + + class SomeLayer(layers.Layer): + def call(self, x): + return x + 1 + + input_tensor = backend.random.uniform((2, 4)) + layer = SomeLayer() + # Case 1: Without rematerialization + output_no_remat = layer(input_tensor) + + # Case 2: With rematerialization + with RematScope(mode="full"): + layer = SomeLayer() + output_with_remat = layer(input_tensor) + + # Assert outputs are the same + self.assertAllClose(output_no_remat, output_with_remat) + + # Ensure remat was applied in the second case + self.assertLen(mock_remat.rematted_functions, 1) + next(iter(mock_remat.rematted_functions.values())).assert_called() + + def test_quantized_layer_with_remat(self): + """Test rematerialization on a quantized layer.""" + mock_remat = MockRemat() + with mock.patch( + "keras.src.backend.common.remat.remat", wraps=mock_remat + ): + input_tensor = backend.random.uniform((2, 4)) + + # Case 2: With rematerialization + with RematScope(mode="full"): + layer = layers.Dense(3) + layer.build((2, 4)) + layer.quantize("float8") + layer(input_tensor) + + # Ensure remat was applied + self.assertLen(mock_remat.rematted_functions, 1) + next(iter(mock_remat.rematted_functions.values())).assert_called() + + def test_functional_model_with_remat(self): + if backend.backend() in ("openvino", "numpy"): + self.skipTest( + "remat is not supported in openvino and numpy backends." + ) + traceback_utils.enable_traceback_filtering() + mock_remat = MockRemat() + with mock.patch( + "keras.src.backend.common.remat.remat", wraps=mock_remat + ): + # Define model inputs + inputs = Input(shape=(32, 32, 3)) + + # just one layer in remat scope + with RematScope(mode="activations"): + layer = layers.Dense(64, activation="relu") + output = layer(layers.Flatten()(inputs)) + + # Build the functional model + model = Model(inputs=inputs, outputs=output) + + # Compile the model + model.compile(optimizer="adam", loss="mse") + + # Generate dummy data for testing + x_train = np.random.random((10, 32, 32, 3)).astype(np.float32) + y_train = np.random.random((10, 64)).astype(np.float32) + + # Run training to ensure `RematScope` is applied correctly + model.fit(x_train, y_train, epochs=1, batch_size=2, verbose=0) + + self.assertLen(mock_remat.rematted_functions, 1) + next(iter(mock_remat.rematted_functions.values())).assert_called() + + def test_remat_wrapper_list_of_layers(self): + """Test rematerialization using list_of_layers mode.""" + mock_remat = MockRemat() + with mock.patch( + "keras.src.backend.common.remat.remat", wraps=mock_remat + ): + + class TestLayer(layers.Layer): + def call(self, x): + return x + 1 + + class OtherLayer(layers.Layer): + def call(self, x): + return x * 2 + + remat_layers = ["test_layer"] + input_tensor = backend.random.uniform((4, 4)) + + with RematScope(mode="list_of_layers", layer_names=remat_layers): + test_layer = TestLayer(name="test_layer") + other_layer = OtherLayer(name="other_layer") + output_test = test_layer(input_tensor) + output_other = other_layer(input_tensor) + + self.assertAllClose(output_test, input_tensor + 1) + self.assertAllClose(output_other, input_tensor * 2) + + # Ensure remat was applied to the correct layer + self.assertLen(mock_remat.rematted_functions, 1) + next(iter(mock_remat.rematted_functions.values())).assert_called() + + def test_remat_larger_than_mode(self): + """Test rematerialization using larger_than mode.""" + mock_remat = MockRemat() + with mock.patch( + "keras.src.backend.common.remat.remat", wraps=mock_remat + ): + + class TestLayer(layers.Layer): + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + return x + 1 + + input_tensor = backend.random.uniform((100, 100)) # Large tensor + + with RematScope(mode="larger_than", output_size_threshold=5000): + layer = TestLayer() + output = layer(input_tensor) + + self.assertAllClose(output, input_tensor + 1) + + # Ensure remat was applied + self.assertLen(mock_remat.rematted_functions, 1) + next(iter(mock_remat.rematted_functions.values())).assert_called() + + def test_remat_larger_than_mode_high_threshold(self): + """Test rematerialization using larger_than mode.""" + mock_remat = MockRemat() + with mock.patch( + "keras.src.backend.common.remat.remat", wraps=mock_remat + ): + + class TestLayer(layers.Layer): + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, x): + return x + 1 + + input_tensor = backend.random.uniform((100, 100)) # Large tensor + + with RematScope(mode="larger_than", output_size_threshold=50000): + layer = TestLayer() + output = layer(input_tensor) + + self.assertAllClose(output, input_tensor + 1) + + # Ensure remat was not applied + self.assertLen(mock_remat.rematted_functions, 0) + + def test_rng_seed_tracking(self): + class RNGLayer(layers.Layer): + def __init__(self): + super().__init__() + self.seed_gen = backend.random.SeedGenerator(seed=1337) + + def call(self, x): + return x * backend.random.normal(x.shape, seed=self.seed_gen) + + layer = RNGLayer() + self.assertEqual(layer.variables, [layer.seed_gen.state]) + self.assertAllClose(layer.variables[0], [1337, 0]) + layer(np.ones((3, 4))) + self.assertAllClose(layer.variables[0], [1337, 1]) + + # Test tracking in list attributes. + class RNGListLayer(layers.Layer): + def __init__(self): + super().__init__() + self.seed_gens = [] + self.seed_gens.append(backend.random.SeedGenerator(seed=1)) + self.seed_gens.append(backend.random.SeedGenerator(seed=10)) + + def call(self, x): + x = x * backend.random.normal(x.shape, seed=self.seed_gens[0]) + x = x * backend.random.normal(x.shape, seed=self.seed_gens[1]) + return x + + layer = RNGListLayer() + self.assertEqual( + layer.variables, + [layer.seed_gens[0].state, layer.seed_gens[1].state], + ) + self.assertAllClose(layer.variables[0], [1, 0]) + self.assertAllClose(layer.variables[1], [10, 0]) + layer(np.ones((3, 4))) + self.assertAllClose(layer.variables[0], [1, 1]) + self.assertAllClose(layer.variables[1], [10, 1]) + + def test_layer_tracking(self): + class LayerWithDenseLayers(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense1 = layers.Dense(units) + self.layer_dict = { + "dense2": layers.Dense(units), + } + self.layer_list = [layers.Dense(units)] + self.units = units + self.seed_generator = backend.random.SeedGenerator(seed=1) + + def build(self, input_shape): + self.layer_list.append(layers.Dense(self.units)) + + def call(self, x): + x = self.dense1(x) + x = self.layer_dict["dense2"](x) + x = self.layer_list[0](x) + x = self.layer_list[1](x) + return x + + class ParentLayer(layers.Layer): + def __init__(self, inner_layer): + super().__init__() + self.inner_layer = inner_layer + + def call(self, x): + return self.inner_layer(x) + + layer = LayerWithDenseLayers(3) + layer.build((1, 3)) + self.assertLen(layer._layers, 4) + layer(np.zeros((1, 3))) + self.assertLen(layer.variables, 9) + self.assertLen(layer.weights, 8) + + layer = ParentLayer(LayerWithDenseLayers(3)) + self.assertLen(layer._layers, 1) + layer(np.zeros((1, 3))) + self.assertLen(layer.variables, 9) + self.assertLen(layer.weights, 8) + + layer = ParentLayer(ParentLayer(LayerWithDenseLayers(3))) + self.assertLen(layer._layers, 1) + layer(np.zeros((1, 3))) + self.assertLen(layer.variables, 9) + self.assertLen(layer.weights, 8) + + def test_metric_tracking(self): + class LayerWithMetric(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense = layers.Dense(units) + self.metric = metrics.MeanSquaredError(name="my_metric") + + def build(self, input_shape): + self.dense.build(input_shape) + + def call(self, x): + return self.dense(x) + + class ParentLayerWithMetric(layers.Layer): + def __init__(self, inner_layer): + super().__init__() + self.inner_layer = inner_layer + self.metric = metrics.MeanSquaredError(name="my_metric") + + def build(self, input_shape): + self.inner_layer.build(input_shape) + + def call(self, x): + return self.inner_layer(x) + + layer = LayerWithMetric(3) + layer.build((1, 3)) + + self.assertLen(layer.metrics, 1) + self.assertLen(layer.metrics_variables, 2) + self.assertLen(layer.trainable_variables, 2) + self.assertLen(layer.non_trainable_variables, 0) + + layer = ParentLayerWithMetric(LayerWithMetric(3)) + layer.build((1, 3)) + + self.assertLen(layer.metrics, 2) + self.assertLen(layer.metrics_variables, 4) + self.assertLen(layer.trainable_variables, 2) + self.assertLen(layer.non_trainable_variables, 0) + + layer = ParentLayerWithMetric(ParentLayerWithMetric(LayerWithMetric(3))) + layer.build((1, 3)) + + self.assertLen(layer.metrics, 3) + self.assertLen(layer.metrics_variables, 6) + self.assertLen(layer.trainable_variables, 2) + self.assertLen(layer.non_trainable_variables, 0) + + def test_build_on_call(self): + class LayerWithUnbuiltState(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense1 = layers.Dense(units) + + def call(self, x): + return self.dense1(x) + + layer = LayerWithUnbuiltState(2) + layer(backend.KerasTensor((3, 4))) + self.assertLen(layer.weights, 2) + + class KwargsLayerWithUnbuiltState(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense1 = layers.Dense(units) + self.dense2 = layers.Dense(units) + + def call(self, x1, x2): + return self.dense1(x1) + self.dense2(x2) + + layer = KwargsLayerWithUnbuiltState(2) + layer(backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))) + self.assertLen(layer.weights, 4) + + layer = KwargsLayerWithUnbuiltState(2) + layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4))) + self.assertLen(layer.weights, 4) + + def test_activity_regularization(self): + class ActivityRegularizer(layers.Layer): + def call(self, x): + return x + + layer = ActivityRegularizer(activity_regularizer="l1") + layer(np.ones((1,))) + self.assertLen(layer.losses, 1) + self.assertAllClose(layer.losses[0], 0.01) + + # losses are reset upon call + layer(np.ones((1,))) + self.assertLen(layer.losses, 1) + self.assertAllClose(layer.losses[0], 0.01) + + # KerasTensors are no op + layer = ActivityRegularizer(activity_regularizer="l1") + layer(layers.Input(batch_shape=(2, 2))) + self.assertLen(layer.losses, 0) + + @pytest.mark.requires_trainable_backend + def test_add_loss(self): + class LossLayer(layers.Layer): + def call(self, x): + self.add_loss(ops.sum(x)) + return x + + layer = LossLayer() + layer(np.ones((1,))) + self.assertLen(layer.losses, 1) + self.assertAllClose(layer.losses[0], 1.0) + + # losses are reset upon call + layer = LossLayer() + layer(np.ones((1,))) + self.assertLen(layer.losses, 1) + self.assertAllClose(layer.losses[0], 1.0) + + # It works inside a model + model = models.Sequential([layer]) + model(np.ones((1,))) + self.assertLen(model.losses, 1) + self.assertAllClose(model.losses[0], 1.0) + + # It works recursively in nested models + model = models.Sequential([model]) + model(np.ones((1,))) + self.assertLen(model.losses, 1) + self.assertAllClose(model.losses[0], 1.0) + + def test_training_arg_value_resolution(self): + # Check that even if `training` is not passed + # to an inner layer, the outer value gets propagated + # in __call__. + class TrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dp = layers.Dropout(0.9) + + def call(self, x, training=False): + return self.dp(x) + + layer = TrainingLayer() + x = np.ones((4, 4)) + y = layer(x) + self.assertEqual(ops.min(y), 1) + y = layer(x, training=True) + self.assertEqual(ops.min(y), 0) + + # Check that it still works one level deeper. + class WrappedTrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dp = TrainingLayer() + + def call(self, x, training=False): + return self.dp(x) + + layer = WrappedTrainingLayer() + x = np.ones((4, 4)) + y = layer(x) + self.assertEqual(ops.min(y), 1) + y = layer(x, training=True) + self.assertEqual(ops.min(y), 0) + + # Check that if `training` is passed + # to an inner layer in call(), the explicitly + # passed value is what the layer sees. + class TrainingLayerExplicit(layers.Layer): + def __init__(self): + super().__init__() + self.dp = layers.Dropout(0.9) + + def call(self, x, training=False): + return self.dp(x, training=True) + + layer = TrainingLayerExplicit() + x = np.ones((4, 4)) + y = layer(x, training=False) + self.assertEqual(ops.min(y), 0) + + # Test that layer interruption does not cause + # the call context to linger + class BadLayer(layers.Layer): + def call(self, x, training=False): + raise RuntimeError("oops!") + + x = np.ones((4, 4)) + layer = BadLayer() + try: + # training=True will be recorded + # in the call context + layer(x, training=True) + except RuntimeError: + pass + layer = TrainingLayer() + # But this layer call should not see it + y = layer(x) + self.assertEqual(ops.min(y), 1) + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Some torch ops not implemented for float16 on CPU.", + ) + def test_mixed_precision(self): + x = np.ones((4, 4)) + + layer = layers.Dense(2, dtype="float16") + y = layer(x) + self.assertEqual(layer.compute_dtype, "float16") + self.assertEqual(layer.variable_dtype, "float16") + self.assertDType(y, "float16") + + layer = layers.Dense(2, dtype="mixed_float16") + y = layer(x) + self.assertEqual(layer.compute_dtype, "float16") + self.assertEqual(layer.variable_dtype, "float32") + self.assertDType(y, "float16") + self.assertEqual(layer.kernel.dtype, "float32") + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Some torch ops not implemented for float16 on CPU.", + ) + def test_autocast(self): + assertDType = self.assertDType + + # A layer with a int dtype (some preprocessing layers do this). + class InnerLayerOne(layers.Layer): + def __init__(self): + super().__init__(dtype="int") + self.v = self.add_weight( + shape=(), + initializer="ones", + trainable=True, + dtype="float32", + ) + self._build_at_init() + + def call(self, x): + # Should not autocast. + assertDType(self.v, "float32") + return ops.add(ops.cast(x, "float32"), self.v) + + # A layer that is explicitly full precision. + class InnerLayerTwo(layers.Layer): + def __init__(self): + super().__init__(dtype="float32") + self.v = self.add_weight( + shape=(), + initializer="ones", + trainable=True, + ) + self._build_at_init() + + def call(self, x): + # Should not autocast. + assertDType(self.v, "float32") + return ops.add(x, self.v) + + # A layer that is explicitly mixed precision but with autocast=False + # weight. + class InnerLayerThree(layers.Layer): + def __init__(self): + super().__init__(dtype="mixed_float16") + self.v = self.add_weight( + shape=(), + initializer="ones", + trainable=True, + autocast=False, + ) + self._build_at_init() + + def call(self, x): + # Should not autocast `self.v`. + assertDType(self.v, "float32") + return ops.add(x, self.v) + + # A layer that is explicitly mixed precision with inner layers. + class MixedPrecisionLayer(layers.Layer): + def __init__(self): + super().__init__(dtype="mixed_float16") + self.v = self.add_weight( + shape=(), + initializer="ones", + trainable=True, + ) + self.inner_one = InnerLayerOne() + self.inner_two = InnerLayerTwo() + self.inner_three = InnerLayerThree() + self._build_at_init() + + def call(self, x): + # Should autocast. + assertDType(self.v, "float16") + return self.inner_three( + self.inner_two(self.inner_one(ops.add(x, self.v))) + ) + + layer = MixedPrecisionLayer() + y = layer(np.array(0.0)) + self.assertEqual(y, 4.0) + + def test_autocast_with_np_array(self): + assertDType = self.assertDType + + class CustomLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + # Here are the assertions. + assertDType(x[0], "float32") # Cast to compute_dtype + assertDType(x[1], "int32") # Untouched + + x = [np.zeros(1, dtype="float64"), np.zeros(1, dtype="int32")] + CustomLayer()(x) + + @pytest.mark.skipif( + backend.backend() == "numpy", reason="masking not supported with numpy" + ) + def test_end_to_end_masking(self): + # Check that masking survives compilation + model = models.Sequential( + [ + layers.Embedding( + 2, 2, mask_zero=True, embeddings_initializer="ones" + ), + ] + ) + model.compile(loss="mse") + targets = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [1.0, 1.0]]]) + loss = model.evaluate(np.array([[1, 0, 0, 1]]), targets, verbose=0) + self.assertAllClose(loss, 0.0) + + @pytest.mark.skipif( + backend.backend() == "numpy", reason="masking not supported with numpy" + ) + def test_masking(self): + class BasicMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + assert mask is not None + return x + + layer = BasicMaskedLayer() + x = backend.numpy.ones((4, 4)) + mask = backend.numpy.ones((4,)) + backend.set_keras_mask(x, mask) + layer(x) + + layer(backend.numpy.ones((4, 4)), mask=backend.numpy.ones((4,))) + + class NestedInputMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + assert isinstance(x, list) + assert len(x) == 2 + assert isinstance(mask, list) + assert len(mask) == 2 + return x + + layer = NestedInputMaskedLayer() + x1 = backend.numpy.ones((4, 4)) + mask1 = backend.numpy.ones((4,)) + backend.set_keras_mask(x1, mask1) + x2 = backend.numpy.ones((4, 4)) + mask2 = backend.numpy.ones((4,)) + backend.set_keras_mask(x2, mask2) + layer([x1, x2]) + + layer( + [backend.numpy.ones((4, 4)), backend.numpy.ones((4, 4))], + mask=[backend.numpy.ones((4,)), backend.numpy.ones((4,))], + ) + + class PositionalInputsMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x1, x2, x1_mask=None, x2_mask=None): + assert x1_mask is not None + assert x2_mask is not None + return x1 + x2 + + layer = PositionalInputsMaskedLayer() + layer(x1, x2) + layer(x1=x1, x2=x2) + + class PositionalNestedInputsMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x1, x2, x1_mask=None, x2_mask=None): + assert isinstance(x1, tuple) + assert x1_mask is not None + assert x2_mask is not None + assert isinstance(x1_mask, tuple) + return x1[0] + x1[1] + x2 + + layer = PositionalNestedInputsMaskedLayer() + x1_1 = backend.numpy.ones((4, 4)) + mask1 = backend.numpy.ones((4,)) + backend.set_keras_mask(x1_1, mask1) + x1_2 = backend.numpy.ones((4, 4)) + mask2 = backend.numpy.ones((4,)) + backend.set_keras_mask(x1_2, mask2) + x2 = backend.numpy.ones((4, 4)) + mask2 = backend.numpy.ones((4,)) + backend.set_keras_mask(x2, mask2) + layer((x1_1, x1_2), x2) + layer(x1=(x1_1, x1_2), x2=x2) + + class MaskUnsetDuringCallLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + assert mask is not None + backend.set_keras_mask(x, None) # Unset mask + return x + + layer = MaskUnsetDuringCallLayer() + x = backend.numpy.ones((4, 4)) + mask = backend.numpy.ones((4,)) + backend.set_keras_mask(x, mask) + y = layer(x) + self.assertAllClose(y._keras_mask, mask) + + @pytest.mark.skipif( + backend.backend() == "numpy", reason="masking not supported with numpy" + ) + def test_masking_with_explicit_kwarg_propagation(self): + """This test validates that an explicit `mask` kwarg is correctly + used to compute the output mask. + """ + + class PassthroughMaskLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + # The layer itself can use the mask. + self.used_mask = mask is not None + return x + + layer = PassthroughMaskLayer() + # Create an input tensor WITHOUT an attached mask. + x = backend.numpy.ones((4, 4)) + self.assertIsNone(getattr(x, "_keras_mask", None)) + + # Create a mask to be passed explicitly. + explicit_mask = backend.numpy.array([True, True, False, False]) + + # Call the layer, passing the mask as a keyword argument. + y = layer(x, mask=explicit_mask) + + # Assert that the layer's internal call received the mask. + self.assertTrue(layer.used_mask) + + # Assert that the output tensor 'y' now has the explicit mask attached + # for propagation to the next layer. + self.assertAllClose(backend.get_keras_mask(y), explicit_mask) + + def test_stateless_call(self): + class TestLayer(layers.Layer): + def __init__(self): + super().__init__() + self._seed_generator = backend.random.SeedGenerator(1337) + self.ntw = self.add_weight( + shape=(), + initializer="zeros", + trainable=False, + ) + self.tw = self.add_weight( + shape=(), + initializer="zeros", + trainable=True, + regularizer="l1", + ) + self._build_at_init() + + def call(self, x): + x = backend.convert_to_tensor(x, dtype="float32") + self.add_loss(ops.sum(x)) + self.ntw.assign(ops.sum(x)) + x = x + backend.random.normal( + shape=(), seed=self._seed_generator + ) + return ops.add(x, ops.add(self.tw, self.ntw)) + + data = np.random.random((3, 4)) + layer = TestLayer() + out = layer(data) + layer1 = TestLayer() + out1 = layer1(data) + # Check that the layer is in fact deterministic + self.assertAllClose(out, out1) + + # Test stateless_call correctness + layer2 = TestLayer() + trainable_variables = layer2.trainable_variables + non_trainable_variables = layer2.non_trainable_variables + out2, non_trainable_variables = layer2.stateless_call( + trainable_variables, non_trainable_variables, data + ) + self.assertAllClose(out1, out2) + self.assertEqual( + len(layer1.non_trainable_variables), len(non_trainable_variables) + ) + for ref_v, v in zip( + layer1.non_trainable_variables, non_trainable_variables + ): + self.assertAllClose(ref_v, v) + + # Test with loss collection + layer3 = TestLayer() + trainable_variables = layer3.trainable_variables + non_trainable_variables = layer3.non_trainable_variables + out3, non_trainable_variables, losses = layer3.stateless_call( + trainable_variables, + non_trainable_variables, + data, + return_losses=True, + ) + self.assertAllClose(out1, out3) + for ref_v, v in zip( + layer1.non_trainable_variables, non_trainable_variables + ): + self.assertAllClose(ref_v, v) + self.assertLen(losses, 2) + for ref_loss, loss in zip(layer1.losses, losses): + self.assertAllClose(ref_loss, loss) + + def test_trainable_setting(self): + class NonTrainableWeightsLayer(layers.Layer): + def build(self, _): + self.w1 = self.add_weight( + shape=(), + initializer="ones", + trainable=True, + ) + self.w2 = self.add_weight( + shape=(), + initializer="ones", + trainable=False, + ) + self.seed = backend.random.SeedGenerator(123) + + def call(self, inputs): + return inputs + + class NestedNonTrainableWeightsLayer(layers.Layer): + def build(self, _): + self.w1 = self.add_weight( + shape=(), + initializer="ones", + trainable=True, + ) + self.w2 = self.add_weight( + shape=(), + initializer="ones", + trainable=False, + ) + self.nested = NonTrainableWeightsLayer() + self.nested.build(None) + + def call(self, inputs): + return inputs + + layer = NestedNonTrainableWeightsLayer() + layer.build(None) + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.trainable_variables), 2) + self.assertEqual(len(layer.non_trainable_weights), 2) + self.assertEqual(len(layer.non_trainable_variables), 3) + + layer.trainable = False + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.trainable_variables), 0) + self.assertEqual(len(layer.non_trainable_weights), 4) + self.assertEqual(len(layer.non_trainable_variables), 5) + self.assertFalse(layer.w1.trainable) + self.assertFalse(layer.nested.w1.trainable) + + layer.trainable = True + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.trainable_variables), 2) + self.assertEqual(len(layer.non_trainable_weights), 2) + self.assertEqual(len(layer.non_trainable_variables), 3) + self.assertTrue(layer.w1.trainable) + self.assertTrue(layer.nested.w1.trainable) + + layer = NestedNonTrainableWeightsLayer(trainable=False) + layer.build(None) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.trainable_variables), 0) + self.assertEqual(len(layer.non_trainable_weights), 4) + self.assertEqual(len(layer.non_trainable_variables), 5) + + layer.trainable = True + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.trainable_variables), 2) + self.assertEqual(len(layer.non_trainable_weights), 2) + self.assertEqual(len(layer.non_trainable_variables), 3) + + def test_build_signature_errors(self): + class NoShapeSuffix(layers.Layer): + def build(self, foo_shape, bar): + self.built = True + + def call(self, foo, bar): + return foo + bar + + class NonMatchingArgument(layers.Layer): + def build(self, foo_shape, baz_shape): + self.built = True + + def call(self, foo, bar): + return foo[:, 0] + bar[:, 0] + + class MatchingArguments(layers.Layer): + def build(self, bar_shape, foo_shape): + self.foo_shape = foo_shape + self.bar_shape = bar_shape + + def call(self, foo, bar): + return foo[:, 0] + bar[:, 0] + + class SubsetArguments(layers.Layer): + def build(self, baz_shape, foo_shape): + self.foo_shape = foo_shape + self.baz_shape = baz_shape + + def call(self, foo, bar=None, baz=None): + return foo[:, 0] + bar[:, 0] + baz[:, 0] + + class SingleArgument(layers.Layer): + def build(self, anything_whatsoever): + self.foo_shape = anything_whatsoever + + def call(self, foo, bar): + return foo[:, 0] + bar[:, 0] + + foo = backend.numpy.ones((4, 1)) + bar = backend.numpy.ones((4, 2)) + baz = backend.numpy.ones((4, 3)) + with self.assertRaisesRegex( + ValueError, + r"argument `bar`, which does not end in `_shape`", + ): + layer = NoShapeSuffix() + layer(foo, bar) + + with self.assertRaisesRegex( + ValueError, + r"`baz_shape`, but `call\(\)` does not have argument `baz`", + ): + layer = NonMatchingArgument() + layer(foo, bar) + + # Align by name when build and call arguments match. + layer = MatchingArguments() + layer(foo, bar) + self.assertEqual(layer.foo_shape, foo.shape) + self.assertEqual(layer.bar_shape, bar.shape) + + # Align by name when build supports a subset of call arguments. + layer = SubsetArguments() + layer(foo, bar, baz) + self.assertEqual(layer.foo_shape, foo.shape) + self.assertEqual(layer.baz_shape, baz.shape) + + # When build has only one argument, match the first call argument. + layer = SingleArgument() + layer(foo, bar) + self.assertEqual(layer.foo_shape, foo.shape) + + def test_training_arg_not_specified(self): + class NoTrainingSpecified(layers.Layer): + def __init__(self): + super().__init__() + + def build(self, input_shape): + self.activation = layers.Activation("linear") + + def call(self, inputs): + return self.activation(inputs) + + layer = NoTrainingSpecified() + inputs = ops.random.uniform(shape=(1, 100, 100, 3)) + layer(inputs, training=True) + + def test_tracker_locking(self): + class BadLayer(layers.Layer): + def call(self, x): + self.w = self.add_weight(initializer="zeros", shape=()) + return x + + layer = BadLayer() + with self.assertRaisesRegex( + ValueError, + "cannot add new elements of state", + ): + layer(np.random.random((3, 2))) + + def test_init_after_state_tracking(self): + class MyLayer(layers.Layer): + def __init__(self): + self.some_attr = True + self.w = backend.Variable(np.random.random((2,))) + super().__init__() + + layer = MyLayer() + self.assertEqual(len(layer.weights), 1) + + def test_add_weight_defaults(self): + class MyLayer(layers.Layer): + def __init__(self): + super().__init__() + self.w1 = self.add_weight() + self.w2 = self.add_weight(dtype="int32", trainable=False) + self.w3 = self.add_weight(dtype="bool", trainable=False) + self.w4 = self.add_weight( + dtype="int32", shape=(2, 2), trainable=False + ) + self.w5 = self.add_weight(initializer="ones", shape=(2, 2)) + + layer = MyLayer() + self.assertEqual(layer.w1.shape, ()) + self.assertEqual(layer.w1.dtype, "float32") + + self.assertEqual(layer.w2.shape, ()) + self.assertEqual(layer.w2.dtype, "int32") + self.assertAllClose(backend.convert_to_numpy(layer.w2), 0) + + self.assertEqual(layer.w3.shape, ()) + self.assertEqual(layer.w3.dtype, "bool") + self.assertAllClose(backend.convert_to_numpy(layer.w3), False) + + self.assertEqual(layer.w4.shape, (2, 2)) + self.assertEqual(layer.w4.dtype, "int32") + self.assertAllClose( + backend.convert_to_numpy(layer.w4), np.zeros((2, 2)) + ) + + self.assertEqual(layer.w5.shape, (2, 2)) + self.assertEqual(layer.w5.dtype, "float32") + self.assertAllClose(backend.convert_to_numpy(layer.w5), np.ones((2, 2))) + + def test_remove_weight(self): + class MyLayer(layers.Layer): + def __init__(self): + super().__init__() + self.w = self.add_weight() + + def custom_remove_w(self): + self.w = self._untrack_variable(self.w) + + def custom_change_dtype(self): + self.w = self._untrack_variable(self.w) + self.w = self.add_weight( + initializer="zeros", dtype="int8", trainable=False + ) + + layer = MyLayer() + self.assertEqual(len(layer.weights), 1) + layer.custom_remove_w() + self.assertEqual(len(layer.weights), 0) + self.assertEqual(layer.w, None) + + layer = MyLayer() + self.assertEqual(layer.w.dtype, "float32") + self.assertEqual(layer.w.trainable, True) + layer.custom_change_dtype() + self.assertEqual(layer.w.dtype, "int8") + self.assertEqual(layer.w.trainable, False) + + def test_trainable_init_arg(self): + inputs = layers.Input(shape=(1,)) + layer = layers.Dense(2, trainable=False) + outputs = layer(inputs) + model = models.Model(inputs, outputs) + + self.assertFalse(layer.trainable) + self.assertLen(layer._trainable_variables, 2) + self.assertLen(layer._non_trainable_variables, 0) + self.assertLen(layer.trainable_weights, 0) + self.assertLen(model.trainable_weights, 0) + self.assertLen(model.non_trainable_weights, 2) + + layer.trainable = True + self.assertTrue(layer.trainable) + self.assertLen(layer._trainable_variables, 2) + self.assertLen(layer._non_trainable_variables, 0) + self.assertLen(layer.trainable_weights, 2) + self.assertLen(model.trainable_weights, 2) + self.assertLen(model.non_trainable_weights, 0) + + def test_dtype_policy_setter(self): + layer = layers.Dense(2) + # Set by string + layer.dtype_policy = "mixed_bfloat16" + self.assertEqual(layer.dtype_policy.name, "mixed_bfloat16") + self.assertEqual(layer.dtype_policy.compute_dtype, "bfloat16") + self.assertEqual(layer.dtype_policy.variable_dtype, "float32") + # Set by DTypePolicy + layer.dtype_policy = dtype_policies.DTypePolicy("mixed_float16") + self.assertEqual(layer.dtype_policy.name, "mixed_float16") + self.assertEqual(layer.dtype_policy.compute_dtype, "float16") + self.assertEqual(layer.dtype_policy.variable_dtype, "float32") + # Set with DTypePolicyMap + dtype_policy_map = dtype_policies.DTypePolicyMap() + layer = layers.Dense(2, dtype=dtype_policy_map) + layer.build([None, 1]) + layer.dtype_policy = "mixed_bfloat16" + self.assertIsInstance( + layer._dtype_policy, dtype_policies.DTypePolicyMap + ) + self.assertEqual( + layer._dtype_policy[layer.path], + dtype_policies.DTypePolicy("mixed_bfloat16"), + ) + + def test_pickle_layer(self): + layer = layers.Dense(2) + reloaded = pickle.loads(pickle.dumps(layer)) + self.assertEqual(layer.get_config(), reloaded.get_config()) + + def test_serialize_dtype(self): + assertIsNone = self.assertIsNone + assertIsNotNone = self.assertIsNotNone + + class AssertionDense(layers.Dense): + def __init__(self, *args, **kwargs): + dtype = kwargs["dtype"] + if isinstance(dtype, str): + # `dtype` is a plain string, it should be the `name` from a + # `DTypePolicy` + dtype = dtype_policies.get(dtype) + assertIsNone(dtype.quantization_mode) + else: + # `dtype` is a DTypePolicy instance, it should be an + # instance of `QuantizedDTypePolicy` + assertIsNotNone(dtype.quantization_mode) + super().__init__(*args, **kwargs) + + # Test floating dtype serialization + layer = layers.Dense(2, dtype="bfloat16") + config = layer.get_config() + self.assertIn("dtype", config) + self.assertEqual( + config["dtype"], + dtype_policies.serialize(dtype_policies.DTypePolicy("bfloat16")), + ) + AssertionDense.from_config(config) # Assertion inside + + # Test quantized dtype serialization + layer = layers.Dense(2, dtype="int8_from_bfloat16") + config = layer.get_config() + self.assertIn("dtype", config) + self.assertEqual( + config["dtype"], + dtype_policies.serialize(dtype_policies.get("int8_from_bfloat16")), + ) + AssertionDense.from_config(config) # Assertion inside + + def test_serialize_activity_regularizer(self): + layer = layers.Dense(2, activity_regularizer="l2") + config = layer.get_config() + self.assertIn("activity_regularizer", config) + new_layer = layers.Dense.from_config(config) + self.assertEqual( + new_layer.activity_regularizer.__class__.__name__, "L2" + ) + + layer = layers.Dense(2) + config = layer.get_config() + self.assertNotIn("activity_regularizer", config) + + def test_custom_layer_add_weight_in_init_name(self): + class TrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.inner = InnerLayer() + + class InnerLayer(layers.Layer): + def __init__(self): + super().__init__() + self.var = self.add_weight( + shape=(1,), + name="inner", + ) + self.inner = InnerInnerLayer() + + class InnerInnerLayer(layers.Layer): + def __init__(self): + super().__init__() + self.var = self.add_weight( + shape=(1,), + name="inner", + ) + + layer = TrainingLayer() + layer.build(None) + self.assertEqual(len(layer.variables), 2) + variable_paths = set(v.path for v in layer.variables) + self.assertTrue("inner_layer/inner" in variable_paths) + self.assertTrue("inner_inner_layer/inner" in variable_paths) + if backend.backend() == "torch": + parameter_names = set( + param_name.replace("_torch_params.", "") + for param_name, _ in layer.named_parameters() + ) + self.assertSetEqual(variable_paths, parameter_names) + + def test_custom_layer_add_weight_in_build_name(self): + class TrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.inner = InnerLayer() + + def call(self, input): + return self.inner(input) + + class InnerLayer(layers.Layer): + def __init__(self): + super().__init__() + self.inner = InnerInnerLayer() + + def build(self, _): + self.var = self.add_weight( + shape=(1,), + name="inner", + ) + + def call(self, input): + return self.var + self.inner(input) + + class InnerInnerLayer(layers.Layer): + def __init__(self): + super().__init__() + + def build(self, _): + self.var = self.add_weight( + shape=(1,), + name="inner", + ) + + def call(self, input): + return self.var + input + + layer = TrainingLayer() + output = layer( + backend.KerasTensor( + (4, 1), + ) + ) + self.assertEqual(output.shape, (4, 1)) + self.assertEqual(len(layer.variables), 2) + variable_paths = set(v.path for v in layer.variables) + self.assertTrue("training_layer/inner_layer/inner" in variable_paths) + self.assertTrue( + "training_layer/inner_layer/inner_inner_layer/inner" + in variable_paths + ) + if backend.backend() == "torch": + parameter_names = set( + param_name.replace("_torch_params.", "") + for param_name, _ in layer.named_parameters() + ) + self.assertSetEqual(variable_paths, parameter_names) + + def test_layer_variable_tracking_correct(self): + class TrainingLayer(layers.Layer): + def __init__(self): + super().__init__() + self.post_build_modify_layer = PostBuildModifyLayer() + + def call(self, input): + return self.post_build_modify_layer(input) + + class PostBuildModifyLayer(layers.Layer): + def call(self, input): + return self.var + input + + def build(self, _): + self.var = self.add_weight( + shape=(2,), + name="var", + ) + + def post_build_add(self): + self._tracker.unlock() + self.additional_var = self.add_weight( + shape=(2,), + name="var2", + ) + self._tracker.lock() + + def post_build_remove(self): + self._tracker.unlock() + self._untrack_variable(self.var) + del self.var + self._tracker.lock() + + layer = TrainingLayer() + output = layer(backend.KerasTensor((4, 2))) + + self.assertEqual(output.shape, (4, 2)) + self.assertEqual(len(layer.variables), 1) + self.assertEqual( + layer.variables[0].path, + "training_layer/post_build_modify_layer/var", + ) + if backend.backend() == "torch": + parameter_names = [pname for pname, _ in layer.named_parameters()] + self.assertEqual(len(parameter_names), 1) + self.assertEqual( + parameter_names[0], + "_torch_params.training_layer/post_build_modify_layer/var", + ) + + layer.post_build_modify_layer.post_build_add() + self.assertEqual(len(layer.variables), 2) + self.assertEqual( + layer.variables[0].path, + "training_layer/post_build_modify_layer/var", + ) + self.assertEqual( + layer.variables[1].path, + "training_layer/post_build_modify_layer/var2", + ) + if backend.backend() == "torch": + # TODO (haohuanw, fchollet): Needs further discussion on how to + # properly manage torch params. Post build modification cannot + # propagate to parent torch params. + parameter_names = [pname for pname, _ in layer.named_parameters()] + # Below check should have 2 parameters instead of 1. + self.assertEqual(len(parameter_names), 1) + self.assertEqual( + parameter_names[0], + "_torch_params.training_layer/post_build_modify_layer/var", + ) + + parameter_names = [ + pname + for pname, _ in layer.post_build_modify_layer.named_parameters() + ] + self.assertEqual(len(parameter_names), 2) + self.assertEqual( + parameter_names[0], + "_torch_params.training_layer/post_build_modify_layer/var", + ) + self.assertEqual( + parameter_names[1], + "_torch_params.training_layer/post_build_modify_layer/var2", + ) + + layer.post_build_modify_layer.post_build_remove() + self.assertEqual(len(layer.variables), 1) + self.assertEqual( + layer.variables[0].path, + "training_layer/post_build_modify_layer/var2", + ) + if backend.backend() == "torch": + # TODO (haohuanw, fchollet): Needs further discussion on how to + # properly manage torch params. Post build modification cannot + # propagate to parent torch params. + parameter_names = [pname for pname, _ in layer.named_parameters()] + # Below check should have 1 parameters instead of 2, torch_params + # in parent layer is wrong. + self.assertEqual(len(parameter_names), 2) + self.assertEqual( + parameter_names[0], + "post_build_modify_layer._torch_params.training_layer/" + "post_build_modify_layer/var2", + ) + self.assertEqual( + parameter_names[1], + "_torch_params.training_layer/post_build_modify_layer/var", + ) + + parameter_names = [ + pname + for pname, _ in layer.post_build_modify_layer.named_parameters() + ] + self.assertEqual(len(parameter_names), 1) + self.assertEqual( + parameter_names[0], + "_torch_params.training_layer/post_build_modify_layer/var2", + ) + + @pytest.mark.skipif(backend.backend() != "torch", reason="Torch only test.") + def test_torch_params_create_deterministic(self): + class MyLayer(layers.Layer): + def __init__(self): + super().__init__() + self.w1 = self.add_weight() + self.w2 = self.add_weight(dtype="int32", trainable=False) + self.w3 = self.add_weight(dtype="bool", trainable=False) + self.w4 = self.add_weight( + dtype="int32", shape=(2, 2), trainable=False + ) + self.w5 = self.add_weight(initializer="ones", shape=(2, 2)) + + layer1 = MyLayer() + layer1.build(None) + layer1_names = list(pname for pname, _ in layer1.named_parameters()) + global_state.clear_session() + layer2 = MyLayer() + layer2.build(None) + layer2_names = list(pname for pname, _ in layer2.named_parameters()) + self.assertListEqual(layer1_names, layer2_names) + + def test_complex_dtype_support(self): + class MyDenseLayer(layers.Layer): + def __init__(self, num_outputs): + super(MyDenseLayer, self).__init__() + self.num_outputs = num_outputs + + def build(self, input_shape): + self.kernel = self.add_weight( + shape=[int(input_shape[-1]), self.num_outputs], + ) + + def call(self, inputs): + kernel = ops.cast(self.kernel, "complex64") + return ops.matmul(inputs, kernel) + + inputs = ops.zeros([10, 5], dtype="complex64") + layer = MyDenseLayer(10) + output = layer(inputs) + self.assertAllEqual(output.shape, (10, 10)) + + def test_call_context_args_with_custom_layers(self): + class Inner(layers.Layer): + def __init__(self): + super().__init__() + self._register_call_context_args("foo_mode") + + def call(self, x, foo_mode=None): + return x + (1 if foo_mode else 0) + + class Outer(layers.Layer): + def __init__(self): + super().__init__() + self._register_call_context_args("foo_mode") + self.inner = Inner() + + def call(self, x): + # Outer doesn’t even need to re‑inject explicitly: + # our base class will propagate foo_mode automatically + return self.inner(x) + + layer = Outer() + self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1) + self.assertEqual(int(layer(np.array(0))), 0) + + def test_register_call_context_arguments(self): + """Validate that registering call-context args works as expected.""" + + class MyLayer(layers.Layer): + def call(self, x): + return x + + layer = MyLayer() + + layer._register_call_context_args("foo_mode") + + self.assertCountEqual( + layer._call_context_args, ("foo_mode", "training") + ) + + def test_register_call_context_arguments_after_call(self): + """Validate that registering call-context args after the layer has + been called raises an error.""" + + class MyLayer(layers.Layer): + def call(self, x): + return x + + layer = MyLayer() + layer(np.array(0)) + with self.assertRaisesRegex( + RuntimeError, + "Cannot add call-context args after the layer has been called.", + ): + layer._register_call_context_args("foo_mode") + + def test_context_args_with_triple_nesting_and_priority(self): + """Validate that call-context args are propagated correctly + through multiple layers, and that the most specific value is used + when multiple values are passed down the call-stack. + """ + + class Inner(layers.Layer): + def __init__(self): + super().__init__() + self._register_call_context_args("foo_mode") + + def call(self, x, foo_mode=None): + return x + (1 if foo_mode else 0) + + class Middle(layers.Layer): + def __init__(self): + super().__init__() + self.inner = Inner() + + def call(self, x): + return self.inner(x) + + class Outer(layers.Layer): + def __init__(self): + super().__init__() + self.middle = Middle() + + def call(self, x): + # Outer explicitly sets foo_mode=False when calling Inner, + # so the value being passed here should be ignored. + return self.middle(x) + + layer = Outer() + layer._register_call_context_args("foo_mode") + + # The value of foo_mode is set to True in the call to Outer, + # so it should automatically propagate to Inner through Middle. + self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1) + self.assertEqual(int(layer(np.array(0))), 0) + + def test_context_arg_propagation_without_declaration(self): + """Validate that layer does not resolve a propagated arg if it is not + declared as a call-context arg in the layer itself.""" + + class Inner(layers.Layer): + def call(self, x, foo_mode=None): + return x + (1 if foo_mode else 0) + + class Wrapper(layers.Layer): + def __init__(self): + super().__init__() + self.inner = Inner() + + def call(self, x): + return self.inner(x) + + layer = Wrapper() + layer._register_call_context_args("foo_mode") + + # The value of foo_mode is set to True in the call to Wrapper, + # However, it is not declared as a call-context arg in Inner, + # so it should not resolve to True inside Inner (and instead + # default to False). + self.assertEqual(int(layer(np.array(0), foo_mode=True)), 0) + + def test_call_context_args_with_func_seq_models_as_layers(self): + """Validate that call-context args are propagated correctly + through functional and sequential models when used as layers. + """ + + class Inner(layers.Layer): + def __init__(self): + super().__init__() + self._register_call_context_args("foo_mode") + + def call(self, x, foo_mode=False): + # If foo_mode=True add 1, otherwise add 0 + add_val = ops.where(foo_mode, 1.0, 0.0) + return x + add_val + + class Outer(layers.Layer): + def __init__(self): + super().__init__() + self.inner = Inner() + + def call(self, x): + # We don’t explicitly pass foo_mode here—Base Layer.__call__ + # should inject it into `self.inner` + return self.inner(x) + + sample_input = np.array([[1.0], [2.0]]) + + # Sequential model + seq = models.Sequential([layers.Identity(), Outer()]) + # Tell the Sequential model to propagate foo_mode down + # the call-stack + seq._register_call_context_args("foo_mode") + + # foo_mode=True -> input + 1 + out_true = seq(sample_input, foo_mode=True) + self.assertAllClose(out_true, sample_input + 1.0) + + # foo_mode omitted -> foo_mode defaults to False -> no change + out_false = seq(sample_input) + self.assertAllClose(out_false, sample_input) + + # Functional model + inp = Input(shape=(1,)) + out = layers.Identity()(inp) + out = Outer()(out) + model = models.Model(inp, out) + # Tell the Functional model to propagate foo_mode down + # the call-stack + model._register_call_context_args("foo_mode") + + # foo_mode=True -> input + 1 + y1 = model(sample_input, foo_mode=True) + self.assertAllClose(y1, sample_input + 1.0) + + # foo_mode omitted -> foo_mode defaults to False -> no change + y2 = model(sample_input) + self.assertAllClose(y2, sample_input) diff --git a/keras/src/layers/merging/__init__.py b/keras/src/layers/merging/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/merging/add.py b/keras/src/layers/merging/add.py new file mode 100644 index 000000000000..bf5f1b2a6aac --- /dev/null +++ b/keras/src/layers/merging/add.py @@ -0,0 +1,69 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Add") +class Add(Merge): + """Performs elementwise addition operation. + + It takes as input a list of tensors, all of the same shape, + and returns a single tensor (also of the same shape). + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.Add()([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> # equivalent to `added = keras.layers.add([x1, x2])` + >>> added = keras.layers.Add()([x1, x2]) + >>> out = keras.layers.Dense(4)(added) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + + def _merge_function(self, inputs): + output = inputs[0] + for i in range(1, len(inputs)): + output = ops.add(output, inputs[i]) + return output + + +@keras_export("keras.layers.add") +def add(inputs, **kwargs): + """Functional interface to the `keras.layers.Add` layer. + + Args: + inputs: A list of input tensors with the same shape. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor as the sum of the inputs. It has the same shape as the inputs. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.add([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> added = keras.layers.add([x1, x2]) + >>> out = keras.layers.Dense(4)(added) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + return Add(**kwargs)(inputs) diff --git a/keras/src/layers/merging/average.py b/keras/src/layers/merging/average.py new file mode 100644 index 000000000000..f90f75beead0 --- /dev/null +++ b/keras/src/layers/merging/average.py @@ -0,0 +1,70 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Average") +class Average(Merge): + """Averages a list of inputs element-wise.. + + It takes as input a list of tensors, all of the same shape, + and returns a single tensor (also of the same shape). + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.Average()([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> # equivalent to `y = keras.layers.average([x1, x2])` + >>> y = keras.layers.Average()([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + + def _merge_function(self, inputs): + output = inputs[0] + for i in range(1, len(inputs)): + output = ops.add(output, inputs[i]) + return output / len(inputs) + + +@keras_export("keras.layers.average") +def average(inputs, **kwargs): + """Functional interface to the `keras.layers.Average` layer. + + Args: + inputs: A list of input tensors , all of the same shape. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor as the element-wise product of the inputs with the same + shape as the inputs. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.average([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> y = keras.layers.average([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + return Average(**kwargs)(inputs) diff --git a/keras/src/layers/merging/base_merge.py b/keras/src/layers/merging/base_merge.py new file mode 100644 index 000000000000..10689b54208d --- /dev/null +++ b/keras/src/layers/merging/base_merge.py @@ -0,0 +1,280 @@ +from keras.src import backend +from keras.src import ops +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.layers.layer import Layer + + +class Merge(Layer): + """Generic merge layer for elementwise merge functions. + + Used to implement `Sum`, `Average`, etc. + + Args: + **kwargs: standard layer keyword arguments. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.supports_masking = True + + def _merge_function(self, inputs): + raise NotImplementedError + + def _apply_merge_op_and_or_mask(self, op_fn, inputs): + """Merge a set of inputs by applying `op_fn` and ORing the masks. + + We use this for `Minimum` and `Maximum` as it handles the fact that + there is no identity element. If applicable, the mask obtained by ORing + all masks is set on the output. + + Args: + op_fn: binary operation to apply to tensor pair. + inputs: array of tensors to apply operation on. + """ + output = None + output_mask = None + + for x in inputs: + mask = backend.get_keras_mask(x) + if mask is not None: + mask = ops.broadcast_to(ops.expand_dims(mask, -1), ops.shape(x)) + if output is None: + output = x + output_mask = mask + continue + if mask is not None: + x = ops.where(mask, x, output) + if output_mask is not None: + output = ops.where(output_mask, output, x) + if mask is not None and output_mask is not None: + output_mask = ops.logical_or(output_mask, mask) + else: + output_mask = None + output = op_fn(output, x) + + if output_mask is not None: + output_mask = ops.any(output_mask, axis=-1, keepdims=False) + backend.set_keras_mask(output, output_mask) + return output + + def _compute_elemwise_op_output_shape(self, shape1, shape2): + """Computes the shape of the resultant of an elementwise operation. + + Args: + shape1: Tuple or None. Shape of the first tensor + shape2: Tuple or None. Shape of the second tensor + + Returns: + Expected output shape when an element-wise operation is + carried out on 2 tensors with shapes shape1 and shape2. + tuple or None. + + Raises: + ValueError: If shape1 and shape2 are not compatible for + element-wise operations. + """ + + if None in [shape1, shape2]: + return None + elif len(shape1) < len(shape2): + return self._compute_elemwise_op_output_shape(shape2, shape1) + elif not shape2: + return shape1 + output_shape = list(shape1[: -len(shape2)]) + for i, j in zip(shape1[-len(shape2) :], shape2): + if i is None or j is None: + output_shape.append(None) + elif i == 1: + output_shape.append(j) + elif j == 1: + output_shape.append(i) + else: + if i != j: + raise ValueError( + "Inputs have incompatible shapes. " + f"Received shapes {shape1} and {shape2}" + ) + output_shape.append(i) + return tuple(output_shape) + + def build(self, input_shape): + # Used purely for shape validation. + if not isinstance(input_shape[0], (tuple, list)): + raise ValueError( + "A merge layer should be called on a list of inputs. " + f"Received: input_shape={input_shape} (not a list of shapes)" + ) + if len(input_shape) < 1: + raise ValueError( + "A merge layer should be called " + "on a list of at least 1 input. " + f"Received {len(input_shape)} inputs. " + f"Full input_shape received: {input_shape}" + ) + + batch_sizes = {s[0] for s in input_shape if s} - {None} + if len(batch_sizes) > 1: + raise ValueError( + "Cannot merge tensors with different batch sizes. " + f"Received tensors with shapes {input_shape}" + ) + + if input_shape[0] is None: + output_shape = None + else: + output_shape = input_shape[0][1:] + + for i in range(1, len(input_shape)): + if input_shape[i] is None: + shape = None + else: + shape = input_shape[i][1:] + output_shape = self._compute_elemwise_op_output_shape( + output_shape, shape + ) + + # If the inputs have different ranks, we have to reshape them + # to make them broadcastable. + if None not in input_shape and len(set(map(len, input_shape))) == 1: + self._reshape_required = False + else: + self._reshape_required = True + + def call(self, inputs): + if not isinstance(inputs, (list, tuple)): + raise ValueError( + "A merge layer should be called on a list of inputs. " + f"Received: inputs={inputs} (not a list of tensors)" + ) + if self._reshape_required: + reshaped_inputs = [] + input_ndims = list(map(ops.ndim, inputs)) + if None not in input_ndims: + # If ranks of all inputs are available, + # we simply expand each of them at axis=1 + # until all of them have the same rank. + max_ndim = max(input_ndims) + for x in inputs: + x_ndim = ops.ndim(x) + for _ in range(max_ndim - x_ndim): + x = ops.expand_dims(x, axis=1) + reshaped_inputs.append(x) + return self._merge_function(reshaped_inputs) + else: + # Transpose all inputs so that batch size is the last dimension. + # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , + # batch_size) + transposed = False + for x in inputs: + x_ndim = ops.ndim(x) + + if x_ndim is None: + x_shape = ops.shape(x) + batch_size = x_shape[0] + + new_shape = backend.concatenate( + [x_shape[1:], ops.expand_dims(batch_size, axis=-1)] + ) + x_transposed = ops.reshape( + x, + ops.stack( + [batch_size, ops.prod(x_shape[1:])], + axis=0, + ), + ) + x_transposed = ops.transpose(x_transposed, perm=(1, 0)) + x_transposed = ops.reshape(x_transposed, new_shape) + + reshaped_inputs.append(x_transposed) + transposed = True + + elif x_ndim > 1: + dims = list(range(1, x_ndim)) + [0] + reshaped_inputs.append(ops.transpose(x, perm=dims)) + print(dims) + transposed = True + else: + # We don't transpose inputs if they are 1D vectors or + # scalars. + reshaped_inputs.append(x) + + y = self._merge_function(reshaped_inputs) + y_ndim = ops.ndim(y) + + if transposed: + # If inputs have been transposed, we have to transpose the + # output too. + if y_ndim is None: + y_shape = ops.shape(y) + y_ndim = ops.shape(y_shape)[0] + batch_size = y_shape[y_ndim - 1] + new_shape = ops.concatenate( + [ + ops.expand_dims(batch_size, axis=-1), + y_shape[: y_ndim - 1], + ] + ) + y = ops.reshape(y, (-1, batch_size)) + y = ops.transpose(y, perm=(1, 0)) + y = ops.reshape(y, new_shape) + elif y_ndim > 1: + dims = [y_ndim - 1] + list(range(y_ndim - 1)) + y = ops.transpose(y, perm=dims) + return y + else: + return self._merge_function(inputs) + + def compute_output_shape(self, input_shape): + if input_shape[0] is None: + output_shape = None + else: + output_shape = input_shape[0][1:] + + for i in range(1, len(input_shape)): + if input_shape[i] is None: + shape = None + else: + shape = input_shape[i][1:] + output_shape = self._compute_elemwise_op_output_shape( + output_shape, shape + ) + batch_sizes = {s[0] for s in input_shape if s is not None} - {None} + if len(batch_sizes) == 1: + output_shape = (list(batch_sizes)[0],) + output_shape + else: + output_shape = (None,) + output_shape + return output_shape + + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape([x.shape for x in inputs]) + output_sparse = all(x.sparse for x in inputs) + return KerasTensor( + output_shape, dtype=self.compute_dtype, sparse=output_sparse + ) + + def compute_mask(self, inputs, mask=None): + if mask is None: + return None + if not isinstance(mask, (tuple, list)): + raise ValueError(f"`mask` should be a list. Received: mask={mask}") + if not isinstance(inputs, (tuple, list)): + raise ValueError( + f"`inputs` should be a list. Received: inputs={inputs}" + ) + if len(mask) != len(inputs): + raise ValueError( + "The lists `inputs` and `mask` should have the same length. " + f"Received: inputs={inputs} of length {len(inputs)}, and " + f"mask={mask} of length {len(mask)}" + ) + # Default implementation does an OR between the masks, which works + # for `Add`, `Subtract`, `Average`, `Maximum`, `Minimum`, `Multiply`. + if any(m is None for m in mask): + return None + output_mask = mask[0] + for m in mask[1:]: + output_mask = ops.logical_or(output_mask, m) + return output_mask + + def get_config(self): + return super().get_config() diff --git a/keras/src/layers/merging/concatenate.py b/keras/src/layers/merging/concatenate.py new file mode 100644 index 000000000000..1ee3913b6581 --- /dev/null +++ b/keras/src/layers/merging/concatenate.py @@ -0,0 +1,178 @@ +import copy + +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Concatenate") +class Concatenate(Merge): + """Concatenates a list of inputs. + + It takes as input a list of tensors, all of the same shape except + for the concatenation axis, and returns a single tensor that is the + concatenation of all inputs. + + Examples: + + >>> x = np.arange(20).reshape(2, 2, 5) + >>> y = np.arange(20, 30).reshape(2, 1, 5) + >>> keras.layers.Concatenate(axis=1)([x, y]) + + Usage in a Keras model: + + >>> x1 = keras.layers.Dense(8)(np.arange(10).reshape(5, 2)) + >>> x2 = keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2)) + >>> y = keras.layers.Concatenate()([x1, x2]) + + Args: + axis: Axis along which to concatenate. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor, the concatenation of the inputs alongside axis `axis`. + """ + + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + self.supports_masking = True + self._reshape_required = False + + def build(self, input_shape): + # Used purely for shape validation. + if len(input_shape) < 1 or not isinstance( + input_shape[0], (tuple, list) + ): + raise ValueError( + "A `Concatenate` layer should be called on a list of " + f"at least 1 input. Received: input_shape={input_shape}" + ) + if all(shape is None for shape in input_shape): + return + + reduced_inputs_shapes = [list(shape) for shape in input_shape] + reduced_inputs_shapes_copy = copy.copy(reduced_inputs_shapes) + shape_set = set() + for i in range(len(reduced_inputs_shapes_copy)): + # Convert self.axis to positive axis for each input + # in case self.axis is a negative number + concat_axis = self.axis % len(reduced_inputs_shapes_copy[i]) + # Skip batch axis. + for axis, axis_value in enumerate( + reduced_inputs_shapes_copy, start=1 + ): + # Remove squeezable axes (axes with value of 1) + # if not in the axis that will be used for concatenation + # otherwise leave it. + # This approach allows building the layer, + # but if tensor shapes are not the same when + # calling, an exception will be raised. + if axis != concat_axis and axis_value == 1: + del reduced_inputs_shapes[i][axis] + + if len(reduced_inputs_shapes[i]) > self.axis: + del reduced_inputs_shapes[i][self.axis] + shape_set.add(tuple(reduced_inputs_shapes[i])) + + if len(shape_set) != 1: + err_msg = ( + "A `Concatenate` layer requires inputs with matching shapes " + "except for the concatenation axis. " + f"Received: input_shape={input_shape}" + ) + # Make sure all the shapes have same ranks. + ranks = set(len(shape) for shape in shape_set) + if len(ranks) != 1: + raise ValueError(err_msg) + # Get the only rank for the set. + (rank,) = ranks + for axis in range(rank): + # Skip the Nones in the shape since they are dynamic, also the + # axis for concat has been removed above. + unique_dims = set( + shape[axis] + for shape in shape_set + if shape[axis] is not None + ) + if len(unique_dims) > 1: + raise ValueError(err_msg) + + def _merge_function(self, inputs): + return ops.concatenate(inputs, axis=self.axis) + + def compute_output_shape(self, input_shape): + if (not isinstance(input_shape, (tuple, list))) or ( + not isinstance(input_shape[0], (tuple, list)) + ): + raise ValueError( + "A `Concatenate` layer should be called on a list of inputs. " + f"Received: input_shape={input_shape}" + ) + input_shapes = input_shape + output_shape = list(input_shapes[0]) + + for shape in input_shapes[1:]: + if output_shape[self.axis] is None or shape[self.axis] is None: + output_shape[self.axis] = None + break + output_shape[self.axis] += shape[self.axis] + return tuple(output_shape) + + def compute_mask(self, inputs, mask=None): + if mask is None: + return None + if not isinstance(mask, (tuple, list)): + raise ValueError(f"`mask` should be a list. Received mask={mask}") + if not isinstance(inputs, (tuple, list)): + raise ValueError( + f"`inputs` should be a list. Received: inputs={inputs}" + ) + if len(mask) != len(inputs): + raise ValueError( + "The lists `inputs` and `mask` should have the same length. " + f"Received: inputs={inputs} of length {len(inputs)}, and " + f"mask={mask} of length {len(mask)}" + ) + if all(m is None for m in mask): + return None + # Make a list of masks while making sure + # the dimensionality of each mask + # is the same as the corresponding input. + masks = [] + for input_i, mask_i in zip(inputs, mask): + if mask_i is None: + # Input is unmasked. Append all 1s to masks, + masks.append(ops.ones_like(input_i, dtype="bool")) + elif mask_i.ndim < input_i.ndim: + # Broadcast mask shape to match in a way where we capture the + # input as a symbolic input in the op graph. + mask_i = ops.logical_or( + ops.expand_dims(mask_i, axis=-1), + ops.zeros_like(input_i, dtype="bool"), + ) + masks.append(mask_i) + else: + masks.append(mask_i) + concatenated = ops.concatenate(masks, axis=self.axis) + return ops.any(concatenated, axis=-1, keepdims=False) + + def get_config(self): + config = {"axis": self.axis} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@keras_export("keras.layers.concatenate") +def concatenate(inputs, axis=-1, **kwargs): + """Functional interface to the `Concatenate` layer. + + Args: + inputs: A list of input tensors. + axis: Concatenation axis. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor, the concatenation of the inputs alongside axis `axis`. + """ + return Concatenate(axis=axis, **kwargs)(inputs) diff --git a/keras/src/layers/merging/dot.py b/keras/src/layers/merging/dot.py new file mode 100644 index 000000000000..b49b965828ce --- /dev/null +++ b/keras/src/layers/merging/dot.py @@ -0,0 +1,378 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge +from keras.src.utils.numerical_utils import normalize + + +def batch_dot(x, y, axes=None): + """Batchwise dot product. + + `batch_dot` is used to compute dot product of `x` and `y` when + `x` and `y` are data in batch, i.e. in a shape of `(batch_size, :)`. + `batch_dot` results in a tensor or variable with less dimensions + than the input. If the number of dimensions is reduced to 1, + we use `expand_dims` to make sure that ndim is at least 2. + + Shape inference: + + Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`. + If `axes` is (1, 2), to find the output shape of resultant tensor, + loop through each dimension in `x`'s shape and `y`'s shape: + + * `x.shape[0]` : 100 : append to output shape + * `x.shape[1]` : 20 : do not append to output shape, dimension 1 of + `x` has been summed over. (`dot_axes[0]` = 1) + * `y.shape[0]` : 100 : do not append to output shape, always ignore + first dimension of `y` + * `y.shape[1]` : 30 : append to output shape + * `y.shape[2]` : 20 : do not append to output shape, dimension 2 of + `y` has been summed over. + (`dot_axes[1]` = 2) `output_shape` = `(100, 30)` + + Example: + + >>> x_batch = np.ones(shape=(32, 20, 1)) + >>> y_batch = np.ones(shape=(32, 30, 20)) + >>> xy_batch_dot = batch_dot(x_batch, y_batch, axes=(1, 2)) + + Args: + x: Keras tensor or variable with `ndim >= 2`. + y: Keras tensor or variable with `ndim >= 2`. + axes: Tuple or list of integers with target dimensions, or single + integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]` + should be equal. + Note that axis `0` (the batch axis) cannot be included. + + Returns: + A tensor with shape equal to the concatenation of `x`'s shape + (less the dimension that was summed over) and `y`'s shape (less the + batch dimension and the dimension that was summed over). If the final + rank is 1, we reshape it to `(batch_size, 1)`. + """ + + x_shape = x.shape + y_shape = y.shape + + x_ndim = len(x_shape) + y_ndim = len(y_shape) + + if x_ndim < 2 or y_ndim < 2: + raise ValueError( + f"Cannot do batch_dot on inputs " + f"with rank < 2. " + f"Received inputs with shapes " + f"{x_shape} and {y_shape}." + ) + + x_batch_size = x_shape[0] + y_batch_size = y_shape[0] + + if x_batch_size is not None and y_batch_size is not None: + if x_batch_size != y_batch_size: + raise ValueError( + f"Cannot do batch_dot on inputs " + f"with different batch sizes. " + f"Received inputs with shapes " + f"{x_shape} and {y_shape}." + ) + if isinstance(axes, int): + axes = [axes, axes] + + if axes is None: + if y_ndim == 2: + axes = [x_ndim - 1, y_ndim - 1] + else: + axes = [x_ndim - 1, y_ndim - 2] + + if any(isinstance(a, (list, tuple)) for a in axes): + raise ValueError( + f"Multiple target dimensions are not supported. " + f"Expected: None, int, (int, int), " + f"Provided: {axes} " + ) + + # if tuple, convert to list. + axes = list(axes) + + # convert negative indices. + if axes[0] < 0: + axes[0] += x_ndim + if axes[1] < 0: + axes[1] += y_ndim + + # sanity checks + if 0 in axes: + raise ValueError( + "Cannot perform batch_dot over axis 0. " + "If your inputs are not batched, " + "add a dummy batch dimension to your " + "inputs using keras.ops.expand_dims(x, 0)" + ) + a0, a1 = axes + d1 = x_shape[a0] + d2 = y_shape[a1] + + if d1 is not None and d2 is not None and d1 != d2: + raise ValueError( + f"Cannot do batch_dot on inputs with shapes " + f"{x_shape} and {y_shape} with axes={axes}. " + f"x.shape[{axes[0]}] != y.shape[{axes[1]}] ({d1} != {d2})." + ) + + # backup ndims. Need them later. + orig_x_ndim = x_ndim + orig_y_ndim = y_ndim + + # if rank is 2, expand to 3. + if x_ndim == 2: + x = ops.expand_dims(x, 1) + a0 += 1 + x_ndim += 1 + if y_ndim == 2: + y = ops.expand_dims(y, 2) + y_ndim += 1 + + # bring x's dimension to be reduced to last axis. + if a0 != x_ndim - 1: + pattern = list(range(x_ndim)) + for i in range(a0, x_ndim - 1): + pattern[i] = pattern[i + 1] + pattern[-1] = a0 + x = ops.transpose(x, pattern) + + # bring y's dimension to be reduced to axis 1. + if a1 != 1: + pattern = list(range(y_ndim)) + for i in range(a1, 1, -1): + pattern[i] = pattern[i - 1] + pattern[1] = a1 + y = ops.transpose(y, pattern) + + # normalize both inputs to rank 3. + if x_ndim > 3: + # squash middle dimensions of x. + x_shape = ops.shape(x) + x_mid_dims = x_shape[1:-1] + x_squashed_shape = (x_shape[0], -1, x_shape[-1]) + x = ops.reshape(x, x_squashed_shape) + x_squashed = True + else: + x_squashed = False + + if y_ndim > 3: + # squash trailing dimensions of y. + y_shape = ops.shape(y) + y_trail_dims = y_shape[2:] + y_squashed_shape = (y_shape[0], y_shape[1], -1) + y = ops.reshape(y, y_squashed_shape) + y_squashed = True + else: + y_squashed = False + + result = ops.matmul(x, y) + + # if inputs were squashed, we have to reshape the matmul output. + output_shape = ops.shape(result) + do_reshape = False + + if x_squashed: + output_shape = output_shape[:1] + x_mid_dims + output_shape[-1:] + do_reshape = True + + if y_squashed: + output_shape = output_shape[:-1] + y_trail_dims + do_reshape = True + + if do_reshape: + result = ops.reshape(result, output_shape) + + # if the inputs were originally rank 2, we remove the added 1 dim. + if orig_x_ndim == 2: + result = ops.squeeze(result, 1) + elif orig_y_ndim == 2: + result = ops.squeeze(result, -1) + + return result + + +@keras_export("keras.layers.Dot") +class Dot(Merge): + """Computes element-wise dot product of two tensors. + + It takes a list of inputs of size 2, and the axes + corresponding to each input along with the dot product + is to be performed. + + Let's say `x` and `y` are the two input tensors with shapes + `(2, 3, 5)` and `(2, 10, 3)`. The batch dimension should be + of same size for both the inputs, and `axes` should correspond + to the dimensions that have the same size in the corresponding + inputs. e.g. with `axes=(1, 2)`, the dot product of `x`, and `y` + will result in a tensor with shape `(2, 5, 10)` + + Example: + + >>> x = np.arange(10).reshape(1, 5, 2) + >>> y = np.arange(10, 20).reshape(1, 2, 5) + >>> keras.layers.Dot(axes=(1, 2))([x, y]) + + Usage in a Keras model: + + >>> x1 = keras.layers.Dense(8)(np.arange(10).reshape(5, 2)) + >>> x2 = keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2)) + >>> y = keras.layers.Dot(axes=1)([x1, x2]) + + Args: + axes: Integer or tuple of integers, axis or axes along which to + take the dot product. If a tuple, should be two integers + corresponding to the desired axis from the first input and the + desired axis from the second input, respectively. Note that the + size of the two selected axes must match, and that + axis `0` (the batch axis) cannot be included. + normalize: Whether to L2-normalize samples along the dot product axis + before taking the dot product. If set to `True`, then + the output of the dot product is the cosine proximity + between the two samples. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor, the dot product of the samples from the inputs. + """ + + def __init__(self, axes, normalize=False, **kwargs): + super().__init__(**kwargs) + if not isinstance(axes, int): + if not isinstance(axes, (list, tuple)): + raise TypeError( + f"Invalid type for argument `axes`: it should be " + f"a list or an int. Received: axes={axes}" + ) + if len(axes) != 2: + raise ValueError( + f"Invalid format for argument `axes`: it should contain " + f"two elements. Received: axes={axes}" + ) + if not isinstance(axes[0], int) or not isinstance(axes[1], int): + raise ValueError( + f"Invalid format for argument `axes`: list elements should " + f"be integers. Received: axes={axes}" + ) + self.axes = axes + self.normalize = normalize + self.supports_masking = True + self._reshape_required = False + + def build(self, input_shape): + # Used purely for shape validation. + if ( + not isinstance(input_shape[0], (tuple, list)) + or len(input_shape) != 2 + ): + raise ValueError( + f"A `Dot` layer should be called on a list of 2 inputs. " + f"Received: input_shape={input_shape}" + ) + shape1 = input_shape[0] + shape2 = input_shape[1] + if shape1 is None or shape2 is None: + return + if isinstance(self.axes, int): + if self.axes < 0: + axes = [self.axes % len(shape1), self.axes % len(shape2)] + else: + axes = [self.axes] * 2 + else: + axes = self.axes + if shape1[axes[0]] != shape2[axes[1]]: + raise ValueError( + f"Incompatible input shapes: " + f"axis values {shape1[axes[0]]} (at axis {axes[0]}) != " + f"{shape2[axes[1]]} (at axis {axes[1]}). " + f"Full input shapes: {shape1}, {shape2}" + ) + + def _merge_function(self, inputs): + if len(inputs) != 2: + raise ValueError( + f"A `Dot` layer should be called on exactly 2 inputs. " + f"Received: inputs={inputs}" + ) + x1 = inputs[0] + x2 = inputs[1] + + if isinstance(self.axes, int): + if self.axes < 0: + axes = [ + self.axes % len(x1.shape), + self.axes % len(x2.shape), + ] + else: + axes = [self.axes] * 2 + else: + axes = [] + for i in range(len(self.axes)): + if self.axes[i] < 0: + axes.append(self.axes[i] % len(inputs[i].shape)) + else: + axes.append(self.axes[i]) + + if self.normalize: + x1 = normalize(x1, axis=axes[0]) + x2 = normalize(x2, axis=axes[1]) + output = batch_dot(x1, x2, axes) + return output + + def compute_output_shape(self, input_shape): + if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2: + raise ValueError( + f"A `Dot` layer should be called on a list of 2 inputs. " + f"Received: input_shape={input_shape}" + ) + shape1 = list(input_shape[0]) + shape2 = list(input_shape[1]) + if isinstance(self.axes, int): + if self.axes < 0: + axes = [self.axes % len(shape1), self.axes % len(shape2)] + else: + axes = [self.axes] * 2 + else: + axes = self.axes + shape1.pop(axes[0]) + shape2.pop(axes[1]) + shape2.pop(0) + output_shape = shape1 + shape2 + if len(output_shape) == 1: + output_shape += [1] + return tuple(output_shape) + + def compute_mask(self, inputs, mask=None): + return None + + def get_config(self): + config = { + "axes": self.axes, + "normalize": self.normalize, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@keras_export("keras.layers.dot") +def dot(inputs, axes=-1, **kwargs): + """Functional interface to the `Dot` layer. + + Args: + inputs: A list of input tensors (at least 2). + axes: Integer or tuple of integers, + axis or axes along which to take the dot product. + Note that axis `0` (the batch axis) cannot be included. + normalize: Whether to L2-normalize samples along the + dot product axis before taking the dot product. + If set to `True`, then the output of the dot product + is the cosine proximity between the two samples. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor, the dot product of the samples from the inputs. + """ + return Dot(axes=axes, **kwargs)(inputs) diff --git a/keras/src/layers/merging/maximum.py b/keras/src/layers/merging/maximum.py new file mode 100644 index 000000000000..3072ecb625a9 --- /dev/null +++ b/keras/src/layers/merging/maximum.py @@ -0,0 +1,67 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Maximum") +class Maximum(Merge): + """Computes element-wise maximum on a list of inputs. + + It takes as input a list of tensors, all of the same shape, + and returns a single tensor (also of the same shape). + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.Maximum()([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> # equivalent to `y = keras.layers.maximum([x1, x2])` + >>> y = keras.layers.Maximum()([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + + def _merge_function(self, inputs): + return self._apply_merge_op_and_or_mask(ops.maximum, inputs) + + +@keras_export("keras.layers.maximum") +def maximum(inputs, **kwargs): + """Functional interface to the `keras.layers.Maximum` layer. + + Args: + inputs: A list of input tensors , all of the same shape. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor as the element-wise product of the inputs with the same + shape as the inputs. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.maximum([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> y = keras.layers.maximum([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + return Maximum(**kwargs)(inputs) diff --git a/keras/src/layers/merging/merging_test.py b/keras/src/layers/merging/merging_test.py new file mode 100644 index 000000000000..977ad9c2cc1d --- /dev/null +++ b/keras/src/layers/merging/merging_test.py @@ -0,0 +1,441 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing + + +def np_dot(a, b, axes): + if isinstance(axes, int): + axes = (axes, axes) + axes = [axis if axis < 0 else axis - 1 for axis in axes] + res = np.stack([np.tensordot(a[i], b[i], axes) for i in range(a.shape[0])]) + if len(res.shape) == 1: + res = np.expand_dims(res, axis=1) + return res + + +TEST_PARAMETERS = [ + { + "testcase_name": "add", + "layer_class": layers.Add, + "np_op": np.add, + }, + { + "testcase_name": "subtract", + "layer_class": layers.Subtract, + "np_op": np.subtract, + }, + { + "testcase_name": "minimum", + "layer_class": layers.Minimum, + "np_op": np.minimum, + }, + { + "testcase_name": "maximum", + "layer_class": layers.Maximum, + "np_op": np.maximum, + }, + { + "testcase_name": "multiply", + "layer_class": layers.Multiply, + "np_op": np.multiply, + }, + { + "testcase_name": "average", + "layer_class": layers.Average, + "np_op": lambda a, b: np.multiply(np.add(a, b), 0.5), + }, + { + "testcase_name": "concat", + "layer_class": layers.Concatenate, + "np_op": lambda a, b, **kwargs: np.concatenate((a, b), **kwargs), + "init_kwargs": {"axis": -1}, + "expected_output_shape": (2, 4, 10), + }, + { + "testcase_name": "dot_2d", + "layer_class": layers.Dot, + "np_op": np_dot, + "init_kwargs": {"axes": -1}, + "input_shape": (2, 4), + "expected_output_shape": (2, 1), + "skip_mask_test": True, + }, + { + "testcase_name": "dot_3d", + "layer_class": layers.Dot, + "np_op": np_dot, + "init_kwargs": {"axes": -1}, + "expected_output_shape": (2, 4, 4), + "skip_mask_test": True, + }, +] + + +@pytest.mark.requires_trainable_backend +class MergingLayersTest(testing.TestCase): + @parameterized.named_parameters(TEST_PARAMETERS) + def test_basic( + self, + layer_class, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + **kwargs, + ): + self.run_layer_test( + layer_class, + init_kwargs=init_kwargs, + input_shape=(input_shape, input_shape), + expected_output_shape=expected_output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + @parameterized.named_parameters(TEST_PARAMETERS) + def test_correctness_static( + self, + layer_class, + np_op, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + skip_mask_test=False, + ): + batch_size = input_shape[0] + shape = input_shape[1:] + x1 = np.random.rand(*input_shape) + x2 = np.random.rand(*input_shape) + x3 = np_op(x1, x2, **init_kwargs) + + input_1 = layers.Input(shape=shape, batch_size=batch_size) + input_2 = layers.Input(shape=shape, batch_size=batch_size) + layer = layer_class(**init_kwargs) + out = layer([input_1, input_2]) + model = models.Model([input_1, input_2], out) + res = model([x1, x2]) + + self.assertEqual(res.shape, expected_output_shape) + self.assertAllClose(res, x3, atol=1e-4) + self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None])) + self.assertIsNone(layer.compute_mask([x1, x2], [None, None])) + if not skip_mask_test: + mask1 = np.ones(input_shape[:-1], dtype=np.bool_) + mask2 = np.ones(input_shape[:-1], dtype=np.bool_) + self.assertTrue( + np.all( + backend.convert_to_numpy( + layer.compute_mask([x1, x2], [mask1, mask2]) + ) + ) + ) + + @parameterized.named_parameters(TEST_PARAMETERS) + def test_correctness_dynamic( + self, + layer_class, + np_op, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + skip_mask_test=False, + ): + shape = input_shape[1:] + x1 = np.random.rand(*input_shape) + x2 = np.random.rand(*input_shape) + x3 = np_op(x1, x2, **init_kwargs) + + input_1 = layers.Input(shape=shape) + input_2 = layers.Input(shape=shape) + layer = layer_class(**init_kwargs) + out = layer([input_1, input_2]) + model = models.Model([input_1, input_2], out) + res = model([x1, x2]) + + self.assertEqual(res.shape, expected_output_shape) + self.assertAllClose(res, x3, atol=1e-4) + self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None])) + if not skip_mask_test: + self.assertTrue( + np.all( + backend.convert_to_numpy( + layer.compute_mask( + [input_1, input_2], + [backend.Variable(x1), backend.Variable(x2)], + ) + ) + ) + ) + + @parameterized.named_parameters(TEST_PARAMETERS) + def test_errors( + self, + layer_class, + init_kwargs={}, + input_shape=(2, 4, 5), + skip_mask_test=False, + **kwargs, + ): + if skip_mask_test: + pytest.skip("Masking not supported") + + batch_size = input_shape[0] + shape = input_shape[1:] + x1 = np.random.rand(*input_shape) + x1 = np.random.rand(batch_size, *shape) + + input_1 = layers.Input(shape=shape, batch_size=batch_size) + input_2 = layers.Input(shape=shape, batch_size=batch_size) + layer = layer_class(**init_kwargs) + + with self.assertRaisesRegex(ValueError, "`mask` should be a list."): + layer.compute_mask([input_1, input_2], x1) + + with self.assertRaisesRegex(ValueError, "`inputs` should be a list."): + layer.compute_mask(input_1, [None, None]) + + with self.assertRaisesRegex( + ValueError, " should have the same length." + ): + layer.compute_mask([input_1, input_2], [None]) + + def test_subtract_layer_inputs_length_errors(self): + shape = (4, 5) + input_1 = layers.Input(shape=shape) + input_2 = layers.Input(shape=shape) + input_3 = layers.Input(shape=shape) + + with self.assertRaisesRegex( + ValueError, "layer should be called on exactly 2 inputs" + ): + layers.Subtract()([input_1, input_2, input_3]) + with self.assertRaisesRegex( + ValueError, "layer should be called on exactly 2 inputs" + ): + layers.Subtract()([input_1]) + + def test_dot_higher_dim(self): + a_shape = (1, 3, 2) + b_shape = (1, 1, 2, 3) + # Test symbolic call + a = layers.Input(batch_shape=a_shape) + b = layers.Input(batch_shape=b_shape) + c = layers.Dot(axes=(-2, -1))([a, b]) + self.assertEqual(c.shape, (1, 2, 1, 2)) + a = np.random.random(a_shape) + b = np.random.random(b_shape) + c = layers.Dot(axes=(-2, -1))([a, b]) + self.assertEqual(backend.standardize_shape(c.shape), (1, 2, 1, 2)) + + def test_add_with_mask(self): + mask = layers.Masking() + x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + + output = layers.Add()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [6, 8]]]) + self.assertIsNone(getattr(output, "_keras_mask", None)) + + x2 = mask(x2) + output = layers.Add()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [6, 8]]]) + self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]]) + + def test_subtract_with_mask(self): + mask = layers.Masking() + x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + + output = layers.Subtract()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [1, 2], [-1, -2], [0, 0]]]) + self.assertIsNone(getattr(output, "_keras_mask", None)) + + x2 = mask(x2) + output = layers.Subtract()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [1, 2], [-1, -2], [0, 0]]]) + self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]]) + + def test_average_with_mask(self): + mask = layers.Masking() + x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + + output = layers.Average()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [0.5, 1], [0.5, 1], [3, 4]]]) + self.assertIsNone(getattr(output, "_keras_mask", None)) + + x2 = mask(x2) + output = layers.Average()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [0.5, 1], [0.5, 1], [3, 4]]]) + self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]]) + + def test_multiply_with_mask(self): + mask = layers.Masking() + x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + + output = layers.Multiply()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [0, 0], [1, 2], [9, 16]]]) + self.assertIsNone(getattr(output, "_keras_mask", None)) + + x2 = mask(x2) + output = layers.Multiply()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [9, 16]]]) + self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]]) + + def test_maximum_with_mask(self): + mask = layers.Masking() + x1 = mask( + backend.convert_to_tensor([[[0, 0], [-1, -2], [0, 0], [-3, -4]]]) + ) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [-1, -2], [-3, -4]]]) + + output = layers.Maximum()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [0, 0], [-1, -2], [-3, -4]]]) + self.assertIsNone(getattr(output, "_keras_mask", None)) + + x2 = mask(x2) + output = layers.Maximum()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [-1, -2], [-1, -2], [-3, -4]]]) + self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]]) + + def test_minimum_with_mask(self): + mask = layers.Masking() + x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + + output = layers.Minimum()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [0, 0], [1, 2], [3, 4]]]) + self.assertIsNone(getattr(output, "_keras_mask", None)) + + x2 = mask(x2) + output = layers.Minimum()([x1, x2]) + self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [3, 4]]]) + self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]]) + + def test_concatenate_with_mask(self): + mask = layers.Masking() + x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + + output = layers.Concatenate(axis=1)([x1, x2]) + self.assertAllClose( + output, + [[[0, 0], [1, 2], [0, 0], [3, 4], [0, 0], [0, 0], [1, 2], [3, 4]]], + ) + self.assertAllClose(output._keras_mask, [[0, 1, 0, 1, 1, 1, 1, 1]]) + + output = layers.Concatenate(axis=2)([x1, x2]) + self.assertAllClose( + output, + [[[0, 0, 0, 0], [1, 2, 0, 0], [0, 0, 1, 2], [3, 4, 3, 4]]], + ) + self.assertAllClose(output._keras_mask, [[1, 1, 1, 1]]) + + def test_concatenate_with_mask_symbolic(self): + input1 = layers.Input((4, 2)) + input2 = layers.Input((4, 2)) + mask = layers.Masking() + output = layers.Concatenate(axis=1)([mask(input1), input2]) + model = models.Model( + inputs=[input1, input2], outputs=output._keras_mask + ) + x1 = backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]) + x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]]) + self.assertAllClose(model([x1, x2]), [[0, 1, 0, 1, 1, 1, 1, 1]]) + + def test_concatenate_errors(self): + # This should work + x1 = np.ones((1, 1, 1, 1, 5)) + x2 = np.ones((1, 1, 1, 1, 4)) + out = layers.Concatenate(axis=-1)([x1, x2]) + self.assertEqual(ops.shape(out), (1, 1, 1, 1, 9)) + + # This won't + x1 = np.ones((1, 2, 1, 1, 5)) + x2 = np.ones((1, 1, 1, 1, 4)) + with self.assertRaisesRegex( + ValueError, + ( + "requires inputs with matching shapes " + "except for the concatenation axis" + ), + ): + out = layers.Concatenate(axis=-1)([x1, x2]) + x1 = np.ones((1, 2, 1, 2, 1)) + x2 = np.ones((1, 1, 1, 3, 1)) + with self.assertRaisesRegex( + ValueError, + ( + "requires inputs with matching shapes " + "except for the concatenation axis" + ), + ): + out = layers.Concatenate(axis=1)([x1, x2]) + + @parameterized.named_parameters(TEST_PARAMETERS) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_sparse( + self, + layer_class, + np_op, + init_kwargs={}, + input_shape=(2, 4, 5), + expected_output_shape=(2, 4, 5), + **kwargs, + ): + self.run_layer_test( + layer_class, + init_kwargs=init_kwargs, + input_shape=[input_shape, input_shape], + input_sparse=True, + expected_output_shape=expected_output_shape, + expected_output_sparse=True, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + run_mixed_precision_check=False, + ) + + layer = layer_class(**init_kwargs) + + # Merging a sparse tensor with a dense tensor, or a dense tensor with a + # sparse tensor produces a dense tensor + if backend.backend() == "tensorflow": + import tensorflow as tf + + x1 = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3)) + x3 = tf.SparseTensor([[0, 0], [1, 1]], [4.0, 5.0], (2, 3)) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + # Use n_batch of 1 to be compatible with all ops. + x1 = jax_sparse.BCOO(([[1.0, 2.0]], [[[0], [2]]]), shape=(2, 3)) + x3 = jax_sparse.BCOO(([[4.0, 5.0]], [[[0], [1]]]), shape=(2, 3)) + else: + self.fail(f"Sparse is unsupported with backend {backend.backend()}") + + x1_np = backend.convert_to_numpy(x1) + x2 = np.random.rand(2, 3) + self.assertAllClose(layer([x1, x2]), np_op(x1_np, x2, **init_kwargs)) + self.assertAllClose(layer([x2, x1]), np_op(x2, x1_np, **init_kwargs)) + + # Merging a sparse tensor with a sparse tensor produces a sparse tensor + x3_np = backend.convert_to_numpy(x3) + + self.assertSparse(layer([x1, x3])) + self.assertAllClose(layer([x1, x3]), np_op(x1_np, x3_np, **init_kwargs)) diff --git a/keras/src/layers/merging/minimum.py b/keras/src/layers/merging/minimum.py new file mode 100644 index 000000000000..dad5997ef656 --- /dev/null +++ b/keras/src/layers/merging/minimum.py @@ -0,0 +1,67 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Minimum") +class Minimum(Merge): + """Computes elementwise minimum on a list of inputs. + + It takes as input a list of tensors, all of the same shape, + and returns a single tensor (also of the same shape). + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.Minimum()([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> # equivalent to `y = keras.layers.minimum([x1, x2])` + >>> y = keras.layers.Minimum()([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + + def _merge_function(self, inputs): + return self._apply_merge_op_and_or_mask(ops.minimum, inputs) + + +@keras_export("keras.layers.minimum") +def minimum(inputs, **kwargs): + """Functional interface to the `keras.layers.Minimum` layer. + + Args: + inputs: A list of input tensors , all of the same shape. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor as the elementwise product of the inputs with the same + shape as the inputs. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.minimum([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> y = keras.layers.minimum([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + return Minimum(**kwargs)(inputs) diff --git a/keras/src/layers/merging/multiply.py b/keras/src/layers/merging/multiply.py new file mode 100644 index 000000000000..72fbe1e831dc --- /dev/null +++ b/keras/src/layers/merging/multiply.py @@ -0,0 +1,91 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Multiply") +class Multiply(Merge): + """Performs elementwise multiplication. + + It takes as input a list of tensors, all of the same shape, + and returns a single tensor (also of the same shape). + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.Multiply()([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> # equivalent to `y = keras.layers.multiply([x1, x2])` + >>> y = keras.layers.Multiply()([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + + def _merge_function(self, inputs): + masks = [backend.get_keras_mask(x) for x in inputs] + has_output_mask = all(mask is not None for mask in masks) + output = None + output_mask = None + + for x, mask in zip(inputs, masks): + if mask is not None: + mask = ops.broadcast_to(ops.expand_dims(mask, -1), ops.shape(x)) + # Replace 0s with 1s outside of mask. + x = ops.where(mask, x, ops.cast(1, x.dtype)) + if has_output_mask: + output_mask = ( + mask + if output_mask is None + else ops.logical_or(output_mask, mask) + ) + output = x if output is None else ops.multiply(output, x) + + if has_output_mask: + # Replace 1s with 0s outside of mask per standard masking rules. + output = ops.where(output_mask, output, ops.cast(0, output.dtype)) + output_mask = ops.any(output_mask, axis=-1, keepdims=False) + backend.set_keras_mask(output, output_mask) + return output + + +@keras_export("keras.layers.multiply") +def multiply(inputs, **kwargs): + """Functional interface to the `keras.layers.Multiply` layer. + + Args: + inputs: A list of input tensors , all of the same shape. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor as the elementwise product of the inputs with the same + shape as the inputs. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.multiply([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> y = keras.layers.multiply([x1, x2]) + >>> out = keras.layers.Dense(4)(y) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + return Multiply(**kwargs)(inputs) diff --git a/keras/src/layers/merging/subtract.py b/keras/src/layers/merging/subtract.py new file mode 100644 index 000000000000..78036adaf233 --- /dev/null +++ b/keras/src/layers/merging/subtract.py @@ -0,0 +1,82 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.merging.base_merge import Merge + + +@keras_export("keras.layers.Subtract") +class Subtract(Merge): + """Performs elementwise subtraction. + + It takes as input a list of tensors of size 2 both of the + same shape, and returns a single tensor (inputs[0] - inputs[1]) + of same shape. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.Subtract()([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> # equivalent to `subtracted = keras.layers.subtract([x1, x2])` + >>> subtracted = keras.layers.Subtract()([x1, x2]) + >>> out = keras.layers.Dense(4)(subtracted) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + + def build(self, input_shape): + super().build(input_shape) + if len(input_shape) != 2: + raise ValueError( + "A `Subtract` layer should be called on exactly 2 inputs. " + f"Received: input_shape={input_shape}" + ) + + def _merge_function(self, inputs): + if len(inputs) != 2: + raise ValueError( + "A `Subtract` layer should be called on exactly 2 inputs. " + f"Received: inputs={inputs}" + ) + return ops.subtract(inputs[0], inputs[1]) + + +@keras_export("keras.layers.subtract") +def subtract(inputs, **kwargs): + """Functional interface to the `keras.layers.Subtract` layer. + + Args: + inputs: A list of input tensors of size 2, each tensor of + the same shape. + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor as the difference of the inputs. It has the same shape + as the inputs. + + Examples: + + >>> input_shape = (2, 3, 4) + >>> x1 = np.random.rand(*input_shape) + >>> x2 = np.random.rand(*input_shape) + >>> y = keras.layers.subtract([x1, x2]) + + Usage in a Keras model: + + >>> input1 = keras.layers.Input(shape=(16,)) + >>> x1 = keras.layers.Dense(8, activation='relu')(input1) + >>> input2 = keras.layers.Input(shape=(32,)) + >>> x2 = keras.layers.Dense(8, activation='relu')(input2) + >>> subtracted = keras.layers.subtract([x1, x2]) + >>> out = keras.layers.Dense(4)(subtracted) + >>> model = keras.models.Model(inputs=[input1, input2], outputs=out) + + """ + return Subtract(**kwargs)(inputs) diff --git a/keras/src/layers/normalization/__init__.py b/keras/src/layers/normalization/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/normalization/batch_normalization.py b/keras/src/layers/normalization/batch_normalization.py new file mode 100644 index 000000000000..c7b5e492ca1e --- /dev/null +++ b/keras/src/layers/normalization/batch_normalization.py @@ -0,0 +1,348 @@ +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.BatchNormalization") +class BatchNormalization(Layer): + """Layer that normalizes its inputs. + + Batch normalization applies a transformation that maintains the mean output + close to 0 and the output standard deviation close to 1. + + Importantly, batch normalization works differently during training and + during inference. + + **During training** (i.e. when using `fit()` or when calling the layer/model + with the argument `training=True`), the layer normalizes its output using + the mean and standard deviation of the current batch of inputs. That is to + say, for each channel being normalized, the layer returns + `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where: + + - `epsilon` is small constant (configurable as part of the constructor + arguments) + - `gamma` is a learned scaling factor (initialized as 1), which + can be disabled by passing `scale=False` to the constructor. + - `beta` is a learned offset factor (initialized as 0), which + can be disabled by passing `center=False` to the constructor. + + **During inference** (i.e. when using `evaluate()` or `predict()` or when + calling the layer/model with the argument `training=False` (which is the + default), the layer normalizes its output using a moving average of the + mean and standard deviation of the batches it has seen during training. That + is to say, it returns + `gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta`. + + `self.moving_mean` and `self.moving_var` are non-trainable variables that + are updated each time the layer in called in training mode, as such: + + - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` + - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` + + As such, the layer will only normalize its inputs during inference + *after having been trained on data that has similar statistics as the + inference data*. + + Args: + axis: Integer, the axis that should be normalized + (typically the features axis). For instance, after a `Conv2D` layer + with `data_format="channels_first"`, use `axis=1`. + momentum: Momentum for the moving average. + epsilon: Small float added to variance to avoid dividing by zero. + center: If `True`, add offset of `beta` to normalized tensor. + If `False`, `beta` is ignored. + scale: If `True`, multiply by `gamma`. If `False`, `gamma` is not used. + When the next layer is linear this can be disabled + since the scaling will be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + moving_mean_initializer: Initializer for the moving mean. + moving_variance_initializer: Initializer for the moving variance. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + synchronized: Only applicable with the TensorFlow backend. + If `True`, synchronizes the global batch statistics (mean and + variance) for the layer across all devices at each training step + in a distributed training strategy. + If `False`, each replica uses its own local batch statistics. + **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`). + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. + - `training=True`: The layer will normalize its inputs using + the mean and variance of the current batch of inputs. + - `training=False`: The layer will normalize its inputs using + the mean and variance of its moving statistics, learned during + training. + mask: Binary tensor of shape broadcastable to `inputs` tensor, with + `True` values indicating the positions for which mean and variance + should be computed. Masked elements of the current inputs are not + taken into account for mean and variance computation during + training. Any prior unmasked element values will be taken into + account until their momentum expires. + + Reference: + + - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). + + **About setting `layer.trainable = False` on a `BatchNormalization` layer:** + + The meaning of setting `layer.trainable = False` is to freeze the layer, + i.e. its internal state will not change during training: + its trainable weights will not be updated + during `fit()` or `train_on_batch()`, and its state updates will not be run. + + Usually, this does not necessarily mean that the layer is run in inference + mode (which is normally controlled by the `training` argument that can + be passed when calling a layer). "Frozen state" and "inference mode" + are two separate concepts. + + However, in the case of the `BatchNormalization` layer, **setting + `trainable = False` on the layer means that the layer will be + subsequently run in inference mode** (meaning that it will use + the moving mean and the moving variance to normalize the current batch, + rather than using the mean and variance of the current batch). + + Note that: + + - Setting `trainable` on an model containing other layers will recursively + set the `trainable` value of all inner layers. + - If the value of the `trainable` attribute is changed after calling + `compile()` on a model, the new value doesn't take effect for this model + until `compile()` is called again. + """ + + def __init__( + self, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + moving_mean_initializer="zeros", + moving_variance_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + synchronized=False, + **kwargs, + ): + super().__init__(**kwargs) + self.axis = int(axis) + + if synchronized and backend.backend() != "tensorflow": + raise ValueError( + "Argument synchronized=True is only supported " + "with the TensorFlow backend." + ) + self.synchronized = synchronized + + self.momentum = float(momentum) + self.epsilon = float(epsilon) + self.center = center + self.scale = scale + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + self.moving_mean_initializer = initializers.get(moving_mean_initializer) + self.moving_variance_initializer = initializers.get( + moving_variance_initializer + ) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_constraint = constraints.get(beta_constraint) + self.gamma_constraint = constraints.get(gamma_constraint) + self.supports_masking = True + + self.gamma = None + self.beta = None + self.moving_mean = None + self.moving_variance = None + self._reduction_axes = None + + def build(self, input_shape): + shape = (input_shape[self.axis],) + if self.scale: + self.gamma = self.add_weight( + shape=shape, + name="gamma", + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + trainable=True, + autocast=False, + ) + if self.center: + self.beta = self.add_weight( + shape=shape, + name="beta", + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + trainable=True, + autocast=False, + ) + self.moving_mean = self.add_weight( + shape=shape, + name="moving_mean", + initializer=self.moving_mean_initializer, + trainable=False, + autocast=False, + ) + self.moving_variance = self.add_weight( + shape=shape, + name="moving_variance", + initializer=self.moving_variance_initializer, + trainable=False, + autocast=False, + ) + + self.input_spec = InputSpec( + ndim=len(input_shape), axes={self.axis: input_shape[self.axis]} + ) + + reduction_axes = list(range(len(input_shape))) + del reduction_axes[self.axis] + self._reduction_axes = reduction_axes + + def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) + return input_shape + + def call(self, inputs, training=None, mask=None): + # Check if the mask has one less dimension than the inputs. + if mask is not None: + if len(mask.shape) != len(inputs.shape) - 1: + # Raise a value error + raise ValueError( + "The mask provided should be one dimension less " + "than the inputs. Received: " + f"mask.shape={mask.shape}, inputs.shape={inputs.shape}" + ) + + compute_dtype = backend.result_type(inputs.dtype, "float32") + # BN is prone to overflow with float16/bfloat16 inputs, so we upcast to + # float32 for the subsequent computations. + inputs = ops.cast(inputs, compute_dtype) + + moving_mean = ops.cast(self.moving_mean, inputs.dtype) + moving_variance = ops.cast(self.moving_variance, inputs.dtype) + + if training and self.trainable: + mean, variance = self._moments(inputs, mask) + + self.moving_mean.assign( + moving_mean * self.momentum + mean * (1.0 - self.momentum) + ) + self.moving_variance.assign( + moving_variance * self.momentum + + variance * (1.0 - self.momentum) + ) + else: + mean = moving_mean + variance = moving_variance + + if self.scale: + gamma = ops.cast(self.gamma, inputs.dtype) + else: + gamma = None + + if self.center: + beta = ops.cast(self.beta, inputs.dtype) + else: + beta = None + + outputs = ops.batch_normalization( + x=inputs, + mean=mean, + variance=variance, + axis=self.axis, + offset=beta, + scale=gamma, + epsilon=self.epsilon, + ) + return ops.cast(outputs, self.compute_dtype) + + def get_config(self): + base_config = super().get_config() + config = { + "axis": self.axis, + "momentum": self.momentum, + "epsilon": self.epsilon, + "center": self.center, + "scale": self.scale, + "beta_initializer": initializers.serialize(self.beta_initializer), + "gamma_initializer": initializers.serialize(self.gamma_initializer), + "moving_mean_initializer": initializers.serialize( + self.moving_mean_initializer + ), + "moving_variance_initializer": initializers.serialize( + self.moving_variance_initializer + ), + "beta_regularizer": regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), + "beta_constraint": constraints.serialize(self.beta_constraint), + "gamma_constraint": constraints.serialize(self.gamma_constraint), + "synchronized": self.synchronized, + } + return {**base_config, **config} + + def _moments(self, inputs, mask): + if mask is None: + return ops.moments( + inputs, + axes=self._reduction_axes, + synchronized=self.synchronized, + ) + + mask_weights = ops.cast(mask, inputs.dtype) + mask_weights_broadcasted = ops.expand_dims(mask_weights, axis=-1) + broadcasted_mask = ops.broadcast_to( + mask_weights_broadcasted, ops.shape(inputs) + ) + weighted_inputs = broadcasted_mask * inputs + + weighted_input_sum = ops.sum( + weighted_inputs, + self._reduction_axes, + keepdims=True, + ) + sum_of_weights = ops.sum( + broadcasted_mask, + self._reduction_axes, + keepdims=True, + ) + mean = weighted_input_sum / (sum_of_weights + backend.epsilon()) + + difference = weighted_inputs - mean + squared_difference = ops.square(difference) + weighted_distsq = ops.sum( + broadcasted_mask * squared_difference, + self._reduction_axes, + keepdims=True, + ) + variance = weighted_distsq / (sum_of_weights + backend.epsilon()) + + return ops.squeeze(mean), ops.squeeze(variance) diff --git a/keras/src/layers/normalization/batch_normalization_test.py b/keras/src/layers/normalization/batch_normalization_test.py new file mode 100644 index 000000000000..d713670aae5c --- /dev/null +++ b/keras/src/layers/normalization/batch_normalization_test.py @@ -0,0 +1,241 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing +from keras.src.losses import MeanSquaredError +from keras.src.models import Model + + +class BatchNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_bn_basics(self): + # vector case + self.run_layer_test( + layers.BatchNormalization, + init_kwargs={ + "center": True, + "scale": True, + }, + call_kwargs={"training": True}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.BatchNormalization, + init_kwargs={ + "center": False, + "scale": False, + }, + call_kwargs={"training": True}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # image case, with regularizers + self.run_layer_test( + layers.BatchNormalization, + init_kwargs={ + "center": True, + "scale": True, + "beta_regularizer": "l2", + "gamma_regularizer": "l2", + }, + call_kwargs={"training": True}, + input_shape=(2, 4, 4, 3), + expected_output_shape=(2, 4, 4, 3), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=2, # we have 2 regularizers. + supports_masking=True, + ) + + @parameterized.product( + axis=(-1, 1), + input_shape=((5, 2, 3), (5, 3, 3, 2)), + moving_mean_initializer=("zeros", "ones"), + moving_variance_initializer=("zeros", "ones"), + ) + def test_correctness( + self, + axis, + input_shape, + moving_mean_initializer, + moving_variance_initializer, + ): + # Training + layer = layers.BatchNormalization( + axis=axis, + momentum=0, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + ) + # Random data centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=input_shape) + out = x + for _ in range(3): + out = layer(out, training=True) + + # Assert the normalization is correct. + broadcast_shape = [1] * len(input_shape) + broadcast_shape[axis] = input_shape[axis] + out = backend.convert_to_numpy(out) + out = out - np.reshape( + backend.convert_to_numpy(layer.beta), broadcast_shape + ) + out = out / np.reshape( + backend.convert_to_numpy(layer.gamma), broadcast_shape + ) + + reduction_axes = list(range(len(input_shape))) + del reduction_axes[axis] + reduction_axes = tuple(reduction_axes) + self.assertAllClose(np.mean(out, axis=reduction_axes), 0.0, atol=1e-3) + self.assertAllClose(np.std(out, axis=reduction_axes), 1.0, atol=1e-3) + self.assertAllClose(layer.moving_mean, 0.0, atol=1e-3) + self.assertAllClose(layer.moving_variance, 1.0, atol=1e-3) + + # Inference done before training shouldn't match. + inference_out = layer(x, training=False) + training_out = layer(x, training=True) + self.assertNotAllClose(inference_out, training_out) + + # Since momentum is zero, inference after training should match. + training_out = layer(x, training=True) + inference_out = layer(x, training=False) + self.assertAllClose(inference_out, training_out) + + # Masked result with no training should not differ + x[:, 1, :] = 0.0 + unmasked_out = layer(x, training=False) + masked = layers.Masking()(x) + masked_out = layer(masked, training=False) + self.assertAllClose(unmasked_out, masked_out) + + # Masked result should differ from unmasked result + unmasked_out = layer(x, training=False) + x[:, 1, :] = 0.0 + masked = layers.Masking()(x) + masked_out = layer(masked, training=True) + self.assertNotAllClose(unmasked_out, masked_out) + + @parameterized.product( + synchronized=( + (False, True) if backend.backend == "tensorflow" else (False,) + ), + ) + def test_input_fully_masked(self, synchronized): + norm = layers.BatchNormalization( + scale=False, + center=False, + synchronized=synchronized, + ) + x = np.zeros((4, 5)) + mask = np.zeros((4,), dtype=np.float32) + y = norm(x, mask=mask, training=True) + self.assertAllClose(y, np.zeros_like(x, dtype=np.float32)) + + @parameterized.product(run_eagerly=(True, False), mask_value=(0.0, 0.1, 1)) + @pytest.mark.requires_trainable_backend + def test_bachnorm_ignore_masked_values(self, run_eagerly, mask_value): + padded_data = np.array( + [ + [ + [1, 5], + [2, 5], + [mask_value, mask_value], + [mask_value, mask_value], + ] + for _ in range(10) + ], + dtype="float32", + ) + + inputs = layers.Input((None, 2)) + masked = layers.Masking(mask_value=mask_value)(inputs) + normed = layers.BatchNormalization(momentum=0.0)(masked) + model = Model(inputs, normed) + loss = MeanSquaredError() + model.compile( + "rmsprop", + loss=loss, + run_eagerly=run_eagerly, + ) + model.fit(x=padded_data, y=padded_data, batch_size=10, epochs=5) + self.assertAllClose(model.layers[2].moving_mean.numpy(), [1.5, 5.0]) + self.assertAllClose( + model.layers[2].moving_variance.numpy(), [0.25, 0.0] + ) + + def test_trainable_behavior(self): + layer = layers.BatchNormalization(axis=-1, momentum=0.8, epsilon=1e-7) + layer.build((1, 4, 4, 3)) + layer.trainable = False + self.assertEqual(len(layer.weights), 4) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.non_trainable_weights), 4) + + # Random data centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(200, 4, 4, 3)) + + out = layer(x, training=True) + self.assertAllClose(out, x) + + layer.trainable = True + self.assertEqual(len(layer.weights), 4) + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.non_trainable_weights), 2) + + for _ in range(10): + out = layer(x, training=True) + + out = backend.convert_to_numpy(out) + out = out - np.reshape( + backend.convert_to_numpy(layer.beta), (1, 1, 1, 3) + ) + out = out / np.reshape( + backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3) + ) + + self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3) + self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3) + + def test_large_value_within_autocast_scope(self): + layer = layers.BatchNormalization() + layer.build((1, 4, 4, 3)) + # Use 70000 to trigger overflow for float16 + large_value = ops.full(layer.moving_variance.shape, 70000) + with backend.AutocastScope("float16"): + layer.moving_variance.assign(large_value) + self.assertAllClose(layer.moving_variance.value, large_value) + + def test_masked_broadcast_normalization(self): + input_shape = (1, 2, 3, 4) + mask_shape = (1, 2, 1) + x = ops.ones(input_shape) + mask = ops.ones(mask_shape) + + layer = layers.BatchNormalization(axis=-1, momentum=0.0, epsilon=1e-3) + + y = layer(x, training=True, mask=mask) + + mean_y = ops.mean(y, axis=[0, 1, 2]) + + self.assertAllClose(mean_y, ops.zeros((4,)), atol=1e-6) + self.assertAllClose(y, ops.zeros_like(y), atol=1e-6) + + self.assertAllClose(layer.moving_mean, ops.ones((4,)), atol=1e-6) + self.assertAllClose(layer.moving_variance, ops.zeros((4,)), atol=1e-6) diff --git a/keras/src/layers/normalization/group_normalization.py b/keras/src/layers/normalization/group_normalization.py new file mode 100644 index 000000000000..9d91d1f9944e --- /dev/null +++ b/keras/src/layers/normalization/group_normalization.py @@ -0,0 +1,240 @@ +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.GroupNormalization") +class GroupNormalization(Layer): + """Group normalization layer. + + Group Normalization divides the channels into groups and computes + within each group the mean and variance for normalization. + Empirically, its accuracy is more stable than batch norm in a wide + range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Relation to Layer Normalization: + If the number of groups is set to 1, then this operation becomes nearly + identical to Layer Normalization (see Layer Normalization docs for details). + + Relation to Instance Normalization: + If the number of groups is set to the input dimension (number of groups is + equal to number of channels), then this operation becomes identical to + Instance Normalization. You can achieve this via `groups=-1`. + + Args: + groups: Integer, the number of groups for Group Normalization. Can be in + the range `[1, N]` where N is the input dimension. The input + dimension must be divisible by the number of groups. + Defaults to 32. + axis: Integer or List/Tuple. The axis or axes to normalize across. + Typically, this is the features axis/axes. The left-out axes are + typically the batch axis/axes. -1 is the last dimension in the + input. Defaults to `-1`. + epsilon: Small float added to variance to avoid dividing by zero. + Defaults to 1e-3. + center: If `True`, add offset of `beta` to normalized tensor. + If `False`, `beta` is ignored. Defaults to `True`. + scale: If `True`, multiply by `gamma`. If `False`, `gamma` is not used. + When the next layer is linear (also e.g. `relu`), this can be + disabled since the scaling will be done by the next layer. + Defaults to `True`. + beta_initializer: Initializer for the beta weight. Defaults to zeros. + gamma_initializer: Initializer for the gamma weight. Defaults to ones. + beta_regularizer: Optional regularizer for the beta weight. None by + default. + gamma_regularizer: Optional regularizer for the gamma weight. None by + default. + beta_constraint: Optional constraint for the beta weight. + None by default. + gamma_constraint: Optional constraint for the gamma weight. None by + default. Input shape: Arbitrary. Use the keyword argument + `input_shape` (tuple of integers, does not include the samples + axis) when using this layer as the first layer in a model. + Output shape: Same shape as input. + **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`). + + Reference: + + - [Yuxin Wu & Kaiming He, 2018](https://arxiv.org/abs/1803.08494) + """ + + def __init__( + self, + groups=32, + axis=-1, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs, + ): + super().__init__(**kwargs) + self.supports_masking = True + self.groups = groups + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_constraint = constraints.get(beta_constraint) + self.gamma_constraint = constraints.get(gamma_constraint) + + def build(self, input_shape): + dim = input_shape[self.axis] + + if dim is None: + raise ValueError( + f"Axis {self.axis} of input tensor should have a defined " + "dimension but the layer received an input with shape " + f"{input_shape}." + ) + + if self.groups == -1: + self.groups = dim + + if dim < self.groups: + raise ValueError( + f"Number of groups ({self.groups}) cannot be more than the " + f"number of channels ({dim})." + ) + + if dim % self.groups != 0: + raise ValueError( + f"Number of groups ({self.groups}) must be a multiple " + f"of the number of channels ({dim})." + ) + + self.input_spec = InputSpec( + ndim=len(input_shape), axes={self.axis: dim} + ) + + if self.scale: + self.gamma = self.add_weight( + shape=(dim,), + name="gamma", + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + ) + else: + self.gamma = None + + if self.center: + self.beta = self.add_weight( + shape=(dim,), + name="beta", + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + ) + else: + self.beta = None + + super().build(input_shape) + + def call(self, inputs): + reshaped_inputs = self._reshape_into_groups(inputs) + normalized_inputs = self._apply_normalization( + reshaped_inputs, inputs.shape + ) + return ops.reshape(normalized_inputs, ops.shape(inputs)) + + def _reshape_into_groups(self, inputs): + input_shape = ops.shape(inputs) + group_shape = list(inputs.shape) + group_shape[0] = -1 + for i, e in enumerate(group_shape[1:]): + if e is None: + group_shape[i + 1] = input_shape[i + 1] + + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + reshaped_inputs = ops.reshape(inputs, group_shape) + return reshaped_inputs + + def _apply_normalization(self, reshaped_inputs, input_shape): + inputs_dtype = reshaped_inputs.dtype + compute_dtype = backend.result_type(inputs_dtype, "float32") + # GN is prone to overflow with float16/bfloat16 inputs, so we upcast to + # float32 for the subsequent computations. + reshaped_inputs = ops.cast(reshaped_inputs, compute_dtype) + + group_reduction_axes = list(range(1, len(reshaped_inputs.shape))) + + axis = -2 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + + broadcast_shape = self._create_broadcast_shape(input_shape) + mean, variance = ops.moments( + reshaped_inputs, axes=group_reduction_axes, keepdims=True + ) + + # Compute the batch normalization. + inv = ops.rsqrt(variance + self.epsilon) + if self.scale: + gamma = ops.reshape(self.gamma, broadcast_shape) + gamma = ops.cast(gamma, reshaped_inputs.dtype) + inv = inv * gamma + + res = -mean * inv + if self.center: + beta = ops.reshape(self.beta, broadcast_shape) + beta = ops.cast(beta, reshaped_inputs.dtype) + res = res + beta + + normalized_inputs = reshaped_inputs * inv + res + normalized_inputs = ops.cast(normalized_inputs, inputs_dtype) + + return normalized_inputs + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * len(input_shape) + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + return broadcast_shape + + def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) + return input_shape + + def get_config(self): + config = { + "groups": self.groups, + "axis": self.axis, + "epsilon": self.epsilon, + "center": self.center, + "scale": self.scale, + "beta_initializer": initializers.serialize(self.beta_initializer), + "gamma_initializer": initializers.serialize(self.gamma_initializer), + "beta_regularizer": regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), + "beta_constraint": constraints.serialize(self.beta_constraint), + "gamma_constraint": constraints.serialize(self.gamma_constraint), + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/normalization/group_normalization_test.py b/keras/src/layers/normalization/group_normalization_test.py new file mode 100644 index 000000000000..76e4eae280a8 --- /dev/null +++ b/keras/src/layers/normalization/group_normalization_test.py @@ -0,0 +1,179 @@ +import numpy as np +import pytest + +from keras.src import constraints +from keras.src import layers +from keras.src import regularizers +from keras.src import testing + + +class GroupNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_groupnorm(self): + self.run_layer_test( + layers.GroupNormalization, + init_kwargs={ + "gamma_regularizer": regularizers.L2(0.01), + "beta_regularizer": regularizers.L2(0.01), + }, + input_shape=(3, 4, 32), + expected_output_shape=(3, 4, 32), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=2, + supports_masking=True, + ) + + self.run_layer_test( + layers.GroupNormalization, + init_kwargs={ + "groups": 4, + "gamma_constraint": constraints.UnitNorm(), + "beta_constraint": constraints.UnitNorm(), + }, + input_shape=(3, 4, 4), + expected_output_shape=(3, 4, 4), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_undefined_dim_error(self): + inputs = layers.Input(shape=(2, 2, 2, None)) + layer = layers.GroupNormalization() + with self.assertRaisesRegex( + ValueError, + ( + "input tensor should have a defined dimension but the layer " + "received an input with shape" + ), + ): + _ = layer(inputs) + + def test_groups_bigger_than_dim_error(self): + inputs = np.ones(shape=(2, 2, 2, 4)) + layer = layers.GroupNormalization(groups=5) + with self.assertRaisesRegex( + ValueError, + "cannot be more than the number of channels", + ): + _ = layer(inputs) + + def test_groups_not_a_multiple_of_dim_error(self): + inputs = np.ones(shape=(2, 2, 2, 4)) + layer = layers.GroupNormalization(groups=3) + with self.assertRaisesRegex( + ValueError, + "must be a multiple of the number of channels", + ): + _ = layer(inputs) + + def test_groups_instance_norm(self): + # GroupNormalization with groups=-1 will become InstanceNormalization + instance_norm_layer_1 = layers.GroupNormalization( + groups=-1, axis=-1, scale=False, center=False + ) + instance_norm_layer_2 = layers.GroupNormalization( + groups=4, axis=-1, scale=False, center=False + ) + inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]]) + + outputs_1 = instance_norm_layer_1(inputs) + outputs_2 = instance_norm_layer_2(inputs) + + self.assertAllClose(outputs_1, outputs_2) + + def test_correctness_instance_norm(self): + instance_norm_layer = layers.GroupNormalization( + groups=4, axis=-1, scale=False, center=False + ) + + inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]]) + + expected_instance_norm_output = np.array( + [[[-1.0, -1.0, 1.0, 1.0], [1.0, 1.0, -1.0, -1.0]]] + ) + + self.assertAllClose( + instance_norm_layer(inputs), + expected_instance_norm_output, + atol=1e-3, + ) + + def test_correctness_1d(self): + layer_with_1_group = layers.GroupNormalization( + groups=1, axis=-1, scale=False, center=False + ) + layer_with_2_groups = layers.GroupNormalization( + groups=2, axis=1, scale=False, center=False + ) + + inputs = np.array([[-1.0, -1.0, 1.0, 1.0, 2.0, 2.0, 0, -2.0]]) + + expected_output_1_group = np.array( + [[-0.898, -0.898, 0.539, 0.539, 1.257, 1.257, -0.180, -1.616]], + ) + self.assertAllClose( + layer_with_1_group(inputs), + expected_output_1_group, + atol=1e-3, + ) + + expected_output_2_groups = np.array( + [[-1.0, -1.0, 1.0, 1.0, 0.904, 0.904, -0.301, -1.507]] + ) + self.assertAllClose( + layer_with_2_groups(inputs), + expected_output_2_groups, + atol=1e-3, + ) + + def test_correctness_2d(self): + layer_with_1_group = layers.GroupNormalization( + groups=1, axis=-1, scale=False, center=False + ) + layer_with_2_groups = layers.GroupNormalization( + groups=2, axis=2, scale=False, center=False + ) + + inputs = np.array([[[-1.0, -1.0, 2.0, 2.0], [1.0, 1.0, 0, -2.0]]]) + + expected_output_1_group = np.array( + [[[-0.898, -0.898, 1.257, 1.257], [0.539, 0.539, -0.180, -1.616]]] + ) + + self.assertAllClose( + layer_with_1_group(inputs), + expected_output_1_group, + atol=1e-3, + ) + + expected_output_2_groups = np.array( + [[[-1.0, -1.0, 0.904, 0.904], [1.0, 1.0, -0.301, -1.507]]] + ) + self.assertAllClose( + layer_with_2_groups(inputs), + expected_output_2_groups, + atol=1e-3, + ) + + def test_broadcasting_2d_channels_first(self): + x = np.arange(16).reshape((1, 4, 2, 2)).astype("float32") + x = layers.GroupNormalization(groups=2, axis=1)(x) + self.assertAllClose( + x, + np.array( + [ + [ + [[-1.5274, -1.0910], [-0.6546, -0.2182]], + [[0.2182, 0.6546], [1.0910, 1.5274]], + [[-1.5274, -1.0910], [-0.6546, -0.2182]], + [[0.2182, 0.6546], [1.0910, 1.5274]], + ] + ] + ), + atol=1e-3, + ) diff --git a/keras/src/layers/normalization/layer_normalization.py b/keras/src/layers/normalization/layer_normalization.py new file mode 100644 index 000000000000..4df59b498049 --- /dev/null +++ b/keras/src/layers/normalization/layer_normalization.py @@ -0,0 +1,226 @@ +import warnings + +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.LayerNormalization") +class LayerNormalization(Layer): + """Layer normalization layer (Ba et al., 2016). + + Normalize the activations of the previous layer for each given example in a + batch independently, rather than across a batch like Batch Normalization. + i.e. applies a transformation that maintains the mean activation within each + example close to 0 and the activation standard deviation close to 1. + + If `scale` or `center` are enabled, the layer will scale the normalized + outputs by broadcasting them with a trainable variable `gamma`, and center + the outputs by broadcasting with a trainable variable `beta`. `gamma` will + default to a ones tensor and `beta` will default to a zeros tensor, so that + centering and scaling are no-ops before training has begun. + + So, with scaling and centering enabled the normalization equations + are as follows: + + Let the intermediate activations for a mini-batch to be the `inputs`. + + For each sample `x_i` in `inputs` with `k` features, we compute the mean and + variance of the sample: + + ```python + mean_i = sum(x_i[j] for j in range(k)) / k + var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k + ``` + + and then compute a normalized `x_i_normalized`, including a small factor + `epsilon` for numerical stability. + + ```python + x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon) + ``` + + And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`, + which are learned parameters: + + ```python + output_i = x_i_normalized * gamma + beta + ``` + + `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and + this part of the inputs' shape must be fully defined. + + For example: + + >>> layer = keras.layers.LayerNormalization(axis=[1, 2, 3]) + >>> layer.build([5, 20, 30, 40]) + >>> print(layer.beta.shape) + (20, 30, 40) + >>> print(layer.gamma.shape) + (20, 30, 40) + + Note that other implementations of layer normalization may choose to define + `gamma` and `beta` over a separate set of axes from the axes being + normalized across. For example, Group Normalization + ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1 + corresponds to a Layer Normalization that normalizes across height, width, + and channel and has `gamma` and `beta` span only the channel dimension. + So, this Layer Normalization implementation will not match a Group + Normalization layer with group size set to 1. + + Args: + axis: Integer or List/Tuple. The axis or axes to normalize across. + Typically, this is the features axis/axes. The left-out axes are + typically the batch axis/axes. `-1` is the last dimension in the + input. Defaults to `-1`. + epsilon: Small float added to variance to avoid dividing by zero. + Defaults to 1e-3. + center: If True, add offset of `beta` to normalized tensor. If False, + `beta` is ignored. Defaults to `True`. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling will be done by the next layer. + Defaults to `True`. + beta_initializer: Initializer for the beta weight. Defaults to zeros. + gamma_initializer: Initializer for the gamma weight. Defaults to ones. + beta_regularizer: Optional regularizer for the beta weight. + None by default. + gamma_regularizer: Optional regularizer for the gamma weight. + None by default. + beta_constraint: Optional constraint for the beta weight. + None by default. + gamma_constraint: Optional constraint for the gamma weight. + None by default. + **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`). + + + Reference: + + - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450). + """ + + def __init__( + self, + axis=-1, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer="zeros", + gamma_initializer="ones", + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs, + ): + rms_scaling = kwargs.pop("rms_scaling", False) + if rms_scaling: + warnings.warn( + "You passed `rms_scaling=True`, which is deprecated. This " + "argument incorrectly scales the input by the variance, not " + "the root mean square. To correctly use RMS Normalization, " + "please use `keras.layers.RMSNormalization` instead." + ) + + super().__init__(**kwargs) + if isinstance(axis, (list, tuple)): + self.axis = list(axis) + elif isinstance(axis, int): + self.axis = axis + else: + raise TypeError( + "Expected an int or a list/tuple of ints for the " + "argument 'axis', but received: %r" % axis + ) + + self.epsilon = epsilon + self.center = center + self.scale = scale + self.rms_scaling = rms_scaling + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_constraint = constraints.get(beta_constraint) + self.gamma_constraint = constraints.get(gamma_constraint) + + self.supports_masking = True + self.autocast = False + + def build(self, input_shape): + if isinstance(self.axis, list): + shape = tuple([input_shape[dim] for dim in self.axis]) + else: + shape = (input_shape[self.axis],) + self.axis = [self.axis] + if self.scale or self.rms_scaling: + self.gamma = self.add_weight( + name="gamma", + shape=shape, + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + trainable=True, + autocast=False, + ) + else: + self.gamma = None + + if self.center and not self.rms_scaling: + self.beta = self.add_weight( + name="beta", + shape=shape, + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + trainable=True, + autocast=False, + ) + else: + self.beta = None + + def call(self, inputs): + outputs = ops.layer_normalization( + inputs, + self.gamma, + self.beta, + self.axis, + self.epsilon, + rms_scaling=self.rms_scaling, + ) + return ops.cast(outputs, self.compute_dtype) + + def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) + return input_shape + + def get_config(self): + config = { + "axis": self.axis, + "epsilon": self.epsilon, + "center": self.center, + "scale": self.scale, + "rms_scaling": self.rms_scaling, + "beta_initializer": initializers.serialize(self.beta_initializer), + "gamma_initializer": initializers.serialize(self.gamma_initializer), + "beta_regularizer": regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": regularizers.serialize(self.gamma_regularizer), + "beta_constraint": constraints.serialize(self.beta_constraint), + "gamma_constraint": constraints.serialize(self.gamma_constraint), + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/normalization/layer_normalization_test.py b/keras/src/layers/normalization/layer_normalization_test.py new file mode 100644 index 000000000000..ad2c72006204 --- /dev/null +++ b/keras/src/layers/normalization/layer_normalization_test.py @@ -0,0 +1,135 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import regularizers +from keras.src import testing + + +class LayerNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_ln_basics(self): + self.run_layer_test( + layers.LayerNormalization, + init_kwargs={ + "gamma_regularizer": regularizers.L2(0.01), + "beta_regularizer": regularizers.L2(0.01), + }, + input_shape=(3, 4, 2), + expected_output_shape=(3, 4, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=2, + supports_masking=True, + ) + self.run_layer_test( + layers.LayerNormalization, + init_kwargs={ + "gamma_initializer": "ones", + "beta_initializer": "ones", + }, + input_shape=(3, 4, 2), + expected_output_shape=(3, 4, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.LayerNormalization, + init_kwargs={"scale": False, "center": False}, + input_shape=(3, 3), + expected_output_shape=(3, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.LayerNormalization, + init_kwargs={"rms_scaling": True}, + input_shape=(3, 3), + expected_output_shape=(3, 3), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.LayerNormalization, + init_kwargs={"axis": (-3, -2, -1)}, + input_shape=(2, 8, 8, 3), + expected_output_shape=(2, 8, 8, 3), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.LayerNormalization, + init_kwargs={}, + input_shape=(1, 0, 10), + expected_output_shape=(1, 0, 10), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_invalid_axis(self): + with self.assertRaisesRegex( + TypeError, + ("Expected an int or a list/tuple of ints for the argument 'axis'"), + ): + layers.LayerNormalization(axis={"axis": -1}) + + def test_correctness(self): + layer = layers.LayerNormalization(dtype="float32") + layer.build(input_shape=(2, 2, 2)) + inputs = np.random.normal( + loc=5.0, scale=10.0, size=(1000, 2, 2, 2) + ).astype("float32") + + out = layer(inputs) + out = ops.subtract(out, layer.beta) + out = ops.divide(out, layer.gamma) + + self.assertAllClose(ops.mean(out), 0.0, atol=1e-1) + self.assertAllClose(ops.std(out), 1.0, atol=1e-1) + + def test_output(self): + layer = layers.LayerNormalization( + dtype="float32", + beta_initializer="ones", + gamma_initializer="ones", + ) + inputs = np.arange(5).astype("float32")[None, :] + out = layer(inputs) + self.assertAllClose(out, [[-0.41386, 0.29307, 1.0, 1.70693, 2.41386]]) + + def test_output_with_rms_scaling(self): + layer = layers.LayerNormalization( + dtype="float32", + rms_scaling=True, + gamma_initializer="ones", + ) + inputs = np.arange(5).astype("float32")[None, :] + out = layer(inputs) + self.assertAllClose(out, [[0.0, 0.70693, 1.41386, 2.12079, 2.82772]]) + + def test_large_value_within_autocast_scope(self): + layer = layers.LayerNormalization() + layer.build((1, 4, 4, 3)) + # Use 70000 to trigger overflow for float16 + large_value = ops.full(layer.gamma.shape, 70000) + with backend.AutocastScope("float16"): + layer.gamma.assign(large_value) + self.assertAllClose(layer.gamma.value, large_value) diff --git a/keras/src/layers/normalization/rms_normalization.py b/keras/src/layers/normalization/rms_normalization.py new file mode 100644 index 000000000000..6af57ef8f073 --- /dev/null +++ b/keras/src/layers/normalization/rms_normalization.py @@ -0,0 +1,98 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.RMSNormalization") +class RMSNormalization(Layer): + """Root Mean Square (RMS) Normalization layer. + + This layer normalizes the input tensor based on its RMS value. + + The Keras layer performs the operation as described in + [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) + by Biao Zhang et al. + + + If `scale` is enabled, the layer will scale the normalized outputs via + a learnable scaling factor. + + So, with scaling enabled, the normalization equations + are as follows: + + Let the intermediate activations for a mini-batch to be the `inputs`. + + ```python + rms_normalization(x) = x * rsqrt(mean(square(x))) * scale + ``` + + For example: + + >>> layer = keras.layers.RMSNormalization() + >>> layer.build([5, 20, 30, 10]) + >>> print(layer.scale.shape) + (10,) + >>> layer(np.random.rand(1, 10)).numpy() + array([[0.35098287, 1.0495652 , 1.4645109 , 1.2944688 , 0.31124955, + 1.2768592 , 1.184331 , 0.17474432, 0.49955517, 1.2428929 ]], + dtype=float32) + + Args: + axis: int. The axis on which to perform the normalization. + epsilon: float. A small number to add to avoid division by zero. + """ + + def __init__(self, axis=-1, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.axis = axis + self.epsilon = epsilon + + def build(self, input_shape): + if isinstance(self.axis, list): + shape = tuple([input_shape[dim] for dim in self.axis]) + else: + shape = (input_shape[self.axis],) + self.axis = [self.axis] + + self.scale = self.add_weight( + name="scale", shape=shape, initializer="ones" + ) + + self.built = True + + def call(self, x): + """Applies RMS normalization to the input tensor. + + Args: + x: Input tensor of shape (batch_size, input_dim). + + Returns: + The RMS-normalized tensor of the same shape (batch_size, input_dim), + scaled by the learned `scale` parameter. + """ + return ops.rms_normalization( + x, scale=self.scale, axis=self.axis, epsilon=self.epsilon + ) + + def compute_output_shape(self, input_shape): + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + f"input shape {input_shape}. " + f"Received: axis={self.axis}" + ) + return input_shape + + def get_config(self): + config = { + "axis": self.axis, + "epsilon": self.epsilon, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/normalization/rms_normalization_test.py b/keras/src/layers/normalization/rms_normalization_test.py new file mode 100644 index 000000000000..5e56fa94634b --- /dev/null +++ b/keras/src/layers/normalization/rms_normalization_test.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class RMSNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_ln_basics(self): + self.run_layer_test( + layers.RMSNormalization, + init_kwargs={}, + input_shape=(4, 2), + expected_output_shape=(4, 2), + expected_num_trainable_weights=1, + expected_num_seed_generators=0, + ) + self.run_layer_test( + layers.RMSNormalization, + init_kwargs={ + "axis": -1, + }, + input_shape=(4, 2), + expected_output_shape=(4, 2), + expected_num_trainable_weights=1, + expected_num_seed_generators=0, + ) + + def test_correctness(self): + layer = layers.RMSNormalization() + layer.build(input_shape=(2, 2, 2)) + inputs = np.random.normal( + loc=5.0, scale=10.0, size=(1000, 2, 2, 2) + ).astype("float32") + + inputs = ops.convert_to_tensor(inputs) + + out = layer(inputs) + expected = ops.multiply( + ops.multiply( + inputs, + ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True)), + ), + layer.scale, + ) + + self.assertAllClose(out, expected, atol=1e-1) + + def test_output(self): + layer = layers.RMSNormalization() + inputs = np.arange(10).astype("float32")[None, :] + out = layer(inputs) + self.assertAllClose( + out, + [ + [ + 0.0, + 0.18731716, + 0.37463433, + 0.5619515, + 0.74926865, + 0.9365858, + 1.123903, + 1.3112202, + 1.4985373, + 1.6858544, + ] + ], + ) diff --git a/keras/src/layers/normalization/spectral_normalization.py b/keras/src/layers/normalization/spectral_normalization.py new file mode 100644 index 000000000000..70b81c75627c --- /dev/null +++ b/keras/src/layers/normalization/spectral_normalization.py @@ -0,0 +1,121 @@ +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers import Wrapper +from keras.src.layers.input_spec import InputSpec +from keras.src.utils.numerical_utils import normalize + + +@keras_export("keras.layers.SpectralNormalization") +class SpectralNormalization(Wrapper): + """Performs spectral normalization on the weights of a target layer. + + This wrapper controls the Lipschitz constant of the weights of a layer by + constraining their spectral norm, which can stabilize the training of GANs. + + Args: + layer: A `keras.layers.Layer` instance that + has either a `kernel` (e.g. `Conv2D`, `Dense`...) + or an `embeddings` attribute (`Embedding` layer). + power_iterations: int, the number of iterations during normalization. + **kwargs: Base wrapper keyword arguments. + + Examples: + + Wrap `keras.layers.Conv2D`: + >>> x = np.random.rand(1, 10, 10, 1) + >>> conv2d = SpectralNormalization(keras.layers.Conv2D(2, 2)) + >>> y = conv2d(x) + >>> y.shape + (1, 9, 9, 2) + + Wrap `keras.layers.Dense`: + >>> x = np.random.rand(1, 10, 10, 1) + >>> dense = SpectralNormalization(keras.layers.Dense(10)) + >>> y = dense(x) + >>> y.shape + (1, 10, 10, 10) + + Reference: + + - [Spectral Normalization for GAN](https://arxiv.org/abs/1802.05957). + """ + + def __init__(self, layer, power_iterations=1, **kwargs): + super().__init__(layer, **kwargs) + if power_iterations <= 0: + raise ValueError( + "`power_iterations` should be greater than zero. Received: " + f"`power_iterations={power_iterations}`" + ) + self.power_iterations = power_iterations + + def build(self, input_shape): + super().build(input_shape) + self.input_spec = InputSpec(min_ndim=1, axes={-1: input_shape[-1]}) + + if hasattr(self.layer, "kernel"): + self.kernel = self.layer.kernel + elif hasattr(self.layer, "embeddings"): + self.kernel = self.layer.embeddings + else: + raise ValueError( + f"{type(self.layer).__name__} object has no attribute 'kernel' " + "nor 'embeddings'" + ) + + self.kernel_shape = self.kernel.shape + + self.vector_u = self.add_weight( + shape=(1, self.kernel_shape[-1]), + initializer=initializers.TruncatedNormal(stddev=0.02), + trainable=False, + name="vector_u", + dtype=self.kernel.dtype, + ) + + def call(self, inputs, training=False): + if training: + new_vector_u, new_kernel = ops.cond( + ops.all(ops.equal(self.kernel.value, 0)), + lambda: (self.vector_u.value, self.kernel.value), + self.normalized_weights, + ) + self.vector_u.assign(new_vector_u) + self.kernel.assign(new_kernel) + + output = self.layer(inputs) + return ops.cast(output, inputs.dtype) + + def compute_output_shape(self, input_shape): + return self.layer.compute_output_shape(input_shape) + + def normalized_weights(self): + """Generate spectral normalized weights. + + This method returns the updated value for `self.kernel` with the + spectral normalized value, so that the layer is ready for `call()`. + """ + + weights = ops.reshape(self.kernel, [-1, self.kernel_shape[-1]]) + vector_u = self.vector_u.value + + for _ in range(self.power_iterations): + vector_v = normalize( + ops.matmul(vector_u, ops.transpose(weights)), axis=None + ) + vector_u = normalize(ops.matmul(vector_v, weights), axis=None) + vector_u = ops.stop_gradient(vector_u) + vector_v = ops.stop_gradient(vector_v) + sigma = ops.matmul( + ops.matmul(vector_v, weights), ops.transpose(vector_u) + ) + kernel = ops.reshape(ops.divide(self.kernel, sigma), self.kernel_shape) + return ops.cast(vector_u, self.vector_u.dtype), ops.cast( + kernel, self.kernel.dtype + ) + + def get_config(self): + config = {"power_iterations": self.power_iterations} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/normalization/spectral_normalization_test.py b/keras/src/layers/normalization/spectral_normalization_test.py new file mode 100644 index 000000000000..bf7f459e62b6 --- /dev/null +++ b/keras/src/layers/normalization/spectral_normalization_test.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import models +from keras.src import testing + + +class SpectralNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basic_spectralnorm(self): + self.run_layer_test( + layers.SpectralNormalization, + init_kwargs={"layer": layers.Dense(2)}, + input_data=np.random.uniform(size=(10, 3, 4)), + expected_output_shape=(10, 3, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.SpectralNormalization, + init_kwargs={"layer": layers.Embedding(10, 4)}, + input_data=np.random.randint(10, size=(10,)).astype("float32"), + expected_output_shape=(10, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + ) + + @pytest.mark.requires_trainable_backend + def test_spectralnorm_higher_dim(self): + self.run_layer_test( + layers.SpectralNormalization, + init_kwargs={"layer": layers.Dense(2)}, + input_data=np.random.uniform(size=(10, 3, 4, 5)), + expected_output_shape=(10, 3, 4, 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_invalid_power_iterations(self): + with self.assertRaisesRegex( + ValueError, "`power_iterations` should be greater than zero." + ): + layers.SpectralNormalization(layers.Dense(2), power_iterations=0) + + def test_invalid_layer(self): + layer = layers.SpectralNormalization(layers.ReLU()) + inputs = np.ones(shape=(4, 2)) + with self.assertRaisesRegex( + ValueError, "object has no attribute 'kernel' nor 'embeddings'" + ): + layer(inputs) + + def test_apply_layer(self): + if backend.config.image_data_format() == "channels_last": + images = np.ones((1, 2, 2, 1)) + else: + images = np.ones((1, 1, 2, 2)) + sn_wrapper = layers.SpectralNormalization( + layers.Conv2D( + 1, (2, 2), kernel_initializer=initializers.Constant(value=1) + ), + power_iterations=8, + ) + + result = sn_wrapper(images, training=False) + result_train = sn_wrapper(images, training=True) + expected_output = np.array([[[[4.0]]]], dtype=np.float32) + self.assertAllClose(result, expected_output) + # max eigen value of 2x2 matrix of ones is 2 + self.assertAllClose(result_train, expected_output / 2) + + @pytest.mark.requires_trainable_backend + def test_end_to_end(self): + sn_wrapper = layers.SpectralNormalization( + layers.Conv2D( + 3, (2, 2), padding="same", data_format="channels_last" + ), + power_iterations=2, + ) + model = models.Sequential([sn_wrapper]) + model.compile("rmsprop", loss="mse") + x = np.random.random((4, 8, 8, 3)) + y = np.random.random((4, 8, 8, 3)) + model.fit(x, y) diff --git a/keras/src/layers/normalization/unit_normalization.py b/keras/src/layers/normalization/unit_normalization.py new file mode 100644 index 000000000000..15ba884f1bbc --- /dev/null +++ b/keras/src/layers/normalization/unit_normalization.py @@ -0,0 +1,64 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.UnitNormalization") +class UnitNormalization(Layer): + """Unit normalization layer. + + Normalize a batch of inputs so that each input in the batch has a L2 norm + equal to 1 (across the axes specified in `axis`). + + Example: + + >>> data = np.arange(6).reshape(2, 3) + >>> normalized_data = keras.layers.UnitNormalization()(data) + >>> np.sum(normalized_data[0, :] ** 2) + 1.0 + + Args: + axis: Integer or list/tuple. The axis or axes to normalize across. + Typically, this is the features axis or axes. The left-out axes are + typically the batch axis or axes. `-1` is the last dimension + in the input. Defaults to `-1`. + """ + + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + if isinstance(axis, (list, tuple)): + self.axis = list(axis) + elif isinstance(axis, int): + self.axis = axis + else: + raise TypeError( + "Invalid value for `axis` argument: " + "expected an int or a list/tuple of ints. " + f"Received: axis={axis}" + ) + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs): + return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12) + + def compute_output_shape(self, input_shape): + # Ensure axis is always treated as a list + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {self.axis} is out of bounds for " + f"input shape {input_shape}." + ) + return input_shape + + def get_config(self): + config = super().get_config() + config.update({"axis": self.axis}) + return config diff --git a/keras/src/layers/normalization/unit_normalization_test.py b/keras/src/layers/normalization/unit_normalization_test.py new file mode 100644 index 000000000000..8e43ee64f58d --- /dev/null +++ b/keras/src/layers/normalization/unit_normalization_test.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +def squared_l2_norm(x): + x = backend.convert_to_numpy(x) + return np.sum(x**2) + + +class UnitNormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_un_basics(self): + self.run_layer_test( + layers.UnitNormalization, + init_kwargs={"axis": -1}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + supports_masking=True, + assert_built_after_instantiation=True, + ) + self.run_layer_test( + layers.UnitNormalization, + init_kwargs={"axis": (1, 2)}, + input_shape=(1, 3, 3), + expected_output_shape=(1, 3, 3), + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_invalid_axis(self): + with self.assertRaisesRegex( + TypeError, + ( + "Invalid value for `axis` argument: expected an int or a " + "list/tuple of ints." + ), + ): + layers.UnitNormalization(axis={"axis": -1}) + + def test_correctness(self): + layer = layers.UnitNormalization(axis=-1) + inputs = np.random.normal(size=(2, 3)) + outputs = layer(inputs) + self.assertAllClose(squared_l2_norm(outputs[0, :]), 1.0) + self.assertAllClose(squared_l2_norm(outputs[1, :]), 1.0) + + layer = layers.UnitNormalization(axis=(1, 2)) + inputs = np.random.normal(size=(2, 3, 3)) + outputs = layer(inputs) + self.assertAllClose(squared_l2_norm(outputs[0, :, :]), 1.0) + self.assertAllClose(squared_l2_norm(outputs[1, :, :]), 1.0) + + layer = layers.UnitNormalization(axis=1) + inputs = np.random.normal(size=(2, 3, 2)) + outputs = layer(inputs) + self.assertAllClose(squared_l2_norm(outputs[0, :, 0]), 1.0) + self.assertAllClose(squared_l2_norm(outputs[1, :, 0]), 1.0) + self.assertAllClose(squared_l2_norm(outputs[0, :, 1]), 1.0) + self.assertAllClose(squared_l2_norm(outputs[1, :, 1]), 1.0) diff --git a/keras/src/layers/pooling/__init__.py b/keras/src/layers/pooling/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/pooling/average_pooling1d.py b/keras/src/layers/pooling/average_pooling1d.py new file mode 100644 index 000000000000..0450149c0473 --- /dev/null +++ b/keras/src/layers/pooling/average_pooling1d.py @@ -0,0 +1,92 @@ +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_pooling import BasePooling + + +@keras_export(["keras.layers.AveragePooling1D", "keras.layers.AvgPool1D"]) +class AveragePooling1D(BasePooling): + """Average pooling for temporal data. + + Downsamples the input representation by taking the average value over the + window defined by `pool_size`. The window is shifted by `strides`. The + resulting output when using "valid" padding option has a shape of: + `output_shape = (input_shape - pool_size + 1) / strides)` + + The resulting output shape when using the "same" padding option is: + `output_shape = input_shape / strides` + + Args: + pool_size: int, size of the max pooling window. + strides: int or None. Specifies how much the pooling window moves + for each pooling step. If None, it will default to `pool_size`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + + Input shape: + + - If `data_format="channels_last"`: + 3D tensor with shape `(batch_size, steps, features)`. + - If `data_format="channels_first"`: + 3D tensor with shape `(batch_size, features, steps)`. + + Output shape: + + - If `data_format="channels_last"`: + 3D tensor with shape `(batch_size, downsampled_steps, features)`. + - If `data_format="channels_first"`: + 3D tensor with shape `(batch_size, features, downsampled_steps)`. + + Examples: + + `strides=1` and `padding="valid"`: + + >>> x = np.array([1., 2., 3., 4., 5.]) + >>> x = np.reshape(x, [1, 5, 1]) + >>> avg_pool_1d = keras.layers.AveragePooling1D(pool_size=2, + ... strides=1, padding="valid") + >>> avg_pool_1d(x) + + `strides=2` and `padding="valid"`: + + >>> x = np.array([1., 2., 3., 4., 5.]) + >>> x = np.reshape(x, [1, 5, 1]) + >>> avg_pool_1d = keras.layers.AveragePooling1D(pool_size=2, + ... strides=2, padding="valid") + >>> avg_pool_1d(x) + + `strides=1` and `padding="same"`: + + >>> x = np.array([1., 2., 3., 4., 5.]) + >>> x = np.reshape(x, [1, 5, 1]) + >>> avg_pool_1d = keras.layers.AveragePooling1D(pool_size=2, + ... strides=1, padding="same") + >>> avg_pool_1d(x) + """ + + def __init__( + self, + pool_size, + strides=None, + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__( + pool_size, + strides, + pool_dimensions=1, + pool_mode="average", + padding=padding, + data_format=data_format, + name=name, + **kwargs, + ) diff --git a/keras/src/layers/pooling/average_pooling2d.py b/keras/src/layers/pooling/average_pooling2d.py new file mode 100644 index 000000000000..a32972779f1f --- /dev/null +++ b/keras/src/layers/pooling/average_pooling2d.py @@ -0,0 +1,109 @@ +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_pooling import BasePooling + + +@keras_export(["keras.layers.AveragePooling2D", "keras.layers.AvgPool2D"]) +class AveragePooling2D(BasePooling): + """Average pooling operation for 2D spatial data. + + Downsamples the input along its spatial dimensions (height and width) + by taking the average value over an input window + (of size defined by `pool_size`) for each channel of the input. + The window is shifted by `strides` along each dimension. + + The resulting output when using the `"valid"` padding option has a spatial + shape (number of rows or columns) of: + `output_shape = math.floor((input_shape - pool_size) / strides) + 1` + (when `input_shape >= pool_size`) + + The resulting output shape when using the `"same"` padding option is: + `output_shape = input_shape` + + Args: + pool_size: int or tuple of 2 integers, factors by which to downscale + (dim1, dim2). If only one integer is specified, the same + window length will be used for all dimensions. + strides: int or tuple of 2 integers, or None. Strides values. If None, + it will default to `pool_size`. If only one int is specified, the + same stride size will be used for all dimensions. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + + Input shape: + + - If `data_format="channels_last"`: + 4D tensor with shape `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape `(batch_size, channels, height, width)`. + + Output shape: + + - If `data_format="channels_last"`: + 4D tensor with shape + `(batch_size, pooled_height, pooled_width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape + `(batch_size, channels, pooled_height, pooled_width)`. + + Examples: + + `strides=(1, 1)` and `padding="valid"`: + + >>> x = np.array([[1., 2., 3.], + ... [4., 5., 6.], + ... [7., 8., 9.]]) + >>> x = np.reshape(x, [1, 3, 3, 1]) + >>> avg_pool_2d = keras.layers.AveragePooling2D(pool_size=(2, 2), + ... strides=(1, 1), padding="valid") + >>> avg_pool_2d(x) + + `strides=(2, 2)` and `padding="valid"`: + + >>> x = np.array([[1., 2., 3., 4.], + ... [5., 6., 7., 8.], + ... [9., 10., 11., 12.]]) + >>> x = np.reshape(x, [1, 3, 4, 1]) + >>> avg_pool_2d = keras.layers.AveragePooling2D(pool_size=(2, 2), + ... strides=(2, 2), padding="valid") + >>> avg_pool_2d(x) + + `stride=(1, 1)` and `padding="same"`: + + >>> x = np.array([[1., 2., 3.], + ... [4., 5., 6.], + ... [7., 8., 9.]]) + >>> x = np.reshape(x, [1, 3, 3, 1]) + >>> avg_pool_2d = keras.layers.AveragePooling2D(pool_size=(2, 2), + ... strides=(1, 1), padding="same") + >>> avg_pool_2d(x) + """ + + def __init__( + self, + pool_size, + strides=None, + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__( + pool_size, + strides, + pool_dimensions=2, + pool_mode="average", + padding=padding, + data_format=data_format, + name=name, + **kwargs, + ) diff --git a/keras/src/layers/pooling/average_pooling3d.py b/keras/src/layers/pooling/average_pooling3d.py new file mode 100644 index 000000000000..2e5c7448d332 --- /dev/null +++ b/keras/src/layers/pooling/average_pooling3d.py @@ -0,0 +1,85 @@ +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_pooling import BasePooling + + +@keras_export(["keras.layers.AveragePooling3D", "keras.layers.AvgPool3D"]) +class AveragePooling3D(BasePooling): + """Average pooling operation for 3D data (spatial or spatio-temporal). + + Downsamples the input along its spatial dimensions (depth, height, and + width) by taking the average value over an input window (of size defined by + `pool_size`) for each channel of the input. The window is shifted by + `strides` along each dimension. + + Args: + pool_size: int or tuple of 3 integers, factors by which to downscale + (dim1, dim2, dim3). If only one integer is specified, the same + window length will be used for all dimensions. + strides: int or tuple of 3 integers, or None. Strides values. If None, + it will default to `pool_size`. If only one int is specified, the + same stride size will be used for all dimensions. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while + `"channels_first"` corresponds to inputs with shape + `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + It defaults to the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. If you never set it, then it + will be `"channels_last"`. + + Input shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` + + Output shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)` + + Example: + + ```python + depth = 30 + height = 30 + width = 30 + channels = 3 + + inputs = keras.layers.Input(shape=(depth, height, width, channels)) + layer = keras.layers.AveragePooling3D(pool_size=3) + outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3) + ``` + """ + + def __init__( + self, + pool_size, + strides=None, + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__( + pool_size, + strides, + pool_dimensions=3, + pool_mode="average", + padding=padding, + data_format=data_format, + name=name, + **kwargs, + ) diff --git a/keras/src/layers/pooling/average_pooling_test.py b/keras/src/layers/pooling/average_pooling_test.py new file mode 100644 index 000000000000..02bbdd301989 --- /dev/null +++ b/keras/src/layers/pooling/average_pooling_test.py @@ -0,0 +1,385 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from numpy.lib.stride_tricks import as_strided + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +def _same_padding(input_size, pool_size, stride): + if input_size % stride == 0: + return max(pool_size - stride, 0) + else: + return max(pool_size - (input_size % stride), 0) + + +def np_avgpool1d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.swapaxes(1, 2) + if isinstance(pool_size, (tuple, list)): + pool_size = pool_size[0] + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + + if padding == "same": + n_batch, h_x, ch_x = x.shape + pad_value = _same_padding(h_x, pool_size, h_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, pad_value) + x = np.pad(x, pad_width=npad, mode="edge") + + n_batch, h_x, ch_x = x.shape + out_h = int((h_x - pool_size) / h_stride) + 1 + + stride_shape = (n_batch, out_h, ch_x, pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + x.strides[2], + x.strides[1], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.mean(windows, axis=(3,)) + if data_format == "channels_first": + out = out.swapaxes(1, 2) + return out + + +def np_avgpool2d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + + h_pool_size, w_pool_size = pool_size + h_stride, w_stride = strides + if padding == "same": + n_batch, h_x, w_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + x = np.pad(x, pad_width=npad, mode="edge") + + n_batch, h_x, w_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + x.strides[3], + x.strides[1], + x.strides[2], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.mean(windows, axis=(4, 5)) + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + return out + + +def np_avgpool3d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + + h_pool_size, w_pool_size, d_pool_size = pool_size + h_stride, w_stride, d_stride = strides + + if padding == "same": + n_batch, h_x, w_x, d_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + d_padding = _same_padding(d_x, d_pool_size, d_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + npad[3] = (0, d_padding) + x = np.pad(x, pad_width=npad, mode="symmetric") + + n_batch, h_x, w_x, d_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + out_d = int((d_x - d_pool_size) / d_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + d_stride * x.strides[3], + x.strides[4], + x.strides[1], + x.strides[2], + x.strides[3], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.mean(windows, axis=(5, 6, 7)) + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) + return out + + +@pytest.mark.requires_trainable_backend +class AveragePoolingBasicTest(testing.TestCase): + @parameterized.parameters( + (2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)), + (2, 1, "same", "channels_first", (3, 5, 4), (3, 5, 4)), + ((2,), (2,), "valid", "channels_last", (3, 5, 4), (3, 2, 4)), + ) + def test_average_pooling1d( + self, + pool_size, + strides, + padding, + data_format, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.AveragePooling1D, + init_kwargs={ + "pool_size": pool_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + (2, 1, "valid", "channels_last", (3, 5, 5, 4), (3, 4, 4, 4)), + (2, 1, "same", "channels_last", (3, 5, 5, 4), (3, 5, 5, 4)), + (2, 1, "valid", "channels_first", (3, 5, 5, 4), (3, 5, 4, 3)), + (2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)), + ((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)), + ((2, 3), (2, 2), "same", "channels_last", (3, 5, 5, 4), (3, 3, 3, 4)), + ((2, 3), (3, 3), "same", "channels_first", (3, 5, 5, 4), (3, 5, 2, 2)), + ) + def test_average_pooling2d( + self, + pool_size, + strides, + padding, + data_format, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.AveragePooling2D, + init_kwargs={ + "pool_size": pool_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + (2, 1, "valid", "channels_last", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)), + (2, 1, "same", "channels_first", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)), + ( + (2, 3, 2), + (2, 2, 1), + "valid", + "channels_last", + (3, 5, 5, 5, 4), + (3, 2, 2, 4, 4), + ), + ) + def test_average_pooling3d( + self, + pool_size, + strides, + padding, + data_format, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.AveragePooling3D, + init_kwargs={ + "pool_size": pool_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + # Incomplete op support on tensorflow. + run_mixed_precision_check=False, + assert_built_after_instantiation=True, + ) + + +class AveragePoolingCorrectnessTest(testing.TestCase): + @parameterized.parameters( + (2, 1, "valid", "channels_last"), + (2, 1, "valid", "channels_first"), + ((2,), (2,), "valid", "channels_last"), + ((2,), (2,), "valid", "channels_first"), + ) + def test_average_pooling1d(self, pool_size, strides, padding, data_format): + inputs = np.arange(24, dtype="float32").reshape((2, 3, 4)) + + layer = layers.AveragePooling1D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_avgpool1d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, 1, "same", "channels_last"), + (2, 1, "same", "channels_first"), + ((2,), (2,), "same", "channels_last"), + ((2,), (2,), "same", "channels_first"), + ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Same padding in Torch backend produces different results.", + ) + def test_average_pooling1d_same_padding( + self, pool_size, strides, padding, data_format + ): + inputs = np.arange(24, dtype="float32").reshape((2, 3, 4)) + + layer = layers.AveragePooling1D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_avgpool1d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, 1, "valid", "channels_last"), + ((2, 3), (2, 2), "valid", "channels_last"), + ) + def test_average_pooling2d(self, pool_size, strides, padding, data_format): + inputs = np.arange(16, dtype="float32").reshape((1, 4, 4, 1)) + layer = layers.AveragePooling2D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_avgpool2d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, (2, 1), "same", "channels_last"), + (2, (2, 1), "same", "channels_first"), + ((2, 2), (2, 2), "same", "channels_last"), + ((2, 2), (2, 2), "same", "channels_first"), + ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Same padding in Torch backend produces different results.", + ) + def test_average_pooling2d_same_padding( + self, pool_size, strides, padding, data_format + ): + inputs = np.arange(16, dtype="float32").reshape((1, 4, 4, 1)) + layer = layers.AveragePooling2D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_avgpool2d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, 1, "valid", "channels_last"), + (2, 1, "valid", "channels_first"), + ((2, 3, 2), (2, 2, 1), "valid", "channels_last"), + ((2, 3, 2), (2, 2, 1), "valid", "channels_first"), + ) + def test_average_pooling3d(self, pool_size, strides, padding, data_format): + inputs = np.arange(240, dtype="float32").reshape((2, 3, 4, 5, 2)) + + layer = layers.AveragePooling3D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_avgpool3d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, 1, "same", "channels_last"), + (2, 1, "same", "channels_first"), + ((2, 2, 2), (2, 2, 1), "same", "channels_last"), + ((2, 2, 2), (2, 2, 1), "same", "channels_first"), + ) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Same padding in Torch backend produces different results.", + ) + def test_average_pooling3d_same_padding( + self, pool_size, strides, padding, data_format + ): + inputs = np.arange(240, dtype="float32").reshape((2, 3, 4, 5, 2)) + + layer = layers.AveragePooling3D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_avgpool3d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) diff --git a/keras/src/layers/pooling/base_global_pooling.py b/keras/src/layers/pooling/base_global_pooling.py new file mode 100644 index 000000000000..95e9ddca550f --- /dev/null +++ b/keras/src/layers/pooling/base_global_pooling.py @@ -0,0 +1,50 @@ +from keras.src import backend +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +class BaseGlobalPooling(Layer): + """Base global pooling layer.""" + + def __init__( + self, pool_dimensions, data_format=None, keepdims=False, **kwargs + ): + super().__init__(**kwargs) + + self.data_format = backend.standardize_data_format(data_format) + self.keepdims = keepdims + self.input_spec = InputSpec(ndim=pool_dimensions + 2) + + self._build_at_init() + + def call(self, inputs): + raise NotImplementedError + + def compute_output_shape(self, input_shape): + num_spatial_dims = len(input_shape) - 2 + if self.data_format == "channels_last": + if self.keepdims: + return ( + (input_shape[0],) + + (1,) * num_spatial_dims + + (input_shape[-1],) + ) + else: + return (input_shape[0],) + (input_shape[-1],) + else: + if self.keepdims: + return (input_shape[0], input_shape[1]) + ( + 1, + ) * num_spatial_dims + else: + return (input_shape[0], input_shape[1]) + + def get_config(self): + config = super().get_config() + config.update( + { + "data_format": self.data_format, + "keepdims": self.keepdims, + } + ) + return config diff --git a/keras/src/layers/pooling/base_pooling.py b/keras/src/layers/pooling/base_pooling.py new file mode 100644 index 000000000000..b427f86ac82a --- /dev/null +++ b/keras/src/layers/pooling/base_pooling.py @@ -0,0 +1,82 @@ +from keras.src import backend +from keras.src import ops +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.ops.operation_utils import compute_pooling_output_shape +from keras.src.utils import argument_validation + + +class BasePooling(Layer): + """Base pooling layer.""" + + def __init__( + self, + pool_size, + strides, + pool_dimensions, + pool_mode="max", + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__(name=name, **kwargs) + + self.pool_size = argument_validation.standardize_tuple( + pool_size, pool_dimensions, "pool_size" + ) + strides = pool_size if strides is None else strides + self.strides = argument_validation.standardize_tuple( + strides, pool_dimensions, "strides", allow_zero=True + ) + self.pool_mode = pool_mode + self.padding = padding + self.data_format = backend.standardize_data_format(data_format) + + self.input_spec = InputSpec(ndim=pool_dimensions + 2) + + self._build_at_init() + + def call(self, inputs): + if self.pool_mode == "max": + return ops.max_pool( + inputs, + pool_size=self.pool_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + ) + elif self.pool_mode == "average": + return ops.average_pool( + inputs, + pool_size=self.pool_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + ) + else: + raise ValueError( + "`pool_mode` must be either 'max' or 'average'. Received: " + f"{self.pool_mode}." + ) + + def compute_output_shape(self, input_shape): + return compute_pooling_output_shape( + input_shape, + self.pool_size, + self.strides, + self.padding, + self.data_format, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "pool_size": self.pool_size, + "padding": self.padding, + "strides": self.strides, + "data_format": self.data_format, + } + ) + return config diff --git a/keras/src/layers/pooling/global_average_pooling1d.py b/keras/src/layers/pooling/global_average_pooling1d.py new file mode 100644 index 000000000000..6db5fb923c8c --- /dev/null +++ b/keras/src/layers/pooling/global_average_pooling1d.py @@ -0,0 +1,86 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + + +@keras_export( + [ + "keras.layers.GlobalAveragePooling1D", + "keras.layers.GlobalAvgPool1D", + ] +) +class GlobalAveragePooling1D(BaseGlobalPooling): + """Global average pooling operation for temporal data. + + Args: + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + keepdims: A boolean, whether to keep the temporal dimension or not. + If `keepdims` is `False` (default), the rank of the tensor is + reduced for spatial dimensions. If `keepdims` is `True`, the + temporal dimension are retained with length 1. + The behavior is the same as for `tf.reduce_mean` or `np.mean`. + + Call arguments: + inputs: A 3D tensor. + mask: Binary tensor of shape `(batch_size, steps)` indicating whether + a given step should be masked (excluded from the average). + + Input shape: + + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, steps)` + + Output shape: + + - If `keepdims=False`: + 2D tensor with shape `(batch_size, features)`. + - If `keepdims=True`: + - If `data_format="channels_last"`: + 3D tensor with shape `(batch_size, 1, features)` + - If `data_format="channels_first"`: + 3D tensor with shape `(batch_size, features, 1)` + + Example: + + >>> x = np.random.rand(2, 3, 4) + >>> y = keras.layers.GlobalAveragePooling1D()(x) + >>> y.shape + (2, 4) + """ + + def __init__(self, data_format=None, keepdims=False, **kwargs): + super().__init__( + pool_dimensions=1, + data_format=data_format, + keepdims=keepdims, + **kwargs, + ) + self.supports_masking = True + + def call(self, inputs, mask=None): + steps_axis = 1 if self.data_format == "channels_last" else 2 + if mask is not None: + mask = backend.cast(mask, inputs[0].dtype) + mask = ops.expand_dims( + mask, 2 if self.data_format == "channels_last" else 1 + ) + inputs *= mask + return ops.sum( + inputs, axis=steps_axis, keepdims=self.keepdims + ) / ops.sum(mask, axis=steps_axis, keepdims=self.keepdims) + else: + return ops.mean(inputs, axis=steps_axis, keepdims=self.keepdims) + + def compute_mask(self, inputs, mask=None): + return None diff --git a/keras/src/layers/pooling/global_average_pooling2d.py b/keras/src/layers/pooling/global_average_pooling2d.py new file mode 100644 index 000000000000..1536c3c302e8 --- /dev/null +++ b/keras/src/layers/pooling/global_average_pooling2d.py @@ -0,0 +1,68 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + + +@keras_export( + [ + "keras.layers.GlobalAveragePooling2D", + "keras.layers.GlobalAvgPool2D", + ] +) +class GlobalAveragePooling2D(BaseGlobalPooling): + """Global average pooling operation for 2D data. + + Args: + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, height, weight)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + keepdims: A boolean, whether to keep the temporal dimension or not. + If `keepdims` is `False` (default), the rank of the tensor is + reduced for spatial dimensions. If `keepdims` is `True`, the + spatial dimension are retained with length 1. + The behavior is the same as for `tf.reduce_mean` or `np.mean`. + + Input shape: + + - If `data_format='channels_last'`: + 4D tensor with shape: + `(batch_size, height, width, channels)` + - If `data_format='channels_first'`: + 4D tensor with shape: + `(batch_size, channels, height, width)` + + Output shape: + + - If `keepdims=False`: + 2D tensor with shape `(batch_size, channels)`. + - If `keepdims=True`: + - If `data_format="channels_last"`: + 4D tensor with shape `(batch_size, 1, 1, channels)` + - If `data_format="channels_first"`: + 4D tensor with shape `(batch_size, channels, 1, 1)` + + Example: + + >>> x = np.random.rand(2, 4, 5, 3) + >>> y = keras.layers.GlobalAveragePooling2D()(x) + >>> y.shape + (2, 3) + """ + + def __init__(self, data_format=None, keepdims=False, **kwargs): + super().__init__( + pool_dimensions=2, + data_format=data_format, + keepdims=keepdims, + **kwargs, + ) + + def call(self, inputs): + if self.data_format == "channels_last": + return ops.mean(inputs, axis=[1, 2], keepdims=self.keepdims) + return ops.mean(inputs, axis=[2, 3], keepdims=self.keepdims) diff --git a/keras/src/layers/pooling/global_average_pooling3d.py b/keras/src/layers/pooling/global_average_pooling3d.py new file mode 100644 index 000000000000..14ffc5bfc4d0 --- /dev/null +++ b/keras/src/layers/pooling/global_average_pooling3d.py @@ -0,0 +1,69 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + + +@keras_export( + [ + "keras.layers.GlobalAveragePooling3D", + "keras.layers.GlobalAvgPool3D", + ] +) +class GlobalAveragePooling3D(BaseGlobalPooling): + """Global average pooling operation for 3D data. + + Args: + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + It defaults to the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. If you never set it, then it + will be `"channels_last"`. + keepdims: A boolean, whether to keep the temporal dimension or not. + If `keepdims` is `False` (default), the rank of the tensor is + reduced for spatial dimensions. If `keepdims` is `True`, the + spatial dimension are retained with length 1. + The behavior is the same as for `tf.reduce_mean` or `np.mean`. + + Input shape: + + - If `data_format='channels_last'`: + 5D tensor with shape: + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + - If `data_format='channels_first'`: + 5D tensor with shape: + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` + + Output shape: + + - If `keepdims=False`: + 2D tensor with shape `(batch_size, channels)`. + - If `keepdims=True`: + - If `data_format="channels_last"`: + 5D tensor with shape `(batch_size, 1, 1, 1, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape `(batch_size, channels, 1, 1, 1)` + + Example: + + >>> x = np.random.rand(2, 4, 5, 4, 3) + >>> y = keras.layers.GlobalAveragePooling3D()(x) + >>> y.shape + (2, 3) + """ + + def __init__(self, data_format=None, keepdims=False, **kwargs): + super().__init__( + pool_dimensions=3, + data_format=data_format, + keepdims=keepdims, + **kwargs, + ) + + def call(self, inputs): + if self.data_format == "channels_last": + return ops.mean(inputs, axis=[1, 2, 3], keepdims=self.keepdims) + return ops.mean(inputs, axis=[2, 3, 4], keepdims=self.keepdims) diff --git a/keras/src/layers/pooling/global_average_pooling_test.py b/keras/src/layers/pooling/global_average_pooling_test.py new file mode 100644 index 000000000000..77b2359b67fa --- /dev/null +++ b/keras/src/layers/pooling/global_average_pooling_test.py @@ -0,0 +1,178 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import testing + + +@pytest.mark.requires_trainable_backend +class GlobalAveragePoolingBasicTest(testing.TestCase): + @parameterized.parameters( + ("channels_last", False, (3, 5, 4), (3, 4)), + ("channels_last", True, (3, 5, 4), (3, 1, 4)), + ("channels_first", False, (3, 5, 4), (3, 5)), + ) + def test_global_average_pooling1d( + self, + data_format, + keepdims, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.GlobalAveragePooling1D, + init_kwargs={ + "data_format": data_format, + "keepdims": keepdims, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + ("channels_last", False, (3, 5, 6, 4), (3, 4)), + ("channels_last", True, (3, 5, 6, 4), (3, 1, 1, 4)), + ("channels_first", False, (3, 5, 6, 4), (3, 5)), + ) + def test_global_average_pooling2d( + self, + data_format, + keepdims, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.GlobalAveragePooling2D, + init_kwargs={ + "data_format": data_format, + "keepdims": keepdims, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + ("channels_last", False, (3, 5, 6, 5, 4), (3, 4)), + ("channels_last", True, (3, 5, 6, 5, 4), (3, 1, 1, 1, 4)), + ("channels_first", False, (3, 5, 6, 5, 4), (3, 5)), + ) + def test_global_average_pooling3d( + self, + data_format, + keepdims, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.GlobalAveragePooling3D, + init_kwargs={ + "data_format": data_format, + "keepdims": keepdims, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + +class GlobalAveragePoolingCorrectnessTest(testing.TestCase): + @parameterized.parameters( + ("channels_last", False), + ("channels_last", True), + ("channels_first", False), + ("channels_first", True), + ) + def test_global_average_pooling1d(self, data_format, keepdims): + def np_gap1d(x, data_format, keepdims, mask=None): + steps_axis = 1 if data_format == "channels_last" else 2 + if mask is not None: + mask = np.expand_dims( + mask, 2 if data_format == "channels_last" else 1 + ) + x *= mask + res = np.sum(x, axis=steps_axis) / np.sum(mask, axis=steps_axis) + else: + res = np.mean(x, axis=steps_axis) + if keepdims: + res = np.expand_dims(res, axis=steps_axis) + return res + + inputs = np.arange(24, dtype="float32").reshape((2, 3, 4)) + layer = layers.GlobalAveragePooling1D( + data_format=data_format, + keepdims=keepdims, + ) + outputs = layer(inputs) + expected = np_gap1d(inputs, data_format, keepdims) + self.assertAllClose(outputs, expected) + + if data_format == "channels_last": + mask = np.array([[1, 1, 0], [0, 1, 0]], dtype="int32") + else: + mask = np.array([[1, 1, 0, 0], [0, 1, 0, 1]], dtype="int32") + outputs = layer(inputs, mask) + expected = np_gap1d(inputs, data_format, keepdims, mask) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + ("channels_last", False), + ("channels_last", True), + ("channels_first", False), + ("channels_first", True), + ) + def test_global_average_pooling2d(self, data_format, keepdims): + def np_gap2d(x, data_format, keepdims): + steps_axis = [1, 2] if data_format == "channels_last" else [2, 3] + res = np.apply_over_axes(np.mean, x, steps_axis) + if not keepdims: + res = res.squeeze() + return res + + inputs = np.arange(96, dtype="float32").reshape((2, 3, 4, 4)) + layer = layers.GlobalAveragePooling2D( + data_format=data_format, + keepdims=keepdims, + ) + outputs = layer(inputs) + expected = np_gap2d(inputs, data_format, keepdims) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + ("channels_last", False), + ("channels_last", True), + ("channels_first", False), + ("channels_first", True), + ) + def test_global_average_pooling3d(self, data_format, keepdims): + def np_gap3d(x, data_format, keepdims): + steps_axis = ( + [1, 2, 3] if data_format == "channels_last" else [2, 3, 4] + ) + res = np.apply_over_axes(np.mean, x, steps_axis) + if not keepdims: + res = res.squeeze() + return res + + inputs = np.arange(360, dtype="float32").reshape((2, 3, 3, 5, 4)) + layer = layers.GlobalAveragePooling3D( + data_format=data_format, + keepdims=keepdims, + ) + outputs = layer(inputs) + expected = np_gap3d(inputs, data_format, keepdims) + self.assertAllClose(outputs, expected) diff --git a/keras/src/layers/pooling/global_max_pooling1d.py b/keras/src/layers/pooling/global_max_pooling1d.py new file mode 100644 index 000000000000..7c6d9ff79692 --- /dev/null +++ b/keras/src/layers/pooling/global_max_pooling1d.py @@ -0,0 +1,66 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + + +@keras_export( + [ + "keras.layers.GlobalMaxPooling1D", + "keras.layers.GlobalMaxPool1D", + ] +) +class GlobalMaxPooling1D(BaseGlobalPooling): + """Global max pooling operation for temporal data. + + Args: + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + keepdims: A boolean, whether to keep the temporal dimension or not. + If `keepdims` is `False` (default), the rank of the tensor is + reduced for spatial dimensions. If `keepdims` is `True`, the + temporal dimension are retained with length 1. + The behavior is the same as for `tf.reduce_mean` or `np.mean`. + + Input shape: + + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, steps)` + + Output shape: + + - If `keepdims=False`: + 2D tensor with shape `(batch_size, features)`. + - If `keepdims=True`: + - If `data_format="channels_last"`: + 3D tensor with shape `(batch_size, 1, features)` + - If `data_format="channels_first"`: + 3D tensor with shape `(batch_size, features, 1)` + + Example: + + >>> x = np.random.rand(2, 3, 4) + >>> y = keras.layers.GlobalMaxPooling1D()(x) + >>> y.shape + (2, 4) + """ + + def __init__(self, data_format=None, keepdims=False, **kwargs): + super().__init__( + pool_dimensions=1, + data_format=data_format, + keepdims=keepdims, + **kwargs, + ) + + def call(self, inputs): + steps_axis = 1 if self.data_format == "channels_last" else 2 + return ops.max(inputs, axis=steps_axis, keepdims=self.keepdims) diff --git a/keras/src/layers/pooling/global_max_pooling2d.py b/keras/src/layers/pooling/global_max_pooling2d.py new file mode 100644 index 000000000000..289ebe0a87d6 --- /dev/null +++ b/keras/src/layers/pooling/global_max_pooling2d.py @@ -0,0 +1,68 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + + +@keras_export( + [ + "keras.layers.GlobalMaxPooling2D", + "keras.layers.GlobalMaxPool2D", + ] +) +class GlobalMaxPooling2D(BaseGlobalPooling): + """Global max pooling operation for 2D data. + + Args: + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, height, weight)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + keepdims: A boolean, whether to keep the temporal dimension or not. + If `keepdims` is `False` (default), the rank of the tensor is + reduced for spatial dimensions. If `keepdims` is `True`, the + spatial dimension are retained with length 1. + The behavior is the same as for `tf.reduce_mean` or `np.mean`. + + Input shape: + + - If `data_format='channels_last'`: + 4D tensor with shape: + `(batch_size, height, width, channels)` + - If `data_format='channels_first'`: + 4D tensor with shape: + `(batch_size, channels, height, width)` + + Output shape: + + - If `keepdims=False`: + 2D tensor with shape `(batch_size, channels)`. + - If `keepdims=True`: + - If `data_format="channels_last"`: + 4D tensor with shape `(batch_size, 1, 1, channels)` + - If `data_format="channels_first"`: + 4D tensor with shape `(batch_size, channels, 1, 1)` + + Example: + + >>> x = np.random.rand(2, 4, 5, 3) + >>> y = keras.layers.GlobalMaxPooling2D()(x) + >>> y.shape + (2, 3) + """ + + def __init__(self, data_format=None, keepdims=False, **kwargs): + super().__init__( + pool_dimensions=2, + data_format=data_format, + keepdims=keepdims, + **kwargs, + ) + + def call(self, inputs): + if self.data_format == "channels_last": + return ops.max(inputs, axis=[1, 2], keepdims=self.keepdims) + return ops.max(inputs, axis=[2, 3], keepdims=self.keepdims) diff --git a/keras/src/layers/pooling/global_max_pooling3d.py b/keras/src/layers/pooling/global_max_pooling3d.py new file mode 100644 index 000000000000..07e1eb065bc7 --- /dev/null +++ b/keras/src/layers/pooling/global_max_pooling3d.py @@ -0,0 +1,69 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling + + +@keras_export( + [ + "keras.layers.GlobalMaxPooling3D", + "keras.layers.GlobalMaxPool3D", + ] +) +class GlobalMaxPooling3D(BaseGlobalPooling): + """Global max pooling operation for 3D data. + + Args: + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + It defaults to the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. If you never set it, then it + will be `"channels_last"`. + keepdims: A boolean, whether to keep the temporal dimension or not. + If `keepdims` is `False` (default), the rank of the tensor is + reduced for spatial dimensions. If `keepdims` is `True`, the + spatial dimension are retained with length 1. + The behavior is the same as for `tf.reduce_mean` or `np.mean`. + + Input shape: + + - If `data_format='channels_last'`: + 5D tensor with shape: + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + - If `data_format='channels_first'`: + 5D tensor with shape: + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` + + Output shape: + + - If `keepdims=False`: + 2D tensor with shape `(batch_size, channels)`. + - If `keepdims=True`: + - If `data_format="channels_last"`: + 5D tensor with shape `(batch_size, 1, 1, 1, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape `(batch_size, channels, 1, 1, 1)` + + Example: + + >>> x = np.random.rand(2, 4, 5, 4, 3) + >>> y = keras.layers.GlobalMaxPooling3D()(x) + >>> y.shape + (2, 3) + """ + + def __init__(self, data_format=None, keepdims=False, **kwargs): + super().__init__( + pool_dimensions=3, + data_format=data_format, + keepdims=keepdims, + **kwargs, + ) + + def call(self, inputs): + if self.data_format == "channels_last": + return ops.max(inputs, axis=[1, 2, 3], keepdims=self.keepdims) + return ops.max(inputs, axis=[2, 3, 4], keepdims=self.keepdims) diff --git a/keras/src/layers/pooling/global_max_pooling_test.py b/keras/src/layers/pooling/global_max_pooling_test.py new file mode 100644 index 000000000000..f88fac43140e --- /dev/null +++ b/keras/src/layers/pooling/global_max_pooling_test.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import testing + + +@pytest.mark.requires_trainable_backend +class GlobalMaxPoolingBasicTest(testing.TestCase): + @parameterized.parameters( + ("channels_last", False, (3, 5, 4), (3, 4)), + ("channels_last", True, (3, 5, 4), (3, 1, 4)), + ("channels_first", False, (3, 5, 4), (3, 5)), + ) + def test_global_max_pooling1d( + self, + data_format, + keepdims, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.GlobalMaxPooling1D, + init_kwargs={ + "data_format": data_format, + "keepdims": keepdims, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + ("channels_last", False, (3, 5, 6, 4), (3, 4)), + ("channels_last", True, (3, 5, 6, 4), (3, 1, 1, 4)), + ("channels_first", False, (3, 5, 6, 4), (3, 5)), + ) + def test_global_max_pooling2d( + self, + data_format, + keepdims, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.GlobalMaxPooling2D, + init_kwargs={ + "data_format": data_format, + "keepdims": keepdims, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + ("channels_last", False, (3, 5, 6, 5, 4), (3, 4)), + ("channels_last", True, (3, 5, 6, 5, 4), (3, 1, 1, 1, 4)), + ("channels_first", False, (3, 5, 6, 5, 4), (3, 5)), + ) + def test_global_max_pooling3d( + self, + data_format, + keepdims, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.GlobalMaxPooling3D, + init_kwargs={ + "data_format": data_format, + "keepdims": keepdims, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + +class GlobalMaxPoolingCorrectnessTest(testing.TestCase): + @parameterized.parameters( + ("channels_last", False), + ("channels_last", True), + ("channels_first", False), + ("channels_first", True), + ) + def test_global_max_pooling1d(self, data_format, keepdims): + def np_global_max_pool1d(x, data_format, keepdims): + steps_axis = [1] if data_format == "channels_last" else [2] + res = np.apply_over_axes(np.max, x, steps_axis) + if not keepdims: + res = res.squeeze() + return res + + inputs = np.arange(24, dtype="float32").reshape((2, 3, 4)) + layer = layers.GlobalMaxPooling1D( + data_format=data_format, + keepdims=keepdims, + ) + outputs = layer(inputs) + expected = np_global_max_pool1d(inputs, data_format, keepdims) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + ("channels_last", False), + ("channels_last", True), + ("channels_first", False), + ("channels_first", True), + ) + def test_global_max_pooling2d(self, data_format, keepdims): + def np_global_max_pool2d(x, data_format, keepdims): + steps_axis = [1, 2] if data_format == "channels_last" else [2, 3] + res = np.apply_over_axes(np.max, x, steps_axis) + if not keepdims: + res = res.squeeze() + return res + + inputs = np.arange(96, dtype="float32").reshape((2, 3, 4, 4)) + layer = layers.GlobalMaxPooling2D( + data_format=data_format, + keepdims=keepdims, + ) + outputs = layer(inputs) + expected = np_global_max_pool2d(inputs, data_format, keepdims) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + ("channels_last", False), + ("channels_last", True), + ("channels_first", False), + ("channels_first", True), + ) + def test_global_max_pooling3d(self, data_format, keepdims): + def np_global_max_pool3d(x, data_format, keepdims): + steps_axis = ( + [1, 2, 3] if data_format == "channels_last" else [2, 3, 4] + ) + res = np.apply_over_axes(np.max, x, steps_axis) + if not keepdims: + res = res.squeeze() + return res + + inputs = np.arange(360, dtype="float32").reshape((2, 3, 3, 5, 4)) + layer = layers.GlobalMaxPooling3D( + data_format=data_format, + keepdims=keepdims, + ) + outputs = layer(inputs) + expected = np_global_max_pool3d(inputs, data_format, keepdims) + self.assertAllClose(outputs, expected) diff --git a/keras/src/layers/pooling/max_pooling1d.py b/keras/src/layers/pooling/max_pooling1d.py new file mode 100644 index 000000000000..c6c35d105f8f --- /dev/null +++ b/keras/src/layers/pooling/max_pooling1d.py @@ -0,0 +1,93 @@ +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_pooling import BasePooling + + +@keras_export(["keras.layers.MaxPooling1D", "keras.layers.MaxPool1D"]) +class MaxPooling1D(BasePooling): + """Max pooling operation for 1D temporal data. + + Downsamples the input representation by taking the maximum value over a + spatial window of size `pool_size`. The window is shifted by `strides`. + + The resulting output when using the `"valid"` padding option has a shape of: + `output_shape = (input_shape - pool_size + 1) / strides)`. + + The resulting output shape when using the `"same"` padding option is: + `output_shape = input_shape / strides` + + Args: + pool_size: int, size of the max pooling window. + strides: int or None. Specifies how much the pooling window moves + for each pooling step. If None, it will default to `pool_size`. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + + Input shape: + + - If `data_format="channels_last"`: + 3D tensor with shape `(batch_size, steps, features)`. + - If `data_format="channels_first"`: + 3D tensor with shape `(batch_size, features, steps)`. + + Output shape: + + - If `data_format="channels_last"`: + 3D tensor with shape `(batch_size, downsampled_steps, features)`. + - If `data_format="channels_first"`: + 3D tensor with shape `(batch_size, features, downsampled_steps)`. + + Examples: + + `strides=1` and `padding="valid"`: + + >>> x = np.array([1., 2., 3., 4., 5.]) + >>> x = np.reshape(x, [1, 5, 1]) + >>> max_pool_1d = keras.layers.MaxPooling1D(pool_size=2, + ... strides=1, padding="valid") + >>> max_pool_1d(x) + + `strides=2` and `padding="valid"`: + + >>> x = np.array([1., 2., 3., 4., 5.]) + >>> x = np.reshape(x, [1, 5, 1]) + >>> max_pool_1d = keras.layers.MaxPooling1D(pool_size=2, + ... strides=2, padding="valid") + >>> max_pool_1d(x) + + `strides=1` and `padding="same"`: + + >>> x = np.array([1., 2., 3., 4., 5.]) + >>> x = np.reshape(x, [1, 5, 1]) + >>> max_pool_1d = keras.layers.MaxPooling1D(pool_size=2, + ... strides=1, padding="same") + >>> max_pool_1d(x) + """ + + def __init__( + self, + pool_size=2, + strides=None, + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__( + pool_size, + strides, + pool_dimensions=1, + pool_mode="max", + padding=padding, + data_format=data_format, + name=name, + **kwargs, + ) diff --git a/keras/src/layers/pooling/max_pooling2d.py b/keras/src/layers/pooling/max_pooling2d.py new file mode 100644 index 000000000000..237da0670ab1 --- /dev/null +++ b/keras/src/layers/pooling/max_pooling2d.py @@ -0,0 +1,109 @@ +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_pooling import BasePooling + + +@keras_export(["keras.layers.MaxPooling2D", "keras.layers.MaxPool2D"]) +class MaxPooling2D(BasePooling): + """Max pooling operation for 2D spatial data. + + Downsamples the input along its spatial dimensions (height and width) + by taking the maximum value over an input window + (of size defined by `pool_size`) for each channel of the input. + The window is shifted by `strides` along each dimension. + + The resulting output when using the `"valid"` padding option has a spatial + shape (number of rows or columns) of: + `output_shape = math.floor((input_shape - pool_size) / strides) + 1` + (when `input_shape >= pool_size`) + + The resulting output shape when using the `"same"` padding option is: + `output_shape = math.floor((input_shape - 1) / strides) + 1` + + Args: + pool_size: int or tuple of 2 integers, factors by which to downscale + (dim1, dim2). If only one integer is specified, the same + window length will be used for all dimensions. + strides: int or tuple of 2 integers, or None. Strides values. If None, + it will default to `pool_size`. If only one int is specified, the + same stride size will be used for all dimensions. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + + Input shape: + + - If `data_format="channels_last"`: + 4D tensor with shape `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape `(batch_size, channels, height, width)`. + + Output shape: + + - If `data_format="channels_last"`: + 4D tensor with shape + `(batch_size, pooled_height, pooled_width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape + `(batch_size, channels, pooled_height, pooled_width)`. + + Examples: + + `strides=(1, 1)` and `padding="valid"`: + + >>> x = np.array([[1., 2., 3.], + ... [4., 5., 6.], + ... [7., 8., 9.]]) + >>> x = np.reshape(x, [1, 3, 3, 1]) + >>> max_pool_2d = keras.layers.MaxPooling2D(pool_size=(2, 2), + ... strides=(1, 1), padding="valid") + >>> max_pool_2d(x) + + `strides=(2, 2)` and `padding="valid"`: + + >>> x = np.array([[1., 2., 3., 4.], + ... [5., 6., 7., 8.], + ... [9., 10., 11., 12.]]) + >>> x = np.reshape(x, [1, 3, 4, 1]) + >>> max_pool_2d = keras.layers.MaxPooling2D(pool_size=(2, 2), + ... strides=(2, 2), padding="valid") + >>> max_pool_2d(x) + + `stride=(1, 1)` and `padding="same"`: + + >>> x = np.array([[1., 2., 3.], + ... [4., 5., 6.], + ... [7., 8., 9.]]) + >>> x = np.reshape(x, [1, 3, 3, 1]) + >>> max_pool_2d = keras.layers.MaxPooling2D(pool_size=(2, 2), + ... strides=(1, 1), padding="same") + >>> max_pool_2d(x) + """ + + def __init__( + self, + pool_size=(2, 2), + strides=None, + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__( + pool_size, + strides, + pool_dimensions=2, + pool_mode="max", + padding=padding, + data_format=data_format, + name=name, + **kwargs, + ) diff --git a/keras/src/layers/pooling/max_pooling3d.py b/keras/src/layers/pooling/max_pooling3d.py new file mode 100644 index 000000000000..d6487e87f321 --- /dev/null +++ b/keras/src/layers/pooling/max_pooling3d.py @@ -0,0 +1,85 @@ +from keras.src.api_export import keras_export +from keras.src.layers.pooling.base_pooling import BasePooling + + +@keras_export(["keras.layers.MaxPooling3D", "keras.layers.MaxPool3D"]) +class MaxPooling3D(BasePooling): + """Max pooling operation for 3D data (spatial or spatio-temporal). + + Downsamples the input along its spatial dimensions (depth, height, and + width) by taking the maximum value over an input window (of size defined by + `pool_size`) for each channel of the input. The window is shifted by + `strides` along each dimension. + + Args: + pool_size: int or tuple of 3 integers, factors by which to downscale + (dim1, dim2, dim3). If only one integer is specified, the same + window length will be used for all dimensions. + strides: int or tuple of 3 integers, or None. Strides values. If None, + it will default to `pool_size`. If only one int is specified, the + same stride size will be used for all dimensions. + padding: string, either `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape + `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while + `"channels_first"` corresponds to inputs with shape + `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + It defaults to the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. If you never set it, then it + will be `"channels_last"`. + + Input shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` + + Output shape: + + - If `data_format="channels_last"`: + 5D tensor with shape: + `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)` + - If `data_format="channels_first"`: + 5D tensor with shape: + `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)` + + Example: + + ```python + depth = 30 + height = 30 + width = 30 + channels = 3 + + inputs = keras.layers.Input(shape=(depth, height, width, channels)) + layer = keras.layers.MaxPooling3D(pool_size=3) + outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3) + ``` + """ + + def __init__( + self, + pool_size=(2, 2, 2), + strides=None, + padding="valid", + data_format=None, + name=None, + **kwargs, + ): + super().__init__( + pool_size, + strides, + pool_dimensions=3, + pool_mode="max", + padding=padding, + data_format=data_format, + name=name, + **kwargs, + ) diff --git a/keras/src/layers/pooling/max_pooling_test.py b/keras/src/layers/pooling/max_pooling_test.py new file mode 100644 index 000000000000..0e8e49d84879 --- /dev/null +++ b/keras/src/layers/pooling/max_pooling_test.py @@ -0,0 +1,306 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from numpy.lib.stride_tricks import as_strided + +from keras.src import layers +from keras.src import testing + + +def _same_padding(input_size, pool_size, stride): + if input_size % stride == 0: + return max(pool_size - stride, 0) + else: + return max(pool_size - (input_size % stride), 0) + + +def np_maxpool1d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.swapaxes(1, 2) + if isinstance(pool_size, (tuple, list)): + pool_size = pool_size[0] + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + + if padding == "same": + n_batch, h_x, ch_x = x.shape + pad_value = _same_padding(h_x, pool_size, h_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, pad_value) + x = np.pad(x, pad_width=npad, mode="constant", constant_values=-np.inf) + + n_batch, h_x, ch_x = x.shape + out_h = int((h_x - pool_size) / h_stride) + 1 + + stride_shape = (n_batch, out_h, ch_x, pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + x.strides[2], + x.strides[1], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.max(windows, axis=(3,)) + if data_format == "channels_first": + out = out.swapaxes(1, 2) + return out + + +def np_maxpool2d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + + h_pool_size, w_pool_size = pool_size + h_stride, w_stride = strides + if padding == "same": + n_batch, h_x, w_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + x = np.pad(x, pad_width=npad, mode="constant", constant_values=-np.inf) + + n_batch, h_x, w_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + x.strides[3], + x.strides[1], + x.strides[2], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.max(windows, axis=(4, 5)) + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + return out + + +def np_maxpool3d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + + h_pool_size, w_pool_size, d_pool_size = pool_size + h_stride, w_stride, d_stride = strides + + if padding == "same": + n_batch, h_x, w_x, d_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + d_padding = _same_padding(d_x, d_pool_size, d_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + npad[3] = (0, d_padding) + x = np.pad(x, pad_width=npad, mode="constant", constant_values=-np.inf) + + n_batch, h_x, w_x, d_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + out_d = int((d_x - d_pool_size) / d_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + d_stride * x.strides[3], + x.strides[4], + x.strides[1], + x.strides[2], + x.strides[3], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.max(windows, axis=(5, 6, 7)) + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) + return out + + +@pytest.mark.requires_trainable_backend +class MaxPoolingBasicTest(testing.TestCase): + @parameterized.parameters( + (2, 1, "valid", "channels_last", (3, 5, 4), (3, 4, 4)), + (2, 1, "same", "channels_first", (3, 5, 4), (3, 5, 4)), + ((2,), (2,), "valid", "channels_last", (3, 5, 4), (3, 2, 4)), + ) + def test_max_pooling1d( + self, + pool_size, + strides, + padding, + data_format, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.MaxPooling1D, + init_kwargs={ + "pool_size": pool_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + (2, 1, "valid", "channels_last", (3, 5, 5, 4), (3, 4, 4, 4)), + (2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)), + ((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)), + ) + def test_max_pooling2d( + self, + pool_size, + strides, + padding, + data_format, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.MaxPooling2D, + init_kwargs={ + "pool_size": pool_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + assert_built_after_instantiation=True, + ) + + @parameterized.parameters( + (2, 1, "valid", "channels_last", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)), + (2, 1, "same", "channels_first", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)), + ( + (2, 3, 2), + (2, 2, 1), + "valid", + "channels_last", + (3, 5, 5, 5, 4), + (3, 2, 2, 4, 4), + ), + ) + def test_max_pooling3d( + self, + pool_size, + strides, + padding, + data_format, + input_shape, + output_shape, + ): + self.run_layer_test( + layers.MaxPooling3D, + init_kwargs={ + "pool_size": pool_size, + "strides": strides, + "padding": padding, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_losses=0, + supports_masking=False, + # Incomplete op support on tensorflow. + run_mixed_precision_check=False, + assert_built_after_instantiation=True, + ) + + +class MaxPoolingCorrectnessTest(testing.TestCase): + @parameterized.parameters( + (2, 1, "valid", "channels_last"), + (2, 1, "valid", "channels_first"), + (2, 1, "same", "channels_last"), + (2, 1, "same", "channels_first"), + ((2,), (2,), "valid", "channels_last"), + ((2,), (2,), "valid", "channels_first"), + ) + def test_max_pooling1d(self, pool_size, strides, padding, data_format): + inputs = np.arange(24, dtype="float32").reshape((2, 3, 4)) + + layer = layers.MaxPooling1D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_maxpool1d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, 1, "valid", "channels_last"), + (2, 1, "valid", "channels_first"), + ((2, 2), (2, 2), "same", "channels_last"), + ((2, 2), (2, 2), "same", "channels_first"), + ((2, 3), (3, 3), "same", "channels_last"), + ) + def test_max_pooling2d(self, pool_size, strides, padding, data_format): + inputs = np.arange(100, dtype="float32").reshape((1, 5, 5, 4)) + + layer = layers.MaxPooling2D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_maxpool2d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) + + @parameterized.parameters( + (2, 1, "valid", "channels_last"), + (2, 1, "same", "channels_first"), + ((2, 3, 2), (2, 2, 1), "valid", "channels_last"), + ((2, 3, 2), (2, 2, 1), "valid", "channels_first"), + ) + def test_max_pooling3d(self, pool_size, strides, padding, data_format): + inputs = np.arange(240, dtype="float32").reshape((2, 3, 4, 5, 2)) + + layer = layers.MaxPooling3D( + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + ) + outputs = layer(inputs) + expected = np_maxpool3d( + inputs, pool_size, strides, padding, data_format + ) + self.assertAllClose(outputs, expected) diff --git a/keras/src/layers/preprocessing/__init__.py b/keras/src/layers/preprocessing/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/preprocessing/category_encoding.py b/keras/src/layers/preprocessing/category_encoding.py new file mode 100644 index 000000000000..681f567a4d21 --- /dev/null +++ b/keras/src/layers/preprocessing/category_encoding.py @@ -0,0 +1,166 @@ +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.utils import backend_utils +from keras.src.utils import numerical_utils + + +@keras_export("keras.layers.CategoryEncoding") +class CategoryEncoding(DataLayer): + """A preprocessing layer which encodes integer features. + + This layer provides options for condensing data into a categorical encoding + when the total number of tokens are known in advance. It accepts integer + values as inputs, and it outputs a dense or sparse representation of those + inputs. For integer inputs where the total number of tokens is not known, + use `keras.layers.IntegerLookup` instead. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Examples: + + **One-hot encoding data** + + >>> layer = keras.layers.CategoryEncoding( + ... num_tokens=4, output_mode="one_hot") + >>> layer([3, 2, 0, 1]) + array([[0., 0., 0., 1.], + [0., 0., 1., 0.], + [1., 0., 0., 0.], + [0., 1., 0., 0.]]> + + **Multi-hot encoding data** + + >>> layer = keras.layers.CategoryEncoding( + ... num_tokens=4, output_mode="multi_hot") + >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]]) + array([[1., 1., 0., 0.], + [1., 0., 0., 0.], + [0., 1., 1., 0.], + [0., 1., 0., 1.]]> + + **Using weighted inputs in `"count"` mode** + + >>> layer = keras.layers.CategoryEncoding( + ... num_tokens=4, output_mode="count") + >>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]]) + >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights) + array([[0.1, 0.2, 0. , 0. ], + [0.2, 0. , 0. , 0. ], + [0. , 0.2, 0.3, 0. ], + [0. , 0.2, 0. , 0.4]]> + + Args: + num_tokens: The total number of tokens the layer should support. All + inputs to the layer must integers in the range `0 <= value < + num_tokens`, or an error will be thrown. + output_mode: Specification for the output of the layer. + Values can be `"one_hot"`, `"multi_hot"` or `"count"`, + configuring the layer as follows: + - `"one_hot"`: Encodes each individual element in the input + into an array of `num_tokens` size, containing a 1 at the + element index. If the last dimension is size 1, will encode + on that dimension. If the last dimension is not size 1, + will append a new dimension for the encoded output. + - `"multi_hot"`: Encodes each sample in the input into a single + array of `num_tokens` size, containing a 1 for each + vocabulary term present in the sample. Treats the last + dimension as the sample dimension, if input shape is + `(..., sample_length)`, output shape will be + `(..., num_tokens)`. + - `"count"`: Like `"multi_hot"`, but the int array contains a + count of the number of times the token at that index + appeared in the sample. + For all output modes, currently only output up to rank 2 is + supported. + Defaults to `"multi_hot"`. + sparse: Whether to return a sparse tensor; for backends that support + sparse tensors. + + Call arguments: + inputs: A 1D or 2D tensor of integer inputs. + count_weights: A tensor in the same shape as `inputs` indicating the + weight for each sample value when summing up in `count` mode. + Not used in `"multi_hot"` or `"one_hot"` modes. + """ + + def __init__( + self, num_tokens=None, output_mode="multi_hot", sparse=False, **kwargs + ): + super().__init__(**kwargs) + + # Support deprecated names for output_modes. + if output_mode == "binary": + output_mode = "multi_hot" + + # 'output_mode' must be one of ("count", "one_hot", "multi_hot") + if output_mode not in ("count", "one_hot", "multi_hot"): + raise ValueError(f"Unknown arg for output_mode: {output_mode}") + + if num_tokens is None: + raise ValueError( + "num_tokens must be set to use this layer. If the " + "number of tokens is not known beforehand, use the " + "IntegerLookup layer instead." + ) + if num_tokens < 1: + raise ValueError( + f"`num_tokens` must be >= 1. Received: num_tokens={num_tokens}." + ) + self.num_tokens = num_tokens + self.output_mode = output_mode + self.sparse = sparse + self._allow_non_tensor_positional_args = True + self._convert_input_args = False + + def _encode(self, inputs, count_weights=None): + inputs = self.backend.core.convert_to_tensor(inputs) + return numerical_utils.encode_categorical_inputs( + inputs, + output_mode=self.output_mode, + depth=self.num_tokens, + dtype=self.dtype, + sparse=self.sparse, + count_weights=count_weights, + backend_module=self.backend, + ) + + def compute_output_shape(self, input_shape): + if (input_shape is not None) & (len(input_shape) == 0): + return (self.num_tokens,) + if self.output_mode == "one_hot": + if input_shape[-1] != 1: + return tuple(input_shape) + (self.num_tokens,) + elif len(input_shape) == 1: + return tuple(input_shape) + (self.num_tokens,) + else: + return tuple(input_shape[:-1]) + (self.num_tokens,) + return tuple(input_shape[:-1]) + (self.num_tokens,) + + def compute_output_spec(self, inputs, count_weights=None): + output_shape = self.compute_output_shape(inputs.shape) + return KerasTensor( + output_shape, dtype=self.compute_dtype, sparse=self.sparse + ) + + def get_config(self): + config = { + "num_tokens": self.num_tokens, + "output_mode": self.output_mode, + } + base_config = super().get_config() + return {**base_config, **config} + + def call(self, inputs, count_weights=None): + if count_weights is not None: + if self.output_mode != "count": + raise ValueError( + "`count_weights` is not used when `output_mode` is not " + f"`'count'`. Received `count_weights={count_weights}`." + ) + count_weights = self.backend.convert_to_tensor( + count_weights, dtype=self.compute_dtype + ) + outputs = self._encode(inputs, count_weights) + return backend_utils.convert_tf_tensor(outputs) diff --git a/keras/src/layers/preprocessing/category_encoding_test.py b/keras/src/layers/preprocessing/category_encoding_test.py new file mode 100644 index 000000000000..4c5a2b929da4 --- /dev/null +++ b/keras/src/layers/preprocessing/category_encoding_test.py @@ -0,0 +1,340 @@ +import numpy as np +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + +TEST_CASES = [{"testcase_name": "dense", "sparse": False}] +if backend.SUPPORTS_SPARSE_TENSORS: + TEST_CASES += [{"testcase_name": "sparse", "sparse": True}] + + +class CategoryEncodingTest(testing.TestCase): + @parameterized.named_parameters(TEST_CASES) + def test_count_output(self, sparse): + input_array = np.array([1, 2, 3, 1]) + expected_output = np.array([0, 2, 1, 1, 0, 0]) + + num_tokens = 6 + expected_output_shape = (num_tokens,) + + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="count", sparse=sparse + ) + int_data = layer(input_array) + self.assertEqual(expected_output_shape, int_data.shape) + self.assertAllClose(int_data, expected_output) + self.assertSparse(int_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_array.shape, dtype="int32") + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + @parameterized.named_parameters(TEST_CASES) + def test_count_weighted_output(self, sparse): + input_array = np.array([[0, 1], [0, 0], [1, 2], [3, 1]]) + count_weights = np.array( + [[0.1, 0.2], [0.1, 0.1], [0.2, 0.3], [0.4, 0.2]] + ) + expected_output = np.array( + [ + [0.1, 0.2, 0.0, 0.0, 0.0, 0.0], + [0.2, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.2, 0.3, 0.0, 0.0, 0.0], + [0.0, 0.2, 0.0, 0.4, 0.0, 0.0], + ] + ) + + num_tokens = 6 + expected_output_shape = (input_array.shape[0], num_tokens) + + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="count", sparse=sparse + ) + int_data = layer(input_array, count_weights=count_weights) + self.assertEqual(expected_output_shape, int_data.shape) + self.assertAllClose(int_data, expected_output) + self.assertSparse(int_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_array.shape, dtype="int32"), + count_weights=layers.Input( + batch_shape=input_array.shape, dtype="float32" + ), + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + @parameterized.named_parameters(TEST_CASES) + def test_batched_count_output(self, sparse): + input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) + expected_output = np.array([[0, 2, 1, 1, 0, 0], [2, 1, 0, 1, 0, 0]]) + + num_tokens = 6 + expected_output_shape = (2, num_tokens) + + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="count", sparse=sparse + ) + int_data = layer(input_array) + self.assertEqual(expected_output_shape, int_data.shape) + self.assertAllClose(int_data, expected_output) + self.assertSparse(int_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_array.shape, dtype="int32") + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + @parameterized.named_parameters(TEST_CASES) + def test_multi_hot(self, sparse): + input_data = np.array([3, 2, 0, 1]) + expected_output = np.array([1, 1, 1, 1, 0, 0]) + num_tokens = 6 + expected_output_shape = (num_tokens,) + + # Test call on layer directly. + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="multi_hot", sparse=sparse + ) + output_data = layer(input_data) + self.assertAllClose(expected_output, output_data) + self.assertEqual(expected_output_shape, output_data.shape) + self.assertSparse(output_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_data.shape, dtype="int32") + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + @parameterized.named_parameters(TEST_CASES) + def test_batched_multi_hot(self, sparse): + input_data = np.array([[3, 2, 0, 1], [3, 2, 0, 1]]) + expected_output = np.array([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]]) + num_tokens = 6 + expected_output_shape = (input_data.shape[0], num_tokens) + + # Test call on layer directly. + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="multi_hot", sparse=sparse + ) + output_data = layer(input_data) + self.assertAllClose(expected_output, output_data) + self.assertEqual(expected_output_shape, output_data.shape) + self.assertSparse(output_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_data.shape, dtype="int32") + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + # Test compute_output_shape + input_data = np.array((4)) + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="multi_hot", sparse=sparse + ) + self.assertEqual( + layer(input_data).shape, + layer.compute_output_shape(input_data.shape), + ) + + @parameterized.named_parameters(TEST_CASES) + def test_one_hot(self, sparse): + input_data = np.array([3, 2, 0, 1]) + expected_output = np.array( + [ + [0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + ] + ) + num_tokens = 6 + expected_output_shape = (input_data.shape[0], num_tokens) + + # Test call on layer directly. + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="one_hot", sparse=sparse + ) + output_data = layer(input_data) + self.assertAllClose(expected_output, output_data) + self.assertEqual(expected_output_shape, output_data.shape) + self.assertSparse(output_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_data.shape, dtype="int32") + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + # Test compute_output_shape + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="one_hot", sparse=sparse + ) + self.assertEqual( + layer(input_data).shape, + layer.compute_output_shape(input_data.shape), + ) + + # Test compute_output_shape with 1 extra dimension + input_data = np.array([[3], [2], [0], [1]]) + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="one_hot", sparse=sparse + ) + self.assertEqual( + layer(input_data).shape, + layer.compute_output_shape(input_data.shape), + ) + + input_data = np.array((4,)) + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="one_hot", sparse=sparse + ) + self.assertEqual( + layer(input_data).shape, + layer.compute_output_shape(input_data.shape), + ) + + @parameterized.named_parameters(TEST_CASES) + def test_batched_one_hot(self, sparse): + input_data = np.array([[3, 2, 0, 1], [3, 2, 0, 1]]) + expected_output = np.array( + [ + [ + [0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + ], + [ + [0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + ], + ] + ) + num_tokens = 6 + expected_output_shape = input_data.shape[0:2] + (num_tokens,) + + # Test call on layer directly. + layer = layers.CategoryEncoding( + num_tokens=num_tokens, output_mode="one_hot", sparse=sparse + ) + output_data = layer(input_data) + self.assertAllClose(expected_output, output_data) + self.assertEqual(expected_output_shape, output_data.shape) + self.assertSparse(output_data, sparse) + + # Test symbolic call. + output = layer( + layers.Input(batch_shape=input_data.shape, dtype="int32") + ) + self.assertEqual(expected_output_shape, output.shape) + self.assertEqual("float32", output.dtype) + self.assertSparse(output, sparse) + + def test_tf_data_compatibility(self): + layer = layers.CategoryEncoding( + num_tokens=4, output_mode="one_hot", dtype="int32" + ) + input_data = np.array([3, 2, 0, 1]) + expected_output = np.array( + [ + [0, 0, 0, 1], + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + ] + ) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(4).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertAllClose(output, expected_output) + + def test_category_encoding_without_num_tokens(self): + with self.assertRaisesRegex( + ValueError, r"num_tokens must be set to use this layer" + ): + layers.CategoryEncoding(output_mode="multi_hot") + + def test_category_encoding_with_invalid_num_tokens(self): + with self.assertRaisesRegex(ValueError, r"`num_tokens` must be >= 1"): + layers.CategoryEncoding(num_tokens=0, output_mode="multi_hot") + + with self.assertRaisesRegex(ValueError, r"`num_tokens` must be >= 1"): + layers.CategoryEncoding(num_tokens=-1, output_mode="multi_hot") + + def test_category_encoding_with_unnecessary_count_weights(self): + layer = layers.CategoryEncoding(num_tokens=4, output_mode="multi_hot") + input_data = np.array([0, 1, 2, 3]) + count_weights = np.array([0.1, 0.2, 0.3, 0.4]) + with self.assertRaisesRegex( + ValueError, r"`count_weights` is not used when `output_mode`" + ): + layer(input_data, count_weights=count_weights) + + def test_invalid_output_mode_raises_error(self): + with self.assertRaisesRegex( + ValueError, r"Unknown arg for output_mode: invalid_mode" + ): + layers.CategoryEncoding(num_tokens=4, output_mode="invalid_mode") + + def test_encode_one_hot_single_sample(self): + layer = layers.CategoryEncoding(num_tokens=4, output_mode="one_hot") + input_array = np.array([1, 2, 3, 1]) + expected_output = np.array( + [ + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [0, 1, 0, 0], + ] + ) + output = layer._encode(input_array) + self.assertAllClose(expected_output, output) + + def test_encode_one_hot_batched_samples(self): + layer = layers.CategoryEncoding(num_tokens=4, output_mode="one_hot") + input_array = np.array([[3, 2, 0, 1], [3, 2, 0, 1]]) + expected_output = np.array( + [ + [[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0]], + [[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0]], + ] + ) + output = layer._encode(input_array) + self.assertAllClose(expected_output, output) + + def test_count_single_sample(self): + layer = layers.CategoryEncoding(num_tokens=4, output_mode="count") + input_array = np.array([1, 2, 3, 1]) + expected_output = np.array([0, 2, 1, 1]) + output = layer(input_array) + self.assertAllClose(expected_output, output) + + def test_count_batched_samples(self): + layer = layers.CategoryEncoding(num_tokens=4, output_mode="count") + input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) + expected_output = np.array([[0, 2, 1, 1], [2, 1, 0, 1]]) + output = layer(input_array) + self.assertAllClose(expected_output, output) diff --git a/keras/src/layers/preprocessing/data_layer.py b/keras/src/layers/preprocessing/data_layer.py new file mode 100644 index 000000000000..437377248fb8 --- /dev/null +++ b/keras/src/layers/preprocessing/data_layer.py @@ -0,0 +1,159 @@ +import keras.src.backend +from keras.src import tree +from keras.src.layers.layer import Layer +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils +from keras.src.utils import jax_utils +from keras.src.utils import tracking + + +class DataLayer(Layer): + """Layer designed for safe use in `tf.data` or `grain` pipeline. + + This layer overrides the `__call__` method to ensure that the correct + backend is used and that computation is performed on the CPU. + + The `call()` method in subclasses should use `self.backend` ops. If + randomness is needed, define both `seed` and `generator` in `__init__` and + retrieve the running seed using `self._get_seed_generator()`. If the layer + has weights in `__init__` or `build()`, use `convert_weight()` to ensure + they are in the correct backend. + + **Note:** This layer and its subclasses only support a single input tensor. + + Examples: + + **Custom `DataLayer` subclass:** + + ```python + from keras.src.layers.preprocessing.data_layer import DataLayer + from keras.src.random import SeedGenerator + + + class BiasedRandomRGBToHSVLayer(DataLayer): + def __init__(self, seed=None, **kwargs): + super().__init__(**kwargs) + self.probability_bias = ops.convert_to_tensor(0.01) + self.seed = seed + self.generator = SeedGenerator(seed) + + def call(self, inputs): + images_shape = self.backend.shape(inputs) + batch_size = 1 if len(images_shape) == 3 else images_shape[0] + seed = self._get_seed_generator(self.backend._backend) + + probability = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + probability = self.backend.numpy.add( + probability, self.convert_weight(self.probability_bias) + ) + hsv_images = self.backend.image.rgb_to_hsv(inputs) + return self.backend.numpy.where( + probability[:, None, None, None] > 0.5, + hsv_images, + inputs, + ) + + def compute_output_shape(self, input_shape): + return input_shape + ``` + + **Using as a regular Keras layer:** + + ```python + import numpy as np + + x = np.random.uniform(size=(1, 16, 16, 3)).astype("float32") + print(BiasedRandomRGBToHSVLayer()(x).shape) # (1, 16, 16, 3) + ``` + + **Using in a `tf.data` pipeline:** + + ```python + import tensorflow as tf + + tf_ds = tf.data.Dataset.from_tensors(x) + tf_ds = tf_ds.map(BiasedRandomRGBToHSVLayer()) + print([x.shape for x in tf_ds]) # [(1, 16, 16, 3)] + ``` + + **Using in a `grain` pipeline:** + + ```python + import grain + + grain_ds = grain.MapDataset.source([x]) + grain_ds = grain_ds.map(BiasedRandomRGBToHSVLayer()) + print([x.shape for x in grain_ds]) # [(1, 16, 16, 3)] + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.backend = backend_utils.DynamicBackend() + self._allow_non_tensor_positional_args = True + + def __call__(self, inputs, **kwargs): + sample_input = tree.flatten(inputs)[0] + if ( + not isinstance(sample_input, keras.KerasTensor) + and backend_utils.in_tf_graph() + and not jax_utils.is_in_jax_tracing_scope(sample_input) + ): + # We're in a TF graph, e.g. a tf.data pipeline. + self.backend.set_backend("tensorflow") + inputs = tree.map_structure( + lambda x: self.backend.convert_to_tensor( + x, dtype=self.compute_dtype + ), + inputs, + ) + switch_convert_input_args = False + if self._convert_input_args: + self._convert_input_args = False + switch_convert_input_args = True + try: + outputs = super().__call__(inputs, **kwargs) + finally: + self.backend.reset() + if switch_convert_input_args: + self._convert_input_args = True + return outputs + elif ( + not isinstance(sample_input, keras.KerasTensor) + and backend_utils.in_grain_data_pipeline() + ): + # We're in a Grain data pipeline. Force computation and data + # placement to CPU. + with keras.src.backend.device_scope("cpu"): + return super().__call__(inputs, **kwargs) + else: + return super().__call__(inputs, **kwargs) + + @tracking.no_automatic_dependency_tracking + def _get_seed_generator(self, backend=None): + if not hasattr(self, "seed") or not hasattr(self, "generator"): + raise ValueError( + "The `seed` and `generator` variable must be set in the " + "`__init__` method before calling `_get_seed_generator()`." + ) + if backend is None or backend == keras.backend.backend(): + return self.generator + if not hasattr(self, "_backend_generators"): + self._backend_generators = {} + if backend in self._backend_generators: + return self._backend_generators[backend] + seed_generator = SeedGenerator(self.seed, backend=self.backend) + self._backend_generators[backend] = seed_generator + return seed_generator + + def convert_weight(self, weight): + """Convert the weight if it is from the a different backend.""" + if self.backend.name == keras.backend.backend(): + return weight + else: + weight = keras.ops.convert_to_numpy(weight) + return self.backend.convert_to_tensor(weight) diff --git a/keras/src/layers/preprocessing/data_layer_test.py b/keras/src/layers/preprocessing/data_layer_test.py new file mode 100644 index 000000000000..01f5945777fc --- /dev/null +++ b/keras/src/layers/preprocessing/data_layer_test.py @@ -0,0 +1,90 @@ +import grain +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import testing +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.random import SeedGenerator + + +class RandomRGBToHSVLayer(DataLayer): + def __init__(self, data_format=None, seed=None, **kwargs): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + self.seed = seed + self.generator = SeedGenerator(seed) + + def call(self, inputs): + images_shape = self.backend.shape(inputs) + batch_size = 1 if len(images_shape) == 3 else images_shape[0] + seed = self._get_seed_generator(self.backend._backend) + + probability = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + hsv_images = self.backend.image.rgb_to_hsv( + inputs, data_format=self.data_format + ) + return self.backend.numpy.where( + probability[:, None, None, None] > 0.5, hsv_images, inputs + ) + + def compute_output_shape(self, input_shape): + return input_shape + + +class DataLayerTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + RandomRGBToHSVLayer, + init_kwargs={ + "seed": 1337, + "data_format": "channels_last", + }, + input_shape=(1, 2, 2, 3), + supports_masking=False, + expected_output_shape=(1, 2, 2, 3), + ) + + self.run_layer_test( + RandomRGBToHSVLayer, + init_kwargs={ + "seed": 1337, + "data_format": "channels_first", + }, + input_shape=(1, 3, 2, 2), + supports_masking=False, + expected_output_shape=(1, 3, 2, 2), + ) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)).astype("float32") + else: + input_data = np.random.random((2, 3, 8, 8)).astype("float32") + layer = RandomRGBToHSVLayer(data_format=data_format, seed=1337) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + self.assertDType(output, "float32") + self.assertEqual(list(output.shape), list(input_data.shape)) + + def test_grain_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)).astype("float32") + else: + input_data = np.random.random((2, 3, 8, 8)).astype("float32") + layer = RandomRGBToHSVLayer(data_format=data_format, seed=1337) + + ds = grain.MapDataset.source(input_data).batch(2).map(layer) + for output in ds[:1]: + self.assertDType(output, "float32") + self.assertEqual(list(output.shape), list(input_data.shape)) diff --git a/keras/src/layers/preprocessing/discretization.py b/keras/src/layers/preprocessing/discretization.py new file mode 100644 index 000000000000..50e79ba2f49f --- /dev/null +++ b/keras/src/layers/preprocessing/discretization.py @@ -0,0 +1,358 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.utils import argument_validation +from keras.src.utils import numerical_utils +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.Discretization") +class Discretization(DataLayer): + """A preprocessing layer which buckets continuous features by ranges. + + This layer will place each element of its input data into one of several + contiguous ranges and output an integer index indicating which range each + element was placed in. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + Any array of dimension 2 or higher. + + Output shape: + Same as input shape. + + Arguments: + bin_boundaries: A list of bin boundaries. + The leftmost and rightmost bins + will always extend to `-inf` and `inf`, + so `bin_boundaries=[0., 1., 2.]` + generates bins `(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, + and `[2., +inf)`. + If this option is set, `adapt()` should not be called. + num_bins: The integer number of bins to compute. + If this option is set, `bin_boundaries` should not be set and + `adapt()` should be called to learn the bin boundaries. + epsilon: Error tolerance, typically a small fraction + close to zero (e.g. 0.01). Higher values of epsilon increase + the quantile approximation, and hence result in more + unequal buckets, but could improve performance + and resource consumption. + output_mode: Specification for the output of the layer. + Values can be `"int"`, `"one_hot"`, `"multi_hot"`, or + `"count"` configuring the layer as follows: + - `"int"`: Return the discretized bin indices directly. + - `"one_hot"`: Encodes each individual element in the + input into an array the same size as `num_bins`, + containing a 1 at the input's bin + index. If the last dimension is size 1, will encode on that + dimension. If the last dimension is not size 1, + will append a new dimension for the encoded output. + - `"multi_hot"`: Encodes each sample in the input into a + single array the same size as `num_bins`, + containing a 1 for each bin index + index present in the sample. + Treats the last dimension as the sample + dimension, if input shape is `(..., sample_length)`, + output shape will be `(..., num_tokens)`. + - `"count"`: As `"multi_hot"`, but the int array contains + a count of the number of times the bin index appeared + in the sample. + Defaults to `"int"`. + sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`, + and `"count"` output modes. Only supported with TensorFlow + backend. If `True`, returns a `SparseTensor` instead of + a dense `Tensor`. Defaults to `False`. + + Examples: + + Discretize float values based on provided buckets. + >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]]) + >>> layer = Discretization(bin_boundaries=[0., 1., 2.]) + >>> layer(input) + array([[0, 2, 3, 1], + [1, 3, 2, 1]]) + + Discretize float values based on a number of buckets to compute. + >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]]) + >>> layer = Discretization(num_bins=4, epsilon=0.01) + >>> layer.adapt(input) + >>> layer(input) + array([[0, 2, 3, 2], + [1, 3, 3, 1]]) + """ + + def __init__( + self, + bin_boundaries=None, + num_bins=None, + epsilon=0.01, + output_mode="int", + sparse=False, + dtype=None, + name=None, + ): + if dtype is None: + dtype = "int64" if output_mode == "int" else backend.floatx() + + super().__init__(name=name, dtype=dtype) + + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + raise ValueError( + f"`sparse=True` cannot be used with backend {backend.backend()}" + ) + if sparse and output_mode == "int": + raise ValueError( + "`sparse=True` may only be used if `output_mode` is " + "`'one_hot'`, `'multi_hot'`, or `'count'`. " + f"Received: sparse={sparse} and " + f"output_mode={output_mode}" + ) + + argument_validation.validate_string_arg( + output_mode, + allowable_strings=( + "int", + "one_hot", + "multi_hot", + "count", + ), + caller_name=self.__class__.__name__, + arg_name="output_mode", + ) + + if num_bins is not None and num_bins < 0: + raise ValueError( + "`num_bins` must be greater than or equal to 0. " + f"Received: `num_bins={num_bins}`" + ) + if num_bins is not None and bin_boundaries is not None: + raise ValueError( + "Both `num_bins` and `bin_boundaries` should not be set. " + f"Received: `num_bins={num_bins}` and " + f"`bin_boundaries={bin_boundaries}`" + ) + if num_bins is None and bin_boundaries is None: + raise ValueError( + "You need to set either `num_bins` or `bin_boundaries`." + ) + + self.bin_boundaries = bin_boundaries + self.num_bins = num_bins + self.epsilon = epsilon + self.output_mode = output_mode + self.sparse = sparse + + if self.bin_boundaries: + self.summary = None + else: + self.summary = np.array([[], []], dtype="float32") + + @property + def input_dtype(self): + return backend.floatx() + + def adapt(self, data, steps=None): + """Computes bin boundaries from quantiles in a input dataset. + + Calling `adapt()` on a `Discretization` layer is an alternative to + passing in a `bin_boundaries` argument during construction. A + `Discretization` layer should always be either adapted over a dataset or + passed `bin_boundaries`. + + During `adapt()`, the layer will estimate the quantile boundaries of the + input dataset. The number of quantiles can be controlled via the + `num_bins` argument, and the error tolerance for quantile boundaries can + be controlled via the `epsilon` argument. + + Arguments: + data: The data to train on. It can be passed either as a + batched `tf.data.Dataset`, + or as a NumPy array. + steps: Integer or `None`. + Total number of steps (batches of samples) to process. + If `data` is a `tf.data.Dataset`, and `steps` is `None`, + `adapt()` will run until the input dataset is exhausted. + When passing an infinitely + repeating dataset, you must specify the `steps` argument. This + argument is not supported with array inputs or list inputs. + """ + if self.num_bins is None: + raise ValueError( + "Cannot adapt a Discretization layer that has been initialized " + "with `bin_boundaries`, use `num_bins` instead." + ) + self.reset_state() + if isinstance(data, tf.data.Dataset): + if steps is not None: + data = data.take(steps) + for batch in data: + self.update_state(batch) + else: + self.update_state(data) + self.finalize_state() + + def update_state(self, data): + data = np.array(data).astype("float32") + summary = summarize(data, self.epsilon) + self.summary = merge_summaries(summary, self.summary, self.epsilon) + + def finalize_state(self): + if self.num_bins is None: + return + self.bin_boundaries = get_bin_boundaries( + self.summary, self.num_bins + ).tolist() + + def reset_state(self): + if self.num_bins is None: + return + self.summary = np.array([[], []], dtype="float32") + + def compute_output_spec(self, inputs): + return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype) + + def load_own_variables(self, store): + if len(store) == 1: + # Legacy format case + self.summary = store["0"] + return + + def call(self, inputs): + if self.bin_boundaries is None: + raise ValueError( + "You need to either pass the `bin_boundaries` argument at " + "construction time or call `adapt(dataset)` before you can " + "start using the `Discretization` layer." + ) + + indices = self.backend.numpy.digitize(inputs, self.bin_boundaries) + return numerical_utils.encode_categorical_inputs( + indices, + output_mode=self.output_mode, + depth=len(self.bin_boundaries) + 1, + dtype=self.compute_dtype, + sparse=self.sparse, + backend_module=self.backend, + ) + + def get_config(self): + return { + "bin_boundaries": self.bin_boundaries, + "num_bins": self.num_bins, + "epsilon": self.epsilon, + "output_mode": self.output_mode, + "sparse": self.sparse, + "name": self.name, + "dtype": self.dtype, + } + + @classmethod + def from_config(cls, config, custom_objects=None): + if ( + config.get("bin_boundaries", None) is not None + and config.get("num_bins", None) is not None + ): + # After `adapt` was called, both `bin_boundaries` and `num_bins` are + # populated, but `__init__` won't let us create a new layer with + # both `bin_boundaries` and `num_bins`. We therefore apply + # `bin_boundaries` after creation. + config = config.copy() + bin_boundaries = config.pop("bin_boundaries") + discretization = cls(**config) + discretization.bin_boundaries = bin_boundaries + return discretization + return cls(**config) + + +def summarize(values, epsilon): + """Reduce a 1D sequence of values to a summary. + + This algorithm is based on numpy.quantiles but modified to allow for + intermediate steps between multiple data sets. It first finds the target + number of bins as the reciprocal of epsilon and then takes the individual + values spaced at appropriate intervals to arrive at that target. + The final step is to return the corresponding counts between those values + If the target num_bins is larger than the size of values, the whole array is + returned (with weights of 1). + + Args: + values: 1D `np.ndarray` to be summarized. + epsilon: A `'float32'` that determines the approximate desired + precision. + + Returns: + A 2D `np.ndarray` that is a summary of the inputs. First column is the + interpolated partition values, the second is the weights (counts). + """ + values = np.reshape(values, [-1]) + values = np.sort(values) + elements = np.size(values) + num_buckets = 1.0 / epsilon + increment = elements / num_buckets + start = increment + step = max(increment, 1) + boundaries = values[int(start) :: int(step)] + weights = np.ones_like(boundaries) + weights = weights * step + return np.stack([boundaries, weights]) + + +def merge_summaries(prev_summary, next_summary, epsilon): + """Weighted merge sort of summaries. + + Given two summaries of distinct data, this function merges (and compresses) + them to stay within `epsilon` error tolerance. + + Args: + prev_summary: 2D `np.ndarray` summary to be merged with `next_summary`. + next_summary: 2D `np.ndarray` summary to be merged with `prev_summary`. + epsilon: A float that determines the approximate desired precision. + + Returns: + A 2-D `np.ndarray` that is a merged summary. First column is the + interpolated partition values, the second is the weights (counts). + """ + merged = np.concatenate((prev_summary, next_summary), axis=1) + merged = np.take(merged, np.argsort(merged[0]), axis=1) + return compress_summary(merged, epsilon) + + +def get_bin_boundaries(summary, num_bins): + return compress_summary(summary, 1.0 / num_bins)[0, :-1] + + +def compress_summary(summary, epsilon): + """Compress a summary to within `epsilon` accuracy. + + The compression step is needed to keep the summary sizes small after + merging, and also used to return the final target boundaries. It finds the + new bins based on interpolating cumulative weight percentages from the large + summary. Taking the difference of the cumulative weights from the previous + bin's cumulative weight will give the new weight for that bin. + + Args: + summary: 2D `np.ndarray` summary to be compressed. + epsilon: A `'float32'` that determines the approximate desired + precision. + + Returns: + A 2D `np.ndarray` that is a compressed summary. First column is the + interpolated partition values, the second is the weights (counts). + """ + if summary.shape[1] * epsilon < 1: + return summary + + percents = epsilon + np.arange(0.0, 1.0, epsilon) + cum_weights = summary[1].cumsum() + cum_weight_percents = cum_weights / cum_weights[-1] + new_bins = np.interp(percents, cum_weight_percents, summary[0]) + cum_weights = np.interp(percents, cum_weight_percents, cum_weights) + new_weights = cum_weights - np.concatenate( + (np.array([0]), cum_weights[:-1]) + ) + summary = np.stack((new_bins, new_weights)) + return summary.astype("float32") diff --git a/keras/src/layers/preprocessing/discretization_test.py b/keras/src/layers/preprocessing/discretization_test.py new file mode 100644 index 000000000000..500c6e9ca039 --- /dev/null +++ b/keras/src/layers/preprocessing/discretization_test.py @@ -0,0 +1,207 @@ +import os + +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.saving import saving_api +from keras.src.testing.test_utils import named_product + + +class DiscretizationTest(testing.TestCase): + def test_discretization_basics(self): + self.run_layer_test( + layers.Discretization, + init_kwargs={ + "bin_boundaries": [0.0, 0.5, 1.0], + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + ) + + def test_adapt_flow(self): + layer = layers.Discretization(num_bins=4) + layer.adapt( + np.random.random((32, 3)), + ) + output = layer(np.array([[0.0, 0.1, 0.3]])) + self.assertTrue(output.dtype, "int32") + + @parameterized.named_parameters( + named_product( + [ + { + "testcase_name": "int", + "output_mode": "int", + "input_array": [[-1.0, 0.0, 0.1, 0.8, 1.2]], + "expected_output": [[0, 1, 1, 2, 3]], + }, + { + "testcase_name": "one_hot_rank_1", + "output_mode": "one_hot", + "input_array": [0.1, 0.8], + "expected_output": [[0, 1, 0, 0], [0, 0, 1, 0]], + }, + { + "testcase_name": "multi_hot_rank_2", + "output_mode": "multi_hot", + "input_array": [[0.1, 0.8]], + "expected_output": [[0, 1, 1, 0]], + }, + { + "testcase_name": "one_hot_rank_3", + "output_mode": "one_hot", + "input_array": [[[0.15, 0.75], [0.85, 0.45]]], + "expected_output": [ + [ + [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], + [[0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + ] + ], + }, + { + "testcase_name": "multi_hot_rank_3", + "output_mode": "multi_hot", + "input_array": [[[0.15, 0.75], [0.85, 0.45]]], + "expected_output": [ + [[0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0]] + ], + }, + { + "testcase_name": "count", + "output_mode": "count", + "input_array": [[0.1, 0.8, 0.9]], + "expected_output": [[0, 1, 2, 0]], + }, + ], + sparse=( + [True, False] if backend.SUPPORTS_SPARSE_TENSORS else [False] + ), + ) + ) + def test_correctness( + self, output_mode, input_array, expected_output, sparse + ): + if output_mode == "int" and sparse: + pytest.skip("sparse=True cannot be combined with output_mode=int") + + input_array = np.array(input_array) + expected_output = np.array(expected_output) + + layer = layers.Discretization( + bin_boundaries=[0.0, 0.5, 1.0], + output_mode=output_mode, + sparse=sparse, + ) + output = layer(input_array) + self.assertSparse(output, sparse) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, expected_output) + + def test_tf_data_compatibility(self): + # With fixed bins + layer = layers.Discretization( + bin_boundaries=[0.0, 0.35, 0.5, 1.0], dtype="float32" + ) + x = np.array([[-1.0, 0.0, 0.1, 0.2, 0.4, 0.5, 1.0, 1.2, 0.98]]) + self.assertAllClose(layer(x), np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]])) + ds = tf_data.Dataset.from_tensor_slices(x).batch(1).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertAllClose(output, np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]])) + + # With adapt flow + layer = layers.Discretization(num_bins=4) + layer.adapt( + np.random.random((32, 3)), + ) + x = np.array([[0.0, 0.1, 0.3]]) + ds = tf_data.Dataset.from_tensor_slices(x).batch(1).map(layer) + for output in ds.take(1): + output.numpy() + + def test_serialization(self): + layer = layers.Discretization(num_bins=5) + + # Serialization before `adapt` is called. + config = layer.get_config() + revived_layer = layers.Discretization.from_config(config) + self.assertEqual(config, revived_layer.get_config()) + + # Serialization after `adapt` is called but `num_bins` was not reached. + layer.adapt(np.array([0.0, 1.0, 5.0])) + config = layer.get_config() + revived_layer = layers.Discretization.from_config(config) + self.assertEqual(config, revived_layer.get_config()) + + # Serialization after `adapt` is called and `num_bins` is reached. + layer.adapt(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])) + config = layer.get_config() + revived_layer = layers.Discretization.from_config(config) + self.assertEqual(config, revived_layer.get_config()) + + # Serialization with `bin_boundaries`. + layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0]) + config = layer.get_config() + revived_layer = layers.Discretization.from_config(config) + self.assertEqual(config, revived_layer.get_config()) + + def test_saving(self): + # With fixed bins + layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0]) + model = models.Sequential( + [ + layers.Input((2,)), + layer, + ] + ) + fpath = os.path.join(self.get_temp_dir(), "model.keras") + model.save(fpath) + model = saving_api.load_model(fpath) + x = np.array([[-1.0, 0.0, 0.1, 0.2, 0.4, 0.5, 1.0, 1.2, 0.98]]) + self.assertAllClose(layer(x), np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]])) + + # With adapt flow + layer = layers.Discretization(num_bins=4) + layer.adapt( + np.random.random((32, 3)), + ) + ref_input = np.random.random((1, 2)) + ref_output = layer(ref_input) + model = models.Sequential( + [ + layers.Input((2,)), + layer, + ] + ) + fpath = os.path.join(self.get_temp_dir(), "model.keras") + model.save(fpath) + model = saving_api.load_model(fpath) + self.assertAllClose(layer(ref_input), ref_output) + + def test_init_num_bins_and_bin_boundaries_raises(self): + with self.assertRaisesRegex( + ValueError, "Both `num_bins` and `bin_boundaries`" + ): + layers.Discretization(num_bins=3, bin_boundaries=[0.0, 1.0]) + + with self.assertRaisesRegex( + ValueError, "either `num_bins` or `bin_boundaries`" + ): + layers.Discretization() + + def test_call_before_adapt_raises(self): + layer = layers.Discretization(num_bins=3) + with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"): + layer([[0.1, 0.8, 0.9]]) diff --git a/keras/src/layers/preprocessing/feature_space.py b/keras/src/layers/preprocessing/feature_space.py new file mode 100644 index 000000000000..578bc8cc55f5 --- /dev/null +++ b/keras/src/layers/preprocessing/feature_space.py @@ -0,0 +1,822 @@ +from keras.src import backend +from keras.src import layers +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.saving import saving_lib +from keras.src.saving import serialization_lib +from keras.src.saving.keras_saveable import KerasSaveable +from keras.src.utils import backend_utils +from keras.src.utils.module_utils import tensorflow as tf +from keras.src.utils.naming import auto_name + + +class Cross(KerasSaveable): + def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): + if output_mode not in {"int", "one_hot"}: + raise ValueError( + "Invalid value for argument `output_mode`. " + "Expected one of {'int', 'one_hot'}. " + f"Received: output_mode={output_mode}" + ) + self.feature_names = tuple(feature_names) + self.crossing_dim = crossing_dim + self.output_mode = output_mode + + def _obj_type(self): + return "Cross" + + @property + def name(self): + return "_X_".join(self.feature_names) + + def get_config(self): + return { + "feature_names": self.feature_names, + "crossing_dim": self.crossing_dim, + "output_mode": self.output_mode, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +class Feature(KerasSaveable): + def __init__(self, dtype, preprocessor, output_mode): + if output_mode not in {"int", "one_hot", "float"}: + raise ValueError( + "Invalid value for argument `output_mode`. " + "Expected one of {'int', 'one_hot', 'float'}. " + f"Received: output_mode={output_mode}" + ) + self.dtype = dtype + if isinstance(preprocessor, dict): + preprocessor = serialization_lib.deserialize_keras_object( + preprocessor + ) + self.preprocessor = preprocessor + self.output_mode = output_mode + + def _obj_type(self): + return "Feature" + + def get_config(self): + return { + "dtype": self.dtype, + "preprocessor": serialization_lib.serialize_keras_object( + self.preprocessor + ), + "output_mode": self.output_mode, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras_export("keras.utils.FeatureSpace") +class FeatureSpace(Layer): + """One-stop utility for preprocessing and encoding structured data. + + Arguments: + feature_names: Dict mapping the names of your features to their + type specification, e.g. `{"my_feature": "integer_categorical"}` + or `{"my_feature": FeatureSpace.integer_categorical()}`. + For a complete list of all supported types, see + "Available feature types" paragraph below. + output_mode: One of `"concat"` or `"dict"`. In concat mode, all + features get concatenated together into a single vector. + In dict mode, the FeatureSpace returns a dict of individually + encoded features (with the same keys as the input dict keys). + crosses: List of features to be crossed together, e.g. + `crosses=[("feature_1", "feature_2")]`. The features will be + "crossed" by hashing their combined value into + a fixed-length vector. + crossing_dim: Default vector size for hashing crossed features. + Defaults to `32`. + hashing_dim: Default vector size for hashing features of type + `"integer_hashed"` and `"string_hashed"`. Defaults to `32`. + num_discretization_bins: Default number of bins to be used for + discretizing features of type `"float_discretized"`. + Defaults to `32`. + + **Available feature types:** + + Note that all features can be referred to by their string name, + e.g. `"integer_categorical"`. When using the string name, the default + argument values are used. + + ```python + # Plain float values. + FeatureSpace.float(name=None) + + # Float values to be preprocessed via featurewise standardization + # (i.e. via a `keras.layers.Normalization` layer). + FeatureSpace.float_normalized(name=None) + + # Float values to be preprocessed via linear rescaling + # (i.e. via a `keras.layers.Rescaling` layer). + FeatureSpace.float_rescaled(scale=1., offset=0., name=None) + + # Float values to be discretized. By default, the discrete + # representation will then be one-hot encoded. + FeatureSpace.float_discretized( + num_bins, bin_boundaries=None, output_mode="one_hot", name=None) + + # Integer values to be indexed. By default, the discrete + # representation will then be one-hot encoded. + FeatureSpace.integer_categorical( + max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None) + + # String values to be indexed. By default, the discrete + # representation will then be one-hot encoded. + FeatureSpace.string_categorical( + max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None) + + # Integer values to be hashed into a fixed number of bins. + # By default, the discrete representation will then be one-hot encoded. + FeatureSpace.integer_hashed(num_bins, output_mode="one_hot", name=None) + + # String values to be hashed into a fixed number of bins. + # By default, the discrete representation will then be one-hot encoded. + FeatureSpace.string_hashed(num_bins, output_mode="one_hot", name=None) + ``` + + Examples: + + **Basic usage with a dict of input data:** + + ```python + raw_data = { + "float_values": [0.0, 0.1, 0.2, 0.3], + "string_values": ["zero", "one", "two", "three"], + "int_values": [0, 1, 2, 3], + } + dataset = tf.data.Dataset.from_tensor_slices(raw_data) + + feature_space = FeatureSpace( + features={ + "float_values": "float_normalized", + "string_values": "string_categorical", + "int_values": "integer_categorical", + }, + crosses=[("string_values", "int_values")], + output_mode="concat", + ) + # Before you start using the FeatureSpace, + # you must `adapt()` it on some data. + feature_space.adapt(dataset) + + # You can call the FeatureSpace on a dict of data (batched or unbatched). + output_vector = feature_space(raw_data) + ``` + + **Basic usage with `tf.data`:** + + ```python + # Unlabeled data + preprocessed_ds = unlabeled_dataset.map(feature_space) + + # Labeled data + preprocessed_ds = labeled_dataset.map(lambda x, y: (feature_space(x), y)) + ``` + + **Basic usage with the Keras Functional API:** + + ```python + # Retrieve a dict Keras Input objects + inputs = feature_space.get_inputs() + # Retrieve the corresponding encoded Keras tensors + encoded_features = feature_space.get_encoded_features() + # Build a Functional model + outputs = keras.layers.Dense(1, activation="sigmoid")(encoded_features) + model = keras.Model(inputs, outputs) + ``` + + **Customizing each feature or feature cross:** + + ```python + feature_space = FeatureSpace( + features={ + "float_values": FeatureSpace.float_normalized(), + "string_values": FeatureSpace.string_categorical(max_tokens=10), + "int_values": FeatureSpace.integer_categorical(max_tokens=10), + }, + crosses=[ + FeatureSpace.cross(("string_values", "int_values"), crossing_dim=32) + ], + output_mode="concat", + ) + ``` + + **Returning a dict of integer-encoded features:** + + ```python + feature_space = FeatureSpace( + features={ + "string_values": FeatureSpace.string_categorical(output_mode="int"), + "int_values": FeatureSpace.integer_categorical(output_mode="int"), + }, + crosses=[ + FeatureSpace.cross( + feature_names=("string_values", "int_values"), + crossing_dim=32, + output_mode="int", + ) + ], + output_mode="dict", + ) + ``` + + **Specifying your own Keras preprocessing layer:** + + ```python + # Let's say that one of the features is a short text paragraph that + # we want to encode as a vector (one vector per paragraph) via TF-IDF. + data = { + "text": ["1st string", "2nd string", "3rd string"], + } + + # There's a Keras layer for this: TextVectorization. + custom_layer = layers.TextVectorization(output_mode="tf_idf") + + # We can use FeatureSpace.feature to create a custom feature + # that will use our preprocessing layer. + feature_space = FeatureSpace( + features={ + "text": FeatureSpace.feature( + preprocessor=custom_layer, dtype="string", output_mode="float" + ), + }, + output_mode="concat", + ) + feature_space.adapt(tf.data.Dataset.from_tensor_slices(data)) + output_vector = feature_space(data) + ``` + + **Retrieving the underlying Keras preprocessing layers:** + + ```python + # The preprocessing layer of each feature is available in `.preprocessors`. + preprocessing_layer = feature_space.preprocessors["feature1"] + + # The crossing layer of each feature cross is available in `.crossers`. + # It's an instance of keras.layers.HashedCrossing. + crossing_layer = feature_space.crossers["feature1_X_feature2"] + ``` + + **Saving and reloading a FeatureSpace:** + + ```python + feature_space.save("featurespace.keras") + reloaded_feature_space = keras.models.load_model("featurespace.keras") + ``` + """ + + @classmethod + def cross(cls, feature_names, crossing_dim, output_mode="one_hot"): + return Cross(feature_names, crossing_dim, output_mode=output_mode) + + @classmethod + def feature(cls, dtype, preprocessor, output_mode): + return Feature(dtype, preprocessor, output_mode) + + @classmethod + def float(cls, name=None): + name = name or auto_name("float") + preprocessor = TFDIdentity(dtype="float32", name=f"{name}_preprocessor") + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode="float" + ) + + @classmethod + def float_rescaled(cls, scale=1.0, offset=0.0, name=None): + name = name or auto_name("float_rescaled") + preprocessor = layers.Rescaling( + scale=scale, offset=offset, name=f"{name}_preprocessor" + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode="float" + ) + + @classmethod + def float_normalized(cls, name=None): + name = name or auto_name("float_normalized") + preprocessor = layers.Normalization( + axis=-1, name=f"{name}_preprocessor" + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode="float" + ) + + @classmethod + def float_discretized( + cls, num_bins, bin_boundaries=None, output_mode="one_hot", name=None + ): + name = name or auto_name("float_discretized") + preprocessor = layers.Discretization( + num_bins=num_bins, + bin_boundaries=bin_boundaries, + name=f"{name}_preprocessor", + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def integer_categorical( + cls, + max_tokens=None, + num_oov_indices=1, + output_mode="one_hot", + name=None, + ): + name = name or auto_name("integer_categorical") + preprocessor = layers.IntegerLookup( + name=f"{name}_preprocessor", + max_tokens=max_tokens, + num_oov_indices=num_oov_indices, + ) + return Feature( + dtype="int32", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def string_categorical( + cls, + max_tokens=None, + num_oov_indices=1, + output_mode="one_hot", + name=None, + ): + name = name or auto_name("string_categorical") + preprocessor = layers.StringLookup( + name=f"{name}_preprocessor", + max_tokens=max_tokens, + num_oov_indices=num_oov_indices, + ) + return Feature( + dtype="string", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def string_hashed(cls, num_bins, output_mode="one_hot", name=None): + name = name or auto_name("string_hashed") + preprocessor = layers.Hashing( + name=f"{name}_preprocessor", num_bins=num_bins + ) + return Feature( + dtype="string", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def integer_hashed(cls, num_bins, output_mode="one_hot", name=None): + name = name or auto_name("integer_hashed") + preprocessor = layers.Hashing( + name=f"{name}_preprocessor", num_bins=num_bins + ) + return Feature( + dtype="int32", preprocessor=preprocessor, output_mode=output_mode + ) + + def __init__( + self, + features, + output_mode="concat", + crosses=None, + crossing_dim=32, + hashing_dim=32, + num_discretization_bins=32, + name=None, + ): + super().__init__(name=name) + if not features: + raise ValueError("The `features` argument cannot be None or empty.") + self.crossing_dim = crossing_dim + self.hashing_dim = hashing_dim + self.num_discretization_bins = num_discretization_bins + self.features = { + name: self._standardize_feature(name, value) + for name, value in features.items() + } + self.crosses = [] + if crosses: + feature_set = set(features.keys()) + for cross in crosses: + if isinstance(cross, dict): + cross = serialization_lib.deserialize_keras_object(cross) + if isinstance(cross, Cross): + self.crosses.append(cross) + else: + if not crossing_dim: + raise ValueError( + "When specifying `crosses`, the argument " + "`crossing_dim` " + "(dimensionality of the crossing space) " + "should be specified as well." + ) + for key in cross: + if key not in feature_set: + raise ValueError( + "All features referenced " + "in the `crosses` argument " + "should be present in the `features` dict. " + f"Received unknown features: {cross}" + ) + self.crosses.append(Cross(cross, crossing_dim=crossing_dim)) + self.crosses_by_name = {cross.name: cross for cross in self.crosses} + + if output_mode not in {"dict", "concat"}: + raise ValueError( + "Invalid value for argument `output_mode`. " + "Expected one of {'dict', 'concat'}. " + f"Received: output_mode={output_mode}" + ) + self.output_mode = output_mode + + self.inputs = { + name: self._feature_to_input(name, value) + for name, value in self.features.items() + } + self.preprocessors = { + name: value.preprocessor for name, value in self.features.items() + } + self.encoded_features = None + self.crossers = { + cross.name: self._cross_to_crosser(cross) for cross in self.crosses + } + self.one_hot_encoders = {} + self._is_adapted = False + self.concat = None + self._preprocessed_features_names = None + self._crossed_features_names = None + self._sublayers_built = False + + def _feature_to_input(self, name, feature): + return layers.Input(shape=(1,), dtype=feature.dtype, name=name) + + def _standardize_feature(self, name, feature): + if isinstance(feature, Feature): + return feature + + if isinstance(feature, dict): + return serialization_lib.deserialize_keras_object(feature) + + if feature == "float": + return self.float(name=name) + elif feature == "float_normalized": + return self.float_normalized(name=name) + elif feature == "float_rescaled": + return self.float_rescaled(name=name) + elif feature == "float_discretized": + return self.float_discretized( + name=name, num_bins=self.num_discretization_bins + ) + elif feature == "integer_categorical": + return self.integer_categorical(name=name) + elif feature == "string_categorical": + return self.string_categorical(name=name) + elif feature == "integer_hashed": + return self.integer_hashed(self.hashing_dim, name=name) + elif feature == "string_hashed": + return self.string_hashed(self.hashing_dim, name=name) + else: + raise ValueError(f"Invalid feature type: {feature}") + + def _cross_to_crosser(self, cross): + return layers.HashedCrossing(cross.crossing_dim, name=cross.name) + + def _list_adaptable_preprocessors(self): + adaptable_preprocessors = [] + for name in self.features.keys(): + preprocessor = self.preprocessors[name] + # Special case: a Normalization layer with preset mean/variance. + # Not adaptable. + if isinstance(preprocessor, layers.Normalization): + if preprocessor.input_mean is not None: + continue + # Special case: a TextVectorization layer with provided vocabulary. + elif isinstance(preprocessor, layers.TextVectorization): + if preprocessor._has_input_vocabulary: + continue + if hasattr(preprocessor, "adapt"): + adaptable_preprocessors.append(name) + return adaptable_preprocessors + + def adapt(self, dataset): + if not isinstance(dataset, tf.data.Dataset): + raise ValueError( + "`adapt()` can only be called on a tf.data.Dataset. " + f"Received instead: {dataset} (of type {type(dataset)})" + ) + + for name in self._list_adaptable_preprocessors(): + # Call adapt() on each individual adaptable layer. + + # TODO: consider rewriting this to instead iterate on the + # dataset once, split each batch into individual features, + # and call the layer's `_adapt_function` on each batch + # to simulate the behavior of adapt() in a more performant fashion. + + feature_dataset = dataset.map(lambda x: x[name]) + preprocessor = self.preprocessors[name] + # TODO: consider adding an adapt progress bar. + # Sample 1 element to check the rank + x = next(iter(feature_dataset)) + if len(x.shape) == 0: + # The dataset yields unbatched scalars; batch it. + feature_dataset = feature_dataset.batch(32) + if len(x.shape) in {0, 1}: + # If the rank is 1, add a dimension + # so we can reduce on axis=-1. + # Note: if rank was previously 0, it is now 1. + feature_dataset = feature_dataset.map( + lambda x: tf.expand_dims(x, -1) + ) + preprocessor.adapt(feature_dataset) + self._is_adapted = True + self.get_encoded_features() # Finish building the layer + self.built = True + self._sublayers_built = True + + def get_inputs(self): + self._check_if_built() + return self.inputs + + def get_encoded_features(self): + self._check_if_adapted() + + if self.encoded_features is None: + preprocessed_features = self._preprocess_features(self.inputs) + crossed_features = self._cross_features(preprocessed_features) + merged_features = self._merge_features( + preprocessed_features, crossed_features + ) + self.encoded_features = merged_features + return self.encoded_features + + def _preprocess_features(self, features): + return { + name: self.preprocessors[name](features[name]) + for name in features.keys() + } + + def _cross_features(self, features): + all_outputs = {} + for cross in self.crosses: + inputs = [features[name] for name in cross.feature_names] + outputs = self.crossers[cross.name](inputs) + all_outputs[cross.name] = outputs + return all_outputs + + def _merge_features(self, preprocessed_features, crossed_features): + if not self._preprocessed_features_names: + self._preprocessed_features_names = sorted( + preprocessed_features.keys() + ) + self._crossed_features_names = sorted(crossed_features.keys()) + + all_names = ( + self._preprocessed_features_names + self._crossed_features_names + ) + all_features = [ + preprocessed_features[name] + for name in self._preprocessed_features_names + ] + [crossed_features[name] for name in self._crossed_features_names] + + if self.output_mode == "dict": + output_dict = {} + else: + features_to_concat = [] + + if self._sublayers_built: + # Fast mode. + for name, feature in zip(all_names, all_features): + encoder = self.one_hot_encoders.get(name, None) + if encoder: + feature = encoder(feature) + if self.output_mode == "dict": + output_dict[name] = feature + else: + features_to_concat.append(feature) + if self.output_mode == "dict": + return output_dict + else: + return self.concat(features_to_concat) + + # If the object isn't built, + # we create the encoder and concat layers below + all_specs = [ + self.features[name] for name in self._preprocessed_features_names + ] + [ + self.crosses_by_name[name] for name in self._crossed_features_names + ] + + for name, feature, spec in zip(all_names, all_features, all_specs): + if tree.is_nested(feature): + dtype = tree.flatten(feature)[0].dtype + else: + dtype = feature.dtype + dtype = backend.standardize_dtype(dtype) + + if spec.output_mode == "one_hot": + preprocessor = self.preprocessors.get( + name + ) or self.crossers.get(name) + + cardinality = None + if not dtype.startswith("int"): + raise ValueError( + f"Feature '{name}' has `output_mode='one_hot'`. " + "Thus its preprocessor should return an integer dtype. " + f"Instead it returns a {dtype} dtype." + ) + + if isinstance( + preprocessor, (layers.IntegerLookup, layers.StringLookup) + ): + cardinality = preprocessor.vocabulary_size() + elif isinstance(preprocessor, layers.CategoryEncoding): + cardinality = preprocessor.num_tokens + elif isinstance(preprocessor, layers.Discretization): + cardinality = preprocessor.num_bins + elif isinstance( + preprocessor, (layers.HashedCrossing, layers.Hashing) + ): + cardinality = preprocessor.num_bins + else: + raise ValueError( + f"Feature '{name}' has `output_mode='one_hot'`. " + "However it isn't a standard feature and the " + "dimensionality of its output space is not known, " + "thus it cannot be one-hot encoded. " + "Try using `output_mode='int'`." + ) + if cardinality is not None: + encoder = layers.CategoryEncoding( + num_tokens=cardinality, output_mode="multi_hot" + ) + self.one_hot_encoders[name] = encoder + feature = encoder(feature) + + if self.output_mode == "concat": + dtype = feature.dtype + if dtype.startswith("int") or dtype == "string": + raise ValueError( + f"Cannot concatenate features because feature '{name}' " + f"has not been encoded (it has dtype {dtype}). " + "Consider using `output_mode='dict'`." + ) + features_to_concat.append(feature) + else: + output_dict[name] = feature + + if self.output_mode == "concat": + self.concat = TFDConcat(axis=-1) + return self.concat(features_to_concat) + else: + return output_dict + + def _check_if_adapted(self): + if not self._is_adapted: + if not self._list_adaptable_preprocessors(): + self._is_adapted = True + else: + raise ValueError( + "You need to call `.adapt(dataset)` on the FeatureSpace " + "before you can start using it." + ) + + def _check_if_built(self): + if not self._sublayers_built: + self._check_if_adapted() + # Finishes building + self.get_encoded_features() + self._sublayers_built = True + + def _convert_input(self, x): + if not isinstance(x, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)): + if not isinstance(x, (list, tuple, int, float)): + x = backend.convert_to_numpy(x) + x = tf.convert_to_tensor(x) + return x + + def __call__(self, data): + self._check_if_built() + if not isinstance(data, dict): + raise ValueError( + "A FeatureSpace can only be called with a dict. " + f"Received: data={data} (of type {type(data)}" + ) + + # Many preprocessing layers support all backends but many do not. + # Switch to TF to make FeatureSpace work universally. + data = {key: self._convert_input(value) for key, value in data.items()} + rebatched = False + for name, x in data.items(): + if len(x.shape) == 0: + data[name] = tf.reshape(x, (1, 1)) + rebatched = True + elif len(x.shape) == 1: + data[name] = tf.expand_dims(x, -1) + + with backend_utils.TFGraphScope(): + # This scope is to make sure that inner DataLayers + # will not convert outputs back to backend-native -- + # they should be TF tensors throughout + preprocessed_data = self._preprocess_features(data) + preprocessed_data = tree.map_structure( + lambda x: self._convert_input(x), preprocessed_data + ) + + crossed_data = self._cross_features(preprocessed_data) + crossed_data = tree.map_structure( + lambda x: self._convert_input(x), crossed_data + ) + + merged_data = self._merge_features(preprocessed_data, crossed_data) + + if rebatched: + if self.output_mode == "concat": + assert merged_data.shape[0] == 1 + if ( + backend.backend() != "tensorflow" + and not backend_utils.in_tf_graph() + ): + merged_data = backend.convert_to_numpy(merged_data) + merged_data = tf.squeeze(merged_data, axis=0) + else: + for name, x in merged_data.items(): + if len(x.shape) == 2 and x.shape[0] == 1: + merged_data[name] = tf.squeeze(x, axis=0) + + if ( + backend.backend() != "tensorflow" + and not backend_utils.in_tf_graph() + ): + merged_data = tree.map_structure( + lambda x: backend.convert_to_tensor(x, dtype=x.dtype), + merged_data, + ) + return merged_data + + def get_config(self): + return { + "features": serialization_lib.serialize_keras_object(self.features), + "output_mode": self.output_mode, + "crosses": serialization_lib.serialize_keras_object(self.crosses), + "crossing_dim": self.crossing_dim, + "hashing_dim": self.hashing_dim, + "num_discretization_bins": self.num_discretization_bins, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_build_config(self): + return { + name: feature.preprocessor.get_build_config() + for name, feature in self.features.items() + } + + def build_from_config(self, config): + for name in config.keys(): + preprocessor = self.features[name].preprocessor + if not preprocessor.built: + preprocessor.build_from_config(config[name]) + self._is_adapted = True + + def save(self, filepath): + """Save the `FeatureSpace` instance to a `.keras` file. + + You can reload it via `keras.models.load_model()`: + + ```python + feature_space.save("featurespace.keras") + reloaded_fs = keras.models.load_model("featurespace.keras") + ``` + """ + saving_lib.save_model(self, filepath) + + def save_own_variables(self, store): + return + + def load_own_variables(self, store): + return + + +class TFDConcat(DataLayer): + def __init__(self, axis, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, xs): + return self.backend.numpy.concatenate(xs, axis=self.axis) + + +class TFDIdentity(DataLayer): + def call(self, x): + return x diff --git a/keras/src/layers/preprocessing/feature_space_test.py b/keras/src/layers/preprocessing/feature_space_test.py new file mode 100644 index 000000000000..a1efe3821b0f --- /dev/null +++ b/keras/src/layers/preprocessing/feature_space_test.py @@ -0,0 +1,640 @@ +import os + +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.layers.preprocessing import feature_space +from keras.src.saving import saving_api + + +class FeatureSpaceTest(testing.TestCase): + def _get_train_data_dict( + self, + as_dataset=False, + as_tensors=False, + as_labeled_dataset=False, + include_strings=True, + ): + data = { + "float_1": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "float_2": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "float_3": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "int_2": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "int_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + if include_strings: + data["string_1"] = [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ] + data["string_2"] = [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ] + + if as_dataset: + return tf_data.Dataset.from_tensor_slices(data) + elif as_tensors: + return { + key: ops.convert_to_tensor(value) for key, value in data.items() + } + elif as_labeled_dataset: + labels = [0, 1, 0, 1, 0, 0, 1, 0, 1, 1] + return tf_data.Dataset.from_tensor_slices((data, labels)) + return data + + def test_basic_usage_no_strings(self): + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "int_1": "integer_categorical", + "int_2": "integer_hashed", + "int_3": "integer_categorical", + }, + crosses=[("int_1", "int_2"), ("int_2", "int_3")], + output_mode="concat", + ) + # Test unbatched adapt + fs.adapt( + self._get_train_data_dict(as_dataset=True, include_strings=False) + ) + # Test batched adapt + fs.adapt( + self._get_train_data_dict( + as_dataset=True, include_strings=False + ).batch(4) + ) + + # Test unbatched call on raw data + data = { + key: value[0] + for key, value in self._get_train_data_dict( + include_strings=False + ).items() + } + out = fs(data) + out_dim = 152 + self.assertEqual(out.shape, (out_dim,)) + + # Test unbatched call on backend tensors + data = self._get_train_data_dict(as_tensors=True, include_strings=False) + data = {key: value[0] for key, value in data.items()} + out = fs(data) + self.assertEqual(out.shape, (out_dim,)) + + # Test batched call on raw data + out = fs(self._get_train_data_dict(include_strings=False)) + self.assertEqual(out.shape, (10, out_dim)) + + # Test batched call on backend tensors + out = fs( + self._get_train_data_dict(as_tensors=True, include_strings=False) + ) + self.assertEqual(out.shape, (10, out_dim)) + + def test_output_mode_dict_no_strings(self): + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "int_1": "integer_categorical", + "int_2": "integer_hashed", + "int_3": "integer_categorical", + }, + crosses=[("int_1", "int_2")], + output_mode="dict", + ) + fs.adapt( + self._get_train_data_dict(as_dataset=True, include_strings=False) + ) + + # Test unbatched call on raw data + data = { + key: value[0] + for key, value in self._get_train_data_dict( + include_strings=False + ).items() + } + out = fs(data) + self.assertIsInstance(out, dict) + self.assertLen(out, 7) + self.assertEqual(out["int_2"].shape, (32,)) + self.assertEqual(out["int_1_X_int_2"].shape, (32,)) + + # Test batched call on raw data + out = fs(self._get_train_data_dict(include_strings=False)) + self.assertIsInstance(out, dict) + self.assertLen(out, 7) + self.assertEqual(out["int_2"].shape, (10, 32)) + + # Test batched call on backend tensors + out = fs( + self._get_train_data_dict(as_tensors=True, include_strings=False) + ) + self.assertIsInstance(out, dict) + self.assertLen(out, 7) + self.assertEqual(out["int_2"].shape, (10, 32)) + + def test_output_mode_dict_of_ints_no_strings(self): + cls = feature_space.FeatureSpace + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "int_1": cls.integer_categorical(output_mode="int"), + "int_2": cls.integer_hashed(num_bins=32, output_mode="int"), + "int_3": cls.integer_categorical(output_mode="int"), + }, + crosses=[ + cls.cross( + ("int_1", "int_2"), output_mode="int", crossing_dim=32 + ), + ], + output_mode="dict", + ) + fs.adapt( + self._get_train_data_dict(as_dataset=True, include_strings=False) + ) + data = { + key: value[0] + for key, value in self._get_train_data_dict( + include_strings=False + ).items() + } + out = fs(data) + self.assertIsInstance(out, dict) + self.assertLen(out, 7) + self.assertEqual(out["int_2"].shape, (1,)) + self.assertTrue( + backend.standardize_dtype(out["int_2"].dtype).startswith("int") + ) + self.assertEqual(out["int_1_X_int_2"].shape, (1,)) + self.assertTrue( + backend.standardize_dtype(out["int_1_X_int_2"].dtype).startswith( + "int" + ) + ) + + def test_basic_usage(self): + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "string_1": "string_categorical", + "string_2": "string_hashed", + "int_1": "integer_categorical", + "int_2": "integer_hashed", + "int_3": "integer_categorical", + }, + crosses=[("float_3", "string_1"), ("string_2", "int_2")], + output_mode="concat", + ) + # Test unbatched adapt + fs.adapt(self._get_train_data_dict(as_dataset=True)) + # Test batched adapt + fs.adapt(self._get_train_data_dict(as_dataset=True).batch(4)) + + # Test unbatched call on raw data + data = { + key: value[0] for key, value in self._get_train_data_dict().items() + } + out = fs(data) + out_dim = 195 + self.assertEqual(out.shape, (out_dim,)) + + # Test unbatched call on tensors + if backend.backend() == "tensorflow": + data = self._get_train_data_dict(as_tensors=True) + data = {key: value[0] for key, value in data.items()} + out = fs(data) + self.assertEqual(out.shape, (out_dim,)) + + # Test batched call on raw data + out = fs(self._get_train_data_dict()) + self.assertEqual(out.shape, (10, out_dim)) + + # Test batched call on tensors + if backend.backend() == "tensorflow": + out = fs(self._get_train_data_dict(as_tensors=True)) + self.assertEqual(out.shape, (10, out_dim)) + + def test_output_mode_dict(self): + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "string_1": "string_categorical", + "string_2": "string_hashed", + "int_1": "integer_categorical", + "int_2": "integer_hashed", + "int_3": "integer_categorical", + }, + crosses=[("float_3", "string_1"), ("string_2", "int_2")], + output_mode="dict", + ) + fs.adapt(self._get_train_data_dict(as_dataset=True)) + + # Test unbatched call on raw data + data = { + key: value[0] for key, value in self._get_train_data_dict().items() + } + out = fs(data) + self.assertIsInstance(out, dict) + self.assertLen(out, 10) + self.assertEqual(out["string_1"].shape, (11,)) + self.assertEqual(out["int_2"].shape, (32,)) + self.assertEqual(out["string_2_X_int_2"].shape, (32,)) + + # Test batched call on raw data + out = fs(self._get_train_data_dict()) + self.assertIsInstance(out, dict) + self.assertLen(out, 10) + self.assertEqual(out["string_1"].shape, (10, 11)) + self.assertEqual(out["int_2"].shape, (10, 32)) + self.assertEqual(out["string_2_X_int_2"].shape, (10, 32)) + + # Test batched call on tensors + if backend.backend() == "tensorflow": + out = fs(self._get_train_data_dict(as_tensors=True)) + self.assertIsInstance(out, dict) + self.assertLen(out, 10) + self.assertEqual(out["string_1"].shape, (10, 11)) + self.assertEqual(out["int_2"].shape, (10, 32)) + self.assertEqual(out["string_2_X_int_2"].shape, (10, 32)) + + def test_output_mode_dict_of_ints(self): + cls = feature_space.FeatureSpace + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "string_1": cls.string_categorical(output_mode="int"), + "string_2": cls.string_hashed(num_bins=32, output_mode="int"), + "int_1": cls.integer_categorical(output_mode="int"), + "int_2": cls.integer_hashed(num_bins=32, output_mode="int"), + "int_3": cls.integer_categorical(output_mode="int"), + }, + crosses=[ + cls.cross( + ("float_3", "string_1"), output_mode="int", crossing_dim=32 + ), + cls.cross( + ("string_2", "int_2"), output_mode="int", crossing_dim=32 + ), + ], + output_mode="dict", + ) + fs.adapt(self._get_train_data_dict(as_dataset=True)) + data = { + key: value[0] for key, value in self._get_train_data_dict().items() + } + out = fs(data) + self.assertIsInstance(out, dict) + self.assertLen(out, 10) + self.assertEqual(out["string_1"].shape, (1,)) + self.assertTrue( + backend.standardize_dtype(out["string_1"].dtype).startswith("int") + ) + self.assertEqual(out["int_2"].shape, (1,)) + self.assertTrue( + backend.standardize_dtype(out["int_2"].dtype).startswith("int") + ) + self.assertEqual(out["string_2_X_int_2"].shape, (1,)) + self.assertTrue( + backend.standardize_dtype(out["string_2_X_int_2"].dtype).startswith( + "int" + ) + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires string dtype." + ) + def test_functional_api_sync_processing(self): + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "string_1": "string_categorical", + "string_2": "string_hashed", + "int_1": "integer_categorical", + "int_2": "integer_hashed", + "int_3": "integer_categorical", + }, + crosses=[("float_3", "string_1"), ("string_2", "int_2")], + output_mode="concat", + ) + fs.adapt(self._get_train_data_dict(as_dataset=True)) + inputs = fs.get_inputs() + features = fs.get_encoded_features() + outputs = layers.Dense(1)(features) + model = models.Model(inputs=inputs, outputs=outputs) + model.compile("adam", "mse") + ds = self._get_train_data_dict(as_labeled_dataset=True) + model.fit(ds.batch(4)) + model.evaluate(ds.batch(4)) + ds = self._get_train_data_dict(as_dataset=True) + model.predict(ds.batch(4)) + + @pytest.mark.requires_trainable_backend + def test_tf_data_async_processing(self): + fs = feature_space.FeatureSpace( + features={ + "float_1": "float", + "float_2": "float_normalized", + "float_3": "float_discretized", + "int_1": "integer_categorical", + "int_2": "integer_hashed", + "int_3": "integer_categorical", + }, + crosses=[("float_3", "int_1"), ("int_1", "int_2")], + output_mode="concat", + ) + fs.adapt( + self._get_train_data_dict(as_dataset=True, include_strings=False) + ) + features = fs.get_encoded_features() + outputs = layers.Dense(1)(features) + model = models.Model(inputs=features, outputs=outputs) + model.compile("adam", "mse") + ds = self._get_train_data_dict( + as_labeled_dataset=True, include_strings=False + ) + # Try map before batch + ds = ds.map(lambda x, y: (fs(x), y)) + model.fit(ds.batch(4)) + # Try map after batch + ds = self._get_train_data_dict( + as_labeled_dataset=True, include_strings=False + ) + ds = ds.batch(4) + ds = ds.map(lambda x, y: (fs(x), y)) + model.evaluate(ds) + ds = self._get_train_data_dict(as_dataset=True, include_strings=False) + ds = ds.map(fs) + model.predict(ds.batch(4)) + + def test_advanced_usage(self): + cls = feature_space.FeatureSpace + fs = feature_space.FeatureSpace( + features={ + "float_1": cls.float(), + "float_2": cls.float_normalized(), + "float_3": cls.float_discretized(num_bins=3), + "string_1": cls.string_categorical(max_tokens=5), + "string_2": cls.string_hashed(num_bins=32), + "int_1": cls.integer_categorical( + max_tokens=5, num_oov_indices=2 + ), + "int_2": cls.integer_hashed(num_bins=32), + "int_3": cls.integer_categorical(max_tokens=5), + }, + crosses=[ + cls.cross(("float_3", "string_1"), crossing_dim=32), + cls.cross(("string_2", "int_2"), crossing_dim=32), + ], + output_mode="concat", + ) + fs.adapt(self._get_train_data_dict(as_dataset=True)) + data = { + key: value[0] for key, value in self._get_train_data_dict().items() + } + out = fs(data) + self.assertEqual(out.shape, (148,)) + + def test_manual_kpl(self): + data = { + "text": ["1st string", "2nd string", "3rd string"], + } + cls = feature_space.FeatureSpace + + # Test with a tf-idf TextVectorization layer + tv = layers.TextVectorization(output_mode="tf_idf") + fs = feature_space.FeatureSpace( + features={ + "text": cls.feature( + preprocessor=tv, dtype="string", output_mode="float" + ), + }, + output_mode="concat", + ) + fs.adapt(tf_data.Dataset.from_tensor_slices(data)) + out = fs(data) + self.assertEqual(list(out.shape), [3, 5]) + + def test_no_adapt(self): + data = { + "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "text_1": [ + "This is", + "not just", + "an example", + "of random words.", + "these are", + "some words", + "in", + "a random", + "example.", + "Bye!", + ], + "float_1": [ + -1.2, + 0.0, + 2.4, + 1.2, + 15.0, + -100.0, + 23.1, + 3.12, + 0.1, + -0.01, + ], + } + cls = feature_space.FeatureSpace + # Pre-defined vocabulary. No need to adapt. + tv_vocab = [ + "this", + "is", + "just", + "an", + "example", + "with", + "some", + "words", + ] + tv_with_vocab = layers.TextVectorization( + vocabulary=tv_vocab, output_mode="int", output_sequence_length=3 + ) + + # Pre-defined mean and variance. No need to adapt. + mean, variance = 12.0, 5.0 + normalization = layers.Normalization(mean=mean, variance=variance) + + fs = feature_space.FeatureSpace( + { + "int_1": "integer_hashed", + "text_1": cls.feature( + dtype="string", + preprocessor=tv_with_vocab, + output_mode="int", + ), + "float_1": cls.feature( + dtype="float32", + preprocessor=normalization, + output_mode="float", + ), + }, + output_mode="dict", + ) + + out = fs(data) + float_out = ops.divide( + ops.convert_to_tensor(data["float_1"]) - mean, ops.sqrt(variance) + ) + float_out = ops.reshape(float_out, (10, -1)) + + self.assertEqual(tuple(out["int_1"].shape), (10, 32)) + self.assertEqual(tuple(out["text_1"].shape), (10, 3)) + self.assertAllClose(out["float_1"], float_out, atol=1e-3) + + @pytest.mark.skipif( + backend.backend() in ("numpy", "torch"), + reason=( + "TODO: When using FeatureSpace as a Model in torch and numpy, " + "the error is large." + ), + ) + def test_saving(self): + cls = feature_space.FeatureSpace + fs = feature_space.FeatureSpace( + features={ + "float_1": cls.float(), + "float_2": cls.float_normalized(), + "float_3": cls.float_discretized(num_bins=3), + "int_1": cls.integer_categorical( + max_tokens=5, num_oov_indices=2 + ), + "int_2": cls.integer_hashed(num_bins=32), + "int_3": cls.integer_categorical(max_tokens=5), + }, + crosses=[ + cls.cross(("float_3", "int_1"), crossing_dim=32), + cls.cross(("int_1", "int_2"), crossing_dim=32), + ], + output_mode="concat", + ) + fs.adapt( + self._get_train_data_dict(as_dataset=True, include_strings=False) + ) + data = { + key: value[0] + for key, value in self._get_train_data_dict( + include_strings=False + ).items() + } + ref_out = fs(data) + + temp_filepath = os.path.join(self.get_temp_dir(), "fs.keras") + fs.save(temp_filepath) + fs = saving_api.load_model(temp_filepath) + + # Save again immediately after loading to test idempotency + temp_filepath = os.path.join(self.get_temp_dir(), "fs2.keras") + fs.save(temp_filepath) + + # Test correctness of the first saved FS + out = fs(data) + self.assertAllClose(out, ref_out) + + inputs = fs.get_inputs() + outputs = fs.get_encoded_features() + model = models.Model(inputs=inputs, outputs=outputs) + ds = self._get_train_data_dict(as_dataset=True, include_strings=False) + out = model.predict(ds.batch(4)) + self.assertAllClose(out[0], ref_out) + + # Test correctness of the re-saved FS + fs = saving_api.load_model(temp_filepath) + out = fs(data) + self.assertAllClose(out, ref_out) + + def test_errors(self): + # Test no features + with self.assertRaisesRegex(ValueError, "cannot be None or empty"): + feature_space.FeatureSpace(features={}) + # Test no crossing dim + with self.assertRaisesRegex(ValueError, "`crossing_dim`"): + feature_space.FeatureSpace( + features={ + "f1": "integer_categorical", + "f2": "integer_categorical", + }, + crosses=[("f1", "f2")], + crossing_dim=None, + ) + # Test wrong cross feature name + with self.assertRaisesRegex(ValueError, "should be present in "): + feature_space.FeatureSpace( + features={ + "f1": "integer_categorical", + "f2": "integer_categorical", + }, + crosses=[("f1", "unknown")], + crossing_dim=32, + ) + # Test wrong output mode + with self.assertRaisesRegex(ValueError, "for argument `output_mode`"): + feature_space.FeatureSpace( + features={ + "f1": "integer_categorical", + "f2": "integer_categorical", + }, + output_mode="unknown", + ) + # Test call before adapt + with self.assertRaisesRegex(ValueError, "You need to call `.adapt"): + fs = feature_space.FeatureSpace( + features={ + "f1": "integer_categorical", + "f2": "integer_categorical", + } + ) + fs({"f1": [0], "f2": [0]}) + # Test get_encoded_features before adapt + with self.assertRaisesRegex(ValueError, "You need to call `.adapt"): + fs = feature_space.FeatureSpace( + features={ + "f1": "integer_categorical", + "f2": "integer_categorical", + } + ) + fs.get_encoded_features() diff --git a/keras/src/layers/preprocessing/hashed_crossing.py b/keras/src/layers/preprocessing/hashed_crossing.py new file mode 100644 index 000000000000..9a794e4beea7 --- /dev/null +++ b/keras/src/layers/preprocessing/hashed_crossing.py @@ -0,0 +1,227 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation +from keras.src.utils import backend_utils +from keras.src.utils import numerical_utils +from keras.src.utils import tf_utils +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.HashedCrossing") +class HashedCrossing(Layer): + """A preprocessing layer which crosses features using the "hashing trick". + + This layer performs crosses of categorical features using the "hashing + trick". Conceptually, the transformation can be thought of as: + `hash(concatenate(features)) % num_bins`. + + This layer currently only performs crosses of scalar inputs and batches of + scalar inputs. Valid input shapes are `(batch_size, 1)`, `(batch_size,)` and + `()`. + + **Note:** This layer wraps `tf.keras.layers.HashedCrossing`. It cannot + be used as part of the compiled computation graph of a model with + any backend other than TensorFlow. + It can however be used with any backend when running eagerly. + It can also always be used as part of an input preprocessing pipeline + with any backend (outside the model itself), which is how we recommend + to use this layer. + + **Note:** This layer is safe to use inside a `tf.data` pipeline + (independently of which backend you're using). + + Args: + num_bins: Number of hash bins. + output_mode: Specification for the output of the layer. Values can be + `"int"`, or `"one_hot"` configuring the layer as follows: + - `"int"`: Return the integer bin indices directly. + - `"one_hot"`: Encodes each individual element in the input into an + array the same size as `num_bins`, containing a 1 at the input's + bin index. Defaults to `"int"`. + sparse: Boolean. Only applicable to `"one_hot"` mode and only valid + when using the TensorFlow backend. If `True`, returns + a `SparseTensor` instead of a dense `Tensor`. Defaults to `False`. + **kwargs: Keyword arguments to construct a layer. + + Examples: + + **Crossing two scalar features.** + + >>> layer = keras.layers.HashedCrossing( + ... num_bins=5) + >>> feat1 = np.array(['A', 'B', 'A', 'B', 'A']) + >>> feat2 = np.array([101, 101, 101, 102, 102]) + >>> layer((feat1, feat2)) + array([1, 4, 1, 1, 3]) + + **Crossing and one-hotting two scalar features.** + + >>> layer = keras.layers.HashedCrossing( + ... num_bins=5, output_mode='one_hot') + >>> feat1 = np.array(['A', 'B', 'A', 'B', 'A']) + >>> feat2 = np.array([101, 101, 101, 102, 102]) + >>> layer((feat1, feat2)) + array([[0., 1., 0., 0., 0.], + [0., 0., 0., 0., 1.], + [0., 1., 0., 0., 0.], + [0., 1., 0., 0., 0.], + [0., 0., 0., 1., 0.]], dtype=float32) + """ + + def __init__( + self, + num_bins, + output_mode="int", + sparse=False, + name=None, + dtype=None, + **kwargs, + ): + if not tf.available: + raise ImportError( + "Layer HashedCrossing requires TensorFlow. " + "Install it via `pip install tensorflow`." + ) + + if output_mode == "int" and dtype is None: + dtype = "int64" + + super().__init__(name=name, dtype=dtype) + if sparse and backend.backend() != "tensorflow": + raise ValueError( + "`sparse=True` can only be used with the TensorFlow backend." + ) + + argument_validation.validate_string_arg( + output_mode, + allowable_strings=("int", "one_hot"), + caller_name=self.__class__.__name__, + arg_name="output_mode", + ) + + self.num_bins = num_bins + self.output_mode = output_mode + self.sparse = sparse + self._allow_non_tensor_positional_args = True + self._convert_input_args = False + self.supports_jit = False + + def compute_output_shape(self, input_shape): + if ( + not len(input_shape) == 2 + or not isinstance(input_shape[0], tuple) + or not isinstance(input_shape[1], tuple) + ): + raise ValueError( + "Expected as input a list/tuple of 2 tensors. " + f"Received input_shape={input_shape}" + ) + if input_shape[0][-1] != input_shape[1][-1]: + raise ValueError( + "Expected the two input tensors to have identical shapes. " + f"Received input_shape={input_shape}" + ) + + if not input_shape: + if self.output_mode == "int": + return () + return (self.num_bins,) + if self.output_mode == "int": + return tuple(input_shape[0]) + + if self.output_mode == "one_hot" and input_shape[0][-1] != 1: + return tuple(input_shape[0]) + (self.num_bins,) + + return tuple(input_shape[0])[:-1] + (self.num_bins,) + + def call(self, inputs): + from keras.src.backend import tensorflow as tf_backend + + self._check_at_least_two_inputs(inputs) + inputs = [tf_utils.ensure_tensor(x) for x in inputs] + self._check_input_shape_and_type(inputs) + + # Uprank to rank 2 for the cross_hashed op. + first_shape = tuple(inputs[0].shape) + rank = len(first_shape) + if rank < 2: + inputs = [tf_backend.numpy.expand_dims(x, -1) for x in inputs] + if rank < 1: + inputs = [tf_backend.numpy.expand_dims(x, -1) for x in inputs] + + # Perform the cross and convert to dense + outputs = tf.sparse.cross_hashed(inputs, self.num_bins) + outputs = tf.sparse.to_dense(outputs) + + # tf.sparse.cross_hashed output shape will always have None dimensions. + # Re-apply the known static shape and downrank to match input rank. + if rank == 2: + outputs.set_shape(first_shape) + elif rank == 1: + outputs.set_shape(first_shape + (1,)) + outputs = tf.squeeze(outputs, axis=1) + elif rank == 0: + outputs = tf.reshape(outputs, []) + + # Encode outputs. + outputs = numerical_utils.encode_categorical_inputs( + outputs, + output_mode=self.output_mode, + depth=self.num_bins, + sparse=self.sparse, + dtype=self.compute_dtype, + backend_module=tf_backend, + ) + return backend_utils.convert_tf_tensor(outputs, dtype=self.dtype) + + def get_config(self): + return { + "num_bins": self.num_bins, + "output_mode": self.output_mode, + "sparse": self.sparse, + "name": self.name, + "dtype": self.dtype, + } + + def _check_at_least_two_inputs(self, inputs): + if not isinstance(inputs, (list, tuple)): + raise ValueError( + "`HashedCrossing` should be called on a list or tuple of " + f"inputs. Received: inputs={inputs}" + ) + if len(inputs) < 2: + raise ValueError( + "`HashedCrossing` should be called on at least two inputs. " + f"Received: inputs={inputs}" + ) + + def _check_input_shape_and_type(self, inputs): + first_shape = tuple(inputs[0].shape) + rank = len(first_shape) + if rank > 2 or (rank == 2 and first_shape[-1] != 1): + raise ValueError( + "All `HashedCrossing` inputs should have shape `()`, " + "`(batch_size)` or `(batch_size, 1)`. " + f"Received: inputs={inputs}" + ) + if not all(tuple(x.shape) == first_shape for x in inputs[1:]): + raise ValueError( + "All `HashedCrossing` inputs should have equal shape. " + f"Received: inputs={inputs}" + ) + if any( + isinstance(x, (tf.RaggedTensor, tf.SparseTensor)) for x in inputs + ): + raise ValueError( + "All `HashedCrossing` inputs should be dense tensors. " + f"Received: inputs={inputs}" + ) + if not all( + tf.as_dtype(x.dtype).is_integer or x.dtype == tf.string + for x in inputs + ): + raise ValueError( + "All `HashedCrossing` inputs should have an integer or " + f"string dtype. Received: inputs={inputs}" + ) diff --git a/keras/src/layers/preprocessing/hashed_crossing_test.py b/keras/src/layers/preprocessing/hashed_crossing_test.py new file mode 100644 index 000000000000..b8eed977a316 --- /dev/null +++ b/keras/src/layers/preprocessing/hashed_crossing_test.py @@ -0,0 +1,212 @@ +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.testing.test_utils import named_product + + +class HashedCrossingTest(testing.TestCase): + def test_basics(self): + self.run_layer_test( + layers.HashedCrossing, + init_kwargs={ + "num_bins": 3, + "output_mode": "int", + }, + input_data=(np.array([1, 2]), np.array([4, 5])), + expected_output_shape=(2,), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + # Incomplete op support on tensorflow. + run_mixed_precision_check=False, + ) + self.run_layer_test( + layers.HashedCrossing, + init_kwargs={"num_bins": 4, "output_mode": "one_hot"}, + input_data=(np.array([1, 2]), np.array([4, 5])), + expected_output_shape=(2, 4), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + # Incomplete op support on tensorflow. + run_mixed_precision_check=False, + ) + + @parameterized.named_parameters( + named_product( + sparse=( + [True, False] if backend.backend() == "tensorflow" else [False] + ) + ) + ) + def test_correctness(self, sparse): + layer = layers.HashedCrossing(num_bins=5) + feat1 = np.array(["A", "B", "A", "B", "A"]) + feat2 = np.array([101, 101, 101, 102, 102]) + output = layer((feat1, feat2)) + self.assertAllClose(tf.constant([1, 4, 1, 1, 3]), output) + + layer = layers.HashedCrossing( + num_bins=5, output_mode="one_hot", sparse=sparse + ) + feat1 = np.array(["A", "B", "A", "B", "A"]) + feat2 = np.array([101, 101, 101, 102, 102]) + output = layer((feat1, feat2)) + self.assertSparse(output, sparse) + self.assertAllClose( + np.array( + [ + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + ] + ), + output, + ) + + def test_tf_data_compatibility(self): + layer = layers.HashedCrossing(num_bins=5) + feat1 = np.array(["A", "B", "A", "B", "A"]) + feat2 = np.array([101, 101, 101, 102, 102]) + ds = ( + tf.data.Dataset.from_tensor_slices((feat1, feat2)) + .batch(5) + .map(lambda x1, x2: layer((x1, x2))) + ) + output = next(iter(ds)).numpy() + self.assertAllClose(np.array([1, 4, 1, 1, 3]), output) + + def test_static_shape_preserved(self): + layer = layers.HashedCrossing(num_bins=5) + + def call_layer(x1, x2): + result = layer((x1, x2)) + self.assertEqual(result.shape, (5,)) + return result + + feat1 = np.array(["A", "B", "A", "B", "A"]) + feat2 = np.array([101, 101, 101, 102, 102]) + ds = ( + tf.data.Dataset.from_tensor_slices((feat1, feat2)) + .batch(5, drop_remainder=True) + .map(call_layer) + ) + next(iter(ds)) + + def test_unsupported_shape_input_fails(self): + with self.assertRaisesRegex(ValueError, "inputs should have shape"): + layers.HashedCrossing(num_bins=10)( + (np.array([[[1.0]]]), np.array([[[1.0]]])) + ) + + @pytest.mark.xfail + def test_cross_output_dtype(self): + input_1, input_2 = np.array([1]), np.array([1]) + + layer = layers.HashedCrossing(num_bins=2) + output_dtype = backend.standardize_dtype( + layer((input_1, input_2)).dtype + ) + self.assertEqual(output_dtype, "int64") + layer = layers.HashedCrossing(num_bins=2, dtype="int32") + output_dtype = backend.standardize_dtype( + layer((input_1, input_2)).dtype + ) + self.assertEqual(output_dtype, "int32") + layer = layers.HashedCrossing(num_bins=2, output_mode="one_hot") + output_dtype = backend.standardize_dtype( + layer((input_1, input_2)).dtype + ) + self.assertEqual(output_dtype, "float32") + layer = layers.HashedCrossing( + num_bins=2, output_mode="one_hot", dtype="float64" + ) + output_dtype = backend.standardize_dtype( + layer((input_1, input_2)).dtype + ) + self.assertEqual(output_dtype, "float64") + + def test_non_list_input_fails(self): + with self.assertRaisesRegex(ValueError, "should be called on a list"): + layers.HashedCrossing(num_bins=10)(np.array(1)) + + def test_single_input_fails(self): + with self.assertRaisesRegex(ValueError, "at least two inputs"): + layers.HashedCrossing(num_bins=10)([np.array(1)]) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Need sparse tensor support.", + ) + def test_sparse_input_fails(self): + with self.assertRaisesRegex( + ValueError, "inputs should be dense tensors" + ): + sparse_in = tf.sparse.from_dense(np.array([1])) + layers.HashedCrossing(num_bins=10)((sparse_in, sparse_in)) + + def test_float_input_fails(self): + with self.assertRaisesRegex( + ValueError, "should have an integer or string" + ): + layers.HashedCrossing(num_bins=10)( + (np.array([1.0]), np.array([1.0])) + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Need string tensor support.", + ) + def test_tf_string(self): + layer = layers.HashedCrossing(num_bins=10) + feat1 = tf.constant("A") + feat2 = tf.constant(101) + outputs = layer((feat1, feat2)) + self.assertAllClose(outputs, 1) + + layer = layers.HashedCrossing(num_bins=5, output_mode="one_hot") + feat1 = tf.constant(["A", "B", "A", "B", "A"]) + feat2 = tf.constant([101, 101, 101, 102, 102]) + self.assertAllClose( + tf.constant( + [ + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + ] + ), + layer((feat1, feat2)), + ) + + layer = layers.HashedCrossing(num_bins=5) + feat1 = tf.constant(["A", "B", "A", "B", "A"]) + feat2 = tf.constant([101, 101, 101, 102, 102]) + self.assertAllClose(tf.constant([1, 4, 1, 1, 3]), layer((feat1, feat2))) + + layer = layers.HashedCrossing( + num_bins=5, output_mode="one_hot", sparse=True + ) + cloned_layer = layers.HashedCrossing.from_config(layer.get_config()) + feat1 = tf.constant([["A"], ["B"], ["A"], ["B"], ["A"]]) + feat2 = tf.constant([[101], [101], [101], [102], [102]]) + original_outputs = layer((feat1, feat2)) + cloned_outputs = cloned_layer((feat1, feat2)) + self.assertAllClose( + tf.sparse.to_dense(cloned_outputs), + tf.sparse.to_dense(original_outputs), + ) diff --git a/keras/src/layers/preprocessing/hashing.py b/keras/src/layers/preprocessing/hashing.py new file mode 100644 index 000000000000..395bfc673502 --- /dev/null +++ b/keras/src/layers/preprocessing/hashing.py @@ -0,0 +1,287 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.utils import backend_utils +from keras.src.utils import numerical_utils +from keras.src.utils import tf_utils +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.Hashing") +class Hashing(Layer): + """A preprocessing layer which hashes and bins categorical features. + + This layer transforms categorical inputs to hashed output. It element-wise + converts a ints or strings to ints in a fixed range. The stable hash + function uses `tensorflow::ops::Fingerprint` to produce the same output + consistently across all platforms. + + This layer uses [FarmHash64](https://github.com/google/farmhash) by default, + which provides a consistent hashed output across different platforms and is + stable across invocations, regardless of device and context, by mixing the + input bits thoroughly. + + If you want to obfuscate the hashed output, you can also pass a random + `salt` argument in the constructor. In that case, the layer will use the + [SipHash64](https://github.com/google/highwayhash) hash function, with + the `salt` value serving as additional input to the hash function. + + **Note:** This layer internally uses TensorFlow. It cannot + be used as part of the compiled computation graph of a model with + any backend other than TensorFlow. + It can however be used with any backend when running eagerly. + It can also always be used as part of an input preprocessing pipeline + with any backend (outside the model itself), which is how we recommend + to use this layer. + + **Note:** This layer is safe to use inside a `tf.data` pipeline + (independently of which backend you're using). + + **Example (FarmHash64)** + + >>> layer = keras.layers.Hashing(num_bins=3) + >>> inp = [['A'], ['B'], ['C'], ['D'], ['E']] + >>> layer(inp) + array([[1], + [0], + [1], + [1], + [2]])> + + **Example (FarmHash64) with a mask value** + + >>> layer = keras.layers.Hashing(num_bins=3, mask_value='') + >>> inp = [['A'], ['B'], [''], ['C'], ['D']] + >>> layer(inp) + array([[1], + [1], + [0], + [2], + [2]]) + + **Example (SipHash64)** + + >>> layer = keras.layers.Hashing(num_bins=3, salt=[133, 137]) + >>> inp = [['A'], ['B'], ['C'], ['D'], ['E']] + >>> layer(inp) + array([[1], + [2], + [1], + [0], + [2]]) + + **Example (Siphash64 with a single integer, same as `salt=[133, 133]`)** + + >>> layer = keras.layers.Hashing(num_bins=3, salt=133) + >>> inp = [['A'], ['B'], ['C'], ['D'], ['E']] + >>> layer(inp) + array([[0], + [0], + [2], + [1], + [0]]) + + Args: + num_bins: Number of hash bins. Note that this includes the `mask_value` + bin, so the effective number of bins is `(num_bins - 1)` + if `mask_value` is set. + mask_value: A value that represents masked inputs, which are mapped to + index 0. `None` means no mask term will be added and the + hashing will start at index 0. Defaults to `None`. + salt: A single unsigned integer or None. + If passed, the hash function used will be SipHash64, + with these values used as an additional input + (known as a "salt" in cryptography). + These should be non-zero. If `None`, uses the FarmHash64 hash + function. It also supports tuple/list of 2 unsigned + integer numbers, see reference paper for details. + Defaults to `None`. + output_mode: Specification for the output of the layer. Values can be + `"int"`, `"one_hot"`, `"multi_hot"`, or + `"count"` configuring the layer as follows: + - `"int"`: Return the integer bin indices directly. + - `"one_hot"`: Encodes each individual element in the input into an + array the same size as `num_bins`, containing a 1 + at the input's bin index. If the last dimension is size 1, + will encode on that dimension. + If the last dimension is not size 1, will append a new + dimension for the encoded output. + - `"multi_hot"`: Encodes each sample in the input into a + single array the same size as `num_bins`, + containing a 1 for each bin index + index present in the sample. Treats the last dimension + as the sample dimension, if input shape is + `(..., sample_length)`, output shape will be + `(..., num_tokens)`. + - `"count"`: As `"multi_hot"`, but the int array contains a count of + the number of times the bin index appeared in the sample. + Defaults to `"int"`. + sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`, + and `"count"` output modes. Only supported with TensorFlow + backend. If `True`, returns a `SparseTensor` instead of + a dense `Tensor`. Defaults to `False`. + **kwargs: Keyword arguments to construct a layer. + + Input shape: + A single string, a list of strings, or an `int32` or `int64` tensor + of shape `(batch_size, ...,)`. + + Output shape: + An `int32` tensor of shape `(batch_size, ...)`. + + Reference: + + - [SipHash with salt](https://www.131002.net/siphash/siphash.pdf) + """ + + def __init__( + self, + num_bins, + mask_value=None, + salt=None, + output_mode="int", + sparse=False, + **kwargs, + ): + if not tf.available: + raise ImportError( + "Layer Hashing requires TensorFlow. " + "Install it via `pip install tensorflow`." + ) + + # By default, output int32 when output_mode='int' and floats otherwise. + if "dtype" not in kwargs or kwargs["dtype"] is None: + kwargs["dtype"] = ( + "int64" if output_mode == "int" else backend.floatx() + ) + + super().__init__(**kwargs) + + if num_bins is None or num_bins <= 0: + raise ValueError( + "The `num_bins` for `Hashing` cannot be `None` or " + f"non-positive values. Received: num_bins={num_bins}." + ) + + if output_mode == "int" and ( + self.dtype_policy.name not in ("int32", "int64") + ): + raise ValueError( + 'When `output_mode="int"`, `dtype` should be an integer ' + f"type, 'int32' or 'in64'. Received: dtype={kwargs['dtype']}" + ) + + # 'output_mode' must be one of (INT, ONE_HOT, MULTI_HOT, COUNT) + accepted_output_modes = ("int", "one_hot", "multi_hot", "count") + if output_mode not in accepted_output_modes: + raise ValueError( + "Invalid value for argument `output_mode`. " + f"Expected one of {accepted_output_modes}. " + f"Received: output_mode={output_mode}" + ) + + if sparse and output_mode == "int": + raise ValueError( + "`sparse` may only be true if `output_mode` is " + '`"one_hot"`, `"multi_hot"`, or `"count"`. ' + f"Received: sparse={sparse} and " + f"output_mode={output_mode}" + ) + + self.num_bins = num_bins + self.mask_value = mask_value + self.strong_hash = True if salt is not None else False + self.output_mode = output_mode + self.sparse = sparse + self.salt = None + if salt is not None: + if isinstance(salt, (tuple, list)) and len(salt) == 2: + self.salt = list(salt) + elif isinstance(salt, int): + self.salt = [salt, salt] + else: + raise ValueError( + "The `salt` argument for `Hashing` can only be a tuple of " + "size 2 integers, or a single integer. " + f"Received: salt={salt}." + ) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.supports_jit = False + + def call(self, inputs): + from keras.src.backend import tensorflow as tf_backend + + inputs = tf_utils.ensure_tensor(inputs) + if self.output_mode == "one_hot" and inputs.shape[-1] == 1: + # One hot only upranks if the final dimension is not 1. + inputs = tf_backend.numpy.squeeze(inputs, axis=-1) + if isinstance(inputs, tf.SparseTensor): + indices = tf.SparseTensor( + indices=inputs.indices, + values=self._hash_values_to_bins(inputs.values), + dense_shape=inputs.dense_shape, + ) + else: + indices = self._hash_values_to_bins(inputs) + outputs = numerical_utils.encode_categorical_inputs( + indices, + output_mode=self.output_mode, + depth=self.num_bins, + sparse=self.sparse, + dtype=self.dtype, + backend_module=tf_backend, + ) + return backend_utils.convert_tf_tensor(outputs) + + def _hash_values_to_bins(self, values): + """Converts a non-sparse tensor of values to bin indices.""" + hash_bins = self.num_bins + mask = None + # If mask_value is set, the zeroth bin is reserved for it. + if self.mask_value is not None and hash_bins > 1: + hash_bins -= 1 + mask = tf.equal(values, self.mask_value) + # Convert all values to strings before hashing. + # Floats are first normalized to int64. + if values.dtype.is_floating: + values = tf.cast(values, dtype="int64") + if values.dtype != tf.string: + values = tf.as_string(values) + # Hash the strings. + if self.strong_hash: + values = tf.strings.to_hash_bucket_strong( + values, hash_bins, name="hash", key=self.salt + ) + else: + values = tf.strings.to_hash_bucket_fast( + values, hash_bins, name="hash" + ) + if mask is not None: + values = tf.add(values, tf.ones_like(values)) + values = tf.where(mask, tf.zeros_like(values), values) + return values + + def compute_output_spec(self, inputs): + if self.output_mode == "int": + return backend.KerasTensor(shape=inputs.shape, dtype=self.dtype) + if len(inputs.shape) >= 1: + base_shape = tuple(inputs.shape)[:-1] + else: + base_shape = () + return backend.KerasTensor( + shape=base_shape + (self.num_bins,), dtype=self.dtype + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_bins": self.num_bins, + "salt": self.salt, + "mask_value": self.mask_value, + "output_mode": self.output_mode, + "sparse": self.sparse, + } + ) + return config diff --git a/keras/src/layers/preprocessing/hashing_test.py b/keras/src/layers/preprocessing/hashing_test.py new file mode 100644 index 000000000000..3a7966f81617 --- /dev/null +++ b/keras/src/layers/preprocessing/hashing_test.py @@ -0,0 +1,533 @@ +import os + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.saving import load_model + + +class ArrayLike: + def __init__(self, values): + self.values = values + + def __array__(self): + return np.array(self.values) + + +@pytest.mark.skipif( + backend.backend() == "numpy", reason="Broken with NumPy backend." +) +class HashingTest(testing.TestCase): + def test_config(self): + layer = layers.Hashing( + num_bins=8, + output_mode="int", + ) + self.run_class_serialization_test(layer) + + def test_correctness(self): + layer = layers.Hashing(num_bins=3) + inp = [["A"], ["B"], ["C"], ["D"], ["E"]] + output = layer(inp) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[1], [0], [1], [1], [2]])) + + layer = layers.Hashing(num_bins=3, mask_value="") + inp = [["A"], ["B"], [""], ["C"], ["D"]] + output = layer(inp) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[1], [1], [0], [2], [2]])) + + layer = layers.Hashing(num_bins=3, salt=[133, 137]) + inp = [["A"], ["B"], ["C"], ["D"], ["E"]] + output = layer(inp) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[1], [2], [1], [0], [2]])) + + layer = layers.Hashing(num_bins=3, salt=133) + inp = [["A"], ["B"], ["C"], ["D"], ["E"]] + output = layer(inp) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[0], [0], [2], [1], [0]])) + + def test_tf_data_compatibility(self): + layer = layers.Hashing(num_bins=3) + inp = [["A"], ["B"], ["C"], ["D"], ["E"]] + ds = tf.data.Dataset.from_tensor_slices(inp).batch(5).map(layer) + output = next(iter(ds)).numpy() + self.assertAllClose(output, np.array([[1], [0], [1], [1], [2]])) + + @parameterized.named_parameters( + ("list", list), + ("tuple", tuple), + ("numpy", np.array), + ("array_like", ArrayLike), + ) + def test_tensor_like_inputs(self, data_fn): + input_data = data_fn([0, 1, 2, 3, 4]) + expected_output = [1, 0, 1, 0, 2] + + layer = layers.Hashing(num_bins=3) + output_data = layer(input_data) + self.assertAllEqual(output_data, expected_output) + + def test_hash_single_bin(self): + layer = layers.Hashing(num_bins=1) + inp = np.asarray([["A"], ["B"], ["C"], ["D"], ["E"]]) + output = layer(inp) + self.assertAllClose([[0], [0], [0], [0], [0]], output) + + def test_hash_dense_input_farmhash(self): + layer = layers.Hashing(num_bins=2) + inp = np.asarray( + [["omar"], ["stringer"], ["marlo"], ["wire"], ["skywalker"]] + ) + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[0], [0], [1], [0], [0]], output) + + def test_hash_dense_input_mask_value_farmhash(self): + empty_mask_layer = layers.Hashing(num_bins=3, mask_value="") + omar_mask_layer = layers.Hashing(num_bins=3, mask_value="omar") + inp = np.asarray( + [["omar"], ["stringer"], ["marlo"], ["wire"], ["skywalker"]] + ) + empty_mask_output = empty_mask_layer(inp) + omar_mask_output = omar_mask_layer(inp) + # Outputs should be one more than test_hash_dense_input_farmhash (the + # zeroth bin is now reserved for masks). + self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output) + # 'omar' should map to 0. + self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output) + + def test_hash_dense_list_input_farmhash(self): + layer = layers.Hashing(num_bins=2) + inp = [["omar"], ["stringer"], ["marlo"], ["wire"], ["skywalker"]] + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[0], [0], [1], [0], [0]], output) + + inp = ["omar", "stringer", "marlo", "wire", "skywalker"] + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([0, 0, 1, 0, 0], output) + + def test_hash_dense_int_input_farmhash(self): + layer = layers.Hashing(num_bins=3) + inp = np.asarray([[0], [1], [2], [3], [4]]) + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[1], [0], [1], [0], [2]], output) + + def test_hash_dense_input_siphash(self): + layer = layers.Hashing(num_bins=2, salt=[133, 137]) + inp = np.asarray( + [["omar"], ["stringer"], ["marlo"], ["wire"], ["skywalker"]] + ) + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + # Note the result is different from FarmHash. + self.assertAllClose([[0], [1], [0], [1], [0]], output) + + layer_2 = layers.Hashing(num_bins=2, salt=[211, 137]) + output_2 = layer_2(inp) + # Note the result is different from (133, 137). + self.assertAllClose([[1], [0], [1], [0], [1]], output_2) + + def test_hash_dense_int_input_siphash(self): + layer = layers.Hashing(num_bins=3, salt=[133, 137]) + inp = np.asarray([[0], [1], [2], [3], [4]]) + output = layer(inp) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[1], [1], [2], [0], [1]], output) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Uses tf.SparseTensor." + ) + def test_hash_sparse_input_farmhash(self): + layer = layers.Hashing(num_bins=2) + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = tf.SparseTensor( + indices=indices, + values=["omar", "stringer", "marlo", "wire", "skywalker"], + dense_shape=[3, 2], + ) + output = layer(inp) + self.assertAllClose(indices, output.indices) + self.assertAllClose([0, 0, 1, 0, 0], output.values) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Uses tf.SparseTensor." + ) + def test_hash_sparse_input_mask_value_farmhash(self): + empty_mask_layer = layers.Hashing(num_bins=3, mask_value="") + omar_mask_layer = layers.Hashing(num_bins=3, mask_value="omar") + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = tf.SparseTensor( + indices=indices, + values=["omar", "stringer", "marlo", "wire", "skywalker"], + dense_shape=[3, 2], + ) + empty_mask_output = empty_mask_layer(inp) + omar_mask_output = omar_mask_layer(inp) + self.assertAllClose(indices, omar_mask_output.indices) + self.assertAllClose(indices, empty_mask_output.indices) + # Outputs should be one more than test_hash_sparse_input_farmhash (the + # zeroth bin is now reserved for masks). + self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values) + # 'omar' should map to 0. + self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Uses tf.SparseTensor." + ) + def test_hash_sparse_int_input_farmhash(self): + layer = layers.Hashing(num_bins=3) + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = tf.SparseTensor( + indices=indices, values=[0, 1, 2, 3, 4], dense_shape=[3, 2] + ) + output = layer(inp) + self.assertAllClose(indices, output.indices) + self.assertAllClose([1, 0, 1, 0, 2], output.values) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Uses tf.SparseTensor." + ) + def test_hash_sparse_input_siphash(self): + layer = layers.Hashing(num_bins=2, salt=[133, 137]) + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = tf.SparseTensor( + indices=indices, + values=["omar", "stringer", "marlo", "wire", "skywalker"], + dense_shape=[3, 2], + ) + output = layer(inp) + self.assertAllClose(output.indices, indices) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose([0, 1, 0, 1, 0], output.values) + + layer_2 = layers.Hashing(num_bins=2, salt=[211, 137]) + output = layer_2(inp) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose([1, 0, 1, 0, 1], output.values) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Uses tf.SparseTensor." + ) + def test_hash_sparse_int_input_siphash(self): + layer = layers.Hashing(num_bins=3, salt=[133, 137]) + indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] + inp = tf.SparseTensor( + indices=indices, values=[0, 1, 2, 3, 4], dense_shape=[3, 2] + ) + output = layer(inp) + self.assertAllClose(indices, output.indices) + self.assertAllClose([1, 1, 2, 0, 1], output.values) + + def test_invalid_inputs(self): + with self.assertRaisesRegex(ValueError, "cannot be `None`"): + _ = layers.Hashing(num_bins=None) + with self.assertRaisesRegex(ValueError, "cannot be `None`"): + _ = layers.Hashing(num_bins=-1) + with self.assertRaisesRegex( + ValueError, "can only be a tuple of size 2" + ): + _ = layers.Hashing(num_bins=2, salt="string") + with self.assertRaisesRegex( + ValueError, "can only be a tuple of size 2" + ): + _ = layers.Hashing(num_bins=2, salt=[1]) + with self.assertRaisesRegex( + ValueError, "can only be a tuple of size 2" + ): + _ = layers.Hashing(num_bins=1, salt=[133, 137, 177]) + + def test_one_hot_output(self): + input_array = np.array([0, 1, 2, 3, 4]) + + expected_output = [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ] + expected_output_shape = [None, 3] + + inputs = layers.Input(shape=(1,), dtype="int32") + layer = layers.Hashing(num_bins=3, output_mode="one_hot") + outputs = layer(inputs) + self.assertAllEqual(expected_output_shape, outputs.shape) + + model = models.Model(inputs, outputs) + output_data = model(input_array) + self.assertAllClose(expected_output, output_data) + + def test_multi_hot_output(self): + input_array = np.array([[0, 1, 2, 3, 4]]) + + expected_output = [[1.0, 1.0, 1.0]] + expected_output_shape = [None, 3] + + inputs = layers.Input(shape=(None,), dtype="int32") + layer = layers.Hashing(num_bins=3, output_mode="multi_hot") + outputs = layer(inputs) + self.assertAllEqual(expected_output_shape, outputs.shape) + + model = models.Model(inputs, outputs) + output_data = model(input_array) + self.assertAllClose(expected_output, output_data) + + @parameterized.named_parameters( + ( + "1d_input", + [0, 1, 2, 3, 4], + [2.0, 2.0, 1.0], + [3], + ), + ( + "2d_input", + [[0, 1, 2, 3, 4]], + [[2.0, 2.0, 1.0]], + [None, 3], + ), + ) + def test_count_output(self, input_value, expected_output, output_shape): + input_array = np.array(input_value) + if input_array.ndim == 1: + symbolic_sample_shape = () + elif input_array.ndim == 2: + symbolic_sample_shape = (None,) + else: + raise TypeError("Unknown `symbolic_sample_shape`") + inputs = layers.Input(shape=symbolic_sample_shape, dtype="int32") + layer = layers.Hashing(num_bins=3, output_mode="count") + outputs = layer(inputs) + self.assertAllEqual(output_shape, outputs.shape) + output_data = layer(input_array) + self.assertAllEqual(expected_output, output_data) + + @parameterized.named_parameters( + ("int32", "int32"), + ("int64", "int64"), + ) + def test_int_output_dtype(self, dtype): + input_data = layers.Input(batch_size=16, shape=(4,), dtype="string") + layer = layers.Hashing(num_bins=3, output_mode="int", dtype=dtype) + output = layer(input_data) + self.assertEqual(output.dtype, dtype) + + @parameterized.named_parameters( + ("float32", "float32"), + ("float64", "float64"), + ) + def test_one_hot_output_dtype(self, dtype): + input_data = layers.Input(batch_size=16, shape=(1,), dtype="string") + layer = layers.Hashing(num_bins=3, output_mode="one_hot", dtype=dtype) + output = layer(input_data) + self.assertEqual(output.dtype, dtype) + + def test_config_with_custom_name(self): + layer = layers.Hashing(num_bins=2, name="hashing") + config = layer.get_config() + layer_1 = layers.Hashing.from_config(config) + self.assertEqual(layer_1.name, layer.name) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Uses string dtype." + ) + def test_saving(self): + input_data = np.array( + ["omar", "stringer", "marlo", "wire", "skywalker"] + ) + inputs = layers.Input(shape=(), dtype="string") + outputs = layers.Hashing(num_bins=100)(inputs) + model = models.Model(inputs=inputs, outputs=outputs) + + original_output_data = model(input_data) + + # Save the model to disk. + output_path = os.path.join(self.get_temp_dir(), "keras_model.keras") + model.save(output_path) + loaded_model = load_model(output_path) + + # Ensure that the loaded model is unique (so that the save/load is real) + self.assertIsNot(model, loaded_model) + + # Validate correctness of the new model. + new_output_data = loaded_model(input_data) + self.assertAllClose(new_output_data, original_output_data) + + @parameterized.named_parameters( + ( + "list_input", + [1, 2, 3], + [1, 1, 1], + ), + ( + "list_input_2d", + [[1], [2], [3]], + [[1], [1], [1]], + ), + ( + "list_input_2d_multiple", + [[1, 2], [2, 3], [3, 4]], + [[1, 1], [1, 1], [1, 1]], + ), + ( + "list_input_3d", + [[[1], [2]], [[2], [3]], [[3], [4]]], + [[[1], [1]], [[1], [1]], [[1], [1]]], + ), + ) + def test_hash_list_input(self, input_data, expected): + layer = layers.Hashing(num_bins=2) + out_data = layer(input_data) + self.assertAllEqual( + expected, backend.convert_to_numpy(out_data).tolist() + ) + + def test_hashing_invalid_num_bins(self): + # Test with `num_bins` set to None + with self.assertRaisesRegex( + ValueError, + "The `num_bins` for `Hashing` cannot be `None` or non-positive", + ): + layers.Hashing(num_bins=None) + + # Test with `num_bins` set to 0 + with self.assertRaisesRegex( + ValueError, + "The `num_bins` for `Hashing` cannot be `None` or non-positive", + ): + layers.Hashing(num_bins=0) + + def test_hashing_invalid_output_mode(self): + # Test with an unsupported `output_mode` + with self.assertRaisesRegex( + ValueError, + "Invalid value for argument `output_mode`. Expected one of", + ): + layers.Hashing(num_bins=3, output_mode="unsupported_mode") + + def test_hashing_invalid_dtype_for_int_mode(self): + with self.assertRaisesRegex( + ValueError, + 'When `output_mode="int"`, `dtype` should be an integer type,', + ): + layers.Hashing(num_bins=3, output_mode="int", dtype="float32") + + def test_hashing_sparse_with_int_mode(self): + # Test setting `sparse=True` with `output_mode='int'` + with self.assertRaisesRegex( + ValueError, "`sparse` may only be true if `output_mode` is" + ): + layers.Hashing(num_bins=3, output_mode="int", sparse=True) + + +# TODO: support tf.RaggedTensor. +# def test_hash_ragged_string_input_farmhash(self): +# layer = layers.Hashing(num_bins=2) +# inp_data = tf.ragged.constant( +# [ +# ["omar", "stringer", "marlo", "wire"], +# ["marlo", "skywalker", "wire"], +# ], +# dtype="string", +# ) +# out_data = layer(inp_data) +# # Same hashed output as test_hash_sparse_input_farmhash +# expected_output = [[0, 0, 1, 0], [1, 0, 0]] +# self.assertAllEqual(expected_output, out_data) + +# inp_t = layers.Input(shape=(None,), ragged=True, dtype="string") +# out_t = layer(inp_t) +# model = models.Model(inputs=inp_t, outputs=out_t) +# self.assertAllClose(out_data, model.predict(inp_data)) + +# TODO: support tf.RaggedTensor. +# def test_hash_ragged_input_mask_value(self): +# empty_mask_layer = layers.Hashing(num_bins=3, mask_value="") +# omar_mask_layer = layers.Hashing(num_bins=3, mask_value="omar") +# inp_data = tf.ragged.constant( +# [ +# ["omar", "stringer", "marlo", "wire"], +# ["marlo", "skywalker", "wire"], +# ], +# dtype="string", +# ) +# empty_mask_output = empty_mask_layer(inp_data) +# omar_mask_output = omar_mask_layer(inp_data) +# # Outputs should be one more than test_hash_ragged_string_input_farmhash +# # (the zeroth bin is now reserved for masks). +# expected_output = [[1, 1, 2, 1], [2, 1, 1]] +# self.assertAllClose(expected_output[0], empty_mask_output[1]) +# self.assertAllClose(expected_output[1], empty_mask_output[2]) +# # 'omar' should map to 0. +# expected_output = [[0, 1, 2, 1], [2, 1, 1]] +# self.assertAllClose(expected_output[0], omar_mask_output[0]) +# self.assertAllClose(expected_output[1], omar_mask_output[1]) + +# TODO: support tf.RaggedTensor. +# def test_hash_ragged_int_input_farmhash(self): +# layer = layers.Hashing(num_bins=3) +# inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]], dtype="int64") +# out_data = layer(inp_data) +# # Same hashed output as test_hash_sparse_input_farmhash +# expected_output = [[1, 0, 0, 2], [1, 0, 1]] +# self.assertAllEqual(expected_output[0], out_data[0]) +# self.assertAllEqual(expected_output[1], out_data[1]) +# inp_t = layers.Input(shape=(None,), ragged=True, dtype="int64") +# out_t = layer(inp_t) +# model = models.Model(inputs=inp_t, outputs=out_t) +# self.assertAllClose(out_data, model.predict(inp_data)) + +# TODO: support tf.RaggedTensor. +# def test_hash_ragged_string_input_siphash(self): +# layer = layers.Hashing(num_bins=2, salt=[133, 137]) +# inp_data = tf.ragged.constant( +# [ +# ["omar", "stringer", "marlo", "wire"], +# ["marlo", "skywalker", "wire"], +# ], +# dtype="string", +# ) +# out_data = layer(inp_data) +# # Same hashed output as test_hash_dense_input_siphash +# expected_output = [[0, 1, 0, 1], [0, 0, 1]] +# self.assertAllEqual(expected_output, out_data) + +# inp_t = layers.Input(shape=(None,), ragged=True, dtype="string") +# out_t = layer(inp_t) +# model = models.Model(inputs=inp_t, outputs=out_t) +# self.assertAllClose(out_data, model.predict(inp_data)) + +# layer_2 = layers.Hashing(num_bins=2, salt=[211, 137]) +# out_data = layer_2(inp_data) +# expected_output = [[1, 0, 1, 0], [1, 1, 0]] +# self.assertAllEqual(expected_output, out_data) + +# out_t = layer_2(inp_t) +# model = models.Model(inputs=inp_t, outputs=out_t) +# self.assertAllClose(out_data, model.predict(inp_data)) + +# TODO: support tf.RaggedTensor. +# def test_hash_ragged_int_input_siphash(self): +# layer = layers.Hashing(num_bins=3, salt=[133, 137]) +# inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]], dtype="int64") +# out_data = layer(inp_data) +# # Same hashed output as test_hash_sparse_input_farmhash +# expected_output = [[1, 1, 0, 1], [2, 1, 1]] +# self.assertAllEqual(expected_output, out_data) + +# inp_t = layers.Input(shape=(None,), ragged=True, dtype="int64") +# out_t = layer(inp_t) +# model = models.Model(inputs=inp_t, outputs=out_t) +# self.assertAllClose(out_data, model.predict(inp_data)) diff --git a/keras/src/layers/preprocessing/image_preprocessing/__init__.py b/keras/src/layers/preprocessing/image_preprocessing/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/preprocessing/image_preprocessing/aug_mix.py b/keras/src/layers/preprocessing/image_preprocessing/aug_mix.py new file mode 100644 index 000000000000..fa7dd33297b1 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/aug_mix.py @@ -0,0 +1,328 @@ +import random as py_random + +import keras.src.layers as layers +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator +from keras.src.utils import backend_utils + +AUGMENT_LAYERS_ALL = [ + "random_shear", + "random_translation", + "random_rotation", + "random_posterization", + "solarization", + "auto_contrast", + "equalization", + "random_brightness", + "random_color_degeneration", + "random_contrast", + "random_sharpness", +] + +AUGMENT_LAYERS = [ + "random_shear", + "random_translation", + "random_rotation", + "random_posterization", + "solarization", + "auto_contrast", + "equalization", +] + + +@keras_export("keras.layers.AugMix") +class AugMix(BaseImagePreprocessingLayer): + """Performs the AugMix data augmentation technique. + + AugMix aims to produce images with variety while preserving the image + semantics and local statistics. During the augmentation process, + the same augmentation is applied across all images in the batch + in num_chains different ways, with each chain consisting of + chain_depth augmentations. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [AugMix paper](https://arxiv.org/pdf/1912.02781) + - [Official Code](https://github.com/google-research/augmix) + + Args: + value_range: the range of values the incoming images will have. + Represented as a two number tuple written (low, high). + This is typically either `(0, 1)` or `(0, 255)` depending + on how your preprocessing pipeline is set up. + num_chains: an integer representing the number of different chains to + be mixed, defaults to 3. + chain_depth: an integer representing the maximum number of + transformations to be applied in each chain. The actual number + of transformations in each chain will be sampled randomly + from the range `[0, `chain_depth`]`. Defaults to 3. + factor: The strength of the augmentation as a normalized value + between 0 and 1. Default is 0.3. + alpha: a float value used as the probability coefficients for the + Beta and Dirichlet distributions, defaults to 1.0. + all_ops: Use all operations (including random_brightness, + random_color_degeneration, random_contrast and random_sharpness). + Default is True. + interpolation: The interpolation method to use for resizing operations. + Options include `"nearest"`, `"bilinear"`. Default is `"bilinear"`. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + def __init__( + self, + value_range=(0, 255), + num_chains=3, + chain_depth=3, + factor=0.3, + alpha=1.0, + all_ops=True, + interpolation="bilinear", + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + + self.value_range = value_range + self.num_chains = num_chains + self.chain_depth = chain_depth + self._set_factor(factor) + self.alpha = alpha + self.all_ops = all_ops + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + + if self.all_ops: + self._augment_layers = AUGMENT_LAYERS_ALL + else: + self._augment_layers = AUGMENT_LAYERS + + self.random_shear = layers.RandomShear( + x_factor=self.factor, + y_factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_translation = layers.RandomTranslation( + height_factor=self.factor, + width_factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_rotation = layers.RandomRotation( + factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.solarization = layers.Solarization( + addition_factor=self.factor, + threshold_factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_posterization = layers.RandomPosterization( + factor=max(1, int(8 * self.factor[1])), + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.auto_contrast = layers.AutoContrast( + value_range=self.value_range, data_format=data_format, **kwargs + ) + + self.equalization = layers.Equalization( + value_range=self.value_range, data_format=data_format, **kwargs + ) + + if self.all_ops: + self.random_brightness = layers.RandomBrightness( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_color_degeneration = layers.RandomColorDegeneration( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_contrast = layers.RandomContrast( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_sharpness = layers.RandomSharpness( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + def build(self, input_shape): + for layer_name in self._augment_layers: + augmentation_layer = getattr(self, layer_name) + augmentation_layer.build(input_shape) + + def _sample_from_dirichlet(self, shape, alpha, seed): + gamma_sample = self.backend.random.gamma( + shape=shape, + alpha=alpha, + seed=seed, + ) + return gamma_sample / self.backend.numpy.sum( + gamma_sample, axis=-1, keepdims=True + ) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + for layer_name in self._augment_layers: + augmentation_layer = getattr(self, layer_name) + augmentation_layer.backend.set_backend("tensorflow") + + seed = seed or self._get_seed_generator(self.backend._backend) + + chain_mixing_weights = self._sample_from_dirichlet( + [self.num_chains], self.alpha, seed + ) + weight_sample = self.backend.random.beta( + shape=(), + alpha=self.alpha, + beta=self.alpha, + seed=seed, + ) + + chain_transforms = [] + for _ in range(self.num_chains): + depth_transforms = [] + for _ in range(self.chain_depth): + layer_name = py_random.choice(self._augment_layers + [None]) + if layer_name is None: + continue + augmentation_layer = getattr(self, layer_name) + depth_transforms.append( + { + "layer_name": layer_name, + "transformation": ( + augmentation_layer.get_random_transformation( + data, + seed=self._get_seed_generator( + self.backend._backend + ), + ) + ), + } + ) + chain_transforms.append(depth_transforms) + + transformation = { + "chain_mixing_weights": chain_mixing_weights, + "weight_sample": weight_sample, + "chain_transforms": chain_transforms, + } + + return transformation + + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + + chain_mixing_weights = self.backend.cast( + transformation["chain_mixing_weights"], dtype=self.compute_dtype + ) + weight_sample = self.backend.cast( + transformation["weight_sample"], dtype=self.compute_dtype + ) + chain_transforms = transformation["chain_transforms"] + + aug_images = self.backend.numpy.zeros_like(images) + for idx, chain_transform in enumerate(chain_transforms): + copied_images = self.backend.numpy.copy(images) + for depth_transform in chain_transform: + layer_name = depth_transform["layer_name"] + layer_transform = depth_transform["transformation"] + + augmentation_layer = getattr(self, layer_name) + copied_images = augmentation_layer.transform_images( + copied_images, layer_transform + ) + aug_images += copied_images * chain_mixing_weights[idx] + images = weight_sample * images + (1 - weight_sample) * aug_images + + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "value_range": self.value_range, + "num_chains": self.chain_depth, + "chain_depth": self.num_chains, + "factor": self.factor, + "alpha": self.alpha, + "all_ops": self.all_ops, + "interpolation": self.interpolation, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/aug_mix_test.py b/keras/src/layers/preprocessing/image_preprocessing/aug_mix_test.py new file mode 100644 index 000000000000..2513642b68e8 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/aug_mix_test.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandAugmentTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.AugMix, + init_kwargs={ + "value_range": (0, 255), + "num_chains": 2, + "chain_depth": 2, + "factor": 1, + "alpha": 1.0, + "all_ops": True, + "interpolation": "nearest", + "seed": 43, + "data_format": "channels_last", + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_aug_mix_inference(self): + seed = 3481 + layer = layers.AugMix() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_augment_randomness(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + + layer = layers.AugMix( + num_chains=11, all_ops=True, data_format=data_format + ) + augmented_image = layer(input_data) + + self.assertNotAllClose( + backend.convert_to_numpy(augmented_image), input_data + ) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.AugMix(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py b/keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py new file mode 100644 index 000000000000..b24f3fb737ff --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py @@ -0,0 +1,112 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.ops.core import _saturate_cast + + +@keras_export("keras.layers.AutoContrast") +class AutoContrast(BaseImagePreprocessingLayer): + """Performs the auto-contrast operation on an image. + + Auto contrast stretches the values of an image across the entire available + `value_range`. This makes differences between pixels more obvious. An + example of this is if an image only has values `[0, 1]` out of the range + `[0, 255]`, auto contrast will change the `1` values to be `255`. + + This layer is active at both training and inference time. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + value_range: Range of values the incoming images will have. + Represented as a two number tuple written `(low, high)`. + This is typically either `(0, 1)` or `(0, 255)` depending + on how your preprocessing pipeline is set up. + Defaults to `(0, 255)`. + """ + + _USE_BASE_FACTOR = False + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + value_range=(0, 255), + **kwargs, + ): + super().__init__(**kwargs) + self._set_value_range(value_range) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def transform_images(self, images, transformation=None, training=True): + original_images = images + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + + images = self.backend.cast(images, self.compute_dtype) + low = self.backend.numpy.min(images, axis=(1, 2), keepdims=True) + high = self.backend.numpy.max(images, axis=(1, 2), keepdims=True) + scale = 255.0 / (high - low) + offset = -low * scale + + images = images * scale + offset + results = self.backend.numpy.clip(images, 0.0, 255.0) + results = self._transform_value_range( + results, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + # don't process NaN channels + results = self.backend.numpy.where( + self.backend.numpy.isnan(results), original_images, results + ) + if results.dtype == images.dtype: + return results + if backend.is_int_dtype(images.dtype): + results = self.backend.numpy.round(results) + return _saturate_cast(results, images.dtype, self.backend) + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def get_config(self): + config = super().get_config() + config.update({"value_range": self.value_range}) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/auto_contrast_test.py b/keras/src/layers/preprocessing/image_preprocessing/auto_contrast_test.py new file mode 100644 index 000000000000..f448819929e8 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/auto_contrast_test.py @@ -0,0 +1,94 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class AutoContrastTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.AutoContrast, + init_kwargs={ + "value_range": (20, 200), + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_constant_channels_dont_get_nanned(self): + img = np.array([1, 1], dtype="float32") + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=0) + + layer = layers.AutoContrast(value_range=(0, 255)) + ys = layer(img) + + self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0)) + self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0)) + + def test_auto_contrast_expands_value_range(self): + img = np.array([0, 128], dtype="float32") + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=0) + + layer = layers.AutoContrast(value_range=(0, 255)) + ys = layer(img) + + self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0)) + self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0)) + + def test_auto_contrast_different_values_per_channel(self): + img = np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype="float32", + ) + img = np.expand_dims(img, axis=0) + + layer = layers.AutoContrast(value_range=(0, 255)) + ys = layer(img) + + self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 0.0)) + self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 0.0)) + + self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 255.0)) + self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 255.0)) + + self.assertAllClose( + ys, + [ + [ + [[0.0, 0.0, 0.0], [85.0, 85.0, 85.0]], + [[170.0, 170.0, 170.0], [255.0, 255.0, 255.0]], + ] + ], + ) + + def test_auto_contrast_expands_value_range_uint8(self): + img = np.array([0, 128], dtype="uint8") + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=0) + + layer = layers.AutoContrast(value_range=(0, 255)) + ys = layer(img) + + self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0)) + self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0)) + + def test_auto_contrast_properly_converts_value_range(self): + img = np.array([0, 0.5], dtype="float32") + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=-1) + img = np.expand_dims(img, axis=0) + + layer = layers.AutoContrast(value_range=(0, 1)) + ys = layer(img) + self.assertAllClose( + ops.convert_to_numpy(ys[0]), np.array([[[0.0]], [[1]]]) + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py b/keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py new file mode 100644 index 000000000000..6cd3bc43cc3e --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py @@ -0,0 +1,385 @@ +import math + +from keras.src.backend import config as backend_config +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import ( # noqa: E501 + densify_bounding_boxes, +) + + +class BaseImagePreprocessingLayer(DataLayer): + _USE_BASE_FACTOR = True + _FACTOR_BOUNDS = (-1, 1) + + def __init__( + self, factor=None, bounding_box_format=None, data_format=None, **kwargs + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.data_format = backend_config.standardize_data_format(data_format) + if self._USE_BASE_FACTOR: + factor = factor or 0.0 + self._set_factor(factor) + elif factor is not None: + raise ValueError( + f"Layer {self.__class__.__name__} does not take " + f"a `factor` argument. Received: factor={factor}" + ) + + def _set_factor(self, factor): + error_msg = ( + "The `factor` argument should be a number " + "(or a list of two numbers) " + "in the range " + f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. " + f"Received: factor={factor}" + ) + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError(error_msg) + if ( + factor[0] > self._FACTOR_BOUNDS[1] + or factor[1] < self._FACTOR_BOUNDS[0] + ): + raise ValueError(error_msg) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + if ( + factor < self._FACTOR_BOUNDS[0] + or factor > self._FACTOR_BOUNDS[1] + ): + raise ValueError(error_msg) + factor = abs(factor) + lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor] + else: + raise ValueError(error_msg) + self.factor = lower, upper + + def get_random_transformation(self, data, training=True, seed=None): + return None + + def transform_images(self, images, transformation, training=True): + raise NotImplementedError() + + def transform_labels(self, labels, transformation, training=True): + raise NotImplementedError() + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + raise NotImplementedError() + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + raise NotImplementedError() + + def transform_single_image(self, image, transformation, training=True): + images = self.backend.numpy.expand_dims(image, axis=0) + outputs = self.transform_images( + images, transformation=transformation, training=training + ) + return self.backend.numpy.squeeze(outputs, axis=0) + + def transform_single_label(self, label, transformation, training=True): + labels = self.backend.numpy.expand_dims(label, axis=0) + outputs = self.transform_labels( + labels, transformation=transformation, training=training + ) + return self.backend.numpy.squeeze(outputs, axis=0) + + def transform_single_bounding_box( + self, + bounding_box, + transformation, + training=True, + ): + bounding_boxes = self._format_single_input_bounding_box(bounding_box) + outputs = self.transform_bounding_boxes( + bounding_boxes, + transformation=transformation, + training=training, + ) + bounding_box = self._format_single_output_bounding_box(outputs) + return bounding_box + + def transform_single_segmentation_mask( + self, segmentation_mask, transformation, training=True + ): + segmentation_masks = self.backend.numpy.expand_dims( + segmentation_mask, axis=0 + ) + outputs = self.transform_segmentation_masks( + segmentation_masks, transformation=transformation, training=training + ) + return self.backend.numpy.squeeze(outputs, axis=0) + + def _is_batched(self, maybe_image_batch): + shape = self.backend.core.shape(maybe_image_batch) + if len(shape) == 3: + return False + if len(shape) == 4: + return True + raise ValueError( + "Expected image tensor to have rank 3 (single image) " + f"or 4 (batch of images). Received: data.shape={shape}" + ) + + def call(self, data, training=True): + transformation = self.get_random_transformation(data, training=training) + if isinstance(data, dict): + is_batched = self._is_batched(data["images"]) + if is_batched: + data["images"] = self.transform_images( + self.backend.convert_to_tensor(data["images"]), + transformation=transformation, + training=training, + ) + else: + data["images"] = self.transform_single_image( + self.backend.convert_to_tensor(data["images"]), + transformation=transformation, + training=training, + ) + if "bounding_boxes" in data: + if not self.bounding_box_format: + raise ValueError( + "You passed an input with a 'bounding_boxes' key, " + "but you didn't specify a bounding box format. " + "Pass a `bounding_box_format` argument to your " + f"{self.__class__.__name__} layer, e.g. " + "`bounding_box_format='xyxy'`." + ) + bounding_boxes = densify_bounding_boxes( + data["bounding_boxes"], + is_batched=is_batched, + backend=self.backend, + ) + + if is_batched: + data["bounding_boxes"] = self.transform_bounding_boxes( + bounding_boxes, + transformation=transformation, + training=training, + ) + else: + data["bounding_boxes"] = self.transform_single_bounding_box( + bounding_boxes, + transformation=transformation, + training=training, + ) + if "labels" in data: + if is_batched: + data["labels"] = self.transform_labels( + self.backend.convert_to_tensor(data["labels"]), + transformation=transformation, + training=training, + ) + else: + data["labels"] = self.transform_single_label( + self.backend.convert_to_tensor(data["labels"]), + transformation=transformation, + training=training, + ) + if "segmentation_masks" in data: + if is_batched: + data["segmentation_masks"] = ( + self.transform_segmentation_masks( + data["segmentation_masks"], + transformation=transformation, + training=training, + ) + ) + else: + data["segmentation_masks"] = ( + self.transform_single_segmentation_mask( + data["segmentation_masks"], + transformation=transformation, + training=training, + ) + ) + return data + + # `data` is just images. + if self._is_batched(data): + return self.transform_images( + self.backend.convert_to_tensor(data), + transformation=transformation, + training=training, + ) + return self.transform_single_image( + self.backend.convert_to_tensor(data), + transformation=transformation, + training=training, + ) + + def _format_single_input_bounding_box(self, bounding_box): + for key in bounding_box: + if key == "labels": + bounding_box[key] = self.backend.numpy.expand_dims( + bounding_box[key], axis=0 + ) + if key == "boxes": + bounding_box[key] = self.backend.numpy.expand_dims( + bounding_box[key], axis=0 + ) + + return bounding_box + + def _format_single_output_bounding_box(self, bounding_boxes): + for key in bounding_boxes: + if key == "labels": + bounding_boxes[key] = self.backend.numpy.squeeze( + bounding_boxes[key], axis=0 + ) + if key == "boxes": + bounding_boxes[key] = self.backend.numpy.squeeze( + bounding_boxes[key], axis=0 + ) + + return bounding_boxes + + def get_config(self): + config = super().get_config() + if self.bounding_box_format is not None: + config.update( + { + "bounding_box_format": self.bounding_box_format, + } + ) + return config + + def _transform_value_range( + self, images, original_range, target_range, dtype="float32" + ): + """Convert input values from `original_range` to `target_range`. + + This function is intended to be used in preprocessing layers that + rely upon color values. This allows us to assume internally that + the input tensor is always in the range `(0, 255)`. + + Args: + images: the set of images to transform to the target range. + original_range: the value range to transform from. + target_range: the value range to transform to. + dtype: the dtype to compute the conversion with, + defaults to "float32". + + Returns: + a new Tensor with values in the target range. + + Example: + + ```python + original_range = [0, 1] + target_range = [0, 255] + images = layer.preprocessing.transform_value_range( + images, + original_range, + target_range + ) + images = ops.minimum(images + 10, 255) + images = layer.preprocessing.transform_value_range( + images, + target_range, + original_range + ) + ``` + """ + if ( + original_range[0] == target_range[0] + and original_range[1] == target_range[1] + ): + return images + + images = self.backend.cast(images, dtype=dtype) + original_min_value, original_max_value = self._unwrap_value_range( + original_range, dtype=dtype + ) + target_min_value, target_max_value = self._unwrap_value_range( + target_range, dtype=dtype + ) + + # images in the [0, 1] scale + images = (images - original_min_value) / ( + original_max_value - original_min_value + ) + + scale_factor = target_max_value - target_min_value + return (images * scale_factor) + target_min_value + + def _unwrap_value_range(self, value_range, dtype="float32"): + min_value, max_value = value_range + min_value = self.backend.cast(min_value, dtype=dtype) + max_value = self.backend.cast(max_value, dtype=dtype) + return min_value, max_value + + def _compute_affine_matrix( + self, + center_x, + center_y, + angle, + translate_x, + translate_y, + scale, + shear_x, + shear_y, + height, + width, + ): + """ + # Scaling Shear Rotation + # [sx 0 0] [1 shx 0] [cos(θ) -sin(θ) 0] + # M = [0 sy 0] * [shy 1 0] * [sin(θ) cos(θ) 0] + # [0 0 1] [0 0 1] [0 0 1] + + # a0 = sx * (cos(θ) + shx * sin(θ)) + # a1 = sx * (-sin(θ) + shx * cos(θ)) + # a2 = tx + cx - cx * a0 - cy * a1 + # b0 = sy * (shy * cos(θ) + sin(θ)) + # b1 = sy * (shy * -sin(θ) + cos(θ)) + # b2 = ty + cy - cx * b0 - cy * b1 + """ + ops = self.backend + + degree_to_radian_factor = ops.convert_to_tensor(math.pi / 180.0) + + angle = angle * degree_to_radian_factor + shear_x = shear_x * degree_to_radian_factor + shear_y = shear_y * degree_to_radian_factor + + batch_size = ops.shape(angle)[0] + dtype = angle.dtype + width = ops.cast(width, dtype) + height = ops.cast(height, dtype) + cx = center_x * (width - 1) + cy = center_y * (height - 1) + + cos_theta = ops.numpy.cos(angle) + sin_theta = ops.numpy.sin(angle) + shear_x = ops.numpy.tan(shear_x) + shear_y = ops.numpy.tan(shear_y) + + a0 = scale * (cos_theta + shear_x * sin_theta) + a1 = scale * (-sin_theta + shear_x * cos_theta) + a2 = translate_x + cx - cx * a0 - cy * a1 + b0 = scale * (shear_y * cos_theta + sin_theta) + b1 = scale * (shear_y * -sin_theta + cos_theta) + b2 = translate_y + cy - cx * b0 - cy * b1 + affine_matrix = ops.numpy.concatenate( + [ + a0[:, None], + a1[:, None], + a2[:, None], + b0[:, None], + b1[:, None], + b2[:, None], + ops.numpy.zeros((batch_size, 2)), + ], + axis=1, + ) + + return affine_matrix diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/__init__.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/bounding_box.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/bounding_box.py new file mode 100644 index 000000000000..1c9515bd1f62 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/bounding_box.py @@ -0,0 +1,468 @@ +import math + +from keras.src.utils import backend_utils + +SUPPORTED_FORMATS = ( + "xyxy", + "yxyx", + "xywh", + "center_xywh", + "center_yxhw", + "rel_xyxy", + "rel_yxyx", + "rel_xywh", + "rel_center_xywh", +) + + +class BoundingBox: + def __init__(self): + self.backend = backend_utils.DynamicBackend() + + def convert_format( + self, + boxes, + source, + target, + height=None, + width=None, + dtype="float32", + ): + if isinstance(boxes, dict): + boxes["boxes"] = self.convert_format( + boxes["boxes"], + source=source, + target=target, + height=height, + width=width, + dtype=dtype, + ) + return boxes + + to_xyxy_converters = { + "xyxy": self._xyxy_to_xyxy, + "yxyx": self._yxyx_to_xyxy, + "xywh": self._xywh_to_xyxy, + "center_xywh": self._center_xywh_to_xyxy, + "center_yxhw": self._center_yxhw_to_xyxy, + "rel_xyxy": self._rel_xyxy_to_xyxy, + "rel_yxyx": self._rel_yxyx_to_xyxy, + "rel_xywh": self._rel_xywh_to_xyxy, + "rel_center_xywh": self._rel_center_xywh_to_xyxy, + } + from_xyxy_converters = { + "xyxy": self._xyxy_to_xyxy, + "yxyx": self._xyxy_to_yxyx, + "xywh": self._xyxy_to_xywh, + "center_xywh": self._xyxy_to_center_xywh, + "center_yxhw": self._xyxy_to_center_yxhw, + "rel_xyxy": self._xyxy_to_rel_xyxy, + "rel_yxyx": self._xyxy_to_rel_yxyx, + "rel_xywh": self._xyxy_to_rel_xywh, + "rel_center_xywh": self._xyxy_to_rel_center_xywh, + } + + ops = self.backend + boxes_shape = ops.shape(boxes) + if boxes_shape[-1] != 4: + raise ValueError( + "`boxes` must be a tensor with the last dimension of 4. " + f"Received: boxes.shape={boxes_shape}" + ) + source = source.lower() + target = target.lower() + if source not in SUPPORTED_FORMATS or target not in SUPPORTED_FORMATS: + raise ValueError( + f"Invalid source or target format. " + f"Supported formats: {SUPPORTED_FORMATS}" + ) + + if (source.startswith("rel_") or target.startswith("rel_")) and ( + width is None or height is None + ): + raise ValueError( + "convert_format() must receive `height` and `width` " + "transforming between relative and absolute formats." + f"convert_format() received source=`{source}`, " + f"target=`{target}, " + f"but height={height} and width={width}." + ) + boxes = ops.cast(boxes, dtype) + if source == target: + return boxes + if width is not None: + width = ops.cast(width, dtype) + if height is not None: + height = ops.cast(height, dtype) + + if source.startswith("rel_") and target.startswith("rel_"): + source = source.replace("rel_", "", 1) + target = target.replace("rel_", "", 1) + to_xyxy_converter = to_xyxy_converters[source] + from_xyxy_converter = from_xyxy_converters[target] + in_xyxy_boxes = to_xyxy_converter(boxes, height, width) + return from_xyxy_converter(in_xyxy_boxes, height, width) + + def clip_to_image_size( + self, + bounding_boxes, + height=None, + width=None, + bounding_box_format="xyxy", + ): + if bounding_box_format not in ("xyxy", "rel_xyxy"): + raise NotImplementedError + if bounding_box_format == "xyxy" and (height is None or width is None): + raise ValueError( + "`height` and `width` must be set if `format='xyxy'`." + ) + + ops = self.backend + boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"] + if width is not None: + width = ops.cast(width, boxes.dtype) + if height is not None: + height = ops.cast(height, boxes.dtype) + + if bounding_box_format == "xyxy": + x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1) + x1 = ops.numpy.clip(x1, 0, width) + y1 = ops.numpy.clip(y1, 0, height) + x2 = ops.numpy.clip(x2, 0, width) + y2 = ops.numpy.clip(y2, 0, height) + boxes = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + areas = self._compute_area(boxes) + areas = ops.numpy.squeeze(areas, axis=-1) + labels = ops.numpy.where(areas > 0, labels, -1) + elif bounding_box_format == "rel_xyxy": + x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1) + x1 = ops.numpy.clip(x1, 0.0, 1.0) + y1 = ops.numpy.clip(y1, 0.0, 1.0) + x2 = ops.numpy.clip(x2, 0.0, 1.0) + y2 = ops.numpy.clip(y2, 0.0, 1.0) + boxes = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + areas = self._compute_area(boxes) + areas = ops.numpy.squeeze(areas, axis=-1) + labels = ops.numpy.where(areas > 0, labels, -1) + + result = bounding_boxes.copy() + result["boxes"] = boxes + result["labels"] = labels + return result + + def affine( + self, + boxes, + angle, + translate_x, + translate_y, + scale, + shear_x, + shear_y, + height, + width, + center_x=None, + center_y=None, + ): + ops = self.backend + + boxes_shape = ops.shape(boxes) + batch_size = boxes_shape[0] + n_boxes = boxes_shape[1] + if center_x is None: + center_x = 0.5 + if center_y is None: + center_y = 0.5 + matrix = self._compute_inverse_affine_matrix( + center_x, + center_y, + angle, + translate_x, + translate_y, + scale, + shear_x, + shear_y, + height, + width, + ) + boxes = ops.cast(boxes, dtype=matrix.dtype) + transposed_matrix = ops.numpy.transpose(matrix[:, :2, :], [0, 2, 1]) + points = boxes # [B, N, 4] + points = ops.numpy.stack( + [ + points[..., 0], + points[..., 1], + points[..., 2], + points[..., 1], + points[..., 2], + points[..., 3], + points[..., 0], + points[..., 3], + ], + axis=-1, + ) + points = ops.numpy.reshape(points, [batch_size, n_boxes, 4, 2]) + points = ops.numpy.concatenate( + [ + points, + ops.numpy.ones([batch_size, n_boxes, 4, 1], points.dtype), + ], + axis=-1, + ) + transformed_points = ops.numpy.einsum( + "bnxy,byz->bnxz", points, transposed_matrix + ) + boxes_min = ops.numpy.amin(transformed_points, axis=2) + boxes_max = ops.numpy.amax(transformed_points, axis=2) + outputs = ops.numpy.concatenate([boxes_min, boxes_max], axis=-1) + return outputs + + def crop(self, boxes, top, left, height, width): + ops = self.backend + + x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1) + x1 = x1 - left + y1 = y1 - top + x2 = x2 - left + y2 = y2 - top + x1 = ops.numpy.clip(x1, 0, width) + y1 = ops.numpy.clip(y1, 0, height) + x2 = ops.numpy.clip(x2, 0, width) + y2 = ops.numpy.clip(y2, 0, height) + outputs = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1) + return outputs + + def pad(self, boxes, top, left): + ops = self.backend + + x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1) + x1 = x1 + left + y1 = y1 + top + x2 = x2 + left + y2 = y2 + top + outputs = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1) + return outputs + + # Converters + + def _xyxy_to_xyxy(self, boxes, height=None, width=None): + return boxes + + def _yxyx_to_xyxy(self, boxes, height=None, width=None): + y1, x1, y2, x2 = self.backend.numpy.split(boxes, 4, axis=-1) + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _xywh_to_xyxy(self, boxes, height=None, width=None): + x1, y1, w, h = self.backend.numpy.split(boxes, 4, axis=-1) + x2 = x1 + w + y2 = y1 + h + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _center_xywh_to_xyxy(self, boxes, height=None, width=None): + ops = self.backend + cx, cy, w, h = ops.numpy.split(boxes, 4, axis=-1) + half_w = w / 2.0 + half_h = h / 2.0 + x1 = cx - half_w + y1 = cy - half_h + x2 = cx + half_w + y2 = cy + half_h + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _center_yxhw_to_xyxy(self, boxes, height=None, width=None): + ops = self.backend + cy, cx, h, w = ops.numpy.split(boxes, 4, axis=-1) + half_w = w / 2.0 + half_h = h / 2.0 + x1 = cx - half_w + y1 = cy - half_h + x2 = cx + half_w + y2 = cy + half_h + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _rel_xyxy_to_xyxy(self, boxes, height=None, width=None): + ops = self.backend + rel_x1, rel_y1, rel_x2, rel_y2 = ops.numpy.split(boxes, 4, axis=-1) + x1 = rel_x1 * width + y1 = rel_y1 * height + x2 = rel_x2 * width + y2 = rel_y2 * height + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _rel_yxyx_to_xyxy(self, boxes, height=None, width=None): + ops = self.backend + rel_y1, rel_x1, rel_y2, rel_x2 = ops.numpy.split(boxes, 4, axis=-1) + x1 = rel_x1 * width + y1 = rel_y1 * height + x2 = rel_x2 * width + y2 = rel_y2 * height + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _rel_xywh_to_xyxy(self, boxes, height=None, width=None): + ops = self.backend + rel_x1, rel_y1, rel_w, rel_h = ops.numpy.split(boxes, 4, axis=-1) + x1 = rel_x1 * width + y1 = rel_y1 * height + x2 = (rel_x1 + rel_w) * width + y2 = (rel_y1 + rel_h) * height + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _rel_center_xywh_to_xyxy(self, boxes, height=None, width=None): + ops = self.backend + rel_cx, rel_cy, rel_w, rel_h = ops.numpy.split(boxes, 4, axis=-1) + half_rel_w = rel_w / 2.0 + half_rel_h = rel_h / 2.0 + x1 = (rel_cx - half_rel_w) * height + y1 = (rel_cy - half_rel_h) * width + x2 = (rel_cx + half_rel_w) * height + y2 = (rel_cy + half_rel_h) * width + return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1) + + def _xyxy_to_yxyx(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + return self.backend.numpy.concatenate([y1, x1, y2, x2], axis=-1) + + def _xyxy_to_xywh(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + w = x2 - x1 + h = y2 - y1 + return self.backend.numpy.concatenate([x1, y1, w, h], axis=-1) + + def _xyxy_to_center_xywh(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + cx = x1 + ((x2 - x1) / 2.0) + cy = y1 + ((y2 - y1) / 2.0) + w = x2 - x1 + h = y2 - y1 + return self.backend.numpy.concatenate([cx, cy, w, h], axis=-1) + + def _xyxy_to_center_yxhw(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + cx = x1 + ((x2 - x1) / 2.0) + cy = y1 + ((y2 - y1) / 2.0) + w = x2 - x1 + h = y2 - y1 + return self.backend.numpy.concatenate([cy, cx, h, w], axis=-1) + + def _xyxy_to_rel_xyxy(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + rel_x1 = self.backend.numpy.divide(x1, width) + rel_y1 = self.backend.numpy.divide(y1, height) + rel_x2 = self.backend.numpy.divide(x2, width) + rel_y2 = self.backend.numpy.divide(y2, height) + return self.backend.numpy.concatenate( + [rel_x1, rel_y1, rel_x2, rel_y2], axis=-1 + ) + + def _xyxy_to_rel_yxyx(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + rel_x1 = self.backend.numpy.divide(x1, width) + rel_y1 = self.backend.numpy.divide(y1, height) + rel_x2 = self.backend.numpy.divide(x2, width) + rel_y2 = self.backend.numpy.divide(y2, height) + return self.backend.numpy.concatenate( + [rel_y1, rel_x1, rel_y2, rel_x2], axis=-1 + ) + + def _xyxy_to_rel_xywh(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + rel_x1 = x1 / width + rel_y1 = y1 / height + rel_w = (x2 - x1) / width + rel_h = (y2 - y1) / height + return self.backend.numpy.concatenate( + [rel_x1, rel_y1, rel_w, rel_h], axis=-1 + ) + + def _xyxy_to_rel_center_xywh(self, boxes, height=None, width=None): + x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1) + rel_cx = (x1 + ((x2 - x1) / 2.0)) / width + rel_cy = (y1 + ((y2 - y1) / 2.0)) / height + rel_w = (x2 - x1) / width + rel_h = (y2 - y1) / height + return self.backend.numpy.concatenate( + [rel_cx, rel_cy, rel_w, rel_h], axis=-1 + ) + + # Clip + def _compute_area(self, boxes, format="xyxy"): + if format not in ("xyxy", "rel_xyxy"): + raise NotImplementedError + + ops = self.backend + x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1) + widths = x2 - x1 + heights = y2 - y1 + return widths * heights + + def _compute_inverse_affine_matrix( + self, + center_x, + center_y, + angle, + translate_x, + translate_y, + scale, + shear_x, + shear_y, + height, + width, + ): + # Ref: TF._geometry._get_inverse_affine_matrix + ops = self.backend + batch_size = ops.shape(angle)[0] + dtype = angle.dtype + + angle = -angle + shear_x = -shear_x + shear_y = -shear_y + + cx = ops.numpy.multiply(center_x, (width - 1)) + cy = ops.numpy.multiply(center_y, (height - 1)) + rot = ops.numpy.multiply(angle, 1.0 / 180.0 * math.pi) + tx = ops.numpy.multiply(-translate_x, (width - 1)) + ty = ops.numpy.multiply(-translate_y, (height - 1)) + sx = ops.numpy.multiply(shear_x, 1.0 / 180.0 * math.pi) + sy = ops.numpy.multiply(shear_y, 1.0 / 180.0 * math.pi) + + # Cached results + cos_sy = ops.numpy.cos(sy) + tan_sx = ops.numpy.tan(sx) + rot_minus_sy = rot - sy + cx_plus_tx = cx + tx + cy_plus_ty = cy + ty + + # Rotate Scale Shear (RSS) without scaling + a = ops.numpy.cos(rot_minus_sy) / cos_sy + b = a * tan_sx + ops.numpy.sin(rot) + c = -ops.numpy.sin(rot_minus_sy) / cos_sy + d = ops.numpy.cos(rot) - c * tan_sx + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + a0 = ops.numpy.multiply(d, scale) + a1 = ops.numpy.multiply(-b, scale) + b0 = ops.numpy.multiply(-c, scale) + b1 = ops.numpy.multiply(a, scale) + a2 = cx - a0 * cx_plus_tx - a1 * cy_plus_ty + b2 = cy - b0 * cx_plus_tx - b1 * cy_plus_ty + + # Shape of matrix: [[batch_size], ...] -> [batch_size, 6] + matrix = ops.numpy.stack( + [ + a0, + a1, + a2, + b0, + b1, + b2, + ops.numpy.zeros([batch_size], dtype), + ops.numpy.zeros([batch_size], dtype), + ops.numpy.ones([batch_size], dtype), + ], + axis=-1, + ) + matrix = ops.numpy.reshape(matrix, [batch_size, 3, 3]) + return matrix diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters.py new file mode 100644 index 000000000000..6a6d6f9867b9 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters.py @@ -0,0 +1,448 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.bounding_box import ( # noqa: E501 + BoundingBox, +) +from keras.src.utils import backend_utils + + +@keras_export("keras.utils.bounding_boxes.convert_format") +def convert_format( + boxes, source, target, height=None, width=None, dtype="float32" +): + """Converts bounding boxes between formats. + + Supported formats (case-insensitive): + `"xyxy"`: [left, top, right, bottom] + `"yxyx"`: [top, left, bottom, right] + `"xywh"`: [left, top, width, height] + `"center_xywh"`: [center_x, center_y, width, height] + `"center_yxhw"`: [center_y, center_x, height, width] + `"rel_xyxy"`, `"rel_yxyx"`, `"rel_xywh"`, `"rel_center_xywh"`: Relative + versions of the above formats, where coordinates are normalized + to the range [0, 1] based on the image `height` and `width`. + + Args: + boxes: Bounding boxes tensor/array or dictionary of `boxes` and + `labels`. + source: Source format string. + target: Target format string. + height: Image height (required for relative target format). + width: Image width (required for relative target format). + dtype: Data type for conversion (optional). + + Returns: + Converted boxes. + + Raises: + ValueError: For invalid formats, shapes, or missing dimensions. + + Example: + ```python + boxes = np.array([[10, 20, 30, 40], [50, 60, 70, 80]]) + # Convert from 'xyxy' to 'xywh' format + boxes_xywh = keras.utils.bounding_boxes.convert_format( + boxes, source='xyxy', target='xywh' + ) # Output: [[10. 20. 20. 20.], [50. 60. 20. 20.]] + + # Convert to relative 'rel_xyxy' format + boxes_rel_xyxy = keras.utils.bounding_boxes.convert_format( + boxes, source='xyxy', target='rel_xyxy', height=200, width=300 + ) # Output: [[0.03333334 0.1 0.1 0.2 ], + #[0.16666667 0.3 0.23333333 0.4 ]] + ``` + """ + box_utils = BoundingBox() + # Switch to tensorflow backend if we are in tf.data pipe + if backend_utils.in_tf_graph(): + box_utils.backend.set_backend("tensorflow") + boxes = box_utils.convert_format( + boxes=boxes, + source=source, + target=target, + height=height, + width=width, + dtype=dtype, + ) + # Switch back to original backend + box_utils.backend.reset() + return boxes + + +@keras_export("keras.utils.bounding_boxes.clip_to_image_size") +def clip_to_image_size( + bounding_boxes, height=None, width=None, bounding_box_format="xyxy" +): + """Clips bounding boxes to be within the image dimensions. + Args: + bounding_boxes: A dictionary with 'boxes' shape `(N, 4)` or + `(batch, N, 4)` and 'labels' shape `(N,)` or `(batch, N,)`. + height: Image height. + width: Image width. + bounding_box_format: The format of the input bounding boxes. Defaults to + `"xyxy"`. + + Returns: + Clipped bounding boxes. + + Example: + ```python + boxes = {"boxes": np.array([[-10, -20, 150, 160], [50, 40, 70, 80]]), + "labels": np.array([0, 1])} + clipped_boxes = keras.utils.bounding_boxes.clip_to_image_size( + boxes, height=100, width=120, + ) + # Output will have boxes clipped to the image boundaries, and labels + # potentially adjusted if the clipped area becomes zero + ``` + """ + + box_utils = BoundingBox() + # Switch to tensorflow backend if we are in tf.data pipe + if backend_utils.in_tf_graph(): + box_utils.backend.set_backend("tensorflow") + bounding_boxes = box_utils.clip_to_image_size( + bounding_boxes, + height=height, + width=width, + bounding_box_format=bounding_box_format, + ) + # Switch back to original backend + box_utils.backend.reset() + return bounding_boxes + + +@keras_export("keras.utils.bounding_boxes.affine_transform") +def affine_transform( + boxes, + angle, + translate_x, + translate_y, + scale, + shear_x, + shear_y, + height, + width, + center_x=None, + center_y=None, + bounding_box_format="xyxy", +): + """Applies an affine transformation to the bounding boxes. + + The `height` and `width` parameters are used to normalize the + translation and scaling factors. + + Args: + boxes: The bounding boxes to transform, a tensor/array of shape + `(N, 4)` or `(batch_size, N, 4)`. + angle: Rotation angle in degrees. + translate_x: Horizontal translation fraction. + translate_y: Vertical translation fraction. + scale: Scaling factor. + shear_x: Shear angle in x-direction (degrees). + shear_y: Shear angle in y-direction (degrees). + height: Height of the image/data. + width: Width of the image/data. + center_x: x-coordinate of the transformation center (fraction). + center_y: y-coordinate of the transformation center (fraction). + bounding_box_format: The format of the input bounding boxes. Defaults to + `"xyxy"`. + + Returns: + The transformed bounding boxes, a tensor/array with the same shape + as the input `boxes`. + """ + if bounding_box_format != "xyxy": + raise NotImplementedError + box_utils = BoundingBox() + # Switch to tensorflow backend if we are in tf.data pipe + if backend_utils.in_tf_graph(): + box_utils.backend.set_backend("tensorflow") + + boxes = box_utils.affine( + boxes, + angle, + translate_x, + translate_y, + scale, + shear_x, + shear_y, + height, + width, + center_x=center_x, + center_y=center_y, + ) + box_utils.backend.reset() + return boxes + + +@keras_export("keras.utils.bounding_boxes.crop") +def crop(boxes, top, left, height, width, bounding_box_format="xyxy"): + """Crops bounding boxes based on the given offsets and dimensions. + + This function crops bounding boxes to a specified region defined by + `top`, `left`, `height`, and `width`. The boxes are first converted to + `xyxy` format, cropped, and then returned. + + Args: + boxes: The bounding boxes to crop. A NumPy array or tensor of shape + `(N, 4)` or `(batch_size, N, 4)`. + top: The vertical offset of the top-left corner of the cropping region. + left: The horizontal offset of the top-left corner of the cropping + region. + height: The height of the cropping region. Defaults to `None`. + width: The width of the cropping region. Defaults to `None`. + bounding_box_format: The format of the input bounding boxes. Defaults to + `"xyxy"`. + + Returns: + The cropped bounding boxes. + + Example: + ```python + boxes = np.array([[10, 20, 50, 60], [70, 80, 100, 120]]) # xyxy format + cropped_boxes = keras.utils.bounding_boxes.crop( + boxes, bounding_box_format="xyxy", top=10, left=20, height=40, width=30 + ) # Cropping a 30x40 region starting at (20, 10) + print(cropped_boxes) + # Expected output: + # array([[ 0., 10., 30., 50.], + # [50., 70., 80., 110.]]) + """ + if bounding_box_format != "xyxy": + raise NotImplementedError + box_utils = BoundingBox() + # Switch to tensorflow backend if we are in tf.data pipe + if backend_utils.in_tf_graph(): + box_utils.backend.set_backend("tensorflow") + outputs = box_utils.crop(boxes, top, left, height, width) + box_utils.backend.reset() + return outputs + + +@keras_export("keras.utils.bounding_boxes.pad") +def pad(boxes, top, left, height=None, width=None, bounding_box_format="xyxy"): + """Pads bounding boxes by adding top and left offsets. + + This function adds padding to the bounding boxes by increasing the 'top' + and 'left' coordinates by the specified amounts. The method assume the + input bounding_box_format is `xyxy`. + + Args: + boxes: Bounding boxes to pad. Shape `(N, 4)` or `(batch, N, 4)`. + top: Vertical padding to add. + left: Horizontal padding to add. + height: Image height. Defaults to None. + width: Image width. Defaults to None. + bounding_box_format: The format of the input bounding boxes. Defaults to + `"xyxy"`. + + Returns: + Padded bounding boxes in the original format. + """ + if bounding_box_format != "xyxy": + raise NotImplementedError + box_utils = BoundingBox() + # Switch to tensorflow backend if we are in tf.data pipe + if backend_utils.in_tf_graph(): + box_utils.backend.set_backend("tensorflow") + outputs = box_utils.pad(boxes, top, left) + box_utils.backend.reset() + return outputs + + +@keras_export("keras.utils.bounding_boxes.encode_box_to_deltas") +def encode_box_to_deltas( + anchors, + boxes, + anchor_format, + box_format, + encoding_format="center_yxhw", + variance=None, + image_shape=None, +): + """Encodes bounding boxes relative to anchors as deltas. + + This function calculates the deltas that represent the difference between + bounding boxes and provided anchors. Deltas encode the offsets and scaling + factors to apply to anchors to obtain the target boxes. + + Boxes and anchors are first converted to the specified `encoding_format` + (defaulting to `center_yxhw`) for consistent delta representation. + + Args: + anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the + number of anchors. + boxes: `Tensors` Bounding boxes to encode. Boxes can be of shape + `(B, N, 4)` or `(N, 4)`. + anchor_format: str. The format of the input `anchors` + (e.g., "xyxy", "xywh", etc.). + box_format: str. The format of the input `boxes` + (e.g., "xyxy", "xywh", etc.). + encoding_format: str. The intermediate format to which boxes and anchors + are converted before delta calculation. Defaults to "center_yxhw". + variance: `List[float]`. A 4-element array/tensor representing variance + factors to scale the box deltas. If provided, the calculated deltas + are divided by the variance. Defaults to None. + image_shape: `Tuple[int]`. The shape of the image (height, width, 3). + When using relative bounding box format for `box_format` the + `image_shape` is used for normalization. + Returns: + Encoded box deltas. The return type matches the `encode_format`. + + Raises: + ValueError: If `variance` is not None and its length is not 4. + ValueError: If `encoding_format` is not `"center_xywh"` or + `"center_yxhw"`. + + """ + if variance is not None: + variance = ops.convert_to_tensor(variance, "float32") + var_len = variance.shape[-1] + + if var_len != 4: + raise ValueError(f"`variance` must be length 4, got {variance}") + + if encoding_format not in ["center_xywh", "center_yxhw"]: + raise ValueError( + "`encoding_format` should be one of 'center_xywh' or " + f"'center_yxhw', got {encoding_format}" + ) + + if image_shape is None: + height, width = None, None + else: + height, width, _ = image_shape + + encoded_anchors = convert_format( + anchors, + source=anchor_format, + target=encoding_format, + height=height, + width=width, + ) + boxes = convert_format( + boxes, + source=box_format, + target=encoding_format, + height=height, + width=width, + ) + anchor_dimensions = ops.maximum(encoded_anchors[..., 2:], backend.epsilon()) + box_dimensions = ops.maximum(boxes[..., 2:], backend.epsilon()) + # anchors be unbatched, boxes can either be batched or unbatched. + boxes_delta = ops.concatenate( + [ + (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions, + ops.log(box_dimensions / anchor_dimensions), + ], + axis=-1, + ) + if variance is not None: + boxes_delta /= variance + return boxes_delta + + +@keras_export("keras.utils.bounding_boxes.decode_deltas_to_boxes") +def decode_deltas_to_boxes( + anchors, + boxes_delta, + anchor_format, + box_format, + encoded_format="center_yxhw", + variance=None, + image_shape=None, +): + """Converts bounding boxes from delta format to the specified `box_format`. + + This function decodes bounding box deltas relative to anchors to obtain the + final bounding box coordinates. The boxes are encoded in a specific + `encoded_format` (center_yxhw by default) during the decoding process. + This allows flexibility in how the deltas are applied to the anchors. + + Args: + anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level + indices and values are corresponding anchor boxes. + The shape of the array/tensor should be `(N, 4)` where N is the + number of anchors. + boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas + must have the same type and structure as `anchors`. The + shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is + the number of boxes. + anchor_format: str. The format of the input `anchors`. + (e.g., `"xyxy"`, `"xywh"`, etc.) + box_format: str. The desired format for the output boxes. + (e.g., `"xyxy"`, `"xywh"`, etc.) + encoded_format: str. Raw output format from regression head. Defaults + to `"center_yxhw"`. + variance: `List[floats]`. A 4-element array/tensor representing + variance factors to scale the box deltas. If provided, the deltas + are multiplied by the variance before being applied to the anchors. + Defaults to None. + image_shape: `Tuple[int]`. The shape of the image (height, width, 3). + When using relative bounding box format for `box_format` the + `image_shape` is used for normalization. + + Returns: + Decoded box coordinates. The return type matches the `box_format`. + + Raises: + ValueError: If `variance` is not None and its length is not 4. + ValueError: If `encoded_format` is not `"center_xywh"` or + `"center_yxhw"`. + + """ + if variance is not None: + variance = ops.convert_to_tensor(variance, "float32") + var_len = variance.shape[-1] + + if var_len != 4: + raise ValueError(f"`variance` must be length 4, got {variance}") + + if encoded_format not in ["center_xywh", "center_yxhw"]: + raise ValueError( + f"`encoded_format` should be 'center_xywh' or 'center_yxhw', " + f"but got '{encoded_format}'." + ) + + if image_shape is None: + height, width = None, None + else: + height, width, _ = image_shape + + def decode_single_level(anchor, box_delta): + encoded_anchor = convert_format( + anchor, + source=anchor_format, + target=encoded_format, + height=height, + width=width, + ) + if variance is not None: + box_delta = box_delta * variance + # anchors be unbatched, boxes can either be batched or unbatched. + box = ops.concatenate( + [ + box_delta[..., :2] * encoded_anchor[..., 2:] + + encoded_anchor[..., :2], + ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:], + ], + axis=-1, + ) + box = convert_format( + box, + source=encoded_format, + target=box_format, + height=height, + width=width, + ) + return box + + if isinstance(anchors, dict) and isinstance(boxes_delta, dict): + boxes = {} + for lvl, anchor in anchors.items(): + boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl]) + return boxes + else: + return decode_single_level(anchors, boxes_delta) diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters_test.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters_test.py new file mode 100644 index 000000000000..9c6638698cc3 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters_test.py @@ -0,0 +1,144 @@ +import itertools + +import numpy as np +from absl.testing import parameterized + +from keras.src import ops +from keras.src import testing +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + affine_transform, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) + + +class ConvertersTest(testing.TestCase): + def setUp(self): + xyxy_box = np.array( + [[[10, 20, 110, 120], [20, 30, 120, 130]]], dtype="float32" + ) + yxyx_box = np.array( + [[[20, 10, 120, 110], [30, 20, 130, 120]]], dtype="float32" + ) + rel_xyxy_box = np.array( + [[[0.01, 0.02, 0.11, 0.12], [0.02, 0.03, 0.12, 0.13]]], + dtype="float32", + ) + rel_yxyx_box = np.array( + [[[0.02, 0.01, 0.12, 0.11], [0.03, 0.02, 0.13, 0.12]]], + dtype="float32", + ) + center_xywh_box = np.array( + [[[60, 70, 100, 100], [70, 80, 100, 100]]], dtype="float32" + ) + center_yxhw_box = np.array( + [[[70, 60, 100, 100], [80, 70, 100, 100]]], dtype="float32" + ) + xywh_box = np.array( + [[[10, 20, 100, 100], [20, 30, 100, 100]]], dtype="float32" + ) + rel_xywh_box = np.array( + [[[0.01, 0.02, 0.1, 0.1], [0.02, 0.03, 0.1, 0.1]]], dtype="float32" + ) + + self.images = np.ones([2, 1000, 1000, 3], dtype="float32") + self.height = 1000 + self.width = 1000 + + self.boxes = { + "xyxy": xyxy_box, + "center_xywh": center_xywh_box, + "rel_xywh": rel_xywh_box, + "xywh": xywh_box, + "rel_xyxy": rel_xyxy_box, + "yxyx": yxyx_box, + "rel_yxyx": rel_yxyx_box, + "center_yxhw": center_yxhw_box, + } + + @parameterized.named_parameters( + *[ + (f"{source}_{target}", source, target) + for (source, target) in itertools.permutations( + [ + "xyxy", + "yxyx", + "xywh", + "rel_xyxy", + "rel_yxyx", + "center_xywh", + "center_yxhw", + ], + 2, + ) + ] + + [("xyxy_xyxy", "xyxy", "xyxy")] + ) + def test_convert_all_formats(self, source, target): + source_box = self.boxes[source] + target_box = self.boxes[target] + self.assertAllClose( + convert_format( + source_box, + source=source, + target=target, + height=self.height, + width=self.width, + ), + target_box, + ) + + def test_convert_format_invalid_source(self): + boxes = self.boxes["xywh"] + with self.assertRaises(ValueError): + convert_format(boxes, source="invalid", target="xywh") + + def test_convert_format_invalid_target(self): + boxes = self.boxes["xyxy"] + with self.assertRaises(ValueError): + convert_format(boxes, source="xyxy", target="invalid") + + def test_convert_format_missing_dimensions(self): + boxes = self.boxes["xyxy"] + with self.assertRaisesRegex( + ValueError, r"must receive `height` and `width`" + ): + convert_format(boxes, source="xyxy", target="rel_xyxy") + + def test_clip_to_image_size(self): + boxes = { + "boxes": np.array([[0.0, 0.0, 1.5, 1.6], [0.5, 0.4, 0.7, 0.8]]), + "labels": np.array([0, 1]), + } + + expected_clipped = { + "boxes": np.array([[0.0, 0.0, 1.0, 1.0], [0.5, 0.4, 0.7, 0.8]]), + "labels": np.array([0, 1]), + } + + clipped_boxes = clip_to_image_size( + boxes, bounding_box_format="rel_xyxy" + ) + + self.assertAllEqual(clipped_boxes, expected_clipped) + + def test_affine_identity(self): + # Test identity transform (no change) + batch_size = self.boxes["xyxy"].shape[0] + transformed_boxes = affine_transform( + boxes=self.boxes["xyxy"], + angle=np.zeros([batch_size], dtype="float32"), + translate_x=np.zeros([batch_size], dtype="float32"), + translate_y=np.zeros([batch_size], dtype="float32"), + scale=np.ones([batch_size], dtype="float32"), + shear_x=np.zeros([batch_size], dtype="float32"), + shear_y=np.zeros([batch_size], dtype="float32"), + height=self.height, + width=self.width, + ) + transformed_boxes = ops.convert_to_numpy(transformed_boxes) + self.assertAllClose(self.boxes["xyxy"], transformed_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/formats.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/formats.py new file mode 100644 index 000000000000..38baf4964b1e --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/formats.py @@ -0,0 +1,135 @@ +class XYXY: + """XYXY contains axis indices for the XYXY format. + + All values in the XYXY format should be absolute pixel values. + + The XYXY format consists of the following required indices: + + - LEFT: left of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + """ + + LEFT = 0 + TOP = 1 + RIGHT = 2 + BOTTOM = 3 + + +class REL_XYXY: + """REL_XYXY contains axis indices for the REL_XYXY format. + + REL_XYXY is like XYXY, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + The REL_XYXY format consists of the following required indices: + + - LEFT: left of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + """ + + LEFT = 0 + TOP = 1 + RIGHT = 2 + BOTTOM = 3 + + +class CENTER_XYWH: + """CENTER_XYWH contains axis indices for the CENTER_XYWH format. + + All values in the CENTER_XYWH format should be absolute pixel values. + + The CENTER_XYWH format consists of the following required indices: + + - X: X coordinate of the center of the bounding box + - Y: Y coordinate of the center of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +class XYWH: + """XYWH contains axis indices for the XYWH format. + + All values in the XYWH format should be absolute pixel values. + + The XYWH format consists of the following required indices: + + - X: X coordinate of the left of the bounding box + - Y: Y coordinate of the top of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +class REL_XYWH: + """REL_XYWH contains axis indices for the XYWH format. + + REL_XYXY is like XYWH, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + - X: X coordinate of the left of the bounding box + - Y: Y coordinate of the top of the bounding box + - WIDTH: width of the bounding box + - HEIGHT: height of the bounding box + """ + + X = 0 + Y = 1 + WIDTH = 2 + HEIGHT = 3 + + +class YXYX: + """YXYX contains axis indices for the YXYX format. + + All values in the YXYX format should be absolute pixel values. + + The YXYX format consists of the following required indices: + + - TOP: top of the bounding box + - LEFT: left of the bounding box + - BOTTOM: bottom of the bounding box + - RIGHT: right of the bounding box + """ + + TOP = 0 + LEFT = 1 + BOTTOM = 2 + RIGHT = 3 + + +class REL_YXYX: + """REL_YXYX contains axis indices for the REL_YXYX format. + + REL_YXYX is like YXYX, but each value is relative to the width and height of + the origin image. Values are percentages of the origin images' width and + height respectively. + + The REL_YXYX format consists of the following required indices: + + - TOP: top of the bounding box + - LEFT: left of the bounding box + - BOTTOM: bottom of the bounding box + - RIGHT: right of the bounding box + """ + + TOP = 0 + LEFT = 1 + BOTTOM = 2 + RIGHT = 3 diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou.py new file mode 100644 index 000000000000..8e4006ea9713 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou.py @@ -0,0 +1,281 @@ +"""Contains functions to compute ious of bounding boxes.""" + +import math + +import keras +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import ( + converters, +) + + +def _compute_area(box): + """Computes area for bounding boxes + + Args: + box: [N, 4] or [batch_size, N, 4] float Tensor, either batched + or unbatched boxes. + Returns: + a float Tensor of [N] or [batch_size, N] + """ + y_min, x_min, y_max, x_max = ops.split(box[..., :4], 4, axis=-1) + return ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1) + + +def _compute_intersection(boxes1, boxes2): + """Computes intersection area between two sets of boxes. + + Args: + boxes1: [N, 4] or [batch_size, N, 4] float Tensor boxes. + boxes2: [M, 4] or [batch_size, M, 4] float Tensor boxes. + Returns: + a [N, M] or [batch_size, N, M] float Tensor. + """ + y_min1, x_min1, y_max1, x_max1 = ops.split(boxes1[..., :4], 4, axis=-1) + y_min2, x_min2, y_max2, x_max2 = ops.split(boxes2[..., :4], 4, axis=-1) + boxes2_rank = len(boxes2.shape) + perm = [1, 0] if boxes2_rank == 2 else [0, 2, 1] + # [N, M] or [batch_size, N, M] + intersect_ymax = ops.minimum(y_max1, ops.transpose(y_max2, perm)) + intersect_ymin = ops.maximum(y_min1, ops.transpose(y_min2, perm)) + intersect_xmax = ops.minimum(x_max1, ops.transpose(x_max2, perm)) + intersect_xmin = ops.maximum(x_min1, ops.transpose(x_min2, perm)) + + intersect_height = intersect_ymax - intersect_ymin + intersect_width = intersect_xmax - intersect_xmin + zeros_t = ops.cast(0, intersect_height.dtype) + intersect_height = ops.maximum(zeros_t, intersect_height) + intersect_width = ops.maximum(zeros_t, intersect_width) + + return intersect_height * intersect_width + + +@keras_export("keras.utils.bounding_boxes.compute_iou") +def compute_iou( + boxes1, + boxes2, + bounding_box_format, + use_masking=False, + mask_val=-1, + image_shape=None, +): + """Computes a lookup table vector containing the ious for a given set boxes. + + The lookup vector is to be indexed by [`boxes1_index`,`boxes2_index`] if + boxes are unbatched and by [`batch`, `boxes1_index`,`boxes2_index`] if the + boxes are batched. + + The users can pass `boxes1` and `boxes2` to be different ranks. For example: + 1) `boxes1`: [batch_size, M, 4], `boxes2`: [batch_size, N, 4] -> return + [batch_size, M, N]. + 2) `boxes1`: [batch_size, M, 4], `boxes2`: [N, 4] -> return + [batch_size, M, N] + 3) `boxes1`: [M, 4], `boxes2`: [batch_size, N, 4] -> return + [batch_size, M, N] + 4) `boxes1`: [M, 4], `boxes2`: [N, 4] -> return [M, N] + + Args: + boxes1: a list of bounding boxes in 'corners' format. Can be batched or + unbatched. + boxes2: a list of bounding boxes in 'corners' format. Can be batched or + unbatched. + bounding_box_format: a case-insensitive string which is one of `"xyxy"`, + `"rel_xyxy"`, `"xyWH"`, `"center_xyWH"`, `"yxyx"`, `"rel_yxyx"`. + For detailed information on the supported format, see the + use_masking: whether masking will be applied. This will mask all + `boxes1` or `boxes2` that have values less than 0 in all its 4 + dimensions. Default to `False`. + mask_val: int to mask those returned IOUs if the masking is True, + defaults to -1. + image_shape: `Tuple[int]`. The shape of the image (height, width, 3). + When using relative bounding box format for `box_format` the + `image_shape` is used for normalization. + + Returns: + iou_lookup_table: a vector containing the pairwise ious of boxes1 and + boxes2. + """ # noqa: E501 + + boxes1_rank = len(ops.shape(boxes1)) + boxes2_rank = len(ops.shape(boxes2)) + + if boxes1_rank not in [2, 3]: + raise ValueError( + "compute_iou() expects boxes1 to be batched, or to be unbatched. " + f"Received len(boxes1.shape)={boxes1_rank}, " + f"len(boxes2.shape)={boxes2_rank}. Expected either " + "len(boxes1.shape)=2 AND or len(boxes1.shape)=3." + ) + if boxes2_rank not in [2, 3]: + raise ValueError( + "compute_iou() expects boxes2 to be batched, or to be unbatched. " + f"Received len(boxes1.shape)={boxes1_rank}, " + f"len(boxes2.shape)={boxes2_rank}. Expected either " + "len(boxes2.shape)=2 AND or len(boxes2.shape)=3." + ) + + target_format = "yxyx" + if "rel" in bounding_box_format and image_shape is None: + raise ValueError( + "When using relative bounding box formats (e.g. `rel_yxyx`) " + "the `image_shape` argument must be provided." + f"Received `image_shape`: {image_shape}" + ) + + if image_shape is None: + height, width = None, None + else: + height, width, _ = image_shape + + boxes1 = converters.convert_format( + boxes1, + source=bounding_box_format, + target=target_format, + height=height, + width=width, + ) + + boxes2 = converters.convert_format( + boxes2, + source=bounding_box_format, + target=target_format, + height=height, + width=width, + ) + + intersect_area = _compute_intersection(boxes1, boxes2) + boxes1_area = _compute_area(boxes1) + boxes2_area = _compute_area(boxes2) + boxes2_area_rank = len(boxes2_area.shape) + boxes2_axis = 1 if (boxes2_area_rank == 2) else 0 + boxes1_area = ops.expand_dims(boxes1_area, axis=-1) + boxes2_area = ops.expand_dims(boxes2_area, axis=boxes2_axis) + union_area = boxes1_area + boxes2_area - intersect_area + res = ops.divide(intersect_area, union_area + backend.epsilon()) + + if boxes1_rank == 2: + perm = [1, 0] + else: + perm = [0, 2, 1] + + if not use_masking: + return res + + mask_val_t = ops.cast(mask_val, res.dtype) * ops.ones_like(res) + boxes1_mask = ops.less(ops.max(boxes1, axis=-1, keepdims=True), 0.0) + boxes2_mask = ops.less(ops.max(boxes2, axis=-1, keepdims=True), 0.0) + background_mask = ops.logical_or( + boxes1_mask, ops.transpose(boxes2_mask, perm) + ) + iou_lookup_table = ops.where(background_mask, mask_val_t, res) + return iou_lookup_table + + +@keras_export("keras.utils.bounding_boxes.compute_ciou") +def compute_ciou(boxes1, boxes2, bounding_box_format, image_shape=None): + """ + Computes the Complete IoU (CIoU) between two bounding boxes or between + two batches of bounding boxes. + + CIoU loss is an extension of GIoU loss, which further improves the IoU + optimization for object detection. CIoU loss not only penalizes the + bounding box coordinates but also considers the aspect ratio and center + distance of the boxes. The length of the last dimension should be 4 to + represent the bounding boxes. + + Args: + box1 (tensor): tensor representing the first bounding box with + shape (..., 4). + box2 (tensor): tensor representing the second bounding box with + shape (..., 4). + bounding_box_format: a case-insensitive string (for example, "xyxy"). + Each bounding box is defined by these 4 values. For detailed + information on the supported formats, see the [KerasCV bounding box + documentation](https://keras.io/api/keras_cv/bounding_box/formats/). + image_shape: `Tuple[int]`. The shape of the image (height, width, 3). + When using relative bounding box format for `box_format` the + `image_shape` is used for normalization. + + Returns: + tensor: The CIoU distance between the two bounding boxes. + """ + target_format = "xyxy" + if "rel" in bounding_box_format: + raise ValueError( + "When using relative bounding box formats (e.g. `rel_yxyx`) " + "the `image_shape` argument must be provided." + f"Received `image_shape`: {image_shape}" + ) + + if image_shape is None: + height, width = None, None + else: + height, width, _ = image_shape + + boxes1 = converters.convert_format( + boxes1, + source=bounding_box_format, + target=target_format, + height=height, + width=width, + ) + + boxes2 = converters.convert_format( + boxes2, + source=bounding_box_format, + target=target_format, + height=height, + width=width, + ) + + x_min1, y_min1, x_max1, y_max1 = ops.split(boxes1[..., :4], 4, axis=-1) + x_min2, y_min2, x_max2, y_max2 = ops.split(boxes2[..., :4], 4, axis=-1) + + width_1 = x_max1 - x_min1 + height_1 = y_max1 - y_min1 + keras.backend.epsilon() + width_2 = x_max2 - x_min2 + height_2 = y_max2 - y_min2 + keras.backend.epsilon() + + intersection_area = ops.maximum( + ops.minimum(x_max1, x_max2) - ops.maximum(x_min1, x_min2), 0 + ) * ops.maximum( + ops.minimum(y_max1, y_max2) - ops.maximum(y_min1, y_min2), 0 + ) + union_area = ( + width_1 * height_1 + + width_2 * height_2 + - intersection_area + + keras.backend.epsilon() + ) + iou = ops.squeeze( + ops.divide(intersection_area, union_area + keras.backend.epsilon()), + axis=-1, + ) + + convex_width = ops.maximum(x_max1, x_max2) - ops.minimum(x_min1, x_min2) + convex_height = ops.maximum(y_max1, y_max2) - ops.minimum(y_min1, y_min2) + convex_diagonal_squared = ops.squeeze( + convex_width**2 + convex_height**2 + keras.backend.epsilon(), + axis=-1, + ) + centers_distance_squared = ops.squeeze( + ((x_min1 + x_max1) / 2 - (x_min2 + x_max2) / 2) ** 2 + + ((y_min1 + y_max1) / 2 - (y_min2 + y_max2) / 2) ** 2, + axis=-1, + ) + + v = ops.squeeze( + (4 / math.pi**2) + * ops.power( + (ops.arctan(width_2 / height_2) - ops.arctan(width_1 / height_1)), + 2, + ), + axis=-1, + ) + alpha = v / (v - iou + (1 + keras.backend.epsilon())) + + return iou - ( + centers_distance_squared / convex_diagonal_squared + v * alpha + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou_test.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou_test.py new file mode 100644 index 000000000000..d66267f91ef5 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou_test.py @@ -0,0 +1,233 @@ +"""Tests for iou functions.""" + +import numpy as np + +from keras.src import testing +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import ( + iou as iou_lib, +) + + +class IoUTest(testing.TestCase): + def test_compute_single_iou(self): + bb1 = np.array([[100, 101, 200, 201]]) + bb1_off_by_1 = np.array([[101, 102, 201, 202]]) + # area of bb1 and bb1_off_by_1 are each 10000. + # intersection area is 99*99=9801 + # iou=9801/(2*10000 - 9801)=0.96097656633 + self.assertAllClose( + iou_lib.compute_iou(bb1, bb1_off_by_1, "yxyx")[0], [0.96097656633] + ) + + def test_compute_iou(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=np.float32, + ) + + sample_y_true = np.array([bb1, top_left_bounding_box, far_away_box]) + sample_y_pred = np.array( + [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_batched_compute_iou(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [ + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + ], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_batched_boxes1_unbatched_boxes2(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + def test_unbatched_boxes1_batched_boxes2(self): + bb1 = [100, 101, 200, 201] + bb1_off_by_1_pred = [101, 102, 201, 202] + iou_bb1_bb1_off = 0.96097656633 + top_left_bounding_box = [0, 2, 1, 3] + far_away_box = [1300, 1400, 1500, 1401] + another_far_away_pred = [1000, 1400, 1200, 1401] + + # Rows represent predictions, columns ground truths + expected_result = np.array( + [ + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + ) + + sample_y_true = np.array( + [ + [bb1, top_left_bounding_box, far_away_box], + ], + ) + sample_y_pred = np.array( + [ + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + [ + bb1_off_by_1_pred, + top_left_bounding_box, + another_far_away_pred, + ], + ], + ) + + result = iou_lib.compute_iou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(expected_result, result) + + +class CIoUTest(testing.TestCase): + def test_compute_single_ciou(self): + bb1 = np.array([[100, 101, 200, 201]]) + bb2 = np.array([[101, 102, 201, 202]]) + self.assertAllClose( + iou_lib.compute_ciou(bb1, bb2, "yxyx")[0], [0.96087853672] + ) + + def test_compute_ciou(self): + bb1 = np.array([100, 101, 200, 201]) + bb2 = np.array([150, 150, 250, 250]) + ciou_bb1_bb2 = 0.036492417 + + # non overlapping case + far_away_bb1 = np.array([1000, 1000, 1500, 1500]) + far_away_bb2 = np.array([2000, 2000, 2500, 2500]) + ciou_far_away_bb1_bb2 = -0.44444444435 + + sample_y_true = np.array([bb1, far_away_bb1]) + sample_y_pred = np.array([bb2, far_away_bb2]) + + result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(ciou_bb1_bb2, result[0]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[1]) + + def test_batched_compute_ciou(self): + bb1 = np.array([100, 101, 200, 201]) + bb2 = np.array([150, 150, 250, 250]) + ciou_bb1_bb2 = 0.036492417 + + # non overlapping case + far_away_bb1 = np.array([1000, 1000, 1500, 1500]) + far_away_bb2 = np.array([2000, 2000, 2500, 2500]) + ciou_far_away_bb1_bb2 = -0.44444444435 + + sample_y_true = np.array([[bb1, far_away_bb1], [bb1, far_away_bb1]]) + sample_y_pred = np.array([[bb2, far_away_bb2], [bb2, far_away_bb2]]) + + result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(ciou_bb1_bb2, result[0][0]) + self.assertAllClose(ciou_bb1_bb2, result[1][0]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[0][1]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[1][1]) + + def test_batched_boxes1_unbatched_boxes2(self): + bb1 = np.array([100, 101, 200, 201]) + bb2 = np.array([150, 150, 250, 250]) + ciou_bb1_bb2 = 0.036492417 + + # non overlapping case + far_away_bb1 = np.array([1000, 1000, 1500, 1500]) + far_away_bb2 = np.array([2000, 2000, 2500, 2500]) + ciou_far_away_bb1_bb2 = -0.44444444435 + + sample_y_true = np.array([[bb1, far_away_bb1], [bb1, far_away_bb1]]) + sample_y_pred = np.array([bb2, far_away_bb2]) + + result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(ciou_bb1_bb2, result[0][0]) + self.assertAllClose(ciou_bb1_bb2, result[1][0]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[0][1]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[1][1]) + + def test_unbatched_boxes1_batched_boxes2(self): + bb1 = np.array([100, 101, 200, 201]) + bb2 = np.array([150, 150, 250, 250]) + ciou_bb1_bb2 = 0.036492417 + + # non overlapping case + far_away_bb1 = np.array([1000, 1000, 1500, 1500]) + far_away_bb2 = np.array([2000, 2000, 2500, 2500]) + ciou_far_away_bb1_bb2 = -0.44444444435 + + sample_y_true = np.array([bb1, far_away_bb1]) + sample_y_pred = np.array([[bb2, far_away_bb2], [bb2, far_away_bb2]]) + + result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, "yxyx") + self.assertAllClose(ciou_bb1_bb2, result[0][0]) + self.assertAllClose(ciou_bb1_bb2, result[1][0]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[0][1]) + self.assertAllClose(ciou_far_away_bb1_bb2, result[1][1]) diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py new file mode 100644 index 000000000000..43aacde89785 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py @@ -0,0 +1,182 @@ +from keras.src import backend as current_backend +from keras.src.utils import tf_utils + + +def _classes_shape(batched, classes_shape, max_boxes): + if max_boxes is None: + return None + if batched: + return [None, max_boxes] + classes_shape[2:] + return [max_boxes] + classes_shape[1:] + + +def _box_shape(batched, boxes_shape, max_boxes): + # ensure we dont drop the final axis in RaggedTensor mode + if max_boxes is None: + shape = list(boxes_shape) + shape[-1] = 4 + return shape + if batched: + return [None, max_boxes, 4] + return [max_boxes, 4] + + +def densify_bounding_boxes( + bounding_boxes, + is_batched=False, + max_boxes=None, + boxes_default_value=0, + labels_default_value=-1, + backend=None, +): + validate_bounding_boxes(bounding_boxes) + boxes = bounding_boxes["boxes"] + labels = bounding_boxes["labels"] + backend = backend or current_backend + if isinstance(boxes, list): + if boxes and isinstance(boxes[0], list): + if boxes[0] and isinstance(boxes[0][0], list): + # Batched case + if not isinstance(labels[0][0], int): + raise ValueError( + "If providing `bounding_boxes['labels']` as a list, " + "it should contain integers labels. Received: " + f"bounding_boxes['labels']={labels}" + ) + if max_boxes is not None: + max_boxes = max([len(b) for b in boxes]) + new_boxes = [] + new_labels = [] + for b, l in zip(boxes, labels): + if len(b) >= max_boxes: + new_boxes.append(b[:max_boxes]) + new_labels.append(l[:max_boxes]) + else: + num_boxes_to_add = max_boxes - len(b) + added_boxes = [ + [ + boxes_default_value, + boxes_default_value, + boxes_default_value, + boxes_default_value, + ] + for _ in range(num_boxes_to_add) + ] + new_boxes.append(b + added_boxes) + new_labels.append( + l + + [ + labels_default_value + for _ in range(num_boxes_to_add) + ] + ) + else: + # Unbatched case + if max_boxes and len(b) >= max_boxes: + new_boxes = b[:max_boxes] + new_labels = l[:max_boxes] + else: + num_boxes_to_add = max_boxes - len(b) + added_boxes = [ + [ + boxes_default_value, + boxes_default_value, + boxes_default_value, + boxes_default_value, + ] + for _ in range(num_boxes_to_add) + ] + new_boxes = b + added_boxes + new_labels = l + [ + labels_default_value for _ in range(num_boxes_to_add) + ] + return { + "boxes": backend.convert_to_tensor(new_boxes, dtype="float32"), + "labels": backend.convert_to_tensor(new_labels, dtype="int32"), + } + + if tf_utils.is_ragged_tensor(boxes): + bounding_boxes["boxes"] = bounding_boxes["boxes"].to_tensor( + default_value=boxes_default_value, + shape=_box_shape( + is_batched, bounding_boxes["boxes"].shape, max_boxes + ), + ) + bounding_boxes["labels"] = bounding_boxes["labels"].to_tensor( + default_value=labels_default_value, + shape=_classes_shape( + is_batched, bounding_boxes["labels"].shape, max_boxes + ), + ) + return bounding_boxes + + bounding_boxes["boxes"] = backend.convert_to_tensor(boxes, dtype="float32") + bounding_boxes["labels"] = backend.convert_to_tensor(labels) + return bounding_boxes + + +def validate_bounding_boxes(bounding_boxes): + if ( + not isinstance(bounding_boxes, dict) + or "labels" not in bounding_boxes + or "boxes" not in bounding_boxes + ): + raise ValueError( + "Expected `bounding_boxes` agurment to be a " + "dict with keys 'boxes' and 'labels'. Received: " + f"bounding_boxes={bounding_boxes}" + ) + boxes = bounding_boxes["boxes"] + labels = bounding_boxes["labels"] + if isinstance(boxes, list): + if not isinstance(labels, list): + raise ValueError( + "If `bounding_boxes['boxes']` is a list, then " + "`bounding_boxes['labels']` must also be a list." + f"Received: bounding_boxes['labels']={labels}" + ) + if len(boxes) != len(labels): + raise ValueError( + "If `bounding_boxes['boxes']` and " + "`bounding_boxes['labels']` are both lists, " + "they must have the same length. Received: " + f"len(bounding_boxes['boxes'])={len(boxes)} and " + f"len(bounding_boxes['labels'])={len(labels)} and " + ) + elif tf_utils.is_ragged_tensor(boxes): + if not tf_utils.is_ragged_tensor(labels): + raise ValueError( + "If `bounding_boxes['boxes']` is a Ragged tensor, " + " `bounding_boxes['labels']` must also be a " + "Ragged tensor. " + f"Received: bounding_boxes['labels']={labels}" + ) + else: + boxes_shape = current_backend.shape(boxes) + labels_shape = current_backend.shape(labels) + if len(boxes_shape) == 2: # (boxes, 4) + if len(labels_shape) not in {1, 2}: + raise ValueError( + "Found " + f"bounding_boxes['boxes'].shape={boxes_shape} " + "and expected bounding_boxes['labels'] to have " + "rank 1 or 2, but received: " + f"bounding_boxes['labels'].shape={labels_shape} " + ) + elif len(boxes_shape) == 3: + if len(labels_shape) not in {2, 3}: + raise ValueError( + "Found " + f"bounding_boxes['boxes'].shape={boxes_shape} " + "and expected bounding_boxes['labels'] to have " + "rank 2 or 3, but received: " + f"bounding_boxes['labels'].shape={labels_shape} " + ) + else: + raise ValueError( + "Expected `bounding_boxes['boxes']` " + "to have rank 2 or 3, with shape " + "(num_boxes, 4) or (batch_size, num_boxes, 4). " + "Received: " + f"bounding_boxes['boxes'].shape={boxes_shape}" + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation_test.py b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation_test.py new file mode 100644 index 000000000000..0a25a05df7d1 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation_test.py @@ -0,0 +1,75 @@ +import pytest +import tensorflow as tf + +from keras.src import backend +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import ( + validation, +) +from keras.src.testing import test_case + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The test targets TensorFlow-specific ragged tensors.", +) +class DensifyBoundingBoxesTest(test_case.TestCase): + def test_densify_ragged_bounding_boxes_batched(self): + ragged_boxes = tf.ragged.constant( + [ + [[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]], + [[0.5, 0.5, 0.6, 0.6]], + ], + dtype=tf.float32, + ) + ragged_labels = tf.ragged.constant( + [ + [0, 1], + [2], + ], + dtype=tf.int32, + ) + bounding_boxes = {"boxes": ragged_boxes, "labels": ragged_labels} + max_boxes = 3 + densified_data = validation.densify_bounding_boxes( + bounding_boxes.copy(), is_batched=True, max_boxes=max_boxes + ) + densified_boxes = densified_data["boxes"] + densified_labels = densified_data["labels"] + self.assertEqual(densified_boxes.shape, (2, max_boxes, 4)) + self.assertEqual(densified_labels.shape, (2, max_boxes)) + expected_boxes = [ + [[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.0, 0.0, 0.0, 0.0]], + [[0.5, 0.5, 0.6, 0.6], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + expected_labels = [ + [0, 1, -1], + [2, -1, -1], + ] + self.assertAllClose(densified_boxes, expected_boxes) + self.assertAllEqual(densified_labels, expected_labels) + + def test_densify_ragged_bounding_boxes_unbatched(self): + ragged_boxes = tf.ragged.constant( + [[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]], + dtype=tf.float32, + ) + ragged_labels = tf.ragged.constant([[0], [1]], dtype=tf.int32) + bounding_boxes = {"boxes": ragged_boxes, "labels": ragged_labels} + max_boxes = 4 + densified_data = validation.densify_bounding_boxes( + bounding_boxes.copy(), is_batched=False, max_boxes=max_boxes + ) + densified_boxes = densified_data["boxes"] + densified_labels = densified_data["labels"] + + self.assertEqual(densified_boxes.shape, (max_boxes, 4)) + self.assertEqual(densified_labels.shape, (max_boxes, 1)) + expected_boxes = [ + [0.1, 0.1, 0.2, 0.2], + [0.3, 0.3, 0.4, 0.4], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + expected_labels = [[0], [1], [-1], [-1]] + self.assertAllClose(densified_boxes, expected_boxes) + self.assertAllEqual(densified_labels, expected_labels) diff --git a/keras/src/layers/preprocessing/image_preprocessing/center_crop.py b/keras/src/layers/preprocessing/image_preprocessing/center_crop.py new file mode 100644 index 000000000000..f32c3bddbb4d --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/center_crop.py @@ -0,0 +1,273 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.utils import image_utils + + +@keras_export("keras.layers.CenterCrop") +class CenterCrop(BaseImagePreprocessingLayer): + """A preprocessing layer which crops images. + + This layers crops the central portion of the images to a target size. If an + image is smaller than the target size, it will be resized and cropped + so as to return the largest possible window in the image that matches + the target aspect ratio. + + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. + + If the input height/width is even and the target height/width is odd (or + inversely), the input image is left-padded by 1 pixel. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + height: Integer, the height of the output shape. + width: Integer, the width of the output shape. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + """ + + _USE_BASE_FACTOR = False + + def __init__(self, height, width, data_format=None, **kwargs): + super().__init__(data_format=data_format, **kwargs) + self.height = height + self.width = width + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + shape = self.backend.core.shape(images) + return {"input_shape": shape} + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + def _get_height_width(input_shape): + if self.data_format == "channels_first": + input_height = input_shape[-2] + input_width = input_shape[-1] + else: + input_height = input_shape[-3] + input_width = input_shape[-2] + return input_height, input_width + + def _get_clipped_bbox(bounding_boxes, h_end, h_start, w_end, w_start): + bboxes = bounding_boxes["boxes"] + x1, y1, x2, y2 = self.backend.numpy.split(bboxes, 4, axis=-1) + x1 = self.backend.numpy.clip(x1, w_start, w_end) - w_start + y1 = self.backend.numpy.clip(y1, h_start, h_end) - h_start + x2 = self.backend.numpy.clip(x2, w_start, w_end) - w_start + y2 = self.backend.numpy.clip(y2, h_start, h_end) - h_start + bounding_boxes["boxes"] = self.backend.numpy.concatenate( + [x1, y1, x2, y2], axis=-1 + ) + return bounding_boxes + + input_shape = transformation["input_shape"] + + init_height, init_width = _get_height_width(input_shape) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=init_height, + width=init_width, + ) + + h_diff = init_height - self.height + w_diff = init_width - self.width + + if h_diff >= 0 and w_diff >= 0: + h_start = int(h_diff / 2) + w_start = int(w_diff / 2) + + h_end = h_start + self.height + w_end = w_start + self.width + + bounding_boxes = _get_clipped_bbox( + bounding_boxes, h_end, h_start, w_end, w_start + ) + else: + width = init_width + height = init_height + target_height = self.height + target_width = self.width + + crop_height = int(float(width * target_height) / target_width) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / target_height) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + + h_start = crop_box_hstart + w_start = crop_box_wstart + + h_end = crop_box_hstart + crop_height + w_end = crop_box_wstart + crop_width + bounding_boxes = _get_clipped_bbox( + bounding_boxes, h_end, h_start, w_end, w_start + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target="rel_xyxy", + height=crop_height, + width=crop_width, + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target="xyxy", + height=self.height, + width=self.width, + ) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=self.height, + width=self.width, + bounding_box_format="xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def transform_images(self, images, transformation=None, training=True): + inputs = self.backend.cast(images, self.compute_dtype) + if self.data_format == "channels_first": + init_height = inputs.shape[-2] + init_width = inputs.shape[-1] + else: + init_height = inputs.shape[-3] + init_width = inputs.shape[-2] + + if init_height is None or init_width is None: + # Dynamic size case. TODO. + raise ValueError( + "At this time, CenterCrop can only " + "process images with a static spatial " + f"shape. Received: inputs.shape={inputs.shape}" + ) + + h_diff = init_height - self.height + w_diff = init_width - self.width + + h_start = int(h_diff / 2) + w_start = int(w_diff / 2) + + if h_diff >= 0 and w_diff >= 0: + if len(inputs.shape) == 4: + if self.data_format == "channels_first": + return inputs[ + :, + :, + h_start : h_start + self.height, + w_start : w_start + self.width, + ] + return inputs[ + :, + h_start : h_start + self.height, + w_start : w_start + self.width, + :, + ] + elif len(inputs.shape) == 3: + if self.data_format == "channels_first": + return inputs[ + :, + h_start : h_start + self.height, + w_start : w_start + self.width, + ] + return inputs[ + h_start : h_start + self.height, + w_start : w_start + self.width, + :, + ] + return image_utils.smart_resize( + inputs, + [self.height, self.width], + data_format=self.data_format, + backend_module=self.backend, + ) + + def compute_output_shape(self, input_shape): + input_shape = list(input_shape) + if isinstance(input_shape[0], (list, tuple)) or len( + input_shape + ) not in (3, 4): + raise ValueError( + "`input_shape` must be a non-nested tuple or list " + "of rank-1 with size 3 (unbatched) or 4 (batched). " + ) + if len(input_shape) == 4: + if self.data_format == "channels_last": + input_shape[1] = self.height + input_shape[2] = self.width + else: + input_shape[2] = self.height + input_shape[3] = self.width + else: + if self.data_format == "channels_last": + input_shape[0] = self.height + input_shape[1] = self.width + else: + input_shape[1] = self.height + input_shape[2] = self.width + return tuple(input_shape) + + def get_config(self): + base_config = super().get_config() + config = { + "height": self.height, + "width": self.width, + "data_format": self.data_format, + } + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py new file mode 100644 index 000000000000..82451fa35285 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py @@ -0,0 +1,296 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class CenterCropTest(testing.TestCase): + def np_center_crop(self, img, h_new, w_new, data_format="channels_last"): + img = np.array(img) + if img.ndim == 4: + if data_format == "channels_last": + _, h, w = img.shape[:3] + else: + _, h, w = img.shape[1:] + else: + if data_format == "channels_last": + h, w = img.shape[:2] + else: + h, w = img.shape[1:] + h_start = (h - h_new) // 2 + w_start = (w - w_new) // 2 + if data_format == "channels_last": + return img[ + ..., h_start : h_start + h_new, w_start : w_start + w_new, : + ] + else: + return img[ + ..., h_start : h_start + h_new, w_start : w_start + w_new + ] + + @pytest.mark.requires_trainable_backend + def test_center_crop_basics(self): + self.run_layer_test( + layers.CenterCrop, + init_kwargs={ + "height": 6, + "width": 6, + "data_format": "channels_last", + }, + input_shape=(2, 12, 12, 3), + expected_output_shape=(2, 6, 6, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.CenterCrop, + init_kwargs={ + "height": 7, + "width": 7, + "data_format": "channels_first", + }, + input_shape=(2, 3, 13, 13), + expected_output_shape=(2, 3, 7, 7), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + [ + ((5, 7), "channels_first"), + ((5, 7), "channels_last"), + ((4, 9), "channels_first"), + ((9, 4), "channels_last"), + ] + ) + def test_center_crop_correctness(self, size, data_format): + # batched case + if data_format == "channels_first": + img = np.random.random((2, 3, 9, 11)) + else: + img = np.random.random((2, 9, 11, 3)) + out = layers.CenterCrop( + size[0], + size[1], + data_format=data_format, + )(img) + if data_format == "channels_first": + img_transpose = np.transpose(img, (0, 2, 3, 1)) + + ref_out = np.transpose( + self.np_center_crop(img_transpose, size[0], size[1]), + (0, 3, 1, 2), + ) + else: + ref_out = self.np_center_crop(img, size[0], size[1]) + self.assertAllClose(ref_out, out) + + # unbatched case + if data_format == "channels_first": + img = np.random.random((3, 9, 11)) + else: + img = np.random.random((9, 11, 3)) + out = layers.CenterCrop( + size[0], + size[1], + data_format=data_format, + )(img) + if data_format == "channels_first": + img_transpose = np.transpose(img, (1, 2, 0)) + ref_out = np.transpose( + self.np_center_crop( + img_transpose, + size[0], + size[1], + ), + (2, 0, 1), + ) + else: + ref_out = self.np_center_crop( + img, + size[0], + size[1], + ) + self.assertAllClose(ref_out, out) + + @parameterized.parameters( + [ + ((15, 10), "channels_first"), + ((10, 17), "channels_last"), + ] + ) + def test_input_smaller_than_crop_box(self, size, data_format): + """Output should equal resizing with crop_to_aspect ratio.""" + # batched case + if data_format == "channels_first": + img = np.random.random((2, 3, 9, 11)) + else: + img = np.random.random((2, 9, 11, 3)) + out = layers.CenterCrop( + size[0], + size[1], + data_format=data_format, + )(img) + ref_out = layers.Resizing( + size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True + )(img) + self.assertAllClose(ref_out, out) + + # unbatched case + if data_format == "channels_first": + img = np.random.random((3, 9, 11)) + else: + img = np.random.random((9, 11, 3)) + out = layers.CenterCrop( + size[0], + size[1], + data_format=data_format, + )(img) + ref_out = layers.Resizing( + size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True + )(img) + self.assertAllClose(ref_out, out) + + def test_tf_data_compatibility(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 12, 3) + output_shape = (2, 8, 9, 3) + else: + input_shape = (2, 3, 10, 12) + output_shape = (2, 3, 8, 9) + layer = layers.CenterCrop(8, 9) + input_data = np.random.random(input_shape) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + output = next(iter(ds)).numpy() + self.assertEqual(tuple(output.shape), output_shape) + + # TODO + # def test_list_compatibility(self): + # if backend.config.image_data_format() == "channels_last": + # images = [ + # np.random.rand(10, 10, 3), + # np.random.rand(10, 10, 3), + # ] + # output_shape = (2, 6, 5, 3) + # else: + # images = [ + # np.random.rand(3, 10, 10), + # np.random.rand(3, 10, 10), + # ] + # output_shape = (2, 3, 6, 5) + # output = layers.CenterCrop(height=6, width=5)(images) + # ref_output = self.np_center_crop( + # images, 6, 5, data_format=backend.config.image_data_format() + # ) + # self.assertEqual(tuple(output.shape), output_shape) + # self.assertAllClose(ref_output, output) + + @parameterized.parameters( + [((5, 17), "channels_last"), ((5, 100), "channels_last")] + ) + def test_image_stretch(self, size, data_format): + img = np.random.rand(2, 11, 3, 9) + out = layers.CenterCrop( + size[0], + size[1], + data_format=data_format, + )(img) + ref_out = layers.Resizing( + size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True + )(img) + self.assertAllClose(ref_out, out) + + @parameterized.named_parameters( + ( + "normal", + 5, + 5, + [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 5.0, 4.0]], + ), + ( + "with_stretch", + 20, + 20, + [[5.0, 0.0, 10.0, 5.0], [15.0, 7.5, 20.0, 12.5]], + ), + ) + def test_center_crop_bounding_boxes(self, height, width, expected_boxes): + if backend.config.image_data_format() == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + center_crop_layer = layers.CenterCrop( + height=height, + width=width, + bounding_box_format="xyxy", + ) + output = center_crop_layer(input_data) + self.assertAllClose(output["bounding_boxes"]["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "normal", + 5, + 5, + [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 5.0, 4.0]], + ), + ( + "with_stretch", + 20, + 20, + [[5.0, 0.0, 10.0, 5.0], [15.0, 7.5, 20.0, 12.5]], + ), + ) + def test_center_crop_tf_data_bounding_boxes( + self, height, width, expected_boxes + ): + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + center_crop_layer = layers.CenterCrop( + height=height, + width=width, + bounding_box_format="xyxy", + ) + ds = ds.map(center_crop_layer) + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["bounding_boxes"]["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/cut_mix.py b/keras/src/layers/preprocessing/image_preprocessing/cut_mix.py new file mode 100644 index 000000000000..a1d07320af4d --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/cut_mix.py @@ -0,0 +1,229 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.CutMix") +class CutMix(BaseImagePreprocessingLayer): + """CutMix data augmentation technique. + + CutMix is a data augmentation method where patches are cut and pasted + between two images in the dataset, while the labels are also mixed + proportionally to the area of the patches. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [CutMix paper]( https://arxiv.org/abs/1905.04899). + + Args: + factor: A single float or a tuple of two floats between 0 and 1. + If a tuple of numbers is passed, a `factor` is sampled + between the two values. + If a single float is passed, a value between 0 and the passed + float is sampled. These values define the range from which the + mixing weight is sampled. A higher factor increases the variability + in patch sizes, leading to more diverse and larger mixed patches. + Defaults to 1. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + def __init__(self, factor=1.0, seed=None, data_format=None, **kwargs): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.seed = seed + self.generator = SeedGenerator(seed) + + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + self.channel_axis = -3 + else: + self.height_axis = -3 + self.width_axis = -2 + self.channel_axis = -1 + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + if len(images_shape) == 3: + return None + + batch_size = images_shape[0] + image_height = images_shape[self.height_axis] + image_width = images_shape[self.width_axis] + + seed = seed or self._get_seed_generator(self.backend._backend) + + mix_weight = self._generate_mix_weight(batch_size, seed) + ratio = self.backend.numpy.sqrt(1.0 - mix_weight) + + x0, x1 = self._compute_crop_bounds(batch_size, image_width, ratio, seed) + y0, y1 = self._compute_crop_bounds( + batch_size, image_height, ratio, seed + ) + + batch_masks, mix_weight = self._generate_batch_mask( + images_shape, + (x0, x1, y0, y1), + ) + + permutation_order = self.backend.random.shuffle( + self.backend.numpy.arange(0, batch_size, dtype="int32"), + seed=seed, + ) + + return { + "permutation_order": permutation_order, + "batch_masks": batch_masks, + "mix_weight": mix_weight, + } + + def _generate_batch_mask(self, images_shape, box_corners): + def _generate_grid_xy(image_height, image_width): + grid_y, grid_x = self.backend.numpy.meshgrid( + self.backend.numpy.arange( + image_height, dtype=self.compute_dtype + ), + self.backend.numpy.arange( + image_width, dtype=self.compute_dtype + ), + indexing="ij", + ) + if self.data_format == "channels_last": + grid_y = self.backend.cast( + grid_y[None, :, :, None], dtype=self.compute_dtype + ) + grid_x = self.backend.cast( + grid_x[None, :, :, None], dtype=self.compute_dtype + ) + else: + grid_y = self.backend.cast( + grid_y[None, None, :, :], dtype=self.compute_dtype + ) + grid_x = self.backend.cast( + grid_x[None, None, :, :], dtype=self.compute_dtype + ) + return grid_x, grid_y + + image_height, image_width = ( + images_shape[self.height_axis], + images_shape[self.width_axis], + ) + grid_x, grid_y = _generate_grid_xy(image_height, image_width) + + x0, x1, y0, y1 = box_corners + + x0 = x0[:, None, None, None] + y0 = y0[:, None, None, None] + x1 = x1[:, None, None, None] + y1 = y1[:, None, None, None] + + batch_masks = ( + (grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1) + ) + batch_masks = self.backend.numpy.repeat( + batch_masks, images_shape[self.channel_axis], axis=self.channel_axis + ) + mix_weight = 1.0 - (x1 - x0) * (y1 - y0) / (image_width * image_height) + return batch_masks, mix_weight + + def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed): + crop_length = self.backend.cast( + crop_ratio * image_length, dtype=self.compute_dtype + ) + + start_pos = self.backend.random.uniform( + shape=[batch_size], + minval=0, + maxval=1, + dtype=self.compute_dtype, + seed=seed, + ) * (image_length - crop_length) + + end_pos = start_pos + crop_length + + return start_pos, end_pos + + def _generate_mix_weight(self, batch_size, seed): + alpha = ( + self.backend.random.uniform( + shape=(), + minval=self.factor[0], + maxval=self.factor[1], + dtype=self.compute_dtype, + seed=seed, + ) + + 1e-6 + ) + mix_weight = self.backend.random.beta( + (batch_size,), alpha, alpha, seed=seed, dtype=self.compute_dtype + ) + return mix_weight + + def transform_images(self, images, transformation=None, training=True): + if training and transformation is not None: + images = self.backend.cast(images, self.compute_dtype) + + permutation_order = transformation["permutation_order"] + batch_masks = transformation["batch_masks"] + + images = self.backend.numpy.where( + batch_masks, + self.backend.numpy.take(images, permutation_order, axis=0), + images, + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + if training and transformation is not None: + permutation_order = transformation["permutation_order"] + mix_weight = transformation["mix_weight"] + + cutout_labels = self.backend.numpy.take( + labels, permutation_order, axis=0 + ) + mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1]) + labels = mix_weight * labels + (1.0 - mix_weight) * cutout_labels + + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + raise NotImplementedError() + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py b/keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py new file mode 100644 index 000000000000..61f09b2a3d80 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class CutMixTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.CutMix, + init_kwargs={ + "factor": 1.0, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + # StatelessRandomGammaV3 is not supported on XLA_GPU_JIT + run_training_check=not testing.tensorflow_uses_gpu(), + ) + + def test_cut_mix_inference(self): + seed = 3481 + layer = layers.CutMix() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_cut_mix_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image1 = np.ones((2, 2, 1)) + image2 = np.zeros((2, 2, 1)) + inputs = np.asarray([image1, image2]) + expected_output = np.array( + [ + [[[1.0], [1.0]], [[1.0], [1.0]]], + [[[0.0], [0.0]], [[0.0], [0.0]]], + ] + ) + else: + image1 = np.ones((1, 2, 2)) + image2 = np.zeros((1, 2, 2)) + inputs = np.asarray([image1, image2]) + expected_output = np.asarray( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]], + [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + ] + ) + + layer = layers.CutMix(data_format=data_format) + + transformation = { + "batch_masks": np.asarray( + [ + [[[False], [True]], [[False], [False]]], + [[[False], [False]], [[True], [False]]], + ] + ), + "mix_weight": np.asarray([[[[0.7826548]]], [[[0.8133545]]]]), + "permutation_order": np.asarray([0, 1]), + } + + output = layer.transform_images(inputs, transformation) + + self.assertAllClose(expected_output, output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.CutMix(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/equalization.py b/keras/src/layers/preprocessing/image_preprocessing/equalization.py new file mode 100644 index 000000000000..4116419cee93 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/equalization.py @@ -0,0 +1,224 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.Equalization") +class Equalization(BaseImagePreprocessingLayer): + """Preprocessing layer for histogram equalization on image channels. + + Histogram equalization is a technique to adjust image intensities to + enhance contrast by effectively spreading out the most frequent + intensity values. This layer applies equalization on a channel-wise + basis, which can improve the visibility of details in images. + + This layer works with both grayscale and color images, performing + equalization independently on each color channel. At inference time, + the equalization is consistently applied. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + value_range: Optional list/tuple of 2 floats specifying the lower + and upper limits of the input data values. Defaults to `[0, 255]`. + If the input image has been scaled, use the appropriate range + (e.g., `[0.0, 1.0]`). The equalization will be scaled to this + range, and output values will be clipped accordingly. + bins: Integer specifying the number of histogram bins to use for + equalization. Defaults to 256, which is suitable for 8-bit images. + Larger values can provide more granular intensity redistribution. + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. + + Example: + + ```python + # Create an equalization layer for standard 8-bit images + equalizer = keras.layers.Equalization() + + # An image with uneven intensity distribution + image = [...] # your input image + + # Apply histogram equalization + equalized_image = equalizer(image) + + # For images with custom value range + custom_equalizer = keras.layers.Equalization( + value_range=[0.0, 1.0], # for normalized images + bins=128 # fewer bins for more subtle equalization + ) + custom_equalized = custom_equalizer(normalized_image) + ``` + """ + + def __init__( + self, value_range=(0, 255), bins=256, data_format=None, **kwargs + ): + super().__init__(**kwargs) + self.bins = bins + self._set_value_range(value_range) + self.data_format = backend.standardize_data_format(data_format) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def _custom_histogram_fixed_width(self, values, value_range, nbins): + values = self.backend.cast(values, "float32") + value_min, value_max = value_range + value_min = self.backend.cast(value_min, "float32") + value_max = self.backend.cast(value_max, "float32") + + scaled = (values - value_min) * (nbins - 1) / (value_max - value_min) + indices = self.backend.cast(scaled, "int32") + indices = self.backend.numpy.clip(indices, 0, nbins - 1) + flat_indices = self.backend.numpy.reshape(indices, [-1]) + + if backend.backend() == "jax": + # for JAX bincount is never jittable because of output shape + histogram = self.backend.numpy.zeros(nbins, dtype="int32") + for i in range(nbins): + matches = self.backend.cast( + self.backend.numpy.equal(flat_indices, i), "int32" + ) + bin_count = self.backend.numpy.sum(matches) + one_hot = self.backend.cast( + self.backend.numpy.arange(nbins) == i, "int32" + ) + histogram = histogram + (bin_count * one_hot) + return histogram + else: + # TensorFlow/PyTorch/NumPy implementation using bincount + return self.backend.numpy.bincount( + flat_indices, + minlength=nbins, + ) + + def _scale_values(self, values, source_range, target_range): + source_min, source_max = source_range + target_min, target_max = target_range + scale = (target_max - target_min) / (source_max - source_min) + offset = target_min - source_min * scale + return values * scale + offset + + def _equalize_channel(self, channel, value_range): + if value_range != (0, 255): + channel = self._scale_values(channel, value_range, (0, 255)) + + hist = self._custom_histogram_fixed_width( + channel, value_range=(0, 255), nbins=self.bins + ) + + nonzero_bins = self.backend.numpy.count_nonzero(hist) + equalized = self.backend.numpy.where( + nonzero_bins <= 1, channel, self._apply_equalization(channel, hist) + ) + + if value_range != (0, 255): + equalized = self._scale_values(equalized, (0, 255), value_range) + + return equalized + + def _apply_equalization(self, channel, hist): + cdf = self.backend.numpy.cumsum(hist) + + if self.backend.name == "jax": + mask = cdf > 0 + first_nonzero_idx = self.backend.numpy.argmax(mask) + cdf_min = self.backend.numpy.take(cdf, first_nonzero_idx) + else: + cdf_min = self.backend.numpy.take( + cdf, self.backend.numpy.nonzero(cdf)[0][0] + ) + + denominator = cdf[-1] - cdf_min + denominator = self.backend.numpy.where( + denominator == 0, + self.backend.numpy.ones_like(1, dtype=denominator.dtype), + denominator, + ) + + lookup_table = ((cdf - cdf_min) * 255) / denominator + lookup_table = self.backend.numpy.clip( + self.backend.numpy.round(lookup_table), 0, 255 + ) + + scaled_channel = (channel / 255.0) * (self.bins - 1) + indices = self.backend.cast( + self.backend.numpy.clip(scaled_channel, 0, self.bins - 1), "int32" + ) + return self.backend.numpy.take(lookup_table, indices) + + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + + if self.data_format == "channels_first": + channels = [] + for i in range(self.backend.core.shape(images)[-3]): + channel = images[..., i, :, :] + equalized = self._equalize_channel( + channel, self.value_range + ) + channels.append(equalized) + equalized_images = self.backend.numpy.stack(channels, axis=-3) + else: + channels = [] + for i in range(self.backend.core.shape(images)[-1]): + channel = images[..., i] + equalized = self._equalize_channel( + channel, self.value_range + ) + channels.append(equalized) + equalized_images = self.backend.numpy.stack(channels, axis=-1) + + return self.backend.cast(equalized_images, self.compute_dtype) + return images + + def compute_output_shape(self, input_shape): + return input_shape + + def compute_output_spec(self, inputs, **kwargs): + return inputs + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def get_config(self): + config = super().get_config() + config.update({"bins": self.bins, "value_range": self.value_range}) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/equalization_test.py b/keras/src/layers/preprocessing/image_preprocessing/equalization_test.py new file mode 100644 index 000000000000..5c669ea2f13b --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/equalization_test.py @@ -0,0 +1,135 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class EqualizationTest(testing.TestCase): + def assertAllInRange(self, array, min_val, max_val): + self.assertTrue(np.all(array >= min_val)) + self.assertTrue(np.all(array <= max_val)) + + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.Equalization, + init_kwargs={ + "value_range": (0, 255), + "data_format": "channels_last", + }, + input_shape=(1, 2, 2, 3), + supports_masking=False, + expected_output_shape=(1, 2, 2, 3), + ) + + self.run_layer_test( + layers.Equalization, + init_kwargs={ + "value_range": (0, 255), + "data_format": "channels_first", + }, + input_shape=(1, 3, 2, 2), + supports_masking=False, + expected_output_shape=(1, 3, 2, 2), + ) + + def test_equalizes_to_all_bins(self): + xs = np.random.uniform(size=(2, 512, 512, 3), low=0, high=255).astype( + np.float32 + ) + layer = layers.Equalization(value_range=(0, 255)) + xs = layer(xs) + + for i in range(0, 256): + self.assertTrue(np.any(ops.convert_to_numpy(xs) == i)) + + @parameterized.named_parameters( + ("float32", np.float32), ("int32", np.int32), ("int64", np.int64) + ) + def test_input_dtypes(self, dtype): + xs = np.random.uniform(size=(2, 512, 512, 3), low=0, high=255).astype( + dtype + ) + layer = layers.Equalization(value_range=(0, 255)) + xs = ops.convert_to_numpy(layer(xs)) + + for i in range(0, 256): + self.assertTrue(np.any(xs == i)) + self.assertAllInRange(xs, 0, 255) + + @parameterized.named_parameters(("0_255", 0, 255), ("0_1", 0, 1)) + def test_output_range(self, lower, upper): + xs = np.random.uniform( + size=(2, 512, 512, 3), low=lower, high=upper + ).astype(np.float32) + layer = layers.Equalization(value_range=(lower, upper)) + xs = ops.convert_to_numpy(layer(xs)) + self.assertAllInRange(xs, lower, upper) + + def test_constant_regions(self): + xs = np.zeros((1, 64, 64, 3), dtype=np.float32) + xs[:, :21, :, :] = 50 + xs[:, 21:42, :, :] = 100 + xs[:, 42:, :, :] = 200 + + layer = layers.Equalization(value_range=(0, 255)) + equalized = ops.convert_to_numpy(layer(xs)) + + self.assertTrue(len(np.unique(equalized)) >= 3) + self.assertAllInRange(equalized, 0, 255) + + def test_grayscale_images(self): + xs_last = np.random.uniform(0, 255, size=(2, 64, 64, 1)).astype( + np.float32 + ) + layer_last = layers.Equalization( + value_range=(0, 255), data_format="channels_last" + ) + equalized_last = ops.convert_to_numpy(layer_last(xs_last)) + self.assertEqual(equalized_last.shape[-1], 1) + self.assertAllInRange(equalized_last, 0, 255) + + xs_first = np.random.uniform(0, 255, size=(2, 1, 64, 64)).astype( + np.float32 + ) + layer_first = layers.Equalization( + value_range=(0, 255), data_format="channels_first" + ) + equalized_first = ops.convert_to_numpy(layer_first(xs_first)) + self.assertEqual(equalized_first.shape[1], 1) + self.assertAllInRange(equalized_first, 0, 255) + + def test_single_color_image(self): + xs_last = np.full((1, 64, 64, 3), 128, dtype=np.float32) + layer_last = layers.Equalization( + value_range=(0, 255), data_format="channels_last" + ) + equalized_last = ops.convert_to_numpy(layer_last(xs_last)) + self.assertAllClose(equalized_last, 128.0) + + xs_first = np.full((1, 3, 64, 64), 128, dtype=np.float32) + layer_first = layers.Equalization( + value_range=(0, 255), data_format="channels_first" + ) + equalized_first = ops.convert_to_numpy(layer_first(xs_first)) + self.assertAllClose(equalized_first, 128.0) + + def test_different_bin_sizes(self): + xs = np.random.uniform(0, 255, size=(1, 64, 64, 3)).astype(np.float32) + bin_sizes = [16, 64, 128, 256] + for bins in bin_sizes: + layer = layers.Equalization(value_range=(0, 255), bins=bins) + equalized = ops.convert_to_numpy(layer(xs)) + self.assertAllInRange(equalized, 0, 255) + + def test_tf_data_compatibility(self): + layer = layers.Equalization(value_range=(0, 255)) + input_data = np.random.random((2, 8, 8, 3)) * 255 + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output_array = output.numpy() + self.assertAllInRange(output_array, 0, 255) diff --git a/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py b/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py new file mode 100644 index 000000000000..f7ef37fd66a0 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py @@ -0,0 +1,92 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.MaxNumBoundingBoxes") +class MaxNumBoundingBoxes(BaseImagePreprocessingLayer): + """Ensure the maximum number of bounding boxes. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + max_number: Desired output number of bounding boxes. + padding_value: The padding value of the `boxes` and `labels` in + `bounding_boxes`. Defaults to `-1`. + """ + + def __init__(self, max_number, fill_value=-1, **kwargs): + super().__init__(**kwargs) + self.max_number = int(max_number) + self.fill_value = int(fill_value) + + def transform_images(self, images, transformation=None, training=True): + return images + + def transform_labels(self, labels, transformation=None, training=True): + return labels + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + ops = self.backend + boxes = bounding_boxes["boxes"] + labels = bounding_boxes["labels"] + boxes_shape = ops.shape(boxes) + batch_size = boxes_shape[0] + num_boxes = boxes_shape[1] + + # Get pad size + pad_size = ops.numpy.maximum( + ops.numpy.subtract(self.max_number, num_boxes), 0 + ) + boxes = boxes[:, : self.max_number, ...] + boxes = ops.numpy.pad( + boxes, + [[0, 0], [0, pad_size], [0, 0]], + constant_values=self.fill_value, + ) + labels = labels[:, : self.max_number] + labels = ops.numpy.pad( + labels, [[0, 0], [0, pad_size]], constant_values=self.fill_value + ) + + # Ensure shape + boxes = ops.numpy.reshape(boxes, [batch_size, self.max_number, 4]) + labels = ops.numpy.reshape(labels, [batch_size, self.max_number]) + + bounding_boxes = bounding_boxes.copy() + bounding_boxes["boxes"] = boxes + bounding_boxes["labels"] = labels + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation=None, training=True + ): + return self.transform_images(segmentation_masks) + + def compute_output_shape(self, input_shape): + if isinstance(input_shape, dict) and "bounding_boxes" in input_shape: + input_keys = set(input_shape["bounding_boxes"].keys()) + extra_keys = input_keys - set(("boxes", "labels")) + if extra_keys: + raise KeyError( + "There are unsupported keys in `bounding_boxes`: " + f"{list(extra_keys)}. " + "Only `boxes` and `labels` are supported." + ) + + boxes_shape = list(input_shape["bounding_boxes"]["boxes"]) + boxes_shape[1] = self.max_number + labels_shape = list(input_shape["bounding_boxes"]["labels"]) + labels_shape[1] = self.max_number + input_shape["bounding_boxes"]["boxes"] = boxes_shape + input_shape["bounding_boxes"]["labels"] = labels_shape + return input_shape + + def get_config(self): + config = super().get_config() + config.update({"max_number": self.max_number}) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box_test.py b/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box_test.py new file mode 100644 index 000000000000..efc8037aecea --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box_test.py @@ -0,0 +1,77 @@ +import numpy as np +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class MaxNumBoundingBoxesTest(testing.TestCase): + def test_max_num_bounding_boxes_basics(self): + self.run_layer_test( + layers.MaxNumBoundingBoxes, + init_kwargs={ + "max_number": 40, + "fill_value": -1, + }, + input_shape=(12, 12, 3), + expected_output_shape=(12, 12, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + ) + + def test_output_shapes(self): + if backend.config.image_data_format() == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), # Example boxes (normalized) + "labels": np.array([1, 2]), # Dummy labels + } + layer = layers.MaxNumBoundingBoxes( + max_number=40, bounding_box_format="xyxy" + ) + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + output = layer(input_data) + self.assertAllEqual(output["bounding_boxes"]["boxes"].shape, (40, 4)) + self.assertAllEqual(output["bounding_boxes"]["labels"].shape, (40,)) + + def test_output_shapes_with_tf_data(self): + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), # Example boxes (normalized) + "labels": np.array([[1, 2]]), # Dummy labels + } + layer = layers.MaxNumBoundingBoxes( + max_number=40, bounding_box_format="xyxy" + ) + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + ds = tf_data.Dataset.from_tensor_slices(input_data) + ds = ds.map(layer) + ds = ds.batch(1) + output = next(iter(ds)) + self.assertAllEqual(output["bounding_boxes"]["boxes"].shape, (1, 40, 4)) + self.assertAllEqual(output["bounding_boxes"]["labels"].shape, (1, 40)) diff --git a/keras/src/layers/preprocessing/image_preprocessing/mix_up.py b/keras/src/layers/preprocessing/image_preprocessing/mix_up.py new file mode 100644 index 000000000000..064ae58279f7 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/mix_up.py @@ -0,0 +1,183 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.MixUp") +class MixUp(BaseImagePreprocessingLayer): + """MixUp implements the MixUp data augmentation technique. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [MixUp paper](https://arxiv.org/abs/1710.09412). + - [MixUp for Object Detection paper](https://arxiv.org/pdf/1902.04103). + + Args: + alpha: Float between 0 and 1. Controls the blending strength. + Smaller values mean less mixing, while larger values allow + for more blending between images. Defaults to 0.2, + recommended for ImageNet1k classification. + seed: Integer. Used to create a random seed. + + Example: + ```python + (images, labels), _ = keras.datasets.cifar10.load_data() + images, labels = images[:8], labels[:8] + labels = keras.ops.cast(keras.ops.one_hot(labels.flatten(), 10), "float32") + mix_up = keras.layers.MixUp(alpha=0.2) + output = mix_up({"images": images, "labels": labels}) + ``` + """ + + def __init__(self, alpha=0.2, data_format=None, seed=None, **kwargs): + super().__init__(data_format=data_format, **kwargs) + self.alpha = alpha + self.seed = seed + self.generator = SeedGenerator(seed) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + + if len(images_shape) == 3: + batch_size = 1 + else: + batch_size = self.backend.shape(images)[0] + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + permutation_order = self.backend.random.shuffle( + self.backend.numpy.arange(0, batch_size, dtype="int64"), + seed=seed, + ) + + mix_weight = self.backend.random.beta( + (batch_size,), self.alpha, self.alpha, seed=seed + ) + return { + "mix_weight": mix_weight, + "permutation_order": permutation_order, + } + + def transform_images(self, images, transformation=None, training=True): + def _mix_up_input(images, transformation): + images = self.backend.cast(images, self.compute_dtype) + mix_weight = transformation["mix_weight"] + permutation_order = transformation["permutation_order"] + mix_weight = self.backend.cast( + self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]), + dtype=self.compute_dtype, + ) + mix_up_images = self.backend.cast( + self.backend.numpy.take(images, permutation_order, axis=0), + dtype=self.compute_dtype, + ) + images = mix_weight * images + (1.0 - mix_weight) * mix_up_images + return images + + if training: + images = _mix_up_input(images, transformation) + return images + + def transform_labels(self, labels, transformation, training=True): + def _mix_up_labels(labels, transformation): + mix_weight = transformation["mix_weight"] + permutation_order = transformation["permutation_order"] + labels_for_mix_up = self.backend.numpy.take( + labels, permutation_order, axis=0 + ) + mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1]) + labels = ( + mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up + ) + return labels + + if training: + labels = _mix_up_labels(labels, transformation) + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + def _mix_up_bounding_boxes(bounding_boxes, transformation): + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + permutation_order = transformation["permutation_order"] + # Make sure we are on cpu for torch tensors. + permutation_order = ops.convert_to_numpy(permutation_order) + + boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"] + boxes_for_mix_up = self.backend.numpy.take( + boxes, permutation_order, axis=0 + ) + + labels_for_mix_up = self.backend.numpy.take( + labels, permutation_order, axis=0 + ) + boxes = self.backend.numpy.concatenate( + [boxes, boxes_for_mix_up], axis=1 + ) + + labels = self.backend.numpy.concatenate( + [labels, labels_for_mix_up], axis=0 + ) + + self.backend.reset() + + return {"boxes": boxes, "labels": labels} + + if training: + bounding_boxes = _mix_up_bounding_boxes( + bounding_boxes, transformation + ) + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + def _mix_up_segmentation_masks(segmentation_masks, transformation): + mix_weight = transformation["mix_weight"] + # Make sure we are on cpu for torch tensors. + mix_weight = ops.convert_to_numpy(mix_weight) + permutation_order = transformation["permutation_order"] + mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]) + segmentation_masks_for_mix_up = self.backend.numpy.take( + segmentation_masks, permutation_order + ) + segmentation_masks = ( + mix_weight * segmentation_masks + + (1.0 - mix_weight) * segmentation_masks_for_mix_up + ) + return segmentation_masks + + if training: + segmentation_masks = _mix_up_segmentation_masks( + segmentation_masks, transformation + ) + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "alpha": self.alpha, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py b/keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py new file mode 100644 index 000000000000..eff9e0b3a72a --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py @@ -0,0 +1,157 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.backend import convert_to_tensor + + +class MixUpTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.MixUp, + init_kwargs={ + "alpha": 0.2, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + # StatelessRandomGammaV3 is not supported on XLA_GPU_JIT + run_training_check=not testing.tensorflow_uses_gpu(), + ) + + def test_mix_up_inference(self): + seed = 3481 + layer = layers.MixUp(alpha=0.2) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_mix_up_basic_functionality(self): + image = np.random.random((64, 64, 3)) + mix_up_layer = layers.MixUp(alpha=1) + transformation = {"mix_weight": 1, "permutation_order": [0]} + output = mix_up_layer.transform_images( + image, transformation=transformation + )[0] + self.assertAllClose(output, image) + + image = np.random.random((4, 64, 64, 3)) + mix_up_layer = layers.MixUp(alpha=0.2) + transformation = {"mix_weight": 0.2, "permutation_order": [1, 0, 2, 3]} + output = mix_up_layer.transform_images( + image, transformation=transformation + ) + self.assertNotAllClose(output, image) + self.assertAllClose(output.shape, image.shape) + + def test_mix_up_basic_functionality_channel_first(self): + image = np.random.random((3, 64, 64)) + mix_up_layer = layers.MixUp(alpha=1) + transformation = {"mix_weight": 1, "permutation_order": [0]} + output = mix_up_layer.transform_images( + image, transformation=transformation + )[0] + self.assertAllClose(output, image) + + image = np.random.random((4, 3, 64, 64)) + mix_up_layer = layers.MixUp(alpha=0.2) + transformation = {"mix_weight": 0.2, "permutation_order": [1, 0, 2, 3]} + output = mix_up_layer.transform_images( + image, transformation=transformation + ) + self.assertNotAllClose(output, image) + self.assertAllClose(output.shape, image.shape) + + def test_tf_data_compatibility(self): + layer = layers.MixUp() + input_data = np.random.random((2, 8, 8, 3)) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + def test_mix_up_bounding_boxes(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([1, 2]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]] + + random_flip_layer = layers.MixUp( + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "mix_weight": convert_to_tensor([0.5, 0.5]), + "permutation_order": convert_to_tensor([1, 0]), + } + output = random_flip_layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + self.assertAllClose(output["boxes"], expected_boxes) + + def test_mix_up_tf_data_bounding_boxes(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]] + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.MixUp( + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "mix_weight": convert_to_tensor([0.5, 0.5]), + "permutation_order": convert_to_tensor([1, 0]), + } + ds = ds.map( + lambda x: layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py b/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py new file mode 100644 index 000000000000..b0dedf5ec63e --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py @@ -0,0 +1,267 @@ +import keras.src.layers as layers +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandAugment") +class RandAugment(BaseImagePreprocessingLayer): + """RandAugment performs the Rand Augment operation on input images. + + This layer can be thought of as an all-in-one image augmentation layer. The + policy implemented by this layer has been benchmarked extensively and is + effective on a wide variety of datasets. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [RandAugment](https://arxiv.org/abs/1909.13719) + + Args: + value_range: The range of values the input image can take. + Default is `(0, 255)`. Typically, this would be `(0, 1)` + for normalized images or `(0, 255)` for raw images. + num_ops: The number of augmentation operations to apply sequentially + to each image. Default is 2. + factor: The strength of the augmentation as a normalized value + between 0 and 1. Default is 0.5. + interpolation: The interpolation method to use for resizing operations. + Options include `nearest`, `bilinear`. Default is `bilinear`. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + _AUGMENT_LAYERS = [ + "random_shear", + "random_translation", + "random_rotation", + "random_brightness", + "random_color_degeneration", + "random_contrast", + "random_sharpness", + "random_posterization", + "solarization", + "auto_contrast", + "equalization", + ] + + def __init__( + self, + value_range=(0, 255), + num_ops=2, + factor=0.5, + interpolation="bilinear", + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + + self.value_range = value_range + self.num_ops = num_ops + self._set_factor(factor) + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + + self.random_shear = layers.RandomShear( + x_factor=self.factor, + y_factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_translation = layers.RandomTranslation( + height_factor=self.factor, + width_factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_rotation = layers.RandomRotation( + factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_brightness = layers.RandomBrightness( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_color_degeneration = layers.RandomColorDegeneration( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_contrast = layers.RandomContrast( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_sharpness = layers.RandomSharpness( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.solarization = layers.Solarization( + addition_factor=self.factor, + threshold_factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_posterization = layers.RandomPosterization( + factor=max(1, int(8 * self.factor[1])), + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.auto_contrast = layers.AutoContrast( + value_range=self.value_range, data_format=data_format, **kwargs + ) + + self.equalization = layers.Equalization( + value_range=self.value_range, data_format=data_format, **kwargs + ) + + def build(self, input_shape): + for layer_name in self._AUGMENT_LAYERS: + augmentation_layer = getattr(self, layer_name) + augmentation_layer.build(input_shape) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + for layer_name in self._AUGMENT_LAYERS: + augmentation_layer = getattr(self, layer_name) + augmentation_layer.backend.set_backend("tensorflow") + + layer_idxes = self.backend.random.randint( + (self.num_ops,), + 0, + len(self._AUGMENT_LAYERS), + seed=self._get_seed_generator(self.backend._backend), + ) + + transformation = {} + for layer_name in self._AUGMENT_LAYERS: + augmentation_layer = getattr(self, layer_name) + transformation[layer_name] = ( + augmentation_layer.get_random_transformation( + data, + training=training, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + + return { + "transforms": transformation, + "layer_idxes": layer_idxes, + } + + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + + layer_idxes = transformation["layer_idxes"] + transforms = transformation["transforms"] + for i in range(self.num_ops): + for idx, (key, value) in enumerate(transforms.items()): + augmentation_layer = getattr(self, key) + images = self.backend.numpy.where( + layer_idxes[i] == idx, + augmentation_layer.transform_images(images, value), + images, + ) + + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + if training: + layer_idxes = transformation["layer_idxes"] + transforms = transformation["transforms"] + for idx, (key, value) in enumerate(transforms.items()): + augmentation_layer = getattr(self, key) + + transformed_bounding_box = ( + augmentation_layer.transform_bounding_boxes( + bounding_boxes.copy(), value + ) + ) + for i in range(self.num_ops): + bounding_boxes["boxes"] = self.backend.numpy.where( + layer_idxes[i] == idx, + transformed_bounding_box["boxes"], + bounding_boxes["boxes"], + ) + + bounding_boxes["labels"] = self.backend.numpy.where( + layer_idxes[i] == idx, + transformed_bounding_box["labels"], + bounding_boxes["labels"], + ) + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "value_range": self.value_range, + "num_ops": self.num_ops, + "factor": self.factor, + "interpolation": self.interpolation, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py b/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py new file mode 100644 index 000000000000..91929d666ce0 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py @@ -0,0 +1,129 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandAugmentTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandAugment, + init_kwargs={ + "value_range": (0, 255), + "num_ops": 2, + "factor": 1, + "interpolation": "nearest", + "seed": 1, + "data_format": "channels_last", + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_rand_augment_inference(self): + seed = 3481 + layer = layers.RandAugment() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_rand_augment_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment(data_format=data_format) + + augmented_image = layer(input_data) + self.assertEqual(augmented_image.shape, input_data.shape) + + def test_rand_augment_no_operations(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment(num_ops=0, data_format=data_format) + + augmented_image = layer(input_data) + self.assertAllClose( + backend.convert_to_numpy(augmented_image), input_data + ) + + def test_random_augment_randomness(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + + layer = layers.RandAugment(num_ops=11, data_format=data_format) + augmented_image = layer(input_data) + + self.assertNotAllClose( + backend.convert_to_numpy(augmented_image), input_data + ) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + def test_rand_augment_tf_data_bounding_boxes(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.RandAugment( + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + ds.map(layer) + + def test_graph_issue(self): + input_data = np.random.random((10, 8, 8, 3)) + layer = layers.RandAugment() + ds = ( + tf_data.Dataset.from_tensor_slices(input_data) + .batch(2) + .map(lambda x: layer.get_random_transformation(x)) + ) + + key_list = [] + for output in ds: + key_list.append(output["layer_idxes"]) + + self.assertNotEqual(len(np.unique(key_list)), 1) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_brightness.py b/keras/src/layers/preprocessing/image_preprocessing/random_brightness.py new file mode 100644 index 000000000000..01071728d9d5 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_brightness.py @@ -0,0 +1,158 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomBrightness") +class RandomBrightness(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly adjusts brightness during training. + + This layer will randomly increase/reduce the brightness for the input RGB + images. At inference time, the output will be identical to the input. + Call the layer with `training=True` to adjust the brightness of the input. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: Float or a list/tuple of 2 floats between -1.0 and 1.0. The + factor is used to determine the lower bound and upper bound of the + brightness adjustment. A float value will be chosen randomly between + the limits. When -1.0 is chosen, the output image will be black, and + when 1.0 is chosen, the image will be fully white. + When only one float is provided, eg, 0.2, + then -0.2 will be used for lower bound and 0.2 + will be used for upper bound. + value_range: Optional list/tuple of 2 floats + for the lower and upper limit + of the values of the input data. + To make no change, use `[0.0, 1.0]`, e.g., if the image input + has been scaled before this layer. Defaults to `[0.0, 255.0]`. + The brightness adjustment will be scaled to this range, and the + output values will be clipped to this range. + seed: optional integer, for fixed RNG behavior. + + Inputs: 3D (HWC) or 4D (NHWC) tensor, with float or int dtype. Input pixel + values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) + + Output: 3D (HWC) or 4D (NHWC) tensor with brightness adjusted based on the + `factor`. By default, the layer will output floats. + The output value will be clipped to the range `[0, 255]`, + the valid range of RGB colors, and + rescaled based on the `value_range` if needed. + + Example: + + ```python + random_bright = keras.layers.RandomBrightness(factor=0.2) + + # An image with shape [2, 2, 3] + image = [[[1, 2, 3], [4 ,5 ,6]], [[7, 8, 9], [10, 11, 12]]] + + # Assume we randomly select the factor to be 0.1, then it will apply + # 0.1 * 255 to all the channel + output = random_bright(image, training=True) + + # output will be int64 with 25.5 added to each channel and round down. + >>> array([[[26.5, 27.5, 28.5] + [29.5, 30.5, 31.5]] + [[32.5, 33.5, 34.5] + [35.5, 36.5, 37.5]]], + shape=(2, 2, 3), dtype=int64) + ``` + """ + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs): + super().__init__(factor=factor, **kwargs) + self.seed = seed + self.generator = SeedGenerator(seed) + self._set_value_range(value_range) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + rgb_delta_shape = (1, 1, 1) + elif rank == 4: + # Keep only the batch dim. This will ensure to have same adjustment + # with in one image, but different across the images. + rgb_delta_shape = [images_shape[0], 1, 1, 1] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) + if not training: + return {"rgb_delta": self.backend.numpy.zeros(rgb_delta_shape)} + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + rgb_delta = self.backend.random.uniform( + minval=self.factor[0], + maxval=self.factor[1], + shape=rgb_delta_shape, + seed=seed, + ) + rgb_delta = rgb_delta * (self.value_range[1] - self.value_range[0]) + return {"rgb_delta": rgb_delta} + + def transform_images(self, images, transformation, training=True): + if training: + rgb_delta = transformation["rgb_delta"] + rgb_delta = self.backend.cast(rgb_delta, images.dtype) + images += rgb_delta + return self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py new file mode 100644 index 000000000000..b33bb439c53d --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py @@ -0,0 +1,142 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomBrightnessTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomBrightness, + init_kwargs={ + "factor": 0.75, + "value_range": (20, 200), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_brightness_inference(self): + seed = 3481 + layer = layers.RandomBrightness([0, 1.0]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_correctness(self): + seed = 2390 + + # Always scale up, but randomly between 0 ~ 255 + layer = layers.RandomBrightness([0.1, 1.0]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = backend.convert_to_numpy(layer(inputs)) + diff = output - inputs + diff = backend.convert_to_numpy(diff) + self.assertTrue(np.amin(diff) >= 0) + self.assertTrue(np.mean(diff) > 0) + + # Always scale down, but randomly between 0 ~ 255 + layer = layers.RandomBrightness([-1.0, -0.1]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = backend.convert_to_numpy(layer(inputs)) + diff = output - inputs + self.assertTrue(np.amax(diff) <= 0) + self.assertTrue(np.mean(diff) < 0) + + def test_tf_data_compatibility(self): + layer = layers.RandomBrightness(factor=0.5, seed=1337) + input_data = np.random.random((2, 8, 8, 3)) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + def test_value_range_incorrect_type(self): + with self.assertRaisesRegex( + ValueError, + "The `value_range` argument should be a list of two numbers.*", + ): + layers.RandomBrightness(factor=0.1, value_range="incorrect_type") + + def test_value_range_incorrect_length(self): + with self.assertRaisesRegex( + ValueError, + "The `value_range` argument should be a list of two numbers.*", + ): + layers.RandomBrightness(factor=0.1, value_range=[10]) + + def test_set_factor_incorrect_length(self): + layer = layers.RandomBrightness(factor=0.5) + with self.assertRaisesRegex( + ValueError, "The `factor` argument should be a number.*" + ): + layer._set_factor([0.1]) # Only one element in list + + def test_set_factor_incorrect_type(self): + layer = layers.RandomBrightness(factor=0.5) + with self.assertRaisesRegex( + ValueError, "The `factor` argument should be a number.*" + ): + layer._set_factor( + "invalid_type" + ) # Passing a string instead of a number or a list/tuple of numbers + + def test_factor_range_below_lower_bound(self): + with self.assertRaisesRegex( + ValueError, "The `factor` argument should be a number.*" + ): + # Passing a value less than -1.0 + layers.RandomBrightness(factor=-1.1) + + def test_factor_range_above_upper_bound(self): + with self.assertRaisesRegex( + ValueError, "The `factor` argument should be a number.*" + ): + # Passing a value more than 1.0 + layers.RandomBrightness(factor=1.1) + + def test_randomly_adjust_brightness_input_incorrect_rank(self): + layer = layers.RandomBrightness(factor=0.1) + wrong_rank_input = np.random.rand(10, 10) + + with self.assertRaisesRegex( + ValueError, + "Expected the input image to be rank 3 or 4.", + ): + layer( + wrong_rank_input, training=True + ) # Call the method that triggers the error + + def test_dict_input(self): + layer = layers.RandomBrightness(factor=0.1, bounding_box_format="xyxy") + data = { + "images": np.random.random((2, 4, 5, 3)), + "labels": np.random.random((2, 7)), + "segmentation_masks": np.random.random((2, 4, 5, 7)), + "bounding_boxes": { + "boxes": np.array([[1, 2, 2, 3]]), + "labels": np.array([0]), + }, + } + transformed_data = layer(data) + self.assertEqual( + data["images"].shape[:-1], + transformed_data["segmentation_masks"].shape[:-1], + ) + self.assertAllClose(data["labels"], transformed_data["labels"]) + self.assertAllClose( + data["bounding_boxes"]["boxes"], + transformed_data["bounding_boxes"]["boxes"], + ) + self.assertAllClose( + data["bounding_boxes"]["labels"], + transformed_data["bounding_boxes"]["labels"], + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py new file mode 100644 index 000000000000..94bce40ad174 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py @@ -0,0 +1,135 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomColorDegeneration") +class RandomColorDegeneration(BaseImagePreprocessingLayer): + """Randomly performs the color degeneration operation on given images. + + The sharpness operation first converts an image to gray scale, then back to + color. It then takes a weighted average between original image and the + degenerated image. This makes colors appear more dull. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A tuple of two floats or a single float. + `factor` controls the extent to which the + image sharpness is impacted. `factor=0.0` makes this layer perform a + no-op operation, while a value of 1.0 uses the degenerated result + entirely. Values between 0 and 1 result in linear interpolation + between the original image and the sharpened image. + Values should be between `0.0` and `1.0`. If a tuple is used, a + `factor` is sampled between the two values for every image + augmented. If a single float is used, a value between `0.0` and the + passed float is sampled. In order to ensure the value is always the + same, please pass a tuple with two identical floats: `(0.5, 0.5)`. + seed: Integer. Used to create a random seed. + """ + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + (batch_size, 1, 1, 1), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + factor = factor + return {"factor": factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + factor = self.backend.cast( + transformation["factor"], self.compute_dtype + ) + degenerates = self.backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + images = images + factor * (degenerates - images) + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py new file mode 100644 index 000000000000..18a0adc7c1f6 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomColorDegenerationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomColorDegeneration, + init_kwargs={ + "factor": 0.75, + "value_range": (0, 1), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_color_degeneration_value_range(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1)) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_color_degeneration_no_op(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + + layer = layers.RandomColorDegeneration((0.5, 0.5)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_color_degeneration_factor_zero(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorDegeneration(factor=(0.0, 0.0)) + result = layer(inputs) + + self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5) + + def test_random_color_degeneration_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomColorDegeneration(0.2) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorDegeneration( + factor=0.5, data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py new file mode 100644 index 000000000000..72a9024b10bc --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py @@ -0,0 +1,213 @@ +import keras.src.layers.preprocessing.image_preprocessing.random_brightness as random_brightness # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_contrast as random_contrast # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_hue as random_hue # noqa: E501 +import keras.src.layers.preprocessing.image_preprocessing.random_saturation as random_saturation # noqa: E501 +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomColorJitter") +class RandomColorJitter(BaseImagePreprocessingLayer): + """RandomColorJitter class randomly apply brightness, contrast, saturation + and hue image processing operation sequentially and randomly on the + input. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + value_range: the range of values the incoming images will have. + Represented as a two number tuple written [low, high]. + This is typically either `[0, 1]` or `[0, 255]` depending + on how your preprocessing pipeline is set up. + brightness_factor: Float or a list/tuple of 2 floats between -1.0 + and 1.0. The factor is used to determine the lower bound and + upper bound of the brightness adjustment. A float value will + be chosen randomly between the limits. When -1.0 is chosen, + the output image will be black, and when 1.0 is chosen, the + image will be fully white. When only one float is provided, + eg, 0.2, then -0.2 will be used for lower bound and 0.2 will + be used for upper bound. + contrast_factor: a positive float represented as fraction of value, + or a tuple of size 2 representing lower and upper bound. When + represented as a single float, lower = upper. The contrast + factor will be randomly picked between `[1.0 - lower, 1.0 + + upper]`. For any pixel x in the channel, the output will be + `(x - mean) * factor + mean` where `mean` is the mean value + of the channel. + saturation_factor: A tuple of two floats or a single float. `factor` + controls the extent to which the image saturation is impacted. + `factor=0.5` makes this layer perform a no-op operation. + `factor=0.0` makes the image fully grayscale. `factor=1.0` + makes the image fully saturated. Values should be between + `0.0` and `1.0`. If a tuple is used, a `factor` is sampled + between the two values for every image augmented. If a single + float is used, a value between `0.0` and the passed float is + sampled. To ensure the value is always the same, pass a tuple + with two identical floats: `(0.5, 0.5)`. + hue_factor: A single float or a tuple of two floats. `factor` + controls the extent to which the image hue is impacted. + `factor=0.0` makes this layer perform a no-op operation, + while a value of `1.0` performs the most aggressive contrast + adjustment available. If a tuple is used, a `factor` is + sampled between the two values for every image augmented. + If a single float is used, a value between `0.0` and the + passed float is sampled. In order to ensure the value is + always the same, please pass a tuple with two identical + floats: `(0.5, 0.5)`. + seed: Integer. Used to create a random seed. + """ + + def __init__( + self, + value_range=(0, 255), + brightness_factor=None, + contrast_factor=None, + saturation_factor=None, + hue_factor=None, + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.value_range = value_range + self.brightness_factor = brightness_factor + self.contrast_factor = contrast_factor + self.saturation_factor = saturation_factor + self.hue_factor = hue_factor + self.seed = seed + self.generator = SeedGenerator(seed) + + self.random_brightness = None + self.random_contrast = None + self.random_saturation = None + self.random_hue = None + + if self.brightness_factor is not None: + self.random_brightness = random_brightness.RandomBrightness( + factor=self.brightness_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.contrast_factor is not None: + self.random_contrast = random_contrast.RandomContrast( + factor=self.contrast_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.saturation_factor is not None: + self.random_saturation = random_saturation.RandomSaturation( + factor=self.saturation_factor, + value_range=self.value_range, + seed=self.seed, + ) + + if self.hue_factor is not None: + self.random_hue = random_hue.RandomHue( + factor=self.hue_factor, + value_range=self.value_range, + seed=self.seed, + ) + + def build(self, input_shape): + if self.brightness_factor is not None: + self.random_brightness.build(input_shape) + + if self.contrast_factor is not None: + self.random_contrast.build(input_shape) + + if self.saturation_factor is not None: + self.random_saturation.build(input_shape) + + if self.hue_factor is not None: + self.random_hue.build(input_shape) + + def transform_images(self, images, transformation, training=True): + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + images = self.backend.cast(images, self.compute_dtype) + if self.brightness_factor is not None: + if backend_utils.in_tf_graph(): + self.random_brightness.backend.set_backend("tensorflow") + transformation = ( + self.random_brightness.get_random_transformation( + images, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + images = self.random_brightness.transform_images( + images, transformation + ) + if self.contrast_factor is not None: + if backend_utils.in_tf_graph(): + self.random_contrast.backend.set_backend("tensorflow") + transformation = self.random_contrast.get_random_transformation( + images, seed=self._get_seed_generator(self.backend._backend) + ) + transformation["contrast_factor"] = self.backend.cast( + transformation["contrast_factor"], dtype=self.compute_dtype + ) + images = self.random_contrast.transform_images( + images, transformation + ) + if self.saturation_factor is not None: + if backend_utils.in_tf_graph(): + self.random_saturation.backend.set_backend("tensorflow") + transformation = ( + self.random_saturation.get_random_transformation( + images, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + images = self.random_saturation.transform_images( + images, transformation + ) + if self.hue_factor is not None: + if backend_utils.in_tf_graph(): + self.random_hue.backend.set_backend("tensorflow") + transformation = self.random_hue.get_random_transformation( + images, seed=self._get_seed_generator(self.backend._backend) + ) + images = self.random_hue.transform_images( + images, transformation + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "value_range": self.value_range, + "brightness_factor": self.brightness_factor, + "contrast_factor": self.contrast_factor, + "saturation_factor": self.saturation_factor, + "hue_factor": self.hue_factor, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py new file mode 100644 index 000000000000..a465970b6b45 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py @@ -0,0 +1,135 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomColorJitterTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomColorJitter, + init_kwargs={ + "value_range": (20, 200), + "brightness_factor": 0.2, + "contrast_factor": 0.2, + "saturation_factor": 0.2, + "hue_factor": 0.2, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_color_jitter_inference(self): + seed = 3481 + layer = layers.RandomColorJitter( + value_range=(0, 1), + brightness_factor=0.1, + contrast_factor=0.2, + saturation_factor=0.9, + hue_factor=0.1, + ) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_brightness_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter( + brightness_factor=[0.5, 0.5], seed=seed + ) + output = backend.convert_to_numpy(layer(inputs)) + + layer = layers.RandomBrightness(factor=[0.5, 0.5], seed=seed) + sub_output = backend.convert_to_numpy(layer(inputs)) + + self.assertAllClose(output, sub_output) + + def test_saturation_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter( + saturation_factor=[0.5, 0.5], seed=seed + ) + output = layer(inputs) + + layer = layers.RandomSaturation(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_hue_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter(hue_factor=[0.5, 0.5], seed=seed) + output = layer(inputs) + + layer = layers.RandomHue(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_contrast_only(self): + seed = 2390 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + else: + inputs = np.random.random((12, 3, 8, 16)) + + layer = layers.RandomColorJitter(contrast_factor=[0.5, 0.5], seed=seed) + output = layer(inputs) + + layer = layers.RandomContrast(factor=[0.5, 0.5], seed=seed) + sub_output = layer(inputs) + + self.assertAllClose(output, sub_output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorJitter( + value_range=(0, 1), + brightness_factor=0.1, + contrast_factor=0.2, + saturation_factor=0.9, + hue_factor=0.1, + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py b/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py new file mode 100644 index 000000000000..ec6e2207a69f --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_contrast.py @@ -0,0 +1,149 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomContrast") +class RandomContrast(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly adjusts contrast during training. + + This layer will randomly adjust the contrast of an image or images + by a random factor. Contrast is adjusted independently + for each channel of each image during training. + + For each channel, this layer computes the mean of the image pixels in the + channel and then adjusts each component `x` of each pixel to + `(x - mean) * contrast_factor + mean`. + + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and + in integer or floating point dtype. + By default, the layer will output floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format. + + Args: + factor: a positive float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound. + When represented as a single float, lower = upper. + The contrast factor will be randomly picked between + `[1.0 - lower, 1.0 + upper]`. For any pixel x in the channel, + the output will be `(x - mean) * factor + mean` + where `mean` is the mean value of the channel. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + """ + + _FACTOR_BOUNDS = (0, 1) + + def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs): + super().__init__(**kwargs) + self._set_factor(factor) + self.value_range = value_range + self.seed = seed + self.generator = SeedGenerator(seed) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + factor_shape = (1, 1, 1) + elif rank == 4: + # Keep only the batch dim. This will ensure to have same adjustment + # with in one image, but different across the images. + factor_shape = [images_shape[0], 1, 1, 1] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) + + if not training: + return {"contrast_factor": self.backend.numpy.zeros(factor_shape)} + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + shape=factor_shape, + minval=1.0 - self.factor[0], + maxval=1.0 + self.factor[1], + seed=seed, + dtype=self.compute_dtype, + ) + return {"contrast_factor": factor} + + def transform_images(self, images, transformation, training=True): + if training: + constrast_factor = transformation["contrast_factor"] + outputs = self._adjust_constrast(images, constrast_factor) + outputs = self.backend.numpy.clip( + outputs, self.value_range[0], self.value_range[1] + ) + self.backend.numpy.reshape(outputs, self.backend.shape(images)) + return outputs + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def _adjust_constrast(self, inputs, contrast_factor): + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + # reduce mean on height + inp_mean = self.backend.numpy.mean( + inputs, axis=height_axis, keepdims=True + ) + # reduce mean on width + inp_mean = self.backend.numpy.mean( + inp_mean, axis=width_axis, keepdims=True + ) + + outputs = (inputs - inp_mean) * contrast_factor + inp_mean + return outputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py new file mode 100644 index 000000000000..a0f9cc24cf57 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py @@ -0,0 +1,131 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomContrastTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomContrast, + init_kwargs={ + "factor": 0.75, + "value_range": (0, 255), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + self.run_layer_test( + layers.RandomContrast, + init_kwargs={ + "factor": 0.75, + "value_range": (0, 255), + "seed": 1, + "data_format": "channels_first", + }, + input_shape=(8, 3, 4, 4), + supports_masking=False, + expected_output_shape=(8, 3, 4, 4), + ) + + def test_random_contrast_with_value_range_0_to_255(self): + seed = 9809 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + height_axis = -3 + width_axis = -2 + else: + inputs = np.random.random((12, 3, 8, 16)) + height_axis = -2 + width_axis = -1 + + inputs = backend.convert_to_tensor(inputs, dtype="float32") + layer = layers.RandomContrast( + factor=0.5, value_range=(0, 255), seed=seed + ) + transformation = layer.get_random_transformation(inputs, training=True) + outputs = layer.transform_images(inputs, transformation, training=True) + + # Actual contrast arithmetic + np.random.seed(seed) + factor = backend.convert_to_numpy(transformation["contrast_factor"]) + inputs = backend.convert_to_numpy(inputs) + inp_mean = np.mean(inputs, axis=height_axis, keepdims=True) + inp_mean = np.mean(inp_mean, axis=width_axis, keepdims=True) + actual_outputs = (inputs - inp_mean) * factor + inp_mean + outputs = backend.convert_to_numpy(outputs) + actual_outputs = np.clip(actual_outputs, 0, 255) + + self.assertAllClose(outputs, actual_outputs) + + def test_random_contrast_with_value_range_0_to_1(self): + seed = 9809 + np.random.seed(seed) + + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((12, 8, 16, 3)) + height_axis = -3 + width_axis = -2 + else: + inputs = np.random.random((12, 3, 8, 16)) + height_axis = -2 + width_axis = -1 + + inputs = backend.convert_to_tensor(inputs, dtype="float32") + layer = layers.RandomContrast(factor=0.5, value_range=(0, 1), seed=seed) + transformation = layer.get_random_transformation(inputs, training=True) + outputs = layer.transform_images(inputs, transformation, training=True) + + # Actual contrast arithmetic + np.random.seed(seed) + factor = backend.convert_to_numpy(transformation["contrast_factor"]) + inputs = backend.convert_to_numpy(inputs) + inp_mean = np.mean(inputs, axis=height_axis, keepdims=True) + inp_mean = np.mean(inp_mean, axis=width_axis, keepdims=True) + actual_outputs = (inputs - inp_mean) * factor + inp_mean + outputs = backend.convert_to_numpy(outputs) + actual_outputs = np.clip(actual_outputs, 0, 1) + + self.assertAllClose(outputs, actual_outputs) + + def test_tf_data_compatibility(self): + layer = layers.RandomContrast(factor=0.5, seed=1337) + input_data = np.random.random((2, 8, 8, 3)) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + next(iter(ds)).numpy() + + def test_dict_input(self): + layer = layers.RandomContrast(factor=0.1, bounding_box_format="xyxy") + data = { + "images": np.random.random((2, 4, 5, 3)), + "labels": np.random.random((2, 7)), + "segmentation_masks": np.random.random((2, 4, 5, 7)), + "bounding_boxes": { + "boxes": np.array([[1, 2, 2, 3]]), + "labels": np.array([0]), + }, + } + transformed_data = layer(data) + self.assertEqual( + data["images"].shape[:-1], + transformed_data["segmentation_masks"].shape[:-1], + ) + self.assertAllClose(data["labels"], transformed_data["labels"]) + self.assertAllClose( + data["bounding_boxes"]["boxes"], + transformed_data["bounding_boxes"]["boxes"], + ) + self.assertAllClose( + data["bounding_boxes"]["labels"], + transformed_data["bounding_boxes"]["labels"], + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py new file mode 100644 index 000000000000..2dc8aec5a105 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py @@ -0,0 +1,276 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import ( # noqa: E501 + densify_bounding_boxes, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomCrop") +class RandomCrop(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly crops images during training. + + During training, this layer will randomly choose a location to crop images + down to a target size. The layer will crop all the images in the same batch + to the same cropping location. + + At inference time, and during training if an input image is smaller than the + target size, the input will be resized and cropped so as to return the + largest possible window in the image that matches the target aspect ratio. + If you need to apply random cropping at inference time, set `training` to + True when calling the layer. + + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and + of integer or floating point dtype. By default, the layer will output + floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`. + + Args: + height: Integer, the height of the output shape. + width: Integer, the width of the output shape. + seed: Integer. Used to create a random seed. + **kwargs: Base layer keyword arguments, such as + `name` and `dtype`. + """ + + def __init__( + self, height, width, seed=None, data_format=None, name=None, **kwargs + ): + super().__init__(name=name, **kwargs) + self.height = height + self.width = width + self.seed = ( + seed if seed is not None else backend.random.make_default_seed() + ) + self.generator = SeedGenerator(seed) + self.data_format = backend.standardize_data_format(data_format) + + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + elif self.data_format == "channels_last": + self.height_axis = -3 + self.width_axis = -2 + + self.supports_masking = False + self.supports_jit = False + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + + def get_random_transformation(self, data, training=True, seed=None): + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + if isinstance(data, dict): + input_shape = self.backend.shape(data["images"]) + else: + input_shape = self.backend.shape(data) + + input_height, input_width = ( + input_shape[self.height_axis], + input_shape[self.width_axis], + ) + if input_height is None or input_width is None: + raise ValueError( + "RandomCrop requires the input to have a fully defined " + f"height and width. Received: images.shape={input_shape}" + ) + + if training and input_height > self.height and input_width > self.width: + h_start = self.backend.cast( + self.backend.random.uniform( + (), + 0, + maxval=float(input_height - self.height + 1), + seed=seed, + ), + "int32", + ) + w_start = self.backend.cast( + self.backend.random.uniform( + (), + 0, + maxval=float(input_width - self.width + 1), + seed=seed, + ), + "int32", + ) + else: + crop_height = int(float(input_width * self.height) / self.width) + crop_height = max(min(input_height, crop_height), 1) + crop_width = int(float(input_height * self.width) / self.height) + crop_width = max(min(input_width, crop_width), 1) + h_start = int(float(input_height - crop_height) / 2) + w_start = int(float(input_width - crop_width) / 2) + + return h_start, w_start + + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + crop_box_hstart, crop_box_wstart = transformation + crop_height = self.height + crop_width = self.width + + if self.data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + + shape = self.backend.shape(images) + new_height = shape[self.height_axis] + new_width = shape[self.width_axis] + if ( + not isinstance(new_height, int) + or not isinstance(new_width, int) + or new_height != self.height + or new_width != self.width + ): + # Resize images if size mismatch or + # if size mismatch cannot be determined + # (in the case of a TF dynamic shape). + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) + # Resize may have upcasted the outputs + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + """ + bounding_boxes = { + "boxes": (batch, num_boxes, 4), # left-top-right-bottom (xyxy) + "labels": (batch, num_boxes, num_classes), + } + or + bounding_boxes = { + "boxes": (num_boxes, 4), + "labels": (num_boxes, num_classes), + } + """ + + if training: + h_start, w_start = transformation + if not self.backend.is_tensor(bounding_boxes["boxes"]): + bounding_boxes = densify_bounding_boxes( + bounding_boxes, backend=self.backend + ) + boxes = bounding_boxes["boxes"] + # Convert to a standard xyxy as operations are done xyxy by default. + boxes = convert_format( + boxes=boxes, + source=self.bounding_box_format, + target="xyxy", + height=self.height, + width=self.width, + ) + h_start = self.backend.cast(h_start, boxes.dtype) + w_start = self.backend.cast(w_start, boxes.dtype) + if len(self.backend.shape(boxes)) == 3: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), + ], + axis=-1, + ) + else: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), + ], + axis=-1, + ) + + # Convert to user defined bounding box format + boxes = convert_format( + boxes=boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) + + return { + "boxes": boxes, + "labels": bounding_boxes["labels"], + } + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images(segmentation_masks, transformation) + + def compute_output_shape(self, input_shape, *args, **kwargs): + input_shape = list(input_shape) + input_shape[self.height_axis] = self.height + input_shape[self.width_axis] = self.width + return tuple(input_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "height": self.height, + "width": self.width, + "seed": self.seed, + "data_format": self.data_format, + } + ) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py new file mode 100644 index 000000000000..c4796a2b2248 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py @@ -0,0 +1,165 @@ +import numpy as np +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomCropTest(testing.TestCase): + def test_random_crop(self): + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 2, + "width": 2, + "data_format": "channels_last", + }, + input_shape=(1, 3, 4, 3), + supports_masking=False, + run_training_check=False, + expected_output_shape=(1, 2, 2, 3), + ) + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 2, + "width": 2, + "data_format": "channels_last", + }, + input_shape=(3, 4, 3), + supports_masking=False, + run_training_check=False, + expected_output_shape=(2, 2, 3), + ) + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 2, + "width": 2, + "data_format": "channels_first", + }, + input_shape=(1, 3, 3, 4), + supports_masking=False, + run_training_check=False, + expected_output_shape=(1, 3, 2, 2), + ) + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 2, + "width": 2, + "data_format": "channels_first", + }, + input_shape=(3, 3, 4), + supports_masking=False, + run_training_check=False, + expected_output_shape=(3, 2, 2), + ) + + def test_random_crop_full(self): + np.random.seed(1337) + height, width = 8, 16 + if backend.config.image_data_format() == "channels_last": + input_shape = (12, 8, 16, 3) + else: + input_shape = (12, 3, 8, 16) + inp = np.random.random(input_shape) + layer = layers.RandomCrop(height, width) + actual_output = layer(inp, training=False) + self.assertAllClose(inp, actual_output) + + def test_random_crop_partial(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (12, 8, 16, 3) + output_shape = (12, 8, 8, 3) + else: + input_shape = (12, 3, 8, 16) + output_shape = (12, 3, 8, 8) + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 8, + "width": 8, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + supports_masking=False, + run_training_check=False, + ) + + def test_predicting_with_longer_height(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (12, 8, 16, 3) + output_shape = (12, 10, 8, 3) + else: + input_shape = (12, 3, 8, 16) + output_shape = (12, 3, 10, 8) + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 10, + "width": 8, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + supports_masking=False, + run_training_check=False, + ) + + def test_predicting_with_longer_width(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (12, 8, 16, 3) + output_shape = (12, 8, 18, 3) + else: + input_shape = (12, 3, 8, 16) + output_shape = (12, 3, 8, 18) + self.run_layer_test( + layers.RandomCrop, + init_kwargs={ + "height": 8, + "width": 18, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + supports_masking=False, + run_training_check=False, + ) + + def test_tf_data_compatibility(self): + layer = layers.RandomCrop(8, 9) + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 12, 3) + output_shape = (2, 8, 9, 3) + else: + input_shape = (2, 3, 10, 12) + output_shape = (2, 3, 8, 9) + input_data = np.random.random(input_shape) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + output = next(iter(ds)).numpy() + self.assertEqual(tuple(output.shape), output_shape) + + def test_dict_input(self): + layer = layers.RandomCrop( + 3, 3, data_format="channels_last", bounding_box_format="xyxy" + ) + data = { + "images": np.random.random((2, 4, 5, 3)), + "labels": np.random.random((2, 7)), + "segmentation_masks": np.random.random((2, 4, 5, 7)), + "bounding_boxes": { + "boxes": np.array([[1, 2, 2, 3]]), + "labels": np.array([0]), + }, + } + transformed_data = layer(data) + self.assertEqual( + data["images"].shape[:-1], + transformed_data["segmentation_masks"].shape[:-1], + ) + self.assertAllClose(data["labels"], transformed_data["labels"]) + self.assertEqual(data["bounding_boxes"]["boxes"].shape, (1, 4)) + self.assertAllClose( + data["bounding_boxes"]["labels"], + transformed_data["bounding_boxes"]["labels"], + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py b/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py new file mode 100644 index 000000000000..6f2e4e15080e --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py @@ -0,0 +1,279 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomElasticTransform") +class RandomElasticTransform(BaseImagePreprocessingLayer): + """A preprocessing layer that applies random elastic transformations. + + This layer distorts input images by applying elastic deformations, + simulating a physically realistic transformation. The magnitude of the + distortion is controlled by the `scale` parameter, while the `factor` + determines the probability of applying the transformation. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A single float or a tuple of two floats. + `factor` controls the probability of applying the transformation. + - `factor=0.0` ensures no erasing is applied. + - `factor=1.0` means erasing is always applied. + - If a tuple `(min, max)` is provided, a probability value + is sampled between `min` and `max` for each image. + - If a single float is provided, a probability is sampled + between `0.0` and the given float. + Default is 1.0. + scale: A float or a tuple of two floats defining the magnitude of + the distortion applied. + - If a tuple `(min, max)` is provided, a random scale value is + sampled within this range. + - If a single float is provided, a random scale value is sampled + between `0.0` and the given float. + Default is 1.0. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode="constant"`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + _SUPPORTED_FILL_MODES = { + "constant", + "nearest", + "wrap", + "mirror", + "reflect", + } + + def __init__( + self, + factor=1.0, + scale=1.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + value_range=(0, 255), + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.scale = self._set_factor_by_name(scale, "scale") + self.interpolation = interpolation + self.fill_mode = fill_mode + self.fill_value = fill_value + self.value_range = value_range + self.seed = seed + self.generator = SeedGenerator(seed) + + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + if fill_mode not in self._SUPPORTED_FILL_MODES: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODES}." + ) + + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + self.channel_axis = -3 + else: + self.height_axis = -3 + self.width_axis = -2 + self.channel_axis = -1 + + def _set_factor_by_name(self, factor, name): + error_msg = ( + f"The `{name}` argument should be a number " + "(or a list of two numbers) " + "in the range " + f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. " + f"Received: factor={factor}" + ) + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError(error_msg) + if ( + factor[0] > self._FACTOR_BOUNDS[1] + or factor[1] < self._FACTOR_BOUNDS[0] + ): + raise ValueError(error_msg) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + if ( + factor < self._FACTOR_BOUNDS[0] + or factor > self._FACTOR_BOUNDS[1] + ): + raise ValueError(error_msg) + factor = abs(factor) + lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor] + else: + raise ValueError(error_msg) + return lower, upper + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if (self.scale[1] == 0) or (self.factor[1] == 0): + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + unbatched = len(images_shape) == 3 + if unbatched: + batch_size = 1 + else: + batch_size = images_shape[0] + + seed = seed or self._get_seed_generator(self.backend._backend) + + transformation_probability = self.backend.random.uniform( + shape=(batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + + random_threshold = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + apply_transform = random_threshold < transformation_probability + + distortion_factor = self.backend.random.uniform( + shape=(), + minval=self.scale[0], + maxval=self.scale[1], + seed=seed, + dtype=self.compute_dtype, + ) + + return { + "apply_transform": apply_transform, + "distortion_factor": distortion_factor, + "seed": seed, + } + + def get_elastic_transform_params(self, height, width, factor): + alpha_scale = 0.1 * factor + sigma_scale = 0.05 * factor + + alpha = max(height, width) * alpha_scale + sigma = min(height, width) * sigma_scale + + return alpha, sigma + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training and transformation is not None: + apply_transform = transformation["apply_transform"] + distortion_factor = transformation["distortion_factor"] + seed = transformation["seed"] + + height, width = ( + images.shape[self.height_axis], + images.shape[self.width_axis], + ) + + alpha, sigma = self.get_elastic_transform_params( + height, width, distortion_factor + ) + + transformed_images = self.backend.image.elastic_transform( + images, + alpha=alpha, + sigma=sigma, + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + seed=seed, + data_format=self.data_format, + ) + + apply_transform = ( + apply_transform[:, None, None] + if len(images.shape) == 3 + else apply_transform[:, None, None, None] + ) + + images = self.backend.numpy.where( + apply_transform, + transformed_images, + images, + ) + + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "factor": self.factor, + "scale": self.scale, + "interpolation": self.interpolation, + "fill_mode": self.fill_mode, + "fill_value": self.fill_value, + "value_range": self.value_range, + "seed": self.seed, + } + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform_test.py new file mode 100644 index 000000000000..b0500808d2a6 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform_test.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomElasticTransformTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomElasticTransform, + init_kwargs={ + "factor": 1.0, + "scale": 0.5, + "interpolation": "bilinear", + "fill_mode": "reflect", + "fill_value": 0, + "value_range": (0, 255), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + run_training_check=False, + ) + + def test_random_elastic_transform_inference(self): + seed = 3481 + layer = layers.RandomElasticTransform() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_elastic_transform_no_op(self): + seed = 3481 + layer = layers.RandomElasticTransform(factor=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + layer = layers.RandomElasticTransform(scale=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + def test_random_elastic_transform_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.zeros((8, 8, 1)) + inputs[3:5, 3:5, :] = 1.0 + else: + inputs = np.zeros((1, 8, 8)) + inputs[:, 3:5, 3:5] = 1.0 + + layer = layers.RandomElasticTransform(data_format=data_format) + + transformation = { + "apply_transform": np.array([True]), + "distortion_factor": np.float32(0.9109325), + "seed": 42, + } + + output = layer.transform_images(inputs, transformation) + + self.assertNotAllClose(inputs, output) + self.assertEqual(inputs.shape, output.shape) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomElasticTransform(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + print("Output shape:", output.shape) # Debugging line + output_numpy = output.numpy() + print("Output numpy shape:", output_numpy.shape) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py new file mode 100644 index 000000000000..b593c7cbad2b --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing.py @@ -0,0 +1,328 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomErasing") +class RandomErasing(BaseImagePreprocessingLayer): + """Random Erasing data augmentation technique. + + Random Erasing is a data augmentation method where random patches of + an image are erased (replaced by a constant value or noise) + during training to improve generalization. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [Random Erasing paper](https://arxiv.org/abs/1708.04896). + + Args: + factor: A single float or a tuple of two floats. + `factor` controls the probability of applying the transformation. + - `factor=0.0` ensures no erasing is applied. + - `factor=1.0` means erasing is always applied. + - If a tuple `(min, max)` is provided, a probability value + is sampled between `min` and `max` for each image. + - If a single float is provided, a probability is sampled + between `0.0` and the given float. + Default is 1.0. + scale: A tuple of two floats representing the aspect ratio range of + the erased patch. This defines the width-to-height ratio of + the patch to be erased. It can help control the rw shape of + the erased region. Default is (0.02, 0.33). + fill_value: A value to fill the erased region with. This can be set to + a constant value or `None` to sample a random value + from a normal distribution. Default is `None`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + def __init__( + self, + factor=1.0, + scale=(0.02, 0.33), + fill_value=None, + value_range=(0, 255), + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.scale = self._set_factor_by_name(scale, "scale") + self.fill_value = fill_value + self.value_range = value_range + self.seed = seed + self.generator = SeedGenerator(seed) + + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + self.channel_axis = -3 + else: + self.height_axis = -3 + self.width_axis = -2 + self.channel_axis = -1 + + def _set_factor_by_name(self, factor, name): + error_msg = ( + f"The `{name}` argument should be a number " + "(or a list of two numbers) " + "in the range " + f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. " + f"Received: factor={factor}" + ) + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError(error_msg) + if ( + factor[0] > self._FACTOR_BOUNDS[1] + or factor[1] < self._FACTOR_BOUNDS[0] + ): + raise ValueError(error_msg) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + if ( + factor < self._FACTOR_BOUNDS[0] + or factor > self._FACTOR_BOUNDS[1] + ): + raise ValueError(error_msg) + factor = abs(factor) + lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor] + else: + raise ValueError(error_msg) + return lower, upper + + def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed): + crop_length = self.backend.cast( + crop_ratio * image_length, dtype=self.compute_dtype + ) + + start_pos = self.backend.random.uniform( + shape=[batch_size], + minval=0, + maxval=1, + dtype=self.compute_dtype, + seed=seed, + ) * (image_length - crop_length) + + end_pos = start_pos + crop_length + + return start_pos, end_pos + + def _generate_batch_mask(self, images_shape, box_corners): + def _generate_grid_xy(image_height, image_width): + grid_y, grid_x = self.backend.numpy.meshgrid( + self.backend.numpy.arange( + image_height, dtype=self.compute_dtype + ), + self.backend.numpy.arange( + image_width, dtype=self.compute_dtype + ), + indexing="ij", + ) + if self.data_format == "channels_last": + grid_y = self.backend.cast( + grid_y[None, :, :, None], dtype=self.compute_dtype + ) + grid_x = self.backend.cast( + grid_x[None, :, :, None], dtype=self.compute_dtype + ) + else: + grid_y = self.backend.cast( + grid_y[None, None, :, :], dtype=self.compute_dtype + ) + grid_x = self.backend.cast( + grid_x[None, None, :, :], dtype=self.compute_dtype + ) + return grid_x, grid_y + + image_height, image_width = ( + images_shape[self.height_axis], + images_shape[self.width_axis], + ) + grid_x, grid_y = _generate_grid_xy(image_height, image_width) + + x0, x1, y0, y1 = box_corners + + x0 = x0[:, None, None, None] + y0 = y0[:, None, None, None] + x1 = x1[:, None, None, None] + y1 = y1[:, None, None, None] + + batch_masks = ( + (grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1) + ) + batch_masks = self.backend.numpy.repeat( + batch_masks, images_shape[self.channel_axis], axis=self.channel_axis + ) + + return batch_masks + + def _get_fill_value(self, images, images_shape, seed): + fill_value = self.fill_value + if fill_value is None: + fill_value = ( + self.backend.random.normal( + images_shape, + dtype=self.compute_dtype, + seed=seed, + ) + * self.value_range[1] + ) + else: + error_msg = ( + "The `fill_value` argument should be a number " + "(or a list of three numbers) " + ) + if isinstance(fill_value, (tuple, list)): + if len(fill_value) != 3: + raise ValueError(error_msg) + fill_value = self.backend.numpy.full_like( + images, fill_value, dtype=self.compute_dtype + ) + elif isinstance(fill_value, (int, float)): + fill_value = ( + self.backend.numpy.ones( + images_shape, dtype=self.compute_dtype + ) + * fill_value + ) + else: + raise ValueError(error_msg) + fill_value = self.backend.numpy.clip( + fill_value, self.value_range[0], self.value_range[1] + ) + return fill_value + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) + + image_height = images_shape[self.height_axis] + image_width = images_shape[self.width_axis] + + seed = seed or self._get_seed_generator(self.backend._backend) + + mix_weight = self.backend.random.uniform( + shape=(batch_size, 2), + minval=self.scale[0], + maxval=self.scale[1], + dtype=self.compute_dtype, + seed=seed, + ) + + mix_weight = self.backend.numpy.sqrt(mix_weight) + + x0, x1 = self._compute_crop_bounds( + batch_size, image_width, mix_weight[:, 0], seed + ) + y0, y1 = self._compute_crop_bounds( + batch_size, image_height, mix_weight[:, 1], seed + ) + + batch_masks = self._generate_batch_mask( + images_shape, + (x0, x1, y0, y1), + ) + + erase_probability = self.backend.random.uniform( + shape=(batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + + random_threshold = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + apply_erasing = random_threshold < erase_probability + + fill_value = self._get_fill_value(images, images_shape, seed) + + return { + "apply_erasing": apply_erasing, + "batch_masks": batch_masks, + "fill_value": fill_value, + } + + def transform_images(self, images, transformation=None, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + batch_masks = transformation["batch_masks"] + apply_erasing = transformation["apply_erasing"] + fill_value = transformation["fill_value"] + + erased_images = self.backend.numpy.where( + batch_masks, + fill_value, + images, + ) + + images = self.backend.numpy.where( + apply_erasing[:, None, None, None], + erased_images, + images, + ) + + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "scale": self.scale, + "fill_value": self.fill_value, + "value_range": self.value_range, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py new file mode 100644 index 000000000000..1db6ae654eaa --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomErasingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomErasing, + init_kwargs={ + "factor": 1.0, + "scale": 0.5, + "fill_value": 0, + "value_range": (0, 255), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_erasing_inference(self): + seed = 3481 + layer = layers.RandomErasing() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_erasing_no_op(self): + seed = 3481 + layer = layers.RandomErasing(factor=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + layer = layers.RandomErasing(scale=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + def test_random_erasing_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.ones((2, 2, 1)) + expected_output = np.array([[[[0.0], [1.0]], [[1.0], [1.0]]]]) + + else: + inputs = np.ones((1, 2, 2)) + + expected_output = np.array( + [[[[0.0, 0.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]] + ) + + layer = layers.RandomErasing(data_format=data_format) + + transformation = { + "apply_erasing": np.asarray([True]), + "batch_masks": np.asarray( + [[[[True], [False]], [[False], [False]]]] + ), + "fill_value": 0, + } + + output = layer.transform_images(inputs, transformation) + + print(output) + + self.assertAllClose(expected_output, output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomErasing(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_flip.py b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py new file mode 100644 index 000000000000..553b2a48e0b9 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_flip.py @@ -0,0 +1,236 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + +HORIZONTAL = "horizontal" +VERTICAL = "vertical" +HORIZONTAL_AND_VERTICAL = "horizontal_and_vertical" + + +@keras_export("keras.layers.RandomFlip") +class RandomFlip(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly flips images during training. + + This layer will flip the images horizontally and or vertically based on the + `mode` attribute. During inference time, the output will be identical to + input. Call the layer with `training=True` to flip the input. + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and + of integer or floating point dtype. + By default, the layer will output floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format. + + Args: + mode: String indicating which flip mode to use. Can be `"horizontal"`, + `"vertical"`, or `"horizontal_and_vertical"`. `"horizontal"` is a + left-right flip and `"vertical"` is a top-bottom flip. Defaults to + `"horizontal_and_vertical"` + seed: Integer. Used to create a random seed. + **kwargs: Base layer keyword arguments, such as + `name` and `dtype`. + """ + + _USE_BASE_FACTOR = False + + def __init__( + self, + mode=HORIZONTAL_AND_VERTICAL, + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.seed = seed + self.generator = SeedGenerator(seed) + self.mode = mode + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + shape = self.backend.core.shape(images) + if len(shape) == 3: + flips_shape = (1, 1, 1) + else: + flips_shape = (shape[0], 1, 1, 1) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + flips = self.backend.numpy.less_equal( + self.backend.random.uniform(shape=flips_shape, seed=seed), 0.5 + ) + return {"flips": flips, "input_shape": shape} + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + return self._flip_inputs(images, transformation) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + def _flip_boxes_horizontal(boxes): + x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1) + outputs = self.backend.numpy.concatenate( + [1 - x3, x2, 1 - x1, x4], axis=-1 + ) + return outputs + + def _flip_boxes_vertical(boxes): + x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1) + outputs = self.backend.numpy.concatenate( + [x1, 1 - x4, x3, 1 - x2], axis=-1 + ) + return outputs + + def _transform_xyxy(boxes, box_flips): + bboxes = boxes["boxes"] + if self.mode in {HORIZONTAL, HORIZONTAL_AND_VERTICAL}: + bboxes = self.backend.numpy.where( + box_flips, + _flip_boxes_horizontal(bboxes), + bboxes, + ) + if self.mode in {VERTICAL, HORIZONTAL_AND_VERTICAL}: + bboxes = self.backend.numpy.where( + box_flips, + _flip_boxes_vertical(bboxes), + bboxes, + ) + return bboxes + + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + flips = self.backend.numpy.squeeze(transformation["flips"], axis=-1) + + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="rel_xyxy", + height=input_height, + width=input_width, + ) + + bounding_boxes["boxes"] = _transform_xyxy(bounding_boxes, flips) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + ) + + self.backend.reset() + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def _flip_inputs(self, inputs, transformation): + if transformation is None: + return inputs + + flips = transformation["flips"] + inputs_shape = self.backend.shape(inputs) + unbatched = len(inputs_shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + flipped_outputs = inputs + if self.data_format == "channels_last": + horizontal_axis = -2 + vertical_axis = -3 + else: + horizontal_axis = -1 + vertical_axis = -2 + + if self.mode == HORIZONTAL or self.mode == HORIZONTAL_AND_VERTICAL: + flipped_outputs = self.backend.numpy.where( + flips, + self.backend.numpy.flip(flipped_outputs, axis=horizontal_axis), + flipped_outputs, + ) + if self.mode == VERTICAL or self.mode == HORIZONTAL_AND_VERTICAL: + flipped_outputs = self.backend.numpy.where( + flips, + self.backend.numpy.flip(flipped_outputs, axis=vertical_axis), + flipped_outputs, + ) + if unbatched: + flipped_outputs = self.backend.numpy.squeeze( + flipped_outputs, axis=0 + ) + return flipped_outputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "seed": self.seed, + "mode": self.mode, + "data_format": self.data_format, + } + ) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py new file mode 100644 index 000000000000..c169ca754419 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py @@ -0,0 +1,285 @@ +import unittest.mock + +import numpy as np +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src import utils + + +class MockedRandomFlip(layers.RandomFlip): + def call(self, inputs, training=True): + unbatched = len(inputs.shape) == 3 + batch_size = 1 if unbatched else self.backend.shape(inputs)[0] + mocked_value = self.backend.numpy.full( + (batch_size, 1, 1, 1), 0.1, dtype="float32" + ) + with unittest.mock.patch.object( + self.backend.random, + "uniform", + return_value=mocked_value, + ): + out = super().call(inputs, training=training) + return out + + +class RandomFlipTest(testing.TestCase): + @parameterized.named_parameters( + ("random_flip_horizontal", "horizontal"), + ("random_flip_vertical", "vertical"), + ("random_flip_both", "horizontal_and_vertical"), + ) + def test_random_flip(self, mode): + run_training_check = False if backend.backend() == "numpy" else True + self.run_layer_test( + layers.RandomFlip, + init_kwargs={ + "mode": mode, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + supports_masking=False, + run_training_check=run_training_check, + ) + + def test_random_flip_horizontal(self): + run_training_check = False if backend.backend() == "numpy" else True + utils.set_random_seed(0) + # Test 3D input: shape (1*2*3) + self.run_layer_test( + MockedRandomFlip, + init_kwargs={ + "mode": "horizontal", + "data_format": "channels_last", + "seed": 42, + }, + input_data=np.asarray([[[2, 3, 4], [5, 6, 7]]]), + expected_output=backend.convert_to_tensor([[[5, 6, 7], [2, 3, 4]]]), + supports_masking=False, + run_training_check=run_training_check, + ) + # Test 4D input: shape (2*1*2*3) + self.run_layer_test( + MockedRandomFlip, + init_kwargs={ + "mode": "horizontal", + "data_format": "channels_last", + "seed": 42, + }, + input_data=np.asarray( + [ + [[[2, 3, 4], [5, 6, 7]]], + [[[2, 3, 4], [5, 6, 7]]], + ] + ), + expected_output=backend.convert_to_tensor( + [ + [[[5, 6, 7], [2, 3, 4]]], + [[[5, 6, 7], [2, 3, 4]]], + ] + ), + supports_masking=False, + run_training_check=run_training_check, + ) + + def test_random_flip_vertical(self): + run_training_check = False if backend.backend() == "numpy" else True + utils.set_random_seed(0) + # Test 3D input: shape (2*1*3) + self.run_layer_test( + MockedRandomFlip, + init_kwargs={ + "mode": "vertical", + "data_format": "channels_last", + "seed": 42, + }, + input_data=np.asarray([[[2, 3, 4]], [[5, 6, 7]]]), + expected_output=backend.convert_to_tensor( + [[[5, 6, 7]], [[2, 3, 4]]] + ), + supports_masking=False, + run_training_check=run_training_check, + ) + # Test 4D input: shape (2*2*1*3) + self.run_layer_test( + MockedRandomFlip, + init_kwargs={ + "mode": "vertical", + "data_format": "channels_last", + "seed": 42, + }, + input_data=np.asarray( + [ + [ + [[2, 3, 4]], + [[5, 6, 7]], + ], + [ + [[2, 3, 4]], + [[5, 6, 7]], + ], + ] + ), + expected_output=backend.convert_to_tensor( + [ + [[[5, 6, 7]], [[2, 3, 4]]], + [[[5, 6, 7]], [[2, 3, 4]]], + ] + ), + supports_masking=False, + run_training_check=run_training_check, + ) + + def test_tf_data_compatibility(self): + # Test 3D input: shape (2, 1, 3) + layer = layers.RandomFlip( + "vertical", data_format="channels_last", seed=42 + ) + input_data = np.array([[[2, 3, 4]], [[5, 6, 7]]]) + expected_output = np.array([[[5, 6, 7]], [[2, 3, 4]]]) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + output = next(iter(ds)).numpy() + self.assertAllClose(output, expected_output) + # Test 4D input: shape (2, 2, 1, 3) + layer = layers.RandomFlip( + "vertical", data_format="channels_last", seed=42 + ) + input_data = np.array( + [ + [ + [[2, 3, 4]], + [[5, 6, 7]], + ], + [ + [[2, 3, 4]], + [[5, 6, 7]], + ], + ] + ) + expected_output = np.array( + [ + [[[5, 6, 7]], [[2, 3, 4]]], + [[[5, 6, 7]], [[2, 3, 4]]], + ] + ) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + output = next(iter(ds)).numpy() + self.assertAllClose(output, expected_output) + + @parameterized.named_parameters( + ( + "with_horizontal", + "horizontal", + [[4, 1, 6, 3], [0, 4, 2, 6]], + ), + ( + "with_vertical", + "vertical", + [[2, 7, 4, 9], [6, 4, 8, 6]], + ), + ( + "with_horizontal_and_vertical", + "horizontal_and_vertical", + [[4, 7, 6, 9], [0, 4, 2, 6]], + ), + ) + def test_random_flip_bounding_boxes(self, mode, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + random_flip_layer = layers.RandomFlip( + mode, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "flips": np.asarray([[True]]), + "input_shape": input_image.shape, + } + output = random_flip_layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + + self.assertAllClose(output["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_horizontal", + "horizontal", + [[4, 1, 6, 3], [0, 4, 2, 6]], + ), + ( + "with_vertical", + "vertical", + [[2, 7, 4, 9], [6, 4, 8, 6]], + ), + ( + "with_horizontal_and_vertical", + "horizontal_and_vertical", + [[4, 7, 6, 9], [0, 4, 2, 6]], + ), + ) + def test_random_flip_tf_data_bounding_boxes(self, mode, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + random_flip_layer = layers.RandomFlip( + mode, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "flips": np.asarray([[True]]), + "input_shape": input_image.shape, + } + ds = ds.map( + lambda x: random_flip_layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py b/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py new file mode 100644 index 000000000000..d5d47039d8f7 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py @@ -0,0 +1,220 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomGaussianBlur") +class RandomGaussianBlur(BaseImagePreprocessingLayer): + """Applies random Gaussian blur to images for data augmentation. + + This layer performs a Gaussian blur operation on input images with a + randomly selected degree of blurring, controlled by the `factor` and + `sigma` arguments. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A single float or a tuple of two floats. + `factor` controls the extent to which the image hue is impacted. + `factor=0.0` makes this layer perform a no-op operation, + while a value of `1.0` performs the most aggressive + blurring available. If a tuple is used, a `factor` is + sampled between the two values for every image augmented. If a + single float is used, a value between `0.0` and the passed float is + sampled. Default is 1.0. + kernel_size: Integer. Size of the Gaussian kernel used for blurring. + Must be an odd integer. Default is 3. + sigma: Float or tuple of two floats. Standard deviation of the Gaussian + kernel. Controls the intensity of the blur. If a tuple is provided, + a value is sampled between the two for each image. Default is 1.0. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + def __init__( + self, + factor=1.0, + kernel_size=3, + sigma=1.0, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.kernel_size = self._set_kernel_size(kernel_size, "kernel_size") + self.sigma = self._set_factor_by_name(sigma, "sigma") + self.value_range = value_range + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_kernel_size(self, factor, name): + error_msg = f"{name} must be an odd number. Received: {name}={factor}" + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + error_msg = ( + f"The `{name}` argument should be a number " + "(or a list of two numbers) " + f"Received: {name}={factor}" + ) + raise ValueError(error_msg) + if (factor[0] % 2 == 0) or (factor[1] % 2 == 0): + raise ValueError(error_msg) + lower, upper = factor + elif isinstance(factor, (int, float)): + if factor % 2 == 0: + raise ValueError(error_msg) + lower, upper = factor, factor + else: + raise ValueError(error_msg) + + return lower, upper + + def _set_factor_by_name(self, factor, name): + error_msg = ( + f"The `{name}` argument should be a number " + "(or a list of two numbers) " + "in the range " + f"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. " + f"Received: factor={factor}" + ) + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError(error_msg) + if ( + factor[0] > self._FACTOR_BOUNDS[1] + or factor[1] < self._FACTOR_BOUNDS[0] + ): + raise ValueError(error_msg) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + if ( + factor < self._FACTOR_BOUNDS[0] + or factor > self._FACTOR_BOUNDS[1] + ): + raise ValueError(error_msg) + factor = abs(factor) + lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor] + else: + raise ValueError(error_msg) + return lower, upper + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) + + seed = seed or self._get_seed_generator(self.backend._backend) + + blur_probability = self.backend.random.uniform( + shape=(batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + + random_threshold = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + should_apply_blur = random_threshold < blur_probability + + blur_factor = ( + self.backend.random.uniform( + shape=(2,), + minval=self.sigma[0], + maxval=self.sigma[1], + seed=seed, + dtype=self.compute_dtype, + ) + + 1e-6 + ) + + return { + "should_apply_blur": should_apply_blur, + "blur_factor": blur_factor, + } + + def transform_images(self, images, transformation=None, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training and transformation is not None: + blur_factor = transformation["blur_factor"] + should_apply_blur = transformation["should_apply_blur"] + + blur_images = self.backend.image.gaussian_blur( + images, + kernel_size=self.kernel_size, + sigma=blur_factor, + data_format=self.data_format, + ) + + images = self.backend.numpy.where( + should_apply_blur[:, None, None, None], + blur_images, + images, + ) + + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + + images = self.backend.cast(images, dtype=self.compute_dtype) + + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "kernel_size": self.kernel_size, + "sigma": self.sigma, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py new file mode 100644 index 000000000000..7b69d87d412a --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py @@ -0,0 +1,91 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.backend import convert_to_tensor + + +class RandomGaussianBlurTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomGaussianBlur, + init_kwargs={ + "factor": 1.0, + "kernel_size": 3, + "sigma": 0, + "value_range": (0, 255), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_erasing_inference(self): + seed = 3481 + layer = layers.RandomGaussianBlur() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_erasing_no_op(self): + seed = 3481 + layer = layers.RandomGaussianBlur(factor=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + def test_random_erasing_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.ones((1, 2, 2, 3)) + expected_output = np.asarray( + [ + [ + [[0.7273, 0.7273, 0.7273], [0.7273, 0.7273, 0.7273]], + [[0.7273, 0.7273, 0.7273], [0.7273, 0.7273, 0.7273]], + ] + ] + ) + + else: + inputs = np.ones((1, 3, 2, 2)) + expected_output = np.asarray( + [ + [ + [[0.7273, 0.7273], [0.7273, 0.7273]], + [[0.7273, 0.7273], [0.7273, 0.7273]], + [[0.7273, 0.7273], [0.7273, 0.7273]], + ] + ] + ) + + layer = layers.RandomGaussianBlur(data_format=data_format) + + transformation = { + "blur_factor": convert_to_tensor([0.3732, 0.8654]), + "should_apply_blur": convert_to_tensor([True]), + } + output = layer.transform_images(inputs, transformation) + + self.assertAllClose(expected_output, output, atol=1e-4, rtol=1e-4) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomGaussianBlur(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py new file mode 100644 index 000000000000..238f43f3bdac --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py @@ -0,0 +1,117 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.RandomGrayscale") +class RandomGrayscale(BaseImagePreprocessingLayer): + """Preprocessing layer for random conversion of RGB images to grayscale. + + This layer randomly converts input images to grayscale with a specified + factor. When applied, it maintains the original number of channels + but sets all channels to the same grayscale value. This can be useful + for data augmentation and training models to be robust to color + variations. + + The conversion preserves the perceived luminance of the original color + image using standard RGB to grayscale conversion coefficients. Images + that are not selected for conversion remain unchanged. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: Float between 0 and 1, specifying the factor of + converting each image to grayscale. Defaults to 0.5. A value of + 1.0 means all images will be converted, while 0.0 means no images + will be converted. + data_format: String, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch, channels, height, width)`. + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + Same as input shape. The output maintains the same number of channels + as the input, even for grayscale-converted images where all channels + will have the same value. + """ + + def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs): + super().__init__(**kwargs) + if factor < 0 or factor > 1: + raise ValueError( + f"`factor` should be between 0 and 1. Received: factor={factor}" + ) + self.factor = factor + self.data_format = backend.standardize_data_format(data_format) + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) + + def get_random_transformation(self, images, training=True, seed=None): + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + # Base case: Unbatched data + batch_size = 1 + if len(images.shape) == 4: + # This is a batch of images (4D input) + batch_size = self.backend.core.shape(images)[0] + + random_values = self.backend.random.uniform( + shape=(batch_size,), + minval=0, + maxval=1, + seed=seed, + ) + should_apply = self.backend.numpy.expand_dims( + random_values < self.factor, axis=[1, 2, 3] + ) + return should_apply + + def transform_images(self, images, transformation, training=True): + if training: + should_apply = ( + transformation + if transformation is not None + else self.get_random_transformation(images) + ) + + grayscale_images = self.backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + return self.backend.numpy.where( + should_apply, grayscale_images, images + ) + return images + + def compute_output_shape(self, input_shape): + return input_shape + + def compute_output_spec(self, inputs, **kwargs): + return backend.KerasTensor( + inputs.shape, dtype=inputs.dtype, sparse=inputs.sparse + ) + + def transform_bounding_boxes(self, bounding_boxes, **kwargs): + return bounding_boxes + + def transform_labels(self, labels, transformations=None, **kwargs): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformations=None, **kwargs + ): + return segmentation_masks + + def get_config(self): + config = super().get_config() + config.update({"factor": self.factor}) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py new file mode 100644 index 000000000000..a43dfc55694a --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class RandomGrayscaleTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomGrayscale, + init_kwargs={ + "factor": 0.5, + "data_format": "channels_last", + }, + input_shape=(1, 2, 2, 3), + supports_masking=False, + expected_output_shape=(1, 2, 2, 3), + ) + + self.run_layer_test( + layers.RandomGrayscale, + init_kwargs={ + "factor": 0.5, + "data_format": "channels_first", + }, + input_shape=(1, 3, 2, 2), + supports_masking=False, + expected_output_shape=(1, 3, 2, 2), + ) + + @parameterized.named_parameters( + ("channels_last", "channels_last"), ("channels_first", "channels_first") + ) + def test_grayscale_conversion(self, data_format): + if data_format == "channels_last": + xs = np.random.uniform(0, 255, size=(2, 4, 4, 3)).astype(np.float32) + layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) + transformed = ops.convert_to_numpy(layer(xs)) + self.assertEqual(transformed.shape[-1], 3) + for img in transformed: + r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] + self.assertTrue(np.allclose(r, g) and np.allclose(g, b)) + else: + xs = np.random.uniform(0, 255, size=(2, 3, 4, 4)).astype(np.float32) + layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) + transformed = ops.convert_to_numpy(layer(xs)) + self.assertEqual(transformed.shape[1], 3) + for img in transformed: + r, g, b = img[0], img[1], img[2] + self.assertTrue(np.allclose(r, g) and np.allclose(g, b)) + + def test_invalid_factor(self): + with self.assertRaises(ValueError): + layers.RandomGrayscale(factor=-0.1) + + with self.assertRaises(ValueError): + layers.RandomGrayscale(factor=1.1) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) * 255 + else: + input_data = np.random.random((2, 3, 8, 8)) * 255 + + layer = layers.RandomGrayscale(factor=0.5, data_format=data_format) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + + for output in ds.take(1): + output_array = output.numpy() + self.assertEqual(output_array.shape, input_data.shape) + + def test_grayscale_with_single_color_image(self): + test_cases = [ + # batched inputs + (np.full((1, 4, 4, 3), 128, dtype=np.float32), "channels_last"), + (np.full((1, 3, 4, 4), 128, dtype=np.float32), "channels_first"), + # unbatched inputs + (np.full((4, 4, 3), 128, dtype=np.float32), "channels_last"), + (np.full((3, 4, 4), 128, dtype=np.float32), "channels_first"), + ] + + for xs, data_format in test_cases: + layer = layers.RandomGrayscale(factor=1.0, data_format=data_format) + transformed = ops.convert_to_numpy(layer(xs)) + + # Determine if the input was batched + is_batched = len(xs.shape) == 4 + + # If batched, select the first image from the batch for inspection. + # Otherwise, use the transformed image directly. + # `image_to_inspect` will always be a 3D tensor. + if is_batched: + image_to_inspect = transformed[0] + else: + image_to_inspect = transformed + + if data_format == "channels_last": + # image_to_inspect has shape (H, W, C), + # get the first channel [:, :, 0] + channel_data = image_to_inspect[:, :, 0] + else: # data_format == "channels_first" + # image_to_inspect has shape (C, H, W), + # get the first channel [0, :, :] + channel_data = image_to_inspect[0, :, :] + + unique_vals = np.unique(channel_data) + self.assertEqual(len(unique_vals), 1) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_hue.py b/keras/src/layers/preprocessing/image_preprocessing/random_hue.py new file mode 100644 index 000000000000..b3a61ebfe803 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_hue.py @@ -0,0 +1,171 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.RandomHue") +class RandomHue(BaseImagePreprocessingLayer): + """Randomly adjusts the hue on given images. + + This layer will randomly increase/reduce the hue for the input RGB + images. + + The image hue is adjusted by converting the image(s) to HSV and rotating the + hue channel (H) by delta. The image is then converted back to RGB. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A single float or a tuple of two floats. + `factor` controls the extent to which the + image hue is impacted. `factor=0.0` makes this layer perform a + no-op operation, while a value of `1.0` performs the most aggressive + contrast adjustment available. If a tuple is used, a `factor` is + sampled between the two values for every image augmented. If a + single float is used, a value between `0.0` and the passed float is + sampled. In order to ensure the value is always the same, please + pass a tuple with two identical floats: `(0.5, 0.5)`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + + Example: + + ```python + (images, labels), _ = keras.datasets.cifar10.load_data() + random_hue = keras.layers.RandomHue(factor=0.5, value_range=[0, 1]) + images = keras.ops.cast(images, "float32") + augmented_images_batch = random_hue(images[:8]) + ``` + """ + + _USE_BASE_FACTOR = True + _FACTOR_BOUNDS = (0, 1) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.value_range = value_range + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + invert = self.backend.random.uniform((batch_size,), seed=seed) + + invert = self.backend.numpy.where( + invert > 0.5, + -self.backend.numpy.ones_like(invert), + self.backend.numpy.ones_like(invert), + ) + factor = self.backend.random.uniform( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + return {"factor": invert * factor * 0.5} + + def transform_images(self, images, transformation=None, training=True): + def _apply_random_hue(images, transformation): + images = self.backend.cast(images, self.compute_dtype) + images = self._transform_value_range( + images, self.value_range, (0, 1) + ) + adjust_factors = transformation["factor"] + adjust_factors = self.backend.cast(adjust_factors, images.dtype) + adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1) + adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1) + images = self.backend.image.rgb_to_hsv( + images, data_format=self.data_format + ) + if self.data_format == "channels_first": + h_channel = images[:, 0, :, :] + adjust_factors + h_channel = self.backend.numpy.where( + h_channel > 1.0, h_channel - 1.0, h_channel + ) + h_channel = self.backend.numpy.where( + h_channel < 0.0, h_channel + 1.0, h_channel + ) + images = self.backend.numpy.stack( + [h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1 + ) + else: + h_channel = images[..., 0] + adjust_factors + h_channel = self.backend.numpy.where( + h_channel > 1.0, h_channel - 1.0, h_channel + ) + h_channel = self.backend.numpy.where( + h_channel < 0.0, h_channel + 1.0, h_channel + ) + images = self.backend.numpy.stack( + [h_channel, images[..., 1], images[..., 2]], axis=-1 + ) + images = self.backend.image.hsv_to_rgb( + images, data_format=self.data_format + ) + images = self.backend.numpy.clip(images, 0, 1) + images = self._transform_value_range( + images, (0, 1), self.value_range + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + if training: + images = _apply_random_hue(images, transformation) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py new file mode 100644 index 000000000000..f115612309d9 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py @@ -0,0 +1,83 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomHueTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomHue, + init_kwargs={ + "factor": 0.75, + "value_range": (20, 200), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_hue_inference(self): + seed = 3481 + layer = layers.RandomHue(0.2, [0, 1.0]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_hue_value_range_0_to_1(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomHue(0.2, (0, 1)) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_hue_value_range_0_to_255(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255) + + layer = layers.RandomHue(0.2, (0, 255)) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255)) + + def test_random_hue_no_change_with_zero_factor(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = keras.random.randint((224, 224, 3), 0, 255) + else: + inputs = keras.random.randint((3, 224, 224), 0, 255) + + layer = layers.RandomHue(0, (0, 255), data_format=data_format) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_hue_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomHue(0.2, (0, 255)) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomHue( + factor=0.5, value_range=[0, 1], data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_invert.py b/keras/src/layers/preprocessing/image_preprocessing/random_invert.py new file mode 100644 index 000000000000..b180d83944c7 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_invert.py @@ -0,0 +1,129 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.RandomInvert") +class RandomInvert(BaseImagePreprocessingLayer): + """Preprocessing layer for random inversion of image colors. + + This layer randomly inverts the colors of input images with a specified + probability range. When applied, each image has a chance of having its + colors inverted, where the pixel values are transformed to their + complementary values. Images that are not selected for inversion + remain unchanged. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A single float or a tuple of two floats. + `factor` controls the probability of inverting the image colors. + If a tuple is provided, the value is sampled between the two values + for each image, where `factor[0]` is the minimum and `factor[1]` is + the maximum probability. If a single float is provided, a value + between `0.0` and the provided float is sampled. + Defaults to `(0, 1)`. + value_range: a tuple or a list of two elements. The first value + represents the lower bound for values in passed images, the second + represents the upper bound. Images passed to the layer should have + values within `value_range`. Defaults to `(0, 255)`. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + def __init__( + self, + factor=1.0, + value_range=(0, 255), + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.value_range = value_range + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + seed = seed or self._get_seed_generator(self.backend._backend) + + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received " + f"inputs.shape={images_shape}" + ) + + invert_probability = self.backend.random.uniform( + shape=(batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + + random_threshold = self.backend.random.uniform( + shape=(batch_size,), + minval=0, + maxval=1, + seed=seed, + ) + + apply_inversion = random_threshold < invert_probability + return {"apply_inversion": apply_inversion} + + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + apply_inversion = transformation["apply_inversion"] + return self.backend.numpy.where( + apply_inversion[:, None, None, None], + self.value_range[1] - images, + images, + ) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_invert_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_invert_test.py new file mode 100644 index 000000000000..0b0d186ab339 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_invert_test.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomInvertTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomInvert, + init_kwargs={ + "factor": 0.75, + "value_range": (20, 200), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_invert_inference(self): + seed = 3481 + layer = layers.RandomInvert() + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_invert_no_op(self): + seed = 3481 + layer = layers.RandomInvert(factor=0) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + def test_random_invert_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((1, 8, 8, 3)) + else: + input_data = np.random.random((1, 3, 8, 8)) + layer = layers.RandomInvert( + factor=(1, 1), + value_range=[0, 1], + data_format=data_format, + seed=1337, + ) + output = layer(input_data) + self.assertAllClose(1 - input_data, output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomInvert( + factor=0.5, value_range=[0, 1], data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_perspective.py b/keras/src/layers/preprocessing/image_preprocessing/random_perspective.py new file mode 100644 index 000000000000..9702edc7b6db --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_perspective.py @@ -0,0 +1,339 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomPerspective") +class RandomPerspective(BaseImagePreprocessingLayer): + """A preprocessing layer that applies random perspective transformations. + + This layer distorts the perspective of input images by shifting their + corner points, simulating a 3D-like transformation. The amount of distortion + is controlled by the `factor` and `scale` parameters. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A float or a tuple of two floats. + Represents the probability of applying the perspective + transformation to each image in the batch. + - `factor=0.0` ensures no transformation is applied. + - `factor=1.0` means the transformation is always applied. + - If a tuple `(min, max)` is provided, a probability is randomly + sampled between `min` and `max` for each image. + - If a single float is given, the probability is sampled between + `0.0` and the provided float. + Default is 1.0. + scale: A float defining the relative amount of perspective shift. + Determines how much the image corners are displaced, affecting + the intensity of the perspective effect. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode="constant"`. + seed: Integer. Used to create a random seed. + + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + + def __init__( + self, + factor=1.0, + scale=1.0, + interpolation="bilinear", + fill_value=0.0, + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self.scale = scale + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.supports_jit = False + + if scale < 0.0 or scale > 1.0: + raise ValueError( + "The `scale` argument should be a number " + "in the range " + f"[0,1]. " + f"Received: scale={scale}" + ) + + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + self.channel_axis = -3 + else: + self.height_axis = -3 + self.width_axis = -2 + self.channel_axis = -1 + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + unbatched = len(images_shape) == 3 + if unbatched: + batch_size = 1 + else: + batch_size = images_shape[0] + height, width = ( + images.shape[self.height_axis], + images.shape[self.width_axis], + ) + + seed = seed or self._get_seed_generator(self.backend._backend) + + transformation_probability = self.backend.random.uniform( + shape=(batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + + random_threshold = self.backend.random.uniform( + shape=(batch_size,), + minval=0.0, + maxval=1.0, + seed=seed, + ) + apply_perspective = random_threshold < transformation_probability + + perspective_factor = self.backend.random.uniform( + shape=(batch_size, 4, 2), + minval=-0.5 * self.scale, + maxval=0.5 * self.scale, + seed=seed, + dtype=self.compute_dtype, + ) + + start_points = self.backend.convert_to_tensor( + [ + [ + [0.0, 0.0], + [width - 1, 0.0], + [0.0, height - 1], + [width - 1, height - 1], + ] + ], + dtype=self.compute_dtype, + ) + + start_points = self.backend.numpy.repeat( + start_points, batch_size, axis=0 + ) + end_points = start_points + start_points * perspective_factor + + return { + "apply_perspective": apply_perspective, + "start_points": start_points, + "end_points": end_points, + "input_shape": images_shape, + } + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training and transformation is not None: + images = self._perspective_inputs(images, transformation) + images = self.backend.cast(images, self.compute_dtype) + return images + + def _perspective_inputs(self, inputs, transformation): + if transformation is None: + return inputs + + inputs_shape = self.backend.shape(inputs) + unbatched = len(inputs_shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + start_points = transformation["start_points"] + end_points = transformation["end_points"] + + outputs = self.backend.image.perspective_transform( + inputs, + start_points, + end_points, + interpolation=self.interpolation, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + apply_perspective = transformation["apply_perspective"] + outputs = self.backend.numpy.where( + apply_perspective[:, None, None, None], + outputs, + inputs, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) + return outputs + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + if training and transformation is not None: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + input_height, input_width = ( + transformation["input_shape"][self.height_axis], + transformation["input_shape"][self.width_axis], + ) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=input_height, + width=input_width, + ) + + boxes = bounding_boxes["boxes"] + x0, y0, x1, y1 = self.backend.numpy.split(boxes, 4, axis=-1) + + start_points = transformation["start_points"] + end_points = transformation["end_points"] + transform = self.backend.image.compute_homography_matrix( + start_points, end_points + ) + transform = self.backend.numpy.expand_dims(transform, axis=1) + transform = self.backend.cast(transform, dtype=self.compute_dtype) + + corners = [ + self._get_transformed_coordinates(x, y, transform) + for x, y in [(x0, y0), (x1, y1), (x0, y1), (x1, y0)] + ] + x_corners, y_corners = zip(*corners) + + xs = self.backend.numpy.stack(x_corners, axis=-1) + ys = self.backend.numpy.stack(y_corners, axis=-1) + + min_x, max_x = ( + self.backend.numpy.min(xs, axis=-1), + self.backend.numpy.max(xs, axis=-1), + ) + min_y, max_y = ( + self.backend.numpy.min(ys, axis=-1), + self.backend.numpy.max(ys, axis=-1), + ) + + min_x = self.backend.numpy.expand_dims(min_x, axis=-1) + max_x = self.backend.numpy.expand_dims(max_x, axis=-1) + min_y = self.backend.numpy.expand_dims(min_y, axis=-1) + max_y = self.backend.numpy.expand_dims(max_y, axis=-1) + + boxes = self.backend.numpy.concatenate( + [min_x, min_y, max_x, max_y], axis=-1 + ) + + apply_perspective = self.backend.core.convert_to_tensor( + transformation["apply_perspective"], dtype=boxes.dtype + ) + + bounding_boxes["boxes"] = self.backend.numpy.where( + apply_perspective[:, None, None], + boxes, + bounding_boxes["boxes"], + ) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) + + self.backend.reset() + + return bounding_boxes + + def _get_transformed_coordinates( + self, x_coords, y_coords, transformation_matrix + ): + backend = self.backend + + batch_size = backend.shape(transformation_matrix)[0] + + homogeneous_transform = backend.numpy.concatenate( + [transformation_matrix, backend.numpy.ones((batch_size, 1, 1))], + axis=-1, + ) + homogeneous_transform = backend.numpy.reshape( + homogeneous_transform, (batch_size, 3, 3) + ) + + inverse_transform = backend.linalg.inv(homogeneous_transform) + + ones_column = backend.numpy.ones_like(x_coords) + homogeneous_coords = backend.numpy.concatenate( + [x_coords, y_coords, ones_column], axis=-1 + ) + + homogeneous_coords = backend.numpy.moveaxis(homogeneous_coords, -1, -2) + transformed_coords = backend.numpy.matmul( + inverse_transform, homogeneous_coords + ) + transformed_coords = backend.numpy.moveaxis(transformed_coords, -1, -2) + + x_transformed = transformed_coords[..., 0] / transformed_coords[..., 2] + y_transformed = transformed_coords[..., 1] / transformed_coords[..., 2] + + return x_transformed, y_transformed + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "factor": self.factor, + "scale": self.scale, + "interpolation": self.interpolation, + "fill_value": self.fill_value, + "seed": self.seed, + } + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py new file mode 100644 index 000000000000..b29c5a679132 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py @@ -0,0 +1,268 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomPerspectiveTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomPerspective, + init_kwargs={ + "factor": 1.0, + "scale": 0.5, + "interpolation": "bilinear", + "fill_value": 0, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_perspective_inference(self): + seed = 3481 + layer = layers.RandomPerspective() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_perspective_no_op(self): + seed = 3481 + layer = layers.RandomPerspective(factor=0) + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs) + self.assertAllClose(inputs, output) + + def test_random_perspective_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.ones((4, 4, 1)) + expected_output = np.asarray( + [ + [[1.0], [1.0], [0.0], [0.0]], + [[1.0], [1.0], [0.0], [0.0]], + [[0.0], [0.0], [0.0], [0.0]], + [[0.0], [0.0], [0.0], [0.0]], + ], + ) + + else: + inputs = np.ones((1, 4, 4)) + expected_output = np.array( + [ + [ + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ] + ) + + layer = layers.RandomPerspective(data_format=data_format) + + transformation = { + "apply_perspective": np.array([True]), + "start_points": np.array( + [[[0.0, 0.0], [3.0, 0.0], [0.0, 3.0], [3.0, 3.0]]] + ), + "end_points": np.array([[[0.0, 0.0], [1, 0.0], [0.0, 1], [1, 1]]]), + "input_shape": np.array((4, 4, 1)), + } + output = layer.transform_images(inputs, transformation) + + self.assertAllClose(expected_output, output, atol=1e-4, rtol=1e-4) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomPerspective(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + @parameterized.named_parameters( + ( + "with_large_scale", + [ + [ + [0.0, 0.0], + [8.151311, 0.0], + [0.0, 12.695701], + [9.2712054, 10.524198], + ] + ], + [ + [ + [2.6490488, 1.1149256, 5.2026834, 3.6187303], + [7.5547166, 4.2492595, 8.0, 6.869391], + ] + ], + ), + ( + "with_small_scale", + [ + [ + [0.0, 0.0], + [4.151311, 0.0], + [0.0, 6.695701], + [4.2712054, 7.524198], + ] + ], + [ + [ + [1.095408, 0.7504317, 2.2761598, 2.3389952], + [3.5416048, 3.2349987, 4.920989, 5.0568376], + ] + ], + ), + ) + def test_random_perspective_bounding_boxes( + self, end_points, expected_boxes + ): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + layer = layers.RandomPerspective( + # data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "apply_perspective": np.array([True]), + "end_points": np.array(end_points), + "input_shape": np.array(image_shape), + "start_points": np.array( + [[[0.0, 0.0], [7.0, 0.0], [0.0, 9.0], [7.0, 9.0]]] + ), + } + + output = layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation, + ) + + self.assertAllClose( + output["boxes"], expected_boxes, atol=1e-3, rtol=1e-3 + ) + + @parameterized.named_parameters( + ( + "with_large_scale", + [ + [ + [0.0, 0.0], + [8.151311, 0.0], + [0.0, 12.695701], + [9.2712054, 10.524198], + ] + ], + [ + [ + [2.6490488, 1.1149256, 5.2026834, 3.6187303], + [7.5547166, 4.2492595, 8.0, 6.869391], + ] + ], + ), + ( + "with_small_scale", + [ + [ + [0.0, 0.0], + [4.151311, 0.0], + [0.0, 6.695701], + [4.2712054, 7.524198], + ] + ], + [ + [ + [1.095408, 0.7504317, 2.2761598, 2.3389952], + [3.5416048, 3.2349987, 4.920989, 5.0568376], + ] + ], + ), + ) + def test_random_flip_tf_data_bounding_boxes( + self, end_points, expected_boxes + ): + data_format = backend.config.image_data_format() + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.RandomPerspective( + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "apply_perspective": np.array([True]), + "end_points": np.array(end_points), + "input_shape": np.array(image_shape), + "start_points": np.array( + [[[0.0, 0.0], [7.0, 0.0], [0.0, 9.0], [7.0, 9.0]]] + ), + } + + ds = ds.map( + lambda x: layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose( + output["boxes"], expected_boxes, atol=1e-3, rtol=1e-3 + ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py b/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py new file mode 100644 index 000000000000..83ae04a165ec --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py @@ -0,0 +1,154 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.RandomPosterization") +class RandomPosterization(BaseImagePreprocessingLayer): + """Reduces the number of bits for each color channel. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501) + - [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719) + + Args: + value_range: a tuple or a list of two elements. The first value + represents the lower bound for values in passed images, the second + represents the upper bound. Images passed to the layer should have + values within `value_range`. Defaults to `(0, 255)`. + factor: integer, the number of bits to keep for each channel. Must be a + value between 1-8. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (1, 8) + _MAX_FACTOR = 8 + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + if self.factor[0] != self.factor[1]: + factor = self.backend.random.randint( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + dtype="uint8", + ) + else: + factor = ( + self.backend.numpy.ones((batch_size,), dtype="uint8") + * self.factor[0] + ) + + shift_factor = self._MAX_FACTOR - factor + return {"shift_factor": shift_factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + shift_factor = transformation["shift_factor"] + + shift_factor = self.backend.numpy.reshape( + shift_factor, self.backend.shape(shift_factor) + (1, 1, 1) + ) + + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + + images = self.backend.cast(images, "uint8") + images = self.backend.numpy.bitwise_left_shift( + self.backend.numpy.bitwise_right_shift(images, shift_factor), + shift_factor, + ) + images = self.backend.cast(images, self.compute_dtype) + + images = self._transform_value_range( + images, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py new file mode 100644 index 000000000000..347f82a3a962 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomPosterizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomPosterization, + init_kwargs={ + "factor": 1, + "value_range": (20, 200), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_posterization_inference(self): + seed = 3481 + layer = layers.RandomPosterization(1, [0, 255]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_posterization_basic(self): + seed = 3481 + layer = layers.RandomPosterization( + 1, [0, 255], data_format="channels_last", seed=seed + ) + np.random.seed(seed) + inputs = np.asarray( + [[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]] + ) + output = layer(inputs) + expected_output = np.asarray( + [[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]] + ) + self.assertAllClose(expected_output, output) + + def test_random_posterization_value_range_0_to_1(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomPosterization(1, [0, 1.0]) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_posterization_value_range_0_to_255(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255) + + layer = layers.RandomPosterization(1, [0, 255]) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255)) + + def test_random_posterization_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomPosterization(1, [0, 255]) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomPosterization(1, [0, 255]) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py new file mode 100644 index 000000000000..9d36f4281cc5 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py @@ -0,0 +1,249 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import ( + converters, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomRotation") +class RandomRotation(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly rotates images during training. + + This layer will apply random rotations to each image, filling empty space + according to `fill_mode`. + + By default, random rotations are only applied during training. + At inference time, the layer does nothing. If you need to apply random + rotations at inference time, pass `training=True` when calling the layer. + + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and + of integer or floating point dtype. + By default, the layer will output floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format + + Args: + factor: a float represented as fraction of 2 Pi, or a tuple of size 2 + representing lower and upper bound for rotating clockwise and + counter-clockwise. A positive values means rotating + counter clock-wise, + while a negative value means clock-wise. + When represented as a single + float, this value is used for both the upper and lower bound. + For instance, `factor=(-0.2, 0.3)` + results in an output rotation by a random + amount in the range `[-20% * 360, 30% * 360]`. + `factor=0.2` results in an + output rotating by a random amount + in the range `[-20% * 360, 20% * 360]`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode + (one of `{"constant", "reflect", "wrap", "nearest"}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about + the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by + filling all values beyond the edge with + the same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by + wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + seed: Integer. Used to create a random seed. + fill_value: a float represents the value to be filled outside + the boundaries when `fill_mode="constant"`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + """ + + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + + def __init__( + self, + factor, + fill_mode="reflect", + interpolation="bilinear", + seed=None, + fill_value=0.0, + data_format=None, + **kwargs, + ): + super().__init__(factor=factor, data_format=data_format, **kwargs) + self.seed = seed + self.generator = SeedGenerator(seed) + self.fill_mode = fill_mode + self.interpolation = interpolation + self.fill_value = fill_value + self.supports_jit = False + + if self.fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if self.interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + return self.backend.image.affine_transform( + images=images, + transform=transformation["rotation_matrix"], + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + if training: + ops = self.backend + boxes = bounding_boxes["boxes"] + height = transformation["image_height"] + width = transformation["image_width"] + batch_size = transformation["batch_size"] + boxes = converters.affine_transform( + boxes=boxes, + angle=transformation["angle"], + translate_x=ops.numpy.zeros([batch_size]), + translate_y=ops.numpy.zeros([batch_size]), + scale=ops.numpy.ones([batch_size]), + shear_x=ops.numpy.zeros([batch_size]), + shear_y=ops.numpy.zeros([batch_size]), + height=height, + width=width, + ) + + bounding_boxes["boxes"] = boxes + bounding_boxes = converters.clip_to_image_size( + bounding_boxes, + height=height, + width=width, + bounding_box_format="xyxy", + ) + bounding_boxes = converters.convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=height, + width=width, + ) + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def get_random_transformation(self, data, training=True, seed=None): + ops = self.backend + if not training: + return None + if isinstance(data, dict): + images = data["images"] + else: + images = data + shape = ops.core.shape(images) + if len(shape) == 4: + batch_size = shape[0] + if self.data_format == "channels_last": + image_height = shape[1] + image_width = shape[2] + else: + image_height = shape[2] + image_width = shape[3] + else: + batch_size = 1 + if self.data_format == "channels_last": + image_height = shape[0] + image_width = shape[1] + else: + image_height = shape[1] + image_width = shape[2] + + if seed is None: + seed = self._get_seed_generator(ops._backend) + lower = self.factor[0] * 360.0 + upper = self.factor[1] * 360.0 + angle = ops.random.uniform( + shape=(batch_size,), + minval=lower, + maxval=upper, + seed=seed, + ) + center_x, center_y = 0.5, 0.5 + rotation_matrix = self._compute_affine_matrix( + center_x=center_x, + center_y=center_y, + angle=angle, + translate_x=ops.numpy.zeros([batch_size]), + translate_y=ops.numpy.zeros([batch_size]), + scale=ops.numpy.ones([batch_size]), + shear_x=ops.numpy.zeros([batch_size]), + shear_y=ops.numpy.zeros([batch_size]), + height=image_height, + width=image_width, + ) + if len(shape) == 3: + rotation_matrix = self.backend.numpy.squeeze( + rotation_matrix, axis=0 + ) + return { + "angle": angle, + "rotation_matrix": rotation_matrix, + "image_height": image_height, + "image_width": image_width, + "batch_size": batch_size, + } + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "factor": self.factor, + "data_format": self.data_format, + "fill_mode": self.fill_mode, + "fill_value": self.fill_value, + "interpolation": self.interpolation, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py new file mode 100644 index 000000000000..7350c550ede6 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py @@ -0,0 +1,77 @@ +import numpy as np +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomRotationTest(testing.TestCase): + @parameterized.named_parameters( + ("random_rotate_neg4", -0.4), + ("random_rotate_neg2", -0.2), + ("random_rotate_4", 0.4), + ("random_rotate_2", 0.2), + ("random_rotate_tuple", (-0.2, 0.4)), + ) + def test_random_rotation_shapes(self, factor): + self.run_layer_test( + layers.RandomRotation, + init_kwargs={ + "factor": factor, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + supports_masking=False, + run_training_check=False, + ) + + def test_random_rotation_correctness(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (1, 5, 5, 1) + else: + input_shape = (1, 1, 5, 5) + input_image = np.reshape(np.arange(0, 25), input_shape) + layer = layers.RandomRotation(factor=(0.5, 0.5)) + actual_output = layer(input_image) + expected_output = np.asarray( + [ + [24, 23, 22, 21, 20], + [19, 18, 17, 16, 15], + [14, 13, 12, 11, 10], + [9, 8, 7, 6, 5], + [4, 3, 2, 1, 0], + ] + ).reshape(input_shape) + + self.assertAllClose( + backend.convert_to_tensor(expected_output), actual_output, atol=1e-5 + ) + + def test_training_false(self): + input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1)) + layer = layers.RandomRotation(factor=(0.5, 0.5)) + actual_output = layer(input_image, training=False) + self.assertAllClose(actual_output, input_image) + + def test_tf_data_compatibility(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (1, 5, 5, 1) + else: + input_shape = (1, 1, 5, 5) + input_image = np.reshape(np.arange(0, 25), input_shape) + layer = layers.RandomRotation(factor=(0.5, 0.5)) + + ds = tf_data.Dataset.from_tensor_slices(input_image).map(layer) + expected_output = np.asarray( + [ + [24, 23, 22, 21, 20], + [19, 18, 17, 16, 15], + [14, 13, 12, 11, 10], + [9, 8, 7, 6, 5], + [4, 3, 2, 1, 0], + ] + ).reshape(input_shape[1:]) + output = next(iter(ds)).numpy() + self.assertAllClose(expected_output, output) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_saturation.py b/keras/src/layers/preprocessing/image_preprocessing/random_saturation.py new file mode 100644 index 000000000000..e930bd687adf --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_saturation.py @@ -0,0 +1,167 @@ +from keras.src.api_export import keras_export +from keras.src.backend import epsilon +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomSaturation") +class RandomSaturation(BaseImagePreprocessingLayer): + """Randomly adjusts the saturation on given images. + + This layer will randomly increase/reduce the saturation for the input RGB + images. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A tuple of two floats or a single float. + `factor` controls the extent to which the image saturation + is impacted. `factor=0.5` makes this layer perform a no-op + operation. `factor=0.0` makes the image fully grayscale. + `factor=1.0` makes the image fully saturated. Values should + be between `0.0` and `1.0`. If a tuple is used, a `factor` + is sampled between the two values for every image augmented. + If a single float is used, a value between `0.0` and the passed + float is sampled. To ensure the value is always the same, + pass a tuple with two identical floats: `(0.5, 0.5)`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + + Example: + ```python + (images, labels), _ = keras.datasets.cifar10.load_data() + images = images.astype("float32") + random_saturation = keras.layers.RandomSaturation(factor=0.2) + augmented_images = random_saturation(images) + ``` + """ + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + factor = factor / (1 - factor + epsilon()) + return {"factor": factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + adjust_factors = transformation["factor"] + adjust_factors = self.backend.cast( + adjust_factors, self.compute_dtype + ) + adjust_factors = self.backend.numpy.reshape( + adjust_factors, self.backend.shape(adjust_factors) + (1, 1) + ) + images = self.backend.image.rgb_to_hsv( + images, data_format=self.data_format + ) + if self.data_format == "channels_first": + s_channel = self.backend.numpy.multiply( + images[:, 1, :, :], adjust_factors + ) + s_channel = self.backend.numpy.clip( + s_channel, self.value_range[0], self.value_range[1] + ) + images = self.backend.numpy.stack( + [images[:, 0, :, :], s_channel, images[:, 2, :, :]], axis=1 + ) + else: + s_channel = self.backend.numpy.multiply( + images[..., 1], adjust_factors + ) + s_channel = self.backend.numpy.clip( + s_channel, self.value_range[0], self.value_range[1] + ) + images = self.backend.numpy.stack( + [images[..., 0], s_channel, images[..., 2]], axis=-1 + ) + images = self.backend.image.hsv_to_rgb( + images, data_format=self.data_format + ) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py new file mode 100644 index 000000000000..42ed613ab913 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomSaturationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomSaturation, + init_kwargs={ + "factor": 0.75, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_saturation_value_range(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomSaturation(0.2) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_saturation_no_op(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + + layer = layers.RandomSaturation((0.5, 0.5)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_saturation_full_grayscale(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + layer = layers.RandomSaturation(factor=(0.0, 0.0)) + result = layer(inputs) + + if data_format == "channels_last": + self.assertAllClose(result[..., 0], result[..., 1]) + self.assertAllClose(result[..., 1], result[..., 2]) + else: + self.assertAllClose(result[:, 0, :, :], result[:, 1, :, :]) + self.assertAllClose(result[:, 1, :, :], result[:, 2, :, :]) + + def test_random_saturation_full_saturation(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + layer = layers.RandomSaturation(factor=(1.0, 1.0)) + result = layer(inputs) + + hsv = backend.image.rgb_to_hsv(result) + s_channel = hsv[..., 1] + + self.assertAllClose( + keras.ops.numpy.max(s_channel), layer.value_range[1] + ) + + def test_random_saturation_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomSaturation(0.2) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomSaturation( + factor=0.5, data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py new file mode 100644 index 000000000000..0ddc38d22b47 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py @@ -0,0 +1,171 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomSharpness") +class RandomSharpness(BaseImagePreprocessingLayer): + """Randomly performs the sharpness operation on given images. + + The sharpness operation first performs a blur, then blends between the + original image and the processed image. This operation adjusts the clarity + of the edges in an image, ranging from blurred to enhanced sharpness. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + factor: A tuple of two floats or a single float. + `factor` controls the extent to which the image sharpness + is impacted. `factor=0.0` results in a fully blurred image, + `factor=0.5` applies no operation (preserving the original image), + and `factor=1.0` enhances the sharpness beyond the original. Values + should be between `0.0` and `1.0`. If a tuple is used, a `factor` + is sampled between the two values for every image augmented. + If a single float is used, a value between `0.0` and the passed + float is sampled. To ensure the value is always the same, + pass a tuple with two identical floats: `(0.5, 0.5)`. + value_range: the range of values the incoming images will have. + Represented as a two-number tuple written `[low, high]`. This is + typically either `[0, 1]` or `[0, 255]` depending on how your + preprocessing pipeline is set up. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + return {"factor": factor} + + def transform_images(self, images, transformation=None, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + if self.data_format == "channels_first": + images = self.backend.numpy.swapaxes(images, -3, -1) + + sharpness_factor = self.backend.cast( + transformation["factor"] * 2, dtype=self.compute_dtype + ) + sharpness_factor = self.backend.numpy.reshape( + sharpness_factor, (-1, 1, 1, 1) + ) + + num_channels = self.backend.shape(images)[-1] + + a, b = 1.0 / 13.0, 5.0 / 13.0 + kernel = self.backend.convert_to_tensor( + [[a, a, a], [a, b, a], [a, a, a]], dtype=self.compute_dtype + ) + kernel = self.backend.numpy.reshape(kernel, (3, 3, 1, 1)) + kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1]) + kernel = self.backend.cast(kernel, self.compute_dtype) + + smoothed_image = self.backend.nn.depthwise_conv( + images, + kernel, + strides=1, + padding="same", + data_format="channels_last", + ) + + smoothed_image = self.backend.cast( + smoothed_image, dtype=self.compute_dtype + ) + images = images + (1.0 - sharpness_factor) * ( + smoothed_image - images + ) + + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + + if self.data_format == "channels_first": + images = self.backend.numpy.swapaxes(images, -3, -1) + + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py new file mode 100644 index 000000000000..5cf3b10c8674 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py @@ -0,0 +1,65 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomSharpnessTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomSharpness, + init_kwargs={ + "factor": 0.75, + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_sharpness_value_range(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomSharpness(0.2) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_sharpness_no_op(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + + layer = layers.RandomSharpness((0.5, 0.5)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_sharpness_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomSharpness(0.2) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomSharpness( + factor=0.5, data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_shear.py b/keras/src/layers/preprocessing/image_preprocessing/random_shear.py new file mode 100644 index 000000000000..71ecc6b81278 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_shear.py @@ -0,0 +1,404 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomShear") +class RandomShear(BaseImagePreprocessingLayer): + """A preprocessing layer that randomly applies shear transformations to + images. + + This layer shears the input images along the x-axis and/or y-axis by a + randomly selected factor within the specified range. The shear + transformation is applied to each image independently in a batch. Empty + regions created during the transformation are filled according to the + `fill_mode` and `fill_value` parameters. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + x_factor: A tuple of two floats. For each augmented image, a value + is sampled from the provided range. If a float is passed, the + range is interpreted as `(0, x_factor)`. Values represent a + percentage of the image to shear over. For example, 0.3 shears + pixels up to 30% of the way across the image. All provided values + should be positive. + y_factor: A tuple of two floats. For each augmented image, a value + is sampled from the provided range. If a float is passed, the + range is interpreted as `(0, y_factor)`. Values represent a + percentage of the image to shear over. For example, 0.3 shears + pixels up to 30% of the way across the image. All provided values + should be positive. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the + last pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge + with the same constant value `k` specified by `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does + not support `"reflect"`. + Note that torch backend does not support `"wrap"`. + fill_value: A float representing the value to be filled outside the + boundaries when `fill_mode="constant"`. + seed: Integer. Used to create a random seed. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + _FACTOR_VALIDATION_ERROR = ( + "The `factor` argument should be a number (or a list of two numbers) " + "in the range [0, 1.0]. " + ) + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + + def __init__( + self, + x_factor=0.0, + y_factor=0.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.x_factor = self._set_factor_with_name(x_factor, "x_factor") + self.y_factor = self._set_factor_with_name(y_factor, "y_factor") + + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + self.fill_mode = fill_mode + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.supports_jit = False + + def _set_factor_with_name(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + factor = abs(factor) + lower, upper = [-factor, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < 0.0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + if len(images_shape) == 3: + batch_size = 1 + else: + batch_size = images_shape[0] + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + invert = self.backend.random.uniform( + minval=0, + maxval=1, + shape=[batch_size, 1], + seed=seed, + dtype=self.compute_dtype, + ) + invert = self.backend.numpy.where( + invert > 0.5, + -self.backend.numpy.ones_like(invert), + self.backend.numpy.ones_like(invert), + ) + + shear_y = self.backend.random.uniform( + minval=self.y_factor[0], + maxval=self.y_factor[1], + shape=[batch_size, 1], + seed=seed, + dtype=self.compute_dtype, + ) + shear_x = self.backend.random.uniform( + minval=self.x_factor[0], + maxval=self.x_factor[1], + shape=[batch_size, 1], + seed=seed, + dtype=self.compute_dtype, + ) + shear_factor = ( + self.backend.cast( + self.backend.numpy.concatenate([shear_x, shear_y], axis=1), + dtype=self.compute_dtype, + ) + * invert + ) + return {"shear_factor": shear_factor, "input_shape": images_shape} + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + return self._shear_inputs(images, transformation) + return images + + def _shear_inputs(self, inputs, transformation): + if transformation is None: + return inputs + + inputs_shape = self.backend.shape(inputs) + unbatched = len(inputs_shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + shear_factor = transformation["shear_factor"] + outputs = self.backend.image.affine_transform( + inputs, + transform=self._get_shear_matrix(shear_factor), + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) + return outputs + + def _get_shear_matrix(self, shear_factors): + num_shear_factors = self.backend.shape(shear_factors)[0] + + # The shear matrix looks like: + # [[1 s_x 0] + # [s_y 1 0] + # [0 0 1]] + + return self.backend.numpy.stack( + [ + self.backend.numpy.ones((num_shear_factors,)), + shear_factors[:, 0], + self.backend.numpy.zeros((num_shear_factors,)), + shear_factors[:, 1], + self.backend.numpy.ones((num_shear_factors,)), + self.backend.numpy.zeros((num_shear_factors,)), + self.backend.numpy.zeros((num_shear_factors,)), + self.backend.numpy.zeros((num_shear_factors,)), + ], + axis=1, + ) + + def transform_labels(self, labels, transformation, training=True): + return labels + + def get_transformed_x_y(self, x, y, transform): + a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split( + transform, 8, axis=-1 + ) + + k = c0 * x + c1 * y + 1 + x_transformed = (a0 * x + a1 * y + a2) / k + y_transformed = (b0 * x + b1 * y + b2) / k + return x_transformed, y_transformed + + def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor): + bboxes = bounding_boxes["boxes"] + x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1) + + w_shift_factor = self.backend.convert_to_tensor( + w_shift_factor, dtype=x1.dtype + ) + h_shift_factor = self.backend.convert_to_tensor( + h_shift_factor, dtype=x1.dtype + ) + + if len(bboxes.shape) == 3: + w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1) + h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1) + + bounding_boxes["boxes"] = self.backend.numpy.concatenate( + [ + x1 - w_shift_factor, + x2 - h_shift_factor, + x3 - w_shift_factor, + x4 - h_shift_factor, + ], + axis=-1, + ) + return bounding_boxes + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + def _get_height_width(transformation): + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) + return input_height, input_width + + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + input_height, input_width = _get_height_width(transformation) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="rel_xyxy", + height=input_height, + width=input_width, + dtype=self.compute_dtype, + ) + + bounding_boxes = self._shear_bboxes(bounding_boxes, transformation) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="rel_xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + dtype=self.compute_dtype, + ) + + self.backend.reset() + + return bounding_boxes + + def _shear_bboxes(self, bounding_boxes, transformation): + shear_factor = self.backend.cast( + transformation["shear_factor"], dtype=self.compute_dtype + ) + shear_x_amount, shear_y_amount = self.backend.numpy.split( + shear_factor, 2, axis=-1 + ) + + x1, y1, x2, y2 = self.backend.numpy.split( + bounding_boxes["boxes"], 4, axis=-1 + ) + x1 = self.backend.numpy.squeeze(x1, axis=-1) + y1 = self.backend.numpy.squeeze(y1, axis=-1) + x2 = self.backend.numpy.squeeze(x2, axis=-1) + y2 = self.backend.numpy.squeeze(y2, axis=-1) + + if shear_x_amount is not None: + x1_top = x1 - (shear_x_amount * y1) + x1_bottom = x1 - (shear_x_amount * y2) + x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom) + + x2_top = x2 - (shear_x_amount * y1) + x2_bottom = x2 - (shear_x_amount * y2) + x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top) + + if shear_y_amount is not None: + y1_left = y1 - (shear_y_amount * x1) + y1_right = y1 - (shear_y_amount * x2) + y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left) + + y2_left = y2 - (shear_y_amount * x1) + y2_right = y2 - (shear_y_amount * x2) + y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right) + + boxes = self.backend.numpy.concatenate( + [ + self.backend.numpy.expand_dims(x1, axis=-1), + self.backend.numpy.expand_dims(y1, axis=-1), + self.backend.numpy.expand_dims(x2, axis=-1), + self.backend.numpy.expand_dims(y2, axis=-1), + ], + axis=-1, + ) + bounding_boxes["boxes"] = boxes + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def get_config(self): + base_config = super().get_config() + config = { + "x_factor": self.x_factor, + "y_factor": self.y_factor, + "fill_mode": self.fill_mode, + "interpolation": self.interpolation, + "seed": self.seed, + "fill_value": self.fill_value, + "data_format": self.data_format, + } + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py new file mode 100644 index 000000000000..9d5592ff491d --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py @@ -0,0 +1,200 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.utils import backend_utils + + +class RandomShearTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomShear, + init_kwargs={ + "x_factor": (0.5, 1), + "y_factor": (0.5, 1), + "interpolation": "bilinear", + "fill_mode": "reflect", + "data_format": "channels_last", + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_posterization_inference(self): + seed = 3481 + layer = layers.RandomShear(1, 1) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_shear_pixel_level(self): + image = np.zeros((1, 5, 5, 3)) + image[0, 1:4, 1:4, :] = 1.0 + image[0, 2, 2, :] = [0.0, 1.0, 0.0] + image = keras.ops.convert_to_tensor(image, dtype="float32") + + data_format = backend.config.image_data_format() + if data_format == "channels_first": + image = keras.ops.transpose(image, (0, 3, 1, 2)) + + shear_layer = layers.RandomShear( + x_factor=(0.2, 0.3), + y_factor=(0.2, 0.3), + interpolation="bilinear", + fill_mode="constant", + fill_value=0.0, + seed=42, + data_format=data_format, + ) + + sheared_image = shear_layer(image) + + if data_format == "channels_first": + sheared_image = keras.ops.transpose(sheared_image, (0, 2, 3, 1)) + + original_pixel = image[0, 2, 2, :] + sheared_pixel = sheared_image[0, 2, 2, :] + self.assertNotAllClose(original_pixel, sheared_pixel) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomShear(1, 1) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + @parameterized.named_parameters( + ( + "with_x_shift", + [[1.0, 0.0]], + [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]], + ), + ( + "with_y_shift", + [[0.0, 1.0]], + [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]], + ), + ( + "with_xy_shift", + [[1.0, 1.0]], + [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]], + ), + ) + def test_random_shear_bounding_boxes(self, translation, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + layer = layers.RandomShear( + x_factor=0.5, + y_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "shear_factor": backend_utils.convert_tf_tensor( + np.array(translation) + ), + "input_shape": image_shape, + } + output = layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + + self.assertAllClose(output["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_x_shift", + [[1.0, 0.0]], + [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]], + ), + ( + "with_y_shift", + [[0.0, 1.0]], + [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]], + ), + ( + "with_xy_shift", + [[1.0, 1.0]], + [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]], + ), + ) + def test_random_shear_tf_data_bounding_boxes( + self, translation, expected_boxes + ): + data_format = backend.config.image_data_format() + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.RandomShear( + x_factor=0.5, + y_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "shear_factor": np.array(translation), + "input_shape": image_shape, + } + + ds = ds.map( + lambda x: layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_translation.py b/keras/src/layers/preprocessing/image_preprocessing/random_translation.py new file mode 100644 index 000000000000..488c0e0e50c2 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_translation.py @@ -0,0 +1,384 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomTranslation") +class RandomTranslation(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly translates images during training. + + This layer will apply random translations to each image during training, + filling empty space according to `fill_mode`. + + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and + of integer or floating point dtype. By default, the layer will output + floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. + + Args: + height_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for shifting vertically. A + negative value means shifting image up, while a positive value means + shifting image down. When represented as a single positive float, + this value is used for both the upper and lower bound. For instance, + `height_factor=(-0.2, 0.3)` results in an output shifted by a random + amount in the range `[-20%, +30%]`. `height_factor=0.2` results in + an output height shifted by a random amount in the range + `[-20%, +20%]`. + width_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for shifting horizontally. + A negative value means shifting image left, while a positive value + means shifting image right. When represented as a single positive + float, this value is used for both the upper and lower bound. For + instance, `width_factor=(-0.2, 0.3)` results in an output shifted + left by 20%, and shifted right by 30%. `width_factor=0.2` results + in an output height shifted left or right by 20%. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + seed: Integer. Used to create a random seed. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode="constant"`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + """ + + _USE_BASE_FACTOR = False + _FACTOR_VALIDATION_ERROR = ( + "The `factor` argument should be a number (or a list of two numbers) " + "in the range [-1.0, 1.0]. " + ) + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + + def __init__( + self, + height_factor, + width_factor, + fill_mode="reflect", + interpolation="bilinear", + seed=None, + fill_value=0.0, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self.height_factor = height_factor + self.height_lower, self.height_upper = self._set_factor( + height_factor, "height_factor" + ) + self.width_factor = width_factor + self.width_lower, self.width_upper = self._set_factor( + width_factor, "width_factor" + ) + + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + self.fill_mode = fill_mode + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.supports_jit = False + + def _set_factor(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + factor = abs(factor) + lower, upper = [-factor, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < -1.0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + return self._translate_inputs(images, transformation) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def get_transformed_x_y(self, x, y, transform): + a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split( + transform, 8, axis=-1 + ) + + k = c0 * x + c1 * y + 1 + x_transformed = (a0 * x + a1 * y + a2) / k + y_transformed = (b0 * x + b1 * y + b2) / k + return x_transformed, y_transformed + + def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor): + bboxes = bounding_boxes["boxes"] + x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1) + + w_shift_factor = self.backend.convert_to_tensor( + w_shift_factor, dtype=x1.dtype + ) + h_shift_factor = self.backend.convert_to_tensor( + h_shift_factor, dtype=x1.dtype + ) + + if len(bboxes.shape) == 3: + w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1) + h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1) + + bounding_boxes["boxes"] = self.backend.numpy.concatenate( + [ + x1 - w_shift_factor, + x2 - h_shift_factor, + x3 - w_shift_factor, + x4 - h_shift_factor, + ], + axis=-1, + ) + return bounding_boxes + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + if self.data_format == "channels_first": + height_axis = -2 + width_axis = -1 + else: + height_axis = -3 + width_axis = -2 + + input_height, input_width = ( + transformation["input_shape"][height_axis], + transformation["input_shape"][width_axis], + ) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=input_height, + width=input_width, + ) + + translations = transformation["translations"] + transform = self._get_translation_matrix(translations) + + w_shift_factor, h_shift_factor = self.get_transformed_x_y( + 0, 0, transform + ) + bounding_boxes = self.get_shifted_bbox( + bounding_boxes, w_shift_factor, h_shift_factor + ) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=input_height, + width=input_width, + bounding_box_format="xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=input_height, + width=input_width, + ) + + self.backend.reset() + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + + images_shape = self.backend.shape(images) + unbatched = len(images_shape) == 3 + if unbatched: + images_shape = self.backend.shape(images) + batch_size = 1 + else: + batch_size = images_shape[0] + if self.data_format == "channels_first": + height = images_shape[-2] + width = images_shape[-1] + else: + height = images_shape[-3] + width = images_shape[-2] + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + height_translate = self.backend.random.uniform( + minval=self.height_lower, + maxval=self.height_upper, + shape=[batch_size, 1], + seed=seed, + ) + height_translate = self.backend.numpy.multiply(height_translate, height) + width_translate = self.backend.random.uniform( + minval=self.width_lower, + maxval=self.width_upper, + shape=[batch_size, 1], + seed=seed, + ) + width_translate = self.backend.numpy.multiply(width_translate, width) + translations = self.backend.cast( + self.backend.numpy.concatenate( + [width_translate, height_translate], axis=1 + ), + dtype="float32", + ) + return {"translations": translations, "input_shape": images_shape} + + def _translate_inputs(self, inputs, transformation): + if transformation is None: + return inputs + + inputs_shape = self.backend.shape(inputs) + unbatched = len(inputs_shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + + translations = transformation["translations"] + outputs = self.backend.image.affine_transform( + inputs, + transform=self._get_translation_matrix(translations), + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) + return outputs + + def _get_translation_matrix(self, translations): + num_translations = self.backend.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # translation matrices are always float32. + return self.backend.numpy.concatenate( + [ + self.backend.numpy.ones((num_translations, 1)), + self.backend.numpy.zeros((num_translations, 1)), + -translations[:, 0:1], + self.backend.numpy.zeros((num_translations, 1)), + self.backend.numpy.ones((num_translations, 1)), + -translations[:, 1:], + self.backend.numpy.zeros((num_translations, 2)), + ], + axis=1, + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "height_factor": self.height_factor, + "width_factor": self.width_factor, + "fill_mode": self.fill_mode, + "interpolation": self.interpolation, + "seed": self.seed, + "fill_value": self.fill_value, + "data_format": self.data_format, + } + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py new file mode 100644 index 000000000000..350f3b957458 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py @@ -0,0 +1,443 @@ +import numpy as np +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.utils import backend_utils + + +class RandomTranslationTest(testing.TestCase): + @parameterized.named_parameters( + ("random_translate_4_by_6", 0.4, 0.6), + ("random_translate_3_by_2", 0.3, 0.2), + ("random_translate_tuple_factor", (-0.5, 0.4), (0.2, 0.3)), + ) + def test_random_translation(self, height_factor, width_factor): + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": height_factor, + "width_factor": width_factor, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + supports_masking=False, + run_training_check=False, + ) + + @parameterized.named_parameters( + ("bad_len", [0.1, 0.2, 0.3], 0.0), + ("bad_type", {"dummy": 0.3}, 0.0), + ("exceed_range_single", -1.1, 0.0), + ("exceed_range_tuple", (-1.1, 0.0), 0.0), + ) + def test_random_translation_with_bad_factor( + self, height_factor, width_factor + ): + with self.assertRaises(ValueError): + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": height_factor, + "width_factor": width_factor, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + supports_masking=False, + run_training_check=False, + ) + + def test_random_translation_with_inference_mode(self): + input_data = np.random.random((1, 4, 4, 3)) + expected_output = input_data + layer = layers.RandomTranslation(0.2, 0.1) + output = layer(input_data, training=False) + self.assertAllClose(output, expected_output) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_up_numeric_reflect(self, data_format): + input_image = np.arange(0, 25) + expected_output = np.asarray( + [ + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [20, 21, 22, 23, 24], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": (-0.2, -0.2), + "width_factor": 0.0, + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_up_numeric_constant(self, data_format): + input_image = np.arange(0, 25).astype("float32") + # Shifting by -.2 * 5 = 1 pixel. + expected_output = np.asarray( + [ + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [0, 0, 0, 0, 0], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)), dtype="float32" + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)), dtype="float32" + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": (-0.2, -0.2), + "width_factor": 0.0, + "fill_mode": "constant", + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_down_numeric_reflect(self, data_format): + input_image = np.arange(0, 25) + # Shifting by .2 * 5 = 1 pixel. + expected_output = np.asarray( + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": (0.2, 0.2), + "width_factor": 0.0, + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_asymmetric_size_numeric_reflect( + self, data_format + ): + input_image = np.arange(0, 16) + # Shifting by .2 * 5 = 1 pixel. + expected_output = np.asarray( + [ + [6, 7], + [4, 5], + [2, 3], + [0, 1], + [0, 1], + [2, 3], + [4, 5], + [6, 7], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 8, 2, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 8, 2, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 8, 2)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 8, 2)) + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": (0.5, 0.5), + "width_factor": 0.0, + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_down_numeric_constant(self, data_format): + input_image = np.arange(0, 25) + # Shifting by .2 * 5 = 1 pixel. + expected_output = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": (0.2, 0.2), + "width_factor": 0.0, + "fill_mode": "constant", + "fill_value": 0.0, + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_left_numeric_reflect(self, data_format): + input_image = np.arange(0, 25) + # Shifting by .2 * 5 = 1 pixel. + expected_output = np.asarray( + [ + [1, 2, 3, 4, 4], + [6, 7, 8, 9, 9], + [11, 12, 13, 14, 14], + [16, 17, 18, 19, 19], + [21, 22, 23, 24, 24], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": 0.0, + "width_factor": (-0.2, -0.2), + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters(["channels_first", "channels_last"]) + def test_random_translation_left_numeric_constant(self, data_format): + input_image = np.arange(0, 25) + # Shifting by .2 * 5 = 1 pixel. + expected_output = np.asarray( + [ + [1, 2, 3, 4, 0], + [6, 7, 8, 9, 0], + [11, 12, 13, 14, 0], + [16, 17, 18, 19, 0], + [21, 22, 23, 24, 0], + ] + ) + if data_format == "channels_last": + input_image = np.reshape(input_image, (1, 5, 5, 1)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 5, 5, 1)) + ) + else: + input_image = np.reshape(input_image, (1, 1, 5, 5)) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, (1, 1, 5, 5)) + ) + self.run_layer_test( + layers.RandomTranslation, + init_kwargs={ + "height_factor": 0.0, + "width_factor": (-0.2, -0.2), + "fill_mode": "constant", + "fill_value": 0.0, + "data_format": data_format, + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + def test_tf_data_compatibility(self): + layer = layers.RandomTranslation(0.2, 0.1) + input_data = np.random.random((1, 4, 4, 3)) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(1).map(layer) + next(iter(ds)).numpy() + + @parameterized.named_parameters( + ( + "with_positive_shift", + [[1.0, 2.0]], + [[3.0, 3.0, 5.0, 5.0], [7.0, 6.0, 8.0, 8.0]], + ), + ( + "with_negative_shift", + [[-1.0, -2.0]], + [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 7.0, 4.0]], + ), + ) + def test_random_flip_bounding_boxes(self, translation, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + random_translation_layer = layers.RandomTranslation( + height_factor=0.5, + width_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "translations": backend_utils.convert_tf_tensor( + np.array(translation) + ), + "input_shape": image_shape, + } + output = random_translation_layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + + self.assertAllClose(output["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_positive_shift", + [[1.0, 2.0]], + [[3.0, 3.0, 5.0, 5.0], [7.0, 6.0, 8.0, 8.0]], + ), + ( + "with_negative_shift", + [[-1.0, -2.0]], + [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 7.0, 4.0]], + ), + ) + def test_random_flip_tf_data_bounding_boxes( + self, translation, expected_boxes + ): + data_format = backend.config.image_data_format() + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + random_translation_layer = layers.RandomTranslation( + height_factor=0.5, + width_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "translations": np.array(translation), + "input_shape": image_shape, + } + + ds = ds.map( + lambda x: random_translation_layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py b/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py new file mode 100644 index 000000000000..0fe9ca82713d --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_zoom.py @@ -0,0 +1,430 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.random.seed_generator import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandomZoom") +class RandomZoom(BaseImagePreprocessingLayer): + """A preprocessing layer which randomly zooms images during training. + + This layer will randomly zoom in or out on each axis of an image + independently, filling empty space according to `fill_mode`. + + Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and + of integer or floating point dtype. + By default, the layer will output floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. + + Args: + height_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for zooming vertically. + When represented as a single float, this value is used for both the + upper and lower bound. A positive value means zooming out, while a + negative value means zooming in. For instance, + `height_factor=(0.2, 0.3)` result in an output zoomed out by a + random amount in the range `[+20%, +30%]`. + `height_factor=(-0.3, -0.2)` result in an output zoomed in by a + random amount in the range `[+20%, +30%]`. + width_factor: a float represented as fraction of value, or a tuple of + size 2 representing lower and upper bound for zooming horizontally. + When represented as a single float, this value is used for both the + upper and lower bound. For instance, `width_factor=(0.2, 0.3)` + result in an output zooming out between 20% to 30%. + `width_factor=(-0.3, -0.2)` result in an output zooming in between + 20% to 30%. `None` means i.e., zooming vertical and horizontal + directions by preserving the aspect ratio. Defaults to `None`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"reflect"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + Note that when using torch backend, `"reflect"` is redirected to + `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not + support `"reflect"`. + Note that torch backend does not support `"wrap"`. + interpolation: Interpolation mode. Supported values: `"nearest"`, + `"bilinear"`. + seed: Integer. Used to create a random seed. + fill_value: a float that represents the value to be filled outside + the boundaries when `fill_mode="constant"`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + + Example: + + >>> input_img = np.random.random((32, 224, 224, 3)) + >>> layer = keras.layers.RandomZoom(.5, .2) + >>> out_img = layer(input_img) + """ + + _USE_BASE_FACTOR = False + _FACTOR_VALIDATION_ERROR = ( + "The `height_factor` and `width_factor` arguments " + "should be a number (or a list of two numbers) " + "in the range [-1.0, 1.0]. " + ) + _SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest") + _SUPPORTED_INTERPOLATION = ("nearest", "bilinear") + + def __init__( + self, + height_factor, + width_factor=None, + fill_mode="reflect", + interpolation="bilinear", + seed=None, + fill_value=0.0, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.height_factor = height_factor + self.height_lower, self.height_upper = self._set_factor( + height_factor, "height_factor" + ) + self.width_factor = width_factor + if width_factor is not None: + self.width_lower, self.width_upper = self._set_factor( + width_factor, "width_factor" + ) + if fill_mode not in self._SUPPORTED_FILL_MODE: + raise NotImplementedError( + f"Unknown `fill_mode` {fill_mode}. Expected of one " + f"{self._SUPPORTED_FILL_MODE}." + ) + if interpolation not in self._SUPPORTED_INTERPOLATION: + raise NotImplementedError( + f"Unknown `interpolation` {interpolation}. Expected of one " + f"{self._SUPPORTED_INTERPOLATION}." + ) + + self.fill_mode = fill_mode + self.fill_value = fill_value + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + self.data_format = backend.standardize_data_format(data_format) + self.supports_jit = False + + def _set_factor(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + factor = abs(factor) + lower, upper = [-factor, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < -1.0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + if training: + return self._zoom_inputs(images, transformation) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def get_transformed_x_y(self, x, y, transform): + a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split( + transform, 8, axis=-1 + ) + + k = c0 * x + c1 * y + 1 + x_transformed = (a0 * x + a1 * y + a2) / k + y_transformed = (b0 * x + b1 * y + b2) / k + return x_transformed, y_transformed + + def get_clipped_bbox(self, bounding_boxes, h_end, h_start, w_end, w_start): + bboxes = bounding_boxes["boxes"] + x1, y1, x2, y2 = self.backend.numpy.split(bboxes, 4, axis=-1) + + if len(bboxes.shape) == 3: + h_end = self.backend.numpy.expand_dims(h_end, -1) + h_start = self.backend.numpy.expand_dims(h_start, -1) + w_end = self.backend.numpy.expand_dims(w_end, -1) + w_start = self.backend.numpy.expand_dims(w_start, -1) + + x1 = self.backend.numpy.clip(x1, w_start, w_end) - w_start + y1 = self.backend.numpy.clip(y1, h_start, h_end) - h_start + x2 = self.backend.numpy.clip(x2, w_start, w_end) - w_start + y2 = self.backend.numpy.clip(y2, h_start, h_end) - h_start + bounding_boxes["boxes"] = self.backend.numpy.concatenate( + [x1, y1, x2, y2], axis=-1 + ) + return bounding_boxes + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + if training: + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + width_zoom = transformation["width_zoom"] + height_zoom = transformation["height_zoom"] + inputs_shape = transformation["input_shape"] + + if self.data_format == "channels_first": + height = inputs_shape[-2] + width = inputs_shape[-1] + else: + height = inputs_shape[-3] + width = inputs_shape[-2] + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=height, + width=width, + ) + + zooms = self.backend.cast( + self.backend.numpy.concatenate( + [width_zoom, height_zoom], axis=1 + ), + dtype="float32", + ) + transform = self._get_zoom_matrix(zooms, height, width) + + w_start, h_start = self.get_transformed_x_y( + 0, + 0, + transform, + ) + + w_end, h_end = self.get_transformed_x_y( + width, + height, + transform, + ) + + bounding_boxes = self.get_clipped_bbox( + bounding_boxes, h_end, h_start, w_end, w_start + ) + + height_transformed = h_end - h_start + width_transformed = w_end - w_start + + height_transformed = self.backend.numpy.expand_dims( + height_transformed, -1 + ) + width_transformed = self.backend.numpy.expand_dims( + width_transformed, -1 + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target="rel_xyxy", + height=height_transformed, + width=width_transformed, + ) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=height_transformed, + width=width_transformed, + bounding_box_format="rel_xyxy", + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="rel_xyxy", + target=self.bounding_box_format, + height=height, + width=width, + ) + + self.backend.reset() + + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + if len(images_shape) == 4: + zoom_factor_shape = (images_shape[0], 1) + else: + zoom_factor_shape = (1, 1) + + if not training: + return { + "height_zoom": self.backend.numpy.zeros(zoom_factor_shape), + "width_zoom": self.backend.numpy.zeros(zoom_factor_shape), + } + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + height_zoom = self.backend.random.uniform( + minval=1.0 + self.height_lower, + maxval=1.0 + self.height_upper, + shape=zoom_factor_shape, + seed=seed, + ) + if self.width_factor is not None: + width_zoom = self.backend.random.uniform( + minval=1.0 + self.width_lower, + maxval=1.0 + self.width_upper, + shape=zoom_factor_shape, + seed=seed, + ) + else: + width_zoom = height_zoom + return { + "height_zoom": height_zoom, + "width_zoom": width_zoom, + "input_shape": images_shape, + } + + def _zoom_inputs(self, inputs, transformation): + if transformation is None: + return inputs + + width_zoom = transformation["width_zoom"] + height_zoom = transformation["height_zoom"] + zooms = self.backend.cast( + self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1), + dtype="float32", + ) + + inputs_shape = self.backend.shape(inputs) + unbatched = len(inputs_shape) == 3 + if unbatched: + inputs = self.backend.numpy.expand_dims(inputs, axis=0) + inputs_shape = self.backend.shape(inputs) + if self.data_format == "channels_first": + height = inputs_shape[-2] + width = inputs_shape[-1] + else: + height = inputs_shape[-3] + width = inputs_shape[-2] + + outputs = self.backend.image.affine_transform( + inputs, + transform=self._get_zoom_matrix(zooms, height, width), + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + if unbatched: + outputs = self.backend.numpy.squeeze(outputs, axis=0) + return outputs + + def _get_zoom_matrix(self, zooms, image_height, image_width): + num_zooms = self.backend.shape(zooms)[0] + # The zoom matrix looks like: + # [[zx 0 0] + # [0 zy 0] + # [0 0 1]] + # where the last entry is implicit. + # zoom matrices are always float32. + x_offset = ((self.backend.cast(image_width, "float32") - 1.0) / 2.0) * ( + 1.0 - zooms[:, 0:1] + ) + y_offset = ( + (self.backend.cast(image_height, "float32") - 1.0) / 2.0 + ) * (1.0 - zooms[:, 1:]) + return self.backend.numpy.concatenate( + [ + zooms[:, 0:1], + self.backend.numpy.zeros((num_zooms, 1)), + x_offset, + self.backend.numpy.zeros((num_zooms, 1)), + zooms[:, 1:], + y_offset, + self.backend.numpy.zeros((num_zooms, 2)), + ], + axis=1, + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "height_factor": self.height_factor, + "width_factor": self.width_factor, + "fill_mode": self.fill_mode, + "interpolation": self.interpolation, + "seed": self.seed, + "fill_value": self.fill_value, + "data_format": self.data_format, + } + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py new file mode 100644 index 000000000000..96407e960c60 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py @@ -0,0 +1,269 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.utils import backend_utils + + +class RandomZoomTest(testing.TestCase): + @parameterized.named_parameters( + ("random_zoom_in_4_by_6", -0.4, -0.6), + ("random_zoom_in_2_by_3", -0.2, -0.3), + ("random_zoom_in_tuple_factor", (-0.4, -0.5), (-0.2, -0.3)), + ("random_zoom_out_4_by_6", 0.4, 0.6), + ("random_zoom_out_2_by_3", 0.2, 0.3), + ("random_zoom_out_tuple_factor", (0.4, 0.5), (0.2, 0.3)), + ) + def test_random_zoom(self, height_factor, width_factor): + self.run_layer_test( + layers.RandomZoom, + init_kwargs={ + "height_factor": height_factor, + "width_factor": width_factor, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 4), + supports_masking=False, + run_training_check=False, + ) + + def test_random_zoom_out_correctness(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (1, 5, 5, 1) + else: + input_shape = (1, 1, 5, 5) + input_image = np.reshape(np.arange(0, 25), input_shape) + expected_output = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 2.7, 4.5, 6.3, 0], + [0, 10.2, 12.0, 13.8, 0], + [0, 17.7, 19.5, 21.3, 0], + [0, 0, 0, 0, 0], + ] + ) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, input_shape) + ) + self.run_layer_test( + layers.RandomZoom, + init_kwargs={ + "height_factor": (0.5, 0.5), + "width_factor": (0.8, 0.8), + "interpolation": "bilinear", + "fill_mode": "constant", + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + def test_random_zoom_in_correctness(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (1, 5, 5, 1) + else: + input_shape = (1, 1, 5, 5) + input_image = np.reshape(np.arange(0, 25), input_shape) + expected_output = np.asarray( + [ + [6.0, 6.5, 7.0, 7.5, 8.0], + [8.5, 9.0, 9.5, 10.0, 10.5], + [11.0, 11.5, 12.0, 12.5, 13.0], + [13.5, 14.0, 14.5, 15.0, 15.5], + [16.0, 16.5, 17.0, 17.5, 18.0], + ] + ) + expected_output = backend.convert_to_tensor( + np.reshape(expected_output, input_shape) + ) + self.run_layer_test( + layers.RandomZoom, + init_kwargs={ + "height_factor": (-0.5, -0.5), + "width_factor": (-0.5, -0.5), + "interpolation": "bilinear", + "fill_mode": "constant", + }, + input_shape=None, + input_data=input_image, + expected_output=expected_output, + supports_masking=False, + run_training_check=False, + ) + + def test_tf_data_compatibility(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (1, 5, 5, 1) + else: + input_shape = (1, 1, 5, 5) + input_image = np.reshape(np.arange(0, 25), input_shape) + layer = layers.RandomZoom( + height_factor=(0.5, 0.5), + width_factor=(0.8, 0.8), + interpolation="nearest", + fill_mode="constant", + ) + ds = tf_data.Dataset.from_tensor_slices(input_image).batch(1).map(layer) + expected_output = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 5, 7, 9, 0], + [0, 10, 12, 14, 0], + [0, 20, 22, 24, 0], + [0, 0, 0, 0, 0], + ] + ).reshape(input_shape) + output = next(iter(ds)).numpy() + self.assertAllClose(expected_output, output) + + def test_dynamic_shape(self): + inputs = layers.Input((None, None, 3)) + outputs = layers.RandomZoom( + height_factor=(0.5, 0.5), + width_factor=(0.8, 0.8), + interpolation="nearest", + fill_mode="constant", + )(inputs) + model = models.Model(inputs, outputs) + model.predict(np.random.random((1, 6, 6, 3))) + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="The NumPy backend does not implement fit.", + ) + def test_connect_with_flatten(self): + model = models.Sequential( + [ + layers.RandomZoom((-0.5, 0.0), (-0.5, 0.0)), + layers.Flatten(), + layers.Dense(1, activation="relu"), + ], + ) + + model.compile(loss="mse") + model.fit(np.random.random((2, 2, 2, 1)), y=np.random.random((2,))) + + @parameterized.named_parameters( + ( + "with_zoom_in", + [[[0.1]], [[0.1]]], + [[[0.0, 0.0, 8.0, 0.0], [8.0, 0.0, 8.0, 10.0]]], + ), + ( + "with_zoom_out", + [[[1.9]], [[1.9]]], + [ + [ + [2.710526, 2.657895, 3.763158, 3.710526], + [4.815789, 4.236842, 5.868421, 5.289474], + ] + ], + ), + ) + def test_random_flip_bounding_boxes(self, zoom, expected_boxes): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), + "labels": np.array([[1, 2]]), + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + random_zoom_layer = layers.RandomZoom( + height_factor=(0.5, 0.5), + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "height_zoom": backend_utils.convert_tf_tensor(np.array(zoom[0])), + "width_zoom": backend_utils.convert_tf_tensor(np.array(zoom[1])), + "input_shape": image_shape, + } + output = random_zoom_layer.transform_bounding_boxes( + input_data["bounding_boxes"], + transformation=transformation, + training=True, + ) + + self.assertAllClose(output["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_zoom_in", + [[[0.1]], [[0.1]]], + [[[0.0, 0.0, 8.0, 0.0], [8.0, 0.0, 8.0, 10.0]]], + ), + ( + "with_zoom_out", + [[[1.9]], [[1.9]]], + [ + [ + [2.710526, 2.657895, 3.763158, 3.710526], + [4.815789, 4.236842, 5.868421, 5.289474], + ] + ], + ), + ) + def test_random_flip_tf_data_bounding_boxes(self, zoom, expected_boxes): + data_format = backend.config.image_data_format() + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + random_zoom_layer = layers.RandomZoom( + height_factor=0.5, + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + + transformation = { + "height_zoom": np.array(zoom[0]), + "width_zoom": np.array(zoom[1]), + "input_shape": image_shape, + } + + ds = ds.map( + lambda x: random_zoom_layer.transform_bounding_boxes( + x["bounding_boxes"], + transformation=transformation, + training=True, + ) + ) + + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing.py b/keras/src/layers/preprocessing/image_preprocessing/resizing.py new file mode 100644 index 000000000000..83460175ee54 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing.py @@ -0,0 +1,308 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + clip_to_image_size, +) +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) +from keras.src.ops.core import _saturate_cast + + +@keras_export("keras.layers.Resizing") +class Resizing(BaseImagePreprocessingLayer): + """A preprocessing layer which resizes images. + + This layer resizes an image input to a target height and width. The input + should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"` + format. Input pixel values can be of any range + (e.g. `[0., 1.)` or `[0, 255]`). + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Input shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., height, width, channels)`, in `"channels_last"` format, + or `(..., channels, height, width)`, in `"channels_first"` format. + + Output shape: + 3D (unbatched) or 4D (batched) tensor with shape: + `(..., target_height, target_width, channels)`, + or `(..., channels, target_height, target_width)`, + in `"channels_first"` format. + + Args: + height: Integer, the height of the output shape. + width: Integer, the width of the output shape. + interpolation: String, the interpolation method. + Supports `"bilinear"`, `"nearest"`, `"bicubic"`, + `"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`. + crop_to_aspect_ratio: If `True`, resize the images without aspect + ratio distortion. When the original aspect ratio differs + from the target aspect ratio, the output image will be + cropped so as to return the + largest possible window in the image (of size `(height, width)`) + that matches the target aspect ratio. By default + (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved. + pad_to_aspect_ratio: If `True`, pad the images without aspect + ratio distortion. When the original aspect ratio differs + from the target aspect ratio, the output image will be + evenly padded on the short side. + fill_mode: When using `pad_to_aspect_ratio=True`, padded areas + are filled according to the given mode. Only `"constant"` is + supported at this time + (fill with constant value, equal to `fill_value`). + fill_value: Float. Padding value to use when `pad_to_aspect_ratio=True`. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + """ + + _USE_BASE_FACTOR = False + + def __init__( + self, + height, + width, + interpolation="bilinear", + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + antialias=False, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.height = height + self.width = width + self.interpolation = interpolation + self.data_format = backend.standardize_data_format(data_format) + self.crop_to_aspect_ratio = crop_to_aspect_ratio + self.pad_to_aspect_ratio = pad_to_aspect_ratio + self.fill_mode = fill_mode + self.fill_value = fill_value + self.antialias = bool(antialias) + if self.data_format == "channels_first": + self.height_axis = -2 + self.width_axis = -1 + elif self.data_format == "channels_last": + self.height_axis = -3 + self.width_axis = -2 + + def transform_images(self, images, transformation=None, training=True): + size = (self.height, self.width) + resized = self.backend.image.resize( + images, + size=size, + interpolation=self.interpolation, + antialias=self.antialias, + data_format=self.data_format, + crop_to_aspect_ratio=self.crop_to_aspect_ratio, + pad_to_aspect_ratio=self.pad_to_aspect_ratio, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + ) + if resized.dtype == images.dtype: + return resized + if backend.is_int_dtype(images.dtype): + resized = self.backend.numpy.round(resized) + return _saturate_cast(resized, images.dtype, self.backend) + + def transform_segmentation_masks( + self, segmentation_masks, transformation=None, training=True + ): + return self.transform_images(segmentation_masks) + + def transform_labels(self, labels, transformation=None, training=True): + return labels + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + input_shape = self.backend.shape(data["images"]) + else: + input_shape = self.backend.shape(data) + + input_height, input_width = ( + input_shape[self.height_axis], + input_shape[self.width_axis], + ) + + return input_height, input_width + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + ops = self.backend + input_height, input_width = transformation + mask_negative_1s = ops.numpy.all(bounding_boxes["boxes"] == -1, axis=-1) + mask_zeros = ops.numpy.all(bounding_boxes["boxes"] == 0, axis=-1) + boxes_mask = ops.numpy.logical_or(mask_negative_1s, mask_zeros) + + bounding_boxes = convert_format( + bounding_boxes, + source=self.bounding_box_format, + target="xyxy", + height=input_height, + width=input_width, + ) + + bounding_boxes["boxes"] = self._transform_xyxy( + bounding_boxes["boxes"], + input_height=input_height, + input_width=input_width, + ) + + bounding_boxes = clip_to_image_size( + bounding_boxes=bounding_boxes, + height=self.height, + width=self.width, + ) + + bounding_boxes["boxes"] = ops.numpy.where( + ops.numpy.expand_dims(boxes_mask, axis=-1), + ops.convert_to_tensor( + [0.0, 0.0, 0.0, 0.0], dtype=bounding_boxes["boxes"].dtype + ), + bounding_boxes["boxes"], + ) + + bounding_boxes = convert_format( + bounding_boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) + + return bounding_boxes + + def _transform_xyxy(self, boxes, input_height, input_width): + ops = self.backend + input_height = ops.cast(input_height, dtype=boxes.dtype) + input_width = ops.cast(input_width, dtype=boxes.dtype) + + if self.pad_to_aspect_ratio: + return self._transform_boxes_pad_to_aspect_ratio( + boxes, input_height, input_width + ) + elif self.crop_to_aspect_ratio: + return self._transform_boxes_crop_to_aspect_ratio( + boxes, input_height, input_width + ) + else: + return self._transform_boxes_stretch( + boxes, input_height, input_width + ) + + def _transform_boxes_pad_to_aspect_ratio( + self, boxes, input_height, input_width + ): + """Transforms bounding boxes for padding to aspect ratio.""" + ops = self.backend + height_ratio = ops.cast(self.height / input_height, dtype=boxes.dtype) + width_ratio = ops.cast(self.width / input_width, dtype=boxes.dtype) + min_aspect_ratio = ops.numpy.minimum(height_ratio, width_ratio) + y_offset = (self.height - input_height * min_aspect_ratio) // 2 + x_offset = (self.width - input_width * min_aspect_ratio) // 2 + return ops.numpy.stack( + [ + boxes[..., 0] * min_aspect_ratio + x_offset, + boxes[..., 1] * min_aspect_ratio + y_offset, + boxes[..., 2] * min_aspect_ratio + x_offset, + boxes[..., 3] * min_aspect_ratio + y_offset, + ], + axis=-1, + ) + + def _transform_boxes_crop_to_aspect_ratio( + self, boxes, input_height, input_width + ): + """Transforms bounding boxes for cropping to aspect ratio.""" + ops = self.backend + source_aspect_ratio = input_width / input_height + target_aspect_ratio = self.width / self.height + new_width = ops.numpy.where( + source_aspect_ratio > target_aspect_ratio, + self.height * source_aspect_ratio, + self.width, + ) + new_height = ops.numpy.where( + source_aspect_ratio > target_aspect_ratio, + self.height, + self.width / source_aspect_ratio, + ) + scale_x = new_width / input_width + scale_y = new_height / input_height + crop_left = (new_width - self.width) // 2 + crop_top = (new_height - self.height) // 2 + return ops.numpy.stack( + [ + boxes[..., 0] * scale_x - crop_left, + boxes[..., 1] * scale_y - crop_top, + boxes[..., 2] * scale_x - crop_left, + boxes[..., 3] * scale_y - crop_top, + ], + axis=-1, + ) + + def _transform_boxes_stretch(self, boxes, input_height, input_width): + """Transforms bounding boxes by simple stretching.""" + ops = self.backend + height_ratio = ops.cast(self.height / input_height, dtype=boxes.dtype) + width_ratio = ops.cast(self.width / input_width, dtype=boxes.dtype) + return ops.numpy.stack( + [ + boxes[..., 0] * width_ratio, + boxes[..., 1] * height_ratio, + boxes[..., 2] * width_ratio, + boxes[..., 3] * height_ratio, + ], + axis=-1, + ) + + def compute_output_shape(self, input_shape): + input_shape = list(input_shape) + if len(input_shape) == 4: + if self.data_format == "channels_last": + input_shape[1] = self.height + input_shape[2] = self.width + else: + input_shape[2] = self.height + input_shape[3] = self.width + else: + if self.data_format == "channels_last": + input_shape[0] = self.height + input_shape[1] = self.width + else: + input_shape[1] = self.height + input_shape[2] = self.width + return tuple(input_shape) + + def get_config(self): + base_config = super().get_config() + config = { + "height": self.height, + "width": self.width, + "interpolation": self.interpolation, + "crop_to_aspect_ratio": self.crop_to_aspect_ratio, + "pad_to_aspect_ratio": self.pad_to_aspect_ratio, + "fill_mode": self.fill_mode, + "fill_value": self.fill_value, + "antialias": self.antialias, + "data_format": self.data_format, + } + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py b/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py new file mode 100644 index 000000000000..38dfafbeaab0 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py @@ -0,0 +1,324 @@ +import grain +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import Sequential +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.testing.test_utils import named_product + + +class ResizingTest(testing.TestCase): + @parameterized.named_parameters( + named_product( + interpolation=["nearest", "bilinear", "bicubic", "lanczos5"], + crop_pad=[(False, False), (True, False), (False, True)], + antialias=[False, True], + data_format=["channels_last", "channels_first"], + ) + ) + def test_resizing_basics( + self, + interpolation, + crop_pad, + antialias, + data_format, + ): + if interpolation == "lanczos5" and backend.backend() == "torch": + self.skipTest("Torch does not support lanczos.") + + crop_to_aspect_ratio, pad_to_aspect_ratio = crop_pad + if data_format == "channels_last": + input_shape = (2, 12, 12, 3) + expected_output_shape = (2, 6, 6, 3) + else: + input_shape = (2, 3, 12, 12) + expected_output_shape = (2, 3, 6, 6) + + self.run_layer_test( + layers.Resizing, + init_kwargs={ + "height": 6, + "width": 6, + "interpolation": interpolation, + "crop_to_aspect_ratio": crop_to_aspect_ratio, + "pad_to_aspect_ratio": pad_to_aspect_ratio, + "antialias": antialias, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=expected_output_shape, + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + ) + + @parameterized.parameters([("channels_first",), ("channels_last",)]) + def test_down_sampling_numeric(self, data_format): + img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype(np.float32) + if data_format == "channels_first": + img = img.transpose(0, 3, 1, 2) + out = layers.Resizing( + height=2, width=2, interpolation="nearest", data_format=data_format + )(img) + ref_out = ( + np.asarray([[5, 7], [13, 15]]) + .astype(np.float32) + .reshape((1, 2, 2, 1)) + ) + if data_format == "channels_first": + ref_out = ref_out.transpose(0, 3, 1, 2) + self.assertAllClose(ref_out, out) + + @parameterized.parameters([("channels_first",), ("channels_last",)]) + def test_up_sampling_numeric(self, data_format): + img = np.reshape(np.arange(0, 4), (1, 2, 2, 1)).astype(np.float32) + if data_format == "channels_first": + img = img.transpose(0, 3, 1, 2) + out = layers.Resizing( + height=4, + width=4, + interpolation="nearest", + data_format=data_format, + )(img) + ref_out = ( + np.asarray([[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]) + .astype(np.float32) + .reshape((1, 4, 4, 1)) + ) + if data_format == "channels_first": + ref_out = ref_out.transpose(0, 3, 1, 2) + self.assertAllClose(ref_out, out) + + @parameterized.parameters([("channels_first",), ("channels_last",)]) + def test_crop_to_aspect_ratio(self, data_format): + img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype("float32") + if data_format == "channels_first": + img = img.transpose(0, 3, 1, 2) + out = layers.Resizing( + height=4, + width=2, + interpolation="nearest", + data_format=data_format, + crop_to_aspect_ratio=True, + )(img) + ref_out = ( + np.asarray( + [ + [1, 2], + [5, 6], + [9, 10], + [13, 14], + ] + ) + .astype("float32") + .reshape((1, 4, 2, 1)) + ) + if data_format == "channels_first": + ref_out = ref_out.transpose(0, 3, 1, 2) + self.assertAllClose(ref_out, out) + + @parameterized.parameters([("channels_first",), ("channels_last",)]) + def test_unbatched_image(self, data_format): + img = np.reshape(np.arange(0, 16), (4, 4, 1)).astype("float32") + if data_format == "channels_first": + img = img.transpose(2, 0, 1) + out = layers.Resizing( + 2, 2, interpolation="nearest", data_format=data_format + )(img) + ref_out = ( + np.asarray( + [ + [5, 7], + [13, 15], + ] + ) + .astype("float32") + .reshape((2, 2, 1)) + ) + if data_format == "channels_first": + ref_out = ref_out.transpose(2, 0, 1) + self.assertAllClose(ref_out, out) + + def test_tf_data_compatibility(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 12, 3) + output_shape = (2, 8, 9, 3) + else: + input_shape = (2, 3, 10, 12) + output_shape = (2, 3, 8, 9) + layer = layers.Resizing(8, 9) + input_data = np.random.random(input_shape) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + output = next(iter(ds)).numpy() + self.assertEqual(tuple(output.shape), output_shape) + + def test_grain_compatibility(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 12, 3) + output_shape = (2, 8, 9, 3) + else: + input_shape = (2, 3, 10, 12) + output_shape = (2, 3, 8, 9) + layer = layers.Resizing(8, 9) + input_data = np.random.random(input_shape) + ds = ( + grain.MapDataset.source(input_data) + .to_iter_dataset() + .batch(2) + .map(layer) + ) + output = next(iter(ds)) + output_np = backend.convert_to_numpy(output) + + self.assertEqual(tuple(output_np.shape), output_shape) + self.assertTrue(backend.is_tensor(output)) + # Ensure the device of the data is on CPU. + if backend.backend() == "tensorflow": + self.assertIn("CPU", str(output.device)) + elif backend.backend() == "jax": + self.assertIn("CPU", str(output.device)) + elif backend.backend() == "torch": + self.assertEqual("cpu", str(output.device)) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Sequential + tf.data only works with TF backend", + ) + def test_tf_data_compatibility_sequential(self): + # Test compatibility when wrapping in a Sequential + # https://github.com/keras-team/keras/issues/347 + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 12, 3) + output_shape = (2, 8, 9, 3) + else: + input_shape = (2, 3, 10, 12) + output_shape = (2, 3, 8, 9) + layer = layers.Resizing(8, 9) + input_data = np.random.random(input_shape) + ds = ( + tf_data.Dataset.from_tensor_slices(input_data) + .batch(2) + .map(Sequential([layer])) + ) + output = next(iter(ds)).numpy() + self.assertEqual(tuple(output.shape), output_shape) + + @parameterized.parameters( + [((15, 10), "channels_last"), ((15, 100), "channels_last")] + ) + def test_data_stretch(self, size, data_format): + img = np.random.rand(1, 1, 4, 4) + output = layers.Resizing( + size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True + )(img) + self.assertEqual(output.shape, (1, *size, 4)) + + @parameterized.named_parameters( + ( + "with_pad_to_aspect_ratio", + True, + False, + [[6.0, 2.0, 10.0, 6.0], [14.0, 8.0, 18.0, 12.0]], + ), + ( + "with_crop_to_aspect_ratio", + False, + True, + [[5.0, 0.5, 10.0, 5.5], [15.0, 8.0, 20.0, 13.0]], + ), + ( + "boxes_stretch", + False, + False, + [[5.0, 2.0, 10.0, 6.0], [15.0, 8.0, 20.0, 12.0]], + ), + ) + def test_resize_bounding_boxes( + self, pad_to_aspect_ratio, crop_to_aspect_ratio, expected_boxes + ): + if backend.config.image_data_format() == "channels_last": + image_shape = (10, 8, 3) + else: + image_shape = (3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ), # Example boxes (normalized) + "labels": np.array([[1, 2]]), # Dummy labels + } + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + resizing_layer = layers.Resizing( + height=20, + width=20, + pad_to_aspect_ratio=pad_to_aspect_ratio, + crop_to_aspect_ratio=crop_to_aspect_ratio, + bounding_box_format="xyxy", + ) + output = resizing_layer(input_data) + self.assertAllClose(output["bounding_boxes"]["boxes"], expected_boxes) + + @parameterized.named_parameters( + ( + "with_pad_to_aspect_ratio", + True, + False, + [[6.0, 2.0, 10.0, 6.0], [14.0, 8.0, 18.0, 12.0]], + ), + ( + "with_crop_to_aspect_ratio", + False, + True, + [[5.0, 0.5, 10.0, 5.5], [15.0, 8.0, 20.0, 13.0]], + ), + ( + "boxes_stretch", + False, + False, + [[5.0, 2.0, 10.0, 6.0], [15.0, 8.0, 20.0, 12.0]], + ), + ) + def test_resize_tf_data_bounding_boxes( + self, pad_to_aspect_ratio, crop_to_aspect_ratio, expected_boxes + ): + if backend.config.image_data_format() == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), # Example boxes (normalized) + "labels": np.array([[1, 2]]), # Dummy labels + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + resizing_layer = layers.Resizing( + height=20, + width=20, + pad_to_aspect_ratio=pad_to_aspect_ratio, + crop_to_aspect_ratio=crop_to_aspect_ratio, + bounding_box_format="xyxy", + ) + ds = ds.map(resizing_layer) + output = next(iter(ds)) + expected_boxes = np.array(expected_boxes) + self.assertAllClose(output["bounding_boxes"]["boxes"], expected_boxes) diff --git a/keras/src/layers/preprocessing/image_preprocessing/solarization.py b/keras/src/layers/preprocessing/image_preprocessing/solarization.py new file mode 100644 index 000000000000..ae182f8e18fd --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/solarization.py @@ -0,0 +1,217 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.ops.core import _saturate_cast +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.Solarization") +class Solarization(BaseImagePreprocessingLayer): + """Applies `(max_value - pixel + min_value)` for each pixel in the image. + + When created without `threshold` parameter, the layer performs solarization + to all values. When created with specified `threshold` the layer only + augments pixels that are above the `threshold` value. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + addition_factor: (Optional) A tuple of two floats or a single float, + between 0 and 1. + For each augmented image a value is + sampled from the provided range. If a float is passed, the range is + interpreted as `(0, addition_factor)`. If specified, this value + (times the value range of input images, e.g. 255), is + added to each pixel before solarization and thresholding. + Defaults to 0.0. + threshold_factor: (Optional) A tuple of two floats or a single float. + For each augmented image a value is + sampled from the provided range. If a float is passed, the range is + interpreted as `(0, threshold_factor)`. If specified, only pixel + values above this threshold will be solarized. + value_range: a tuple or a list of two elements. The first value + represents the lower bound for values in input images, the second + represents the upper bound. Images passed to the layer should have + values within `value_range`. Typical values to pass + are `(0, 255)` (RGB image) or `(0., 1.)` (scaled image). + seed: Integer. Used to create a random seed. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + + Example: + + ```python + (images, labels), _ = keras.datasets.cifar10.load_data() + print(images[0, 0, 0]) + # [59 62 63] + # Note that images are Tensor with values in the range [0, 255] + solarization = Solarization(value_range=(0, 255)) + images = solarization(images) + print(images[0, 0, 0]) + # [196, 193, 192] + ``` + """ + + _USE_BASE_FACTOR = False + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + _FACTOR_VALIDATION_ERROR = ( + "The `addition_factor` and `threshold_factor` arguments " + "should be a number (or a list of two numbers) " + "in the range [0, 1]. " + ) + + def __init__( + self, + addition_factor=0.0, + threshold_factor=0.0, + value_range=(0, 255), + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.seed = seed + self.generator = SeedGenerator(seed) + self.addition_factor = self._set_factor( + addition_factor, "addition_factor" + ) + self.threshold_factor = self._set_factor( + threshold_factor, "threshold_factor" + ) + self._set_value_range(value_range) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def _set_factor(self, factor, factor_name): + if isinstance(factor, (tuple, list)): + if len(factor) != 2: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + self._check_factor_range(factor[0]) + self._check_factor_range(factor[1]) + lower, upper = sorted(factor) + elif isinstance(factor, (int, float)): + self._check_factor_range(factor) + lower, upper = [0, factor] + else: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: {factor_name}={factor}" + ) + return lower, upper + + def _check_factor_range(self, input_number): + if input_number > 1.0 or input_number < 0: + raise ValueError( + self._FACTOR_VALIDATION_ERROR + + f"Received: input_number={input_number}" + ) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + if len(images_shape) == 4: + factor_shape = (images_shape[0], 1, 1, 1) + else: + factor_shape = (1, 1, 1) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + return { + "additions": self.backend.random.uniform( + minval=self.addition_factor[0], + maxval=self.addition_factor[1] * 255, + shape=factor_shape, + seed=seed, + dtype=self.compute_dtype, + ), + "thresholds": self.backend.random.uniform( + minval=self.threshold_factor[0], + maxval=self.threshold_factor[1] * 255, + shape=factor_shape, + seed=seed, + dtype=self.compute_dtype, + ), + } + + def transform_images(self, images, transformation, training=True): + images = self.backend.cast(images, self.compute_dtype) + + if training: + if transformation is None: + return images + + thresholds = transformation["thresholds"] + additions = transformation["additions"] + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + results = images + additions + results = self.backend.numpy.clip(results, 0, 255) + results = self.backend.numpy.where( + results < thresholds, results, 255 - results + ) + results = self._transform_value_range( + results, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + if results.dtype == images.dtype: + return results + if backend.is_int_dtype(images.dtype): + results = self.backend.numpy.round(results) + return _saturate_cast(results, images.dtype, self.backend) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def get_config(self): + base_config = super().get_config() + config = { + "value_range": self.value_range, + "addition_factor": self.addition_factor, + "threshold_factor": self.threshold_factor, + "seed": self.seed, + } + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/solarization_test.py b/keras/src/layers/preprocessing/image_preprocessing/solarization_test.py new file mode 100644 index 000000000000..d562dd7a0e6c --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/solarization_test.py @@ -0,0 +1,84 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import ops +from keras.src import random +from keras.src import testing + + +class SolarizationTest(testing.TestCase): + def _test_input_output(self, layer, input_value, expected_value, dtype): + input = np.ones(shape=(2, 224, 224, 3), dtype=dtype) * input_value + expected_output = ops.clip( + ( + np.ones(shape=(2, 224, 224, 3), dtype=layer.compute_dtype) + * expected_value + ), + 0, + 255, + ) + output = layer(input) + self.assertAllClose(output, expected_output) + + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.Solarization, + init_kwargs={ + "addition_factor": 0.75, + "value_range": (20, 200), + "threshold_factor": (0, 1), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + @parameterized.named_parameters( + ("0_255", 0, 255), + ("64_191", 64, 191), + ("127_128", 127, 128), + ("191_64", 191, 64), + ("255_0", 255, 0), + ) + def test_output_values(self, input_value, expected_value): + solarization = layers.Solarization(value_range=(0, 255)) + + self._test_input_output( + layer=solarization, + input_value=input_value, + expected_value=expected_value, + dtype="uint8", + ) + + @parameterized.named_parameters( + ("0_0", 0, 0), + ("191_64", 191, 64), + ("255_0", 255, 0), + ) + def test_only_values_above_threshold_are_solarized( + self, input_value, output_value + ): + solarization = layers.Solarization( + threshold_factor=(128.0 / 255.0, 128.0 / 255.0), + value_range=(0, 255), + ) + + self._test_input_output( + layer=solarization, + input_value=input_value, + expected_value=output_value, + dtype="uint8", + ) + + def test_random_augmentation_applied_per_sample(self): + image = random.uniform((16, 16, 3), minval=0, maxval=255) + images = ops.stack([image, image]) + layer = layers.Solarization( + value_range=(0, 255), threshold_factor=0.5, addition_factor=0.5 + ) + outputs = layer(images) + self.assertNotAllClose(outputs[0], outputs[1]) diff --git a/keras/src/layers/preprocessing/index_lookup.py b/keras/src/layers/preprocessing/index_lookup.py new file mode 100644 index 000000000000..3fe55a07e703 --- /dev/null +++ b/keras/src/layers/preprocessing/index_lookup.py @@ -0,0 +1,1009 @@ +import collections + +import numpy as np + +from keras.src import backend +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation +from keras.src.utils import numerical_utils +from keras.src.utils import tf_utils +from keras.src.utils.module_utils import tensorflow as tf + + +class IndexLookup(Layer): + """Maps values from a vocabulary to integer indices. + + This layer translates a set of arbitrary hashables into an integer output + via a table-based lookup, with optional out-of-vocabulary handling. This is + the basis layer for both IntegerLookup and StringLookup; it holds the common + logic but is not intended to be exported as part of the Keras API. + + Args: + max_tokens: The maximum size of the vocabulary for this layer. + If `None`, there is no cap on the size of the vocabulary. + Note that this size includes the OOV and mask tokens. + num_oov_indices: The number of out-of-vocabulary tokens to use. + If this value is more than 1, OOV inputs are hashed to determine + their OOV value. If this value is 0, + OOV inputs will cause an error when calling the layer. + mask_token: A token that represents masked inputs. + When `output_mode` is `"int"`, + the token is included in vocabulary and mapped to index 0. + In other output modes, the token will not appear in the vocabulary + and instances of the mask token in the input will be dropped. + If set to `None`, no mask term will be added. + oov_token: Only used when `invert` is `True`. + The token to return for OOV indices. + vocabulary: Optional. Either an array or a string path to a text file. + If passing an array, can pass a tuple, list, 1D numpy array, + or 1D tensor containing the vocbulary terms. + If passing a file path, the file should contain one line per term + in the vocabulary. If this argument is set, + there is no need to `adapt` the layer. + vocabulary_dtype: The dtype of the vocabulary terms. + For example, `"int64"` or `"string"`. + idf_weights: Only valid when `output_mode` is `"tf_idf"`. + A tuple, list, 1D numpy array, or 1D tensor or the same length + as the vocabulary, containing the floating point + inverse document frequency weights, which will be multiplied + by per sample term counts for the final TF-IDF + weight. If the `vocabulary` argument is set, and `output_mode` + is `"tf_idf"`, this argument must be supplied. + invert: Only valid when `output_mode` is `"int"`. + If `True`, this layer will map indices to vocabulary items + instead of mapping vocabulary items to indices. + Defaults to `False`. + output_mode: Specification for the output of the layer. Values can be + `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or `"tf_idf"` + configuring the layer as follows: + - `"int"`: Return the raw integer indices of the input tokens. + - `"one_hot"`: Encodes each individual element in the input into an + array the same size as the vocabulary, containing a 1 + at the element index. If the last dimension is size 1, + will encode on that dimension. + If the last dimension is not size 1, + will append a new dimension for the encoded output. + - `"multi_hot"`: Encodes each sample in the input into + a single array the same size as the vocabulary, + containing a 1 for each vocabulary term present in the sample. + Treats the last dimension as the sample dimension, + if input shape is `(..., sample_length)`, output shape will + be `(..., num_tokens)`. + - `"count"`: As `"multi_hot"`, but the int array contains a count + of the number of times the token at that index appeared + in the sample. + - `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm + is applied to find the value in each token slot. + Defaults to `"int"`. + pad_to_max_tokens: Only valid when `output_mode` is `"multi_hot"`, + `"count"`, or `"tf_idf"`. If `True`, the output will have its + feature axis padded to `max_tokens` even if the number + of unique tokens in the vocabulary is less than max_tokens, + resulting in a tensor of shape `(batch_size, max_tokens)` + regardless of vocabulary size. Defaults to `False`. + sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`, + `"count"` and `"tf-idf"` output modes. + If `True`, returns a `SparseTensor` instead of a dense `Tensor`. + Defaults to `False`. + """ + + def __init__( + self, + max_tokens, + num_oov_indices, + mask_token, + oov_token, + vocabulary_dtype, + vocabulary=None, + idf_weights=None, + invert=False, + output_mode="int", + sparse=False, + pad_to_max_tokens=False, + name=None, + **kwargs, + ): + # If max_tokens is set, the value must be greater than 1 - otherwise we + # are creating a 0-element vocab, which doesn't make sense. + if max_tokens is not None and max_tokens <= 1: + raise ValueError( + "If set, `max_tokens` must be greater than 1. " + f"Received: max_tokens={max_tokens}" + ) + + if pad_to_max_tokens and max_tokens is None: + raise ValueError( + "If pad_to_max_tokens is True, must set `max_tokens`. " + f"Received: max_tokens={max_tokens}" + ) + + if num_oov_indices < 0: + raise ValueError( + "`num_oov_indices` must be greater than or equal to 0. " + f"Received: num_oov_indices={num_oov_indices}" + ) + + # Support deprecated names for output_modes. + if output_mode == "binary": + output_mode = "multi_hot" + if output_mode == "tf-idf": + output_mode = "tf_idf" + argument_validation.validate_string_arg( + output_mode, + allowable_strings=( + "int", + "one_hot", + "multi_hot", + "count", + "tf_idf", + ), + caller_name=self.__class__.__name__, + arg_name="output_mode", + ) + + if invert and output_mode != "int": + raise ValueError( + "`output_mode` must be `'int'` when `invert` is true. " + f"Received: output_mode={output_mode}" + ) + + if sparse and output_mode == "int": + raise ValueError( + "`sparse` may only be true if `output_mode` is " + "`'one_hot'`, `'multi_hot'`, `'count'` or `'tf_idf'`. " + f"Received: sparse={sparse} and " + f"output_mode={output_mode}" + ) + + if idf_weights is not None and output_mode != "tf_idf": + raise ValueError( + "`idf_weights` should only be set if `output_mode` is " + f"`'tf_idf'`. Received: idf_weights={idf_weights} and " + f"output_mode={output_mode}" + ) + + super().__init__(name=name) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.supports_jit = False + + self.invert = invert + self.max_tokens = max_tokens + self.num_oov_indices = num_oov_indices + self.mask_token = mask_token + self.oov_token = oov_token + self.output_mode = output_mode + self.sparse = sparse + self.pad_to_max_tokens = pad_to_max_tokens + self.vocabulary_dtype = tf.as_dtype(vocabulary_dtype).name + self._frozen_vocab_size = kwargs.pop("vocabulary_size", None) + + self.input_vocabulary = vocabulary + self.input_idf_weights = idf_weights + + # We set this hidden attr to + # persist the fact that we have have a non-adaptable layer with a + # manually set vocab. + self._has_input_vocabulary = kwargs.pop( + "has_input_vocabulary", (vocabulary is not None) + ) + kwargs.pop("trainable", None) + kwargs.pop("dtype", None) + if kwargs: + raise ValueError(f"Unrecognized keyword argument(s): {kwargs}") + + if invert: + self._key_dtype = "int64" + self._value_dtype = self.vocabulary_dtype + mask_key = 0 + mask_value = mask_token + self._default_value = self.oov_token + else: + self._key_dtype = self.vocabulary_dtype + self._value_dtype = "int64" + mask_key = mask_token + # Masks should map to 0 for int output and be dropped otherwise. Max + # ints will be dropped from the bincount op. + mask_value = ( + 0 + if self.output_mode == "int" + else tf.as_dtype(self._value_dtype).max + ) + if self.num_oov_indices == 0: + # If there are no OOV indices, we map OOV tokens to -1 and error + # out during call if we find a negative index. + self._default_value = -1 + elif self.num_oov_indices == 1: + # If there is only one OOV index, we can set that index as the + # default value of the index_lookup table. + self._default_value = self._oov_start_index() + else: + # If we have multiple OOV values, we need to do a further + # hashing step; to make this easier, we set the OOV value to -1. + # (This lets us do a vectorized add and cast to boolean to + # determine locations where we need to do extra hashing.) + self._default_value = -1 + if self.mask_token is not None: + self._mask_key = tf.convert_to_tensor(mask_key, self._key_dtype) + self._mask_value = tf.convert_to_tensor( + mask_value, self._value_dtype + ) + + if self.output_mode == "tf_idf": + if self._has_input_vocabulary and idf_weights is None: + raise ValueError( + "When specifying the `vocabulary` argument, " + "in TF-IDF output mode, the `idf_weights` argument " + "must also be provided." + ) + if idf_weights is not None: + self.idf_weights = tf.Variable( + idf_weights, + dtype=backend.floatx(), + trainable=False, + ) + self.idf_weights_const = self.idf_weights.value() + + if vocabulary is not None: + self.set_vocabulary(vocabulary, idf_weights) + else: + # When restoring from a keras SavedModel, the loading code will + # expect to find and restore a lookup_table attribute on the layer. + # This table needs to be uninitialized as a StaticHashTable cannot + # be initialized twice. + self.lookup_table = self._uninitialized_lookup_table() + + # Only set up adapt state if we did not receive a vocab on construction. + if not self._has_input_vocabulary: + # Set adapt state. + self.token_counts = tf.lookup.experimental.MutableHashTable( + key_dtype=vocabulary_dtype, + value_dtype="int64", + default_value=0, + ) + if self.output_mode == "tf_idf": + self.token_document_counts = ( + tf.lookup.experimental.MutableHashTable( + key_dtype=vocabulary_dtype, + value_dtype="int64", + default_value=0, + ) + ) + self.num_documents = tf.Variable( + 0, dtype="int64", trainable=False + ) + + def get_vocabulary(self, include_special_tokens=True): + """Returns the current vocabulary of the layer. + + Args: + include_special_tokens: If `True`, the returned vocabulary + will include mask and OOV tokens, + and a term's index in the vocabulary + will equal the term's index when calling the layer. + If `False`, the returned vocabulary will not include + any mask or OOV tokens. + """ + # The lookup table data will not be sorted, so we will create a inverted + # lookup here, and use that to lookup a range of indices + # [0, vocab_size). + if self.lookup_table.size() == 0: + vocab, indices = [], [] + else: + keys, values = self.lookup_table.export() + vocab, indices = (values, keys) if self.invert else (keys, values) + vocab, indices = ( + self._tensor_vocab_to_numpy(vocab), + indices.numpy(), + ) + lookup = collections.defaultdict( + lambda: self.oov_token, zip(indices, vocab) + ) + vocab = [lookup[x] for x in range(self.vocabulary_size())] + if self.mask_token is not None and self.output_mode == "int": + vocab[0] = self.mask_token + if not include_special_tokens: + vocab = vocab[self._token_start_index() :] + if self.vocabulary_dtype == "string": + return [ + i.decode("utf-8") if isinstance(i, bytes) else i for i in vocab + ] + else: + return vocab + + def vocabulary_size(self): + """Gets the current size of the layer's vocabulary. + + Returns: + The integer size of the vocabulary, including optional mask and oov + indices. + """ + if tf.executing_eagerly(): + return ( + int(self.lookup_table.size().numpy()) + + self._token_start_index() + ) + else: + return self.lookup_table.size() + self._token_start_index() + + def get_config(self): + config = { + "invert": self.invert, + "max_tokens": self.max_tokens, + "num_oov_indices": self.num_oov_indices, + "oov_token": self.oov_token, + "mask_token": self.mask_token, + "output_mode": self.output_mode, + "sparse": self.sparse, + "pad_to_max_tokens": self.pad_to_max_tokens, + "vocabulary_dtype": self.vocabulary_dtype, + "idf_weights": listify_tensors(self.input_idf_weights), + "vocabulary": listify_tensors(self.input_vocabulary), + "vocabulary_size": self._frozen_vocab_size, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def _record_vocabulary_size(self): + self._ensure_vocab_size_unchanged() + with tf.init_scope(): + self._frozen_vocab_size = self.vocabulary_size() + + def set_vocabulary(self, vocabulary, idf_weights=None): + """Sets vocabulary (and optionally document frequency) for this layer. + + This method sets the vocabulary and idf weights for this layer directly, + instead of analyzing a dataset through `adapt`. It should be used + whenever the vocab (and optionally document frequency) information is + already known. If vocabulary data is already present in the layer, this + method will replace it. + + Args: + vocabulary: Either an array or a string path to a text file. + If passing an array, can pass a tuple, list, + 1D numpy array, or 1D tensor containing the vocbulary terms. + If passing a file path, the file should contain one line + per term in the vocabulary. + idf_weights: A tuple, list, 1D numpy array, or 1D tensor + of inverse document frequency weights with equal + length to vocabulary. Must be set if `output_mode` + is `"tf_idf"`. Should not be set otherwise. + """ + if self.output_mode == "tf_idf": + if idf_weights is None: + raise ValueError( + "`idf_weights` must be set if output_mode is 'tf_idf'." + ) + elif idf_weights is not None: + raise ValueError( + "`idf_weights` should only be set if output_mode is " + f"`'tf_idf'`. Received: output_mode={self.output_mode} " + f"and idf_weights={idf_weights}" + ) + + if isinstance(vocabulary, str): + if not tf.io.gfile.exists(vocabulary): + raise ValueError( + f"Vocabulary file {vocabulary} does not exist." + ) + if self.output_mode == "tf_idf": + raise ValueError( + "output_mode `'tf_idf'` does not support loading a " + "vocabulary from file." + ) + self.lookup_table = self._lookup_table_from_file(vocabulary) + self._record_vocabulary_size() + return + + if not tf.executing_eagerly() and ( + tf.is_tensor(vocabulary) or tf.is_tensor(idf_weights) + ): + raise RuntimeError( + f"Cannot set a tensor vocabulary on layer {self.name} " + "when not executing eagerly. " + "Create this layer or call `set_vocabulary()` " + "outside of any traced function." + ) + + # TODO(mattdangerw): for better performance we should rewrite this + # entire function to operate on tensors and convert vocabulary to a + # tensor here. + if tf.is_tensor(vocabulary): + vocabulary = self._tensor_vocab_to_numpy(vocabulary) + elif isinstance(vocabulary, (list, tuple)): + vocabulary = np.array(vocabulary) + if tf.is_tensor(idf_weights): + idf_weights = idf_weights.numpy() + elif isinstance(idf_weights, (list, tuple)): + idf_weights = np.array(idf_weights) + + if vocabulary.size == 0: + raise ValueError( + "Cannot set an empty vocabulary. " + f"Received: vocabulary={vocabulary}" + ) + + oov_start = self._oov_start_index() + token_start = self._token_start_index() + special_tokens = [self.mask_token] * oov_start + [ + self.oov_token + ] * self.num_oov_indices + found_special_tokens = np.array_equal( + special_tokens, vocabulary[:token_start] + ) + if found_special_tokens: + tokens = vocabulary[token_start:] + else: + tokens = vocabulary + + repeated_tokens = self._find_repeated_tokens(tokens) + if repeated_tokens: + raise ValueError( + "The passed vocabulary has at least one repeated " + "term. Please uniquify your dataset. The repeated terms " + f"are: {repeated_tokens}" + ) + + if self.mask_token is not None and self.mask_token in tokens: + mask_index = np.argwhere(vocabulary == self.mask_token)[-1] + raise ValueError( + "Found reserved mask token at unexpected location in " + "`vocabulary`. Note that passed `vocabulary` does not need to " + "include the OOV and mask tokens. Either remove all mask and " + "OOV tokens, or include them only at the start of the " + f"vocabulary in precisely this order: {special_tokens}. " + f"Received: mask_token={self.mask_token} at " + f"vocabulary index {mask_index}" + ) + # Only error out for oov_token when invert=True. When invert=False, + # oov_token is unused during lookup. + if ( + self.oov_token is not None + and self.invert + and self.oov_token in tokens + ): + oov_index = np.argwhere(vocabulary == self.oov_token)[-1] + raise ValueError( + "Found reserved OOV token at unexpected location in " + "`vocabulary`. Note that passed `vocabulary` does not need to " + "include the OOV and mask tokens. Either remove all mask and " + "OOV tokens, or include them only at the start of the " + f"vocabulary in precisely this order: {special_tokens}. " + f"Received: oov_token={self.oov_token} at " + f"vocabulary index {oov_index}" + ) + + new_vocab_size = token_start + len(tokens) + if self.max_tokens is not None and (new_vocab_size > self.max_tokens): + raise ValueError( + "Attempted to set a vocabulary larger than the maximum vocab " + f"size. Received vocabulary size is {new_vocab_size}; " + f"`max_tokens` is {self.max_tokens}." + ) + self.lookup_table = self._lookup_table_from_tokens(tokens) + self._record_vocabulary_size() + + if self.output_mode == "tf_idf" and idf_weights is not None: + if len(vocabulary) != len(idf_weights): + raise ValueError( + "`idf_weights` must be the same length as vocabulary. " + f"len(idf_weights) is {len(idf_weights)}; " + f"len(vocabulary) is {len(vocabulary)}" + ) + idf_weights = self._convert_to_ndarray(idf_weights) + if idf_weights.ndim != 1: + raise ValueError( + "TF-IDF data must be a 1-index array. " + f"Received: type(idf_weights)={type(idf_weights)}" + ) + + # If the passed vocabulary has no special tokens, we need to pad the + # front of idf_weights. We don't have real document frequencies for + # these tokens so we will use an average of all idf_weights passed + # in as a reasonable default. + if found_special_tokens: + front_padding = 0 + front_padding_value = 0 + else: + front_padding = token_start + front_padding_value = np.average(idf_weights) + # If pad_to_max_tokens is true, and max_tokens is greater than our + # total vocab size, we need to pad the back of idf_weights with + # zeros as well. + back_padding_value = 0 + if self.pad_to_max_tokens and self.max_tokens is not None: + back_padding = ( + self.max_tokens - front_padding - len(idf_weights) + ) + else: + back_padding = 0 + weights = np.pad( + idf_weights, + (front_padding, back_padding), + "constant", + constant_values=(front_padding_value, back_padding_value), + ) + weights = tf.convert_to_tensor(weights, dtype=backend.floatx()) + self.idf_weights = tf.Variable( + weights, + trainable=False, + ) + self.idf_weights_const = self.idf_weights.value() + + def get_build_config(self): + return {} + + def build_from_config(self, config): + self.build(None) + + @property + def compute_dtype(self): + return self.vocabulary_dtype + + @property + def variable_dtype(self): + return self.vocabulary_dtype + + def compute_output_shape(self, input_shape): + if self.output_mode == "int": + return input_shape + depth = ( + self.max_tokens + if self.pad_to_max_tokens + else self._frozen_vocab_size + ) + return (input_shape[0], depth) + + def compute_output_spec(self, inputs): + if self.output_mode == "int": + output_dtype = "int64" + else: + output_dtype = backend.floatx() + output_shape = self.compute_output_shape(inputs.shape) + return backend.KerasTensor(output_shape, dtype=output_dtype) + + def adapt(self, data, steps=None): + self.reset_state() + if isinstance(data, tf.data.Dataset): + if steps is not None: + data = data.take(steps) + for batch in data: + self.update_state(batch) + else: + data = tf_utils.ensure_tensor(data, dtype=self.vocabulary_dtype) + if data.shape.rank == 1: + # A plain list of strings + # is treated as as many documents + data = tf.expand_dims(data, -1) + self.update_state(data) + self.finalize_state() + + def update_state(self, data): + if self._has_input_vocabulary: + raise ValueError( + f"Cannot adapt layer '{self.name}' after setting a static " + "vocabulary via `vocabulary` argument or " + "`set_vocabulary()` method." + ) + + data = tf_utils.ensure_tensor(data, dtype=self.vocabulary_dtype) + if data.shape.rank == 0: + data = tf.expand_dims(data, 0) + if data.shape.rank == 1: + # Expand dims on axis 0 for tf-idf. A 1-d tensor + # is a single document. + data = tf.expand_dims(data, 0) + + tokens, counts = self._num_tokens(data) + self.token_counts.insert( + tokens, counts + self.token_counts.lookup(tokens) + ) + + if self.output_mode == "tf_idf": + # Dedupe each row of our dataset. + if isinstance(data, tf.RaggedTensor): + deduped_doc_data = tf.map_fn(lambda x: tf.unique(x)[0], data) + else: + deduped_doc_data = [tf.unique(x)[0] for x in data] + deduped_doc_data = tf.concat(deduped_doc_data, axis=0) + # Flatten and count tokens. + tokens, counts = self._num_tokens(deduped_doc_data) + + self.token_document_counts.insert( + tokens, counts + self.token_document_counts.lookup(tokens) + ) + if isinstance(data, tf.RaggedTensor): + self.num_documents.assign_add(data.nrows()) + else: + self.num_documents.assign_add( + tf.shape(data, out_type="int64")[0] + ) + + def finalize_state(self): + if self._has_input_vocabulary or tf.equal(self.token_counts.size(), 0): + # Finalize idf_weights to a const for call even if we don't need to + # compute a new vocabulary. + if self.output_mode == "tf_idf": + self.idf_weights_const = self.idf_weights.value() + self._record_vocabulary_size() + return + + # Remove special tokens from our counts. + if self.mask_token is not None: + self.token_counts.remove( + tf.convert_to_tensor([self.mask_token], self.vocabulary_dtype) + ) + if self.oov_token is not None: + self.token_counts.remove( + tf.convert_to_tensor([self.oov_token], self.vocabulary_dtype) + ) + + tokens, counts = self.token_counts.export() + # To keep vocabs deterministic, we sort our tokens by count and break + # ties by sorting the tokens themselves. Tensorflow has no ops for + # sorting strings, so we need to use numpy for the sort. + sorted_indices = np.lexsort((tokens.numpy(), counts.numpy()))[::-1] + token_start = self._token_start_index() + if self.max_tokens: + max_learned_tokens = self.max_tokens - token_start + sorted_indices = sorted_indices[:max_learned_tokens] + tokens = tf.gather(tokens, sorted_indices) + self.lookup_table = self._lookup_table_from_tokens(tokens) + + if self.output_mode == "tf_idf": + token_document_counts = self.token_document_counts.lookup(tokens) + idf_weights = self._inverse_document_frequency( + token_document_counts, self.num_documents + ) + idf_weights = tf.cast(idf_weights, backend.floatx()) + # Pad the front of idf_weights with the average idf weight for OOV + # tokens. We cannot compute the real idf weight of OOV in a single + # pass. + idf_weights = tf.pad( + idf_weights, + [[self._token_start_index(), 0]], + constant_values=tf.reduce_mean(idf_weights), + ) + if self.pad_to_max_tokens and self.max_tokens is not None: + # Pad the back of idf_weights with zeros. + idf_weights = tf.pad( + idf_weights, + [[0, self.max_tokens - tf.size(idf_weights)]], + constant_values=0, + ) + self.idf_weights = tf.Variable( + idf_weights, + dtype=backend.floatx(), + trainable=False, + ) + self.idf_weights_const = self.idf_weights.value() + + # We call this here to save memory, now that we've built our vocabulary, + # we don't want to keep every token we've seen in separate lookup + # tables. + self.reset_state() + self._record_vocabulary_size() + + def reset_state(self): + if self._has_input_vocabulary: + return + + self.token_counts.remove(self.token_counts.export()[0]) + if self.output_mode == "tf_idf": + self.token_document_counts.remove( + self.token_document_counts.export()[0] + ) + self.num_documents.assign(0) + + def call(self, inputs): + from keras.src.backend import tensorflow as tf_backend + + self._ensure_known_vocab_size() + + inputs = tf_utils.ensure_tensor(inputs, dtype=self._key_dtype) + original_shape = inputs.shape + # Some ops will not handle scalar input, so uprank to rank 1. + if inputs.shape.rank == 0: + inputs = self._expand_dims(inputs, -1) + + if isinstance(inputs, tf.SparseTensor): + lookups = tf.SparseTensor( + inputs.indices, + self._lookup_dense(inputs.values), + inputs.dense_shape, + ) + elif isinstance(inputs, tf.RaggedTensor): + lookups = tf.ragged.map_flat_values(self._lookup_dense, inputs) + else: + lookups = self._lookup_dense(inputs) + + if self.output_mode == "int": + # If we received a scalar input, downrank back to a scalar. + if original_shape.rank == 0: + lookups = tf.squeeze(lookups, -1) + return lookups + + depth = ( + self.max_tokens + if self.pad_to_max_tokens + else self._frozen_vocab_size + ) + idf_weights = ( + self.idf_weights_const if self.output_mode == "tf_idf" else None + ) + output = numerical_utils.encode_categorical_inputs( + lookups, + output_mode=( + "count" if self.output_mode == "tf_idf" else self.output_mode + ), + depth=depth, + dtype=self._value_dtype, + sparse=self.sparse, + backend_module=tf_backend, + ) + if self.output_mode == "tf_idf": + if idf_weights is None: + raise ValueError( + "When `output_mode` is `'tf_idf'`, `idf_weights` must be " + "provided." + ) + output = tf_backend.numpy.multiply( + tf_backend.core.cast(output, idf_weights.dtype), idf_weights + ) + return output + + def _lookup_dense(self, inputs): + """Lookup table values for a dense Tensor, handling masking and OOV.""" + # When executing eagerly and tracing keras.Input objects, + # do not call lookup. + # This is critical for restoring SavedModel, which will first trace + # layer.call and then attempt to restore the table. We need the table to + # be uninitialized for the restore to work, but calling the table + # uninitialized would error. + if tf.executing_eagerly() and backend.is_keras_tensor(inputs): + lookups = tf.zeros_like(inputs, dtype=self._value_dtype) + else: + lookups = self.lookup_table.lookup(inputs) + + if self.mask_token is not None: + mask_locations = tf.equal(inputs, self._mask_key) + lookups = tf.where(mask_locations, self._mask_value, lookups) + + if self.invert: + return lookups + + lookup_checks = [] + + if self.num_oov_indices == 0: + # If we have zero oov indices, we need to check for oov inputs. + oov_indices = tf.where(tf.equal(lookups, -1)) + oov_inputs = tf.gather_nd(inputs, oov_indices) + msg = tf.strings.format( + "When `num_oov_indices=0` all inputs should be in vocabulary, " + "found OOV values {}, consider setting `num_oov_indices=1`.", + (oov_inputs,), + ) + assertion = tf.Assert(tf.equal(tf.size(oov_indices), 0), [msg]) + lookup_checks.append(assertion) + elif self.num_oov_indices > 1: + # If we have multiple oov indices, we need a further hashing step. + if tf.as_dtype(self._key_dtype).is_integer: + oov_indices = tf.math.floormod(inputs, self.num_oov_indices) + else: + oov_indices = tf.strings.to_hash_bucket_fast( + inputs, num_buckets=self.num_oov_indices + ) + oov_indices = oov_indices + self._oov_start_index() + oov_locations = tf.equal(lookups, self._default_value) + lookups = tf.where(oov_locations, oov_indices, lookups) + + with tf.control_dependencies(lookup_checks): + return tf.identity(lookups) + + def save_own_variables(self, store): + if self.output_mode == "tf_idf": + store["idf_weights"] = self.idf_weights_const.numpy() + + def load_own_variables(self, store): + if self.output_mode == "tf_idf": + self.idf_weights.assign(store["idf_weights"]) + self.idf_weights_const = self.idf_weights.value() + + def save_assets(self, dir_path): + if self.input_vocabulary is not None: + # Vocab saved in config. + # TODO: consider unifying both paths. + return + vocabulary = self.get_vocabulary(include_special_tokens=True) + vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt") + with open(vocabulary_filepath, "w") as f: + f.write("\n".join([str(w) for w in vocabulary])) + + def load_assets(self, dir_path): + if self.input_vocabulary is not None: + # Vocab saved in config. + # TODO: consider unifying both paths. + return + vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt") + # TODO: fix bug with include_special_tokens and set reload from file. + with open(vocabulary_filepath, "r") as f: + lines = f.read().split("\n") + if tf.as_dtype(self.vocabulary_dtype) == tf.string: + values = [str(line) for line in lines] + else: + values = [int(line) for line in lines] + if self.output_mode == "tf_idf": + self.set_vocabulary(values, idf_weights=False) + else: + self.set_vocabulary(values) + + def _uninitialized_lookup_table(self): + with tf.init_scope(): + initializer = get_null_initializer( + self._key_dtype, self._value_dtype + ) + return tf.lookup.StaticHashTable(initializer, self._default_value) + + def _lookup_table_from_tokens(self, tokens): + with tf.init_scope(): + token_start = self._token_start_index() + token_end = token_start + tf.size(tokens) + indices_dtype = ( + self._key_dtype if self.invert else self._value_dtype + ) + indices = tf.range(token_start, token_end, dtype=indices_dtype) + keys, values = ( + (indices, tokens) if self.invert else (tokens, indices) + ) + initializer = tf.lookup.KeyValueTensorInitializer( + keys, values, self._key_dtype, self._value_dtype + ) + return tf.lookup.StaticHashTable(initializer, self._default_value) + + def _lookup_table_from_file(self, filename): + if self.invert: + key_index = tf.lookup.TextFileIndex.LINE_NUMBER + value_index = tf.lookup.TextFileIndex.WHOLE_LINE + else: + key_index = tf.lookup.TextFileIndex.WHOLE_LINE + value_index = tf.lookup.TextFileIndex.LINE_NUMBER + with tf.init_scope(): + initializer = tf.lookup.TextFileInitializer( + filename=filename, + key_dtype=self._key_dtype, + key_index=key_index, + value_dtype=self._value_dtype, + value_index=value_index, + value_index_offset=self._token_start_index(), + ) + return tf.lookup.StaticHashTable(initializer, self._default_value) + + def _convert_to_ndarray(self, x): + return np.array(x) if isinstance(x, (list, tuple)) else x + + def _expand_dims(self, inputs, axis): + if isinstance(inputs, tf.SparseTensor): + return tf.sparse.expand_dims(inputs, axis) + else: + return tf.expand_dims(inputs, axis) + + def _oov_start_index(self): + return ( + 1 + if self.mask_token is not None and self.output_mode == "int" + else 0 + ) + + def _token_start_index(self): + return self._oov_start_index() + self.num_oov_indices + + def _ensure_known_vocab_size(self): + if self.output_mode == "int" or self.pad_to_max_tokens: + return + if self._frozen_vocab_size is None: + raise RuntimeError( + f"When using `output_mode={self.output_mode}` " + "and `pad_to_max_tokens=False`, " + "you must set the layer's vocabulary before calling it. Either " + "pass a `vocabulary` argument to the layer, or call `adapt` " + "with some sample data." + ) + + def _ensure_vocab_size_unchanged(self): + if self.output_mode == "int" or self.pad_to_max_tokens: + return + + with tf.init_scope(): + new_vocab_size = self.vocabulary_size() + + if ( + self._frozen_vocab_size is not None + and new_vocab_size != self._frozen_vocab_size + ): + raise RuntimeError( + f"When using `output_mode={self.output_mode}` " + "and `pad_to_max_tokens=False`, " + "the vocabulary size cannot be changed after the layer is " + f"called. Old vocab size is {self._frozen_vocab_size}, " + f"new vocab size is {new_vocab_size}" + ) + + def _find_repeated_tokens(self, vocabulary): + """Return all repeated tokens in a vocabulary.""" + vocabulary_set = set(vocabulary) + if len(vocabulary) != len(vocabulary_set): + return [ + item + for item, count in collections.Counter(vocabulary).items() + if count > 1 + ] + else: + return [] + + def _num_tokens(self, data): + """Count the number of tokens in a ragged, sparse or dense tensor.""" + if isinstance(data, tf.SparseTensor): + flat_values = data.values + elif isinstance(data, tf.RaggedTensor): + flat_values = data.flat_values + else: + flat_values = tf.reshape(data, [-1]) + tokens, _, counts = tf.unique_with_counts(flat_values, out_idx="int64") + return tokens, counts + + def _inverse_document_frequency(self, token_document_counts, num_documents): + """Computes the inverse-document-frequency (IDF) component of "tf_idf". + Args: + token_document_counts: An array of the # of documents each token + appears in. + num_documents: An int representing the total number of documents + + Returns: + An array of "inverse document frequency" weights. + """ + return tf.math.log(1 + num_documents / (1 + token_document_counts)) + + # Override points for IntegerLookup and StringLookup. + def _tensor_vocab_to_numpy(self, vocabulary): + """Converts a tensor vocabulary to a numpy vocabulary.""" + return vocabulary.numpy() + + +def get_null_initializer(key_dtype, value_dtype): + class NullInitializer(tf.lookup.KeyValueTensorInitializer): + """A placeholder initializer for restoring from a SavedModel.""" + + def __init__(self, key_dtype, value_dtype): + """Construct a table initializer object. + + Args: + key_dtype: Type of the table keys. + value_dtype: Type of the table values. + """ + self._key_dtype = key_dtype + self._value_dtype = value_dtype + + @property + def key_dtype(self): + """The expected table key dtype.""" + return self._key_dtype + + @property + def value_dtype(self): + """The expected table value dtype.""" + return self._value_dtype + + def initialize(self, table): + """Returns the table initialization op.""" + pass + + return NullInitializer(key_dtype, value_dtype) + + +def listify_tensors(x): + """Convert any tensors or numpy arrays to lists for config serialization.""" + if tf.is_tensor(x): + x = x.numpy() + if isinstance(x, np.ndarray): + x = x.tolist() + return x diff --git a/keras/src/layers/preprocessing/index_lookup_test.py b/keras/src/layers/preprocessing/index_lookup_test.py new file mode 100644 index 000000000000..7fe7ff113cbf --- /dev/null +++ b/keras/src/layers/preprocessing/index_lookup_test.py @@ -0,0 +1,618 @@ +import os + +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.saving import saving_api + + +@pytest.mark.skipif( + backend.backend() == "numpy", reason="Failing for numpy backend." +) +class IndexLookupLayerTest(testing.TestCase): + def test_basics_string_vocab(self): + # Case: adapt + list inputs + adapt_data = ["one", "one", "one", "two", "two", "three"] + input_data = ["one", "two", "four"] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + self.assertEqual( + layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"] + ) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + ["one", "two", "three"], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: numpy array input + output = layer(np.array(input_data)) + self.assertEqual(list(output), [2, 3, 1]) + + # Case: fixed vocab + list inputs + vocabulary = ["one", "two", "three"] + layer = layers.IndexLookup(vocabulary=vocabulary, **kwargs) + self.assertEqual( + layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"] + ) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + ["one", "two", "three"], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: fixed vocab with special tokens + list inputs + vocabulary_with_special_tokens = ["", "[OOV]", "one", "two", "three"] + layer = layers.IndexLookup( + vocabulary=vocabulary_with_special_tokens, **kwargs + ) + self.assertEqual( + layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"] + ) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + ["one", "two", "three"], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: set vocabulary + layer = layers.IndexLookup(**kwargs) + layer.set_vocabulary(vocabulary) + self.assertEqual( + layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"] + ) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + ["one", "two", "three"], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: set vocabulary (with special tokens) + layer = layers.IndexLookup(**kwargs) + layer.set_vocabulary(vocabulary_with_special_tokens) + self.assertEqual( + layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"] + ) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + ["one", "two", "three"], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + def test_basics_integer_vocab(self): + # Case: adapt + list inputs + adapt_data = [1, 1, 1, 2, 2, 3] + input_data = [1, 2, 4] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": 0, + "oov_token": -1, + "vocabulary_dtype": "int64", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3]) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + [1, 2, 3], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: numpy array input + output = layer(np.array(input_data)) + self.assertEqual(list(output), [2, 3, 1]) + + # Case: fixed vocab + list inputs + vocabulary = [1, 2, 3] + layer = layers.IndexLookup(vocabulary=vocabulary, **kwargs) + self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3]) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + [1, 2, 3], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: fixed vocab with special tokens + list inputs + vocabulary_with_special_tokens = [0, -1, 1, 2, 3] + layer = layers.IndexLookup( + vocabulary=vocabulary_with_special_tokens, **kwargs + ) + self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3]) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + [1, 2, 3], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: set vocabulary + layer = layers.IndexLookup(**kwargs) + layer.set_vocabulary(vocabulary) + self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3]) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + [1, 2, 3], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + # Case: set vocabulary (with special tokens) + layer = layers.IndexLookup(**kwargs) + layer.set_vocabulary(vocabulary_with_special_tokens) + self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3]) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + [1, 2, 3], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + def test_max_tokens_adapt(self): + adapt_data = [1, 1, 1, 2, 2, 3] + input_data = [1, 2, 3, 4] + kwargs = { + "max_tokens": 4, + "num_oov_indices": 1, + "mask_token": 0, + "oov_token": -1, + "vocabulary_dtype": "int64", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2]) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + [1, 2], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + def test_pad_to_max_tokens(self): + vocabulary = [1, 2] + input_data = [1, 2] + kwargs = { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": 0, + "oov_token": -1, + "vocabulary_dtype": "int64", + "vocabulary": vocabulary, + "pad_to_max_tokens": True, + "output_mode": "multi_hot", + } + layer = layers.IndexLookup(**kwargs) + output = layer(input_data) + self.assertAllClose(output, [0, 1, 1, 0, 0]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + def test_output_modes(self): + vocabulary = ["one", "two", "three"] + single_sample_input_data = ["one", "two", "four"] + batch_input_data = [["one", "two", "four", "two"]] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + "vocabulary": vocabulary, + } + + # int + kwargs["output_mode"] = "int" + layer = layers.IndexLookup(**kwargs) + output = layer(single_sample_input_data) + self.assertAllClose(output, [2, 3, 1]) + output = layer(batch_input_data) + self.assertAllClose(output, [[2, 3, 1, 3]]) + + # multi-hot + kwargs["output_mode"] = "multi_hot" + layer = layers.IndexLookup(**kwargs) + output = layer(single_sample_input_data) + self.assertAllClose(output, [1, 1, 1, 0]) + output = layer(batch_input_data) + self.assertAllClose(output, [[1, 1, 1, 0]]) + + # one-hot + kwargs["output_mode"] = "one_hot" + layer = layers.IndexLookup(**kwargs) + output = layer(single_sample_input_data) + self.assertAllClose(output, [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]) + + # count + kwargs["output_mode"] = "count" + layer = layers.IndexLookup(**kwargs) + output = layer(single_sample_input_data) + self.assertAllClose(output, [1, 1, 1, 0]) + output = layer(batch_input_data) + self.assertAllClose(output, [[1, 1, 2, 0]]) + + # tf-idf + kwargs["output_mode"] = "tf_idf" + kwargs["idf_weights"] = np.array([0.1, 0.2, 0.3]) + layer = layers.IndexLookup(**kwargs) + output = layer(single_sample_input_data) + self.assertAllClose(output, [0.2, 0.1, 0.2, 0.0]) + output = layer(batch_input_data) + self.assertAllClose(output, [[0.2, 0.1, 0.4, 0.0]]) + + def test_sparse_outputs(self): + # TODO + pass + + def test_adapt_tf_idf(self): + # Case: unbatched data + adapt_data = ["one", "one", "one", "two", "two", "three"] + input_data = ["one", "two", "four"] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + "output_mode": "tf_idf", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + output = layer(input_data) + # Document counts for one, two, three = [3, 2, 1] + idf_weights = np.log(1 + len(adapt_data) / (1 + np.array([3, 2, 1]))) + self.assertAllClose(layer.idf_weights[1:], idf_weights) + self.assertAllClose(output, [1.1337324, 0.91629076, 1.0986123, 0.0]) + # Case: batched data + adapt_data = [["one", "one"], ["one", "two"], ["two", "three"]] + input_data = [["one", "two"], ["two", "four"]] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + "output_mode": "tf_idf", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + # Document counts for one, two, three = [2, 2, 1] + idf_weights = np.log(1 + len(adapt_data) / (1 + np.array([2, 2, 1]))) + self.assertAllClose(layer.idf_weights[1:], idf_weights) + output = layer(input_data) + self.assertAllClose( + output, + [ + [0.0, 0.6931472, 0.6931472, 0.0], + [0.76752836, 0.0, 0.6931472, 0.0], + ], + ) + + def test_invert(self): + vocabulary = ["one", "two", "three"] + single_sample_input_data = [2, 3, 1] + batch_input_data = [[2, 3, 1, 3]] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + "vocabulary": vocabulary, + "invert": True, + "output_mode": "int", + } + layer = layers.IndexLookup(**kwargs) + output = layer(single_sample_input_data) + self.assertEqual( + [w.decode("utf-8") for w in output.numpy()], ["one", "two", "[OOV]"] + ) + output = layer(batch_input_data) + self.assertEqual( + [w.decode("utf-8") for w in output.numpy()[0]], + ["one", "two", "[OOV]", "two"], + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires string input dtype" + ) + def test_saving(self): + # Test with adapt() + vocabulary = ["one", "two", "three"] + adapt_data = ["one", "one", "one", "two", "two", "three"] + batch_input_data = np.array([["one", "two", "four"]]) + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + "output_mode": "int", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + model = models.Sequential( + [ + layers.Input(shape=(None,), dtype="string"), + layer, + ] + ) + output_1 = model(batch_input_data) + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + model = saving_api.load_model(path) + output_2 = model(batch_input_data) + self.assertAllClose(output_1, output_2) + + # Test when vocabulary is provided + kwargs["vocabulary"] = vocabulary + layer = layers.IndexLookup(**kwargs) + model = models.Sequential( + [ + layers.Input(shape=(None,), dtype="string"), + layer, + ] + ) + output_1 = model(batch_input_data) + path = os.path.join(self.get_temp_dir(), "model.keras") + model.save(path) + model = saving_api.load_model(path) + output_2 = model(batch_input_data) + self.assertAllClose(output_1, output_2) + + def test_adapt_with_tf_data(self): + # Case: adapt + list inputs + adapt_data = tf_data.Dataset.from_tensor_slices( + ["one", "one", "one", "two", "two", "three"] + ).batch(2) + input_data = ["one", "two", "four"] + kwargs = { + "max_tokens": 7, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "vocabulary_dtype": "string", + } + layer = layers.IndexLookup(**kwargs) + layer.adapt(adapt_data) + self.assertEqual( + layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"] + ) + self.assertEqual( + layer.get_vocabulary(include_special_tokens=False), + ["one", "two", "three"], + ) + output = layer(input_data) + self.assertEqual(list(output), [2, 3, 1]) + if backend.backend() != "torch": + self.run_class_serialization_test(layer) + + def test_max_tokens_less_than_two(self): + with self.assertRaisesRegex( + ValueError, + "If set, `max_tokens` must be greater than 1.", + ): + layers.IndexLookup( + max_tokens=1, + num_oov_indices=1, + mask_token=None, + oov_token=None, + vocabulary_dtype="int64", + ) + + def test_max_tokens_none_with_pad_to_max_tokens(self): + with self.assertRaisesRegex( + ValueError, + "If pad_to_max_tokens is True, must set `max_tokens`.", + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="int64", + pad_to_max_tokens=True, + ) + + def test_negative_num_oov_indices(self): + with self.assertRaisesRegex( + ValueError, + "`num_oov_indices` must be greater than or equal to 0.", + ): + layers.IndexLookup( + max_tokens=10, + num_oov_indices=-1, + mask_token=None, + oov_token=None, + vocabulary_dtype="int64", + ) + + def test_invert_with_non_int_output_mode(self): + with self.assertRaisesRegex( + ValueError, r"`output_mode` must be `'int'` when `invert` is true." + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + invert=True, + output_mode="one_hot", # Invalid combination + ) + + def test_sparse_true_with_int_output_mode(self): + with self.assertRaisesRegex( + ValueError, + r"`sparse` may only be true if `output_mode` is `'one_hot'`", + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + sparse=True, + output_mode="int", # Invalid combination + ) + + def test_idf_weights_set_with_non_tfidf_output_mode(self): + with self.assertRaisesRegex( + ValueError, + r"`idf_weights` should only be set if `output_mode` is `'tf_idf'`", + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + idf_weights=[ + 0.5, + 0.1, + 0.3, + ], # Should not be set for non-TF-IDF modes + output_mode="int", + ) + + def test_unrecognized_kwargs(self): + with self.assertRaisesRegex( + ValueError, "Unrecognized keyword argument" + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + output_mode="int", + # This is an unrecognized argument + extra_arg=True, + ) + + def test_non_tf_idf_with_idf_weights(self): + with self.assertRaisesRegex( + ValueError, + "`idf_weights` should only be set if `output_mode` is", + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + output_mode="multi_hot", + idf_weights=[ + 0.5, + 0.1, + 0.3, + ], # idf_weights not valid for multi_hot mode + ) + + def test_vocabulary_file_does_not_exist(self): + with self.assertRaisesRegex( + ValueError, + "Vocabulary file path/to/missing_vocab.txt does not exist", + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + output_mode="int", + # Nonexistent file path + vocabulary="path/to/missing_vocab.txt", + ) + + def test_repeated_tokens_in_vocabulary(self): + with self.assertRaisesRegex( + ValueError, "The passed vocabulary has at least one repeated term." + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token=None, + oov_token=None, + vocabulary_dtype="string", + vocabulary=["token", "token", "unique"], + ) + + def test_mask_token_in_wrong_position(self): + with self.assertRaisesRegex( + ValueError, + "Found reserved mask token at unexpected location in `vocabulary`.", + ): + layers.IndexLookup( + num_oov_indices=1, + max_tokens=None, + mask_token="mask", + oov_token=None, + vocabulary_dtype="string", + vocabulary=[ + "token", + "mask", + "unique", + ], # 'mask' should be at the start if included explicitly + ) + + def test_ensure_known_vocab_size_without_vocabulary(self): + kwargs = { + "num_oov_indices": 1, + # Assume empty string or some default token is valid. + "mask_token": "", + # Assume [OOV] or some default token is valid. + "oov_token": "[OOV]", + "output_mode": "multi_hot", + "pad_to_max_tokens": False, + "vocabulary_dtype": "string", + "max_tokens": None, + } + layer = layers.IndexLookup(**kwargs) + + # Try calling the layer without setting the vocabulary. + with self.assertRaisesRegex( + RuntimeError, "When using `output_mode=multi_hot` and" + ): + input_data = ["sample", "data"] + layer(input_data) diff --git a/keras/src/layers/preprocessing/integer_lookup.py b/keras/src/layers/preprocessing/integer_lookup.py new file mode 100644 index 000000000000..b99da00b3941 --- /dev/null +++ b/keras/src/layers/preprocessing/integer_lookup.py @@ -0,0 +1,409 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.index_lookup import IndexLookup +from keras.src.utils import backend_utils +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.IntegerLookup") +class IntegerLookup(IndexLookup): + """A preprocessing layer that maps integers to (possibly encoded) indices. + + This layer maps a set of arbitrary integer input tokens into indexed integer + output via a table-based vocabulary lookup. The layer's output indices will + be contiguously arranged up to the maximum vocab size, even if the input + tokens are non-continguous or unbounded. The layer supports multiple options + for encoding the output via `output_mode`, and has optional support for + out-of-vocabulary (OOV) tokens and masking. + + The vocabulary for the layer must be either supplied on construction or + learned via `adapt()`. During `adapt()`, the layer will analyze a data set, + determine the frequency of individual integer tokens, and create a + vocabulary from them. If the vocabulary is capped in size, the most frequent + tokens will be used to create the vocabulary and all others will be treated + as OOV. + + There are two possible output modes for the layer. When `output_mode` is + `"int"`, input integers are converted to their index in the vocabulary (an + integer). When `output_mode` is `"multi_hot"`, `"count"`, or `"tf_idf"`, + input integers are encoded into an array where each dimension corresponds to + an element in the vocabulary. + + The vocabulary can optionally contain a mask token as well as an OOV token + (which can optionally occupy multiple indices in the vocabulary, as set + by `num_oov_indices`). + The position of these tokens in the vocabulary is fixed. When `output_mode` + is `"int"`, the vocabulary will begin with the mask token at index 0, + followed by OOV indices, followed by the rest of the vocabulary. When + `output_mode` is `"multi_hot"`, `"count"`, or `"tf_idf"` the vocabulary will + begin with OOV indices and instances of the mask token will be dropped. + + **Note:** This layer uses TensorFlow internally. It cannot + be used as part of the compiled computation graph of a model with + any backend other than TensorFlow. + It can however be used with any backend when running eagerly. + It can also always be used as part of an input preprocessing pipeline + with any backend (outside the model itself), which is how we recommend + to use this layer. + + **Note:** This layer is safe to use inside a `tf.data` pipeline + (independently of which backend you're using). + + Args: + max_tokens: Maximum size of the vocabulary for this layer. This should + only be specified when adapting the vocabulary or when setting + `pad_to_max_tokens=True`. If None, there is no cap on the size of + the vocabulary. Note that this size includes the OOV + and mask tokens. Defaults to `None`. + num_oov_indices: The number of out-of-vocabulary tokens to use. + If this value is more than 1, OOV inputs are modulated to + determine their OOV value. + If this value is 0, OOV inputs will cause an error when calling + the layer. Defaults to `1`. + mask_token: An integer token that represents masked inputs. When + `output_mode` is `"int"`, the token is included in vocabulary + and mapped to index 0. In other output modes, + the token will not appear in the vocabulary and instances + of the mask token in the input will be dropped. + If set to None, no mask term will be added. Defaults to `None`. + oov_token: Only used when `invert` is `True`. The token to return + for OOV indices. Defaults to `-1`. + vocabulary: Optional. Either an array of integers or a string path to a + text file. If passing an array, can pass a tuple, list, + 1D NumPy array, or 1D tensor containing the integer vocbulary terms. + If passing a file path, the file should contain one line per term + in the vocabulary. If this argument is set, + there is no need to `adapt()` the layer. + vocabulary_dtype: The dtype of the vocabulary terms. + Only `vocabulary_dtype='int64'` is supported at this time. + Defaults to `"int64"`. + idf_weights: Only valid when `output_mode` is `"tf_idf"`. + A tuple, list, 1D NumPy array, or 1D tensor or the same length + as the vocabulary, containing the floating point inverse document + frequency weights, which will be multiplied by per sample term + counts for the final TF-IDF weight. + If the `vocabulary` argument is set, and `output_mode` is + `"tf_idf"`, this argument must be supplied. + invert: Only valid when `output_mode` is `"int"`. + If `True`, this layer will map indices to vocabulary items + instead of mapping vocabulary items to indices. + Defaults to `False`. + output_mode: Specification for the output of the layer. Values can be + `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or `"tf_idf"` + configuring the layer as follows: + - `"int"`: Return the vocabulary indices of the input tokens. + - `"one_hot"`: Encodes each individual element in the input into an + array the same size as the vocabulary, + containing a 1 at the element index. If the last dimension + is size 1, will encode on that dimension. + If the last dimension is not size 1, will append a new + dimension for the encoded output. + - `"multi_hot"`: Encodes each sample in the input into a single + array the same size as the vocabulary, + containing a 1 for each vocabulary term present in the sample. + Treats the last dimension as the sample dimension, + if input shape is `(..., sample_length)`, + output shape will be `(..., num_tokens)`. + - `"count"`: As `"multi_hot"`, but the int array contains + a count of the number of times the token at that index + appeared in the sample. + - `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is + applied to find the value in each token slot. + For `"int"` output, the output shape matches the input shape. + For `"one_hot"` output, the output shape is + `input_shape + (vocabulary_size,)`, where `input_shape` may + have arbitrary rank. For other output modes (`"multi_hot"`, + `"count"`, `"tf_idf"`), the output shape is `(batch_size, + vocabulary_size)`. Defaults to `"int"`. + pad_to_max_tokens: Only applicable when `output_mode` is `"multi_hot"`, + `"count"`, or `"tf_idf"`. If `True`, the output will have + its feature axis padded to `max_tokens` even if the number + of unique tokens in the vocabulary is less than `max_tokens`, + resulting in a tensor of shape `(batch_size, max_tokens)` + regardless of vocabulary size. Defaults to `False`. + sparse: Boolean. Only applicable to `"multi_hot"`, `"count"`, and + `"tf_idf"` output modes. Only supported with TensorFlow + backend. If `True`, returns a `SparseTensor` + instead of a dense `Tensor`. Defaults to `False`. + + Examples: + + **Creating a lookup layer with a known vocabulary** + + This example creates a lookup layer with a pre-existing vocabulary. + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([[12, 1138, 42], [42, 1000, 36]]) # Note OOV tokens + >>> layer = IntegerLookup(vocabulary=vocab) + >>> layer(data) + array([[1, 3, 4], + [4, 0, 2]]) + + **Creating a lookup layer with an adapted vocabulary** + + This example creates a lookup layer and generates the vocabulary by + analyzing the dataset. + + >>> data = np.array([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup() + >>> layer.adapt(data) + >>> layer.get_vocabulary() + [-1, 42, 1138, 1000, 36, 12] + + Note that the OOV token -1 have been added to the vocabulary. The remaining + tokens are sorted by frequency (42, which has 2 occurrences, is first) then + by inverse sort order. + + >>> data = np.array([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup() + >>> layer.adapt(data) + >>> layer(data) + array([[5, 2, 1], + [1, 3, 4]]) + + **Lookups with multiple OOV indices** + + This example demonstrates how to use a lookup layer with multiple OOV + indices. When a layer is created with more than one OOV index, any OOV + tokens are hashed into the number of OOV buckets, distributing OOV tokens in + a deterministic fashion across the set. + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([[12, 1138, 42], [37, 1000, 36]]) + >>> layer = IntegerLookup(vocabulary=vocab, num_oov_indices=2) + >>> layer(data) + array([[2, 4, 5], + [1, 0, 3]]) + + Note that the output for OOV token 37 is 1, while the output for OOV token + 1000 is 0. The in-vocab terms have their output index increased by 1 from + earlier examples (12 maps to 2, etc) in order to make space for the extra + OOV token. + + **One-hot output** + + Configure the layer with `output_mode='one_hot'`. Note that the first + `num_oov_indices` dimensions in the ont_hot encoding represent OOV values. + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([12, 36, 1138, 42, 7]) # Note OOV tokens + >>> layer = IntegerLookup(vocabulary=vocab, output_mode='one_hot') + >>> layer(data) + array([[0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.], + [0., 0., 0., 1., 0.], + [0., 0., 0., 0., 1.], + [1., 0., 0., 0., 0.]], dtype=float32) + + **Multi-hot output** + + Configure the layer with `output_mode='multi_hot'`. Note that the first + `num_oov_indices` dimensions in the multi_hot encoding represent OOV tokens + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([[12, 1138, 42, 42], + ... [42, 7, 36, 7]]) # Note OOV tokens + >>> layer = IntegerLookup(vocabulary=vocab, output_mode='multi_hot') + >>> layer(data) + array([[0., 1., 0., 1., 1.], + [1., 0., 1., 0., 1.]], dtype=float32) + + **Token count output** + + Configure the layer with `output_mode='count'`. As with multi_hot output, + the first `num_oov_indices` dimensions in the output represent OOV tokens. + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([[12, 1138, 42, 42], + ... [42, 7, 36, 7]]) # Note OOV tokens + >>> layer = IntegerLookup(vocabulary=vocab, output_mode='count') + >>> layer(data) + array([[0., 1., 0., 1., 2.], + [2., 0., 1., 0., 1.]], dtype=float32) + + **TF-IDF output** + + Configure the layer with `output_mode='tf_idf'`. As with multi_hot output, + the first `num_oov_indices` dimensions in the output represent OOV tokens. + + Each token bin will output `token_count * idf_weight`, where the idf weights + are the inverse document frequency weights per token. These should be + provided along with the vocabulary. Note that the `idf_weight` for OOV + tokens will default to the average of all idf weights passed in. + + >>> vocab = [12, 36, 1138, 42] + >>> idf_weights = [0.25, 0.75, 0.6, 0.4] + >>> data = np.array([[12, 1138, 42, 42], + ... [42, 7, 36, 7]]) # Note OOV tokens + >>> layer = IntegerLookup( + ... output_mode='tf_idf', vocabulary=vocab, idf_weights=idf_weights) + >>> layer(data) + array([[0. , 0.25, 0. , 0.6 , 0.8 ], + [1.0 , 0. , 0.75, 0. , 0.4 ]], dtype=float32) + + To specify the idf weights for oov tokens, you will need to pass the entire + vocabulary including the leading oov token. + + >>> vocab = [-1, 12, 36, 1138, 42] + >>> idf_weights = [0.9, 0.25, 0.75, 0.6, 0.4] + >>> data = np.array([[12, 1138, 42, 42], + ... [42, 7, 36, 7]]) # Note OOV tokens + >>> layer = IntegerLookup( + ... output_mode='tf_idf', vocabulary=vocab, idf_weights=idf_weights) + >>> layer(data) + array([[0. , 0.25, 0. , 0.6 , 0.8 ], + [1.8 , 0. , 0.75, 0. , 0.4 ]], dtype=float32) + + When adapting the layer in `"tf_idf"` mode, each input sample will + be considered a document, and IDF weight per token will be + calculated as: + `log(1 + num_documents / (1 + token_document_count))`. + + **Inverse lookup** + + This example demonstrates how to map indices to tokens using this layer. + (You can also use `adapt()` with `inverse=True`, but for simplicity we'll + pass the vocab in this example.) + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([[1, 3, 4], [4, 0, 2]]) + >>> layer = IntegerLookup(vocabulary=vocab, invert=True) + >>> layer(data) + array([[ 12, 1138, 42], + [ 42, -1, 36]]) + + Note that the first index correspond to the oov token by default. + + **Forward and inverse lookup pairs** + + This example demonstrates how to use the vocabulary of a standard lookup + layer to create an inverse lookup layer. + + >>> vocab = [12, 36, 1138, 42] + >>> data = np.array([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup(vocabulary=vocab) + >>> i_layer = IntegerLookup( + ... vocabulary=layer.get_vocabulary(), invert=True) + >>> int_data = layer(data) + >>> i_layer(int_data) + array([[ 12, 1138, 42], + [ 42, -1, 36]]) + + In this example, the input token 1000 resulted in an output of -1, since + 1000 was not in the vocabulary - it got represented as an OOV, and all OOV + tokens are returned as -1 in the inverse layer. Also, note that for the + inverse to work, you must have already set the forward layer vocabulary + either directly or via `adapt()` before calling `get_vocabulary()`. + """ + + def __init__( + self, + max_tokens=None, + num_oov_indices=1, + mask_token=None, + oov_token=-1, + vocabulary=None, + vocabulary_dtype="int64", + idf_weights=None, + invert=False, + output_mode="int", + sparse=False, + pad_to_max_tokens=False, + name=None, + **kwargs, + ): + if not tf.available: + raise ImportError( + "Layer IntegerLookup requires TensorFlow. " + "Install it via `pip install tensorflow`." + ) + if max_tokens is not None and max_tokens <= 1: + raise ValueError( + "If `max_tokens` is set for `IntegerLookup`, it must be " + f"greater than 1. Received: max_tokens={max_tokens}" + ) + if num_oov_indices < 0: + raise ValueError( + "The value of `num_oov_indices` argument for `IntegerLookup` " + "must >= 0. Received: num_oov_indices=" + f"{num_oov_indices}" + ) + if sparse and backend.backend() != "tensorflow": + raise ValueError( + "`sparse=True` can only be used with the TensorFlow backend." + ) + if vocabulary_dtype != "int64": + raise ValueError( + "Only `vocabulary_dtype='int64'` is supported " + "at this time. Received: " + f"vocabulary_dtype={vocabulary_dtype}" + ) + super().__init__( + max_tokens=max_tokens, + num_oov_indices=num_oov_indices, + mask_token=mask_token, + oov_token=oov_token, + vocabulary=vocabulary, + vocabulary_dtype=vocabulary_dtype, + idf_weights=idf_weights, + invert=invert, + output_mode=output_mode, + sparse=sparse, + pad_to_max_tokens=pad_to_max_tokens, + name=name, + **kwargs, + ) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.supports_jit = False + + def adapt(self, data, steps=None): + """Computes a vocabulary of integer terms from tokens in a dataset. + + Calling `adapt()` on an `IntegerLookup` layer is an alternative to + passing in a precomputed vocabulary on construction via the + `vocabulary` argument. An `IntegerLookup` layer should always be either + adapted over a dataset or supplied with a vocabulary. + + During `adapt()`, the layer will build a vocabulary of all integer + tokens seen in the dataset, sorted by occurrence count, with ties broken + by sort order of the tokens (high to low). At the end of `adapt()`, if + `max_tokens` is set, the vocabulary will be truncated to `max_tokens` + size. For example, adapting a layer with `max_tokens=1000` will compute + the 1000 most frequent tokens occurring in the input dataset. If + `output_mode='tf-idf'`, `adapt()` will also learn the document + frequencies of each token in the input dataset. + + Arguments: + data: The data to train on. It can be passed either as a + batched `tf.data.Dataset`, as a list of integers, + or as a NumPy array. + steps: Integer or `None`. + Total number of steps (batches of samples) to process. + If `data` is a `tf.data.Dataset`, and `steps` is `None`, + `adapt()` will run until the input dataset is exhausted. + When passing an infinitely + repeating dataset, you must specify the `steps` argument. This + argument is not supported with array inputs or list inputs. + """ + super().adapt(data, steps=steps) + + def get_config(self): + config = super().get_config() + if config["oov_token"] is not None: + config["oov_token"] = int(config["oov_token"]) + if config["mask_token"] is not None: + config["mask_token"] = int(config["mask_token"]) + if config["vocabulary"] is not None: + config["vocabulary"] = [int(v) for v in config["vocabulary"]] + return config + + def call(self, inputs): + if not isinstance( + inputs, (tf.Tensor, tf.RaggedTensor, np.ndarray, list, tuple) + ): + inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) + outputs = super().call(inputs) + return backend_utils.convert_tf_tensor(outputs) diff --git a/keras/src/layers/preprocessing/integer_lookup_test.py b/keras/src/layers/preprocessing/integer_lookup_test.py new file mode 100644 index 000000000000..9e2ed6482b26 --- /dev/null +++ b/keras/src/layers/preprocessing/integer_lookup_test.py @@ -0,0 +1,155 @@ +import numpy as np +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class IntegerLookupTest(testing.TestCase): + # TODO: increase coverage. Most features aren't being tested. + + def test_config(self): + layer = layers.IntegerLookup( + output_mode="int", + vocabulary=[1, 2, 3], + oov_token=1, + mask_token=0, + ) + self.run_class_serialization_test(layer) + + def test_adapt_flow(self): + adapt_data = [1, 1, 1, 2, 2, 3] + single_sample_input_data = [1, 2, 4] + batch_input_data = [[1, 2, 4], [2, 3, 5]] + + # int mode + layer = layers.IntegerLookup( + output_mode="int", + ) + layer.adapt(adapt_data) + output = layer(single_sample_input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([1, 2, 0])) + output = layer(batch_input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[1, 2, 0], [2, 3, 0]])) + + # one_hot mode + layer = layers.IntegerLookup( + output_mode="one_hot", + ) + layer.adapt(adapt_data) + output = layer(single_sample_input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose( + output, np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]) + ) + + # multi_hot mode + layer = layers.IntegerLookup( + output_mode="multi_hot", + ) + layer.adapt(adapt_data) + output = layer(single_sample_input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([1, 1, 1, 0])) + + # tf_idf mode + layer = layers.IntegerLookup( + output_mode="tf_idf", + ) + layer.adapt(adapt_data) + output = layer(single_sample_input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose( + output, np.array([1.133732, 0.916291, 1.098612, 0.0]) + ) + + # count mode + layer = layers.IntegerLookup( + output_mode="count", + ) + layer.adapt(adapt_data) + output = layer([1, 2, 3, 4, 1, 2, 1]) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([1, 3, 2, 1])) + + def test_fixed_vocabulary(self): + layer = layers.IntegerLookup( + output_mode="int", + vocabulary=[1, 2, 3, 4], + ) + input_data = [2, 3, 4, 5] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([2, 3, 4, 0])) + + def test_set_vocabulary(self): + layer = layers.IntegerLookup( + output_mode="int", + ) + layer.set_vocabulary([1, 2, 3, 4]) + input_data = [2, 3, 4, 5] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([2, 3, 4, 0])) + + def test_tf_data_compatibility(self): + layer = layers.IntegerLookup( + output_mode="int", + vocabulary=[1, 2, 3, 4], + ) + input_data = [2, 3, 4, 5] + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(4).map(layer) + output = next(iter(ds)).numpy() + self.assertAllClose(output, np.array([2, 3, 4, 0])) + + def test_one_hot_output_with_higher_rank_input(self): + input_data = np.array([[1, 2], [3, 0]]) + vocabulary = [1, 2, 3] + layer = layers.IntegerLookup( + vocabulary=vocabulary, output_mode="one_hot" + ) + output_data = layer(input_data) + self.assertEqual(output_data.shape, (2, 2, 4)) + expected_output = np.array( + [ + [[0, 1, 0, 0], [0, 0, 1, 0]], + [[0, 0, 0, 1], [1, 0, 0, 0]], + ] + ) + self.assertAllClose(output_data, expected_output) + output_data_3d = layer(np.expand_dims(input_data, axis=0)) + self.assertEqual(output_data_3d.shape, (1, 2, 2, 4)) + self.assertAllClose( + output_data_3d, np.expand_dims(expected_output, axis=0) + ) + + def test_multi_hot_output_shape(self): + input_data = np.array([[1, 2], [3, 0]]) + vocabulary = [1, 2, 3] + layer = layers.IntegerLookup( + vocabulary=vocabulary, output_mode="multi_hot" + ) + output_data = layer(input_data) + self.assertEqual(output_data.shape, (2, 4)) + + def test_count_output_shape(self): + input_data = np.array([[1, 2], [3, 0]]) + vocabulary = [1, 2, 3] + layer = layers.IntegerLookup(vocabulary=vocabulary, output_mode="count") + output_data = layer(input_data) + self.assertEqual(output_data.shape, (2, 4)) + + def test_tf_idf_output_shape(self): + input_data = np.array([[1, 2], [3, 0]]) + vocabulary = [1, 2, 3] + idf_weights = [1.0, 1.0, 1.0] + layer = layers.IntegerLookup( + vocabulary=vocabulary, + idf_weights=idf_weights, + output_mode="tf_idf", + ) + output_data = layer(input_data) + self.assertEqual(output_data.shape, (2, 4)) diff --git a/keras/src/layers/preprocessing/mel_spectrogram.py b/keras/src/layers/preprocessing/mel_spectrogram.py new file mode 100644 index 000000000000..ed3022d86b9a --- /dev/null +++ b/keras/src/layers/preprocessing/mel_spectrogram.py @@ -0,0 +1,374 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.data_layer import DataLayer + +# mel spectrum constants. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +@keras_export("keras.layers.MelSpectrogram") +class MelSpectrogram(DataLayer): + """A preprocessing layer to convert raw audio signals to Mel spectrograms. + + This layer takes `float32`/`float64` single or batched audio signal as + inputs and computes the Mel spectrogram using Short-Time Fourier Transform + and Mel scaling. The input should be a 1D (unbatched) or 2D (batched) tensor + representing audio signals. The output will be a 2D or 3D tensor + representing Mel spectrograms. + + A spectrogram is an image-like representation that shows the frequency + spectrum of a signal over time. It uses x-axis to represent time, y-axis to + represent frequency, and each pixel to represent intensity. + Mel spectrograms are a special type of spectrogram that use the mel scale, + which approximates how humans perceive sound. They are commonly used in + speech and music processing tasks like speech recognition, speaker + identification, and music genre classification. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + References: + - [Spectrogram](https://en.wikipedia.org/wiki/Spectrogram), + - [Mel scale](https://en.wikipedia.org/wiki/Mel_scale). + + Args: + fft_length: Integer, size of the FFT window. + sequence_stride: Integer, number of samples between successive STFT + columns. + sequence_length: Integer, size of the window used for applying + `window` to each audio frame. If `None`, defaults to `fft_length`. + window: String, name of the window function to use. Available values + are `"hann"` and `"hamming"`. If `window` is a tensor, it will be + used directly as the window and its length must be + `sequence_length`. If `window` is `None`, no windowing is + used. Defaults to `"hann"`. + sampling_rate: Integer, sample rate of the input signal. + num_mel_bins: Integer, number of mel bins to generate. + min_freq: Float, minimum frequency of the mel bins. + max_freq: Float, maximum frequency of the mel bins. + If `None`, defaults to `sampling_rate / 2`. + power_to_db: If True, convert the power spectrogram to decibels. + top_db: Float, minimum negative cut-off `max(10 * log10(S)) - top_db`. + mag_exp: Float, exponent for the magnitude spectrogram. + 1 for magnitude, 2 for power, etc. Default is 2. + ref_power: Float, the power is scaled relative to it + `10 * log10(S / ref_power)`. + min_power: Float, minimum value for power and `ref_power`. + + Examples: + + **Unbatched audio signal** + + >>> layer = keras.layers.MelSpectrogram(num_mel_bins=64, + ... sampling_rate=8000, + ... sequence_stride=256, + ... fft_length=2048) + >>> layer(keras.random.uniform(shape=(16000,))).shape + (64, 63) + + **Batched audio signal** + + >>> layer = keras.layers.MelSpectrogram(num_mel_bins=80, + ... sampling_rate=8000, + ... sequence_stride=128, + ... fft_length=2048) + >>> layer(keras.random.uniform(shape=(2, 16000))).shape + (2, 80, 125) + + Input shape: + 1D (unbatched) or 2D (batched) tensor with shape:`(..., samples)`. + + Output shape: + 2D (unbatched) or 3D (batched) tensor with + shape:`(..., num_mel_bins, time)`. + + """ + + def __init__( + self, + fft_length=2048, + sequence_stride=512, + sequence_length=None, + window="hann", + sampling_rate=16000, + num_mel_bins=128, + min_freq=20.0, + max_freq=None, + power_to_db=True, + top_db=80.0, + mag_exp=2.0, + min_power=1e-10, + ref_power=1.0, + **kwargs, + ): + self.fft_length = fft_length + self.sequence_stride = sequence_stride + self.sequence_length = sequence_length or fft_length + self.window = window + self.sampling_rate = sampling_rate + self.num_mel_bins = num_mel_bins + self.min_freq = min_freq + self.max_freq = max_freq or int(sampling_rate / 2) + self.power_to_db = power_to_db + self.top_db = top_db + self.mag_exp = mag_exp + self.min_power = min_power + self.ref_power = ref_power + super().__init__(**kwargs) + + def call(self, inputs): + dtype = ( + "float32" + if self.compute_dtype not in ["float32", "float64"] + else self.compute_dtype + ) # jax, tf supports only "float32" and "float64" in stft + inputs = self.backend.convert_to_tensor(inputs, dtype=dtype) + outputs = self._spectrogram(inputs) + outputs = self._melscale(outputs) + if self.power_to_db: + outputs = self._dbscale(outputs) + # swap time & freq axis to have shape of (..., num_mel_bins, time) + outputs = self.backend.numpy.swapaxes(outputs, -1, -2) + outputs = self.backend.cast(outputs, self.compute_dtype) + return outputs + + def _spectrogram(self, inputs): + real, imag = self.backend.math.stft( + inputs, + sequence_length=self.sequence_length, + sequence_stride=self.sequence_stride, + fft_length=self.fft_length, + window=self.window, + center=True, + ) + # abs of complex = sqrt(real^2 + imag^2) + spec = self.backend.numpy.sqrt( + self.backend.numpy.add( + self.backend.numpy.square(real), self.backend.numpy.square(imag) + ) + ) + spec = self.backend.numpy.power(spec, self.mag_exp) + return spec + + def _melscale(self, inputs): + matrix = self.linear_to_mel_weight_matrix( + num_mel_bins=self.num_mel_bins, + num_spectrogram_bins=self.backend.shape(inputs)[-1], + sampling_rate=self.sampling_rate, + lower_edge_hertz=self.min_freq, + upper_edge_hertz=self.max_freq, + ) + return self.backend.numpy.tensordot(inputs, matrix, axes=1) + + def _dbscale(self, inputs): + log_spec = 10.0 * ( + self.backend.numpy.log10( + self.backend.numpy.maximum(inputs, self.min_power) + ) + ) + ref_value = self.backend.numpy.abs( + self.backend.convert_to_tensor(self.ref_power) + ) + log_spec -= 10.0 * self.backend.numpy.log10( + self.backend.numpy.maximum(ref_value, self.min_power) + ) + log_spec = self.backend.numpy.maximum( + log_spec, self.backend.numpy.max(log_spec) - self.top_db + ) + return log_spec + + def _hertz_to_mel(self, frequencies_hertz): + """Converts frequencies in `frequencies_hertz` in Hertz to the + mel scale. + + Args: + frequencies_hertz: A tensor of frequencies in Hertz. + name: An optional name for the operation. + + Returns: + A tensor of the same shape and type of `frequencies_hertz` + containing frequencies in the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * self.backend.numpy.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ) + ) + + def linear_to_mel_weight_matrix( + self, + num_mel_bins=20, + num_spectrogram_bins=129, + sampling_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0, + dtype="float32", + ): + """Returns a matrix to warp linear scale spectrograms to the mel scale. + + Returns a weight matrix that can be used to re-weight a tensor + containing `num_spectrogram_bins` linearly sampled frequency information + from `[0, sampling_rate / 2]` into `num_mel_bins` frequency information + from `[lower_edge_hertz, upper_edge_hertz]` on the mel scale. + + This function follows the [Hidden Markov Model Toolkit (HTK)]( + http://htk.eng.cam.ac.uk/) convention, defining the mel scale in + terms of a frequency in hertz according to the following formula: + + ```mel(f) = 2595 * log10( 1 + f/700)``` + + In the returned matrix, all the triangles (filterbanks) have a peak + value of 1.0. + + For example, the returned matrix `A` can be used to right-multiply a + spectrogram `S` of shape `[frames, num_spectrogram_bins]` of linear + scale spectrum values (e.g. STFT magnitudes) to generate a + "mel spectrogram" `M` of shape `[frames, num_mel_bins]`. + + ``` + # `S` has shape [frames, num_spectrogram_bins] + # `M` has shape [frames, num_mel_bins] + M = keras.ops.matmul(S, A) + ``` + + The matrix can be used with `keras.ops.tensordot` to convert an + arbitrary rank `Tensor` of linear-scale spectral bins into the + mel scale. + + ``` + # S has shape [..., num_spectrogram_bins]. + # M has shape [..., num_mel_bins]. + M = keras.ops.tensordot(S, A, 1) + ``` + + References: + - [Mel scale (Wikipedia)](https://en.wikipedia.org/wiki/Mel_scale) + + Args: + num_mel_bins: Python int. How many bands in the resulting + mel spectrum. + num_spectrogram_bins: An integer `Tensor`. How many bins there are + in the source spectrogram data, which is understood to be + `fft_size // 2 + 1`, i.e. the spectrogram only contains the + nonredundant FFT bins. + sampling_rate: An integer or float `Tensor`. Samples per second of + the input signal used to create the spectrogram. Used to figure + out the frequencies corresponding to each spectrogram bin, + which dictates how they are mapped into the mel scale. + lower_edge_hertz: Python float. Lower bound on the frequencies to be + included in the mel spectrum. This corresponds to the lower + edge of the lowest triangular band. + upper_edge_hertz: Python float. The desired top edge of the highest + frequency band. + dtype: The `DType` of the result matrix. Must be a floating point + type. + + Returns: + A tensor of shape `[num_spectrogram_bins, num_mel_bins]`. + """ + + # This function can be constant folded by graph optimization since + # there are no Tensor inputs. + sampling_rate = self.backend.cast(sampling_rate, dtype) + lower_edge_hertz = self.backend.convert_to_tensor( + lower_edge_hertz, + dtype, + ) + upper_edge_hertz = self.backend.convert_to_tensor( + upper_edge_hertz, + dtype, + ) + zero = self.backend.convert_to_tensor(0.0, dtype) + + # HTK excludes the spectrogram DC bin. + bands_to_zero = 1 + nyquist_hertz = sampling_rate / 2.0 + linear_frequencies = self.backend.numpy.linspace( + zero, nyquist_hertz, num_spectrogram_bins + )[bands_to_zero:] + spectrogram_bins_mel = self.backend.numpy.expand_dims( + self._hertz_to_mel(linear_frequencies), 1 + ) + + # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The + # center of each band is the lower and upper edge of the adjacent bands. + # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into + # num_mel_bins + 2 pieces. + band_edges_mel = self.backend.math.extract_sequences( + self.backend.numpy.linspace( + self._hertz_to_mel(lower_edge_hertz), + self._hertz_to_mel(upper_edge_hertz), + num_mel_bins + 2, + ), + sequence_length=3, + sequence_stride=1, + ) + + # Split the triples up and reshape them into [1, num_mel_bins] tensors. + lower_edge_mel, center_mel, upper_edge_mel = tuple( + self.backend.numpy.reshape(t, [1, num_mel_bins]) + for t in self.backend.numpy.split(band_edges_mel, 3, axis=1) + ) + + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the mel domain, not Hertz. + lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( + center_mel - lower_edge_mel + ) + upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( + upper_edge_mel - center_mel + ) + + # Intersect the line segments with each other and zero. + mel_weights_matrix = self.backend.numpy.maximum( + zero, self.backend.numpy.minimum(lower_slopes, upper_slopes) + ) + + # Re-add the zeroed lower bins we sliced out above. + return self.backend.numpy.pad( + mel_weights_matrix, + [[bands_to_zero, 0], [0, 0]], + ) + + def compute_output_shape(self, input_shape): + if len(input_shape) == 1: + output_shape = [ + self.num_mel_bins, + ( + (input_shape[0] + self.sequence_stride + 1) + // self.sequence_stride + if input_shape[0] is not None + else None + ), + ] + else: + output_shape = [ + input_shape[0], + self.num_mel_bins, + ( + (input_shape[1] + self.sequence_stride + 1) + // self.sequence_stride + if input_shape[1] is not None + else None + ), + ] + return output_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "fft_length": self.fft_length, + "sequence_stride": self.sequence_stride, + "sequence_length": self.sequence_length, + "window": self.window, + "sampling_rate": self.sampling_rate, + "num_mel_bins": self.num_mel_bins, + "min_freq": self.min_freq, + "max_freq": self.max_freq, + "power_to_db": self.power_to_db, + "top_db": self.top_db, + "mag_exp": self.mag_exp, + "min_power": self.min_power, + "ref_power": self.ref_power, + } + ) + return config diff --git a/keras/src/layers/preprocessing/mel_spectrogram_test.py b/keras/src/layers/preprocessing/mel_spectrogram_test.py new file mode 100644 index 000000000000..dcd3f7edd08c --- /dev/null +++ b/keras/src/layers/preprocessing/mel_spectrogram_test.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import layers +from keras.src import testing + + +class MelSpectrogramTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_mel_spectrogram_basics(self): + self.run_layer_test( + layers.MelSpectrogram, + init_kwargs={ + "num_mel_bins": 80, + "sampling_rate": 8000, + "sequence_stride": 128, + "fft_length": 2048, + }, + input_shape=(2, 16000), + expected_output_shape=(2, 80, 126), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.MelSpectrogram, + init_kwargs={ + "num_mel_bins": 80, + "sampling_rate": 8000, + "sequence_stride": 128, + "fft_length": 2048, + }, + input_shape=(16000,), + expected_output_shape=(80, 126), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters( + [ + ((2, 16000), 80, 128, 2048, 8000, False), + ((16000,), 80, 128, 2048, 8000, False), + ((2, 16001), 80, 128, 2048, 16000, False), + ((16001,), 80, 128, 2048, 8000, False), + ((2, 8000), 128, 64, 512, 32000, False), + ((8000,), 128, 64, 512, 32000, False), + ((2, 8000), 128, 64, 512, 32000, True), + ((8000,), 128, 64, 512, 32000, True), + ] + ) + def test_output_shape( + self, + input_shape, + num_mel_bins, + sequence_stride, + fft_length, + sampling_rate, + all_zero, + ): + if all_zero: + audios = np.zeros(input_shape) + else: + audios = np.random.random(input_shape) + out = layers.MelSpectrogram( + num_mel_bins=num_mel_bins, + sequence_stride=sequence_stride, + fft_length=fft_length, + sampling_rate=sampling_rate, + )(audios) + if len(input_shape) == 1: + ref_shape = ( + num_mel_bins, + (input_shape[0] + sequence_stride + 1) // sequence_stride, + ) + else: + ref_shape = ( + input_shape[0], + num_mel_bins, + (input_shape[1] + sequence_stride + 1) // sequence_stride, + ) + self.assertEqual(tuple(out.shape), ref_shape) + + def test_tf_data_compatibility(self): + input_shape = (2, 16000) + output_shape = (2, 80, 126) + layer = layers.MelSpectrogram( + num_mel_bins=80, + sampling_rate=8000, + sequence_stride=128, + fft_length=2048, + ) + input_data = np.random.random(input_shape) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertEqual(tuple(output.shape), output_shape) diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py new file mode 100644 index 000000000000..8ea0d439b31b --- /dev/null +++ b/keras/src/layers/preprocessing/normalization.py @@ -0,0 +1,363 @@ +import math + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.Normalization") +class Normalization(DataLayer): + """A preprocessing layer that normalizes continuous features. + + This layer will shift and scale inputs into a distribution centered around + 0 with standard deviation 1. It accomplishes this by precomputing the mean + and variance of the data, and calling `(input - mean) / sqrt(var)` at + runtime. + + The mean and variance values for the layer must be either supplied on + construction or learned via `adapt()`. `adapt()` will compute the mean and + variance of the data and store them as the layer's weights. `adapt()` should + be called before `fit()`, `evaluate()`, or `predict()`. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + axis: Integer, tuple of integers, or None. The axis or axes that should + have a separate mean and variance for each index in the shape. + For example, if shape is `(None, 5)` and `axis=1`, the layer will + track 5 separate mean and variance values for the last axis. + If `axis` is set to `None`, the layer will normalize + all elements in the input by a scalar mean and variance. + When `-1`, the last axis of the input is assumed to be a + feature dimension and is normalized per index. + Note that in the specific case of batched scalar inputs where + the only axis is the batch axis, the default will normalize + each index in the batch separately. + In this case, consider passing `axis=None`. Defaults to `-1`. + mean: The mean value(s) to use during normalization. The passed value(s) + will be broadcast to the shape of the kept axes above; + if the value(s) cannot be broadcast, an error will be raised when + this layer's `build()` method is called. + `mean` and `variance` must be specified together. + variance: The variance value(s) to use during normalization. The passed + value(s) will be broadcast to the shape of the kept axes above; + if the value(s) cannot be broadcast, an error will be raised when + this layer's `build()` method is called. + `mean` and `variance` must be specified together. + invert: If `True`, this layer will apply the inverse transformation + to its inputs: it would turn a normalized input back into its + original form. + + Examples: + + Calculate a global mean and variance by analyzing the dataset in `adapt()`. + + >>> adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32') + >>> input_data = np.array([1., 2., 3.], dtype='float32') + >>> layer = keras.layers.Normalization(axis=None) + >>> layer.adapt(adapt_data) + >>> layer(input_data) + array([-1.4142135, -0.70710677, 0.], dtype=float32) + + Calculate a mean and variance for each index on the last axis. + + >>> adapt_data = np.array([[0., 7., 4.], + ... [2., 9., 6.], + ... [0., 7., 4.], + ... [2., 9., 6.]], dtype='float32') + >>> input_data = np.array([[0., 7., 4.]], dtype='float32') + >>> layer = keras.layers.Normalization(axis=-1) + >>> layer.adapt(adapt_data) + >>> layer(input_data) + array([-1., -1., -1.], dtype=float32) + + Pass the mean and variance directly. + + >>> input_data = np.array([[1.], [2.], [3.]], dtype='float32') + >>> layer = keras.layers.Normalization(mean=3., variance=2.) + >>> layer(input_data) + array([[-1.4142135 ], + [-0.70710677], + [ 0. ]], dtype=float32) + + Use the layer to de-normalize inputs (after adapting the layer). + + >>> adapt_data = np.array([[0., 7., 4.], + ... [2., 9., 6.], + ... [0., 7., 4.], + ... [2., 9., 6.]], dtype='float32') + >>> input_data = np.array([[1., 2., 3.]], dtype='float32') + >>> layer = keras.layers.Normalization(axis=-1, invert=True) + >>> layer.adapt(adapt_data) + >>> layer(input_data) + array([2., 10., 8.], dtype=float32) + """ + + def __init__( + self, axis=-1, mean=None, variance=None, invert=False, **kwargs + ): + super().__init__(**kwargs) + # Standardize `axis` to a tuple. + if axis is None: + axis = () + elif isinstance(axis, int): + axis = (axis,) + else: + axis = tuple(axis) + self.axis = axis + + self.input_mean = mean + self.input_variance = variance + self.invert = invert + self.supports_masking = True + self._build_input_shape = None + self.mean = None + + # Set `mean` and `variance` if passed. + if (mean is not None) != (variance is not None): + raise ValueError( + "When setting values directly, both `mean` and `variance` " + f"must be set. Received: mean={mean} and variance={variance}" + ) + + def build(self, input_shape): + if input_shape is None: + return + + ndim = len(input_shape) + self._build_input_shape = input_shape + + if any(a < -ndim or a >= ndim for a in self.axis): + raise ValueError( + "All `axis` values must be in the range [-ndim, ndim). " + f"Received inputs with ndim={ndim}, while axis={self.axis}" + ) + + # Axes to be kept, replacing negative values with positive equivalents. + # Sorted to avoid transposing axes. + self._keep_axis = tuple( + sorted([d if d >= 0 else d + ndim for d in self.axis]) + ) + # All axes to be kept should have known shape. + for d in self._keep_axis: + if input_shape[d] is None: + raise ValueError( + "All `axis` values to be kept must have a known shape. " + f"Received axis={self.axis}, " + f"inputs.shape={input_shape}, " + f"with unknown axis at index {d}" + ) + # Axes to be reduced. + self._reduce_axis = tuple( + d for d in range(ndim) if d not in self._keep_axis + ) + # 1 if an axis should be reduced, 0 otherwise. + self._reduce_axis_mask = [ + 0 if d in self._keep_axis else 1 for d in range(ndim) + ] + # Broadcast any reduced axes. + self._broadcast_shape = [ + input_shape[d] if d in self._keep_axis else 1 for d in range(ndim) + ] + mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis) + self._mean_and_var_shape = mean_and_var_shape + + if self.input_mean is None: + self.adapt_mean = self.add_weight( + name="mean", + shape=mean_and_var_shape, + initializer="zeros", + trainable=False, + ) + self.adapt_variance = self.add_weight( + name="variance", + shape=mean_and_var_shape, + initializer="ones", + trainable=False, + ) + # For backwards compatibility with older saved models. + self.count = self.add_weight( + name="count", + shape=(), + dtype="int", + initializer="zeros", + trainable=False, + ) + self.built = True + self.finalize_state() + else: + # In the no adapt case, make constant tensors for mean and variance + # with proper broadcast shape for use during call. + mean = ops.convert_to_tensor(self.input_mean) + variance = ops.convert_to_tensor(self.input_variance) + mean = ops.broadcast_to(mean, self._broadcast_shape) + variance = ops.broadcast_to(variance, self._broadcast_shape) + self.mean = ops.cast(mean, dtype=self.compute_dtype) + self.variance = ops.cast(variance, dtype=self.compute_dtype) + + def adapt(self, data): + """Computes the mean and variance of values in a dataset. + + Calling `adapt()` on a `Normalization` layer is an alternative to + passing in `mean` and `variance` arguments during layer construction. A + `Normalization` layer should always either be adapted over a dataset or + passed `mean` and `variance`. + + During `adapt()`, the layer will compute a `mean` and `variance` + separately for each position in each axis specified by the `axis` + argument. To calculate a single `mean` and `variance` over the input + data, simply pass `axis=None` to the layer. + + Arg: + data: The data to train on. It can be passed either as a + `tf.data.Dataset`, as a NumPy array, or as a backend-native + eager tensor. + If a dataset, *it must be batched*. Keras will assume that the + data is batched, and if that assumption doesn't hold, the mean + and variance may be incorrectly computed. + """ + if isinstance(data, np.ndarray) or backend.is_tensor(data): + input_shape = data.shape + elif isinstance(data, tf.data.Dataset): + input_shape = tuple(data.element_spec.shape) + if len(input_shape) == 1: + # Batch dataset if it isn't batched + data = data.batch(128) + input_shape = tuple(data.element_spec.shape) + + if not self.built: + self.build(input_shape) + else: + for d in self._keep_axis: + if input_shape[d] != self._build_input_shape[d]: + raise ValueError( + "The layer was built with " + f"input_shape={self._build_input_shape}, " + "but adapt() is being called with data with " + f"an incompatible shape, data.shape={input_shape}" + ) + + if isinstance(data, np.ndarray): + total_mean = np.mean(data, axis=self._reduce_axis) + total_var = np.var(data, axis=self._reduce_axis) + elif backend.is_tensor(data): + total_mean = ops.mean(data, axis=self._reduce_axis) + total_var = ops.var(data, axis=self._reduce_axis) + elif isinstance(data, tf.data.Dataset): + total_mean = ops.zeros(self._mean_and_var_shape) + total_var = ops.zeros(self._mean_and_var_shape) + total_count = 0 + for batch in data: + batch = backend.convert_to_tensor( + batch, dtype=self.compute_dtype + ) + batch_mean = ops.mean(batch, axis=self._reduce_axis) + batch_var = ops.var(batch, axis=self._reduce_axis) + if self._reduce_axis: + batch_reduce_shape = ( + batch.shape[d] for d in self._reduce_axis + ) + batch_count = math.prod(batch_reduce_shape) + else: + batch_count = 1 + + total_count += batch_count + batch_weight = float(batch_count) / total_count + existing_weight = 1.0 - batch_weight + new_total_mean = ( + total_mean * existing_weight + batch_mean * batch_weight + ) + # The variance is computed using the lack-of-fit sum of squares + # formula (see + # https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). + total_var = ( + total_var + (total_mean - new_total_mean) ** 2 + ) * existing_weight + ( + batch_var + (batch_mean - new_total_mean) ** 2 + ) * batch_weight + total_mean = new_total_mean + else: + raise NotImplementedError(f"Unsupported data type: {type(data)}") + + self.adapt_mean.assign(total_mean) + self.adapt_variance.assign(total_var) + self.finalize_state() + + def finalize_state(self): + if self.input_mean is not None or not self.built: + return + + # In the adapt case, we make constant tensors for mean and variance with + # proper broadcast shape and dtype each time `finalize_state` is called. + self.mean = ops.reshape(self.adapt_mean, self._broadcast_shape) + self.mean = ops.cast(self.mean, self.compute_dtype) + self.variance = ops.reshape(self.adapt_variance, self._broadcast_shape) + self.variance = ops.cast(self.variance, self.compute_dtype) + + def call(self, inputs): + # This layer can be called in tf.data + # even with another backend after it has been adapted. + # However it must use backend-native logic for adapt(). + if self.mean is None: + # May happen when in tf.data when mean/var was passed explicitly + raise ValueError( + "You must call `.build(input_shape)` " + "on the layer before using it." + ) + inputs = self.backend.core.convert_to_tensor( + inputs, dtype=self.compute_dtype + ) + # Ensure the weights are in the correct backend. Without this, it is + # possible to cause breakage when using this layer in tf.data. + mean = self.convert_weight(self.mean) + variance = self.convert_weight(self.variance) + if self.invert: + return self.backend.numpy.add( + mean, + self.backend.numpy.multiply( + inputs, + self.backend.numpy.maximum( + self.backend.numpy.sqrt(variance), backend.epsilon() + ), + ), + ) + else: + return self.backend.numpy.divide( + self.backend.numpy.subtract(inputs, mean), + self.backend.numpy.maximum( + self.backend.numpy.sqrt(variance), backend.epsilon() + ), + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "axis": self.axis, + "invert": self.invert, + "mean": np.array(self.input_mean).tolist(), + "variance": np.array(self.input_variance).tolist(), + } + ) + return config + + def load_own_variables(self, store): + super().load_own_variables(store) + # Ensure that we call finalize_state after variable loading. + self.finalize_state() + + def get_build_config(self): + if self._build_input_shape: + return {"input_shape": self._build_input_shape} + + def build_from_config(self, config): + if config: + self.build(config["input_shape"]) diff --git a/keras/src/layers/preprocessing/normalization_test.py b/keras/src/layers/preprocessing/normalization_test.py new file mode 100644 index 000000000000..70dea3787002 --- /dev/null +++ b/keras/src/layers/preprocessing/normalization_test.py @@ -0,0 +1,171 @@ +import numpy as np +import pytest +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class NormalizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_normalization_basics(self): + self.run_layer_test( + layers.Normalization, + init_kwargs={ + "axis": -1, + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=3, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.Normalization, + init_kwargs={ + "axis": -1, + "mean": np.array([0.5, 0.2, -0.1]), + "variance": np.array([0.1, 0.2, 0.3]), + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + self.run_layer_test( + layers.Normalization, + init_kwargs={ + "axis": -1, + "mean": np.array([0.5, 0.2, -0.1]), + "variance": np.array([0.1, 0.2, 0.3]), + "invert": True, + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + @parameterized.parameters([("np",), ("tensor",), ("tf.data")]) + def test_normalization_adapt(self, input_type): + x = np.random.random((32, 4)) + if input_type == "np": + data = x + elif input_type == "tensor": + data = backend.convert_to_tensor(x) + elif input_type == "tf.data": + data = tf_data.Dataset.from_tensor_slices(x).batch(8) + else: + raise NotImplementedError(input_type) + + layer = layers.Normalization() + layer.adapt(data) + self.assertTrue(layer.built) + output = layer(x) + output = backend.convert_to_numpy(output) + self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5) + self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5) + + # Test in high-dim and with tuple axis. + x = np.random.random((32, 4, 3, 5)) + if input_type == "np": + data = x + elif input_type == "tensor": + data = backend.convert_to_tensor(x) + elif input_type == "tf.data": + data = tf_data.Dataset.from_tensor_slices(x).batch(8) + + layer = layers.Normalization(axis=(1, 2)) + layer.adapt(data) + self.assertTrue(layer.built) + output = layer(x) + output = backend.convert_to_numpy(output) + self.assertAllClose(np.var(output, axis=(0, 3)), 1.0, atol=1e-5) + self.assertAllClose(np.mean(output, axis=(0, 3)), 0.0, atol=1e-5) + + @pytest.mark.skipif( + backend.backend() != "torch", + reason="Test symbolic call for torch meta device.", + ) + def test_call_on_meta_device_after_built(self): + layer = layers.Normalization() + data = np.random.random((32, 4)) + layer.adapt(data) + with backend.device("meta"): + layer(data) + + def test_normalization_with_mean_only_raises_error(self): + # Test error when only `mean` is provided + with self.assertRaisesRegex( + ValueError, "both `mean` and `variance` must be set" + ): + layers.Normalization(mean=0.5) + + def test_normalization_with_variance_only_raises_error(self): + # Test error when only `variance` is provided + with self.assertRaisesRegex( + ValueError, "both `mean` and `variance` must be set" + ): + layers.Normalization(variance=0.1) + + def test_normalization_axis_too_high(self): + with self.assertRaisesRegex( + ValueError, "All `axis` values must be in the range" + ): + layer = layers.Normalization(axis=3) + layer.build((2, 2)) + + def test_normalization_axis_too_low(self): + with self.assertRaisesRegex( + ValueError, "All `axis` values must be in the range" + ): + layer = layers.Normalization(axis=-4) + layer.build((2, 3, 4)) + + def test_normalization_unknown_axis_shape(self): + with self.assertRaisesRegex(ValueError, "All `axis` values to be kept"): + layer = layers.Normalization(axis=1) + layer.build((None, None)) + + def test_normalization_adapt_with_incompatible_shape(self): + layer = layers.Normalization(axis=-1) + initial_shape = (10, 5) + layer.build(initial_shape) + new_shape_data = np.random.random((10, 3)) + with self.assertRaisesRegex(ValueError, "an incompatible shape"): + layer.adapt(new_shape_data) + + def test_tf_data_compatibility(self): + x = np.random.random((32, 3)) + ds = tf_data.Dataset.from_tensor_slices(x).batch(1) + + # With built-in values + layer = layers.Normalization( + mean=[0.1, 0.2, 0.3], variance=[0.1, 0.2, 0.3], axis=-1 + ) + layer.build((None, 3)) + for output in ds.map(layer).take(1): + output.numpy() + + # With adapt flow + layer = layers.Normalization(axis=-1) + layer.adapt( + np.random.random((32, 3)), + ) + for output in ds.map(layer).take(1): + output.numpy() + + def test_normalization_with_scalar_mean_var(self): + input_data = np.array([[1, 2, 3]], dtype="float32") + layer = layers.Normalization(mean=3.0, variance=2.0) + layer(input_data) diff --git a/keras/src/layers/preprocessing/pipeline.py b/keras/src/layers/preprocessing/pipeline.py new file mode 100644 index 000000000000..7890eff95533 --- /dev/null +++ b/keras/src/layers/preprocessing/pipeline.py @@ -0,0 +1,84 @@ +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.saving import serialization_lib + + +@keras_export("keras.layers.Pipeline") +class Pipeline(Layer): + """Applies a series of layers to an input. + + This class is useful to build a preprocessing pipeline, + in particular an image data augmentation pipeline. + Compared to a `Sequential` model, `Pipeline` features + a few important differences: + + - It's not a `Model`, just a plain layer. + - When the layers in the pipeline are compatible + with `tf.data`, the pipeline will also + remain `tf.data` compatible. That is to say, + the pipeline will not attempt to convert + its inputs to backend-native tensors + when in a tf.data context (unlike a `Sequential` + model). + + Example: + + ```python + from keras import layers + preprocessing_pipeline = layers.Pipeline([ + layers.AutoContrast(), + layers.RandomZoom(0.2), + layers.RandomRotation(0.2), + ]) + + # `ds` is a tf.data.Dataset + preprocessed_ds = ds.map( + preprocessing_pipeline, + num_parallel_calls=4, + ) + ``` + """ + + def __init__(self, layers, name=None): + super().__init__(name=name) + self._pipeline_layers = layers + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + + @property + def layers(self): + return self._pipeline_layers + + def call(self, inputs, training=True, mask=None): + for layer in self._pipeline_layers: + kwargs = {} + if layer._call_has_mask_arg: + kwargs["mask"] = mask + if layer._call_has_training_arg and training is not None: + kwargs["training"] = training + outputs = layer(inputs, **kwargs) + inputs = outputs + + def _get_mask_from_keras_tensor(kt): + return getattr(kt, "_keras_mask", None) + + mask = tree.map_structure(_get_mask_from_keras_tensor, outputs) + return outputs + + @classmethod + def from_config(cls, config): + config["layers"] = [ + serialization_lib.deserialize_keras_object(x) + for x in config["layers"] + ] + return cls(**config) + + def get_config(self): + config = { + "layers": serialization_lib.serialize_keras_object( + self._pipeline_layers + ), + "name": self.name, + } + return config diff --git a/keras/src/layers/preprocessing/pipeline_test.py b/keras/src/layers/preprocessing/pipeline_test.py new file mode 100644 index 000000000000..dc02d75966c1 --- /dev/null +++ b/keras/src/layers/preprocessing/pipeline_test.py @@ -0,0 +1,92 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class CanaryLayer(layers.Layer): + def __init__(self): + super().__init__() + self.training = None + self.received_mask = False + + def call(self, x, training=False, mask=None): + self.training = training + if mask is not None: + self.received_mask = True + return x + + def compute_mask(self, x, mask=None): + return x + + def compute_output_shape(self, input_shape): + return input_shape + + +class PipelineTest(testing.TestCase): + def test_basics(self): + run_training_check = False if backend.backend() == "numpy" else True + self.run_layer_test( + layers.Pipeline, + init_kwargs={ + "layers": [layers.AutoContrast(), layers.RandomBrightness(0.1)], + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + run_mixed_precision_check=False, + run_training_check=run_training_check, + ) + + @pytest.mark.skipif( + backend.backend() == "numpy", reason="masking not working in numpy" + ) + def test_correctness(self): + pipeline = layers.Pipeline([CanaryLayer(), CanaryLayer()]) + x = np.array([0]) + mask = np.array([0]) + pipeline(x, training=True, mask=mask) + self.assertTrue(pipeline.layers[0].training) + self.assertTrue(pipeline.layers[0].received_mask) + self.assertTrue(pipeline.layers[1].training) + self.assertTrue(pipeline.layers[1].received_mask) + + def test_tf_data_compatibility(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 12, 3) + output_shape = (2, 8, 9, 3) + else: + input_shape = (2, 3, 10, 12) + output_shape = (2, 3, 8, 9) + layer = layers.Pipeline( + [ + layers.AutoContrast(), + layers.CenterCrop(8, 9), + ] + ) + input_data = np.random.random(input_shape) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertEqual(tuple(output.shape), output_shape) + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Fails on CI, passes locally. TODO: debug", + ) + def test_from_config(self): + pipeline = layers.Pipeline( + [ + layers.AutoContrast(), + layers.CenterCrop(8, 9), + ] + ) + x = np.ones((2, 10, 12, 3)) + output = pipeline(x) + restored = layers.Pipeline.from_config(pipeline.get_config()) + restored_output = restored(x) + self.assertEqual(tuple(output.shape), (2, 8, 9, 3)) + self.assertAllClose(output, restored_output) diff --git a/keras/src/layers/preprocessing/rescaling.py b/keras/src/layers/preprocessing/rescaling.py new file mode 100644 index 000000000000..77b34150c22e --- /dev/null +++ b/keras/src/layers/preprocessing/rescaling.py @@ -0,0 +1,78 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.data_layer import DataLayer +from keras.src.saving import serialization_lib + + +@keras_export("keras.layers.Rescaling") +class Rescaling(DataLayer): + """A preprocessing layer which rescales input values to a new range. + + This layer rescales every value of an input (often an image) by multiplying + by `scale` and adding `offset`. + + For instance: + + 1. To rescale an input in the `[0, 255]` range + to be in the `[0, 1]` range, you would pass `scale=1./255`. + + 2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range, + you would pass `scale=1./127.5, offset=-1`. + + The rescaling is applied both during training and inference. Inputs can be + of integer or floating point dtype, and by default the layer will output + floats. + + **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline + (independently of which backend you're using). + + Args: + scale: Float, the scale to apply to the inputs. + offset: Float, the offset to apply to the inputs. + **kwargs: Base layer keyword arguments, such as `name` and `dtype`. + """ + + def __init__(self, scale, offset=0.0, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.offset = offset + self.supports_masking = True + + def call(self, inputs): + dtype = self.compute_dtype + scale = self.backend.cast(self.scale, dtype) + offset = self.backend.cast(self.offset, dtype) + scale_shape = self.backend.core.shape(scale) + if ( + len(scale_shape) > 0 + and backend.image_data_format() == "channels_first" + ): + scale = self.backend.numpy.reshape( + scale, scale_shape + (1,) * (3 - len(scale_shape)) + ) + return self.backend.cast(inputs, dtype) * scale + offset + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + # `scale` and `offset` might be numpy array. + "scale": serialization_lib.serialize_keras_object(self.scale), + "offset": serialization_lib.serialize_keras_object(self.offset), + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + config["scale"] = serialization_lib.deserialize_keras_object( + config["scale"], custom_objects=custom_objects + ) + config["offset"] = serialization_lib.deserialize_keras_object( + config["offset"], custom_objects=custom_objects + ) + return cls(**config) diff --git a/keras/src/layers/preprocessing/rescaling_test.py b/keras/src/layers/preprocessing/rescaling_test.py new file mode 100644 index 000000000000..a2863821f28e --- /dev/null +++ b/keras/src/layers/preprocessing/rescaling_test.py @@ -0,0 +1,119 @@ +import grain +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RescalingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_rescaling_basics(self): + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 1.0 / 255, "offset": 0.5}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + @pytest.mark.requires_trainable_backend + def test_rescaling_dtypes(self): + # int scale + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 2, "offset": 0.5}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # int offset + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 1.0, "offset": 2}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # int inputs + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 1.0 / 255, "offset": 0.5}, + input_shape=(2, 3), + input_dtype="int16", + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_rescaling_correctness(self): + layer = layers.Rescaling(scale=1.0 / 255, offset=0.5) + x = np.random.random((3, 10, 10, 3)) * 255 + out = layer(x) + self.assertAllClose(out, x / 255 + 0.5) + + def test_tf_data_compatibility(self): + layer = layers.Rescaling(scale=1.0 / 255, offset=0.5) + x = np.random.random((3, 10, 10, 3)) * 255 + ds = tf_data.Dataset.from_tensor_slices(x).batch(3).map(layer) + next(iter(ds)).numpy() + + def test_grain_compatibility(self): + layer = layers.Rescaling(scale=1.0 / 255, offset=0.5) + x = np.random.random((3, 10, 10, 3)) * 255 + ds = grain.MapDataset.source(x).to_iter_dataset().batch(3).map(layer) + output = next(iter(ds)) + + self.assertTrue(backend.is_tensor(output)) + # Ensure the device of the data is on CPU. + if backend.backend() == "tensorflow": + self.assertIn("CPU", str(output.device)) + elif backend.backend() == "jax": + self.assertIn("CPU", str(output.device)) + elif backend.backend() == "torch": + self.assertEqual("cpu", str(output.device)) + + def test_rescaling_with_channels_first_and_vector_scale(self): + config = backend.image_data_format() + backend.set_image_data_format("channels_first") + layer = layers.Rescaling( + scale=[1.0 / 255, 1.5 / 255, 2.0 / 255], offset=0.5 + ) + x = np.random.random((2, 3, 10, 10)) * 255 + layer(x) + backend.set_image_data_format(config) + + @pytest.mark.requires_trainable_backend + def test_numpy_args(self): + # https://github.com/keras-team/keras/issues/20072 + self.run_layer_test( + layers.Rescaling, + init_kwargs={ + "scale": np.array(1.0 / 255.0), + "offset": np.array(0.5), + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) diff --git a/keras/src/layers/preprocessing/stft_spectrogram.py b/keras/src/layers/preprocessing/stft_spectrogram.py new file mode 100644 index 000000000000..f8ef0db98281 --- /dev/null +++ b/keras/src/layers/preprocessing/stft_spectrogram.py @@ -0,0 +1,383 @@ +import math +import warnings + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.utils.module_utils import scipy + + +@keras_export("keras.layers.STFTSpectrogram") +class STFTSpectrogram(layers.Layer): + """Layer to compute the Short-Time Fourier Transform (STFT) on a 1D signal. + + A layer that computes Spectrograms of the input signal to produce + a spectrogram. This layers utilizes Short-Time Fourier Transform (STFT) by + The layer computes Spectrograms based on STFT by utilizing convolution + kernels, which allows parallelization on GPUs and trainable kernels for + fine-tuning support. This layer allows different modes of output + (e.g., log-scaled magnitude, phase, power spectral density, etc.) and + provides flexibility in windowing, padding, and scaling options for the + STFT calculation. + + Examples: + + Apply it as a non-trainable preprocessing layer on 3 audio tracks of + 1 channel, 10 seconds and sampled at 16 kHz. + + >>> layer = keras.layers.STFTSpectrogram( + ... mode='log', + ... frame_length=256, + ... frame_step=128, # 50% overlap + ... fft_length=512, + ... window="hann", + ... padding="valid", + ... trainable=False, # non-trainable, preprocessing only + ... ) + >>> layer(keras.random.uniform(shape=(3, 160000, 1))).shape + (3, 1249, 257) + + Apply it as a trainable processing layer on 3 stereo audio tracks of + 2 channels, 10 seconds and sampled at 16 kHz. This is initialized as the + non-trainable layer, but then can be trained jointly within a model. + + >>> layer = keras.layers.STFTSpectrogram( + ... mode='log', + ... frame_length=256, + ... frame_step=128, # 50% overlap + ... fft_length=512, + ... window="hamming", # hamming windowing function + ... padding="same", # padding to preserve the time dimension + ... trainable=True, # trainable, this is the default in keras + ... ) + >>> layer(keras.random.uniform(shape=(3, 160000, 2))).shape + (3, 1250, 514) + + Similar to the last example, but add an extra dimension so the output is + an image to be used with image models. We apply this here on a signal of + 3 input channels to output an image tensor, hence is directly applicable + with an image model. + + >>> layer = keras.layers.STFTSpectrogram( + ... mode='log', + ... frame_length=256, + ... frame_step=128, + ... fft_length=512, + ... padding="same", + ... expand_dims=True, # this adds the extra dimension + ... ) + >>> layer(keras.random.uniform(shape=(3, 160000, 3))).shape + (3, 1250, 257, 3) + + Args: + mode: String, the output type of the spectrogram. Can be one of + `"log"`, `"magnitude`", `"psd"`, `"real`", `"imag`", `"angle`", + `"stft`". Defaults to `"log`". + frame_length: Integer, The length of each frame (window) for STFT in + samples. Defaults to 256. + frame_step: Integer, the step size (hop length) between + consecutive frames. If not provided, defaults to half the + frame_length. Defaults to `frame_length // 2`. + fft_length: Integer, the size of frequency bins used in the Fast-Fourier + Transform (FFT) to apply to each frame. Should be greater than or + equal to `frame_length`. Recommended to be a power of two. Defaults + to the smallest power of two that is greater than or equal + to `frame_length`. + window: (String or array_like), the windowing function to apply to each + frame. Can be `"hann`" (default), `"hamming`", or a custom window + provided as an array_like. + periodic: Boolean, if True, the window function will be treated as + periodic. Defaults to `False`. + scaling: String, type of scaling applied to the window. Can be + `"density`", `"spectrum`", or None. Default is `"density`". + padding: String, padding strategy. Can be `"valid`" or `"same`". + Defaults to `"valid"`. + expand_dims: Boolean, if True, will expand the output into spectrograms + into two dimensions to be compatible with image models. + Defaults to `False`. + data_format: String, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, weight)`. Defaults to `"channels_last"`. + + Raises: + ValueError: If an invalid value is provided for `"mode`", `"scaling`", + `"padding`", or other input arguments. + TypeError: If the input data type is not one of `"float16`", + `"float32`", or `"float64`". + + Input shape: + A 3D tensor of shape `(batch_size, time_length, input_channels)`, if + `data_format=="channels_last"`, and of shape + `(batch_size, input_channels, time_length)` if + `data_format=="channels_first"`, where `time_length` is the length of + the input signal, and `input_channels` is the number of input channels. + The same kernels are applied to each channel independently. + + Output shape: + If `data_format=="channels_first" and not expand_dims`, a 3D tensor: + `(batch_size, input_channels * freq_channels, new_time_length)` + If `data_format=="channels_last" and not expand_dims`, a 3D tensor: + `(batch_size, new_time_length, input_channels * freq_channels)` + If `data_format=="channels_first" and expand_dims`, a 4D tensor: + `(batch_size, input_channels, new_time_length, freq_channels)` + If `data_format=="channels_last" and expand_dims`, a 4D tensor: + `(batch_size, new_time_length, freq_channels, input_channels)` + + where `new_time_length` depends on the padding, and `freq_channels` is + the number of FFT bins `(fft_length // 2 + 1)`. + """ + + def __init__( + self, + mode="log", + frame_length=256, + frame_step=None, + fft_length=None, + window="hann", + periodic=False, + scaling="density", + padding="valid", + expand_dims=False, + data_format=None, + **kwargs, + ): + if frame_step is not None and ( + frame_step > frame_length or frame_step < 1 + ): + raise ValueError( + "`frame_step` should be a positive integer not greater than " + f"`frame_length`. Received frame_step={frame_step}, " + f"frame_length={frame_length}" + ) + + if fft_length is not None and fft_length < frame_length: + raise ValueError( + "`fft_length` should be not less than `frame_length`. " + f"Received fft_length={fft_length}, frame_length={frame_length}" + ) + + if fft_length is not None and (fft_length & -fft_length) != fft_length: + warnings.warn( + "`fft_length` is recommended to be a power of two. " + f"Received fft_length={fft_length}" + ) + + all_modes = ["log", "magnitude", "psd", "real", "imag", "angle", "stft"] + + if mode not in all_modes: + raise ValueError( + "Output mode is invalid, it must be one of " + f"{', '.join(all_modes)}. Received: mode={mode}" + ) + + if scaling is not None and scaling not in ["density", "spectrum"]: + raise ValueError( + "Scaling is invalid, it must be `None`, 'density' " + f"or 'spectrum'. Received scaling={scaling}" + ) + + if padding not in ["valid", "same"]: + raise ValueError( + "Padding is invalid, it should be 'valid', 'same'. " + f"Received: padding={padding}" + ) + + if isinstance(window, str): + # throws an exception for invalid window function + scipy.signal.get_window(window, 1) + + super().__init__(**kwargs) + + self.mode = mode + + self.frame_length = frame_length + self.frame_step = frame_step + self._frame_step = frame_step or self.frame_length // 2 + self.fft_length = fft_length + self._fft_length = fft_length or ( + 2 ** int(math.ceil(math.log2(frame_length))) + ) + + self.window = window + self.periodic = periodic + self.scaling = scaling + self.padding = padding + self.expand_dims = expand_dims + self.data_format = backend.standardize_data_format(data_format) + self.input_spec = layers.input_spec.InputSpec(ndim=3) + + def build(self, input_shape): + shape = (self.frame_length, 1, self._fft_length // 2 + 1) + + if self.mode != "imag": + self.real_kernel = self.add_weight( + name="real_kernel", + shape=shape, + initializer=initializers.STFT( + "real", self.window, self.scaling, self.periodic + ), + ) + if self.mode != "real": + self.imag_kernel = self.add_weight( + name="imag_kernel", + shape=shape, + initializer=initializers.STFT( + "imag", self.window, self.scaling, self.periodic + ), + ) + + def _adjust_shapes(self, outputs): + _, channels, freq_channels, time_seq = ops.shape(outputs) + batch_size = -1 + if self.data_format == "channels_last": + if self.expand_dims: + outputs = ops.transpose(outputs, [0, 3, 2, 1]) + # [batch_size, time_seq, freq_channels, input_channels] + else: + outputs = ops.reshape( + outputs, + [batch_size, channels * freq_channels, time_seq], + ) + # [batch_size, input_channels * freq_channels, time_seq] + outputs = ops.transpose(outputs, [0, 2, 1]) + else: + if self.expand_dims: + outputs = ops.transpose(outputs, [0, 1, 3, 2]) + # [batch_size, channels, time_seq, freq_channels] + else: + outputs = ops.reshape( + outputs, + [batch_size, channels * freq_channels, time_seq], + ) + return outputs + + def _apply_conv(self, inputs, kernel): + if self.data_format == "channels_last": + _, time_seq, channels = ops.shape(inputs) + inputs = ops.transpose(inputs, [0, 2, 1]) + inputs = ops.reshape(inputs, [-1, time_seq, 1]) + else: + _, channels, time_seq = ops.shape(inputs) + inputs = ops.reshape(inputs, [-1, 1, time_seq]) + + outputs = ops.conv( + inputs, + ops.cast(kernel, backend.standardize_dtype(inputs.dtype)), + padding=self.padding, + strides=self._frame_step, + data_format=self.data_format, + ) + batch_size = -1 + if self.data_format == "channels_last": + _, time_seq, freq_channels = ops.shape(outputs) + outputs = ops.transpose(outputs, [0, 2, 1]) + outputs = ops.reshape( + outputs, + [batch_size, channels, freq_channels, time_seq], + ) + else: + _, freq_channels, time_seq = ops.shape(outputs) + outputs = ops.reshape( + outputs, + [batch_size, channels, freq_channels, time_seq], + ) + return outputs + + def call(self, inputs): + dtype = inputs.dtype + if backend.standardize_dtype(dtype) not in { + "float16", + "float32", + "float64", + }: + raise TypeError( + "Invalid input type. Expected `float16`, `float32` or " + f"`float64`. Received: input type={dtype}" + ) + + real_signal = None + imag_signal = None + power = None + + if self.mode != "imag": + real_signal = self._apply_conv(inputs, self.real_kernel) + if self.mode != "real": + imag_signal = self._apply_conv(inputs, self.imag_kernel) + + if self.mode == "real": + return self._adjust_shapes(real_signal) + elif self.mode == "imag": + return self._adjust_shapes(imag_signal) + elif self.mode == "angle": + return self._adjust_shapes(ops.arctan2(imag_signal, real_signal)) + elif self.mode == "stft": + return self._adjust_shapes( + ops.concatenate([real_signal, imag_signal], axis=2) + ) + else: + power = ops.square(real_signal) + ops.square(imag_signal) + + if self.mode == "psd": + return self._adjust_shapes( + power + + ops.pad( + power[:, :, 1:-1, :], [[0, 0], [0, 0], [1, 1], [0, 0]] + ) + ) + linear_stft = self._adjust_shapes( + ops.sqrt(ops.maximum(power, backend.epsilon())) + ) + + if self.mode == "magnitude": + return linear_stft + else: + return ops.log(ops.maximum(linear_stft, backend.epsilon())) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + channels = input_shape[-1] + else: + channels = input_shape[1] + freq_channels = self._fft_length // 2 + 1 + if self.mode == "stft": + freq_channels *= 2 + shape = ops.operation_utils.compute_conv_output_shape( + input_shape, + freq_channels * channels, + (self.frame_length,), + strides=self._frame_step, + padding=self.padding, + data_format=self.data_format, + ) + if self.data_format == "channels_last": + batch_size, time_seq, _ = shape + else: + batch_size, _, time_seq = shape + if self.expand_dims: + if self.data_format == "channels_last": + return (batch_size, time_seq, freq_channels, channels) + else: + return (batch_size, channels, time_seq, freq_channels) + return shape + + def get_config(self): + config = super().get_config() + config.update( + { + "mode": self.mode, + "frame_length": self.frame_length, + "frame_step": self.frame_step, + "fft_length": self.fft_length, + "window": self.window, + "periodic": self.periodic, + "scaling": self.scaling, + "padding": self.padding, + "data_format": self.data_format, + "expand_dims": self.expand_dims, + } + ) + return config diff --git a/keras/src/layers/preprocessing/stft_spectrogram_test.py b/keras/src/layers/preprocessing/stft_spectrogram_test.py new file mode 100644 index 000000000000..a363393d776e --- /dev/null +++ b/keras/src/layers/preprocessing/stft_spectrogram_test.py @@ -0,0 +1,393 @@ +import numpy as np +import pytest +import scipy.signal +import tensorflow as tf + +from keras import Input +from keras import Sequential +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class TestSpectrogram(testing.TestCase): + DTYPE = "float32" + + @staticmethod + def _calc_spectrograms( + x, mode, scaling, window, periodic, frame_length, frame_step, fft_length + ): + data_format = backend.image_data_format() + input_shape = (None, 1) if data_format == "channels_last" else (1, None) + + layer = Sequential( + [ + Input(shape=input_shape, dtype=TestSpectrogram.DTYPE), + layers.STFTSpectrogram( + mode=mode, + frame_length=frame_length, + frame_step=frame_step, + fft_length=fft_length, + window=window, + scaling=scaling, + periodic=periodic, + dtype=TestSpectrogram.DTYPE, + ), + ] + ) + if data_format == "channels_first": + y = layer.predict(np.transpose(x, [0, 2, 1]), verbose=0) + y = np.transpose(y, [0, 2, 1]) + else: + y = layer.predict(x, verbose=0) + + window_arr = scipy.signal.get_window(window, frame_length, periodic) + _, _, spec = scipy.signal.spectrogram( + x[..., 0].astype(TestSpectrogram.DTYPE), + window=window_arr.astype(TestSpectrogram.DTYPE), + nperseg=frame_length, + noverlap=frame_length - frame_step, + mode=mode, + scaling=scaling, + detrend=False, + nfft=fft_length, + ) + y_true = np.transpose(spec, [0, 2, 1]) + return y_true, y + + @pytest.mark.requires_trainable_backend + def test_spectrogram_channels_broadcasting(self): + rnd = np.random.RandomState(41) + audio = rnd.uniform(-1, 1, size=(3, 16000, 7)) + + layer_last = Sequential( + [ + Input(shape=(None, 7), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + layer_single = Sequential( + [ + Input(shape=(None, 1), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + + layer_expand = Sequential( + [ + Input(shape=(None, 7), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", + dtype=self.DTYPE, + data_format="channels_last", + expand_dims=True, + ), + ] + ) + + y_last = layer_last.predict(audio, verbose=0) + y_expanded = layer_expand.predict(audio, verbose=0) + y_singles = [ + layer_single.predict(audio[..., i : i + 1], verbose=0) + for i in range(audio.shape[-1]) + ] + + self.assertAllClose(y_last, np.concatenate(y_singles, axis=-1)) + self.assertAllClose(y_expanded, np.stack(y_singles, axis=-1)) + + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="TF doesn't support channels_first", + ) + @pytest.mark.requires_trainable_backend + def test_spectrogram_channels_first(self): + rnd = np.random.RandomState(41) + audio = rnd.uniform(-1, 1, size=(3, 16000, 7)) + + layer_first = Sequential( + [ + Input(shape=(7, None), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_first" + ), + ] + ) + layer_last = Sequential( + [ + Input(shape=(None, 7), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + layer_single = Sequential( + [ + Input(shape=(None, 1), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", dtype=self.DTYPE, data_format="channels_last" + ), + ] + ) + layer_expand = Sequential( + [ + Input(shape=(7, None), dtype=self.DTYPE), + layers.STFTSpectrogram( + mode="psd", + dtype=self.DTYPE, + data_format="channels_first", + expand_dims=True, + ), + ] + ) + + y_singles = [ + layer_single.predict(audio[..., i : i + 1], verbose=0) + for i in range(audio.shape[-1]) + ] + y_expanded = layer_expand.predict( + np.transpose(audio, [0, 2, 1]), verbose=0 + ) + y_last = layer_last.predict(audio, verbose=0) + y_first = layer_first.predict(np.transpose(audio, [0, 2, 1]), verbose=0) + self.assertAllClose(np.transpose(y_first, [0, 2, 1]), y_last) + self.assertAllClose(y_expanded, np.stack(y_singles, axis=1)) + self.assertAllClose( + y_first, + np.transpose(np.concatenate(y_singles, axis=-1), [0, 2, 1]), + ) + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 10, + "fft_length": 512, + "trainable": False, + "padding": "same", + "expand_dims": True, + "data_format": "channels_first", + }, + input_shape=(2, 3, 160000), + expected_output_shape=(2, 3, 160000 // 10, 257), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @pytest.mark.requires_trainable_backend + def test_spectrogram_basics(self): + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 500, + "frame_step": 25, + "fft_length": 1024, + "mode": "stft", + "data_format": "channels_last", + }, + input_shape=(2, 16000, 1), + expected_output_shape=(2, 15500 // 25 + 1, 513 * 2), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 71, + "fft_length": 4096, + "mode": "real", + "data_format": "channels_last", + }, + input_shape=(2, 160000, 1), + expected_output_shape=(2, 159850 // 71 + 1, 2049), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 43, + "fft_length": 512, + "mode": "imag", + "padding": "same", + "data_format": "channels_last", + }, + input_shape=(2, 160000, 1), + expected_output_shape=(2, 160000 // 43 + 1, 257), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 10, + "fft_length": 512, + "trainable": False, + "padding": "same", + "data_format": "channels_last", + }, + input_shape=(2, 160000, 3), + expected_output_shape=(2, 160000 // 10, 257 * 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + self.run_layer_test( + layers.STFTSpectrogram, + init_kwargs={ + "frame_length": 150, + "frame_step": 10, + "fft_length": 512, + "trainable": False, + "padding": "same", + "expand_dims": True, + "data_format": "channels_last", + }, + input_shape=(2, 160000, 3), + expected_output_shape=(2, 160000 // 10, 257, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Backend does not support dynamic shapes", + ) + def test_spectrogram_dynamic_shape(self): + model = Sequential( + [ + Input(shape=(None, 1), dtype=TestSpectrogram.DTYPE), + layers.STFTSpectrogram( + frame_length=500, + frame_step=25, + fft_length=1024, + mode="stft", + data_format="channels_last", + ), + ] + ) + + def generator(): + yield (np.random.random((2, 16000, 1)),) + yield (np.random.random((3, 8000, 1)),) + + model.predict(generator()) + + @pytest.mark.requires_trainable_backend + def test_spectrogram_error(self): + rnd = np.random.RandomState(41) + x = rnd.uniform(low=-1, high=1, size=(4, 160000, 1)).astype(self.DTYPE) + names = [ + "scaling", + "window", + "periodic", + "frame_length", + "frame_step", + "fft_length", + ] + for args in [ + ("density", "hann", False, 512, 256, 1024), + ("spectrum", "blackman", True, 512, 32, 1024), + ("spectrum", "hamming", True, 256, 192, 512), + ("spectrum", "tukey", False, 512, 128, 512), + ("density", "hamming", True, 256, 256, 256), + ("density", "hann", True, 256, 128, 256), + ]: + init_args = dict(zip(names, args)) + + tol_kwargs = {"atol": 5e-4, "rtol": 1e-6} + + init_args["mode"] = "magnitude" + y_true, y = self._calc_spectrograms(x, **init_args) + self.assertEqual(np.shape(y_true), np.shape(y)) + self.assertAllClose(y_true, y, **tol_kwargs) + + init_args["mode"] = "psd" + y_true, y = self._calc_spectrograms(x, **init_args) + self.assertEqual(np.shape(y_true), np.shape(y)) + self.assertAllClose(y_true, y, **tol_kwargs) + + init_args["mode"] = "angle" + y_true, y = self._calc_spectrograms(x, **init_args) + + mask = np.isclose(y, y_true, **tol_kwargs) + mask |= np.isclose(y + 2 * np.pi, y_true, **tol_kwargs) + mask |= np.isclose(y - 2 * np.pi, y_true, **tol_kwargs) + mask |= np.isclose(np.cos(y), np.cos(y_true), **tol_kwargs) + mask |= np.isclose(np.sin(y), np.sin(y_true), **tol_kwargs) + + self.assertLess(np.mean(~mask), 2e-4) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Requires TF tensors for TF-data module.", + ) + def test_tf_data_compatibility(self): + input_shape = (2, 16000, 1) + output_shape = (2, 16000 // 128, 358) + layer = layers.STFTSpectrogram( + frame_length=256, + frame_step=128, + fft_length=715, + padding="same", + scaling=None, + ) + input_data = np.random.random(input_shape) + ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertEqual(tuple(output.shape), output_shape) + + def test_exceptions(self): + with self.assertRaises(ValueError): + layers.STFTSpectrogram( + frame_length=256, frame_step=1024, fft_length=512 + ) + with self.assertRaises(ValueError): + layers.STFTSpectrogram( + frame_length=256, frame_step=0, fft_length=512 + ) + with self.assertRaises(ValueError): + layers.STFTSpectrogram( + frame_length=256, frame_step=32, fft_length=128 + ) + with self.assertRaises(ValueError): + layers.STFTSpectrogram(padding="mypadding") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(scaling="l2") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(mode="spectrogram") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(window="unknowable") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(scaling="l2") + with self.assertRaises(ValueError): + layers.STFTSpectrogram(padding="divide") + with self.assertRaises(TypeError): + layers.STFTSpectrogram()( + np.random.randint(0, 255, size=(2, 16000, 1)) + ) diff --git a/keras/src/layers/preprocessing/string_lookup.py b/keras/src/layers/preprocessing/string_lookup.py new file mode 100644 index 000000000000..2b03e50987bc --- /dev/null +++ b/keras/src/layers/preprocessing/string_lookup.py @@ -0,0 +1,421 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.index_lookup import IndexLookup +from keras.src.utils import backend_utils +from keras.src.utils.module_utils import tensorflow as tf + +if backend.backend() == "torch": + import torch + + +@keras_export("keras.layers.StringLookup") +class StringLookup(IndexLookup): + """A preprocessing layer that maps strings to (possibly encoded) indices. + + This layer translates a set of arbitrary strings into integer output via a + table-based vocabulary lookup. This layer will perform no splitting or + transformation of input strings. For a layer that can split and tokenize + natural language, see the `keras.layers.TextVectorization` layer. + + The vocabulary for the layer must be either supplied on construction or + learned via `adapt()`. During `adapt()`, the layer will analyze a data set, + determine the frequency of individual strings tokens, and create a + vocabulary from them. If the vocabulary is capped in size, the most frequent + tokens will be used to create the vocabulary and all others will be treated + as out-of-vocabulary (OOV). + + There are two possible output modes for the layer. When `output_mode` is + `"int"`, input strings are converted to their index in the vocabulary (an + integer). + When `output_mode` is `"multi_hot"`, `"count"`, or `"tf_idf"`, input strings + are encoded into an array where each dimension corresponds to an element in + the vocabulary. + + The vocabulary can optionally contain a mask token as well as an OOV token + (which can optionally occupy multiple indices in the vocabulary, as set + by `num_oov_indices`). + The position of these tokens in the vocabulary is fixed. When `output_mode` + is `"int"`, the vocabulary will begin with the mask token (if set), followed + by OOV indices, followed by the rest of the vocabulary. When `output_mode` + is `"multi_hot"`, `"count"`, or `"tf_idf"` the vocabulary will begin with + OOV indices and instances of the mask token will be dropped. + + **Note:** This layer uses TensorFlow internally. It cannot + be used as part of the compiled computation graph of a model with + any backend other than TensorFlow. + It can however be used with any backend when running eagerly. + It can also always be used as part of an input preprocessing pipeline + with any backend (outside the model itself), which is how we recommend + using this layer. + + **Note:** This layer is safe to use inside a `tf.data` pipeline + (independently of which backend you're using). + + Args: + max_tokens: Maximum size of the vocabulary for this layer. This should + only be specified when adapting the vocabulary or when setting + `pad_to_max_tokens=True`. If None, there is no cap on the size of + the vocabulary. Note that this size includes the OOV + and mask tokens. Defaults to `None`. + num_oov_indices: The number of out-of-vocabulary tokens to use. + If this value is more than 1, OOV inputs are modulated to + determine their OOV value. + If this value is 0, OOV inputs will cause an error when calling + the layer. Defaults to `1`. + mask_token: A token that represents masked inputs. When `output_mode` is + `"int"`, the token is included in the vocabulary and mapped to index + 0. + In other output modes, the token will not appear in the vocabulary + and instances of the mask token in the input will be dropped. + If set to `None`, no mask term will be added. Defaults to `None`. + oov_token: Only used when `invert` is True. The token to return for OOV + indices. Defaults to `"[UNK]"`. + vocabulary: Optional. Either an array of strings or a string path to a + text file. If passing an array, you can pass a tuple, list, 1D NumPy + array, or 1D tensor containing the string vocabulary terms. + If passing a file path, the file should contain one line per term in + the vocabulary. If this argument is set, there is no need to + `adapt()` the layer. + idf_weights: Only valid when `output_mode` is `"tf_idf"`. + A tuple, list, 1D NumPy array, or 1D tensor or the same length + as the vocabulary, containing the floating point inverse document + frequency weights, which will be multiplied by per sample term + counts for the final TF-IDF weight. + If the `vocabulary` argument is set and `output_mode` is `"tf_idf"`, + this argument must be supplied. + invert: Only valid when `output_mode` is `"int"`. + If `True`, this layer will map indices to vocabulary items + instead of mapping vocabulary items to indices. + Defaults to `False`. + output_mode: Specification for the output of the layer. Values can be + `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or `"tf_idf"` + configuring the layer as follows: + - `"int"`: Return the vocabulary indices of the input tokens. + - `"one_hot"`: Encodes each individual element in the input into an + array the same size as the vocabulary, + containing a 1 at the element index. If the last dimension + is size 1, will encode on that dimension. + If the last dimension is not size 1, will append a new + dimension for the encoded output. + - `"multi_hot"`: Encodes each sample in the input into a single + array the same size as the vocabulary containing a 1 for each + vocabulary term present in the sample. + Treats the last dimension as the sample dimension, if the input + shape is `(..., sample_length)`, the output shape will be + `(..., num_tokens)`. + - `"count"`: As `"multi_hot"`, but the int array contains + a count of the number of times the token at that index + appeared in the sample. + - `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is + applied to find the value in each token slot. + For `"int"` output, any shape of input and output is supported. + For all other output modes, currently only output up to rank 2 + is supported. Defaults to `"int"`. + pad_to_max_tokens: Only applicable when `output_mode` is `"multi_hot"`, + `"count"`, or `"tf_idf"`. If `True`, the output will have + its feature axis padded to `max_tokens` even if the number + of unique tokens in the vocabulary is less than `max_tokens`, + resulting in a tensor of shape `(batch_size, max_tokens)` + regardless of vocabulary size. Defaults to `False`. + sparse: Boolean. Only applicable to `"multi_hot"`, `"count"`, and + `"tf_idf"` output modes. Only supported with TensorFlow + backend. If `True`, returns a `SparseTensor` + instead of a dense `Tensor`. Defaults to `False`. + encoding: Optional. The text encoding to use to interpret the input + strings. Defaults to `"utf-8"`. + + Examples: + + **Creating a lookup layer with a known vocabulary** + + This example creates a lookup layer with a pre-existing vocabulary. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = [["a", "c", "d"], ["d", "z", "b"]] + >>> layer = StringLookup(vocabulary=vocab) + >>> layer(data) + array([[1, 3, 4], + [4, 0, 2]]) + + **Creating a lookup layer with an adapted vocabulary** + + This example creates a lookup layer and generates the vocabulary by + analyzing the dataset. + + >>> data = [["a", "c", "d"], ["d", "z", "b"]] + >>> layer = StringLookup() + >>> layer.adapt(data) + >>> layer.get_vocabulary() + ['[UNK]', 'd', 'z', 'c', 'b', 'a'] + + Note that the OOV token `"[UNK]"` has been added to the vocabulary. + The remaining tokens are sorted by frequency + (`"d"`, which has 2 occurrences, is first) then by inverse sort order. + + >>> data = [["a", "c", "d"], ["d", "z", "b"]] + >>> layer = StringLookup() + >>> layer.adapt(data) + >>> layer(data) + array([[5, 3, 1], + [1, 2, 4]]) + + **Lookups with multiple OOV indices** + + This example demonstrates how to use a lookup layer with multiple OOV + indices. When a layer is created with more than one OOV index, any OOV + values are hashed into the number of OOV buckets, distributing OOV values in + a deterministic fashion across the set. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = [["a", "c", "d"], ["m", "z", "b"]] + >>> layer = StringLookup(vocabulary=vocab, num_oov_indices=2) + >>> layer(data) + array([[2, 4, 5], + [0, 1, 3]]) + + Note that the output for OOV value 'm' is 0, while the output for OOV value + `"z"` is 1. The in-vocab terms have their output index increased by 1 from + earlier examples (a maps to 2, etc) in order to make space for the extra OOV + value. + + **One-hot output** + + Configure the layer with `output_mode='one_hot'`. Note that the first + `num_oov_indices` dimensions in the ont_hot encoding represent OOV values. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = ["a", "b", "c", "d", "z"] + >>> layer = StringLookup(vocabulary=vocab, output_mode='one_hot') + >>> layer(data) + array([[0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.], + [0., 0., 0., 1., 0.], + [0., 0., 0., 0., 1.], + [1., 0., 0., 0., 0.]], dtype=int64) + + **Multi-hot output** + + Configure the layer with `output_mode='multi_hot'`. Note that the first + `num_oov_indices` dimensions in the multi_hot encoding represent OOV values. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]] + >>> layer = StringLookup(vocabulary=vocab, output_mode='multi_hot') + >>> layer(data) + array([[0., 1., 0., 1., 1.], + [1., 0., 1., 0., 1.]], dtype=int64) + + **Token count output** + + Configure the layer with `output_mode='count'`. As with multi_hot output, + the first `num_oov_indices` dimensions in the output represent OOV values. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]] + >>> layer = StringLookup(vocabulary=vocab, output_mode='count') + >>> layer(data) + array([[0., 1., 0., 1., 2.], + [2., 0., 1., 0., 1.]], dtype=int64) + + **TF-IDF output** + + Configure the layer with `output_mode="tf_idf"`. As with multi_hot output, + the first `num_oov_indices` dimensions in the output represent OOV values. + + Each token bin will output `token_count * idf_weight`, where the idf weights + are the inverse document frequency weights per token. These should be + provided along with the vocabulary. Note that the `idf_weight` for OOV + values will default to the average of all idf weights passed in. + + >>> vocab = ["a", "b", "c", "d"] + >>> idf_weights = [0.25, 0.75, 0.6, 0.4] + >>> data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]] + >>> layer = StringLookup(output_mode="tf_idf") + >>> layer.set_vocabulary(vocab, idf_weights=idf_weights) + >>> layer(data) + array([[0. , 0.25, 0. , 0.6 , 0.8 ], + [1.0 , 0. , 0.75, 0. , 0.4 ]], dtype=float32) + + To specify the idf weights for OOV values, you will need to pass the entire + vocabulary including the leading OOV token. + + >>> vocab = ["[UNK]", "a", "b", "c", "d"] + >>> idf_weights = [0.9, 0.25, 0.75, 0.6, 0.4] + >>> data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]] + >>> layer = StringLookup(output_mode="tf_idf") + >>> layer.set_vocabulary(vocab, idf_weights=idf_weights) + >>> layer(data) + array([[0. , 0.25, 0. , 0.6 , 0.8 ], + [1.8 , 0. , 0.75, 0. , 0.4 ]], dtype=float32) + + When adapting the layer in `"tf_idf"` mode, each input sample will be + considered a document, and IDF weight per token will be calculated as + `log(1 + num_documents / (1 + token_document_count))`. + + **Inverse lookup** + + This example demonstrates how to map indices to strings using this layer. + (You can also use `adapt()` with `inverse=True`, but for simplicity we'll + pass the vocab in this example.) + + >>> vocab = ["a", "b", "c", "d"] + >>> data = [[1, 3, 4], [4, 0, 2]] + >>> layer = StringLookup(vocabulary=vocab, invert=True) + >>> layer(data) + array([[b'a', b'c', b'd'], + [b'd', b'[UNK]', b'b']], dtype=object) + + Note that the first index corresponds to the OOV token by default. + + + **Forward and inverse lookup pairs** + + This example demonstrates how to use the vocabulary of a standard lookup + layer to create an inverse lookup layer. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = [["a", "c", "d"], ["d", "z", "b"]] + >>> layer = StringLookup(vocabulary=vocab) + >>> i_layer = StringLookup(vocabulary=vocab, invert=True) + >>> int_data = layer(data) + >>> i_layer(int_data) + array([[b'a', b'c', b'd'], + [b'd', b'[UNK]', b'b']], dtype=object) + + In this example, the input value `"z"` resulted in an output of `"[UNK]"`, + since 1000 was not in the vocabulary - it got represented as an OOV, and all + OOV values are returned as `"[UNK]"` in the inverse layer. Also, note that + for the inverse to work, you must have already set the forward layer + vocabulary either directly or via `adapt()` before calling + `get_vocabulary()`. + """ + + def __init__( + self, + max_tokens=None, + num_oov_indices=1, + mask_token=None, + oov_token="[UNK]", + vocabulary=None, + idf_weights=None, + invert=False, + output_mode="int", + pad_to_max_tokens=False, + sparse=False, + encoding="utf-8", + name=None, + **kwargs, + ): + if not tf.available: + raise ImportError( + "Layer StringLookup requires TensorFlow. " + "Install it via `pip install tensorflow`." + ) + if sparse and backend.backend() != "tensorflow": + raise ValueError( + "`sparse=True` can only be used with the TensorFlow backend." + ) + self.encoding = encoding + super().__init__( + max_tokens=max_tokens, + num_oov_indices=num_oov_indices, + mask_token=mask_token, + oov_token=oov_token, + vocabulary=vocabulary, + idf_weights=idf_weights, + invert=invert, + output_mode=output_mode, + pad_to_max_tokens=pad_to_max_tokens, + sparse=sparse, + name=name, + vocabulary_dtype="string", + **kwargs, + ) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.supports_jit = False + + def adapt(self, data, steps=None): + """Computes a vocabulary of terms from tokens in a dataset. + + Calling `adapt()` on a `StringLookup` layer is an alternative to passing + in a precomputed vocabulary on construction via the `vocabulary` + argument. A `StringLookup` layer should always be either adapted over a + dataset or supplied with a vocabulary. + + During `adapt()`, the layer will build a vocabulary of all string tokens + seen in the dataset, sorted by occurrence count, with ties broken by + sort order of the tokens (high to low). At the end of `adapt()`, if + `max_tokens` is set, the vocabulary will be truncated to `max_tokens` + size. For example, adapting a layer with `max_tokens=1000` will compute + the 1000 most frequent tokens occurring in the input dataset. If + `output_mode='tf-idf'`, `adapt()` will also learn the document + frequencies of each token in the input dataset. + + Arguments: + data: The data to train on. It can be passed either as a + batched `tf.data.Dataset`, as a list of strings, + or as a NumPy array. + steps: Integer or `None`. + Total number of steps (batches of samples) to process. + If `data` is a `tf.data.Dataset`, and `steps` is `None`, + `adapt()` will run until the input dataset is exhausted. + When passing an infinitely + repeating dataset, you must specify the `steps` argument. This + argument is not supported with array inputs or list inputs. + """ + super().adapt(data, steps=steps) + + # Overridden methods from IndexLookup. + def _tensor_vocab_to_numpy(self, vocabulary): + vocabulary = vocabulary.numpy() + return np.array( + [tf.compat.as_text(x, self.encoding) for x in vocabulary] + ) + + def get_config(self): + config = {"encoding": self.encoding} + base_config = super().get_config() + # There is only one valid dtype for strings, so we don't expose this. + del base_config["vocabulary_dtype"] + return {**base_config, **config} + + def call(self, inputs): + is_torch_backend = backend.backend() == "torch" + + # Handle input conversion + inputs_for_processing = inputs + was_tf_input = isinstance( + inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor) + ) + + if is_torch_backend and isinstance(inputs, torch.Tensor): + inputs_for_processing = tf.convert_to_tensor( + inputs.detach().cpu().numpy() + ) + elif isinstance(inputs, (np.ndarray, list, tuple)): + inputs_for_processing = tf.convert_to_tensor(inputs) + elif not was_tf_input: + inputs_for_processing = tf.convert_to_tensor( + backend.convert_to_numpy(inputs) + ) + + output = super().call(inputs_for_processing) + + # Handle torch backend output conversion + if is_torch_backend and isinstance( + inputs, (torch.Tensor, np.ndarray, list, tuple) + ): + numpy_outputs = output.numpy() + if self.invert: + return [n.decode(self.encoding) for n in numpy_outputs] + else: + return torch.from_numpy(numpy_outputs) + + # other backends + if not was_tf_input: + output = backend_utils.convert_tf_tensor(output) + + return output diff --git a/keras/src/layers/preprocessing/string_lookup_test.py b/keras/src/layers/preprocessing/string_lookup_test.py new file mode 100644 index 000000000000..307591d9c8ad --- /dev/null +++ b/keras/src/layers/preprocessing/string_lookup_test.py @@ -0,0 +1,121 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.ops import convert_to_tensor + + +class StringLookupTest(testing.TestCase): + # TODO: increase coverage. Most features aren't being tested. + + def test_config(self): + layer = layers.StringLookup( + output_mode="int", + vocabulary=["a", "b", "c"], + oov_token="[OOV]", + mask_token="[MASK]", + ) + self.run_class_serialization_test(layer) + + def test_adapt_flow(self): + layer = layers.StringLookup( + output_mode="int", + ) + layer.adapt(["a", "a", "a", "b", "b", "c"]) + input_data = ["b", "c", "d"] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([2, 3, 0])) + + def test_fixed_vocabulary(self): + layer = layers.StringLookup( + output_mode="int", + vocabulary=["a", "b", "c"], + ) + input_data = ["b", "c", "d"] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([2, 3, 0])) + + @pytest.mark.skipif( + not backend.backend() == "tensorflow", reason="Requires tf.SparseTensor" + ) + def test_sparse_inputs(self): + import tensorflow as tf + + layer = layers.StringLookup( + output_mode="int", + vocabulary=["a", "b", "c"], + ) + input_data = tf.SparseTensor( + indices=[[0, 0], [1, 1], [2, 2]], + values=["b", "c", "d"], + dense_shape=(3, 3), + ) + output = layer(input_data) + self.assertIsInstance(output, tf.SparseTensor) + self.assertAllClose(output, np.array([[2, 0, 0], [0, 3, 0], [0, 0, 0]])) + self.assertAllClose(output.values, np.array([2, 3, 0])) + + def test_set_vocabulary(self): + layer = layers.StringLookup( + output_mode="int", + ) + layer.set_vocabulary(["a", "b", "c"]) + input_data = ["b", "c", "d"] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([2, 3, 0])) + + def test_tf_data_compatibility(self): + layer = layers.StringLookup( + output_mode="int", + vocabulary=["a", "b", "c"], + ) + input_data = ["b", "c", "d"] + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(3).map(layer) + output = next(iter(ds)).numpy() + self.assertAllClose(output, np.array([2, 3, 0])) + + @pytest.mark.skipif(not backend.backend() == "tensorflow", reason="tf only") + def test_tensor_as_vocab(self): + vocab = convert_to_tensor(["a", "b", "c", "d"]) + data = [["a", "c", "d"], ["d", "z", "b"]] + layer = layers.StringLookup( + vocabulary=vocab, + ) + output = layer(data) + self.assertAllClose(output, np.array([[1, 3, 4], [4, 0, 2]])) + + @pytest.mark.skipif(backend.backend() != "torch", reason="Only torch") + def test_torch_backend_compatibility(self): + import torch + + # Forward lookup: String -> number + forward_lookup = layers.StringLookup( + vocabulary=["a", "b", "c"], oov_token="[OOV]" + ) + input_data_str = ["a", "b", "[OOV]", "d"] + output_numeric = forward_lookup(input_data_str) + + # assert instance of output is torch.Tensor + self.assertIsInstance(output_numeric, torch.Tensor) + expected_numeric = torch.tensor([1, 2, 0, 0]) + self.assertAllClose(output_numeric.cpu(), expected_numeric) + + oov = "[OOV]" + # Inverse lookup: Number -> string + inverse_lookup = layers.StringLookup( + vocabulary=["a", "b", "c"], oov_token=oov, invert=True + ) + input_data_int = torch.tensor([1, 2, 0], dtype=torch.int64) + output_string = inverse_lookup(input_data_int) + # Assert that the output is a list + # See : https://docs.pytorch.org/text/stable/_modules/torchtext/vocab/vocab.html#Vocab.lookup_tokens + # The torch equivalent implementation of this returns a list of strings + self.assertIsInstance(output_string, list) + expected_string = ["a", "b", "[OOV]"] + self.assertEqual(output_string, expected_string) diff --git a/keras/src/layers/preprocessing/text_vectorization.py b/keras/src/layers/preprocessing/text_vectorization.py new file mode 100644 index 000000000000..bb04e023a496 --- /dev/null +++ b/keras/src/layers/preprocessing/text_vectorization.py @@ -0,0 +1,630 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.layers.preprocessing.index_lookup import listify_tensors +from keras.src.layers.preprocessing.string_lookup import StringLookup +from keras.src.saving import serialization_lib +from keras.src.utils import argument_validation +from keras.src.utils import backend_utils +from keras.src.utils import tf_utils +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.TextVectorization") +class TextVectorization(Layer): + """A preprocessing layer which maps text features to integer sequences. + + This layer has basic options for managing text in a Keras model. It + transforms a batch of strings (one example = one string) into either a list + of token indices (one example = 1D tensor of integer token indices) or a + dense representation (one example = 1D tensor of float values representing + data about the example's tokens). This layer is meant to handle natural + language inputs. To handle simple string inputs (categorical strings or + pre-tokenized strings) see `kers_core.layers.StringLookup`. + + The vocabulary for the layer must be either supplied on construction or + learned via `adapt()`. When this layer is adapted, it will analyze the + dataset, determine the frequency of individual string values, and create a + vocabulary from them. This vocabulary can have unlimited size or be capped, + depending on the configuration options for this layer; if there are more + unique values in the input than the maximum vocabulary size, the most + frequent terms will be used to create the vocabulary. + + The processing of each example contains the following steps: + + 1. Standardize each example (usually lowercasing + punctuation stripping) + 2. Split each example into substrings (usually words) + 3. Recombine substrings into tokens (usually ngrams) + 4. Index tokens (associate a unique int value with each token) + 5. Transform each example using this index, either into a vector of ints or + a dense float vector. + + Some notes on passing callables to customize splitting and normalization for + this layer: + + 1. Any callable can be passed to this Layer, but if you want to serialize + this object you should only pass functions that are registered Keras + serializables (see `keras.saving.register_keras_serializable` + for more details). + 2. When using a custom callable for `standardize`, the data received + by the callable will be exactly as passed to this layer. The callable + should return a tensor of the same shape as the input. + 3. When using a custom callable for `split`, the data received by the + callable will have the 1st dimension squeezed out - instead of + `[["string to split"], ["another string to split"]]`, the Callable will + see `["string to split", "another string to split"]`. + The callable should return a `tf.Tensor` of dtype `string` + with the first dimension containing the split tokens - + in this example, we should see something like `[["string", "to", + "split"], ["another", "string", "to", "split"]]`. + + **Note:** This layer uses TensorFlow internally. It cannot + be used as part of the compiled computation graph of a model with + any backend other than TensorFlow. + It can however be used with any backend when running eagerly. + It can also always be used as part of an input preprocessing pipeline + with any backend (outside the model itself), which is how we recommend + to use this layer. + + **Note:** This layer is safe to use inside a `tf.data` pipeline + (independently of which backend you're using). + + Args: + max_tokens: Maximum size of the vocabulary for this layer. This should + only be specified when adapting a vocabulary or when setting + `pad_to_max_tokens=True`. Note that this vocabulary + contains 1 OOV token, so the effective number of tokens is + `(max_tokens - 1 - (1 if output_mode == "int" else 0))`. + standardize: Optional specification for standardization to apply to the + input text. Values can be: + - `None`: No standardization. + - `"lower_and_strip_punctuation"`: Text will be lowercased and all + punctuation removed. + - `"lower"`: Text will be lowercased. + - `"strip_punctuation"`: All punctuation will be removed. + - Callable: Inputs will passed to the callable function, + which should be standardized and returned. + split: Optional specification for splitting the input text. + Values can be: + - `None`: No splitting. + - `"whitespace"`: Split on whitespace. + - `"character"`: Split on each unicode character. + - Callable: Standardized inputs will passed to the callable + function, which should be split and returned. + ngrams: Optional specification for ngrams to create from the + possibly-split input text. Values can be `None`, an integer + or tuple of integers; passing an integer will create ngrams + up to that integer, and passing a tuple of integers will + create ngrams for the specified values in the tuple. + Passing `None` means that no ngrams will be created. + output_mode: Optional specification for the output of the layer. + Values can be `"int"`, `"multi_hot"`, `"count"` or `"tf_idf"`, + configuring the layer as follows: + - `"int"`: Outputs integer indices, one integer index per split + string token. When `output_mode == "int"`, + 0 is reserved for masked locations; + this reduces the vocab size to `max_tokens - 2` + instead of `max_tokens - 1`. + - `"multi_hot"`: Outputs a single int array per batch, of either + vocab_size or max_tokens size, containing 1s in all elements + where the token mapped to that index exists at least + once in the batch item. + - `"count"`: Like `"multi_hot"`, but the int array contains + a count of the number of times the token at that index + appeared in the batch item. + - `"tf_idf"`: Like `"multi_hot"`, but the TF-IDF algorithm + is applied to find the value in each token slot. + For `"int"` output, any shape of input and output is supported. + For all other output modes, currently only rank 1 inputs + (and rank 2 outputs after splitting) are supported. + output_sequence_length: Only valid in INT mode. If set, the output will + have its time dimension padded or truncated to exactly + `output_sequence_length` values, resulting in a tensor of shape + `(batch_size, output_sequence_length)` regardless of how many tokens + resulted from the splitting step. Defaults to `None`. If `ragged` + is `True` then `output_sequence_length` may still truncate the + output. + pad_to_max_tokens: Only valid in `"multi_hot"`, `"count"`, + and `"tf_idf"` modes. If `True`, the output will have + its feature axis padded to `max_tokens` even if the number + of unique tokens in the vocabulary is less than `max_tokens`, + resulting in a tensor of shape `(batch_size, max_tokens)` + regardless of vocabulary size. Defaults to `False`. + vocabulary: Optional. Either an array of strings or a string path to a + text file. If passing an array, can pass a tuple, list, + 1D NumPy array, or 1D tensor containing the string vocabulary terms. + If passing a file path, the file should contain one line per term + in the vocabulary. If this argument is set, + there is no need to `adapt()` the layer. + idf_weights: Only valid when `output_mode` is `"tf_idf"`. A tuple, list, + 1D NumPy array, or 1D tensor of the same length as the vocabulary, + containing the floating point inverse document frequency weights, + which will be multiplied by per sample term counts for + the final `tf_idf` weight. If the `vocabulary` argument is set, + and `output_mode` is `"tf_idf"`, this argument must be supplied. + ragged: Boolean. Only applicable to `"int"` output mode. + Only supported with TensorFlow backend. + If `True`, returns a `RaggedTensor` instead of a dense `Tensor`, + where each sequence may have a different length + after string splitting. Defaults to `False`. + sparse: Boolean. Only applicable to `"multi_hot"`, `"count"`, and + `"tf_idf"` output modes. Only supported with TensorFlow + backend. If `True`, returns a `SparseTensor` + instead of a dense `Tensor`. Defaults to `False`. + encoding: Optional. The text encoding to use to interpret the input + strings. Defaults to `"utf-8"`. + + Examples: + + This example instantiates a `TextVectorization` layer that lowercases text, + splits on whitespace, strips punctuation, and outputs integer vocab indices. + + >>> max_tokens = 5000 # Maximum vocab size. + >>> max_len = 4 # Sequence length to pad the outputs to. + >>> # Create the layer. + >>> vectorize_layer = TextVectorization( + ... max_tokens=max_tokens, + ... output_mode='int', + ... output_sequence_length=max_len) + + >>> # Now that the vocab layer has been created, call `adapt` on the + >>> # list of strings to create the vocabulary. + >>> vectorize_layer.adapt(["foo bar", "bar baz", "baz bada boom"]) + + >>> # Now, the layer can map strings to integers -- you can use an + >>> # embedding layer to map these integers to learned embeddings. + >>> input_data = [["foo qux bar"], ["qux baz"]] + >>> vectorize_layer(input_data) + array([[4, 1, 3, 0], + [1, 2, 0, 0]]) + + This example instantiates a `TextVectorization` layer by passing a list + of vocabulary terms to the layer's `__init__()` method. + + >>> vocab_data = ["earth", "wind", "and", "fire"] + >>> max_len = 4 # Sequence length to pad the outputs to. + >>> # Create the layer, passing the vocab directly. You can also pass the + >>> # vocabulary arg a path to a file containing one vocabulary word per + >>> # line. + >>> vectorize_layer = keras.layers.TextVectorization( + ... max_tokens=max_tokens, + ... output_mode='int', + ... output_sequence_length=max_len, + ... vocabulary=vocab_data) + + >>> # Because we've passed the vocabulary directly, we don't need to adapt + >>> # the layer - the vocabulary is already set. The vocabulary contains the + >>> # padding token ('') and OOV token ('[UNK]') + >>> # as well as the passed tokens. + >>> vectorize_layer.get_vocabulary() + ['', '[UNK]', 'earth', 'wind', 'and', 'fire'] + """ + + def __init__( + self, + max_tokens=None, + standardize="lower_and_strip_punctuation", + split="whitespace", + ngrams=None, + output_mode="int", + output_sequence_length=None, + pad_to_max_tokens=False, + vocabulary=None, + idf_weights=None, + sparse=False, + ragged=False, + encoding="utf-8", + name=None, + **kwargs, + ): + if not tf.available: + raise ImportError( + "Layer TextVectorization requires TensorFlow. " + "Install it via `pip install tensorflow`." + ) + if sparse and backend.backend() != "tensorflow": + raise ValueError( + "`sparse=True` can only be used with the TensorFlow backend." + ) + if ragged and backend.backend() != "tensorflow": + raise ValueError( + "`ragged=True` can only be used with the TensorFlow backend." + ) + + # 'standardize' must be one of + # (None, "lower_and_strip_punctuation", "lower", "strip_punctuation", + # callable) + argument_validation.validate_string_arg( + standardize, + allowable_strings=( + "lower_and_strip_punctuation", + "lower", + "strip_punctuation", + ), + caller_name=self.__class__.__name__, + arg_name="standardize", + allow_none=True, + allow_callables=True, + ) + + # 'split' must be one of (None, "whitespace", "character", callable) + argument_validation.validate_string_arg( + split, + allowable_strings=("whitespace", "character"), + caller_name=self.__class__.__name__, + arg_name="split", + allow_none=True, + allow_callables=True, + ) + + # Support deprecated names for output_modes. + if output_mode == "binary": + output_mode = "multi_hot" + if output_mode == "tf-idf": + output_mode = "tf_idf" + argument_validation.validate_string_arg( + output_mode, + allowable_strings=( + "int", + "one_hot", + "multi_hot", + "count", + "tf_idf", + ), + caller_name=self.__class__.__name__, + arg_name="output_mode", + ) + + # 'ngrams' must be one of (None, int, tuple(int)) + if not ( + ngrams is None + or isinstance(ngrams, int) + or isinstance(ngrams, tuple) + and all(isinstance(item, int) for item in ngrams) + ): + raise ValueError( + "`ngrams` must be None, an integer, or a tuple of " + f"integers. Received: ngrams={ngrams}" + ) + + # 'output_sequence_length' must be one of (None, int) and is only + # set if output_mode is "int"". + if output_mode == "int" and not ( + isinstance(output_sequence_length, int) + or (output_sequence_length is None) + ): + raise ValueError( + "`output_sequence_length` must be either None or an " + "integer when `output_mode` is 'int'. Received: " + f"output_sequence_length={output_sequence_length}" + ) + + if output_mode != "int" and output_sequence_length is not None: + raise ValueError( + "`output_sequence_length` must not be set if `output_mode` is " + "not 'int'. " + f"Received output_sequence_length={output_sequence_length}." + ) + + if ragged and output_mode != "int": + raise ValueError( + "`ragged` must not be true if `output_mode` is " + f"`'int'`. Received: ragged={ragged} and " + f"output_mode={output_mode}" + ) + + self._max_tokens = max_tokens + self._standardize = standardize + self._split = split + self._ngrams_arg = ngrams + if isinstance(ngrams, int): + self._ngrams = tuple(range(1, ngrams + 1)) + else: + self._ngrams = ngrams + self._ragged = ragged + + self._output_mode = output_mode + self._output_sequence_length = output_sequence_length + self._encoding = encoding + + # We save this hidden option to persist the fact + # that we have a non-adaptable layer with a + # manually set vocab. + self._has_input_vocabulary = kwargs.pop( + "has_input_vocabulary", (vocabulary is not None) + ) + vocabulary_size = kwargs.pop("vocabulary_size", None) + + super().__init__(name=name, **kwargs) + + self._lookup_layer = StringLookup( + max_tokens=max_tokens, + vocabulary=vocabulary, + idf_weights=idf_weights, + pad_to_max_tokens=pad_to_max_tokens, + mask_token="", + output_mode=output_mode, + sparse=sparse, + has_input_vocabulary=self._has_input_vocabulary, + encoding=encoding, + vocabulary_size=vocabulary_size, + ) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.supports_jit = False + + @property + def compute_dtype(self): + return "string" + + @property + def variable_dtype(self): + return "string" + + def build(self, input_shape=None): + pass + + def compute_output_shape(self, input_shape): + if self._output_mode == "int": + return (input_shape[0], self._output_sequence_length) + if self._split is None: + if len(input_shape) <= 1: + input_shape = tuple(input_shape) + (1,) + else: + input_shape = tuple(input_shape) + (None,) + return self._lookup_layer.compute_output_shape(input_shape) + + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape(inputs.shape) + if self._output_mode == "int": + output_dtype = "int64" + else: + output_dtype = backend.floatx() + return backend.KerasTensor(output_shape, dtype=output_dtype) + + def adapt(self, data, batch_size=None, steps=None): + """Computes a vocabulary of string terms from tokens in a dataset. + + Calling `adapt()` on a `TextVectorization` layer is an alternative to + passing in a precomputed vocabulary on construction via the `vocabulary` + argument. A `TextVectorization` layer should always be either adapted + over a dataset or supplied with a vocabulary. + + During `adapt()`, the layer will build a vocabulary of all string tokens + seen in the dataset, sorted by occurrence count, with ties broken by + sort order of the tokens (high to low). At the end of `adapt()`, if + `max_tokens` is set, the vocabulary will be truncated to `max_tokens` + size. For example, adapting a layer with `max_tokens=1000` will compute + the 1000 most frequent tokens occurring in the input dataset. If + `output_mode='tf-idf'`, `adapt()` will also learn the document + frequencies of each token in the input dataset. + + Arguments: + data: The data to train on. It can be passed either as a + batched `tf.data.Dataset`, as a list of strings, + or as a NumPy array. + steps: Integer or `None`. + Total number of steps (batches of samples) to process. + If `data` is a `tf.data.Dataset`, and `steps` is `None`, + `adapt()` will run until the input dataset is exhausted. + When passing an infinitely + repeating dataset, you must specify the `steps` argument. This + argument is not supported with array inputs or list inputs. + """ + self.reset_state() + if isinstance(data, tf.data.Dataset): + if steps is not None: + data = data.take(steps) + for batch in data: + self.update_state(batch) + else: + data = tf_utils.ensure_tensor(data, dtype="string") + if data.shape.rank == 1: + # A plain list of strings + # is treated as as many documents + data = tf.expand_dims(data, -1) + self.update_state(data) + self.finalize_state() + + def update_state(self, data): + self._lookup_layer.update_state(self._preprocess(data)) + + def finalize_state(self): + self._lookup_layer.finalize_state() + + def reset_state(self): + self._lookup_layer.reset_state() + + def get_vocabulary(self, include_special_tokens=True): + """Returns the current vocabulary of the layer. + + Args: + include_special_tokens: If `True`, the returned vocabulary + will include the padding and OOV tokens, + and a term's index in the vocabulary will equal + the term's index when calling the layer. If `False`, the + returned vocabulary will not include any padding + or OOV tokens. + """ + return self._lookup_layer.get_vocabulary(include_special_tokens) + + def vocabulary_size(self): + """Gets the current size of the layer's vocabulary. + + Returns: + The integer size of the vocabulary, including optional + mask and OOV indices. + """ + return self._lookup_layer.vocabulary_size() + + def get_config(self): + config = { + "max_tokens": self._lookup_layer.max_tokens, + "standardize": self._standardize, + "split": self._split, + "ngrams": self._ngrams_arg, + "output_mode": self._output_mode, + "output_sequence_length": self._output_sequence_length, + "pad_to_max_tokens": self._lookup_layer.pad_to_max_tokens, + "sparse": self._lookup_layer.sparse, + "ragged": self._ragged, + "vocabulary": listify_tensors(self._lookup_layer.input_vocabulary), + "idf_weights": listify_tensors( + self._lookup_layer.input_idf_weights + ), + "encoding": self._encoding, + "vocabulary_size": self.vocabulary_size(), + } + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + if not isinstance(config["standardize"], str): + config["standardize"] = serialization_lib.deserialize_keras_object( + config["standardize"] + ) + if not isinstance(config["split"], str): + config["split"] = serialization_lib.deserialize_keras_object( + config["split"] + ) + + if isinstance(config["ngrams"], list): + config["ngrams"] = tuple(config["ngrams"]) + + return cls(**config) + + def set_vocabulary(self, vocabulary, idf_weights=None): + """Sets vocabulary (and optionally document frequency) for this layer. + + This method sets the vocabulary and IDF weights for this layer directly, + instead of analyzing a dataset through `adapt()`. It should be used + whenever the vocab (and optionally document frequency) information is + already known. If vocabulary data is already present in the layer, this + method will replace it. + + Args: + vocabulary: Either an array or a string path to a text file. + If passing an array, can pass a tuple, list, 1D NumPy array, + or 1D tensor containing the vocabulary terms. + If passing a file path, the file should contain one line + per term in the vocabulary. + idf_weights: A tuple, list, 1D NumPy array, or 1D tensor of inverse + document frequency weights with equal length to vocabulary. + Must be set if `output_mode` is `"tf_idf"`. + Should not be set otherwise. + """ + self._lookup_layer.set_vocabulary(vocabulary, idf_weights=idf_weights) + + def _preprocess(self, inputs): + inputs = tf_utils.ensure_tensor(inputs, dtype=tf.string) + if self._standardize in ("lower", "lower_and_strip_punctuation"): + inputs = tf.strings.lower(inputs) + if self._standardize in ( + "strip_punctuation", + "lower_and_strip_punctuation", + ): + inputs = tf.strings.regex_replace( + inputs, r'[!"#$%&()\*\+,-\./:;<=>?@\[\\\]^_`{|}~\']', "" + ) + if callable(self._standardize): + inputs = self._standardize(inputs) + + if self._split is not None: + # If we are splitting, we validate that the 1st axis is of dimension + # 1 and so can be squeezed out. We do this here instead of after + # splitting for performance reasons - it's more expensive to squeeze + # a ragged tensor. + if inputs.shape.rank > 1: + if inputs.shape[-1] != 1: + raise ValueError( + "When using `TextVectorization` to tokenize strings, " + "the input rank must be 1 or the last shape dimension " + f"must be 1. Received: inputs.shape={inputs.shape} " + f"with rank={inputs.shape.rank}" + ) + else: + inputs = tf.squeeze(inputs, axis=-1) + if self._split == "whitespace": + # This treats multiple whitespaces as one whitespace, and strips + # leading and trailing whitespace. + inputs = tf.strings.split(inputs) + elif self._split == "character": + inputs = tf.strings.unicode_split(inputs, "UTF-8") + elif callable(self._split): + inputs = self._split(inputs) + + # Note that 'inputs' here can be either ragged or dense depending on the + # configuration choices for this Layer. The strings.ngrams op, however, + # does support both ragged and dense inputs. + if self._ngrams is not None: + inputs = tf.strings.ngrams( + inputs, ngram_width=self._ngrams, separator=" " + ) + return inputs + + def call(self, inputs): + if not isinstance( + inputs, (tf.Tensor, tf.RaggedTensor, np.ndarray, list, tuple) + ): + inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) + + inputs = self._preprocess(inputs) + + # If we're not doing any output processing, return right away. + if self._output_mode is None: + outputs = inputs + + lookup_data = self._lookup_layer.call(inputs) + + # For non-int output, we can return directly from the underlying layer. + if self._output_mode != "int": + return backend_utils.convert_tf_tensor(lookup_data) + + # If we have a ragged tensor, we can pad during the conversion to dense. + if isinstance(lookup_data, tf.RaggedTensor) and not self._ragged: + shape = lookup_data.shape.as_list() + # If output sequence length is None, to_tensor will pad the last + # dimension to the bounding shape of the ragged dimension. + shape[-1] = self._output_sequence_length + outputs = lookup_data.to_tensor(default_value=0, shape=shape) + # If we have a dense tensor, we need to pad/trim directly. + elif self._output_sequence_length is not None: + # Maybe trim the output. + outputs = lookup_data[..., : self._output_sequence_length] + + # Maybe pad the output. We need to be careful to use dynamic shape + # here as required_space_to_batch_paddings requires a fully known + # shape. + if not self._ragged: + shape = tf.shape(outputs) + padded_shape = tf.concat( + (shape[:-1], [self._output_sequence_length]), 0 + ) + padding, _ = tf.required_space_to_batch_paddings( + shape, padded_shape + ) + outputs = tf.pad(outputs, padding) + # Because `tf.pad` used a dynamic shape, the output shape is + # dynamic. Apply the known static `_output_sequence_length`. + static_padded_shape = lookup_data.shape.as_list() + static_padded_shape[-1] = self._output_sequence_length + outputs.set_shape(static_padded_shape) + else: + outputs = lookup_data + + return backend_utils.convert_tf_tensor(outputs) + + def save_own_variables(self, store): + self._lookup_layer.save_own_variables(store) + + def load_own_variables(self, store): + self._lookup_layer.load_own_variables(store) + + def save_assets(self, dir_path): + self._lookup_layer.save_assets(dir_path) + + def load_assets(self, dir_path): + self._lookup_layer.load_assets(dir_path) diff --git a/keras/src/layers/preprocessing/text_vectorization_test.py b/keras/src/layers/preprocessing/text_vectorization_test.py new file mode 100644 index 000000000000..341b4b5b7f10 --- /dev/null +++ b/keras/src/layers/preprocessing/text_vectorization_test.py @@ -0,0 +1,316 @@ +import os + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import data as tf_data + +from keras.src import Sequential +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import saving +from keras.src import testing + + +class TextVectorizationTest(testing.TestCase, parameterized.TestCase): + # TODO: increase coverage. Most features aren't being tested. + + def test_config(self): + layer = layers.TextVectorization( + output_mode="int", + vocabulary=["one", "two"], + output_sequence_length=5, + ) + self.run_class_serialization_test(layer) + + def test_adapt_flow(self): + max_tokens = 5000 + max_len = 4 + layer = layers.TextVectorization( + max_tokens=max_tokens, + output_mode="int", + output_sequence_length=max_len, + ) + layer.adapt(["foo bar", "bar baz", "baz bada boom"]) + input_data = [["foo qux bar"], ["qux baz"]] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]])) + + def test_fixed_vocabulary(self): + max_tokens = 5000 + max_len = 4 + layer = layers.TextVectorization( + max_tokens=max_tokens, + output_mode="int", + output_sequence_length=max_len, + vocabulary=["baz", "bar", "foo"], + ) + input_data = [["foo qux bar"], ["qux baz"]] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]])) + + def test_set_vocabulary(self): + max_tokens = 5000 + max_len = 4 + layer = layers.TextVectorization( + max_tokens=max_tokens, + output_mode="int", + output_sequence_length=max_len, + ) + layer.set_vocabulary(["baz", "bar", "foo"]) + input_data = [["foo qux bar"], ["qux baz"]] + output = layer(input_data) + self.assertTrue(backend.is_tensor(output)) + self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]])) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires string input dtype" + ) + def test_save_load_with_ngrams_flow(self): + input_data = np.array(["foo bar", "bar baz", "baz bada boom"]) + model = Sequential( + [ + layers.Input(dtype="string", shape=(1,)), + layers.TextVectorization(ngrams=(1, 2)), + ] + ) + model.layers[0].adapt(input_data) + output = model(input_data) + temp_filepath = os.path.join(self.get_temp_dir(), "model.keras") + model.save(temp_filepath) + model = saving.load_model(temp_filepath) + self.assertAllClose(output, model(input_data)) + + def test_tf_data_compatibility(self): + max_tokens = 5000 + max_len = 4 + layer = layers.TextVectorization( + max_tokens=max_tokens, + output_mode="int", + output_sequence_length=max_len, + vocabulary=["baz", "bar", "foo"], + ) + input_data = [["foo qux bar"], ["qux baz"]] + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + output = next(iter(ds)).numpy() + self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]])) + + # Test adapt flow + layer = layers.TextVectorization( + max_tokens=max_tokens, + output_mode="int", + output_sequence_length=max_len, + ) + layer.adapt(input_data) + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + next(iter(ds)).numpy() + + @parameterized.named_parameters( + [ + ("from_ragged", "whitespace"), # intermediate tensor is ragged + ("from_dense", None), # intermediate tensor is dense + ] + ) + def test_static_output_sequence_length(self, split): + max_tokens = 5000 + max_len = 4 + layer = layers.TextVectorization( + max_tokens=max_tokens, + output_mode="int", + output_sequence_length=max_len, + split=split, + vocabulary=["baz", "bar", "foo"], + ) + if split: + input_data = [["foo qux bar"], ["qux baz"]] + else: + input_data = [["foo"], ["baz"]] + + def call_layer(x): + result = layer(x) + self.assertEqual(result.shape, (None, 4)) + return result + + ds = ( + tf_data.Dataset.from_tensor_slices(input_data) + .batch(2) + .map(call_layer) + ) + next(iter(ds)) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires string tensors." + ) + def test_tf_as_first_sequential_layer(self): + layer = layers.TextVectorization( + max_tokens=10, + output_mode="int", + output_sequence_length=3, + ) + layer.set_vocabulary(["baz", "bar", "foo"]) + model = models.Sequential( + [ + layer, + layers.Embedding(5, 4), + ] + ) + model(backend.convert_to_tensor([["foo qux bar"], ["qux baz"]])) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires ragged tensors." + ) + def test_ragged_tensor(self): + layer = layers.TextVectorization( + output_mode="int", + vocabulary=["baz", "bar", "foo"], + ragged=True, + ) + input_data = [["foo qux bar"], ["qux baz"], ["foo"]] + output = layer(input_data) + self.assertIsInstance(output, tf.RaggedTensor) + self.assertEqual(output.shape, (3, None)) + self.assertEqual(output.to_list(), [[4, 1, 3], [1, 2], [4]]) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Requires ragged tensors." + ) + def test_ragged_tensor_output_length(self): + layer = layers.TextVectorization( + output_mode="int", + vocabulary=["baz", "bar", "foo"], + ragged=True, + output_sequence_length=2, + ) + input_data = [["foo qux bar"], ["qux baz"], ["foo"]] + output = layer(input_data) + self.assertIsInstance(output, tf.RaggedTensor) + self.assertEqual(output.shape, (3, None)) + self.assertEqual(output.to_list(), [[4, 1], [1, 2], [4]]) + + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="Verify raises exception for non-TF backends", + ) + def test_raises_exception_ragged_tensor(self): + with self.assertRaises(ValueError): + _ = layers.TextVectorization( + output_mode="int", + vocabulary=["baz", "bar", "foo"], + ragged=True, + ) + + def test_multi_hot_output(self): + layer = layers.TextVectorization( + output_mode="multi_hot", vocabulary=["foo", "bar", "baz"] + ) + input_data = [["foo bar"], ["baz foo foo"]] + output = layer(input_data) + + """ + First batch + Tokens present: ["foo", "bar"] + For each token in vocabulary: + foo (index 1): present -> 1 + bar (index 2): present -> 1 + baz (index 3): absent -> 0 + Result: [0, 1, 1, 0] + + Second batch + Tokens: ["baz", "foo", "foo"] + For each token in vocabulary: + foo (index 1): present -> 1 + bar (index 2): absent -> 0 + baz (index 3): present -> 1 + Result: [0, 1, 0, 1] + """ + self.assertAllClose(output, [[0, 1, 1, 0], [0, 1, 0, 1]]) + + def test_output_mode_count_output(self): + layer = layers.TextVectorization( + output_mode="count", vocabulary=["foo", "bar", "baz"] + ) + output = layer(["foo bar", "baz foo foo"]) + self.assertAllClose(output, [[0, 1, 1, 0], [0, 2, 0, 1]]) + + def test_output_mode_tf_idf_output(self): + layer = layers.TextVectorization( + output_mode="tf_idf", + vocabulary=["foo", "bar", "baz"], + idf_weights=[0.3, 0.5, 0.2], + ) + output = layer(["foo bar", "baz foo foo"]) + self.assertAllClose( + output, [[0.0, 0.3, 0.5, 0.0], [0.0, 0.6, 0.0, 0.2]] + ) + + def test_lower_and_strip_punctuation_standardization(self): + layer = layers.TextVectorization( + standardize="lower_and_strip_punctuation", + vocabulary=["hello", "world", "this", "is", "nice", "test"], + ) + output = layer(["Hello, World!. This is just a nice test!"]) + self.assertTrue(backend.is_tensor(output)) + + # test output sequence length, taking first batch. + self.assertEqual(len(output[0]), 8) + + self.assertAllEqual(output, [[2, 3, 4, 5, 1, 1, 6, 7]]) + + def test_lower_standardization(self): + layer = layers.TextVectorization( + standardize="lower", + vocabulary=[ + "hello,", + "hello", + "world", + "this", + "is", + "nice", + "test", + ], + ) + output = layer(["Hello, World!. This is just a nice test!"]) + self.assertTrue(backend.is_tensor(output)) + self.assertEqual(len(output[0]), 8) + """ + The input is lowercased and tokenized into words. The vocab is: + {0: '', + 1: '[UNK]', + 2: 'hello,', + 3: 'hello', + 4: 'world', + 5: 'this', + 6: 'is', + 7: 'nice', + 8: 'test'} + """ + self.assertAllEqual(output, [[2, 1, 5, 6, 1, 1, 7, 1]]) + + def test_char_splitting(self): + layer = layers.TextVectorization( + split="character", vocabulary=list("abcde"), output_mode="int" + ) + output = layer(["abcf"]) + self.assertTrue(backend.is_tensor(output)) + self.assertEqual(len(output[0]), 4) + self.assertAllEqual(output, [[2, 3, 4, 1]]) + + def test_custom_splitting(self): + def custom_split(text): + return tf.strings.split(text, sep="|") + + layer = layers.TextVectorization( + split=custom_split, + vocabulary=["foo", "bar", "foobar"], + output_mode="int", + ) + output = layer(["foo|bar"]) + self.assertTrue(backend.is_tensor(output)) + + # after custom split, the outputted index should be the last + # token in the vocab. + self.assertAllEqual(output, [[4]]) diff --git a/keras/src/layers/regularization/__init__.py b/keras/src/layers/regularization/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/regularization/activity_regularization.py b/keras/src/layers/regularization/activity_regularization.py new file mode 100644 index 000000000000..a9d663c6d46f --- /dev/null +++ b/keras/src/layers/regularization/activity_regularization.py @@ -0,0 +1,43 @@ +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.ActivityRegularization") +class ActivityRegularization(Layer): + """Layer that applies an update to the cost function based input activity. + + Args: + l1: L1 regularization factor (positive float). + l2: L2 regularization factor (positive float). + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as input. + """ + + def __init__(self, l1=0.0, l2=0.0, **kwargs): + super().__init__( + activity_regularizer=regularizers.L1L2(l1=l1, l2=l2), **kwargs + ) + self.supports_masking = True + self.l1 = l1 + self.l2 = l2 + + self._build_at_init() + + def call(self, inputs): + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + base_config.pop("activity_regularizer", None) + config = {"l1": self.l1, "l2": self.l2} + return {**base_config, **config} diff --git a/keras/src/layers/regularization/activity_regularization_test.py b/keras/src/layers/regularization/activity_regularization_test.py new file mode 100644 index 000000000000..f0b5c92fede1 --- /dev/null +++ b/keras/src/layers/regularization/activity_regularization_test.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src.testing import test_case + + +class ActivityRegularizationTest(test_case.TestCase): + def test_correctness(self): + layer = layers.ActivityRegularization(l1=0.2, l2=0.3) + layer(2 * np.ones((1,))) + self.assertLen(layer.losses, 1) + self.assertAllClose(layer.losses[0], 4 * 0.3 + 2 * 0.2) + + @pytest.mark.requires_trainable_backend + def test_activity_regularization_basics(self): + self.run_layer_test( + layers.ActivityRegularization, + {"l1": 0.1, "l2": 0.2}, + input_shape=(2, 3), + input_dtype="float32", + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=1, + supports_masking=True, + assert_built_after_instantiation=True, + ) diff --git a/keras/src/layers/regularization/alpha_dropout.py b/keras/src/layers/regularization/alpha_dropout.py new file mode 100644 index 000000000000..ebfd68e15917 --- /dev/null +++ b/keras/src/layers/regularization/alpha_dropout.py @@ -0,0 +1,99 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AlphaDropout") +class AlphaDropout(Layer): + """Applies Alpha Dropout to the input. + + Alpha Dropout is a `Dropout` that keeps mean and variance of inputs + to their original values, in order to ensure the self-normalizing property + even after this dropout. + Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by + randomly setting activations to the negative saturation value. + + Args: + rate: Float between 0 and 1. The multiplicative noise will have + standard deviation `sqrt(rate / (1 - rate))`. + noise_shape: 1D integer tensor representing the shape of the + binary alpha dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)` and + you want the alpha dropout mask to be the same for all timesteps, + you can use `noise_shape=(batch_size, 1, features)`. + seed: A Python integer to use as random seed. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding alpha dropout) or in inference mode + (doing nothing). + """ + + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(**kwargs) + if not 0 <= rate <= 1: + raise ValueError( + f"Invalid value received for argument " + "`rate`. Expected a float value between 0 and 1. " + f"Received: rate={rate}" + ) + self.rate = rate + self.seed = seed + self.noise_shape = noise_shape + if rate > 0: + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs, training=False): + if training and self.rate > 0: + noise_shape = self._get_concrete_noise_shape( + inputs, self.noise_shape + ) + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = ops.greater_equal( + ops.random.uniform(noise_shape, seed=self.seed_generator), + self.rate, + ) + kept_idx = ops.cast(kept_idx, inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * self.rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return a * x + b + + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def _get_concrete_noise_shape(self, inputs, noise_shape): + if noise_shape is None: + return ops.shape(inputs) + + concrete_inputs_shape = ops.shape(inputs) + concrete_noise_shape = [] + for i, value in enumerate(noise_shape): + concrete_noise_shape.append( + concrete_inputs_shape[i] if value is None else value + ) + return concrete_noise_shape + + def get_config(self): + base_config = super().get_config() + config = { + "rate": self.rate, + "seed": self.seed, + "noise_shape": self.noise_shape, + } + return {**base_config, **config} diff --git a/keras/src/layers/regularization/alpha_dropout_test.py b/keras/src/layers/regularization/alpha_dropout_test.py new file mode 100644 index 000000000000..2f0408eb6cfc --- /dev/null +++ b/keras/src/layers/regularization/alpha_dropout_test.py @@ -0,0 +1,62 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class AlphaDropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_alpha_dropout_basics(self): + self.run_layer_test( + layers.AlphaDropout, + init_kwargs={ + "rate": 0.2, + }, + input_shape=(2, 3), + call_kwargs={"training": True}, + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_alpha_dropout_correctness(self): + inputs = np.ones((20, 500)).astype("float32") + layer = layers.AlphaDropout(0.3, seed=1337) + outputs = layer(inputs, training=True) + self.assertAllClose( + np.std(backend.convert_to_numpy(outputs)), 1.0, atol=1e-1 + ) + + def test_alpha_dropout_partial_noise_shape_dynamic(self): + inputs = np.ones((20, 5, 10)) + layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_alpha_dropout_partial_noise_shape_static(self): + inputs = np.ones((20, 5, 10)) + layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_alpha_dropout_negative_rate(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.AlphaDropout(rate=-0.5) + + def test_alpha_dropout_rate_greater_than_one(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.AlphaDropout(rate=1.5) diff --git a/keras/src/layers/regularization/dropout.py b/keras/src/layers/regularization/dropout.py new file mode 100644 index 000000000000..0041e65c152c --- /dev/null +++ b/keras/src/layers/regularization/dropout.py @@ -0,0 +1,78 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Dropout") +class Dropout(Layer): + """Applies dropout to the input. + + The `Dropout` layer randomly sets input units to 0 with a frequency of + `rate` at each step during training time, which helps prevent overfitting. + Inputs not set to 0 are scaled up by `1 / (1 - rate)` such that the sum over + all inputs is unchanged. + + Note that the `Dropout` layer only applies when `training` is set to `True` + in `call()`, such that no values are dropped during inference. + When using `model.fit`, `training` will be appropriately set to `True` + automatically. In other contexts, you can set the argument explicitly + to `True` when calling the layer. + + (This is in contrast to setting `trainable=False` for a `Dropout` layer. + `trainable` does not affect the layer's behavior, as `Dropout` does + not have any variables/weights that can be frozen during training.) + + Args: + rate: Float between 0 and 1. Fraction of the input units to drop. + noise_shape: 1D integer tensor representing the shape of the + binary dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)` and + you want the dropout mask to be the same for all timesteps, + you can use `noise_shape=(batch_size, 1, features)`. + seed: A Python integer to use as random seed. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (doing nothing). + """ + + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(**kwargs) + if not 0 <= rate <= 1: + raise ValueError( + f"Invalid value received for argument " + "`rate`. Expected a float value between 0 and 1. " + f"Received: rate={rate}" + ) + self.rate = rate + self.seed = seed + self.noise_shape = noise_shape + if rate > 0: + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs, training=False): + if training and self.rate > 0: + return backend.random.dropout( + inputs, + self.rate, + noise_shape=self.noise_shape, + seed=self.seed_generator, + ) + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "rate": self.rate, + "seed": self.seed, + "noise_shape": self.noise_shape, + } + return {**base_config, **config} diff --git a/keras/src/layers/regularization/dropout_test.py b/keras/src/layers/regularization/dropout_test.py new file mode 100644 index 000000000000..b20a0e4330dd --- /dev/null +++ b/keras/src/layers/regularization/dropout_test.py @@ -0,0 +1,62 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class DropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_dropout_basics(self): + self.run_layer_test( + layers.Dropout, + init_kwargs={ + "rate": 0.2, + }, + input_shape=(2, 3), + call_kwargs={"training": True}, + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_dropout_rescaling(self): + inputs = np.ones((20, 500)) + layer = layers.Dropout(0.5, seed=1337) + outputs = layer(inputs, training=True) + outputs = backend.convert_to_numpy(outputs) + self.assertAllClose(np.mean(outputs), 1.0, atol=0.02) + self.assertAllClose(np.max(outputs), 2.0) + + def test_dropout_partial_noise_shape_dynamic(self): + inputs = np.ones((20, 5, 10)) + layer = layers.Dropout(0.5, noise_shape=(None, 1, None)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_dropout_partial_noise_shape_static(self): + inputs = np.ones((20, 5, 10)) + layer = layers.Dropout(0.5, noise_shape=(20, 1, 10)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_dropout_negative_rate(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.Dropout(rate=-0.5) + + def test_dropout_rate_greater_than_one(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.Dropout(rate=1.5) diff --git a/keras/src/layers/regularization/gaussian_dropout.py b/keras/src/layers/regularization/gaussian_dropout.py new file mode 100644 index 000000000000..dae82edd168d --- /dev/null +++ b/keras/src/layers/regularization/gaussian_dropout.py @@ -0,0 +1,64 @@ +import math + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src.api_export import keras_export + + +@keras_export("keras.layers.GaussianDropout") +class GaussianDropout(layers.Layer): + """Apply multiplicative 1-centered Gaussian noise. + + As it is a regularization layer, it is only active at training time. + + Args: + rate: Float, drop probability (as with `Dropout`). + The multiplicative noise will have + standard deviation `sqrt(rate / (1 - rate))`. + seed: Integer, optional random seed to enable deterministic behavior. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (doing nothing). + """ + + def __init__(self, rate, seed=None, **kwargs): + super().__init__(**kwargs) + if not 0 <= rate <= 1: + raise ValueError( + f"Invalid value received for argument " + "`rate`. Expected a float value between 0 and 1. " + f"Received: rate={rate}" + ) + self.rate = rate + self.seed = seed + if rate > 0: + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs, training=False): + if training and self.rate > 0: + stddev = math.sqrt(self.rate / (1.0 - self.rate)) + return inputs * backend.random.normal( + shape=ops.shape(inputs), + mean=1.0, + stddev=stddev, + dtype=self.compute_dtype, + seed=self.seed_generator, + ) + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "rate": self.rate, + "seed": self.seed, + } + return {**base_config, **config} diff --git a/keras/src/layers/regularization/gaussian_dropout_test.py b/keras/src/layers/regularization/gaussian_dropout_test.py new file mode 100644 index 000000000000..4f376cac7f89 --- /dev/null +++ b/keras/src/layers/regularization/gaussian_dropout_test.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class GaussianDropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_gaussian_dropout_basics(self): + self.run_layer_test( + layers.GaussianDropout, + init_kwargs={ + "rate": 0.2, + }, + input_shape=(2, 3), + call_kwargs={"training": True}, + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_gaussian_dropout_correctness(self): + inputs = np.ones((20, 500)) + layer = layers.GaussianDropout(0.3, seed=1337) + outputs = layer(inputs, training=True) + self.assertAllClose( + np.std(backend.convert_to_numpy(outputs)), + np.sqrt(0.3 / (1 - 0.3)), + atol=0.02, + ) diff --git a/keras/src/layers/regularization/gaussian_noise.py b/keras/src/layers/regularization/gaussian_noise.py new file mode 100644 index 000000000000..561541d4d4dc --- /dev/null +++ b/keras/src/layers/regularization/gaussian_noise.py @@ -0,0 +1,64 @@ +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src.api_export import keras_export + + +@keras_export("keras.layers.GaussianNoise") +class GaussianNoise(layers.Layer): + """Apply additive zero-centered Gaussian noise. + + This is useful to mitigate overfitting + (you could see it as a form of random data augmentation). + Gaussian Noise (GS) is a natural choice as corruption process + for real valued inputs. + + As it is a regularization layer, it is only active at training time. + + Args: + stddev: Float, standard deviation of the noise distribution. + seed: Integer, optional random seed to enable deterministic behavior. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding noise) or in inference mode (doing nothing). + """ + + def __init__(self, stddev, seed=None, **kwargs): + super().__init__(**kwargs) + if not 0 <= stddev <= 1: + raise ValueError( + f"Invalid value received for argument " + "`stddev`. Expected a float value between 0 and 1. " + f"Received: stddev={stddev}" + ) + self.stddev = stddev + self.seed = seed + if stddev > 0: + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + + self._build_at_init() + + def call(self, inputs, training=False): + if training and self.stddev > 0: + return inputs + backend.random.normal( + shape=ops.shape(inputs), + mean=0.0, + stddev=self.stddev, + dtype=self.compute_dtype, + seed=self.seed_generator, + ) + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "stddev": self.stddev, + "seed": self.seed, + } + return {**base_config, **config} diff --git a/keras/src/layers/regularization/gaussian_noise_test.py b/keras/src/layers/regularization/gaussian_noise_test.py new file mode 100644 index 000000000000..e47f6b182e52 --- /dev/null +++ b/keras/src/layers/regularization/gaussian_noise_test.py @@ -0,0 +1,34 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class GaussianNoiseTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_gaussian_noise_basics(self): + self.run_layer_test( + layers.GaussianNoise, + init_kwargs={ + "stddev": 0.2, + }, + input_shape=(2, 3), + call_kwargs={"training": True}, + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + assert_built_after_instantiation=True, + ) + + def test_gaussian_noise_correctness(self): + inputs = np.ones((20, 500)) + layer = layers.GaussianNoise(0.3, seed=1337) + outputs = layer(inputs, training=True) + self.assertAllClose( + np.std(backend.convert_to_numpy(outputs)), 0.3, atol=0.02 + ) diff --git a/keras/src/layers/regularization/spatial_dropout.py b/keras/src/layers/regularization/spatial_dropout.py new file mode 100644 index 000000000000..5f440164f40d --- /dev/null +++ b/keras/src/layers/regularization/spatial_dropout.py @@ -0,0 +1,192 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.regularization.dropout import Dropout + + +class BaseSpatialDropout(Dropout): + def __init__(self, rate, seed=None, name=None, dtype=None): + super().__init__(rate, seed=seed, name=name, dtype=dtype) + + def call(self, inputs, training=False): + if training and self.rate > 0: + return backend.random.dropout( + inputs, + self.rate, + noise_shape=self._get_noise_shape(inputs), + seed=self.seed_generator, + ) + return inputs + + def get_config(self): + return { + "rate": self.rate, + "seed": self.seed, + "name": self.name, + "dtype": self.dtype, + } + + +@keras_export("keras.layers.SpatialDropout1D") +class SpatialDropout1D(BaseSpatialDropout): + """Spatial 1D version of Dropout. + + This layer performs the same function as Dropout, however, it drops + entire 1D feature maps instead of individual elements. If adjacent frames + within feature maps are strongly correlated (as is normally the case in + early convolution layers) then regular dropout will not regularize the + activations and will otherwise just result in an effective learning rate + decrease. In this case, `SpatialDropout1D` will help promote independence + between feature maps and should be used instead. + + Args: + rate: Float between 0 and 1. Fraction of the input units to drop. + + Call arguments: + inputs: A 3D tensor. + training: Python boolean indicating whether the layer + should behave in training mode (applying dropout) + or in inference mode (pass-through). + + Input shape: + 3D tensor with shape: `(samples, timesteps, channels)` + + Output shape: Same as input. + + Reference: + + - [Tompson et al., 2014](https://arxiv.org/abs/1411.4280) + """ + + def __init__(self, rate, seed=None, name=None, dtype=None): + super().__init__(rate, seed=seed, name=name, dtype=dtype) + self.input_spec = InputSpec(ndim=3) + + def _get_noise_shape(self, inputs): + input_shape = ops.shape(inputs) + return (input_shape[0], 1, input_shape[2]) + + +@keras_export("keras.layers.SpatialDropout2D") +class SpatialDropout2D(BaseSpatialDropout): + """Spatial 2D version of Dropout. + + This version performs the same function as Dropout, however, it drops + entire 2D feature maps instead of individual elements. If adjacent pixels + within feature maps are strongly correlated (as is normally the case in + early convolution layers) then regular dropout will not regularize the + activations and will otherwise just result in an effective learning rate + decrease. In this case, `SpatialDropout2D` will help promote independence + between feature maps and should be used instead. + + Args: + rate: Float between 0 and 1. Fraction of the input units to drop. + data_format: `"channels_first"` or `"channels_last"`. + In `"channels_first"` mode, the channels dimension (the depth) + is at index 1, in `"channels_last"` mode is it at index 3. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + + Call arguments: + inputs: A 4D tensor. + training: Python boolean indicating whether the layer + should behave in training mode (applying dropout) + or in inference mode (pass-through). + + Input shape: + 4D tensor with shape: `(samples, channels, rows, cols)` if + data_format='channels_first' + or 4D tensor with shape: `(samples, rows, cols, channels)` if + data_format='channels_last'. + + Output shape: Same as input. + + Reference: + + - [Tompson et al., 2014](https://arxiv.org/abs/1411.4280) + """ + + def __init__( + self, rate, data_format=None, seed=None, name=None, dtype=None + ): + super().__init__(rate, seed=seed, name=name, dtype=dtype) + self.data_format = backend.standardize_data_format(data_format) + self.input_spec = InputSpec(ndim=4) + + def _get_noise_shape(self, inputs): + input_shape = ops.shape(inputs) + if self.data_format == "channels_first": + return (input_shape[0], input_shape[1], 1, 1) + elif self.data_format == "channels_last": + return (input_shape[0], 1, 1, input_shape[3]) + + def get_config(self): + base_config = super().get_config() + config = { + "data_format": self.data_format, + } + return {**base_config, **config} + + +@keras_export("keras.layers.SpatialDropout3D") +class SpatialDropout3D(BaseSpatialDropout): + """Spatial 3D version of Dropout. + + This version performs the same function as Dropout, however, it drops + entire 3D feature maps instead of individual elements. If adjacent voxels + within feature maps are strongly correlated (as is normally the case in + early convolution layers) then regular dropout will not regularize the + activations and will otherwise just result in an effective learning rate + decrease. In this case, SpatialDropout3D will help promote independence + between feature maps and should be used instead. + + Args: + rate: Float between 0 and 1. Fraction of the input units to drop. + data_format: `"channels_first"` or `"channels_last"`. + In `"channels_first"` mode, the channels dimension (the depth) + is at index 1, in `"channels_last"` mode is it at index 4. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + + Call arguments: + inputs: A 5D tensor. + training: Python boolean indicating whether the layer + should behave in training mode (applying dropout) + or in inference mode (pass-through). + + Input shape: + 5D tensor with shape: `(samples, channels, dim1, dim2, dim3)` if + data_format='channels_first' + or 5D tensor with shape: `(samples, dim1, dim2, dim3, channels)` if + data_format='channels_last'. + + Output shape: Same as input. + + Reference: + + - [Tompson et al., 2014](https://arxiv.org/abs/1411.4280) + """ + + def __init__( + self, rate, data_format=None, seed=None, name=None, dtype=None + ): + super().__init__(rate, seed=seed, name=name, dtype=dtype) + self.data_format = backend.standardize_data_format(data_format) + self.input_spec = InputSpec(ndim=5) + + def _get_noise_shape(self, inputs): + input_shape = ops.shape(inputs) + if self.data_format == "channels_first": + return (input_shape[0], input_shape[1], 1, 1, 1) + elif self.data_format == "channels_last": + return (input_shape[0], 1, 1, 1, input_shape[4]) + + def get_config(self): + base_config = super().get_config() + config = { + "data_format": self.data_format, + } + return {**base_config, **config} diff --git a/keras/src/layers/regularization/spatial_dropout_test.py b/keras/src/layers/regularization/spatial_dropout_test.py new file mode 100644 index 000000000000..8b3e664bc6c0 --- /dev/null +++ b/keras/src/layers/regularization/spatial_dropout_test.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src.testing import test_case + + +class SpatialDropoutTest(test_case.TestCase): + @pytest.mark.requires_trainable_backend + def test_spatial_dropout_1d(self): + self.run_layer_test( + layers.SpatialDropout1D, + init_kwargs={"rate": 0.5}, + call_kwargs={"training": True}, + input_shape=(2, 3, 4), + assert_built_after_instantiation=True, + ) + + self.run_layer_test( + layers.SpatialDropout1D, + init_kwargs={"rate": 0.5}, + call_kwargs={"training": False}, + input_shape=(2, 3, 4), + assert_built_after_instantiation=True, + ) + + @pytest.mark.requires_trainable_backend + def test_spatial_dropout_2d(self): + self.run_layer_test( + layers.SpatialDropout2D, + init_kwargs={"rate": 0.5}, + call_kwargs={"training": True}, + input_shape=(2, 3, 4, 5), + assert_built_after_instantiation=True, + ) + + self.run_layer_test( + layers.SpatialDropout2D, + init_kwargs={"rate": 0.5, "data_format": "channels_first"}, + call_kwargs={"training": True}, + input_shape=(2, 3, 4, 5), + assert_built_after_instantiation=True, + ) + + @pytest.mark.requires_trainable_backend + def test_spatial_dropout_3d(self): + self.run_layer_test( + layers.SpatialDropout3D, + init_kwargs={"rate": 0.5}, + call_kwargs={"training": True}, + input_shape=(2, 3, 4, 4, 5), + assert_built_after_instantiation=True, + ) + + self.run_layer_test( + layers.SpatialDropout3D, + init_kwargs={"rate": 0.5, "data_format": "channels_first"}, + call_kwargs={"training": True}, + input_shape=(2, 3, 4, 4, 5), + assert_built_after_instantiation=True, + ) + + def test_spatial_dropout_1D_dynamic(self): + inputs = layers.Input((3, 2)) + layer = layers.SpatialDropout1D(0.5) + layer(inputs, training=True) + + def test_spatial_dropout_1D_correctness(self): + inputs = np.ones((10, 3, 10)) + layer = layers.SpatialDropout1D(0.5) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_spatial_dropout_2D_dynamic(self): + inputs = layers.Input((3, 2, 4)) + layer = layers.SpatialDropout2D(0.5) + layer(inputs, training=True) + + def test_spatial_dropout_2D_correctness(self): + if backend.config.image_data_format() == "channels_last": + inputs = np.ones((10, 3, 3, 10)) + else: + inputs = np.ones((10, 10, 3, 3)) + layer = layers.SpatialDropout2D(0.5) + outputs = layer(inputs, training=True) + if backend.config.image_data_format() == "channels_last": + self.assertAllClose(outputs[:, 0, 0, :], outputs[:, 1, 1, :]) + else: + self.assertAllClose(outputs[:, :, 0, 0], outputs[:, :, 1, 1]) + + def test_spatial_dropout_3D_dynamic(self): + inputs = layers.Input((3, 2, 4, 2)) + layer = layers.SpatialDropout3D(0.5) + layer(inputs, training=True) + + def test_spatial_dropout_3D_correctness(self): + if backend.config.image_data_format() == "channels_last": + inputs = np.ones((10, 3, 3, 3, 10)) + else: + inputs = np.ones((10, 10, 3, 3, 3)) + layer = layers.SpatialDropout3D(0.5) + outputs = layer(inputs, training=True) + if backend.config.image_data_format() == "channels_last": + self.assertAllClose(outputs[:, 0, 0, 0, :], outputs[:, 1, 1, 1, :]) + else: + self.assertAllClose(outputs[:, :, 0, 0, 0], outputs[:, :, 1, 1, 1]) diff --git a/keras/src/layers/reshaping/__init__.py b/keras/src/layers/reshaping/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/reshaping/cropping1d.py b/keras/src/layers/reshaping/cropping1d.py new file mode 100644 index 000000000000..abce618dff65 --- /dev/null +++ b/keras/src/layers/reshaping/cropping1d.py @@ -0,0 +1,82 @@ +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.Cropping1D") +class Cropping1D(Layer): + """Cropping layer for 1D input (e.g. temporal sequence). + + It crops along the time dimension (axis 1). + + Example: + + >>> input_shape = (2, 3, 2) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> x + [[[ 0 1] + [ 2 3] + [ 4 5]] + [[ 6 7] + [ 8 9] + [10 11]]] + >>> y = keras.layers.Cropping1D(cropping=1)(x) + >>> y + [[[2 3]] + [[8 9]]] + + Args: + cropping: Int, or tuple of int (length 2), or dictionary. + - If int: how many units should be trimmed off at the beginning and + end of the cropping dimension (axis 1). + - If tuple of 2 ints: how many units should be trimmed off at the + beginning and end of the cropping dimension + (`(left_crop, right_crop)`). + + Input shape: + 3D tensor with shape `(batch_size, axis_to_crop, features)` + + Output shape: + 3D tensor with shape `(batch_size, cropped_axis, features)` + """ + + def __init__(self, cropping=(1, 1), **kwargs): + super().__init__(**kwargs) + self.cropping = argument_validation.standardize_tuple( + cropping, 2, "cropping", allow_zero=True + ) + self.input_spec = InputSpec(ndim=3) + + def compute_output_shape(self, input_shape): + if input_shape[1] is not None: + length = input_shape[1] - self.cropping[0] - self.cropping[1] + if length <= 0: + raise ValueError( + "`cropping` parameter of `Cropping1D` layer must be " + "smaller than the input length. Received: input_shape=" + f"{input_shape}, cropping={self.cropping}" + ) + else: + length = None + return (input_shape[0], length, input_shape[2]) + + def call(self, inputs): + if ( + inputs.shape[1] is not None + and sum(self.cropping) >= inputs.shape[1] + ): + raise ValueError( + "`cropping` parameter of `Cropping1D` layer must be " + "smaller than the input length. Received: inputs.shape=" + f"{inputs.shape}, cropping={self.cropping}" + ) + if self.cropping[1] == 0: + return inputs[:, self.cropping[0] :, :] + else: + return inputs[:, self.cropping[0] : -self.cropping[1], :] + + def get_config(self): + config = {"cropping": self.cropping} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/cropping1d_test.py b/keras/src/layers/reshaping/cropping1d_test.py new file mode 100644 index 000000000000..cceb5922d92e --- /dev/null +++ b/keras/src/layers/reshaping/cropping1d_test.py @@ -0,0 +1,80 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class Cropping1DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_cropping_1d(self): + inputs = np.random.rand(3, 5, 7) + + # Cropping with different values on the left and the right. + self.run_layer_test( + layers.Cropping1D, + init_kwargs={"cropping": (1, 2)}, + input_data=inputs, + expected_output=ops.convert_to_tensor(inputs[:, 1:3, :]), + ) + # Same cropping on the left and the right. + self.run_layer_test( + layers.Cropping1D, + init_kwargs={"cropping": (1, 1)}, + input_data=inputs, + expected_output=ops.convert_to_tensor(inputs[:, 1:4, :]), + ) + # Same cropping on the left and the right provided as an int. + self.run_layer_test( + layers.Cropping1D, + init_kwargs={"cropping": 1}, + input_data=inputs, + expected_output=ops.convert_to_tensor(inputs[:, 1:4, :]), + ) + # Cropping on the right only. + self.run_layer_test( + layers.Cropping1D, + init_kwargs={"cropping": (0, 1)}, + input_data=inputs, + expected_output=ops.convert_to_tensor(inputs[:, 0:4, :]), + ) + # Cropping on the left only. + self.run_layer_test( + layers.Cropping1D, + init_kwargs={"cropping": (1, 0)}, + input_data=inputs, + expected_output=ops.convert_to_tensor(inputs[:, 1:5, :]), + ) + + @pytest.mark.requires_trainable_backend + def test_cropping_1d_with_dynamic_spatial_dim(self): + input_layer = layers.Input(batch_shape=(1, None, 7)) + cropped = layers.Cropping1D((1, 2))(input_layer) + self.assertEqual(cropped.shape, (1, None, 7)) + + def test_cropping_1d_errors_if_cropping_argument_invalid(self): + with self.assertRaises(ValueError): + layers.Cropping1D(cropping=(1,)) + with self.assertRaises(ValueError): + layers.Cropping1D(cropping=(1, 2, 3)) + with self.assertRaises(ValueError): + layers.Cropping1D(cropping="1") + + def test_cropping_1d_errors_if_cropping_more_than_available(self): + with self.assertRaisesRegex( + ValueError, + "`cropping` parameter of `Cropping1D` layer must be smaller than", + ): + input_layer = layers.Input(batch_shape=(3, 5, 7)) + layers.Cropping1D(cropping=(2, 3))(input_layer) + + def test_cropping_1d_error_on_excessive_cropping(self): + inputs = np.random.rand(3, 5, 7) + + with self.assertRaisesRegex( + ValueError, + "`cropping` parameter of `Cropping1D` layer must be smaller than", + ): + layer = layers.Cropping1D(cropping=(3, 3)) + _ = layer(inputs) diff --git a/keras/src/layers/reshaping/cropping2d.py b/keras/src/layers/reshaping/cropping2d.py new file mode 100644 index 000000000000..aec6813a861f --- /dev/null +++ b/keras/src/layers/reshaping/cropping2d.py @@ -0,0 +1,224 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.Cropping2D") +class Cropping2D(Layer): + """Cropping layer for 2D input (e.g. picture). + + It crops along spatial dimensions, i.e. height and width. + + Example: + + >>> input_shape = (2, 28, 28, 3) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> y = keras.layers.Cropping2D(cropping=((2, 2), (4, 4)))(x) + >>> y.shape + (2, 24, 20, 3) + + Args: + cropping: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. + - If int: the same symmetric cropping is applied to height and + width. + - If tuple of 2 ints: interpreted as two different symmetric + cropping values for height and width: + `(symmetric_height_crop, symmetric_width_crop)`. + - If tuple of 2 tuples of 2 ints: interpreted as + `((top_crop, bottom_crop), (left_crop, right_crop))`. + data_format: A string, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch_size, channels, height, width)`. + When unspecified, uses `image_data_format` value found in your Keras + config file at `~/.keras/keras.json` (if exists). Defaults to + `"channels_last"`. + + Input shape: + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, height, width, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, height, width)` + + Output shape: + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, cropped_height, cropped_width, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, cropped_height, cropped_width)` + """ + + def __init__(self, cropping=((0, 0), (0, 0)), data_format=None, **kwargs): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + if isinstance(cropping, int): + if cropping < 0: + raise ValueError( + "`cropping` cannot be negative. " + f"Received: cropping={cropping}." + ) + self.cropping = ((cropping, cropping), (cropping, cropping)) + elif hasattr(cropping, "__len__"): + if len(cropping) != 2: + raise ValueError( + "`cropping` should have two elements. " + f"Received: cropping={cropping}." + ) + height_cropping = argument_validation.standardize_tuple( + cropping[0], 2, "1st entry of cropping", allow_zero=True + ) + width_cropping = argument_validation.standardize_tuple( + cropping[1], 2, "2nd entry of cropping", allow_zero=True + ) + self.cropping = (height_cropping, width_cropping) + else: + raise ValueError( + "`cropping` should be either an int, a tuple of 2 ints " + "(symmetric_height_crop, symmetric_width_crop), " + "or a tuple of 2 tuples of 2 ints " + "((top_crop, bottom_crop), (left_crop, right_crop)). " + f"Received: cropping={cropping}." + ) + self.input_spec = InputSpec(ndim=4) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + if ( + input_shape[2] is not None + and sum(self.cropping[0]) >= input_shape[2] + ) or ( + input_shape[3] is not None + and sum(self.cropping[1]) >= input_shape[3] + ): + raise ValueError( + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input. Received: " + f"input_shape={input_shape}, cropping={self.cropping}" + ) + return ( + input_shape[0], + input_shape[1], + ( + input_shape[2] - self.cropping[0][0] - self.cropping[0][1] + if input_shape[2] is not None + else None + ), + ( + input_shape[3] - self.cropping[1][0] - self.cropping[1][1] + if input_shape[3] is not None + else None + ), + ) + else: + if ( + input_shape[1] is not None + and sum(self.cropping[0]) >= input_shape[1] + ) or ( + input_shape[2] is not None + and sum(self.cropping[1]) >= input_shape[2] + ): + raise ValueError( + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input. Received: " + f"input_shape={input_shape}, cropping={self.cropping}" + ) + return ( + input_shape[0], + ( + input_shape[1] - self.cropping[0][0] - self.cropping[0][1] + if input_shape[1] is not None + else None + ), + ( + input_shape[2] - self.cropping[1][0] - self.cropping[1][1] + if input_shape[2] is not None + else None + ), + input_shape[3], + ) + + def call(self, inputs): + if self.data_format == "channels_first": + if ( + inputs.shape[2] is not None + and sum(self.cropping[0]) >= inputs.shape[2] + ) or ( + inputs.shape[3] is not None + and sum(self.cropping[1]) >= inputs.shape[3] + ): + raise ValueError( + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input. Received: " + f"inputs.shape={inputs.shape}, cropping={self.cropping}" + ) + if self.cropping[0][1] == self.cropping[1][1] == 0: + return inputs[ + :, :, self.cropping[0][0] :, self.cropping[1][0] : + ] + elif self.cropping[0][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] :, + self.cropping[1][0] : -self.cropping[1][1], + ] + elif self.cropping[1][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] :, + ] + return inputs[ + :, + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] : -self.cropping[1][1], + ] + else: + if ( + inputs.shape[1] is not None + and sum(self.cropping[0]) >= inputs.shape[1] + ) or ( + inputs.shape[2] is not None + and sum(self.cropping[1]) >= inputs.shape[2] + ): + raise ValueError( + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input. Received: " + f"inputs.shape={inputs.shape}, cropping={self.cropping}" + ) + if self.cropping[0][1] == self.cropping[1][1] == 0: + return inputs[ + :, self.cropping[0][0] :, self.cropping[1][0] :, : + ] + elif self.cropping[0][1] == 0: + return inputs[ + :, + self.cropping[0][0] :, + self.cropping[1][0] : -self.cropping[1][1], + :, + ] + elif self.cropping[1][1] == 0: + return inputs[ + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] :, + :, + ] + return inputs[ + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] : -self.cropping[1][1], + :, + ] + + def get_config(self): + config = {"cropping": self.cropping, "data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/cropping2d_test.py b/keras/src/layers/reshaping/cropping2d_test.py new file mode 100644 index 000000000000..5a04dc5a78f2 --- /dev/null +++ b/keras/src/layers/reshaping/cropping2d_test.py @@ -0,0 +1,135 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class Cropping2DTest(testing.TestCase): + @parameterized.product( + ( + # different cropping values + {"cropping": ((1, 2), (3, 4)), "expected_ranges": ((1, 5), (3, 5))}, + # same cropping values with 2 tuples + {"cropping": ((2, 2), (2, 2)), "expected_ranges": ((2, 5), (2, 7))}, + # same cropping values with 1 tuple + {"cropping": (2, 2), "expected_ranges": ((2, 5), (2, 7))}, + # same cropping values with an integer + {"cropping": 2, "expected_ranges": ((2, 5), (2, 7))}, + # cropping right only in both dimensions + {"cropping": ((0, 2), (0, 4)), "expected_ranges": ((0, 5), (0, 5))}, + # cropping left only in both dimensions + {"cropping": ((1, 0), (3, 0)), "expected_ranges": ((1, 7), (3, 9))}, + # cropping left only in rows dimension + {"cropping": ((1, 0), (3, 4)), "expected_ranges": ((1, 7), (3, 5))}, + # cropping left only in cols dimension + {"cropping": ((1, 2), (3, 0)), "expected_ranges": ((1, 5), (3, 9))}, + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + @pytest.mark.requires_trainable_backend + def test_cropping_2d(self, cropping, data_format, expected_ranges): + if data_format == "channels_first": + inputs = np.random.rand(3, 5, 7, 9) + expected_output = ops.convert_to_tensor( + inputs[ + :, + :, + expected_ranges[0][0] : expected_ranges[0][1], + expected_ranges[1][0] : expected_ranges[1][1], + ] + ) + else: + inputs = np.random.rand(3, 7, 9, 5) + expected_output = ops.convert_to_tensor( + inputs[ + :, + expected_ranges[0][0] : expected_ranges[0][1], + expected_ranges[1][0] : expected_ranges[1][1], + :, + ] + ) + + self.run_layer_test( + layers.Cropping2D, + init_kwargs={"cropping": cropping, "data_format": data_format}, + input_data=inputs, + expected_output=expected_output, + ) + + def test_cropping_2d_with_dynamic_spatial_dim(self): + if backend.config.image_data_format() == "channels_last": + input_layer = layers.Input(batch_shape=(1, 7, None, 5)) + else: + input_layer = layers.Input(batch_shape=(1, 5, 7, None)) + cropped = layers.Cropping2D(((1, 2), (3, 4)))(input_layer) + if backend.config.image_data_format() == "channels_last": + self.assertEqual(cropped.shape, (1, 4, None, 5)) + else: + self.assertEqual(cropped.shape, (1, 5, 4, None)) + + @parameterized.product( + ( + {"cropping": ((3, 6), (0, 0))}, + {"cropping": ((0, 0), (5, 4))}, + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + def test_cropping_2d_errors_if_cropping_more_than_available( + self, cropping, data_format + ): + input_layer = layers.Input(batch_shape=(3, 7, 9, 5)) + with self.assertRaises(ValueError): + layers.Cropping2D(cropping=cropping, data_format=data_format)( + input_layer + ) + + def test_cropping_2d_errors_if_cropping_argument_invalid(self): + with self.assertRaises(ValueError): + layers.Cropping2D(cropping=(1,)) + with self.assertRaises(ValueError): + layers.Cropping2D(cropping=(1, 2, 3)) + with self.assertRaises(ValueError): + layers.Cropping2D(cropping="1") + with self.assertRaises(ValueError): + layers.Cropping2D(cropping=((1, 2), (3, 4, 5))) + with self.assertRaises(ValueError): + layers.Cropping2D(cropping=((1, 2), (3, -4))) + with self.assertRaises(ValueError): + layers.Cropping2D(cropping=((1, 2), "3")) + + @parameterized.product( + ( + {"cropping": ((4, 5), (0, 0)), "input_shape": (3, 8, 9, 5)}, + {"cropping": ((0, 0), (5, 5)), "input_shape": (3, 8, 9, 5)}, + {"cropping": ((6, 3), (0, 0)), "input_shape": (3, 8, 9, 5)}, + {"cropping": ((0, 0), (7, 3)), "input_shape": (3, 8, 9, 5)}, + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + def test_cropping_2d_error_on_excessive_cropping( + self, cropping, input_shape, data_format + ): + inputs = np.random.rand(*input_shape) + + with self.assertRaisesRegex( + ValueError, + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input.", + ): + layer = layers.Cropping2D( + cropping=cropping, data_format=data_format + ) + _ = layer(inputs) diff --git a/keras/src/layers/reshaping/cropping3d.py b/keras/src/layers/reshaping/cropping3d.py new file mode 100644 index 000000000000..724d0cf72635 --- /dev/null +++ b/keras/src/layers/reshaping/cropping3d.py @@ -0,0 +1,284 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.Cropping3D") +class Cropping3D(Layer): + """Cropping layer for 3D data (e.g. spatial or spatio-temporal). + + Example: + + >>> input_shape = (2, 28, 28, 10, 3) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> y = keras.layers.Cropping3D(cropping=(2, 4, 2))(x) + >>> y.shape + (2, 24, 20, 6, 3) + + Args: + cropping: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. + - If int: the same symmetric cropping is applied to depth, height, + and width. + - If tuple of 3 ints: interpreted as three different symmetric + cropping values for depth, height, and width: + `(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`. + - If tuple of 3 tuples of 2 ints: interpreted as + `((left_dim1_crop, right_dim1_crop), (left_dim2_crop, + right_dim2_crop), (left_dim3_crop, right_dim3_crop))`. + data_format: A string, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + When unspecified, uses `image_data_format` value found in your Keras + config file at `~/.keras/keras.json` (if exists). Defaults to + `"channels_last"`. + + Input shape: + 5D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, first_axis_to_crop, second_axis_to_crop, + third_axis_to_crop, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, first_axis_to_crop, second_axis_to_crop, + third_axis_to_crop)` + + Output shape: + 5D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, first_cropped_axis, second_cropped_axis, + third_cropped_axis, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, first_cropped_axis, second_cropped_axis, + third_cropped_axis)` + """ + + def __init__( + self, cropping=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs + ): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + if isinstance(cropping, int): + if cropping < 0: + raise ValueError( + "`cropping` cannot be negative. " + f"Received: cropping={cropping}." + ) + self.cropping = ( + (cropping, cropping), + (cropping, cropping), + (cropping, cropping), + ) + elif hasattr(cropping, "__len__"): + if len(cropping) != 3: + raise ValueError( + f"`cropping` should have 3 elements. Received: {cropping}." + ) + dim1_cropping = argument_validation.standardize_tuple( + cropping[0], 2, "1st entry of cropping", allow_zero=True + ) + dim2_cropping = argument_validation.standardize_tuple( + cropping[1], 2, "2nd entry of cropping", allow_zero=True + ) + dim3_cropping = argument_validation.standardize_tuple( + cropping[2], 2, "3rd entry of cropping", allow_zero=True + ) + self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping) + else: + raise ValueError( + "`cropping` should be either an int, a tuple of 3 ints " + "(symmetric_dim1_crop, symmetric_dim2_crop, " + "symmetric_dim3_crop), " + "or a tuple of 3 tuples of 2 ints " + "((left_dim1_crop, right_dim1_crop)," + " (left_dim2_crop, right_dim2_crop)," + " (left_dim3_crop, right_dim2_crop)). " + f"Received: {cropping}." + ) + self.input_spec = InputSpec(ndim=5) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + spatial_dims = list(input_shape[2:5]) + else: + spatial_dims = list(input_shape[1:4]) + + for index in range(0, 3): + if spatial_dims[index] is None: + continue + spatial_dims[index] -= sum(self.cropping[index]) + if spatial_dims[index] <= 0: + raise ValueError( + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input. Received: " + f"input_shape={input_shape}, cropping={self.cropping}" + ) + + if self.data_format == "channels_first": + return (input_shape[0], input_shape[1], *spatial_dims) + else: + return (input_shape[0], *spatial_dims, input_shape[4]) + + def call(self, inputs): + if self.data_format == "channels_first": + spatial_dims = list(inputs.shape[2:5]) + else: + spatial_dims = list(inputs.shape[1:4]) + + for index in range(0, 3): + if spatial_dims[index] is None: + continue + spatial_dims[index] -= sum(self.cropping[index]) + if spatial_dims[index] <= 0: + raise ValueError( + "Values in `cropping` argument should be smaller than the " + "corresponding spatial dimension of the input. Received: " + f"inputs.shape={inputs.shape}, cropping={self.cropping}" + ) + + if self.data_format == "channels_first": + if ( + self.cropping[0][1] + == self.cropping[1][1] + == self.cropping[2][1] + == 0 + ): + return inputs[ + :, + :, + self.cropping[0][0] :, + self.cropping[1][0] :, + self.cropping[2][0] :, + ] + elif self.cropping[0][1] == self.cropping[1][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] :, + self.cropping[1][0] :, + self.cropping[2][0] : -self.cropping[2][1], + ] + elif self.cropping[1][1] == self.cropping[2][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] :, + self.cropping[2][0] :, + ] + elif self.cropping[0][1] == self.cropping[2][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] :, + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] :, + ] + elif self.cropping[0][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] :, + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] : -self.cropping[2][1], + ] + elif self.cropping[1][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] :, + self.cropping[2][0] : -self.cropping[2][1], + ] + elif self.cropping[2][1] == 0: + return inputs[ + :, + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] :, + ] + return inputs[ + :, + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] : -self.cropping[2][1], + ] + else: + if ( + self.cropping[0][1] + == self.cropping[1][1] + == self.cropping[2][1] + == 0 + ): + return inputs[ + :, + self.cropping[0][0] :, + self.cropping[1][0] :, + self.cropping[2][0] :, + :, + ] + elif self.cropping[0][1] == self.cropping[1][1] == 0: + return inputs[ + :, + self.cropping[0][0] :, + self.cropping[1][0] :, + self.cropping[2][0] : -self.cropping[2][1], + :, + ] + elif self.cropping[1][1] == self.cropping[2][1] == 0: + return inputs[ + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] :, + self.cropping[2][0] :, + :, + ] + elif self.cropping[0][1] == self.cropping[2][1] == 0: + return inputs[ + :, + self.cropping[0][0] :, + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] :, + :, + ] + elif self.cropping[0][1] == 0: + return inputs[ + :, + self.cropping[0][0] :, + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] : -self.cropping[2][1], + :, + ] + elif self.cropping[1][1] == 0: + return inputs[ + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] :, + self.cropping[2][0] : -self.cropping[2][1], + :, + ] + elif self.cropping[2][1] == 0: + return inputs[ + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] :, + :, + ] + return inputs[ + :, + self.cropping[0][0] : -self.cropping[0][1], + self.cropping[1][0] : -self.cropping[1][1], + self.cropping[2][0] : -self.cropping[2][1], + :, + ] + + def get_config(self): + config = {"cropping": self.cropping, "data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/cropping3d_test.py b/keras/src/layers/reshaping/cropping3d_test.py new file mode 100644 index 000000000000..30b540ae226d --- /dev/null +++ b/keras/src/layers/reshaping/cropping3d_test.py @@ -0,0 +1,199 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class Cropping3DTest(testing.TestCase): + @parameterized.product( + ( + {"dim1_cropping": (1, 2), "dim1_expected": (1, 5)}, # both + {"dim1_cropping": (0, 2), "dim1_expected": (0, 5)}, # left only + {"dim1_cropping": (1, 0), "dim1_expected": (1, 7)}, # right only + ), + ( + {"dim2_cropping": (3, 4), "dim2_expected": (3, 5)}, # both + {"dim2_cropping": (0, 4), "dim2_expected": (0, 5)}, # left only + {"dim2_cropping": (3, 0), "dim2_expected": (3, 9)}, # right only + ), + ( + {"dim3_cropping": (5, 6), "dim3_expected": (5, 7)}, # both + {"dim3_cropping": (0, 6), "dim3_expected": (0, 7)}, # left only + {"dim3_cropping": (5, 0), "dim3_expected": (5, 13)}, # right only + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + @pytest.mark.requires_trainable_backend + def test_cropping_3d( + self, + dim1_cropping, + dim2_cropping, + dim3_cropping, + data_format, + dim1_expected, + dim2_expected, + dim3_expected, + ): + if data_format == "channels_first": + inputs = np.random.rand(3, 5, 7, 9, 13) + expected_output = ops.convert_to_tensor( + inputs[ + :, + :, + dim1_expected[0] : dim1_expected[1], + dim2_expected[0] : dim2_expected[1], + dim3_expected[0] : dim3_expected[1], + ] + ) + else: + inputs = np.random.rand(3, 7, 9, 13, 5) + expected_output = ops.convert_to_tensor( + inputs[ + :, + dim1_expected[0] : dim1_expected[1], + dim2_expected[0] : dim2_expected[1], + dim3_expected[0] : dim3_expected[1], + :, + ] + ) + + cropping = (dim1_cropping, dim2_cropping, dim3_cropping) + self.run_layer_test( + layers.Cropping3D, + init_kwargs={"cropping": cropping, "data_format": data_format}, + input_data=inputs, + expected_output=expected_output, + ) + + @parameterized.product( + ( + # same cropping values with 3 tuples + { + "cropping": ((2, 2), (2, 2), (2, 2)), + "expected": ((2, 5), (2, 7), (2, 11)), + }, + # same cropping values with 1 tuple + {"cropping": (2, 2, 2), "expected": ((2, 5), (2, 7), (2, 11))}, + # same cropping values with an integer + {"cropping": 2, "expected": ((2, 5), (2, 7), (2, 11))}, + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + @pytest.mark.requires_trainable_backend + def test_cropping_3d_with_same_cropping( + self, cropping, data_format, expected + ): + if data_format == "channels_first": + inputs = np.random.rand(3, 5, 7, 9, 13) + expected_output = ops.convert_to_tensor( + inputs[ + :, + :, + expected[0][0] : expected[0][1], + expected[1][0] : expected[1][1], + expected[2][0] : expected[2][1], + ] + ) + else: + inputs = np.random.rand(3, 7, 9, 13, 5) + expected_output = ops.convert_to_tensor( + inputs[ + :, + expected[0][0] : expected[0][1], + expected[1][0] : expected[1][1], + expected[2][0] : expected[2][1], + :, + ] + ) + + self.run_layer_test( + layers.Cropping3D, + init_kwargs={"cropping": cropping, "data_format": data_format}, + input_data=inputs, + expected_output=expected_output, + ) + + def test_cropping_3d_with_dynamic_spatial_dim(self): + if backend.config.image_data_format() == "channels_last": + input_layer = layers.Input(batch_shape=(1, 7, None, 13, 5)) + else: + input_layer = layers.Input(batch_shape=(1, 5, 7, None, 13)) + cropped = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer) + if backend.config.image_data_format() == "channels_last": + self.assertEqual(cropped.shape, (1, 4, None, 2, 5)) + else: + self.assertEqual(cropped.shape, (1, 5, 4, None, 2)) + + @parameterized.product( + ( + {"cropping": ((3, 6), (0, 0), (0, 0))}, + {"cropping": ((0, 0), (5, 8), (0, 0))}, + {"cropping": ((0, 0), (0, 0), (7, 6))}, + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + def test_cropping_3d_errors_if_cropping_more_than_available( + self, cropping, data_format + ): + input_layer = layers.Input(batch_shape=(3, 7, 9, 13, 5)) + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=cropping, data_format=data_format)( + input_layer + ) + + def test_cropping_3d_errors_if_cropping_argument_invalid(self): + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=(1,)) + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=(1, 2)) + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=(1, 2, 3, 4)) + with self.assertRaises(ValueError): + layers.Cropping3D(cropping="1") + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=((1, 2), (3, 4), (5, 6, 7))) + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=((1, 2), (3, 4), (5, -6))) + with self.assertRaises(ValueError): + layers.Cropping3D(cropping=((1, 2), (3, 4), "5")) + + @parameterized.product( + ( + {"cropping": ((8, 1), (1, 1), (1, 1))}, + {"cropping": ((1, 1), (10, 1), (1, 1))}, + {"cropping": ((1, 1), (1, 1), (14, 1))}, + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + def test_cropping_3d_with_excessive_cropping(self, cropping, data_format): + if data_format == "channels_first": + shape = (3, 5, 7, 9, 13) + input_layer = layers.Input(batch_shape=shape) + else: + shape = (3, 7, 9, 13, 5) + input_layer = layers.Input(batch_shape=shape) + + expected_error_msg = ( + "Values in `cropping` argument should be smaller than the" + ) + + with self.assertRaisesRegex(ValueError, expected_error_msg): + layers.Cropping3D(cropping=cropping, data_format=data_format)( + input_layer + ) diff --git a/keras/src/layers/reshaping/flatten.py b/keras/src/layers/reshaping/flatten.py new file mode 100644 index 000000000000..e941e48eb1e2 --- /dev/null +++ b/keras/src/layers/reshaping/flatten.py @@ -0,0 +1,84 @@ +import math + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Flatten") +class Flatten(Layer): + """Flattens the input. Does not affect the batch size. + + Note: If inputs are shaped `(batch,)` without a feature axis, then + flattening adds an extra channel dimension and output shape is `(batch, 1)`. + + Args: + data_format: A string, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch, ..., channels)` while `"channels_first"` corresponds to + inputs with shape `(batch, channels, ...)`. + When unspecified, uses `image_data_format` value found in your Keras + config file at `~/.keras/keras.json` (if exists). Defaults to + `"channels_last"`. + + Example: + + >>> x = keras.Input(shape=(10, 64)) + >>> y = keras.layers.Flatten()(x) + >>> y.shape + (None, 640) + """ + + def __init__(self, data_format=None, **kwargs): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + self.input_spec = InputSpec(min_ndim=1) + self._channels_first = self.data_format == "channels_first" + + def call(self, inputs): + input_shape = ops.shape(inputs) + rank = len(input_shape) + + if self._channels_first and rank > 1: + # Switch to channels-last format. + inputs = ops.transpose(inputs, axes=(0, *range(2, rank), 1)) + + non_batch_dims = input_shape[1:] + if len(non_batch_dims) == 0: + flattened_dim = 1 + elif any(not isinstance(d, int) for d in non_batch_dims): + flattened_dim = -1 + else: + flattened_dim = math.prod(non_batch_dims) + + return ops.reshape(inputs, (input_shape[0], flattened_dim)) + + def compute_output_shape(self, input_shape): + non_batch_dims = input_shape[1:] + if len(non_batch_dims) == 0: + flattened_dim = 1 + elif any(d is None for d in non_batch_dims): + # NB: we cannot use the shorter `None in non_batch_dims` here b/c + # torchdynamo errors when calling `__contains__` op with + # a constant (in this case `None`) operand since it assumes + # that the elements in the collection are also `ConstantVariable`s + # but tensor shapes can be `SymNodeVariable`s (e.g. `SymInt`) + flattened_dim = None + else: + flattened_dim = math.prod(non_batch_dims) + return (input_shape[0], flattened_dim) + + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape(inputs.shape) + return KerasTensor( + shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse + ) + + def get_config(self): + config = {"data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/flatten_test.py b/keras/src/layers/reshaping/flatten_test.py new file mode 100644 index 000000000000..4f8d283022f0 --- /dev/null +++ b/keras/src/layers/reshaping/flatten_test.py @@ -0,0 +1,134 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from conftest import skip_if_backend +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing + + +class FlattenTest(testing.TestCase): + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + @pytest.mark.requires_trainable_backend + def test_flatten(self, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors.") + + inputs = np.random.random((10, 3, 5, 5)).astype("float32") + # Make the ndarray relatively sparse + inputs = np.multiply(inputs, inputs >= 0.8) + expected_output_channels_last = ops.convert_to_tensor( + np.reshape(inputs, (-1, 5 * 5 * 3)) + ) + expected_output_channels_first = ops.convert_to_tensor( + np.reshape(np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3)) + ) + if sparse: + if backend.backend() == "tensorflow": + import tensorflow as tf + + dense_to_sparse = tf.sparse.from_dense + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + dense_to_sparse = jax_sparse.BCOO.fromdense + else: + self.fail( + f"Sparse is unsupported with backend {backend.backend()}" + ) + inputs = dense_to_sparse(inputs) + expected_output_channels_last = dense_to_sparse( + expected_output_channels_last + ) + expected_output_channels_first = dense_to_sparse( + expected_output_channels_first + ) + + # Test default data_format and channels_last + self.run_layer_test( + layers.Flatten, + init_kwargs={}, + input_data=inputs, + input_sparse=True, + expected_output=( + expected_output_channels_last + if backend.config.image_data_format() == "channels_last" + else expected_output_channels_first + ), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + self.run_layer_test( + layers.Flatten, + init_kwargs={"data_format": "channels_last"}, + input_data=inputs, + input_sparse=True, + expected_output=expected_output_channels_last, + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + # Test channels_first + self.run_layer_test( + layers.Flatten, + init_kwargs={"data_format": "channels_first"}, + input_data=inputs, + input_sparse=True, + expected_output=expected_output_channels_first, + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + @pytest.mark.requires_trainable_backend + def test_flatten_with_scalar_channels(self): + inputs = np.random.random((10,)).astype("float32") + expected_output = ops.convert_to_tensor(np.expand_dims(inputs, -1)) + + # Test default data_format and channels_last + self.run_layer_test( + layers.Flatten, + init_kwargs={}, + input_data=inputs, + expected_output=expected_output, + ) + self.run_layer_test( + layers.Flatten, + init_kwargs={"data_format": "channels_last"}, + input_data=inputs, + expected_output=expected_output, + ) + + # Test channels_first + self.run_layer_test( + layers.Flatten, + init_kwargs={"data_format": "channels_first"}, + input_data=inputs, + expected_output=expected_output, + ) + + def test_flatten_symbolic_with_dynamic_batch_size(self): + input_layer = layers.Input(batch_shape=(None, 2, 3)) + flattened = layers.Flatten()(input_layer) + self.assertEqual(flattened.shape, (None, 2 * 3)) + + def test_flatten_symbolic_with_dynamic_dimension(self): + input_layer = layers.Input(batch_shape=(5, 2, None)) + flattened = layers.Flatten()(input_layer) + self.assertEqual(flattened.shape, (5, None)) + + @skip_if_backend("openvino", "Dynamic dimensions not supported by OpenVino") + def test_flatten_with_dynamic_batch_size_and_dynamic_dimenstions(self): + def generator(): + yield (np.ones((3, 5, 7), dtype="float32"),) + yield (np.ones((2, 7, 5), dtype="float32"),) + + model = models.Sequential([layers.Flatten()]) + model.predict(generator()) diff --git a/keras/src/layers/reshaping/permute.py b/keras/src/layers/reshaping/permute.py new file mode 100644 index 000000000000..86580dfa0820 --- /dev/null +++ b/keras/src/layers/reshaping/permute.py @@ -0,0 +1,64 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.Permute") +class Permute(Layer): + """Permutes the dimensions of the input according to a given pattern. + + Useful e.g. connecting RNNs and convnets. + + Args: + dims: Tuple of integers. Permutation pattern does not include the + batch dimension. Indexing starts at 1. + For instance, `(1, 3, 2)` permutes the second and third dimensions + of the input. + + Input shape: + Arbitrary. + + Output shape: + Same as the input shape, but with the dimensions re-ordered according + to the specified pattern. + + Example: + + >>> x = keras.Input(shape=(10, 64)) + >>> y = keras.layers.Permute((2, 1))(x) + >>> y.shape + (None, 64, 10) + """ + + def __init__(self, dims, **kwargs): + super().__init__(**kwargs) + self.dims = tuple(dims) + if sorted(dims) != list(range(1, len(dims) + 1)): + raise ValueError( + "Invalid permutation argument `dims` for Permute Layer. " + "The set of indices in `dims` must be consecutive and start " + f"from 1. Received dims={dims}" + ) + self.input_spec = InputSpec(ndim=len(self.dims) + 1) + + def compute_output_shape(self, input_shape): + output_shape = [input_shape[0]] + for dim in self.dims: + output_shape.append(input_shape[dim]) + return tuple(output_shape) + + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape(inputs.shape) + return KerasTensor( + shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse + ) + + def call(self, inputs): + return ops.transpose(inputs, axes=(0,) + self.dims) + + def get_config(self): + config = {"dims": self.dims} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/permute_test.py b/keras/src/layers/reshaping/permute_test.py new file mode 100644 index 000000000000..324d7a0d5354 --- /dev/null +++ b/keras/src/layers/reshaping/permute_test.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class PermuteTest(testing.TestCase): + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + @pytest.mark.requires_trainable_backend + def test_permute(self, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors.") + + inputs = np.random.random((10, 3, 5, 5)).astype("float32") + # Make the ndarray relatively sparse + inputs = np.multiply(inputs, inputs >= 0.8) + expected_output = ops.convert_to_tensor( + np.transpose(inputs, axes=(0, 3, 1, 2)) + ) + if sparse: + if backend.backend() == "tensorflow": + import tensorflow as tf + + inputs = tf.sparse.from_dense(inputs) + expected_output = tf.sparse.from_dense(expected_output) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + inputs = jax_sparse.BCOO.fromdense(inputs) + expected_output = jax_sparse.BCOO.fromdense(expected_output) + else: + self.fail( + f"Backend {backend.backend()} does not support sparse" + ) + + self.run_layer_test( + layers.Permute, + init_kwargs={"dims": (3, 1, 2)}, + input_data=inputs, + input_sparse=sparse, + expected_output=expected_output, + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + def test_permute_with_dynamic_batch_size(self): + input_layer = layers.Input(batch_shape=(None, 3, 5)) + permuted = layers.Permute((2, 1))(input_layer) + self.assertEqual(permuted.shape, (None, 5, 3)) + + def test_permute_errors_on_invalid_starting_dims_index(self): + with self.assertRaisesRegex( + ValueError, r"Invalid permutation .*dims.*" + ): + self.run_layer_test( + layers.Permute, + init_kwargs={"dims": (0, 1, 2)}, + input_shape=(3, 2, 4), + ) + + def test_permute_errors_on_invalid_set_of_dims_indices(self): + with self.assertRaisesRegex( + ValueError, r"Invalid permutation .*dims.*" + ): + self.run_layer_test( + layers.Permute, + init_kwargs={"dims": (1, 4, 2)}, + input_shape=(3, 2, 4), + ) diff --git a/keras/src/layers/reshaping/repeat_vector.py b/keras/src/layers/reshaping/repeat_vector.py new file mode 100644 index 000000000000..d8914d10fce7 --- /dev/null +++ b/keras/src/layers/reshaping/repeat_vector.py @@ -0,0 +1,48 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.RepeatVector") +class RepeatVector(Layer): + """Repeats the input n times. + + Example: + + >>> x = keras.Input(shape=(32,)) + >>> y = keras.layers.RepeatVector(3)(x) + >>> y.shape + (None, 3, 32) + + Args: + n: Integer, repetition factor. + + Input shape: + 2D tensor with shape `(batch_size, features)`. + + Output shape: + 3D tensor with shape `(batch_size, n, features)`. + """ + + def __init__(self, n, **kwargs): + super().__init__(**kwargs) + self.n = n + if not isinstance(n, int): + raise TypeError( + f"Expected an integer value for `n`, got {type(n)}." + ) + self.input_spec = InputSpec(ndim=2) + + def compute_output_shape(self, input_shape): + return (input_shape[0], self.n, input_shape[1]) + + def call(self, inputs): + input_shape = ops.shape(inputs) + reshaped = ops.reshape(inputs, (input_shape[0], 1, input_shape[1])) + return ops.repeat(reshaped, self.n, axis=1) + + def get_config(self): + config = {"n": self.n} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/repeat_vector_test.py b/keras/src/layers/reshaping/repeat_vector_test.py new file mode 100644 index 000000000000..3d1d1a59624a --- /dev/null +++ b/keras/src/layers/reshaping/repeat_vector_test.py @@ -0,0 +1,47 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class FlattenTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_repeat_vector(self): + inputs = np.random.random((2, 5)).astype("float32") + expected_output = ops.convert_to_tensor( + np.repeat(np.reshape(inputs, (2, 1, 5)), 3, axis=1) + ) + self.run_layer_test( + layers.RepeatVector, + init_kwargs={"n": 3}, + input_data=inputs, + expected_output=expected_output, + ) + + def test_repeat_vector_with_dynamic_batch_size(self): + input_layer = layers.Input(batch_shape=(None, 5)) + repeated = layers.RepeatVector(n=3)(input_layer) + self.assertEqual(repeated.shape, (None, 3, 5)) + + def test_repeat_vector_with_dynamic_dimension(self): + input_layer = layers.Input(batch_shape=(2, None)) + repeated = layers.RepeatVector(n=3)(input_layer) + self.assertEqual(repeated.shape, (2, 3, None)) + + def test_repeat_vector_with_invalid_n(self): + with self.assertRaisesRegex( + TypeError, "Expected an integer value for `n`" + ): + layers.RepeatVector(n="3") + + with self.assertRaisesRegex( + TypeError, "Expected an integer value for `n`" + ): + layers.RepeatVector(n=3.5) + + with self.assertRaisesRegex( + TypeError, "Expected an integer value for `n`" + ): + layers.RepeatVector(n=[3]) diff --git a/keras/src/layers/reshaping/reshape.py b/keras/src/layers/reshaping/reshape.py new file mode 100644 index 000000000000..46cfb3ec507e --- /dev/null +++ b/keras/src/layers/reshaping/reshape.py @@ -0,0 +1,79 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.layers.layer import Layer +from keras.src.ops import operation_utils + + +@keras_export("keras.layers.Reshape") +class Reshape(Layer): + """Layer that reshapes inputs into the given shape. + + Args: + target_shape: Target shape. Tuple of integers, does not include the + samples dimension (batch size). One element of the `target_shape` + can be -1 in which case the missing value is inferred from the + size of the array and remaining dimensions. + + Input shape: + Arbitrary, but required to be compatible with `target_shape`. + + Output shape: + `(batch_size, *target_shape)` + + Example: + + >>> x = keras.Input(shape=(12,)) + >>> y = keras.layers.Reshape((3, 4))(x) + >>> y.shape + (None, 3, 4) + + >>> # another example with shape inference using `-1` as dimension + >>> y = keras.layers.Reshape((-1, 2, 2))(x) + >>> y.shape + (None, 3, 2, 2) + """ + + def __init__(self, target_shape, **kwargs): + super().__init__(**kwargs) + target_shape = tuple(target_shape) + # test validity of target_shape + if target_shape.count(-1) > 1: + raise ValueError( + "The `target_shape` argument must not contain more than one " + f"`-1` value. Received: target_shape={target_shape}" + ) + self.target_shape = target_shape + self.built = True + + def compute_output_shape(self, input_shape): + return ( + input_shape[0], + *operation_utils.compute_reshape_output_shape( + input_shape[1:], self.target_shape, "target_shape" + ), + ) + + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape(inputs.shape) + return KerasTensor( + shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse + ) + + def call(self, inputs): + potentially_resolved_target_shape = ( + operation_utils.compute_reshape_output_shape( + tuple(inputs.shape)[1:], self.target_shape, "target_shape" + ) + ) + potentially_resolved_target_shape = tuple( + -1 if d is None else d for d in potentially_resolved_target_shape + ) + return ops.reshape( + inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape + ) + + def get_config(self): + config = {"target_shape": self.target_shape} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/reshape_test.py b/keras/src/layers/reshaping/reshape_test.py new file mode 100644 index 000000000000..823fb8fc672d --- /dev/null +++ b/keras/src/layers/reshaping/reshape_test.py @@ -0,0 +1,138 @@ +import pytest +from absl.testing import parameterized + +from keras.src import Sequential +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.keras_tensor import KerasTensor + + +class ReshapeTest(testing.TestCase): + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + @pytest.mark.requires_trainable_backend + def test_reshape(self, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors.") + + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (8, 1)}, + input_shape=(3, 2, 4), + input_sparse=sparse, + expected_output_shape=(3, 8, 1), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (8,)}, + input_shape=(3, 2, 4), + input_sparse=sparse, + expected_output_shape=(3, 8), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (2, 4)}, + input_shape=(3, 8), + input_sparse=sparse, + expected_output_shape=(3, 2, 4), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (-1, 1)}, + input_shape=(3, 2, 4), + input_sparse=sparse, + expected_output_shape=(3, 8, 1), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (1, -1)}, + input_shape=(3, 2, 4), + input_sparse=sparse, + expected_output_shape=(3, 1, 8), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (-1,)}, + input_shape=(3, 2, 4), + input_sparse=sparse, + expected_output_shape=(3, 8), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + self.run_layer_test( + layers.Reshape, + init_kwargs={"target_shape": (2, -1)}, + input_shape=(3, 2, 4), + input_sparse=sparse, + expected_output_shape=(3, 2, 4), + expected_output_sparse=sparse, + run_training_check=not sparse, + ) + + def test_reshape_with_dynamic_batch_size(self): + input_layer = layers.Input(shape=(2, 4)) + reshaped = layers.Reshape((8,))(input_layer) + self.assertEqual(reshaped.shape, (None, 8)) + + def test_reshape_with_dynamic_batch_size_and_minus_one(self): + input = KerasTensor((None, 6, 4)) + layer = layers.Reshape((-1, 8)) + reshaped = backend.compute_output_spec(layer.__call__, input) + self.assertEqual(reshaped.shape, (None, 3, 8)) + + def test_reshape_layer_with_varying_input_size_and_minus_one(self): + layer = layers.Reshape((-1, 8)) + res = layer(ops.ones((1, 6, 4), dtype="float32")) + self.assertEqual(res.shape, (1, 3, 8)) + res = layer(ops.ones((1, 10, 4), dtype="float32")) + self.assertEqual(res.shape, (1, 5, 8)) + + def test_reshape_with_dynamic_dim_and_minus_one(self): + input = KerasTensor((4, 6, None, 3)) + layer = layers.Reshape((-1, 3)) + reshaped = backend.compute_output_spec(layer.__call__, input) + self.assertEqual(reshaped.shape, (4, None, 3)) + + def test_reshape_sets_static_shape(self): + input_layer = layers.Input(batch_shape=(2, None)) + reshaped = layers.Reshape((3, 5))(input_layer) + # Also make sure the batch dim is not lost after reshape. + self.assertEqual(reshaped.shape, (2, 3, 5)) + + @pytest.mark.requires_trainable_backend + def test_reshape_model_fit_with_varying_input_size_and_minus_one(self): + def generator(): + yield ( + ops.ones((1, 12, 2), dtype="float32"), + ops.zeros((1, 3, 8), dtype="float32"), + ) + yield ( + ops.ones((1, 20, 2), dtype="float32"), + ops.zeros((1, 5, 8), dtype="float32"), + ) + + layer = layers.Reshape((-1, 8)) + model = Sequential([layer]) + model.compile(loss="mean_squared_error") + model.fit(generator(), steps_per_epoch=2, epochs=1) diff --git a/keras/src/layers/reshaping/up_sampling1d.py b/keras/src/layers/reshaping/up_sampling1d.py new file mode 100644 index 000000000000..47a16b9824f4 --- /dev/null +++ b/keras/src/layers/reshaping/up_sampling1d.py @@ -0,0 +1,60 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.UpSampling1D") +class UpSampling1D(Layer): + """Upsampling layer for 1D inputs. + + Repeats each temporal step `size` times along the time axis. + + Example: + + >>> input_shape = (2, 2, 3) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> x + [[[ 0 1 2] + [ 3 4 5]] + [[ 6 7 8] + [ 9 10 11]]] + >>> y = keras.layers.UpSampling1D(size=2)(x) + >>> y + [[[ 0. 1. 2.] + [ 0. 1. 2.] + [ 3. 4. 5.] + [ 3. 4. 5.]] + [[ 6. 7. 8.] + [ 6. 7. 8.] + [ 9. 10. 11.] + [ 9. 10. 11.]]] + + Args: + size: Integer. Upsampling factor. + + Input shape: + 3D tensor with shape: `(batch_size, steps, features)`. + + Output shape: + 3D tensor with shape: `(batch_size, upsampled_steps, features)`. + """ + + def __init__(self, size=2, **kwargs): + super().__init__(**kwargs) + self.size = int(size) + self.input_spec = InputSpec(ndim=3) + + def compute_output_shape(self, input_shape): + size = ( + self.size * input_shape[1] if input_shape[1] is not None else None + ) + return [input_shape[0], size, input_shape[2]] + + def call(self, inputs): + return ops.repeat(x=inputs, repeats=self.size, axis=1) + + def get_config(self): + config = {"size": self.size} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/up_sampling1d_test.py b/keras/src/layers/reshaping/up_sampling1d_test.py new file mode 100644 index 000000000000..978401fd7157 --- /dev/null +++ b/keras/src/layers/reshaping/up_sampling1d_test.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import testing +from keras.src.backend.common.keras_tensor import KerasTensor + + +class UpSamplingTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_upsampling_1d(self): + self.run_layer_test( + layers.UpSampling1D, + init_kwargs={"size": 2}, + input_shape=(3, 5, 4), + expected_output_shape=(3, 10, 4), + expected_output_dtype="float32", + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_upsampling_1d_correctness(self): + self.assertAllClose( + layers.UpSampling1D(size=2)(np.arange(12).reshape((2, 2, 3))), + np.array( + [ + [ + [0.0, 1.0, 2.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [3.0, 4.0, 5.0], + ], + [ + [6.0, 7.0, 8.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [9.0, 10.0, 11.0], + ], + ] + ), + ) + + def test_upsampling_1d_correctness_with_ones(self): + self.assertAllClose( + layers.UpSampling1D(size=3)(np.ones((2, 1, 5))), np.ones((2, 3, 5)) + ) + + def test_upsampling_1d_with_dynamic_batch_size(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3)) + self.assertEqual(layers.UpSampling1D(size=4)(x).shape, (None, 8, 3)) + + def test_upsampling_1d_with_dynamic_shape(self): + y = KerasTensor([2, None, 3]) + self.assertEqual(layers.UpSampling1D(size=2)(y).shape, (2, None, 3)) + self.assertEqual(layers.UpSampling1D(size=4)(y).shape, (2, None, 3)) + + z = KerasTensor([2, 3, None]) + self.assertEqual(layers.UpSampling1D(size=2)(z).shape, (2, 6, None)) + self.assertEqual(layers.UpSampling1D(size=4)(z).shape, (2, 12, None)) diff --git a/keras/src/layers/reshaping/up_sampling2d.py b/keras/src/layers/reshaping/up_sampling2d.py new file mode 100644 index 000000000000..769f1cd7c003 --- /dev/null +++ b/keras/src/layers/reshaping/up_sampling2d.py @@ -0,0 +1,175 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.UpSampling2D") +class UpSampling2D(Layer): + """Upsampling layer for 2D inputs. + + The implementation uses interpolative resizing, given the resize method + (specified by the `interpolation` argument). Use `interpolation=nearest` + to repeat the rows and columns of the data. + + Example: + + >>> input_shape = (2, 2, 1, 3) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> print(x) + [[[[ 0 1 2]] + [[ 3 4 5]]] + [[[ 6 7 8]] + [[ 9 10 11]]]] + >>> y = keras.layers.UpSampling2D(size=(1, 2))(x) + >>> print(y) + [[[[ 0 1 2] + [ 0 1 2]] + [[ 3 4 5] + [ 3 4 5]]] + [[[ 6 7 8] + [ 6 7 8]] + [[ 9 10 11] + [ 9 10 11]]]] + + Args: + size: Int, or tuple of 2 integers. + The upsampling factors for rows and columns. + data_format: A string, + one of `"channels_last"` (default) or `"channels_first"`. + The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch_size, channels, height, width)`. + When unspecified, uses + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json` (if exists) else `"channels_last"`. + Defaults to `"channels_last"`. + interpolation: A string, one of `"bicubic"`, `"bilinear"`, `"lanczos3"`, + `"lanczos5"`, `"nearest"`. + + Input shape: + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, rows, cols, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, rows, cols)` + + Output shape: + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, upsampled_rows, upsampled_cols, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, upsampled_rows, upsampled_cols)` + """ + + def __init__( + self, size=(2, 2), data_format=None, interpolation="nearest", **kwargs + ): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + self.size = argument_validation.standardize_tuple(size, 2, "size") + self.interpolation = interpolation.lower() + self.input_spec = InputSpec(ndim=4) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + height = ( + self.size[0] * input_shape[2] + if input_shape[2] is not None + else None + ) + width = ( + self.size[1] * input_shape[3] + if input_shape[3] is not None + else None + ) + return (input_shape[0], input_shape[1], height, width) + else: + height = ( + self.size[0] * input_shape[1] + if input_shape[1] is not None + else None + ) + width = ( + self.size[1] * input_shape[2] + if input_shape[2] is not None + else None + ) + return (input_shape[0], height, width, input_shape[3]) + + def call(self, inputs): + return self._resize_images( + inputs, + self.size[0], + self.size[1], + self.data_format, + interpolation=self.interpolation, + ) + + def get_config(self): + config = { + "size": self.size, + "data_format": self.data_format, + "interpolation": self.interpolation, + } + base_config = super().get_config() + return {**base_config, **config} + + def _resize_images( + self, + x, + height_factor, + width_factor, + data_format, + interpolation="nearest", + ): + """Resizes the images contained in a 4D tensor. + + Args: + x: Tensor or variable to resize. + height_factor: Positive integer. + width_factor: Positive integer. + data_format: One of `"channels_first"`, `"channels_last"`. + interpolation: A string, one of `"bicubic"`, `"bilinear"`, + `"lanczos3"`, `"lanczos5"`, or `"nearest"`. + + Returns: + A tensor. + """ + if data_format not in {"channels_last", "channels_first"}: + raise ValueError(f"Invalid `data_format` argument: {data_format}") + + if data_format == "channels_first": + x = ops.transpose(x, [0, 2, 3, 1]) + # https://github.com/keras-team/keras/issues/294 + # Use `ops.repeat` for `nearest` interpolation to enable XLA + if interpolation == "nearest": + x = ops.repeat(x, height_factor, axis=1) + x = ops.repeat(x, width_factor, axis=2) + else: + # multiply the height and width factor on each dim + # by hand (versus using element-wise multiplication + # by np.array([height_factor, width_factor]) then + # list-ifying the tensor by calling `.tolist()`) + # since when running under torchdynamo, `new_shape` + # will be traced as a symbolic variable (specifically + # a `FakeTensor`) which does not have a `tolist()` method. + shape = ops.shape(x) + new_shape = ( + shape[1] * height_factor, + shape[2] * width_factor, + ) + x = ops.image.resize( + x, + new_shape, + data_format="channels_last", + interpolation=interpolation, + ) + if data_format == "channels_first": + x = ops.transpose(x, [0, 3, 1, 2]) + + return x diff --git a/keras/src/layers/reshaping/up_sampling2d_test.py b/keras/src/layers/reshaping/up_sampling2d_test.py new file mode 100644 index 000000000000..6757e4fb615c --- /dev/null +++ b/keras/src/layers/reshaping/up_sampling2d_test.py @@ -0,0 +1,158 @@ +# flake8: noqa +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.backend import set_image_data_format + + +class UpSampling2dTest(testing.TestCase): + @classmethod + def setUpClass(cls): + cls.original_image_data_format = backend.image_data_format() + + @classmethod + def tearDownClass(cls): + backend.set_image_data_format(cls.original_image_data_format) + + @parameterized.product( + data_format=["channels_first", "channels_last"], + length_row=[2], + length_col=[2, 3], + ) + @pytest.mark.requires_trainable_backend + def test_upsampling_2d(self, data_format, length_row, length_col): + num_samples = 2 + stack_size = 2 + input_num_row = 11 + input_num_col = 12 + + if data_format == "channels_first": + inputs = np.random.rand( + num_samples, stack_size, input_num_row, input_num_col + ) + else: + inputs = np.random.rand( + num_samples, input_num_row, input_num_col, stack_size + ) + + # basic test + self.run_layer_test( + layers.UpSampling2D, + init_kwargs={"size": (2, 2), "data_format": data_format}, + input_shape=inputs.shape, + ) + + layer = layers.UpSampling2D( + size=(length_row, length_col), + data_format=data_format, + ) + layer.build(inputs.shape) + np_output = layer(inputs=backend.Variable(inputs)) + if data_format == "channels_first": + assert np_output.shape[2] == length_row * input_num_row + assert np_output.shape[3] == length_col * input_num_col + else: + assert np_output.shape[1] == length_row * input_num_row + assert np_output.shape[2] == length_col * input_num_col + + # compare with numpy + if data_format == "channels_first": + expected_out = np.repeat(inputs, length_row, axis=2) + expected_out = np.repeat(expected_out, length_col, axis=3) + else: + expected_out = np.repeat(inputs, length_row, axis=1) + expected_out = np.repeat(expected_out, length_col, axis=2) + + self.assertAllClose(np_output, expected_out) + + @parameterized.product( + data_format=["channels_first", "channels_last"], + use_set_image_data_format=[True, False], + length_row=[2], + length_col=[2, 3], + ) + @pytest.mark.requires_trainable_backend + def test_upsampling_2d_bilinear( + self, data_format, use_set_image_data_format, length_row, length_col + ): + num_samples = 2 + stack_size = 2 + input_num_row = 11 + input_num_col = 12 + + if use_set_image_data_format: + set_image_data_format(data_format) + + if data_format == "channels_first": + inputs = np.random.rand( + num_samples, stack_size, input_num_row, input_num_col + ) + else: + inputs = np.random.rand( + num_samples, input_num_row, input_num_col, stack_size + ) + + self.run_layer_test( + layers.UpSampling2D, + init_kwargs={ + "size": (2, 2), + "data_format": data_format, + "interpolation": "bilinear", + }, + input_shape=inputs.shape, + ) + + layer = layers.UpSampling2D( + size=(length_row, length_col), + data_format=data_format, + interpolation="bilinear", + ) + layer.build(inputs.shape) + np_output = layer(inputs=backend.Variable(inputs)) + if data_format == "channels_first": + self.assertEqual(np_output.shape[2], length_row * input_num_row) + self.assertEqual(np_output.shape[3], length_col * input_num_col) + else: + self.assertEqual(np_output.shape[1], length_row * input_num_row) + self.assertEqual(np_output.shape[2], length_col * input_num_col) + + def test_upsampling_2d_correctness(self): + input_shape = (2, 2, 1, 3) + x = np.arange(np.prod(input_shape)).reshape(input_shape) + # fmt: off + expected_output = np.array( + [[[[ 0., 1., 2.], + [ 0., 1., 2.]], + [[ 3., 4., 5.], + [ 3., 4., 5.]]], + [[[ 6., 7., 8.], + [ 6., 7., 8.]], + [[ 9., 10., 11.], + [ 9., 10., 11.]]]] + ) + # fmt: on + if backend.config.image_data_format() == "channels_first": + expected_output = expected_output.transpose((0, 3, 1, 2)) + x = x.transpose((0, 3, 1, 2)) + self.assertAllClose( + layers.UpSampling2D(size=(1, 2))(x), expected_output + ) + + def test_upsampling_2d_various_interpolation_methods(self): + input_shape = (2, 2, 1, 3) + x = np.arange(np.prod(input_shape)).reshape(input_shape) + for interpolation in ["nearest", "bilinear", "bicubic"]: + layers.UpSampling2D(size=(1, 2), interpolation=interpolation)(x) + + @pytest.mark.skipif( + backend.backend() == "torch", reason="Torch does not support lanczos." + ) + def test_upsampling_2d_lanczos_interpolation_methods(self): + input_shape = (2, 2, 1, 3) + x = np.arange(np.prod(input_shape)).reshape(input_shape) + for interpolation in ["lanczos3", "lanczos5"]: + layers.UpSampling2D(size=(1, 2), interpolation=interpolation)(x) diff --git a/keras/src/layers/reshaping/up_sampling3d.py b/keras/src/layers/reshaping/up_sampling3d.py new file mode 100644 index 000000000000..3b642e48ef6a --- /dev/null +++ b/keras/src/layers/reshaping/up_sampling3d.py @@ -0,0 +1,134 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.UpSampling3D") +class UpSampling3D(Layer): + """Upsampling layer for 3D inputs. + + Repeats the 1st, 2nd and 3rd dimensions + of the data by `size[0]`, `size[1]` and `size[2]` respectively. + + Example: + + >>> input_shape = (2, 1, 2, 1, 3) + >>> x = np.ones(input_shape) + >>> y = keras.layers.UpSampling3D(size=(2, 2, 2))(x) + >>> y.shape + (2, 2, 4, 2, 3) + + Args: + size: Int, or tuple of 3 integers. + The upsampling factors for dim1, dim2 and dim3. + data_format: A string, + one of `"channels_last"` (default) or `"channels_first"`. + The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + When unspecified, uses + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json` (if exists) else `"channels_last"`. + Defaults to `"channels_last"`. + + Input shape: + 5D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, dim1, dim2, dim3, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, dim1, dim2, dim3)` + + Output shape: + 5D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, upsampled_dim1, upsampled_dim2, upsampled_dim3, + channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, upsampled_dim1, upsampled_dim2, + upsampled_dim3)` + """ + + def __init__(self, size=(2, 2, 2), data_format=None, **kwargs): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + self.size = argument_validation.standardize_tuple(size, 3, "size") + self.input_spec = InputSpec(ndim=5) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + dim1 = ( + self.size[0] * input_shape[2] + if input_shape[2] is not None + else None + ) + dim2 = ( + self.size[1] * input_shape[3] + if input_shape[3] is not None + else None + ) + dim3 = ( + self.size[2] * input_shape[4] + if input_shape[4] is not None + else None + ) + return (input_shape[0], input_shape[1], dim1, dim2, dim3) + else: + dim1 = ( + self.size[0] * input_shape[1] + if input_shape[1] is not None + else None + ) + dim2 = ( + self.size[1] * input_shape[2] + if input_shape[2] is not None + else None + ) + dim3 = ( + self.size[2] * input_shape[3] + if input_shape[3] is not None + else None + ) + return (input_shape[0], dim1, dim2, dim3, input_shape[4]) + + def call(self, inputs): + return self._resize_volumes( + inputs, self.size[0], self.size[1], self.size[2], self.data_format + ) + + def get_config(self): + config = {"size": self.size, "data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} + + def _resize_volumes( + self, x, depth_factor, height_factor, width_factor, data_format + ): + """Resizes the volume contained in a 5D tensor. + + Args: + x: Tensor or variable to resize. + depth_factor: Positive integer. + height_factor: Positive integer. + width_factor: Positive integer. + data_format: One of `"channels_first"`, `"channels_last"`. + + Returns: + Resized tensor. + """ + if data_format == "channels_first": + output = ops.repeat(x, depth_factor, axis=2) + output = ops.repeat(output, height_factor, axis=3) + output = ops.repeat(output, width_factor, axis=4) + return output + elif data_format == "channels_last": + output = ops.repeat(x, depth_factor, axis=1) + output = ops.repeat(output, height_factor, axis=2) + output = ops.repeat(output, width_factor, axis=3) + return output + else: + raise ValueError(f"Invalid data_format: {data_format}") diff --git a/keras/src/layers/reshaping/up_sampling3d_test.py b/keras/src/layers/reshaping/up_sampling3d_test.py new file mode 100644 index 000000000000..3d4a763032b8 --- /dev/null +++ b/keras/src/layers/reshaping/up_sampling3d_test.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class UpSampling3dTest(testing.TestCase): + @parameterized.product( + data_format=["channels_first", "channels_last"], + length_dim1=[2, 3], + length_dim2=[2], + length_dim3=[3], + ) + @pytest.mark.requires_trainable_backend + def test_upsampling_3d( + self, data_format, length_dim1, length_dim2, length_dim3 + ): + num_samples = 2 + stack_size = 2 + input_len_dim1 = 10 + input_len_dim2 = 11 + input_len_dim3 = 12 + + if data_format == "channels_first": + inputs = np.random.rand( + num_samples, + stack_size, + input_len_dim1, + input_len_dim2, + input_len_dim3, + ) + else: + inputs = np.random.rand( + num_samples, + input_len_dim1, + input_len_dim2, + input_len_dim3, + stack_size, + ) + + # basic test + if data_format == "channels_first": + expected_output_shape = (2, 2, 20, 22, 24) + else: + expected_output_shape = (2, 20, 22, 24, 2) + + self.run_layer_test( + layers.UpSampling3D, + init_kwargs={"size": (2, 2, 2), "data_format": data_format}, + input_shape=inputs.shape, + expected_output_shape=expected_output_shape, + expected_output_dtype="float32", + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + layer = layers.UpSampling3D( + size=(length_dim1, length_dim2, length_dim3), + data_format=data_format, + ) + layer.build(inputs.shape) + np_output = layer(inputs=backend.Variable(inputs)) + if data_format == "channels_first": + assert np_output.shape[2] == length_dim1 * input_len_dim1 + assert np_output.shape[3] == length_dim2 * input_len_dim2 + assert np_output.shape[4] == length_dim3 * input_len_dim3 + else: # tf + assert np_output.shape[1] == length_dim1 * input_len_dim1 + assert np_output.shape[2] == length_dim2 * input_len_dim2 + assert np_output.shape[3] == length_dim3 * input_len_dim3 + + # compare with numpy + if data_format == "channels_first": + expected_out = np.repeat(inputs, length_dim1, axis=2) + expected_out = np.repeat(expected_out, length_dim2, axis=3) + expected_out = np.repeat(expected_out, length_dim3, axis=4) + else: # tf + expected_out = np.repeat(inputs, length_dim1, axis=1) + expected_out = np.repeat(expected_out, length_dim2, axis=2) + expected_out = np.repeat(expected_out, length_dim3, axis=3) + + self.assertAllClose(np_output, expected_out) + + def test_upsampling_3d_correctness(self): + input_shape = (2, 1, 2, 1, 3) + x = np.arange(np.prod(input_shape)).reshape(input_shape) + expected_output = np.array( + [ + [ + [ + [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], + [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], + [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + ], + [ + [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], + [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], + [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]], + ], + ], + [ + [ + [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]], + [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]], + [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]], + ], + [ + [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]], + [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]], + [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]], + ], + ], + ] + ) + if backend.config.image_data_format() == "channels_first": + expected_output = expected_output.transpose((0, 4, 1, 2, 3)) + x = x.transpose((0, 4, 1, 2, 3)) + self.assertAllClose( + layers.UpSampling3D(size=(2, 2, 2))(x), expected_output + ) diff --git a/keras/src/layers/reshaping/zero_padding1d.py b/keras/src/layers/reshaping/zero_padding1d.py new file mode 100644 index 000000000000..c9e50d8897b3 --- /dev/null +++ b/keras/src/layers/reshaping/zero_padding1d.py @@ -0,0 +1,93 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.ZeroPadding1D") +class ZeroPadding1D(Layer): + """Zero-padding layer for 1D input (e.g. temporal sequence). + + Example: + + >>> input_shape = (2, 2, 3) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> x + [[[ 0 1 2] + [ 3 4 5]] + [[ 6 7 8] + [ 9 10 11]]] + >>> y = keras.layers.ZeroPadding1D(padding=2)(x) + >>> y + [[[ 0 0 0] + [ 0 0 0] + [ 0 1 2] + [ 3 4 5] + [ 0 0 0] + [ 0 0 0]] + [[ 0 0 0] + [ 0 0 0] + [ 6 7 8] + [ 9 10 11] + [ 0 0 0] + [ 0 0 0]]] + + Args: + padding: Int, or tuple of int (length 2), or dictionary. + - If int: how many zeros to add at the beginning and end of + the padding dimension (axis 1). + - If tuple of 2 ints: how many zeros to add at the beginning and the + end of the padding dimension (`(left_pad, right_pad)`). + data_format: A string, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, axis_to_pad, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch_size, channels, axis_to_pad)`. + When unspecified, uses `image_data_format` value found in your Keras + config file at `~/.keras/keras.json` (if exists). Defaults to + `"channels_last"`. + + Input shape: + 3D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, axis_to_pad, features)` + - If `data_format` is `"channels_first"`: + `(batch_size, features, axis_to_pad)` + + Output shape: + 3D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, padded_axis, features)` + - If `data_format` is `"channels_first"`: + `(batch_size, features, padded_axis)` + """ + + def __init__(self, padding=1, data_format=None, **kwargs): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + self.padding = argument_validation.standardize_tuple( + padding, 2, "padding", allow_zero=True + ) + self.input_spec = InputSpec(ndim=3) + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + padding_dim = 2 if self.data_format == "channels_first" else 1 + if output_shape[padding_dim] is not None: + output_shape[padding_dim] += self.padding[0] + self.padding[1] + return tuple(output_shape) + + def call(self, inputs): + if self.data_format == "channels_first": + all_dims_padding = ((0, 0), (0, 0), self.padding) + else: + all_dims_padding = ((0, 0), self.padding, (0, 0)) + return ops.pad(inputs, all_dims_padding) + + def get_config(self): + config = {"padding": self.padding, "data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/zero_padding1d_test.py b/keras/src/layers/reshaping/zero_padding1d_test.py new file mode 100644 index 000000000000..fd14ae085be7 --- /dev/null +++ b/keras/src/layers/reshaping/zero_padding1d_test.py @@ -0,0 +1,74 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import dtype_policies +from keras.src import layers +from keras.src import testing + + +class ZeroPadding1DTest(testing.TestCase): + @parameterized.parameters( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ) + def test_zero_padding_1d(self, data_format): + inputs = np.random.rand(1, 2, 3) + outputs = layers.ZeroPadding1D(padding=(1, 2), data_format=data_format)( + inputs + ) + if data_format == "channels_last": + for index in [0, -1, -2]: + self.assertAllClose(outputs[:, index, :], 0.0) + self.assertAllClose(outputs[:, 1:-2, :], inputs) + else: + for index in [0, -1, -2]: + self.assertAllClose(outputs[:, :, index], 0.0) + self.assertAllClose(outputs[:, :, 1:-2], inputs) + + @parameterized.named_parameters(("one_tuple", (2, 2)), ("one_int", 2)) + def test_zero_padding_1d_with_same_padding(self, padding): + inputs = np.random.rand(1, 2, 3) + outputs = layers.ZeroPadding1D( + padding=padding, data_format="channels_last" + )(inputs) + + for index in [0, 1, -1, -2]: + self.assertAllClose(outputs[:, index, :], 0.0) + self.assertAllClose(outputs[:, 2:-2, :], inputs) + + def test_zero_padding_1d_with_dynamic_spatial_dim(self): + input_layer = layers.Input(batch_shape=(1, None, 3)) + padded = layers.ZeroPadding1D((1, 2), data_format="channels_last")( + input_layer + ) + self.assertEqual(padded.shape, (1, None, 3)) + + input_layer = layers.Input(batch_shape=(1, 2, 3)) + padded = layers.ZeroPadding1D((1, 2), data_format="channels_last")( + input_layer + ) + self.assertEqual(padded.shape, (1, 5, 3)) + + @parameterized.parameters( + {"padding": (1,)}, + {"padding": (1, 2, 3)}, + {"padding": "1"}, + ) + def test_zero_padding_1d_errors_if_padding_argument_invalid(self, padding): + with self.assertRaises(ValueError): + layers.ZeroPadding1D(padding) + + @parameterized.parameters( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ) + def test_zero_padding_1d_get_config(self, data_format): + layer = layers.ZeroPadding1D(padding=(1, 2), data_format=data_format) + expected_config = { + "dtype": dtype_policies.serialize(layer.dtype_policy), + "data_format": data_format, + "name": layer.name, + "padding": (1, 2), + "trainable": layer.trainable, + } + self.assertEqual(layer.get_config(), expected_config) diff --git a/keras/src/layers/reshaping/zero_padding2d.py b/keras/src/layers/reshaping/zero_padding2d.py new file mode 100644 index 000000000000..e5d88d16d76d --- /dev/null +++ b/keras/src/layers/reshaping/zero_padding2d.py @@ -0,0 +1,119 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.ZeroPadding2D") +class ZeroPadding2D(Layer): + """Zero-padding layer for 2D input (e.g. picture). + + This layer can add rows and columns of zeros at the top, bottom, left and + right side of an image tensor. + + Example: + + >>> input_shape = (1, 1, 2, 2) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> x + [[[[0 1] + [2 3]]]] + >>> y = keras.layers.ZeroPadding2D(padding=1)(x) + >>> y + [[[[0 0] + [0 0] + [0 0] + [0 0]] + [[0 0] + [0 1] + [2 3] + [0 0]] + [[0 0] + [0 0] + [0 0] + [0 0]]]] + + Args: + padding: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. + - If int: the same symmetric padding is applied to height and width. + - If tuple of 2 ints: interpreted as two different symmetric padding + values for height and width: + `(symmetric_height_pad, symmetric_width_pad)`. + - If tuple of 2 tuples of 2 ints: interpreted as + `((top_pad, bottom_pad), (left_pad, right_pad))`. + data_format: A string, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch_size, channels, height, width)`. + When unspecified, uses `image_data_format` value found in your Keras + config file at `~/.keras/keras.json` (if exists). Defaults to + `"channels_last"`. + + Input shape: + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, height, width, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, height, width)` + + Output shape: + 4D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, padded_height, padded_width, channels)` + - If `data_format` is `"channels_first"`: + `(batch_size, channels, padded_height, padded_width)` + """ + + def __init__(self, padding=(1, 1), data_format=None, **kwargs): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + if isinstance(padding, int): + self.padding = ((padding, padding), (padding, padding)) + elif hasattr(padding, "__len__"): + if len(padding) != 2: + raise ValueError( + "`padding` should have two elements. " + f"Received: padding={padding}." + ) + height_padding = argument_validation.standardize_tuple( + padding[0], 2, "1st entry of padding", allow_zero=True + ) + width_padding = argument_validation.standardize_tuple( + padding[1], 2, "2nd entry of padding", allow_zero=True + ) + self.padding = (height_padding, width_padding) + else: + raise ValueError( + "`padding` should be either an int, a tuple of 2 ints " + "(symmetric_height_crop, symmetric_width_crop), " + "or a tuple of 2 tuples of 2 ints " + "((top_crop, bottom_crop), (left_crop, right_crop)). " + f"Received: padding={padding}." + ) + self.input_spec = InputSpec(ndim=4) + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + spatial_dims_offset = 2 if self.data_format == "channels_first" else 1 + for index in range(0, 2): + if output_shape[index + spatial_dims_offset] is not None: + output_shape[index + spatial_dims_offset] += ( + self.padding[index][0] + self.padding[index][1] + ) + return tuple(output_shape) + + def call(self, inputs): + if self.data_format == "channels_first": + all_dims_padding = ((0, 0), (0, 0), *self.padding) + else: + all_dims_padding = ((0, 0), *self.padding, (0, 0)) + return ops.pad(inputs, all_dims_padding) + + def get_config(self): + config = {"padding": self.padding, "data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/zero_padding2d_test.py b/keras/src/layers/reshaping/zero_padding2d_test.py new file mode 100644 index 000000000000..c373f50a9b1a --- /dev/null +++ b/keras/src/layers/reshaping/zero_padding2d_test.py @@ -0,0 +1,98 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import dtype_policies +from keras.src import layers +from keras.src import testing + + +class ZeroPadding2DTest(testing.TestCase): + @parameterized.parameters( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ) + def test_zero_padding_2d(self, data_format): + inputs = np.random.rand(1, 2, 3, 4) + outputs = layers.ZeroPadding2D( + padding=((1, 2), (3, 4)), data_format=data_format + )(inputs) + + if data_format == "channels_first": + for index in [0, -1, -2]: + self.assertAllClose(outputs[:, :, index, :], 0.0) + for index in [0, 1, 2, -1, -2, -3, -4]: + self.assertAllClose(outputs[:, :, :, index], 0.0) + self.assertAllClose(outputs[:, :, 1:-2, 3:-4], inputs) + else: + for index in [0, -1, -2]: + self.assertAllClose(outputs[:, index, :, :], 0.0) + for index in [0, 1, 2, -1, -2, -3, -4]: + self.assertAllClose(outputs[:, :, index, :], 0.0) + self.assertAllClose(outputs[:, 1:-2, 3:-4, :], inputs) + + @parameterized.product( + ( + {"padding": ((2, 2), (2, 2))}, # 2 tuples + {"padding": (2, 2)}, # 1 tuple + {"padding": 2}, # 1 int + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + def test_zero_padding_2d_with_same_padding(self, padding, data_format): + inputs = np.random.rand(1, 2, 3, 4) + outputs = layers.ZeroPadding2D( + padding=padding, data_format=data_format + )(inputs) + + if data_format == "channels_first": + for index in [0, 1, -1, -2]: + self.assertAllClose(outputs[:, :, index, :], 0.0) + self.assertAllClose(outputs[:, :, :, index], 0.0) + self.assertAllClose(outputs[:, :, 2:-2, 2:-2], inputs) + else: + for index in [0, 1, -1, -2]: + self.assertAllClose(outputs[:, index, :, :], 0.0) + self.assertAllClose(outputs[:, :, index, :], 0.0) + self.assertAllClose(outputs[:, 2:-2, 2:-2, :], inputs) + + def test_zero_padding_2d_with_dynamic_spatial_dim(self): + if backend.config.image_data_format() == "channels_last": + input_layer = layers.Input(batch_shape=(1, 2, None, 4)) + else: + input_layer = layers.Input(batch_shape=(1, 4, 2, None)) + padded = layers.ZeroPadding2D(((1, 2), (3, 4)))(input_layer) + if backend.config.image_data_format() == "channels_last": + self.assertEqual(padded.shape, (1, 5, None, 4)) + else: + self.assertEqual(padded.shape, (1, 4, 5, None)) + + @parameterized.parameters( + {"padding": (1,)}, + {"padding": (1, 2, 3)}, + {"padding": "1"}, + {"padding": ((1, 2), (3, 4, 5))}, + {"padding": ((1, 2), (3, -4))}, + {"padding": ((1, 2), "3")}, + ) + def test_zero_padding_2d_errors_if_padding_argument_invalid(self, padding): + with self.assertRaises(ValueError): + layers.ZeroPadding2D(padding) + + @parameterized.parameters( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ) + def test_zero_padding_2d_get_config(self, data_format): + layer = layers.ZeroPadding2D(padding=(1, 2), data_format=data_format) + expected_config = { + "data_format": data_format, + "dtype": dtype_policies.serialize(layer.dtype_policy), + "name": layer.name, + "padding": ((1, 1), (2, 2)), + "trainable": layer.trainable, + } + self.assertEqual(layer.get_config(), expected_config) diff --git a/keras/src/layers/reshaping/zero_padding3d.py b/keras/src/layers/reshaping/zero_padding3d.py new file mode 100644 index 000000000000..87e39bf00060 --- /dev/null +++ b/keras/src/layers/reshaping/zero_padding3d.py @@ -0,0 +1,118 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.utils import argument_validation + + +@keras_export("keras.layers.ZeroPadding3D") +class ZeroPadding3D(Layer): + """Zero-padding layer for 3D data (spatial or spatio-temporal). + + Example: + + >>> input_shape = (1, 1, 2, 2, 3) + >>> x = np.arange(np.prod(input_shape)).reshape(input_shape) + >>> y = keras.layers.ZeroPadding3D(padding=2)(x) + >>> y.shape + (1, 5, 6, 6, 3) + + Args: + padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. + - If int: the same symmetric padding is applied to depth, height, + and width. + - If tuple of 3 ints: interpreted as three different symmetric + padding values for depth, height, and width: + `(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`. + - If tuple of 3 tuples of 2 ints: interpreted as + `((left_dim1_pad, right_dim1_pad), (left_dim2_pad, + right_dim2_pad), (left_dim3_pad, right_dim3_pad))`. + data_format: A string, one of `"channels_last"` (default) or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. + When unspecified, uses `image_data_format` value found in your Keras + config file at `~/.keras/keras.json` (if exists). Defaults to + `"channels_last"`. + + Input shape: + 5D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, first_axis_to_pad, second_axis_to_pad, + third_axis_to_pad, depth)` + - If `data_format` is `"channels_first"`: + `(batch_size, depth, first_axis_to_pad, second_axis_to_pad, + third_axis_to_pad)` + + Output shape: + 5D tensor with shape: + - If `data_format` is `"channels_last"`: + `(batch_size, first_padded_axis, second_padded_axis, + third_axis_to_pad, depth)` + - If `data_format` is `"channels_first"`: + `(batch_size, depth, first_padded_axis, second_padded_axis, + third_axis_to_pad)` + """ + + def __init__( + self, padding=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs + ): + super().__init__(**kwargs) + self.data_format = backend.standardize_data_format(data_format) + if isinstance(padding, int): + self.padding = ( + (padding, padding), + (padding, padding), + (padding, padding), + ) + elif hasattr(padding, "__len__"): + if len(padding) != 3: + raise ValueError( + f"`padding` should have 3 elements. Received: {padding}." + ) + dim1_padding = argument_validation.standardize_tuple( + padding[0], 2, "1st entry of padding", allow_zero=True + ) + dim2_padding = argument_validation.standardize_tuple( + padding[1], 2, "2nd entry of padding", allow_zero=True + ) + dim3_padding = argument_validation.standardize_tuple( + padding[2], 2, "3rd entry of padding", allow_zero=True + ) + self.padding = (dim1_padding, dim2_padding, dim3_padding) + else: + raise ValueError( + "`padding` should be either an int, a tuple of 3 ints " + "(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), " + "or a tuple of 3 tuples of 2 ints " + "((left_dim1_pad, right_dim1_pad)," + " (left_dim2_pad, right_dim2_pad)," + " (left_dim3_pad, right_dim2_pad)). " + f"Received: padding={padding}." + ) + self.input_spec = InputSpec(ndim=5) + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + spatial_dims_offset = 2 if self.data_format == "channels_first" else 1 + for index in range(0, 3): + if output_shape[index + spatial_dims_offset] is not None: + output_shape[index + spatial_dims_offset] += ( + self.padding[index][0] + self.padding[index][1] + ) + return tuple(output_shape) + + def call(self, inputs): + if self.data_format == "channels_first": + all_dims_padding = ((0, 0), (0, 0), *self.padding) + else: + all_dims_padding = ((0, 0), *self.padding, (0, 0)) + return ops.pad(inputs, all_dims_padding) + + def get_config(self): + config = {"padding": self.padding, "data_format": self.data_format} + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/reshaping/zero_padding3d_test.py b/keras/src/layers/reshaping/zero_padding3d_test.py new file mode 100644 index 000000000000..ba2d152be44c --- /dev/null +++ b/keras/src/layers/reshaping/zero_padding3d_test.py @@ -0,0 +1,104 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import dtype_policies +from keras.src import layers +from keras.src import testing + + +class ZeroPadding3DTest(testing.TestCase): + @parameterized.parameters( + {"data_format": "channels_first"}, {"data_format": "channels_last"} + ) + def test_zero_padding_3d(self, data_format): + inputs = np.random.rand(1, 2, 3, 4, 5) + outputs = layers.ZeroPadding3D( + padding=((1, 2), (3, 4), (0, 2)), data_format=data_format + )(inputs) + + if data_format == "channels_first": + for index in [0, -1, -2]: + self.assertAllClose(outputs[:, :, index, :, :], 0.0) + for index in [0, 1, 2, -1, -2, -3, -4]: + self.assertAllClose(outputs[:, :, :, index, :], 0.0) + for index in [-1, -2]: + self.assertAllClose(outputs[:, :, :, :, index], 0.0) + self.assertAllClose(outputs[:, :, 1:-2, 3:-4, 0:-2], inputs) + else: + for index in [0, -1, -2]: + self.assertAllClose(outputs[:, index, :, :, :], 0.0) + for index in [0, 1, 2, -1, -2, -3, -4]: + self.assertAllClose(outputs[:, :, index, :, :], 0.0) + for index in [-1, -2]: + self.assertAllClose(outputs[:, :, :, index, :], 0.0) + self.assertAllClose(outputs[:, 1:-2, 3:-4, 0:-2, :], inputs) + + @parameterized.product( + ( + {"padding": ((2, 2), (2, 2), (2, 2))}, # 3 tuples + {"padding": (2, 2, 2)}, # 1 tuple + {"padding": 2}, # 1 int + ), + ( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ), + ) + def test_zero_padding_3d_with_same_padding(self, padding, data_format): + inputs = np.random.rand(1, 2, 3, 4, 5) + outputs = layers.ZeroPadding3D( + padding=padding, data_format=data_format + )(inputs) + + if data_format == "channels_first": + for index in [0, 1, -1, -2]: + self.assertAllClose(outputs[:, :, index, :, :], 0.0) + self.assertAllClose(outputs[:, :, :, index, :], 0.0) + self.assertAllClose(outputs[:, :, :, :, index], 0.0) + self.assertAllClose(outputs[:, :, 2:-2, 2:-2, 2:-2], inputs) + else: + for index in [0, 1, -1, -2]: + self.assertAllClose(outputs[:, index, :, :, :], 0.0) + self.assertAllClose(outputs[:, :, index, :, :], 0.0) + self.assertAllClose(outputs[:, :, :, index, :], 0.0) + self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs) + + def test_zero_padding_3d_with_dynamic_spatial_dim(self): + if backend.config.image_data_format() == "channels_last": + input_layer = layers.Input(batch_shape=(1, 2, None, 4, 5)) + else: + input_layer = layers.Input(batch_shape=(1, 5, 2, None, 4)) + padded = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer) + if backend.config.image_data_format() == "channels_last": + self.assertEqual(padded.shape, (1, 5, None, 15, 5)) + else: + self.assertEqual(padded.shape, (1, 5, 5, None, 15)) + + @parameterized.parameters( + {"padding": (1,)}, + {"padding": (1, 2)}, + {"padding": (1, 2, 3, 4)}, + {"padding": "1"}, + {"padding": ((1, 2), (3, 4), (5, 6, 7))}, + {"padding": ((1, 2), (3, 4), (5, -6))}, + {"padding": ((1, 2), (3, 4), "5")}, + ) + def test_zero_padding_3d_errors_if_padding_argument_invalid(self, padding): + with self.assertRaises(ValueError): + layers.ZeroPadding3D(padding=padding) + + @parameterized.parameters( + {"data_format": "channels_first"}, + {"data_format": "channels_last"}, + ) + def test_zero_padding_3d_get_config(self, data_format): + layer = layers.ZeroPadding3D(padding=(1, 2, 3), data_format=data_format) + expected_config = { + "data_format": data_format, + "dtype": dtype_policies.serialize(layer.dtype_policy), + "name": layer.name, + "padding": ((1, 1), (2, 2), (3, 3)), + "trainable": layer.trainable, + } + self.assertEqual(layer.get_config(), expected_config) diff --git a/keras/src/layers/rnn/__init__.py b/keras/src/layers/rnn/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/layers/rnn/bidirectional.py b/keras/src/layers/rnn/bidirectional.py new file mode 100644 index 000000000000..39cbbcb52ee4 --- /dev/null +++ b/keras/src/layers/rnn/bidirectional.py @@ -0,0 +1,329 @@ +import copy + +from keras.src import ops +from keras.src import utils +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.saving import serialization_lib + + +@keras_export("keras.layers.Bidirectional") +class Bidirectional(Layer): + """Bidirectional wrapper for RNNs. + + Args: + layer: `keras.layers.RNN` instance, such as + `keras.layers.LSTM` or `keras.layers.GRU`. + It could also be a `keras.layers.Layer` instance + that meets the following criteria: + 1. Be a sequence-processing layer (accepts 3D+ inputs). + 2. Have a `go_backwards`, `return_sequences` and `return_state` + attribute (with the same semantics as for the `RNN` class). + 3. Have an `input_spec` attribute. + 4. Implement serialization via `get_config()` and `from_config()`. + Note that the recommended way to create new RNN layers is to write a + custom RNN cell and use it with `keras.layers.RNN`, instead of + subclassing `keras.layers.Layer` directly. + When `return_sequences` is `True`, the output of the masked + timestep will be zero regardless of the layer's original + `zero_output_for_mask` value. + merge_mode: Mode by which outputs of the forward and backward RNNs + will be combined. One of `{"sum", "mul", "concat", "ave", None}`. + If `None`, the outputs will not be combined, + they will be returned as a list. Defaults to `"concat"`. + backward_layer: Optional `keras.layers.RNN`, + or `keras.layers.Layer` instance to be used to handle + backwards input processing. + If `backward_layer` is not provided, the layer instance passed + as the `layer` argument will be used to generate the backward layer + automatically. + Note that the provided `backward_layer` layer should have properties + matching those of the `layer` argument, in particular + it should have the same values for `stateful`, `return_states`, + `return_sequences`, etc. In addition, `backward_layer` + and `layer` should have different `go_backwards` argument values. + A `ValueError` will be raised if these requirements are not met. + + Call arguments: + The call arguments for this layer are the same as those of the + wrapped RNN layer. Beware that when passing the `initial_state` + argument during the call of this layer, the first half in the + list of elements in the `initial_state` list will be passed to + the forward RNN call and the last half in the list of elements + will be passed to the backward RNN call. + + Note: instantiating a `Bidirectional` layer from an existing RNN layer + instance will not reuse the weights state of the RNN layer instance -- the + `Bidirectional` layer will have freshly initialized weights. + + Examples: + + ```python + model = Sequential([ + Input(shape=(5, 10)), + Bidirectional(LSTM(10, return_sequences=True), + Bidirectional(LSTM(10)), + Dense(5, activation="softmax"), + ]) + model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + + # With custom backward layer + forward_layer = LSTM(10, return_sequences=True) + backward_layer = LSTM(10, activation='relu', return_sequences=True, + go_backwards=True) + model = Sequential([ + Input(shape=(5, 10)), + Bidirectional(forward_layer, backward_layer=backward_layer), + Dense(5, activation="softmax"), + ]) + model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + ``` + """ + + def __init__( + self, + layer, + merge_mode="concat", + weights=None, + backward_layer=None, + **kwargs, + ): + if not isinstance(layer, Layer): + raise ValueError( + "Please initialize `Bidirectional` layer with a " + f"`keras.layers.Layer` instance. Received: {layer}" + ) + if backward_layer is not None and not isinstance(backward_layer, Layer): + raise ValueError( + "`backward_layer` need to be a `keras.layers.Layer` " + f"instance. Received: {backward_layer}" + ) + if merge_mode not in ["sum", "mul", "ave", "concat", None]: + raise ValueError( + f"Invalid merge mode. Received: {merge_mode}. " + "Merge mode should be one of " + '{"sum", "mul", "ave", "concat", None}' + ) + super().__init__(**kwargs) + + # Recreate the forward layer from the original layer config, so that it + # will not carry over any state from the layer. + config = serialization_lib.serialize_keras_object(layer) + config["config"]["name"] = ( + f"forward_{utils.removeprefix(layer.name, 'forward_')}" + ) + self.forward_layer = serialization_lib.deserialize_keras_object(config) + + if backward_layer is None: + config = serialization_lib.serialize_keras_object(layer) + config["config"]["go_backwards"] = True + config["config"]["name"] = ( + f"backward_{utils.removeprefix(layer.name, 'backward_')}" + ) + self.backward_layer = serialization_lib.deserialize_keras_object( + config + ) + else: + self.backward_layer = backward_layer + # Keep the use_cudnn attribute if defined (not serialized). + if hasattr(layer, "use_cudnn"): + self.forward_layer.use_cudnn = layer.use_cudnn + self.backward_layer.use_cudnn = layer.use_cudnn + self._verify_layer_config() + + def force_zero_output_for_mask(layer): + # Force the zero_output_for_mask to be True if returning sequences. + if getattr(layer, "zero_output_for_mask", None) is not None: + layer.zero_output_for_mask = layer.return_sequences + + force_zero_output_for_mask(self.forward_layer) + force_zero_output_for_mask(self.backward_layer) + + self.merge_mode = merge_mode + if weights: + nw = len(weights) + self.forward_layer.initial_weights = weights[: nw // 2] + self.backward_layer.initial_weights = weights[nw // 2 :] + self.stateful = layer.stateful + self.return_sequences = layer.return_sequences + self.return_state = layer.return_state + self.supports_masking = True + self.input_spec = layer.input_spec + + def _verify_layer_config(self): + """Ensure the forward and backward layers have valid common property.""" + if self.forward_layer.go_backwards == self.backward_layer.go_backwards: + raise ValueError( + "Forward layer and backward layer should have different " + "`go_backwards` value. Received: " + "forward_layer.go_backwards " + f"{self.forward_layer.go_backwards}, " + "backward_layer.go_backwards=" + f"{self.backward_layer.go_backwards}" + ) + + common_attributes = ("stateful", "return_sequences", "return_state") + for a in common_attributes: + forward_value = getattr(self.forward_layer, a) + backward_value = getattr(self.backward_layer, a) + if forward_value != backward_value: + raise ValueError( + "Forward layer and backward layer are expected to have " + f'the same value for attribute "{a}", got ' + f'"{forward_value}" for forward layer and ' + f'"{backward_value}" for backward layer' + ) + + def compute_output_shape(self, sequences_shape, initial_state_shape=None): + output_shape = self.forward_layer.compute_output_shape(sequences_shape) + + if self.return_state: + output_shape, state_shape = output_shape[0], output_shape[1:] + + if self.merge_mode == "concat": + output_shape = list(output_shape) + output_shape[-1] *= 2 + output_shape = tuple(output_shape) + elif self.merge_mode is None: + output_shape = [output_shape, output_shape] + + if self.return_state: + if self.merge_mode is None: + return tuple(output_shape) + state_shape + state_shape + return tuple([output_shape]) + (state_shape) + (state_shape) + return tuple(output_shape) + + def call( + self, + sequences, + initial_state=None, + mask=None, + training=None, + ): + kwargs = {} + if self.forward_layer._call_has_training_arg: + kwargs["training"] = training + if self.forward_layer._call_has_mask_arg: + kwargs["mask"] = mask + + if initial_state is not None: + # initial_states are not keras tensors, eg eager tensor from np + # array. They are only passed in from kwarg initial_state, and + # should be passed to forward/backward layer via kwarg + # initial_state as well. + forward_inputs, backward_inputs = sequences, sequences + half = len(initial_state) // 2 + forward_state = initial_state[:half] + backward_state = initial_state[half:] + else: + forward_inputs, backward_inputs = sequences, sequences + forward_state, backward_state = None, None + + y = self.forward_layer( + forward_inputs, initial_state=forward_state, **kwargs + ) + y_rev = self.backward_layer( + backward_inputs, initial_state=backward_state, **kwargs + ) + + if self.return_state: + states = tuple(y[1:] + y_rev[1:]) + y = y[0] + y_rev = y_rev[0] + + y = ops.cast(y, self.compute_dtype) + y_rev = ops.cast(y_rev, self.compute_dtype) + + if self.return_sequences: + y_rev = ops.flip(y_rev, axis=1) + if self.merge_mode == "concat": + output = ops.concatenate([y, y_rev], axis=-1) + elif self.merge_mode == "sum": + output = y + y_rev + elif self.merge_mode == "ave": + output = (y + y_rev) / 2 + elif self.merge_mode == "mul": + output = y * y_rev + elif self.merge_mode is None: + output = (y, y_rev) + else: + raise ValueError( + "Unrecognized value for `merge_mode`. " + f"Received: {self.merge_mode}" + 'Expected one of {"concat", "sum", "ave", "mul"}.' + ) + if self.return_state: + if self.merge_mode is None: + return output + states + return (output,) + states + return output + + def reset_states(self): + # Compatibility alias. + self.reset_state() + + def reset_state(self): + if not self.stateful: + raise AttributeError("Layer must be stateful.") + self.forward_layer.reset_state() + self.backward_layer.reset_state() + + @property + def states(self): + if self.forward_layer.states and self.backward_layer.states: + return tuple(self.forward_layer.states + self.backward_layer.states) + return None + + def build(self, sequences_shape, initial_state_shape=None): + if not self.forward_layer.built: + self.forward_layer.build(sequences_shape) + if not self.backward_layer.built: + self.backward_layer.build(sequences_shape) + + def compute_mask(self, _, mask): + if isinstance(mask, list): + mask = mask[0] + if self.return_sequences: + if not self.merge_mode: + output_mask = (mask, mask) + else: + output_mask = mask + else: + output_mask = (None, None) if not self.merge_mode else None + + if self.return_state and self.states is not None: + state_mask = (None for _ in self.states) + if isinstance(output_mask, list): + return output_mask + state_mask * 2 + return (output_mask,) + state_mask * 2 + return output_mask + + def get_config(self): + config = {"merge_mode": self.merge_mode} + config["layer"] = serialization_lib.serialize_keras_object( + self.forward_layer + ) + config["backward_layer"] = serialization_lib.serialize_keras_object( + self.backward_layer + ) + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + # Instead of updating the input, create a copy and use that. + config = copy.deepcopy(config) + + config["layer"] = serialization_lib.deserialize_keras_object( + config["layer"], custom_objects=custom_objects + ) + # Handle (optional) backward layer instantiation. + backward_layer_config = config.pop("backward_layer", None) + if backward_layer_config is not None: + backward_layer = serialization_lib.deserialize_keras_object( + backward_layer_config, custom_objects=custom_objects + ) + config["backward_layer"] = backward_layer + # Instantiate the wrapper, adjust it and return it. + layer = cls(**config) + return layer diff --git a/keras/src/layers/rnn/bidirectional_test.py b/keras/src/layers/rnn/bidirectional_test.py new file mode 100644 index 000000000000..aed4127c95ce --- /dev/null +++ b/keras/src/layers/rnn/bidirectional_test.py @@ -0,0 +1,277 @@ +import numpy as np +import pytest + +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class SimpleRNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.Bidirectional, + init_kwargs={"layer": layers.SimpleRNN(4)}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 8), + expected_num_trainable_weights=6, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.Bidirectional, + init_kwargs={ + "layer": layers.SimpleRNN(4), + "backward_layer": layers.SimpleRNN(4, go_backwards=True), + "merge_mode": "sum", + }, + input_shape=(3, 2, 4), + expected_output_shape=(3, 4), + expected_num_trainable_weights=6, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + def test_correctness(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + forward_layer = layers.SimpleRNN( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + layer = layers.Bidirectional( + layer=forward_layer, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.39687276, 0.39687276, 0.10004295, 0.10004295], + [0.7237238, 0.7237238, 0.53391594, 0.53391594], + ] + ), + output, + ) + + layer = layers.Bidirectional(layer=forward_layer, merge_mode="ave") + output = layer(sequence) + self.assertAllClose( + np.array([[0.24845785, 0.24845785], [0.6288199, 0.6288199]]), + output, + ) + + layer = layers.Bidirectional(layer=forward_layer, merge_mode=None) + output1, output2 = layer(sequence) + self.assertAllClose( + np.array([[0.39687276, 0.39687276], [0.7237238, 0.7237238]]), + output1, + ) + self.assertAllClose( + np.array([[0.10004295, 0.10004295], [0.53391594, 0.53391594]]), + output2, + ) + + backward_layer = layers.SimpleRNN( + 2, + kernel_initializer=initializers.Constant(0.03), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.01), + go_backwards=True, + ) + layer = layers.Bidirectional( + layer=forward_layer, backward_layer=backward_layer, merge_mode="mul" + ) + output = layer(sequence) + self.assertAllClose( + np.array([[0.08374989, 0.08374989], [0.6740834, 0.6740834]]), + output, + ) + + forward_layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + ) + layer = layers.Bidirectional(layer=forward_layer, merge_mode="sum") + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [ + [0.20937867, 0.20937867], + [0.34462988, 0.34462988], + [0.40290534, 0.40290534], + ], + [ + [0.59829646, 0.59829646], + [0.6734641, 0.6734641], + [0.6479671, 0.6479671], + ], + ] + ), + output, + ) + + def test_statefulness(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + forward_layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + stateful=True, + ) + layer = layers.Bidirectional(layer=forward_layer) + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.26234663, 0.26234663, 0.16959146, 0.16959146], + [0.6137073, 0.6137073, 0.5381646, 0.5381646], + ] + ), + output, + ) + layer.reset_state() + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.26234663, 0.26234663, 0.16959146, 0.16959146], + [0.6137073, 0.6137073, 0.5381646, 0.5381646], + ] + ), + output, + ) + + def test_pass_initial_state(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + initial_state = [ + np.arange(4).reshape((2, 2)).astype("float32") * 1, + np.arange(4).reshape((2, 2)).astype("float32") * 2, + np.arange(4).reshape((2, 2)).astype("float32") * 3, + np.arange(4).reshape((2, 2)).astype("float32") * 4, + ] + forward_layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + layer = layers.Bidirectional( + layer=forward_layer, + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array( + [ + [0.20794602, 0.4577124, 0.14046375, 0.48191673], + [0.6682636, 0.6711909, 0.60943645, 0.60950446], + ] + ), + output, + ) + + def test_masking(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + forward_layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + layer = layers.Bidirectional(layer=forward_layer) + mask = np.array([[True, True, False, True], [True, False, False, True]]) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.19393763, 0.19393763, 0.11669192, 0.11669192], + [0.30818558, 0.30818558, 0.28380975, 0.28380975], + ] + ), + output, + ) + + def test_return_state(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + forward_layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_state=True, + ) + layer = layers.Bidirectional(layer=forward_layer) + output, h1, c1, h2, c2 = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.1990008, 0.1990008, 0.12659755, 0.12659755], + [0.52335435, 0.52335435, 0.44717982, 0.44717982], + ] + ), + output, + ) + self.assertAllClose( + np.array([[0.1990008, 0.1990008], [0.52335435, 0.52335435]]), + h1, + ) + self.assertAllClose( + np.array([[0.35567185, 0.35567185], [1.0492687, 1.0492687]]), + c1, + ) + self.assertAllClose( + np.array([[0.12659755, 0.12659755], [0.44717982, 0.44717982]]), + h2, + ) + self.assertAllClose( + np.array([[0.2501858, 0.2501858], [0.941473, 0.941473]]), + c2, + ) + + @pytest.mark.requires_trainable_backend + def test_output_shape(self): + x = np.array([[[101, 202], [303, 404]]]) + for merge_mode in ["ave", "concat", "mul", "sum", None]: + sub_layer = layers.LSTM(2, return_state=True) + layer = layers.Bidirectional(sub_layer, merge_mode=merge_mode) + output = layer(x) + output_shape = layer.compute_output_shape(x.shape) + for out, shape in zip(output, output_shape): + self.assertEqual(out.shape, shape) + + for merge_mode in ["concat", "ave", "mul", "sum"]: + sub_layer = layers.LSTM(2, return_state=False) + layer = layers.Bidirectional(sub_layer, merge_mode=merge_mode) + output = layer(x) + output_shape = layer.compute_output_shape(x.shape) + self.assertEqual(output.shape, output_shape) + + # return_state=False & merge_mode=None + sub_layer = layers.LSTM(2, return_state=False) + layer = layers.Bidirectional(sub_layer, merge_mode=None) + output = layer(x) + output_shape = layer.compute_output_shape(x.shape) + for out, shape in zip(output, output_shape): + self.assertEqual(out.shape, shape) + + def test_keeps_use_cudnn(self): + # keep use_cudnn if the layer has it + for rnn_class in [layers.GRU, layers.LSTM]: + for use_cudnn in [True, False, "auto"]: + rnn = rnn_class(1, use_cudnn=use_cudnn) + bidi = layers.Bidirectional(rnn) + self.assertEqual(bidi.forward_layer.use_cudnn, use_cudnn) + self.assertEqual(bidi.backward_layer.use_cudnn, use_cudnn) + + # otherwise ignore it + rnn = layers.SimpleRNN(1) + bidi = layers.Bidirectional(rnn) + self.assertFalse(hasattr(bidi.forward_layer, "use_cudnn")) + self.assertFalse(hasattr(bidi.backward_layer, "use_cudnn")) diff --git a/keras/src/layers/rnn/conv_lstm.py b/keras/src/layers/rnn/conv_lstm.py new file mode 100644 index 000000000000..df82e5e5bf74 --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm.py @@ -0,0 +1,695 @@ +from keras.src import activations +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src import tree +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell +from keras.src.layers.rnn.rnn import RNN +from keras.src.ops import operation_utils +from keras.src.utils import argument_validation + + +class ConvLSTMCell(Layer, DropoutRNNCell): + """Cell class for the ConvLSTM layer. + + Args: + rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions. + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + dimensions of the convolution window. + strides: An integer or tuple/list of n integers, specifying the strides + of the convolution. Specifying any stride value != 1 + is incompatible with specifying any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly + to the left/right or up/down of the input such that output + has the same height/width dimension as the input. + data_format: A string, one of `channels_last` (default) or + `channels_first`. When unspecified, uses + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json` (if exists) else 'channels_last'. + Defaults to `'channels_last'`. + dilation_rate: An integer or tuple/list of n integers, specifying the + dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. If `None`, no activation is applied. + recurrent_activation: Activation function to use for the recurrent step. + use_bias: Boolean, (default `True`), whether the layer + should use a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the recurrent + state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + unit_forget_bias: Boolean (default `True`). If `True`, + add 1 to the bias of the forget gate at initialization. + Setting it to `True` will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al.]( + https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation"). Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + seed: Random seed for dropout. + + Call arguments: + inputs: A (2+ `rank`)D tensor. + states: List of state tensors corresponding to the previous timestep. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. Only relevant when `dropout` or + `recurrent_dropout` is used. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.seed = seed + self.seed_generator = backend.random.SeedGenerator(seed=seed) + self.rank = rank + if self.rank > 3: + raise ValueError( + f"Rank {rank} convolutions are not currently " + f"implemented. Received: rank={rank}" + ) + self.filters = filters + self.kernel_size = argument_validation.standardize_tuple( + kernel_size, self.rank, "kernel_size" + ) + self.strides = argument_validation.standardize_tuple( + strides, self.rank, "strides", allow_zero=True + ) + self.padding = argument_validation.standardize_padding(padding) + self.data_format = backend.standardize_data_format(data_format) + self.dilation_rate = argument_validation.standardize_tuple( + dilation_rate, self.rank, "dilation_rate" + ) + self.activation = activations.get(activation) + self.recurrent_activation = activations.get(recurrent_activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.unit_forget_bias = unit_forget_bias + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1.0, max(0.0, dropout)) + self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) + self.dropout_mask_count = 4 + self.input_spec = InputSpec(ndim=rank + 2) + self.state_size = -1 # Custom, defined in methods + + def build(self, inputs_shape, states_shape=None): + if self.data_format == "channels_first": + channel_axis = 1 + self.spatial_dims = inputs_shape[2:] + else: + channel_axis = -1 + self.spatial_dims = inputs_shape[1:-1] + if None in self.spatial_dims: + raise ValueError( + "ConvLSTM layers only support static " + "input shapes for the spatial dimension. " + f"Received invalid input shape: input_shape={inputs_shape}" + ) + if inputs_shape[channel_axis] is None: + raise ValueError( + "The channel dimension of the inputs (last axis) should be " + "defined. Found None. Full input shape received: " + f"input_shape={inputs_shape}" + ) + self.input_spec = InputSpec( + ndim=self.rank + 3, shape=(None,) + inputs_shape[1:] + ) + + input_dim = inputs_shape[channel_axis] + self.input_dim = input_dim + self.kernel_shape = self.kernel_size + (input_dim, self.filters * 4) + recurrent_kernel_shape = self.kernel_size + ( + self.filters, + self.filters * 4, + ) + + self.kernel = self.add_weight( + shape=self.kernel_shape, + initializer=self.kernel_initializer, + name="kernel", + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + self.recurrent_kernel = self.add_weight( + shape=recurrent_kernel_shape, + initializer=self.recurrent_initializer, + name="recurrent_kernel", + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint, + ) + + if self.use_bias: + if self.unit_forget_bias: + + def bias_initializer(_, *args, **kwargs): + return ops.concatenate( + [ + self.bias_initializer( + (self.filters,), *args, **kwargs + ), + initializers.get("ones")( + (self.filters,), *args, **kwargs + ), + self.bias_initializer( + (self.filters * 2,), *args, **kwargs + ), + ] + ) + + else: + bias_initializer = self.bias_initializer + self.bias = self.add_weight( + shape=(self.filters * 4,), + name="bias", + initializer=bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + + def call(self, inputs, states, training=False): + h_tm1 = states[0] # previous memory state + c_tm1 = states[1] # previous carry state + + if training and 0.0 < self.dropout < 1.0: + dp_mask = self.get_dropout_mask(inputs) + inputs_i = inputs * dp_mask[0] + inputs_f = inputs * dp_mask[1] + inputs_c = inputs * dp_mask[2] + inputs_o = inputs * dp_mask[3] + else: + inputs_i = inputs + inputs_f = inputs + inputs_c = inputs + inputs_o = inputs + + if training and 0.0 < self.recurrent_dropout < 1.0: + rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1) + h_tm1_i = h_tm1 * rec_dp_mask[0] + h_tm1_f = h_tm1 * rec_dp_mask[1] + h_tm1_c = h_tm1 * rec_dp_mask[2] + h_tm1_o = h_tm1 * rec_dp_mask[3] + else: + h_tm1_i = h_tm1 + h_tm1_f = h_tm1 + h_tm1_c = h_tm1 + h_tm1_o = h_tm1 + + (kernel_i, kernel_f, kernel_c, kernel_o) = ops.split( + self.kernel, 4, axis=self.rank + 1 + ) + ( + recurrent_kernel_i, + recurrent_kernel_f, + recurrent_kernel_c, + recurrent_kernel_o, + ) = ops.split(self.recurrent_kernel, 4, axis=self.rank + 1) + + if self.use_bias: + bias_i, bias_f, bias_c, bias_o = ops.split(self.bias, 4) + else: + bias_i, bias_f, bias_c, bias_o = None, None, None, None + + x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) + x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) + x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) + x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) + + h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) + h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) + h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) + h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) + + i = self.recurrent_activation(x_i + h_i) + f = self.recurrent_activation(x_f + h_f) + c = f * c_tm1 + i * self.activation(x_c + h_c) + o = self.recurrent_activation(x_o + h_o) + h = o * self.activation(c) + return h, [h, c] + + def compute_output_shape(self, inputs_shape, states_shape=None): + conv_output_shape = operation_utils.compute_conv_output_shape( + inputs_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + return conv_output_shape, [conv_output_shape, conv_output_shape] + + def get_initial_state(self, batch_size=None): + if self.data_format == "channels_last": + input_shape = (batch_size,) + self.spatial_dims + (self.input_dim,) + else: + input_shape = (batch_size, self.input_dim) + self.spatial_dims + state_shape = self.compute_output_shape(input_shape)[0] + return [ + ops.zeros(state_shape, dtype=self.compute_dtype), + ops.zeros(state_shape, dtype=self.compute_dtype), + ] + + def input_conv(self, x, w, b=None, padding="valid"): + conv_out = ops.conv( + x, + w, + strides=self.strides, + padding=padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + if b is not None: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(b, bias_shape) + conv_out += bias + return conv_out + + def recurrent_conv(self, x, w): + strides = argument_validation.standardize_tuple( + 1, self.rank, "strides", allow_zero=True + ) + conv_out = ops.conv( + x, w, strides=strides, padding="same", data_format=self.data_format + ) + return conv_out + + def get_config(self): + config = { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "activation": activations.serialize(self.activation), + "recurrent_activation": activations.serialize( + self.recurrent_activation + ), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "unit_forget_bias": self.unit_forget_bias, + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} + + +class ConvLSTM(RNN): + """Abstract N-D Convolutional LSTM layer (used as implementation base). + + Similar to an LSTM layer, but the input transformations + and recurrent transformations are both convolutional. + + Args: + rank: Integer, rank of the convolution, e.g. "2" for 2D convolutions. + filters: Integer, the dimensionality of the output space + (i.e. the number of output filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + dimensions of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the strides of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, time, ..., channels)` + while `channels_first` corresponds to + inputs with shape `(batch, time, channels, ...)`. + When unspecified, uses + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json` (if exists) else 'channels_last'. + Defaults to `'channels_last'`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function to use. + By default hyperbolic tangent activation function is applied + (`tanh(x)`). + recurrent_activation: Activation function to use + for the recurrent step. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + bias_initializer: Initializer for the bias vector. + unit_forget_bias: Boolean. + If True, add 1 to the bias of the forget gate at initialization. + Use in combination with `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al., 2015]( + http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix. + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to. + kernel_constraint: Constraint function applied to + the `kernel` weights matrix. + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + seed: Random seed for dropout. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. (default False) + return_state: Boolean Whether to return the last state + in addition to the output. (default False) + go_backwards: Boolean (default False). + If True, process the input sequence backwards. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + **kwargs, + ): + cell = ConvLSTMCell( + rank=rank, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + unit_forget_bias=unit_forget_bias, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + seed=seed, + name="conv_lstm_cell", + dtype=kwargs.get("dtype"), + ) + super().__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + **kwargs, + ) + self.input_spec = InputSpec(ndim=rank + 3) + + def call(self, sequences, initial_state=None, mask=None, training=False): + return super().call( + sequences, initial_state=initial_state, mask=mask, training=training + ) + + def compute_output_shape(self, sequences_shape, initial_state_shape=None): + batch_size = sequences_shape[0] + steps = sequences_shape[1] + step_shape = (batch_size,) + sequences_shape[2:] + state_shape = self.cell.compute_output_shape(step_shape)[0][1:] + + if self.return_sequences: + output_shape = ( + batch_size, + steps, + ) + state_shape + else: + output_shape = (batch_size,) + state_shape + + if self.return_state: + batched_state_shape = (batch_size,) + state_shape + return output_shape, batched_state_shape, batched_state_shape + return output_shape + + def compute_mask(self, _, mask): + mask = tree.flatten(mask)[0] + output_mask = mask if self.return_sequences else None + if self.return_state: + state_mask = [None, None] + return [output_mask] + state_mask + else: + return output_mask + + @property + def filters(self): + return self.cell.filters + + @property + def kernel_size(self): + return self.cell.kernel_size + + @property + def strides(self): + return self.cell.strides + + @property + def padding(self): + return self.cell.padding + + @property + def data_format(self): + return self.cell.data_format + + @property + def dilation_rate(self): + return self.cell.dilation_rate + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def unit_forget_bias(self): + return self.cell.unit_forget_bias + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + def get_config(self): + config = { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "data_format": self.data_format, + "dilation_rate": self.dilation_rate, + "activation": activations.serialize(self.activation), + "recurrent_activation": activations.serialize( + self.recurrent_activation + ), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "unit_forget_bias": self.unit_forget_bias, + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "seed": self.cell.seed, + } + base_config = super().get_config() + del base_config["cell"] + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/keras/src/layers/rnn/conv_lstm1d.py b/keras/src/layers/rnn/conv_lstm1d.py new file mode 100644 index 000000000000..2d68eb748a40 --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm1d.py @@ -0,0 +1,184 @@ +from keras.src.api_export import keras_export +from keras.src.layers.rnn.conv_lstm import ConvLSTM + + +@keras_export("keras.layers.ConvLSTM1D") +class ConvLSTM1D(ConvLSTM): + """1D Convolutional LSTM. + + Similar to an LSTM layer, but the input transformations + and recurrent transformations are both convolutional. + + Args: + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of 1 integer, specifying the size of + the convolution window. + strides: int or tuple/list of 1 integer, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the + same height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 1 integers, specifying the dilation + rate to use for dilated convolution. + activation: Activation function to use. By default hyperbolic tangent + activation function is applied (`tanh(x)`). + recurrent_activation: Activation function to use for the recurrent step. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + recurrent_initializer: Initializer for the `recurrent_kernel` weights + matrix, used for the linear transformation of the recurrent state. + bias_initializer: Initializer for the bias vector. + unit_forget_bias: Boolean. If `True`, add 1 to the bias of + the forget gate at initialization. + Use in combination with `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al., 2015]( + http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. + seed: Random seed for dropout. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. Default: `False`. + return_state: Boolean. Whether to return the last state in addition + to the output. Default: `False`. + go_backwards: Boolean (default: `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If `True`, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default: `False`). + If `True`, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + + Call arguments: + inputs: A 4D tensor. + initial_state: List of initial state tensors to be passed to the first + call of the cell. + mask: Binary tensor of shape `(samples, timesteps)` indicating whether a + given timestep should be masked. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. + This is only relevant if `dropout` or `recurrent_dropout` are set. + + Input shape: + + - If `data_format="channels_first"`: + 4D tensor with shape: `(samples, time, channels, rows)` + - If `data_format="channels_last"`: + 4D tensor with shape: `(samples, time, rows, channels)` + + Output shape: + + - If `return_state`: a list of tensors. The first tensor is the output. + The remaining tensors are the last states, + each 3D tensor with shape: `(samples, filters, new_rows)` if + `data_format='channels_first'` + or shape: `(samples, new_rows, filters)` if + `data_format='channels_last'`. + `rows` values might have changed due to padding. + - If `return_sequences`: 4D tensor with shape: `(samples, timesteps, + filters, new_rows)` if data_format='channels_first' + or shape: `(samples, timesteps, new_rows, filters)` if + `data_format='channels_last'`. + - Else, 3D tensor with shape: `(samples, filters, new_rows)` if + `data_format='channels_first'` + or shape: `(samples, new_rows, filters)` if + `data_format='channels_last'`. + + References: + + - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1) + (the current implementation does not include the feedback loop on the + cells output). + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + **kwargs, + ): + super().__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + unit_forget_bias=unit_forget_bias, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + seed=seed, + **kwargs, + ) diff --git a/keras/src/layers/rnn/conv_lstm1d_test.py b/keras/src/layers/rnn/conv_lstm1d_test.py new file mode 100644 index 000000000000..b69cbf8b55aa --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm1d_test.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class ConvLSTM1DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + channels_last = backend.config.image_data_format() == "channels_last" + self.run_layer_test( + layers.ConvLSTM1D, + init_kwargs={"filters": 5, "kernel_size": 3, "padding": "same"}, + input_shape=(3, 2, 4, 3) if channels_last else (3, 2, 3, 4), + expected_output_shape=(3, 4, 5) if channels_last else (3, 5, 4), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.ConvLSTM1D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "padding": "valid", + "recurrent_dropout": 0.5, + }, + input_shape=(3, 2, 8, 3) if channels_last else (3, 2, 3, 8), + call_kwargs={"training": True}, + expected_output_shape=(3, 6, 5) if channels_last else (3, 5, 6), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.ConvLSTM1D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "padding": "valid", + "return_sequences": True, + }, + input_shape=(3, 2, 8, 3) if channels_last else (3, 2, 3, 8), + expected_output_shape=( + (3, 2, 6, 5) if channels_last else (3, 2, 5, 6) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + def test_correctness(self): + sequence = np.arange(120).reshape((2, 3, 4, 5)).astype("float32") / 10 + expected_output = np.array( + [ + [[0.40807986, 0.40807986], [0.46421072, 0.46421072]], + [[0.80933154, 0.80933154], [0.8233646, 0.8233646]], + ] + ) + if backend.config.image_data_format() == "channels_first": + sequence = sequence.transpose((0, 1, 3, 2)) + expected_output = expected_output.transpose((0, 2, 1)) + layer = layers.ConvLSTM1D( + filters=2, + kernel_size=3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence) + self.assertAllClose( + expected_output, + output, + ) diff --git a/keras/src/layers/rnn/conv_lstm2d.py b/keras/src/layers/rnn/conv_lstm2d.py new file mode 100644 index 000000000000..5e14eadc25aa --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm2d.py @@ -0,0 +1,184 @@ +from keras.src.api_export import keras_export +from keras.src.layers.rnn.conv_lstm import ConvLSTM + + +@keras_export("keras.layers.ConvLSTM2D") +class ConvLSTM2D(ConvLSTM): + """2D Convolutional LSTM. + + Similar to an LSTM layer, but the input transformations + and recurrent transformations are both convolutional. + + Args: + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of 2 integers, specifying the size of the + convolution window. + strides: int or tuple/list of 2 integers, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 2 integers, specifying the dilation + rate to use for dilated convolution. + activation: Activation function to use. By default hyperbolic tangent + activation function is applied (`tanh(x)`). + recurrent_activation: Activation function to use for the recurrent step. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + recurrent_initializer: Initializer for the `recurrent_kernel` weights + matrix, used for the linear transformation of the recurrent state. + bias_initializer: Initializer for the bias vector. + unit_forget_bias: Boolean. If `True`, add 1 to the bias of the forget + gate at initialization. + Use in combination with `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al., 2015]( + http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. + seed: Random seed for dropout. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. Default: `False`. + return_state: Boolean. Whether to return the last state in addition + to the output. Default: `False`. + go_backwards: Boolean (default: `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If `True`, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default: `False`). + If `True`, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + + Call arguments: + inputs: A 5D tensor. + mask: Binary tensor of shape `(samples, timesteps)` indicating whether a + given timestep should be masked. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. + This is only relevant if `dropout` or `recurrent_dropout` are set. + initial_state: List of initial state tensors to be passed to the first + call of the cell. + + Input shape: + + - If `data_format='channels_first'`: + 5D tensor with shape: `(samples, time, channels, rows, cols)` + - If `data_format='channels_last'`: + 5D tensor with shape: `(samples, time, rows, cols, channels)` + + Output shape: + + - If `return_state`: a list of tensors. The first tensor is the output. + The remaining tensors are the last states, + each 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if + `data_format='channels_first'` + or shape: `(samples, new_rows, new_cols, filters)` if + `data_format='channels_last'`. `rows` and `cols` values might have + changed due to padding. + - If `return_sequences`: 5D tensor with shape: `(samples, timesteps, + filters, new_rows, new_cols)` if data_format='channels_first' + or shape: `(samples, timesteps, new_rows, new_cols, filters)` if + `data_format='channels_last'`. + - Else, 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if + `data_format='channels_first'` + or shape: `(samples, new_rows, new_cols, filters)` if + `data_format='channels_last'`. + + References: + + - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1) + (the current implementation does not include the feedback loop on the + cells output). + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + **kwargs, + ): + super().__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + unit_forget_bias=unit_forget_bias, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + seed=seed, + **kwargs, + ) diff --git a/keras/src/layers/rnn/conv_lstm2d_test.py b/keras/src/layers/rnn/conv_lstm2d_test.py new file mode 100644 index 000000000000..b3846b64058c --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm2d_test.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class ConvLSTM2DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + channels_last = backend.config.image_data_format() == "channels_last" + self.run_layer_test( + layers.ConvLSTM2D, + init_kwargs={"filters": 5, "kernel_size": 3, "padding": "same"}, + input_shape=(3, 2, 4, 4, 3) if channels_last else (3, 2, 3, 4, 4), + expected_output_shape=( + (3, 4, 4, 5) if channels_last else (3, 5, 4, 4) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.ConvLSTM2D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "padding": "valid", + "recurrent_dropout": 0.5, + }, + input_shape=(3, 2, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8), + call_kwargs={"training": True}, + expected_output_shape=( + (3, 6, 6, 5) if channels_last else (3, 5, 6, 6) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.ConvLSTM2D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "padding": "valid", + "return_sequences": True, + }, + input_shape=(3, 2, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8), + expected_output_shape=( + (3, 2, 6, 6, 5) if channels_last else (3, 2, 5, 6, 6) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + def test_correctness(self): + sequence = ( + np.arange(480).reshape((2, 3, 4, 4, 5)).astype("float32") / 100 + ) + expected_output = np.array( + [ + [ + [[0.48694518, 0.48694518], [0.50237733, 0.50237733]], + [[0.5461202, 0.5461202], [0.5598283, 0.5598283]], + ], + [ + [[0.8661607, 0.8661607], [0.86909103, 0.86909103]], + [[0.8774414, 0.8774414], [0.8800861, 0.8800861]], + ], + ] + ) + if backend.config.image_data_format() == "channels_first": + sequence = sequence.transpose((0, 1, 4, 2, 3)) + expected_output = expected_output.transpose((0, 3, 1, 2)) + layer = layers.ConvLSTM2D( + filters=2, + kernel_size=3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence) + self.assertAllClose( + expected_output, + output, + ) diff --git a/keras/src/layers/rnn/conv_lstm3d.py b/keras/src/layers/rnn/conv_lstm3d.py new file mode 100644 index 000000000000..a36ed1dddf92 --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm3d.py @@ -0,0 +1,183 @@ +from keras.src.api_export import keras_export +from keras.src.layers.rnn.conv_lstm import ConvLSTM + + +@keras_export("keras.layers.ConvLSTM3D") +class ConvLSTM3D(ConvLSTM): + """3D Convolutional LSTM. + + Similar to an LSTM layer, but the input transformations + and recurrent transformations are both convolutional. + + Args: + filters: int, the dimension of the output space (the number of filters + in the convolution). + kernel_size: int or tuple/list of 3 integers, specifying the size of the + convolution window. + strides: int or tuple/list of 3 integers, specifying the stride length + of the convolution. `strides > 1` is incompatible with + `dilation_rate > 1`. + padding: string, `"valid"` or `"same"` (case-insensitive). + `"valid"` means no padding. `"same"` results in padding evenly to + the left/right or up/down of the input such that output has the same + height/width dimension as the input. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, steps, features)` + while `"channels_first"` corresponds to inputs with shape + `(batch, features, steps)`. It defaults to the `image_data_format` + value found in your Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be `"channels_last"`. + dilation_rate: int or tuple/list of 3 integers, specifying the dilation + rate to use for dilated convolution. + activation: Activation function to use. By default hyperbolic tangent + activation function is applied (`tanh(x)`). + recurrent_activation: Activation function to use for the recurrent step. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + recurrent_initializer: Initializer for the `recurrent_kernel` weights + matrix, used for the linear transformation of the recurrent state. + bias_initializer: Initializer for the bias vector. + unit_forget_bias: Boolean. If `True`, add 1 to the bias of the forget + gate at initialization. + Use in combination with `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al., 2015]( + http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. + seed: Random seed for dropout. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. Default: `False`. + return_state: Boolean. Whether to return the last state in addition + to the output. Default: `False`. + go_backwards: Boolean (default: `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If `True`, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default: `False`). + If `True`, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + + Call arguments: + inputs: A 6D tensor. + mask: Binary tensor of shape `(samples, timesteps)` indicating whether a + given timestep should be masked. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. + This is only relevant if `dropout` or `recurrent_dropout` are set. + initial_state: List of initial state tensors to be passed to the first + call of the cell. + + Input shape: + + - If `data_format='channels_first'`: + 5D tensor with shape: `(samples, time, channels, *spatial_dims)` + - If `data_format='channels_last'`: + 5D tensor with shape: `(samples, time, *spatial_dims, channels)` + + Output shape: + + - If `return_state`: a list of tensors. The first tensor is the output. + The remaining tensors are the last states, + each 4D tensor with shape: `(samples, filters, *spatial_dims)` if + `data_format='channels_first'` + or shape: `(samples, *spatial_dims, filters)` if + `data_format='channels_last'`. + - If `return_sequences`: 5D tensor with shape: `(samples, timesteps, + filters, *spatial_dims)` if data_format='channels_first' + or shape: `(samples, timesteps, *spatial_dims, filters)` if + `data_format='channels_last'`. + - Else, 4D tensor with shape: `(samples, filters, *spatial_dims)` if + `data_format='channels_first'` + or shape: `(samples, *spatial_dims, filters)` if + `data_format='channels_last'`. + + References: + + - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1) + (the current implementation does not include the feedback loop on the + cells output). + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + **kwargs, + ): + super().__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + unit_forget_bias=unit_forget_bias, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + seed=seed, + **kwargs, + ) diff --git a/keras/src/layers/rnn/conv_lstm3d_test.py b/keras/src/layers/rnn/conv_lstm3d_test.py new file mode 100644 index 000000000000..b6c23326539f --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm3d_test.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class ConvLSTM1DTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + channels_last = backend.config.image_data_format() == "channels_last" + self.run_layer_test( + layers.ConvLSTM3D, + init_kwargs={"filters": 5, "kernel_size": 3, "padding": "same"}, + input_shape=( + (3, 2, 4, 4, 4, 3) if channels_last else (3, 2, 3, 4, 4, 4) + ), + expected_output_shape=( + (3, 4, 4, 4, 5) if channels_last else (3, 5, 4, 4, 4) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.ConvLSTM3D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "padding": "valid", + "recurrent_dropout": 0.5, + }, + input_shape=( + (3, 2, 8, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8, 8) + ), + call_kwargs={"training": True}, + expected_output_shape=( + (3, 6, 6, 6, 5) if channels_last else (3, 5, 6, 6, 6) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.ConvLSTM3D, + init_kwargs={ + "filters": 5, + "kernel_size": 3, + "padding": "valid", + "return_sequences": True, + }, + input_shape=( + (3, 2, 8, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8, 8) + ), + expected_output_shape=( + (3, 2, 6, 6, 6, 5) if channels_last else (3, 2, 5, 6, 6, 6) + ), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + def test_correctness(self): + sequence = ( + np.arange(1920).reshape((2, 3, 4, 4, 4, 5)).astype("float32") / 100 + ) + expected_output = np.array( + [ + [ + [ + [[0.99149036, 0.99149036], [0.99180907, 0.99180907]], + [[0.99258363, 0.99258363], [0.9927925, 0.9927925]], + ], + [ + [[0.99413764, 0.99413764], [0.99420583, 0.99420583]], + [[0.9943788, 0.9943788], [0.9944278, 0.9944278]], + ], + ], + [ + [ + [[0.9950547, 0.9950547], [0.9950547, 0.9950547]], + [[0.9950547, 0.9950547], [0.9950547, 0.9950547]], + ], + [ + [[0.9950547, 0.9950547], [0.9950547, 0.9950547]], + [[0.9950547, 0.9950547], [0.9950547, 0.9950547]], + ], + ], + ] + ) + if backend.config.image_data_format() == "channels_first": + sequence = sequence.transpose((0, 1, 5, 2, 3, 4)) + expected_output = expected_output.transpose((0, 4, 1, 2, 3)) + layer = layers.ConvLSTM3D( + filters=2, + kernel_size=3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence) + self.assertAllClose( + expected_output, + output, + ) diff --git a/keras/src/layers/rnn/conv_lstm_test.py b/keras/src/layers/rnn/conv_lstm_test.py new file mode 100644 index 000000000000..e66fed91b62c --- /dev/null +++ b/keras/src/layers/rnn/conv_lstm_test.py @@ -0,0 +1,57 @@ +import numpy as np + +from keras.src import backend +from keras.src import initializers +from keras.src import testing +from keras.src.layers.rnn.conv_lstm import ConvLSTM +from keras.src.layers.rnn.conv_lstm import ConvLSTMCell + + +class ConvLSTMCellTest(testing.TestCase): + def test_correctness(self): + x = np.arange(150).reshape((2, 5, 5, 3)).astype("float32") / 10 + s1 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 10 + s2 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 10 + + if backend.config.image_data_format() == "channels_first": + x = x.transpose((0, 3, 1, 2)) + s1 = s1.transpose((0, 3, 1, 2)) + s2 = s2.transpose((0, 3, 1, 2)) + layer = ConvLSTMCell( + rank=2, + filters=4, + kernel_size=3, + padding="same", + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + ) + output = layer(x, [s1, s2]) + checksum_0 = np.sum(backend.convert_to_numpy(output[0])) + self.assertAllClose(checksum_0, 188.89502) + checksum_1 = np.sum(backend.convert_to_numpy(output[1][0])) + self.assertAllClose(checksum_1, 188.89502) + checksum_2 = np.sum(backend.convert_to_numpy(output[1][1])) + self.assertAllClose(checksum_2, 2170.444) + + +class ConvLSTMTest(testing.TestCase): + def test_correctness(self): + x = np.arange(450).reshape((2, 3, 5, 5, 3)).astype("float32") / 100 + s1 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 100 + s2 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 100 + + if backend.config.image_data_format() == "channels_first": + x = x.transpose((0, 1, 4, 2, 3)) + s1 = s1.transpose((0, 3, 1, 2)) + s2 = s2.transpose((0, 3, 1, 2)) + layer = ConvLSTM( + rank=2, + filters=4, + kernel_size=3, + padding="same", + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + ) + output = layer(x, initial_state=[s1, s2]) + output = backend.convert_to_numpy(output) + self.assertAllClose(np.sum(output), 119.812454) diff --git a/keras/src/layers/rnn/dropout_rnn_cell.py b/keras/src/layers/rnn/dropout_rnn_cell.py new file mode 100644 index 000000000000..3dd39b0ca6b2 --- /dev/null +++ b/keras/src/layers/rnn/dropout_rnn_cell.py @@ -0,0 +1,66 @@ +from keras.src import backend +from keras.src import ops + + +class DropoutRNNCell: + """Object that holds dropout-related functionality for RNN cells. + + This class is not a standalone RNN cell. It suppose to be used with a RNN + cell by multiple inheritance. Any cell that mix with class should have + following fields: + + - `dropout`: a float number in the range `[0, 1]`. + Dropout rate for the input tensor. + - `recurrent_dropout`: a float number in the range `[0, 1]`. + Dropout rate for the recurrent connections. + - `seed_generator`, an instance of `backend.random.SeedGenerator`. + + This object will create and cache dropout masks, and reuse them for + all incoming steps, so that the same mask is used for every step. + """ + + def _create_dropout_mask(self, step_input, dropout_rate): + count = getattr(self, "dropout_mask_count", None) + ones = ops.ones_like(step_input) + if count is None: + return backend.random.dropout( + ones, rate=dropout_rate, seed=self.seed_generator + ) + else: + return [ + backend.random.dropout( + ones, rate=dropout_rate, seed=self.seed_generator + ) + for _ in range(count) + ] + + def get_dropout_mask(self, step_input): + if not hasattr(self, "_dropout_mask"): + self._dropout_mask = None + if self._dropout_mask is None and self.dropout > 0: + self._dropout_mask = self._create_dropout_mask( + step_input, self.dropout + ) + return self._dropout_mask + + def get_recurrent_dropout_mask(self, step_input): + if not hasattr(self, "_recurrent_dropout_mask"): + self._recurrent_dropout_mask = None + if self._recurrent_dropout_mask is None and self.recurrent_dropout > 0: + self._recurrent_dropout_mask = self._create_dropout_mask( + step_input, self.recurrent_dropout + ) + return self._recurrent_dropout_mask + + def reset_dropout_mask(self): + """Reset the cached dropout mask if any. + + The RNN layer invokes this in the `call()` method + so that the cached mask is cleared after calling `cell.call()`. The + mask should be cached across all timestep within the same batch, but + shouldn't be cached between batches. + """ + self._dropout_mask = None + + def reset_recurrent_dropout_mask(self): + self._recurrent_dropout_mask = None diff --git a/keras/src/layers/rnn/dropout_rnn_cell_test.py b/keras/src/layers/rnn/dropout_rnn_cell_test.py new file mode 100644 index 000000000000..cf94aa67fd52 --- /dev/null +++ b/keras/src/layers/rnn/dropout_rnn_cell_test.py @@ -0,0 +1,92 @@ +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import testing +from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell + + +class RNNCellWithDropout(layers.Layer, DropoutRNNCell): + def __init__( + self, units, dropout=0.5, recurrent_dropout=0.5, seed=None, **kwargs + ): + super().__init__(**kwargs) + self.seed = seed + self.seed_generator = backend.random.SeedGenerator(seed) + self.units = units + self.state_size = units + self.dropout = dropout + self.recurrent_dropout = recurrent_dropout + + def build(self, input_shape): + self.kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="ones", + name="kernel", + ) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer="ones", + name="recurrent_kernel", + ) + + def call(self, inputs, states, training=False): + if training: + dp_mask = self.get_dropout_mask(inputs) + inputs = inputs * dp_mask + prev_output = states[0] + h = ops.matmul(inputs, self.kernel) + if training: + rdp_mask = self.get_recurrent_dropout_mask(prev_output) + prev_output = prev_output * rdp_mask + output = h + ops.matmul(prev_output, self.recurrent_kernel) + return output, [output] + + +class DropoutRNNCellTest(testing.TestCase): + def test_seed_tracking(self): + cell = RNNCellWithDropout(3, seed=1337) + self.assertEqual(len(cell.non_trainable_variables), 1) + layer = layers.RNN(cell) + self.assertEqual(len(layer.non_trainable_variables), 1) + + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": RNNCellWithDropout(5, seed=1337)}, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_non_trainable_variables=1, + supports_masking=True, + run_mixed_precision_check=False, + ) + + # manually set dtype to mixed_float16 to run mixed precision check + run_mixed_precision_check = True + if backend.backend() == "torch": + import torch + + run_mixed_precision_check = torch.cuda.is_available() + if run_mixed_precision_check: + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": RNNCellWithDropout( + 5, seed=1337, dtype="mixed_float16" + ), + "dtype": "mixed_float16", + }, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_non_trainable_variables=1, + supports_masking=True, + run_mixed_precision_check=False, + ) diff --git a/keras/src/layers/rnn/gru.py b/keras/src/layers/rnn/gru.py new file mode 100644 index 000000000000..3a6abd2d1cbb --- /dev/null +++ b/keras/src/layers/rnn/gru.py @@ -0,0 +1,710 @@ +from keras.src import activations +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell +from keras.src.layers.rnn.rnn import RNN + + +@keras_export("keras.layers.GRUCell") +class GRUCell(Layer, DropoutRNNCell): + """Cell class for the GRU layer. + + This class processes one step within the whole time sequence input, whereas + `keras.layer.GRU` processes the whole sequence. + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. Default: hyperbolic tangent + (`tanh`). If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use for the recurrent step. + Default: sigmoid (`sigmoid`). If you pass `None`, no activation is + applied (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer + should use a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation + of the recurrent state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + reset_after: GRU convention (whether to apply reset gate after or + before matrix multiplication). False = "before", + True = "after" (default and cuDNN compatible). + seed: Random seed for dropout. + + Call arguments: + inputs: A 2D tensor, with shape `(batch, features)`. + states: A 2D tensor with shape `(batch, units)`, which is the state + from the previous time step. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. Only relevant when `dropout` or + `recurrent_dropout` is used. + + Example: + + >>> inputs = np.random.random((32, 10, 8)) + >>> rnn = keras.layers.RNN(keras.layers.GRUCell(4)) + >>> output = rnn(inputs) + >>> output.shape + (32, 4) + >>> rnn = keras.layers.RNN( + ... keras.layers.GRUCell(4), + ... return_sequences=True, + ... return_state=True) + >>> whole_sequence_output, final_state = rnn(inputs) + >>> whole_sequence_output.shape + (32, 10, 4) + >>> final_state.shape + (32, 4) + """ + + def __init__( + self, + units, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + reset_after=True, + seed=None, + **kwargs, + ): + if units <= 0: + raise ValueError( + "Received an invalid value for argument `units`, " + f"expected a positive integer, got {units}." + ) + implementation = kwargs.pop("implementation", 2) + super().__init__(**kwargs) + self.implementation = implementation + self.units = units + self.activation = activations.get(activation) + self.recurrent_activation = activations.get(recurrent_activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1.0, max(0.0, dropout)) + self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) + if self.recurrent_dropout != 0.0: + self.implementation = 1 + if self.implementation == 1: + self.dropout_mask_count = 3 + self.seed = seed + self.seed_generator = backend.random.SeedGenerator(seed=seed) + + self.reset_after = reset_after + self.state_size = self.units + self.output_size = self.units + + def build(self, input_shape): + super().build(input_shape) + input_dim = input_shape[-1] + self.kernel = self.add_weight( + shape=(input_dim, self.units * 3), + name="kernel", + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units * 3), + name="recurrent_kernel", + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint, + ) + + if self.use_bias: + if not self.reset_after: + bias_shape = (3 * self.units,) + else: + # separate biases for input and recurrent kernels + # Note: the shape is intentionally different from CuDNNGRU + # biases `(2 * 3 * self.units,)`, so that we can distinguish the + # classes when loading and converting saved weights. + bias_shape = (2, 3 * self.units) + self.bias = self.add_weight( + shape=bias_shape, + name="bias", + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + + def call(self, inputs, states, training=False): + h_tm1 = ( + states[0] if tree.is_nested(states) else states + ) # previous state + + if self.use_bias: + if not self.reset_after: + input_bias, recurrent_bias = self.bias, None + else: + input_bias, recurrent_bias = ( + ops.squeeze(e, axis=0) + for e in ops.split(self.bias, self.bias.shape[0], axis=0) + ) + + if self.implementation == 1: + if training and 0.0 < self.dropout < 1.0: + dp_mask = self.get_dropout_mask(inputs) + inputs_z = inputs * dp_mask[0] + inputs_r = inputs * dp_mask[1] + inputs_h = inputs * dp_mask[2] + else: + inputs_z = inputs + inputs_r = inputs + inputs_h = inputs + + x_z = ops.matmul(inputs_z, self.kernel[:, : self.units]) + x_r = ops.matmul( + inputs_r, self.kernel[:, self.units : self.units * 2] + ) + x_h = ops.matmul(inputs_h, self.kernel[:, self.units * 2 :]) + + if self.use_bias: + x_z += input_bias[: self.units] + x_r += input_bias[self.units : self.units * 2] + x_h += input_bias[self.units * 2 :] + + if training and 0.0 < self.recurrent_dropout < 1.0: + rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1) + h_tm1_z = h_tm1 * rec_dp_mask[0] + h_tm1_r = h_tm1 * rec_dp_mask[1] + h_tm1_h = h_tm1 * rec_dp_mask[2] + else: + h_tm1_z = h_tm1 + h_tm1_r = h_tm1 + h_tm1_h = h_tm1 + + recurrent_z = ops.matmul( + h_tm1_z, self.recurrent_kernel[:, : self.units] + ) + recurrent_r = ops.matmul( + h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2] + ) + if self.reset_after and self.use_bias: + recurrent_z += recurrent_bias[: self.units] + recurrent_r += recurrent_bias[self.units : self.units * 2] + + z = self.recurrent_activation(x_z + recurrent_z) + r = self.recurrent_activation(x_r + recurrent_r) + + # reset gate applied after/before matrix multiplication + if self.reset_after: + recurrent_h = ops.matmul( + h_tm1_h, self.recurrent_kernel[:, self.units * 2 :] + ) + if self.use_bias: + recurrent_h += recurrent_bias[self.units * 2 :] + recurrent_h = r * recurrent_h + else: + recurrent_h = ops.matmul( + r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :] + ) + + hh = self.activation(x_h + recurrent_h) + else: + if training and 0.0 < self.dropout < 1.0: + dp_mask = self.get_dropout_mask(inputs) + inputs = inputs * dp_mask + + # inputs projected by all gate matrices at once + matrix_x = ops.matmul(inputs, self.kernel) + if self.use_bias: + # biases: bias_z_i, bias_r_i, bias_h_i + matrix_x = ops.add(matrix_x, input_bias) + + x_z, x_r, x_h = ops.split(matrix_x, 3, axis=-1) + + if self.reset_after: + # hidden state projected by all gate matrices at once + matrix_inner = ops.matmul(h_tm1, self.recurrent_kernel) + if self.use_bias: + matrix_inner += recurrent_bias + else: + # hidden state projected separately for update/reset and new + matrix_inner = ops.matmul( + h_tm1, self.recurrent_kernel[:, : 2 * self.units] + ) + + recurrent_z = matrix_inner[:, : self.units] + recurrent_r = matrix_inner[:, self.units : self.units * 2] + recurrent_h = matrix_inner[:, self.units * 2 :] + + z = self.recurrent_activation(x_z + recurrent_z) + r = self.recurrent_activation(x_r + recurrent_r) + + if self.reset_after: + recurrent_h = r * recurrent_h + else: + recurrent_h = ops.matmul( + r * h_tm1, self.recurrent_kernel[:, 2 * self.units :] + ) + + hh = self.activation(x_h + recurrent_h) + + # previous and candidate state mixed by update gate + h = z * h_tm1 + (1 - z) * hh + new_state = [h] if tree.is_nested(states) else h + return h, new_state + + def get_config(self): + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "recurrent_activation": activations.serialize( + self.recurrent_activation + ), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "reset_after": self.reset_after, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} + + def get_initial_state(self, batch_size=None): + return [ + ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype) + ] + + +@keras_export("keras.layers.GRU") +class GRU(RNN): + """Gated Recurrent Unit - Cho et al. 2014. + + Based on available runtime hardware and constraints, this layer + will choose different implementations (cuDNN-based or backend-native) + to maximize the performance. If a GPU is available and all + the arguments to the layer meet the requirement of the cuDNN kernel + (see below for details), the layer will use a fast cuDNN implementation + when using the TensorFlow backend. + + The requirements to use the cuDNN implementation are: + + 1. `activation` == `tanh` + 2. `recurrent_activation` == `sigmoid` + 3. `recurrent_dropout` == 0 + 4. `unroll` is `False` + 5. `use_bias` is `True` + 6. `reset_after` is `True` + 7. Inputs, if use masking, are strictly right-padded. + 8. Eager execution is enabled in the outermost context. + + There are two variants of the GRU implementation. The default one is based + on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to + hidden state before matrix multiplication. The other one is based on + [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed. + + The second variant is compatible with CuDNNGRU (GPU-only) and allows + inference on CPU. Thus it has separate biases for `kernel` and + `recurrent_kernel`. To use this variant, set `reset_after=True` and + `recurrent_activation='sigmoid'`. + + For example: + + >>> inputs = np.random.random((32, 10, 8)) + >>> gru = keras.layers.GRU(4) + >>> output = gru(inputs) + >>> output.shape + (32, 4) + >>> gru = keras.layers.GRU(4, return_sequences=True, return_state=True) + >>> whole_sequence_output, final_state = gru(inputs) + >>> whole_sequence_output.shape + (32, 10, 4) + >>> final_state.shape + (32, 4) + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step. + Default: sigmoid (`sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer + should use a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the recurrent + state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation"). Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + seed: Random seed for dropout. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. Default: `False`. + return_state: Boolean. Whether to return the last state in addition + to the output. Default: `False`. + go_backwards: Boolean (default `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default: `False`). If `True`, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default: `False`). + If `True`, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + reset_after: GRU convention (whether to apply reset gate after or + before matrix multiplication). `False` is `"before"`, + `True` is `"after"` (default and cuDNN compatible). + use_cudnn: Whether to use a cuDNN-backed implementation. `"auto"` will + attempt to use cuDNN when feasible, and will fallback to the + default implementation if not. + + Call arguments: + inputs: A 3D tensor, with shape `(batch, timesteps, feature)`. + mask: Binary tensor of shape `(samples, timesteps)` indicating whether + a given timestep should be masked (optional). + An individual `True` entry indicates that the corresponding timestep + should be utilized, while a `False` entry indicates that the + corresponding timestep should be ignored. Defaults to `None`. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. This argument is passed to the + cell when calling it. This is only relevant if `dropout` or + `recurrent_dropout` is used (optional). Defaults to `None`. + initial_state: List of initial state tensors to be passed to the first + call of the cell (optional, `None` causes creation + of zero-filled initial state tensors). Defaults to `None`. + """ + + def __init__( + self, + units, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + reset_after=True, + use_cudnn="auto", + **kwargs, + ): + cell = GRUCell( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + reset_after=reset_after, + dtype=kwargs.get("dtype", None), + trainable=kwargs.get("trainable", True), + name="gru_cell", + seed=seed, + implementation=kwargs.pop("implementation", 2), + ) + super().__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + activity_regularizer=activity_regularizer, + **kwargs, + ) + self.input_spec = InputSpec(ndim=3) + if use_cudnn not in ("auto", True, False): + raise ValueError( + "Invalid valid received for argument `use_cudnn`. " + "Expected one of {'auto', True, False}. " + f"Received: use_cudnn={use_cudnn}" + ) + self.use_cudnn = use_cudnn + if ( + backend.backend() == "tensorflow" + and backend.cudnn_ok( + cell.activation, + cell.recurrent_activation, + self.unroll, + cell.use_bias, + reset_after=reset_after, + ) + and use_cudnn in (True, "auto") + ): + self.supports_jit = False + + def inner_loop(self, sequences, initial_state, mask, training=False): + if tree.is_nested(initial_state): + initial_state = initial_state[0] + if tree.is_nested(mask): + mask = mask[0] + if self.use_cudnn in ("auto", True): + if not self.recurrent_dropout: + try: + if training and self.dropout: + dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :]) + dp_mask = ops.expand_dims(dp_mask, axis=1) + dp_mask = ops.broadcast_to( + dp_mask, ops.shape(sequences) + ) + dp_sequences = sequences * dp_mask + else: + dp_sequences = sequences + # Backends are allowed to specify (optionally) optimized + # implementation of the inner GRU loop. In the case of + # TF for instance, it will leverage cuDNN when feasible, and + # it will raise NotImplementedError otherwise. + out = backend.gru( + dp_sequences, + initial_state, + mask, + kernel=self.cell.kernel, + recurrent_kernel=self.cell.recurrent_kernel, + bias=self.cell.bias, + activation=self.cell.activation, + recurrent_activation=self.cell.recurrent_activation, + return_sequences=self.return_sequences, + go_backwards=self.go_backwards, + unroll=self.unroll, + reset_after=self.cell.reset_after, + ) + # We disable jit_compile for the model in this case, + # since cuDNN ops aren't XLA compatible. + if backend.backend() == "tensorflow": + self.supports_jit = False + return out + except NotImplementedError: + pass + if self.use_cudnn is True: + raise ValueError( + "use_cudnn=True was specified, " + "but cuDNN is not supported for this layer configuration " + "with this backend. Pass use_cudnn='auto' to fallback " + "to a non-cuDNN implementation." + ) + return super().inner_loop( + sequences, initial_state, mask=mask, training=training + ) + + def call(self, sequences, initial_state=None, mask=None, training=False): + return super().call( + sequences, mask=mask, training=training, initial_state=initial_state + ) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + @property + def reset_after(self): + return self.cell.reset_after + + def get_config(self): + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "recurrent_activation": activations.serialize( + self.recurrent_activation + ), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "reset_after": self.reset_after, + "seed": self.cell.seed, + } + base_config = super().get_config() + del base_config["cell"] + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/keras/src/layers/rnn/gru_test.py b/keras/src/layers/rnn/gru_test.py new file mode 100644 index 000000000000..7fc0d6c35b7e --- /dev/null +++ b/keras/src/layers/rnn/gru_test.py @@ -0,0 +1,356 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class GRUTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.GRU, + init_kwargs={"units": 3, "dropout": 0.5}, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 3), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.GRU, + init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5}, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 3), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.GRU, + init_kwargs={ + "units": 3, + "return_sequences": True, + "bias_regularizer": "l1", + "kernel_regularizer": "l2", + "recurrent_regularizer": "l2", + }, + input_shape=(3, 2, 4), + expected_output_shape=(3, 2, 3), + expected_num_losses=3, + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + @parameterized.parameters([1, 2]) + def test_correctness(self, implementation): + sequence = np.arange(72).reshape((3, 6, 4)).astype("float32") + layer = layers.GRU( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.5217289, 0.5217289, 0.5217289], + [0.6371659, 0.6371659, 0.6371659], + [0.39384964, 0.39384964, 0.3938496], + ] + ), + output, + ) + + layer = layers.GRU( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.24406259, 0.24406259, 0.24406259], + [0.611516, 0.611516, 0.611516], + [0.3928808, 0.3928808, 0.3928808], + ] + ), + output, + ) + + layer = layers.GRU( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unroll=True, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.5217289, 0.5217289, 0.5217289], + [0.6371659, 0.6371659, 0.6371659], + [0.39384964, 0.39384964, 0.3938496], + ] + ), + output, + ) + + layer = layers.GRU( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + reset_after=False, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.51447755, 0.51447755, 0.51447755], + [0.6426879, 0.6426879, 0.6426879], + [0.40208298, 0.40208298, 0.40208298], + ] + ), + output, + ) + + layer = layers.GRU( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + use_bias=False, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.49988455, 0.49988455, 0.49988455], + [0.64701194, 0.64701194, 0.64701194], + [0.4103359, 0.4103359, 0.4103359], + ] + ), + output, + ) + + def test_statefulness(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.GRU( + 4, + stateful=True, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.29542392, 0.29542392, 0.29542392, 0.29542392], + [0.5885018, 0.5885018, 0.5885018, 0.5885018], + ] + ), + output, + ) + layer.reset_state() + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.29542392, 0.29542392, 0.29542392, 0.29542392], + [0.5885018, 0.5885018, 0.5885018, 0.5885018], + ] + ), + output, + ) + + def test_pass_initial_state(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + initial_state = np.arange(4).reshape((2, 2)).astype("float32") + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]), + output, + ) + + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]), + output, + ) + + def test_pass_return_state(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + initial_state = np.arange(4).reshape((2, 2)).astype("float32") + + # Test with go_backwards=False + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_state=True, + ) + output, state = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]), + output, + ) + self.assertAllClose(output, state) + + # Test with go_backwards=True + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_state=True, + go_backwards=True, + ) + output, state = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]), + output, + ) + self.assertAllClose(output, state) + + def test_masking(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + mask = np.array([[True, True, False, True], [True, False, False, True]]) + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unroll=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array([[0.19393763, 0.19393763], [0.30818558, 0.30818558]]), + output, + ) + + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.03606692, 0.03606692], + [0.09497581, 0.09497581], + [0.09497581, 0.09497581], + [0.19393763, 0.19393763], + ], + ), + output[0], + ) + self.assertAllClose( + np.array( + [ + [0.16051409, 0.16051409], + [0.16051409, 0.16051409], + [0.16051409, 0.16051409], + [0.30818558, 0.30818558], + ], + ), + output[1], + ) + + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + zero_output_for_mask=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.03606692, 0.03606692], + [0.09497581, 0.09497581], + [0.0, 0.0], + [0.19393763, 0.19393763], + ], + ), + output[0], + ) + self.assertAllClose( + np.array( + [ + [0.16051409, 0.16051409], + [0.0, 0.0], + [0.0, 0.0], + [0.30818558, 0.30818558], + ], + ), + output[1], + ) + + layer = layers.GRU( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array([[0.11669192, 0.11669192], [0.28380975, 0.28380975]]), + output, + ) + + def test_legacy_implementation_argument(self): + sequence = np.arange(72).reshape((3, 6, 4)).astype("float32") + layer = layers.GRU( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + config = layer.get_config() + config["implementation"] = 0 # Add legacy argument + layer = layers.GRU.from_config(config) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.5217289, 0.5217289, 0.5217289], + [0.6371659, 0.6371659, 0.6371659], + [0.39384964, 0.39384964, 0.3938496], + ] + ), + output, + ) diff --git a/keras/src/layers/rnn/lstm.py b/keras/src/layers/rnn/lstm.py new file mode 100644 index 000000000000..32a426a8ee50 --- /dev/null +++ b/keras/src/layers/rnn/lstm.py @@ -0,0 +1,692 @@ +from keras.src import activations +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell +from keras.src.layers.rnn.rnn import RNN + + +@keras_export("keras.layers.LSTMCell") +class LSTMCell(Layer, DropoutRNNCell): + """Cell class for the LSTM layer. + + This class processes one step within the whole time sequence input, whereas + `keras.layer.LSTM` processes the whole sequence. + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. Default: hyperbolic tangent + (`tanh`). If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use for the recurrent step. + Default: sigmoid (`sigmoid`). If you pass `None`, no activation is + applied (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer + should use a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation + of the recurrent state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + unit_forget_bias: Boolean (default `True`). If `True`, + add 1 to the bias of the forget gate at initialization. + Setting it to `True` will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al.]( + https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + seed: Random seed for dropout. + + Call arguments: + inputs: A 2D tensor, with shape `(batch, features)`. + states: A 2D tensor with shape `(batch, units)`, which is the state + from the previous time step. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. Only relevant when `dropout` or + `recurrent_dropout` is used. + + Example: + + >>> inputs = np.random.random((32, 10, 8)) + >>> rnn = keras.layers.RNN(keras.layers.LSTMCell(4)) + >>> output = rnn(inputs) + >>> output.shape + (32, 4) + >>> rnn = keras.layers.RNN( + ... keras.layers.LSTMCell(4), + ... return_sequences=True, + ... return_state=True) + >>> whole_sequence_output, final_state = rnn(inputs) + >>> whole_sequence_output.shape + (32, 10, 4) + >>> final_state.shape + (32, 4) + """ + + def __init__( + self, + units, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + **kwargs, + ): + if units <= 0: + raise ValueError( + "Received an invalid value for argument `units`, " + f"expected a positive integer, got {units}." + ) + implementation = kwargs.pop("implementation", 2) + super().__init__(**kwargs) + self.implementation = implementation + self.units = units + self.activation = activations.get(activation) + self.recurrent_activation = activations.get(recurrent_activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1.0, max(0.0, dropout)) + self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) + if self.recurrent_dropout != 0.0: + self.implementation = 1 + if self.implementation == 1: + self.dropout_mask_count = 4 + self.seed = seed + self.seed_generator = backend.random.SeedGenerator(seed=seed) + + self.unit_forget_bias = unit_forget_bias + self.state_size = [self.units, self.units] + self.output_size = self.units + + def build(self, input_shape): + super().build(input_shape) + input_dim = input_shape[-1] + self.kernel = self.add_weight( + shape=(input_dim, self.units * 4), + name="kernel", + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units * 4), + name="recurrent_kernel", + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint, + ) + + if self.use_bias: + if self.unit_forget_bias: + + def bias_initializer(_, *args, **kwargs): + return ops.concatenate( + [ + self.bias_initializer( + (self.units,), *args, **kwargs + ), + initializers.get("ones")( + (self.units,), *args, **kwargs + ), + self.bias_initializer( + (self.units * 2,), *args, **kwargs + ), + ] + ) + + else: + bias_initializer = self.bias_initializer + self.bias = self.add_weight( + shape=(self.units * 4,), + name="bias", + initializer=bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + + def _compute_carry_and_output(self, x, h_tm1, c_tm1): + """Computes carry and output using split kernels.""" + x_i, x_f, x_c, x_o = x + h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 + i = self.recurrent_activation( + x_i + ops.matmul(h_tm1_i, self.recurrent_kernel[:, : self.units]) + ) + f = self.recurrent_activation( + x_f + + ops.matmul( + h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2] + ) + ) + c = f * c_tm1 + i * self.activation( + x_c + + ops.matmul( + h_tm1_c, + self.recurrent_kernel[:, self.units * 2 : self.units * 3], + ) + ) + o = self.recurrent_activation( + x_o + + ops.matmul(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :]) + ) + return c, o + + def _compute_carry_and_output_fused(self, z, c_tm1): + """Computes carry and output using fused kernels.""" + z0, z1, z2, z3 = z + i = self.recurrent_activation(z0) + f = self.recurrent_activation(z1) + c = f * c_tm1 + i * self.activation(z2) + o = self.recurrent_activation(z3) + return c, o + + def call(self, inputs, states, training=False): + h_tm1 = states[0] # previous memory state + c_tm1 = states[1] # previous carry state + + if self.implementation == 1: + if training and 0.0 < self.dropout < 1.0: + dp_mask = self.get_dropout_mask(inputs) + inputs_i = inputs * dp_mask[0] + inputs_f = inputs * dp_mask[1] + inputs_c = inputs * dp_mask[2] + inputs_o = inputs * dp_mask[3] + else: + inputs_i = inputs + inputs_f = inputs + inputs_c = inputs + inputs_o = inputs + k_i, k_f, k_c, k_o = ops.split(self.kernel, 4, axis=1) + x_i = ops.matmul(inputs_i, k_i) + x_f = ops.matmul(inputs_f, k_f) + x_c = ops.matmul(inputs_c, k_c) + x_o = ops.matmul(inputs_o, k_o) + if self.use_bias: + b_i, b_f, b_c, b_o = ops.split(self.bias, 4, axis=0) + x_i += b_i + x_f += b_f + x_c += b_c + x_o += b_o + + if training and 0.0 < self.recurrent_dropout < 1.0: + rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1) + h_tm1_i = h_tm1 * rec_dp_mask[0] + h_tm1_f = h_tm1 * rec_dp_mask[1] + h_tm1_c = h_tm1 * rec_dp_mask[2] + h_tm1_o = h_tm1 * rec_dp_mask[3] + else: + h_tm1_i = h_tm1 + h_tm1_f = h_tm1 + h_tm1_c = h_tm1 + h_tm1_o = h_tm1 + x = (x_i, x_f, x_c, x_o) + h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) + c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) + else: + if training and 0.0 < self.dropout < 1.0: + dp_mask = self.get_dropout_mask(inputs) + inputs = inputs * dp_mask + + z = ops.matmul(inputs, self.kernel) + + z = ops.add(z, ops.matmul(h_tm1, self.recurrent_kernel)) + if self.use_bias: + z = ops.add(z, self.bias) + + z = ops.split(z, 4, axis=1) + c, o = self._compute_carry_and_output_fused(z, c_tm1) + + h = o * self.activation(c) + return h, [h, c] + + def get_config(self): + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "recurrent_activation": activations.serialize( + self.recurrent_activation + ), + "use_bias": self.use_bias, + "unit_forget_bias": self.unit_forget_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} + + def get_initial_state(self, batch_size=None): + return [ + ops.zeros((batch_size, d), dtype=self.compute_dtype) + for d in self.state_size + ] + + +@keras_export("keras.layers.LSTM") +class LSTM(RNN): + """Long Short-Term Memory layer - Hochreiter 1997. + + Based on available runtime hardware and constraints, this layer + will choose different implementations (cuDNN-based or backend-native) + to maximize the performance. If a GPU is available and all + the arguments to the layer meet the requirement of the cuDNN kernel + (see below for details), the layer will use a fast cuDNN implementation + when using the TensorFlow backend. + The requirements to use the cuDNN implementation are: + + 1. `activation` == `tanh` + 2. `recurrent_activation` == `sigmoid` + 3. `recurrent_dropout` == 0 + 4. `unroll` is `False` + 5. `use_bias` is `True` + 6. Inputs, if use masking, are strictly right-padded. + 7. Eager execution is enabled in the outermost context. + + For example: + + >>> inputs = np.random.random((32, 10, 8)) + >>> lstm = keras.layers.LSTM(4) + >>> output = lstm(inputs) + >>> output.shape + (32, 4) + >>> lstm = keras.layers.LSTM( + ... 4, return_sequences=True, return_state=True) + >>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs) + >>> whole_seq_output.shape + (32, 10, 4) + >>> final_memory_state.shape + (32, 4) + >>> final_carry_state.shape + (32, 4) + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step. + Default: sigmoid (`sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer + should use a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the recurrent + state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + unit_forget_bias: Boolean (default `True`). If `True`, + add 1 to the bias of the forget gate at initialization. + Setting it to `True` will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al.]( + https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation"). Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + seed: Random seed for dropout. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. Default: `False`. + return_state: Boolean. Whether to return the last state in addition + to the output. Default: `False`. + go_backwards: Boolean (default: `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default: `False`). If `True`, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If `True`, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + use_cudnn: Whether to use a cuDNN-backed implementation. `"auto"` will + attempt to use cuDNN when feasible, and will fallback to the + default implementation if not. + + Call arguments: + inputs: A 3D tensor, with shape `(batch, timesteps, feature)`. + mask: Binary tensor of shape `(samples, timesteps)` indicating whether + a given timestep should be masked (optional). + An individual `True` entry indicates that the corresponding timestep + should be utilized, while a `False` entry indicates that the + corresponding timestep should be ignored. Defaults to `None`. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. This argument is passed to the + cell when calling it. This is only relevant if `dropout` or + `recurrent_dropout` is used (optional). Defaults to `None`. + initial_state: List of initial state tensors to be passed to the first + call of the cell (optional, `None` causes creation + of zero-filled initial state tensors). Defaults to `None`. + """ + + def __init__( + self, + units, + activation="tanh", + recurrent_activation="sigmoid", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + use_cudnn="auto", + **kwargs, + ): + cell = LSTMCell( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + unit_forget_bias=unit_forget_bias, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + dtype=kwargs.get("dtype", None), + trainable=kwargs.get("trainable", True), + name="lstm_cell", + seed=seed, + implementation=kwargs.pop("implementation", 2), + ) + super().__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + activity_regularizer=activity_regularizer, + **kwargs, + ) + self.input_spec = InputSpec(ndim=3) + if use_cudnn not in ("auto", True, False): + raise ValueError( + "Invalid valid received for argument `use_cudnn`. " + "Expected one of {'auto', True, False}. " + f"Received: use_cudnn={use_cudnn}" + ) + self.use_cudnn = use_cudnn + if ( + backend.backend() == "tensorflow" + and backend.cudnn_ok( + cell.activation, + cell.recurrent_activation, + self.unroll, + cell.use_bias, + ) + and use_cudnn in (True, "auto") + ): + self.supports_jit = False + + def inner_loop(self, sequences, initial_state, mask, training=False): + if tree.is_nested(mask): + mask = mask[0] + + if self.use_cudnn in ("auto", True): + if not self.recurrent_dropout: + try: + if training and self.dropout: + dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :]) + dp_mask = ops.expand_dims(dp_mask, axis=1) + dp_mask = ops.broadcast_to( + dp_mask, ops.shape(sequences) + ) + dp_sequences = sequences * dp_mask + else: + dp_sequences = sequences + + # Backends are allowed to specify (optionally) optimized + # implementation of the inner LSTM loop. In the case of + # TF for instance, it will leverage cuDNN when feasible, and + # it will raise NotImplementedError otherwise. + out = backend.lstm( + dp_sequences, + initial_state[0], + initial_state[1], + mask, + kernel=self.cell.kernel, + recurrent_kernel=self.cell.recurrent_kernel, + bias=self.cell.bias, + activation=self.cell.activation, + recurrent_activation=self.cell.recurrent_activation, + return_sequences=self.return_sequences, + go_backwards=self.go_backwards, + unroll=self.unroll, + ) + # We disable jit_compile for the model in this case, + # since cuDNN ops aren't XLA compatible. + if backend.backend() == "tensorflow": + self.supports_jit = False + return out + except NotImplementedError: + pass + if self.use_cudnn is True: + raise ValueError( + "use_cudnn=True was specified, " + "but cuDNN is not supported for this layer configuration " + "with this backend. Pass use_cudnn='auto' to fallback " + "to a non-cuDNN implementation." + ) + return super().inner_loop( + sequences, initial_state, mask=mask, training=training + ) + + def call(self, sequences, initial_state=None, mask=None, training=False): + return super().call( + sequences, mask=mask, training=training, initial_state=initial_state + ) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def unit_forget_bias(self): + return self.cell.unit_forget_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + def get_config(self): + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "recurrent_activation": activations.serialize( + self.recurrent_activation + ), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "unit_forget_bias": self.unit_forget_bias, + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "seed": self.cell.seed, + } + base_config = super().get_config() + del base_config["cell"] + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/keras/src/layers/rnn/lstm_test.py b/keras/src/layers/rnn/lstm_test.py new file mode 100644 index 000000000000..0486c196e4fc --- /dev/null +++ b/keras/src/layers/rnn/lstm_test.py @@ -0,0 +1,306 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class LSTMTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.LSTM, + init_kwargs={"units": 3, "dropout": 0.5}, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 3), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.LSTM, + init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5}, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 3), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + self.run_layer_test( + layers.LSTM, + init_kwargs={ + "units": 3, + "return_sequences": True, + "bias_regularizer": "l1", + "kernel_regularizer": "l2", + "recurrent_regularizer": "l2", + }, + input_shape=(3, 2, 4), + expected_output_shape=(3, 2, 3), + expected_num_losses=3, + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + @parameterized.parameters([1, 2]) + def test_correctness(self, implementation): + sequence = np.arange(72).reshape((3, 6, 4)).astype("float32") + layer = layers.LSTM( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + implementation=implementation, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.6288687, 0.6288687, 0.6288687], + [0.86899155, 0.86899155, 0.86899155], + [0.9460773, 0.9460773, 0.9460773], + ] + ), + output, + ) + + layer = layers.LSTM( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + implementation=implementation, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.35622165, 0.35622165, 0.35622165], + [0.74789524, 0.74789524, 0.74789524], + [0.8872726, 0.8872726, 0.8872726], + ] + ), + output, + ) + + layer = layers.LSTM( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unroll=True, + implementation=implementation, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.6288687, 0.6288687, 0.6288687], + [0.86899155, 0.86899155, 0.86899155], + [0.9460773, 0.9460773, 0.9460773], + ] + ), + output, + ) + + layer = layers.LSTM( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unit_forget_bias=False, + implementation=implementation, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.57019705, 0.57019705, 0.57019705], + [0.8661914, 0.8661914, 0.8661914], + [0.9459622, 0.9459622, 0.9459622], + ] + ), + output, + ) + + layer = layers.LSTM( + 3, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + use_bias=False, + implementation=implementation, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.54986924, 0.54986924, 0.54986924], + [0.86226785, 0.86226785, 0.86226785], + [0.9443936, 0.9443936, 0.9443936], + ] + ), + output, + ) + + def test_statefulness(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.LSTM( + 4, + stateful=True, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.3124785, 0.3124785, 0.3124785, 0.3124785], + [0.6863672, 0.6863672, 0.6863672, 0.6863672], + ] + ), + output, + ) + layer.reset_state() + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.3124785, 0.3124785, 0.3124785, 0.3124785], + [0.6863672, 0.6863672, 0.6863672, 0.6863672], + ] + ), + output, + ) + + def test_pass_initial_state(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + initial_state = [ + np.arange(4).reshape((2, 2)).astype("float32"), + np.arange(4).reshape((2, 2)).astype("float32"), + ] + layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array([[0.20574439, 0.3558822], [0.64930826, 0.66276]]), + output, + ) + + layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array([[0.13281618, 0.2790356], [0.5839337, 0.5992567]]), + output, + ) + + def test_masking(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + mask = np.array([[True, True, False, True], [True, False, False, True]]) + layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unroll=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array([[0.1524914, 0.1524914], [0.35969394, 0.35969394]]), + output, + ) + + layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.0158891, 0.0158891], + [0.05552047, 0.05552047], + [0.05552047, 0.05552047], + [0.1524914, 0.1524914], + ], + ), + output[0], + ) + self.assertAllClose( + np.array( + [ + [0.14185596, 0.14185596], + [0.14185596, 0.14185596], + [0.14185596, 0.14185596], + [0.35969394, 0.35969394], + ], + ), + output[1], + ) + + layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + zero_output_for_mask=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.0158891, 0.0158891], + [0.05552047, 0.05552047], + [0.0, 0.0], + [0.1524914, 0.1524914], + ], + ), + output[0], + ) + self.assertAllClose( + np.array( + [ + [0.14185596, 0.14185596], + [0.0, 0.0], + [0.0, 0.0], + [0.35969394, 0.35969394], + ], + ), + output[1], + ) + + layer = layers.LSTM( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array([[0.10056866, 0.10056866], [0.31006062, 0.31006062]]), + output, + ) diff --git a/keras/src/layers/rnn/rnn.py b/keras/src/layers/rnn/rnn.py new file mode 100644 index 000000000000..3f86daaba26a --- /dev/null +++ b/keras/src/layers/rnn/rnn.py @@ -0,0 +1,479 @@ +from keras.src import backend +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell +from keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells +from keras.src.saving import serialization_lib +from keras.src.utils import tracking + + +@keras_export("keras.layers.RNN") +class RNN(Layer): + """Base class for recurrent layers. + + Args: + cell: A RNN cell instance or a list of RNN cell instances. + A RNN cell is a class that has: + - A `call(input_at_t, states_at_t)` method, returning + `(output_at_t, states_at_t_plus_1)`. The call method of the + cell can also take the optional argument `constants`, see + section "Note on passing external constants" below. + - A `state_size` attribute. This can be a single integer + (single state) in which case it is the size of the recurrent + state. This can also be a list/tuple of integers + (one size per state). + - A `output_size` attribute, a single integer. + - A `get_initial_state(batch_size=None)` + method that creates a tensor meant to be fed to `call()` as the + initial state, if the user didn't specify any initial state + via other means. The returned initial state should have + shape `(batch_size, cell.state_size)`. + The cell might choose to create a tensor full of zeros, + or other values based on the cell's implementation. + `inputs` is the input tensor to the RNN layer, with shape + `(batch_size, timesteps, features)`. + If this method is not implemented + by the cell, the RNN layer will create a zero filled tensor + with shape `(batch_size, cell.state_size)`. + In the case that `cell` is a list of RNN cell instances, the cells + will be stacked on top of each other in the RNN, resulting in an + efficient stacked RNN. + return_sequences: Boolean (default `False`). Whether to return the last + output in the output sequence, or the full sequence. + return_state: Boolean (default `False`). + Whether to return the last state in addition to the output. + go_backwards: Boolean (default `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default `False`). If True, the last state + for each sample at index `i` in a batch will be used as initial + state for the sample of index `i` in the following batch. + unroll: Boolean (default `False`). + If True, the network will be unrolled, else a symbolic loop will be + used. Unrolling can speed-up a RNN, although it tends to be more + memory-intensive. Unrolling is only suitable for short sequences. + zero_output_for_mask: Boolean (default `False`). + Whether the output should use zeros for the masked timesteps. + Note that this field is only used when `return_sequences` + is `True` and `mask` is provided. + It can useful if you want to reuse the raw output sequence of + the RNN without interference from the masked timesteps, e.g., + merging bidirectional RNNs. + + Call arguments: + sequences: A 3-D tensor with shape `(batch_size, timesteps, features)`. + initial_state: List of initial state tensors to be passed to the first + call of the cell. + mask: Binary tensor of shape `[batch_size, timesteps]` + indicating whether a given timestep should be masked. + An individual `True` entry indicates that the corresponding + timestep should be utilized, while a `False` entry indicates + that the corresponding timestep should be ignored. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. This argument is passed + to the cell when calling it. + This is for use with cells that use dropout. + + Output shape: + + - If `return_state`: a list of tensors. The first tensor is + the output. The remaining tensors are the last states, + each with shape `(batch_size, state_size)`, where `state_size` could + be a high dimension tensor shape. + - If `return_sequences`: 3D tensor with shape + `(batch_size, timesteps, output_size)`. + + Masking: + + This layer supports masking for input data with a variable number + of timesteps. To introduce masks to your data, + use a `keras.layers.Embedding` layer with the `mask_zero` parameter + set to `True`. + + Note on using statefulness in RNNs: + + You can set RNN layers to be 'stateful', which means that the states + computed for the samples in one batch will be reused as initial states + for the samples in the next batch. This assumes a one-to-one mapping + between samples in different successive batches. + + To enable statefulness: + + - Specify `stateful=True` in the layer constructor. + - Specify a fixed batch size for your model, by passing + `batch_size=...` to the `Input` layer(s) of your model. + Remember to also specify the same `batch_size=...` when + calling `fit()`, or otherwise use a generator-like + data source like a `keras.utils.PyDataset` or a + `tf.data.Dataset`. + - Specify `shuffle=False` when calling `fit()`, since your + batches are expected to be temporally ordered. + + To reset the states of your model, call `.reset_state()` on either + a specific layer, or on your entire model. + + Note on specifying the initial state of RNNs: + + You can specify the initial state of RNN layers symbolically by + calling them with the keyword argument `initial_state`. The value of + `initial_state` should be a tensor or list of tensors representing + the initial state of the RNN layer. + + You can specify the initial state of RNN layers numerically by + calling `reset_state()` with the keyword argument `states`. The value of + `states` should be a numpy array or list of numpy arrays representing + the initial state of the RNN layer. + + Examples: + + ```python + from keras.layers import RNN + from keras import ops + + # First, let's define a RNN Cell, as a layer subclass. + class MinimalRNNCell(keras.Layer): + + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + self.state_size = units + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + + def call(self, inputs, states): + prev_output = states[0] + h = ops.matmul(inputs, self.kernel) + output = h + ops.matmul(prev_output, self.recurrent_kernel) + return output, [output] + + # Let's use this cell in a RNN layer: + + cell = MinimalRNNCell(32) + x = keras.Input((None, 5)) + layer = RNN(cell) + y = layer(x) + + # Here's how to use the cell to build a stacked RNN: + + cells = [MinimalRNNCell(32), MinimalRNNCell(64)] + x = keras.Input((None, 5)) + layer = RNN(cells) + y = layer(x) + ``` + """ + + def __init__( + self, + cell, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + zero_output_for_mask=False, + **kwargs, + ): + if isinstance(cell, (list, tuple)): + cell = StackedRNNCells(cell) + if "call" not in dir(cell): + raise ValueError( + "Argument `cell` should have a `call` method. " + f"Received: cell={cell}" + ) + if "state_size" not in dir(cell): + raise ValueError( + "The RNN cell should have a `state_size` attribute " + "(single integer or list of integers, " + "one integer per RNN state). " + f"Received: cell={cell}" + ) + super().__init__(**kwargs) + + # If True, the output for masked timestep will be zeros, whereas in the + # False case, output from previous timestep is returned for masked + # timestep. + self.zero_output_for_mask = zero_output_for_mask + self.cell = cell + self.return_sequences = return_sequences + self.return_state = return_state + self.go_backwards = go_backwards + self.stateful = stateful + self.unroll = unroll + + self.supports_masking = True + self.input_spec = None + self.states = None + + state_size = getattr(self.cell, "state_size", None) + if state_size is None: + raise ValueError( + "state_size must be specified as property on the RNN cell." + ) + if not isinstance(state_size, (list, tuple, int)): + raise ValueError( + "state_size must be an integer, or a list/tuple of integers " + "(one for each state tensor)." + ) + if isinstance(state_size, int): + self.state_size = [state_size] + self.single_state = True + else: + self.state_size = list(state_size) + self.single_state = False + + def compute_output_shape(self, sequences_shape, initial_state_shape=None): + batch_size = sequences_shape[0] + length = sequences_shape[1] + states_shape = [] + for state_size in self.state_size: + if isinstance(state_size, int): + states_shape.append((batch_size, state_size)) + elif isinstance(state_size, (list, tuple)): + states_shape.append([(batch_size, s) for s in state_size]) + + output_size = getattr(self.cell, "output_size", None) + if output_size is None: + output_size = self.state_size[0] + if not isinstance(output_size, int): + raise ValueError("output_size must be an integer.") + if self.return_sequences: + output_shape = (batch_size, length, output_size) + else: + output_shape = (batch_size, output_size) + if self.return_state: + return output_shape, *states_shape + return output_shape + + def compute_mask(self, _, mask): + # Time step masks must be the same for each input. + # This is because the mask for an RNN is of size [batch, time_steps, 1], + # and specifies which time steps should be skipped, and a time step + # must be skipped for all inputs. + mask = tree.flatten(mask)[0] + output_mask = mask if self.return_sequences else None + if self.return_state: + state_mask = [None for _ in self.state_size] + return [output_mask] + state_mask + else: + return output_mask + + def build(self, sequences_shape, initial_state_shape=None): + # Build cell (if layer). + step_input_shape = (sequences_shape[0],) + tuple(sequences_shape[2:]) + if isinstance(self.cell, Layer) and not self.cell.built: + self.cell.build(step_input_shape) + self.cell.built = True + if self.stateful: + if self.states is not None: + self.reset_state() + else: + if sequences_shape[0] is None: + raise ValueError( + "When using `stateful=True` in a RNN, the " + "batch size must be static. Found dynamic " + f"batch size: sequence.shape={sequences_shape}" + ) + self._create_state_variables(sequences_shape[0]) + + @tracking.no_automatic_dependency_tracking + def _create_state_variables(self, batch_size): + with backend.name_scope(self.name, caller=self): + self.states = tree.map_structure( + lambda value: backend.Variable( + value, + trainable=False, + dtype=self.variable_dtype, + name="rnn_state", + ), + self.get_initial_state(batch_size), + ) + + def get_initial_state(self, batch_size): + get_initial_state_fn = getattr(self.cell, "get_initial_state", None) + if get_initial_state_fn: + init_state = get_initial_state_fn(batch_size=batch_size) + else: + return [ + ops.zeros((batch_size, d), dtype=self.cell.compute_dtype) + for d in self.state_size + ] + + # RNN expect the states in a list, even if single state. + if not tree.is_nested(init_state): + init_state = [init_state] + # Force the state to be a list in case it is a namedtuple eg + # LSTMStateTuple. + return list(init_state) + + def reset_states(self): + # Compatibility alias. + self.reset_state() + + def reset_state(self): + if self.states is not None: + for v in self.states: + v.assign(ops.zeros_like(v.value)) + + def inner_loop(self, sequences, initial_state, mask, training=False): + cell_kwargs = {} + if isinstance(self.cell, Layer) and self.cell._call_has_training_arg: + cell_kwargs["training"] = training + + def step(inputs, states): + # Create new tensor copies when using PyTorch backend + # with stateful=True. This prevents in-place modifications + # that would otherwise break PyTorch's autograd functionality + # by modifying tensors needed for gradient computation. + if backend.backend() == "torch" and self.stateful: + states = tree.map_structure(ops.copy, states) + output, new_states = self.cell(inputs, states, **cell_kwargs) + if not tree.is_nested(new_states): + new_states = [new_states] + return output, new_states + + if not tree.is_nested(initial_state): + initial_state = [initial_state] + + return backend.rnn( + step, + sequences, + initial_state, + go_backwards=self.go_backwards, + mask=mask, + unroll=self.unroll, + input_length=sequences.shape[1], + zero_output_for_mask=self.zero_output_for_mask, + return_all_outputs=self.return_sequences, + ) + + def call( + self, + sequences, + initial_state=None, + mask=None, + training=False, + ): + timesteps = sequences.shape[1] + if self.unroll and timesteps is None: + raise ValueError( + "Cannot unroll a RNN if the " + "time dimension is undefined. \n" + "- If using a Sequential model, " + "specify the time dimension by passing " + "an `Input()` as your first layer.\n" + "- If using the functional API, specify " + "the time dimension by passing a `shape` " + "or `batch_shape` argument to your `Input()`." + ) + + if initial_state is None: + if self.stateful: + initial_state = self.states + else: + initial_state = self.get_initial_state( + batch_size=ops.shape(sequences)[0] + ) + # RNN expect the states in a list, even if single state. + if not tree.is_nested(initial_state): + initial_state = [initial_state] + initial_state = list(initial_state) + + # Cast states to compute dtype. + # Note that states may be deeply nested + # (e.g. in the stacked cells case). + initial_state = tree.map_structure( + lambda x: backend.convert_to_tensor( + x, dtype=self.cell.compute_dtype + ), + initial_state, + ) + + # Prepopulate the dropout state so that the inner_loop is stateless + # this is particularly important for JAX backend. + self._maybe_config_dropout_masks( + self.cell, sequences[:, 0, :], initial_state + ) + + last_output, outputs, states = self.inner_loop( + sequences=sequences, + initial_state=initial_state, + mask=mask, + training=training, + ) + last_output = ops.cast(last_output, self.compute_dtype) + outputs = ops.cast(outputs, self.compute_dtype) + states = tree.map_structure( + lambda x: ops.cast(x, dtype=self.compute_dtype), states + ) + self._maybe_reset_dropout_masks(self.cell) + + if self.stateful: + for self_state, state in zip( + tree.flatten(self.states), tree.flatten(states) + ): + self_state.assign(state) + + if self.return_sequences: + output = outputs + else: + output = last_output + + if self.return_state: + return output, *states + return output + + def _maybe_config_dropout_masks(self, cell, input_sequence, input_state): + state = ( + input_state[0] + if isinstance(input_state, (list, tuple)) + else input_state + ) + if isinstance(cell, DropoutRNNCell): + cell.get_dropout_mask(input_sequence) + cell.get_recurrent_dropout_mask(state) + if isinstance(cell, StackedRNNCells): + for c, s in zip(cell.cells, input_state): + self._maybe_config_dropout_masks(c, input_sequence, s) + # Replicate the behavior of `StackedRNNCells.call` to compute + # the inputs for the next cell. + s = list(s) if tree.is_nested(s) else [s] + cell_call_fn = c.__call__ if callable(c) else c.call + input_sequence, _ = cell_call_fn(input_sequence, s) + + def _maybe_reset_dropout_masks(self, cell): + if isinstance(cell, DropoutRNNCell): + cell.reset_dropout_mask() + cell.reset_recurrent_dropout_mask() + if isinstance(cell, StackedRNNCells): + for c in cell.cells: + self._maybe_reset_dropout_masks(c) + + def get_config(self): + config = { + "return_sequences": self.return_sequences, + "return_state": self.return_state, + "go_backwards": self.go_backwards, + "stateful": self.stateful, + "unroll": self.unroll, + "zero_output_for_mask": self.zero_output_for_mask, + } + config["cell"] = serialization_lib.serialize_keras_object(self.cell) + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + cell = serialization_lib.deserialize_keras_object( + config.pop("cell"), custom_objects=custom_objects + ) + layer = cls(cell, **config) + return layer diff --git a/keras/src/layers/rnn/rnn_test.py b/keras/src/layers/rnn/rnn_test.py new file mode 100644 index 000000000000..3562f6c0bb96 --- /dev/null +++ b/keras/src/layers/rnn/rnn_test.py @@ -0,0 +1,384 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import ops +from keras.src import testing + + +class OneStateRNNCell(layers.Layer): + def __init__(self, units, state_size=None, **kwargs): + super().__init__(**kwargs) + self.units = units + self.state_size = state_size if state_size else units + + def build(self, input_shape): + self.kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="ones", + name="kernel", + ) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer="ones", + name="recurrent_kernel", + ) + + def call(self, inputs, states): + prev_output = states[0] + h = ops.matmul(inputs, self.kernel) + output = h + ops.matmul(prev_output, self.recurrent_kernel) + return output, [output] + + +class TwoStatesRNNCell(layers.Layer): + def __init__(self, units, state_size=None, **kwargs): + super().__init__(**kwargs) + self.units = units + self.state_size = state_size if state_size else [units, units] + self.output_size = units + + def build(self, input_shape): + self.kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer="ones", + name="kernel", + ) + self.recurrent_kernel_1 = self.add_weight( + shape=(self.units, self.units), + initializer="ones", + name="recurrent_kernel_1", + ) + self.recurrent_kernel_2 = self.add_weight( + shape=(self.units, self.units), + initializer="ones", + name="recurrent_kernel_2", + ) + + def call(self, inputs, states): + prev_1 = states[0] + prev_2 = states[0] + h = ops.matmul(inputs, self.kernel) + output_1 = h + ops.matmul(prev_1, self.recurrent_kernel_1) + output_2 = h + ops.matmul(prev_2, self.recurrent_kernel_2) + output = output_1 + output_2 + return output, [output_1, output_2] + + +class RNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": OneStateRNNCell(5, state_size=5)}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": OneStateRNNCell(5, state_size=[5])}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": OneStateRNNCell(5, state_size=(5,))}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": OneStateRNNCell(5), "return_sequences": True}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 2, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": OneStateRNNCell(5), + "go_backwards": True, + "unroll": True, + }, + input_shape=(3, 2, 4), + expected_output_shape=(3, 5), + expected_num_trainable_weights=2, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": TwoStatesRNNCell(5, state_size=[5, 5])}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": TwoStatesRNNCell(5, state_size=(5, 5))}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={"cell": TwoStatesRNNCell(5), "return_sequences": True}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 2, 5), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + ) + + def test_compute_output_shape_single_state(self): + sequence = np.ones((3, 4, 5)) + layer = layers.RNN(OneStateRNNCell(8), return_sequences=False) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape, (3, 8)) + + layer = layers.RNN(OneStateRNNCell(8), return_sequences=True) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape, (3, 4, 8)) + + layer = layers.RNN( + OneStateRNNCell(8), return_sequences=False, return_state=True + ) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape[0], (3, 8)) + self.assertEqual(output_shape[1], (3, 8)) + + layer = layers.RNN( + OneStateRNNCell(8), return_sequences=True, return_state=True + ) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape[0], (3, 4, 8)) + self.assertEqual(output_shape[1], (3, 8)) + + def test_compute_output_shape_two_states(self): + sequence = np.ones((3, 4, 5)) + layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=False) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape, (3, 8)) + + layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=True) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape, (3, 4, 8)) + + layer = layers.RNN( + TwoStatesRNNCell(8), return_sequences=False, return_state=True + ) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape[0], (3, 8)) + self.assertEqual(output_shape[1], (3, 8)) + self.assertEqual(output_shape[2], (3, 8)) + + layer = layers.RNN( + TwoStatesRNNCell(8), return_sequences=True, return_state=True + ) + output_shape = layer.compute_output_shape(sequence.shape) + self.assertEqual(output_shape[0], (3, 4, 8)) + self.assertEqual(output_shape[1], (3, 8)) + self.assertEqual(output_shape[2], (3, 8)) + + def test_dynamic_shapes(self): + sequence_shape = (None, None, 3) + layer = layers.RNN(OneStateRNNCell(8), return_sequences=False) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape, (None, 8)) + + layer = layers.RNN(OneStateRNNCell(8), return_sequences=True) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape, (None, None, 8)) + + layer = layers.RNN( + OneStateRNNCell(8), return_sequences=False, return_state=True + ) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape[0], (None, 8)) + self.assertEqual(output_shape[1], (None, 8)) + + layer = layers.RNN( + OneStateRNNCell(8), return_sequences=True, return_state=True + ) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape[0], (None, None, 8)) + self.assertEqual(output_shape[1], (None, 8)) + + layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=False) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape, (None, 8)) + + layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=True) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape, (None, None, 8)) + + layer = layers.RNN( + TwoStatesRNNCell(8), return_sequences=False, return_state=True + ) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape[0], (None, 8)) + self.assertEqual(output_shape[1], (None, 8)) + self.assertEqual(output_shape[2], (None, 8)) + + layer = layers.RNN( + TwoStatesRNNCell(8), return_sequences=True, return_state=True + ) + output_shape = layer.compute_output_shape(sequence_shape) + self.assertEqual(output_shape[0], (None, None, 8)) + self.assertEqual(output_shape[1], (None, 8)) + self.assertEqual(output_shape[2], (None, 8)) + + def test_forward_pass_single_state(self): + sequence = np.ones((1, 2, 3)) + layer = layers.RNN(OneStateRNNCell(2), return_sequences=False) + output = layer(sequence) + self.assertAllClose(np.array([[9.0, 9.0]]), output) + + layer = layers.RNN(OneStateRNNCell(2), return_sequences=True) + output = layer(sequence) + self.assertAllClose(np.array([[[3.0, 3.0], [9.0, 9.0]]]), output) + + layer = layers.RNN( + OneStateRNNCell(2), return_sequences=False, return_state=True + ) + output, state = layer(sequence) + self.assertAllClose(np.array([[9.0, 9.0]]), output) + self.assertAllClose(np.array([[9.0, 9.0]]), state) + + layer = layers.RNN( + OneStateRNNCell(2), return_sequences=True, return_state=True + ) + output, state = layer(sequence) + self.assertAllClose(np.array([[[3.0, 3.0], [9.0, 9.0]]]), output) + self.assertAllClose(np.array([[9.0, 9.0]]), state) + + def test_forward_pass_two_states(self): + sequence = np.ones((1, 2, 3)) + layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False) + output = layer(sequence) + self.assertAllClose(np.array([[18.0, 18.0]]), output) + + layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=True) + output = layer(sequence) + self.assertAllClose(np.array([[[6.0, 6.0], [18.0, 18.0]]]), output) + + layer = layers.RNN( + TwoStatesRNNCell(2), return_sequences=False, return_state=True + ) + output, state1, state2 = layer(sequence) + self.assertAllClose(np.array([[18.0, 18.0]]), output) + self.assertAllClose(np.array([[9.0, 9.0]]), state1) + self.assertAllClose(np.array([[9.0, 9.0]]), state2) + + layer = layers.RNN( + TwoStatesRNNCell(2), return_sequences=True, return_state=True + ) + output, state1, state2 = layer(sequence) + self.assertAllClose(np.array([[[6.0, 6.0], [18.0, 18.0]]]), output) + self.assertAllClose(np.array([[9.0, 9.0]]), state1) + self.assertAllClose(np.array([[9.0, 9.0]]), state2) + + def test_passing_initial_state_single_state(self): + sequence = np.ones((2, 3, 2)) + state = np.ones((2, 2)) + layer = layers.RNN(OneStateRNNCell(2), return_sequences=False) + output = layer(sequence, initial_state=state) + self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), output) + + layer = layers.RNN( + OneStateRNNCell(2), return_sequences=False, return_state=True + ) + output, state = layer(sequence, initial_state=state) + self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), output) + self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), state) + + def test_passing_initial_state_two_states(self): + sequence = np.ones((2, 3, 2)) + state = [np.ones((2, 2)), np.ones((2, 2))] + layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False) + output = layer(sequence, initial_state=state) + self.assertAllClose(np.array([[44.0, 44.0], [44.0, 44.0]]), output) + + layer = layers.RNN( + TwoStatesRNNCell(2), return_sequences=False, return_state=True + ) + output, state_1, state_2 = layer(sequence, initial_state=state) + self.assertAllClose(np.array([[44.0, 44.0], [44.0, 44.0]]), output) + self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), state_1) + self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), state_2) + + def test_statefulness_single_state(self): + sequence = np.ones((1, 2, 3)) + layer = layers.RNN(OneStateRNNCell(2), stateful=True) + layer(sequence) + output = layer(sequence) + self.assertAllClose(np.array([[45.0, 45.0]]), output) + + layer = layers.RNN(OneStateRNNCell(2), stateful=True, return_state=True) + layer(sequence) + output, state = layer(sequence) + self.assertAllClose(np.array([[45.0, 45.0]]), output) + self.assertAllClose(np.array([[45.0, 45.0]]), state) + + def test_statefulness_two_states(self): + sequence = np.ones((1, 2, 3)) + layer = layers.RNN(TwoStatesRNNCell(2), stateful=True) + layer(sequence) + output = layer(sequence) + self.assertAllClose(np.array([[90.0, 90.0]]), output) + + layer = layers.RNN( + TwoStatesRNNCell(2), stateful=True, return_state=True + ) + layer(sequence) + output, state_1, state_2 = layer(sequence) + self.assertAllClose(np.array([[90.0, 90.0]]), output) + self.assertAllClose(np.array([[45.0, 45.0]]), state_1) + self.assertAllClose(np.array([[45.0, 45.0]]), state_2) + + def test_go_backwards(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.RNN(OneStateRNNCell(2), go_backwards=True) + layer(sequence) + output = layer(sequence) + self.assertAllClose(np.array([[202.0, 202.0], [538.0, 538.0]]), output) + + layer = layers.RNN(OneStateRNNCell(2), stateful=True, return_state=True) + layer(sequence) + output, state = layer(sequence) + self.assertAllClose( + np.array([[954.0, 954.0], [3978.0, 3978.0]]), output + ) + self.assertAllClose(np.array([[954.0, 954.0], [3978.0, 3978.0]]), state) + + def test_serialization(self): + layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False) + self.run_class_serialization_test(layer) + + layer = layers.RNN(OneStateRNNCell(2), return_sequences=False) + self.run_class_serialization_test(layer) + + # TODO: test masking diff --git a/keras/src/layers/rnn/simple_rnn.py b/keras/src/layers/rnn/simple_rnn.py new file mode 100644 index 000000000000..b811baf88234 --- /dev/null +++ b/keras/src/layers/rnn/simple_rnn.py @@ -0,0 +1,449 @@ +from keras.src import activations +from keras.src import backend +from keras.src import constraints +from keras.src import initializers +from keras.src import ops +from keras.src import regularizers +from keras.src.api_export import keras_export +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell +from keras.src.layers.rnn.rnn import RNN + + +@keras_export("keras.layers.SimpleRNNCell") +class SimpleRNNCell(Layer, DropoutRNNCell): + """Cell class for SimpleRNN. + + This class processes one step within the whole time sequence input, whereas + `keras.layer.SimpleRNN` processes the whole sequence. + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer + should use a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation + of the recurrent state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. Fraction of the units to drop for the + linear transformation of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. Fraction of the units to drop + for the linear transformation of the recurrent state. Default: 0. + seed: Random seed for dropout. + + Call arguments: + sequence: A 2D tensor, with shape `(batch, features)`. + states: A 2D tensor with shape `(batch, units)`, which is the state + from the previous time step. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. Only relevant when `dropout` or + `recurrent_dropout` is used. + + Example: + + ```python + inputs = np.random.random([32, 10, 8]).astype(np.float32) + rnn = keras.layers.RNN(keras.layers.SimpleRNNCell(4)) + output = rnn(inputs) # The output has shape `(32, 4)`. + rnn = keras.layers.RNN( + keras.layers.SimpleRNNCell(4), + return_sequences=True, + return_state=True + ) + # whole_sequence_output has shape `(32, 10, 4)`. + # final_state has shape `(32, 4)`. + whole_sequence_output, final_state = rnn(inputs) + ``` + """ + + def __init__( + self, + units, + activation="tanh", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + seed=None, + **kwargs, + ): + if units <= 0: + raise ValueError( + "Received an invalid value for argument `units`, " + f"expected a positive integer, got {units}." + ) + super().__init__(**kwargs) + self.seed = seed + self.seed_generator = backend.random.SeedGenerator(seed) + + self.units = units + self.activation = activations.get(activation) + self.use_bias = use_bias + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.dropout = min(1.0, max(0.0, dropout)) + self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout)) + self.state_size = self.units + self.output_size = self.units + + def build(self, input_shape): + self.kernel = self.add_weight( + shape=(input_shape[-1], self.units), + name="kernel", + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + name="recurrent_kernel", + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint, + ) + if self.use_bias: + self.bias = self.add_weight( + shape=(self.units,), + name="bias", + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + + def call(self, sequence, states, training=False): + prev_output = states[0] if isinstance(states, (list, tuple)) else states + dp_mask = self.get_dropout_mask(sequence) + rec_dp_mask = self.get_recurrent_dropout_mask(prev_output) + + if training and dp_mask is not None: + sequence = sequence * dp_mask + h = ops.matmul(sequence, self.kernel) + if self.bias is not None: + h = ops.add(h, self.bias) + + if training and rec_dp_mask is not None: + prev_output = prev_output * rec_dp_mask + output = h + ops.matmul(prev_output, self.recurrent_kernel) + if self.activation is not None: + output = self.activation(output) + + new_state = [output] if isinstance(states, (list, tuple)) else output + return output, new_state + + def get_initial_state(self, batch_size=None): + return [ + ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype) + ] + + def get_config(self): + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.layers.SimpleRNN") +class SimpleRNN(RNN): + """Fully-connected RNN where the output is to be fed back as the new input. + + Args: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, (default `True`), whether the layer uses + a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. Default: + `"glorot_uniform"`. + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, used for the linear transformation of the recurrent + state. Default: `"orthogonal"`. + bias_initializer: Initializer for the bias vector. Default: `"zeros"`. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_regularizer: Regularizer function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_regularizer: Regularizer function applied to the bias vector. + Default: `None`. + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation"). Default: `None`. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. Default: `None`. + recurrent_constraint: Constraint function applied to the + `recurrent_kernel` weights matrix. Default: `None`. + bias_constraint: Constraint function applied to the bias vector. + Default: `None`. + dropout: Float between 0 and 1. + Fraction of the units to drop for the linear transformation + of the inputs. Default: 0. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for the linear transformation of the + recurrent state. Default: 0. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. Default: `False`. + return_state: Boolean. Whether to return the last state + in addition to the output. Default: `False`. + go_backwards: Boolean (default: `False`). + If `True`, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default: `False`). If `True`, the last state + for each sample at index i in a batch will be used as the + initial state for the sample of index i in the following batch. + unroll: Boolean (default: `False`). + If `True`, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up an RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + Call arguments: + sequence: A 3D tensor, with shape `[batch, timesteps, feature]`. + mask: Binary tensor of shape `[batch, timesteps]` indicating whether + a given timestep should be masked. An individual `True` entry + indicates that the corresponding timestep should be utilized, + while a `False` entry indicates that the corresponding timestep + should be ignored. + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. + This argument is passed to the cell when calling it. + This is only relevant if `dropout` or `recurrent_dropout` is used. + initial_state: List of initial state tensors to be passed to the first + call of the cell. + + Example: + + ```python + inputs = np.random.random((32, 10, 8)) + simple_rnn = keras.layers.SimpleRNN(4) + output = simple_rnn(inputs) # The output has shape `(32, 4)`. + simple_rnn = keras.layers.SimpleRNN( + 4, return_sequences=True, return_state=True + ) + # whole_sequence_output has shape `(32, 10, 4)`. + # final_state has shape `(32, 4)`. + whole_sequence_output, final_state = simple_rnn(inputs) + ``` + """ + + def __init__( + self, + units, + activation="tanh", + use_bias=True, + kernel_initializer="glorot_uniform", + recurrent_initializer="orthogonal", + bias_initializer="zeros", + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0.0, + recurrent_dropout=0.0, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + seed=None, + **kwargs, + ): + cell = SimpleRNNCell( + units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + seed=seed, + dtype=kwargs.get("dtype", None), + trainable=kwargs.get("trainable", True), + name="simple_rnn_cell", + ) + super().__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + **kwargs, + ) + self.input_spec = [InputSpec(ndim=3)] + + def call(self, sequences, initial_state=None, mask=None, training=False): + return super().call( + sequences, mask=mask, training=training, initial_state=initial_state + ) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + def get_config(self): + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "recurrent_initializer": initializers.serialize( + self.recurrent_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "recurrent_regularizer": regularizers.serialize( + self.recurrent_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "activity_regularizer": regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "recurrent_constraint": constraints.serialize( + self.recurrent_constraint + ), + "bias_constraint": constraints.serialize(self.bias_constraint), + "dropout": self.dropout, + "recurrent_dropout": self.recurrent_dropout, + } + base_config = super().get_config() + del base_config["cell"] + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + return cls(**config) diff --git a/keras/src/layers/rnn/simple_rnn_test.py b/keras/src/layers/rnn/simple_rnn_test.py new file mode 100644 index 000000000000..8493bdbee8a8 --- /dev/null +++ b/keras/src/layers/rnn/simple_rnn_test.py @@ -0,0 +1,283 @@ +import numpy as np +import pytest + +from keras.src import initializers +from keras.src import layers +from keras.src import testing + + +class SimpleRNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.SimpleRNN, + init_kwargs={"units": 3, "dropout": 0.5, "recurrent_dropout": 0.5}, + input_shape=(3, 2, 4), + call_kwargs={"training": True}, + expected_output_shape=(3, 3), + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + expected_num_non_trainable_variables=1, + supports_masking=True, + ) + self.run_layer_test( + layers.SimpleRNN, + init_kwargs={ + "units": 3, + "return_sequences": True, + "bias_regularizer": "l1", + "kernel_regularizer": "l2", + "recurrent_regularizer": "l2", + }, + input_shape=(3, 2, 4), + expected_output_shape=(3, 2, 3), + expected_num_losses=3, + expected_num_trainable_weights=3, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + def test_correctness(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.405432, 0.405432, 0.405432, 0.405432], + [0.73605347, 0.73605347, 0.73605347, 0.73605347], + ] + ), + output, + ) + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unroll=True, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.405432, 0.405432, 0.405432, 0.405432], + [0.73605347, 0.73605347, 0.73605347, 0.73605347], + ] + ), + output, + ) + + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.11144729, 0.11144729, 0.11144729, 0.11144729], + [0.5528889, 0.5528889, 0.5528889, 0.5528889], + ] + ), + output, + ) + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + unroll=True, + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.11144729, 0.11144729, 0.11144729, 0.11144729], + [0.5528889, 0.5528889, 0.5528889, 0.5528889], + ] + ), + output, + ) + + def test_statefulness(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.SimpleRNN( + 4, + stateful=True, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.40559256, 0.40559256, 0.40559256, 0.40559256], + [0.7361247, 0.7361247, 0.7361247, 0.7361247], + ] + ), + output, + ) + layer.reset_state() + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [0.40559256, 0.40559256, 0.40559256, 0.40559256], + [0.7361247, 0.7361247, 0.7361247, 0.7361247], + ] + ), + output, + ) + + def test_pass_initial_state(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + initial_state = np.arange(8).reshape((2, 4)).astype("float32") + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array( + [ + [0.33621645, 0.33621645, 0.33621645, 0.33621645], + [0.6262637, 0.6262637, 0.6262637, 0.6262637], + ] + ), + output, + ) + + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence, initial_state=initial_state) + self.assertAllClose( + np.array( + [ + [0.07344437, 0.07344437, 0.07344437, 0.07344437], + [0.43043602, 0.43043602, 0.43043602, 0.43043602], + ] + ), + output, + ) + + def test_masking(self): + sequence = np.arange(24).reshape((2, 4, 3)).astype("float32") + mask = np.array([[True, True, False, True], [True, False, False, True]]) + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + unroll=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.32951632, 0.32951632, 0.32951632, 0.32951632], + [0.61799484, 0.61799484, 0.61799484, 0.61799484], + ] + ), + output, + ) + + layer = layers.SimpleRNN( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.0599281, 0.0599281], + [0.15122814, 0.15122814], + [0.15122814, 0.15122814], + [0.32394567, 0.32394567], + ], + ), + output[0], + ) + self.assertAllClose( + np.array( + [ + [0.3969304, 0.3969304], + [0.3969304, 0.3969304], + [0.3969304, 0.3969304], + [0.608085, 0.608085], + ], + ), + output[1], + ) + + layer = layers.SimpleRNN( + 2, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + return_sequences=True, + zero_output_for_mask=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.0599281, 0.0599281], + [0.15122814, 0.15122814], + [0.0, 0.0], + [0.32394567, 0.32394567], + ], + ), + output[0], + ) + self.assertAllClose( + np.array( + [ + [0.3969304, 0.3969304], + [0.0, 0.0], + [0.0, 0.0], + [0.608085, 0.608085], + ], + ), + output[1], + ) + + layer = layers.SimpleRNN( + 4, + kernel_initializer=initializers.Constant(0.01), + recurrent_initializer=initializers.Constant(0.02), + bias_initializer=initializers.Constant(0.03), + go_backwards=True, + ) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array( + [ + [0.07376196, 0.07376196, 0.07376196, 0.07376196], + [0.43645123, 0.43645123, 0.43645123, 0.43645123], + ] + ), + output, + ) diff --git a/keras/src/layers/rnn/stacked_rnn_cells.py b/keras/src/layers/rnn/stacked_rnn_cells.py new file mode 100644 index 000000000000..613ec7f2b1ee --- /dev/null +++ b/keras/src/layers/rnn/stacked_rnn_cells.py @@ -0,0 +1,138 @@ +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.saving import serialization_lib + + +@keras_export("keras.layers.StackedRNNCells") +class StackedRNNCells(Layer): + """Wrapper allowing a stack of RNN cells to behave as a single cell. + + Used to implement efficient stacked RNNs. + + Args: + cells: List of RNN cell instances. + + Example: + + ```python + batch_size = 3 + sentence_length = 5 + num_features = 2 + new_shape = (batch_size, sentence_length, num_features) + x = np.reshape(np.arange(30), new_shape) + + rnn_cells = [keras.layers.LSTMCell(128) for _ in range(2)] + stacked_lstm = keras.layers.StackedRNNCells(rnn_cells) + lstm_layer = keras.layers.RNN(stacked_lstm) + + result = lstm_layer(x) + ``` + """ + + def __init__(self, cells, **kwargs): + super().__init__(**kwargs) + for cell in cells: + if "call" not in dir(cell): + raise ValueError( + "All cells must have a `call` method. " + f"Received cell without a `call` method: {cell}" + ) + if "state_size" not in dir(cell): + raise ValueError( + "All cells must have a `state_size` attribute. " + f"Received cell without a `state_size`: {cell}" + ) + self.cells = cells + + @property + def state_size(self): + return [c.state_size for c in self.cells] + + @property + def output_size(self): + if getattr(self.cells[-1], "output_size", None) is not None: + return self.cells[-1].output_size + elif isinstance(self.cells[-1].state_size, (list, tuple)): + return self.cells[-1].state_size[0] + else: + return self.cells[-1].state_size + + def get_initial_state(self, batch_size=None): + initial_states = [] + for cell in self.cells: + get_initial_state_fn = getattr(cell, "get_initial_state", None) + if get_initial_state_fn: + initial_states.append( + get_initial_state_fn(batch_size=batch_size) + ) + else: + if isinstance(cell.state_size, int): + initial_states.append( + ops.zeros( + (batch_size, cell.state_size), + dtype=self.compute_dtype, + ) + ) + else: + initial_states.append( + [ + ops.zeros((batch_size, d), dtype=self.compute_dtype) + for d in cell.state_size + ] + ) + return initial_states + + def call(self, inputs, states, training=False, **kwargs): + # Call the cells in order and store the returned states. + new_states = [] + for cell, states in zip(self.cells, states): + state_is_list = tree.is_nested(states) + states = list(states) if tree.is_nested(states) else [states] + if isinstance(cell, Layer) and cell._call_has_training_arg: + kwargs["training"] = training + else: + kwargs.pop("training", None) + cell_call_fn = cell.__call__ if callable(cell) else cell.call + inputs, states = cell_call_fn(inputs, states, **kwargs) + if len(states) == 1 and not state_is_list: + states = states[0] + new_states.append(states) + + if len(new_states) == 1: + new_states = new_states[0] + return inputs, new_states + + def build(self, input_shape): + for cell in self.cells: + if isinstance(cell, Layer) and not cell.built: + cell.build(input_shape) + cell.built = True + if getattr(cell, "output_size", None) is not None: + output_dim = cell.output_size + elif isinstance(cell.state_size, (list, tuple)): + output_dim = cell.state_size[0] + else: + output_dim = cell.state_size + batch_size = tree.flatten(input_shape)[0] + input_shape = (batch_size, output_dim) + + def get_config(self): + cells = [] + for cell in self.cells: + cells.append(serialization_lib.serialize_keras_object(cell)) + config = {"cells": cells} + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + cells = [] + for cell_config in config.pop("cells"): + cells.append( + serialization_lib.deserialize_keras_object( + cell_config, custom_objects=custom_objects + ) + ) + return cls(cells, **config) diff --git a/keras/src/layers/rnn/stacked_rnn_cells_test.py b/keras/src/layers/rnn/stacked_rnn_cells_test.py new file mode 100644 index 000000000000..1b87b177f64b --- /dev/null +++ b/keras/src/layers/rnn/stacked_rnn_cells_test.py @@ -0,0 +1,288 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import testing +from keras.src.layers.rnn.rnn_test import OneStateRNNCell +from keras.src.layers.rnn.rnn_test import TwoStatesRNNCell + + +class StackedRNNTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + OneStateRNNCell(3), + OneStateRNNCell(4), + OneStateRNNCell(5), + ], + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 5), + expected_num_trainable_weights=6, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + custom_objects={"OneStateRNNCell": OneStateRNNCell}, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + OneStateRNNCell(3), + OneStateRNNCell(4), + OneStateRNNCell(5), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=6, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + custom_objects={"OneStateRNNCell": OneStateRNNCell}, + ) + # Two-state case. + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + TwoStatesRNNCell(3), + TwoStatesRNNCell(4), + TwoStatesRNNCell(5), + ], + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + custom_objects={"TwoStatesRNNCell": TwoStatesRNNCell}, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + TwoStatesRNNCell(3), + TwoStatesRNNCell(4), + TwoStatesRNNCell(5), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + supports_masking=True, + custom_objects={"TwoStatesRNNCell": TwoStatesRNNCell}, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + layers.SimpleRNNCell(3, dropout=0.1, recurrent_dropout=0.1), + layers.SimpleRNNCell(4, dropout=0.1, recurrent_dropout=0.1), + layers.SimpleRNNCell(5, dropout=0.1, recurrent_dropout=0.1), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=3, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + layers.GRUCell(3, dropout=0.1, recurrent_dropout=0.1), + layers.GRUCell(4, dropout=0.1, recurrent_dropout=0.1), + layers.GRUCell(5, dropout=0.1, recurrent_dropout=0.1), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=3, + supports_masking=True, + ) + self.run_layer_test( + layers.RNN, + init_kwargs={ + "cell": [ + layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1), + layers.LSTMCell(4, dropout=0.1, recurrent_dropout=0.1), + layers.LSTMCell(5, dropout=0.1, recurrent_dropout=0.1), + ], + "return_sequences": True, + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=9, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=3, + supports_masking=True, + ) + + def test_correctness_single_state_stack(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.RNN([OneStateRNNCell(3), OneStateRNNCell(2)]) + output = layer(sequence) + self.assertAllClose( + np.array([[786.0, 786.0], [4386.0, 4386.0]]), output + ) + + layer = layers.RNN( + [OneStateRNNCell(3), OneStateRNNCell(2)], return_sequences=True + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [[18.0, 18.0], [156.0, 156.0], [786.0, 786.0]], + [[162.0, 162.0], [1020.0, 1020.0], [4386.0, 4386.0]], + ] + ), + output, + ) + + layer = layers.RNN( + [OneStateRNNCell(3), OneStateRNNCell(2)], return_state=True + ) + output, state_1, state_2 = layer(sequence) + self.assertAllClose( + np.array([[786.0, 786.0], [4386.0, 4386.0]]), output + ) + self.assertAllClose( + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1 + ) + self.assertAllClose( + np.array([[786.0, 786.0], [4386.0, 4386.0]]), state_2 + ) + + layer = layers.RNN( + [OneStateRNNCell(3), OneStateRNNCell(2)], + return_sequences=True, + return_state=True, + ) + output, state_1, state_2 = layer(sequence) + self.assertAllClose( + np.array( + [ + [[18.0, 18.0], [156.0, 156.0], [786.0, 786.0]], + [[162.0, 162.0], [1020.0, 1020.0], [4386.0, 4386.0]], + ] + ), + output, + ) + self.assertAllClose( + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1 + ) + self.assertAllClose( + np.array([[786.0, 786.0], [4386.0, 4386.0]]), state_2 + ) + + def test_correctness_two_states_stack(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.RNN([TwoStatesRNNCell(3), TwoStatesRNNCell(2)]) + output = layer(sequence) + self.assertAllClose( + np.array([[3144.0, 3144.0], [17544.0, 17544.0]]), output + ) + + layer = layers.RNN( + [TwoStatesRNNCell(3), TwoStatesRNNCell(2)], return_sequences=True + ) + output = layer(sequence) + self.assertAllClose( + np.array( + [ + [[72.0, 72.0], [624.0, 624.0], [3144.0, 3144.0]], + [[648.0, 648.0], [4080.0, 4080.0], [17544.0, 17544.0]], + ] + ), + output, + ) + + layer = layers.RNN( + [TwoStatesRNNCell(3), TwoStatesRNNCell(2)], return_state=True + ) + output, state_1, state_2 = layer(sequence) + + self.assertAllClose( + np.array([[3144.0, 3144.0], [17544.0, 17544.0]]), output + ) + self.assertAllClose( + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1[0] + ) + self.assertAllClose( + np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1[1] + ) + self.assertAllClose( + np.array([[1572.0, 1572.0], [8772.0, 8772.0]]), state_2[0] + ) + self.assertAllClose( + np.array([[1572.0, 1572.0], [8772.0, 8772.0]]), state_2[1] + ) + + def test_statefullness_single_state_stack(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.RNN( + [OneStateRNNCell(3), OneStateRNNCell(2)], stateful=True + ) + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array([[34092.0, 34092.0], [173196.0, 173196.0]]), output + ) + + def test_statefullness_two_states_stack(self): + sequence = np.arange(24).reshape((2, 3, 4)).astype("float32") + layer = layers.RNN( + [TwoStatesRNNCell(3), TwoStatesRNNCell(2)], stateful=True + ) + layer(sequence) + output = layer(sequence) + self.assertAllClose( + np.array([[136368.0, 136368.0], [692784.0, 692784.0]]), output + ) + + def test_return_state_stacked_lstm_cell(self): + layer = layers.RNN( + [layers.LSTMCell(10), layers.LSTMCell(10)], return_state=True + ) + out = layer(np.zeros((2, 3, 5))) + self.assertLen(out, 3) + self.assertEqual(out[0].shape, (2, 10)) + self.assertEqual(out[1][0].shape, (2, 10)) + self.assertEqual(out[1][1].shape, (2, 10)) + self.assertEqual(out[2][0].shape, (2, 10)) + self.assertEqual(out[2][1].shape, (2, 10)) + + shape = layer.compute_output_shape((2, 3, 5)) + self.assertLen(shape, 3) + self.assertEqual(shape[0], (2, 10)) + self.assertEqual(shape[1][0], (2, 10)) + self.assertEqual(shape[1][1], (2, 10)) + self.assertEqual(shape[2][0], (2, 10)) + self.assertEqual(shape[2][1], (2, 10)) + + def test_stacked_lstm_cell_mask(self): + sequence = np.ones((2, 3, 4)) + mask = np.array([[True, True, True], [True, True, False]]) + cell_kwargs = dict( + units=1, kernel_initializer="ones", recurrent_initializer="ones" + ) + rnn_cells = [layers.LSTMCell(**cell_kwargs) for _ in range(2)] + stacked_rnn = layers.RNN(rnn_cells) + output = stacked_rnn(sequence, mask=mask) + self.assertAllClose(np.array([[0.7793], [0.5998]]), output, atol=1e-4) diff --git a/keras/src/layers/rnn/time_distributed.py b/keras/src/layers/rnn/time_distributed.py new file mode 100644 index 000000000000..51aec7893f1d --- /dev/null +++ b/keras/src/layers/rnn/time_distributed.py @@ -0,0 +1,133 @@ +"""Wrapper layer to apply every temporal slice of an input.""" + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.core.wrapper import Wrapper +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.TimeDistributed") +class TimeDistributed(Wrapper): + """This wrapper allows to apply a layer to every temporal slice of an input. + + Every input should be at least 3D, and the dimension of index one of the + first input will be considered to be the temporal dimension. + + Consider a batch of 32 video samples, where each sample is a 128x128 RGB + image with `channels_last` data format, across 10 timesteps. + The batch input shape is `(32, 10, 128, 128, 3)`. + + You can then use `TimeDistributed` to apply the same `Conv2D` layer to each + of the 10 timesteps, independently: + + >>> inputs = layers.Input(shape=(10, 128, 128, 3), batch_size=32) + >>> conv_2d_layer = layers.Conv2D(64, (3, 3)) + >>> outputs = layers.TimeDistributed(conv_2d_layer)(inputs) + >>> outputs.shape + (32, 10, 126, 126, 64) + + Because `TimeDistributed` applies the same instance of `Conv2D` to each of + the timestamps, the same set of weights are used at each timestamp. + + Args: + layer: a `keras.layers.Layer` instance. + + Call arguments: + inputs: Input tensor of shape (batch, time, ...) or nested tensors, + and each of which has shape (batch, time, ...). + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. This argument is passed to the + wrapped layer (only if the layer supports this argument). + mask: Binary tensor of shape `(samples, timesteps)` indicating whether + a given timestep should be masked. This argument is passed to the + wrapped layer (only if the layer supports this argument). + """ + + def __init__(self, layer, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + "Please initialize `TimeDistributed` layer with a " + f"`keras.layers.Layer` instance. Received: {layer}" + ) + super().__init__(layer, **kwargs) + self.supports_masking = True + + def _get_child_input_shape(self, input_shape): + if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3: + raise ValueError( + "`TimeDistributed` Layer should be passed an `input_shape` " + f"with at least 3 dimensions, received: {input_shape}" + ) + return (input_shape[0], *input_shape[2:]) + + def compute_output_shape(self, input_shape): + child_input_shape = self._get_child_input_shape(input_shape) + child_output_shape = self.layer.compute_output_shape(child_input_shape) + return (child_output_shape[0], input_shape[1], *child_output_shape[1:]) + + def build(self, input_shape): + child_input_shape = self._get_child_input_shape(input_shape) + super().build(child_input_shape) + + def call(self, inputs, training=None, mask=None): + input_shape = ops.shape(inputs) + mask_shape = None if mask is None else ops.shape(mask) + batch_size = input_shape[0] + timesteps = input_shape[1] + + # For TF backend with graph mode and `partial_batch_size`, skip + # evaluation of `batch_size` as it can be a `strided_slice` and + # not a constant. + if backend.backend() == "tensorflow": + from keras.src.utils.module_utils import tensorflow as tf + + if ( + not tf.executing_eagerly + and mask_shape is not None + and mask_shape[1:2] != (timesteps,) + ): + raise ValueError( + "`TimeDistributed` Layer should be passed a `mask` of " + f"shape ({batch_size}, {timesteps}, ...), " + f"received: mask.shape={mask_shape}" + ) + elif mask_shape is not None and mask_shape[:2] != ( + batch_size, + timesteps, + ): + raise ValueError( + "`TimeDistributed` Layer should be passed a `mask` of " + f"shape ({batch_size}, {timesteps}, ...), " + f"received: mask.shape={mask_shape}" + ) + + def time_distributed_transpose(data): + """Swaps the timestep and batch dimensions of a tensor.""" + axes = [1, 0, *range(2, len(data.shape))] + return ops.transpose(data, axes=axes) + + inputs = time_distributed_transpose(inputs) + if mask is not None: + mask = time_distributed_transpose(mask) + + def step_function(i): + kwargs = {} + if self.layer._call_has_mask_arg and mask is not None: + kwargs["mask"] = mask[i] + if self.layer._call_has_training_arg: + kwargs["training"] = training + return self.layer.call(inputs[i], **kwargs) + + # Implementation #1: is the time axis is static, use a Python for loop. + + if inputs.shape[0] is not None: + outputs = ops.stack( + [step_function(i) for i in range(inputs.shape[0])] + ) + return time_distributed_transpose(outputs) + + # Implementation #2: use backend.vectorized_map. + + outputs = backend.vectorized_map(step_function, ops.arange(timesteps)) + return time_distributed_transpose(outputs) diff --git a/keras/src/layers/rnn/time_distributed_test.py b/keras/src/layers/rnn/time_distributed_test.py new file mode 100644 index 000000000000..87cc31fe6197 --- /dev/null +++ b/keras/src/layers/rnn/time_distributed_test.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import ops +from keras.src import testing +from keras.src.models import Sequential + + +class TimeDistributedTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basics(self): + self.run_layer_test( + layers.TimeDistributed, + init_kwargs={"layer": layers.Dense(1, use_bias=False)}, + input_shape=(3, 2, 4), + expected_output_shape=(3, 2, 1), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=0, + supports_masking=True, + ) + + def test_build(self): + if backend.config.image_data_format() == "channels_last": + input_shape = (10, 128, 128, 3) + output_shape = (32, 10, 126, 126, 64) + else: + input_shape = (10, 3, 128, 128) + output_shape = (32, 10, 64, 126, 126) + inputs = layers.Input(shape=input_shape, batch_size=32) + conv_2d_layer = layers.Conv2D(64, (3, 3)) + outputs = layers.TimeDistributed(conv_2d_layer)(inputs) + self.assertEqual(outputs.shape, output_shape) + + def test_correctness(self): + sequence = np.arange(24).reshape((3, 2, 4)).astype("float32") + layer = layers.Dense( + 1, + kernel_initializer=initializers.Constant(0.01), + use_bias=False, + ) + layer = layers.TimeDistributed(layer=layer) + output = layer(sequence) + self.assertAllClose( + np.array( + [[[0.06], [0.22]], [[0.38], [0.53999996]], [[0.7], [0.86]]] + ), + output, + ) + + def test_masking(self): + class MaskedDense(layers.Wrapper): + def __init__(self, units, **kwargs): + layer = layers.Dense( + units, + kernel_initializer=initializers.Constant(0.01), + use_bias=False, + ) + super().__init__(layer, **kwargs) + self.supports_masking = True + + def call(self, inputs, training=False, mask=None): + unmasked = self.layer.call(inputs) + if mask is None: + return unmasked + else: + return ops.transpose( + ops.transpose(unmasked) * ops.cast(mask, inputs.dtype) + ) + + sequence = np.arange(24).reshape((3, 2, 4)).astype("float32") + layer = layers.TimeDistributed(layer=MaskedDense(1)) + mask = np.array([[False, True], [True, False], [True, True]]) + output = layer(sequence, mask=mask) + self.assertAllClose( + np.array([[[0], [0.22]], [[0.38], [0]], [[0.7], [0.86]]]), + output, + ) + + @pytest.mark.requires_trainable_backend + def test_with_mask_zero(self): + model = Sequential( + [ + layers.Input(shape=(20,)), + layers.Embedding(input_dim=10, output_dim=5, mask_zero=True), + layers.TimeDistributed( + layers.Dense(units=5, activation="softmax") + ), + ] + ) + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + X_train = np.random.uniform(1, 10, size=(22, 20)) + Y_train = np.random.randint(1, 2, size=(22, 20)) + + model.fit(X_train, Y_train, epochs=1, batch_size=16) diff --git a/keras/src/legacy/__init__.py b/keras/src/legacy/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/legacy/backend.py b/keras/src/legacy/backend.py new file mode 100644 index 000000000000..9c361c7f33e5 --- /dev/null +++ b/keras/src/legacy/backend.py @@ -0,0 +1,2277 @@ +"""Legacy Keras 1/2 backend functions.""" + +import itertools + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.utils.module_utils import tensorflow as tf + +py_any = any +py_all = all + + +@keras_export("keras._legacy.backend.abs") +def abs(x): + """DEPRECATED.""" + return tf.abs(x) + + +@keras_export("keras._legacy.backend.all") +def all(x, axis=None, keepdims=False): + """DEPRECATED.""" + x = tf.cast(x, tf.bool) + return tf.reduce_all(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.any") +def any(x, axis=None, keepdims=False): + """DEPRECATED.""" + x = tf.cast(x, tf.bool) + return tf.reduce_any(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.argmax") +def argmax(x, axis=-1): + """DEPRECATED.""" + return tf.argmax(x, axis) + + +@keras_export("keras._legacy.backend.argmin") +def argmin(x, axis=-1): + """DEPRECATED.""" + return tf.argmin(x, axis) + + +@keras_export("keras._legacy.backend.arange") +def arange(start, stop=None, step=1, dtype="int32"): + """DEPRECATED.""" + if stop is None and start < 0: + start = 0 + result = tf.range(start, limit=stop, delta=step, name="arange") + if dtype != "int32": + result = tf.cast(result, dtype) + return result + + +@keras_export("keras._legacy.backend.batch_dot") +def batch_dot(x, y, axes=None): + """DEPRECATED.""" + x_shape = x.shape + y_shape = y.shape + + x_ndim = len(x_shape) + y_ndim = len(y_shape) + + if x_ndim < 2 or y_ndim < 2: + raise ValueError( + "Cannot do batch_dot on inputs " + "with rank < 2. " + f"Received inputs with tf.shapes {x_shape} and {y_shape}." + ) + + x_batch_size = x_shape[0] + y_batch_size = y_shape[0] + + if x_batch_size is not None and y_batch_size is not None: + if x_batch_size != y_batch_size: + raise ValueError( + "Cannot do batch_dot on inputs " + "with different batch sizes. " + "Received inputs with tf.shapes " + f"{x_shape} and {y_shape}." + ) + if isinstance(axes, int): + axes = [axes, axes] + + if axes is None: + if y_ndim == 2: + axes = [x_ndim - 1, y_ndim - 1] + else: + axes = [x_ndim - 1, y_ndim - 2] + + if py_any(isinstance(a, (list, tuple)) for a in axes): + raise ValueError( + "Multiple target dimensions are not supported. " + "Expected: None, int, (int, int), " + f"Provided: {axes}" + ) + + # if tuple, convert to list. + axes = list(axes) + + # convert negative indices. + if axes[0] < 0: + axes[0] += x_ndim + if axes[1] < 0: + axes[1] += y_ndim + + # sanity checks + if 0 in axes: + raise ValueError( + "Cannot perform batch_dot over axis 0. " + "If your inputs are not batched, " + "add a dummy batch dimension to your " + "inputs using K.expand_dims(x, 0)" + ) + a0, a1 = axes + d1 = x_shape[a0] + d2 = y_shape[a1] + + if d1 is not None and d2 is not None and d1 != d2: + raise ValueError( + "Cannot do batch_dot on inputs with tf.shapes " + f"{x_shape} and {y_shape} with axes={axes}. " + "x.shape[%d] != y.shape[%d] (%d != %d)." + % (axes[0], axes[1], d1, d2) + ) + + # backup ndims. Need them later. + orig_x_ndim = x_ndim + orig_y_ndim = y_ndim + + # if rank is 2, expand to 3. + if x_ndim == 2: + x = tf.expand_dims(x, 1) + a0 += 1 + x_ndim += 1 + if y_ndim == 2: + y = tf.expand_dims(y, 2) + y_ndim += 1 + + # bring x's dimension to be reduced to last axis. + if a0 != x_ndim - 1: + pattern = list(range(x_ndim)) + for i in range(a0, x_ndim - 1): + pattern[i] = pattern[i + 1] + pattern[-1] = a0 + x = tf.transpose(x, pattern) + + # bring y's dimension to be reduced to axis 1. + if a1 != 1: + pattern = list(range(y_ndim)) + for i in range(a1, 1, -1): + pattern[i] = pattern[i - 1] + pattern[1] = a1 + y = tf.transpose(y, pattern) + + # normalize both inputs to rank 3. + if x_ndim > 3: + # squash middle dimensions of x. + x_shape = tf.shape(x) + x_mid_dims = x_shape[1:-1] + x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]]) + x = tf.reshape(x, x_squashed_shape) + x_squashed = True + else: + x_squashed = False + + if y_ndim > 3: + # squash trailing dimensions of y. + y_shape = tf.shape(y) + y_trail_dims = y_shape[2:] + y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1]) + y = tf.reshape(y, y_squashed_shape) + y_squashed = True + else: + y_squashed = False + + result = tf.matmul(x, y) + + # if inputs were squashed, we have to reshape the matmul output. + output_shape = tf.shape(result) + do_reshape = False + + if x_squashed: + output_shape = tf.concat( + [output_shape[:1], x_mid_dims, output_shape[-1:]], 0 + ) + do_reshape = True + + if y_squashed: + output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0) + do_reshape = True + + if do_reshape: + result = tf.reshape(result, output_shape) + + # if the inputs were originally rank 2, we remove the added 1 dim. + if orig_x_ndim == 2: + result = tf.squeeze(result, 1) + elif orig_y_ndim == 2: + result = tf.squeeze(result, -1) + + return result + + +@keras_export("keras._legacy.backend.batch_flatten") +def batch_flatten(x): + """DEPRECATED.""" + x = tf.reshape(x, tf.stack([-1, prod(tf.shape(x)[1:])])) + return x + + +@keras_export("keras._legacy.backend.batch_get_value") +def batch_get_value(tensors): + """DEPRECATED.""" + return [x.numpy() for x in tensors] + + +@keras_export("keras._legacy.backend.batch_set_value") +def batch_set_value(tuples): + """DEPRECATED.""" + if tf.executing_eagerly() or tf.inside_function(): + for x, value in tuples: + value = np.asarray(value, dtype=x.dtype.name) + x.assign(value) + + +@keras_export("keras._legacy.backend.batch_normalization") +def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): + """DEPRECATED.""" + return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon) + + +@keras_export("keras._legacy.backend.bias_add") +def bias_add(x, bias, data_format=None): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + bias_shape = bias.shape + if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1: + raise ValueError( + f"Unexpected bias dimensions {len(bias_shape)}. " + f"Expected it to be 1 or {ndim(x) - 1} dimensions" + ) + + if len(bias_shape) == 1: + if data_format == "channels_first": + return tf.nn.bias_add(x, bias, data_format="NCHW") + return tf.nn.bias_add(x, bias, data_format="NHWC") + if ndim(x) in (3, 4, 5): + if data_format == "channels_first": + bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1] + return x + reshape(bias, bias_reshape_axis) + return x + reshape(bias, (1,) + bias_shape) + return tf.nn.bias_add(x, bias) + + +@keras_export("keras._legacy.backend.binary_crossentropy") +def binary_crossentropy(target, output, from_logits=False): + """DEPRECATED.""" + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + + if from_logits: + return tf.nn.sigmoid_cross_entropy_with_logits( + labels=target, logits=output + ) + + epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype) + output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) + + # Compute cross entropy from probabilities. + bce = target * tf.math.log(output + backend.epsilon()) + bce += (1 - target) * tf.math.log(1 - output + backend.epsilon()) + return -bce + + +@keras_export("keras._legacy.backend.binary_focal_crossentropy") +def binary_focal_crossentropy( + target, + output, + apply_class_balancing=False, + alpha=0.25, + gamma=2.0, + from_logits=False, +): + """DEPRECATED.""" + sigmoidal = tf.sigmoid(output) if from_logits else output + + p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal) + + # Calculate focal factor + focal_factor = tf.pow(1.0 - p_t, gamma) + + # Binary crossentropy + bce = binary_crossentropy( + target=target, + output=output, + from_logits=from_logits, + ) + focal_bce = focal_factor * bce + + if apply_class_balancing: + weight = target * alpha + (1 - target) * (1 - alpha) + focal_bce = weight * focal_bce + + return focal_bce + + +@keras_export("keras._legacy.backend.cast") +def cast(x, dtype): + """DEPRECATED.""" + return tf.cast(x, dtype) + + +@keras_export("keras._legacy.backend.cast_to_floatx") +def cast_to_floatx(x): + """DEPRECATED.""" + if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)): + return tf.cast(x, dtype=backend.floatx()) + return np.asarray(x, dtype=backend.floatx()) + + +@keras_export("keras._legacy.backend.categorical_crossentropy") +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + """DEPRECATED.""" + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + target.shape.assert_is_compatible_with(output.shape) + + if from_logits: + return tf.nn.softmax_cross_entropy_with_logits( + labels=target, logits=output, axis=axis + ) + + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. + output = output / tf.reduce_sum(output, axis, True) + + # Compute cross entropy from probabilities. + epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype) + output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) + return -tf.reduce_sum(target * tf.math.log(output), axis) + + +@keras_export("keras._legacy.backend.categorical_focal_crossentropy") +def categorical_focal_crossentropy( + target, + output, + alpha=0.25, + gamma=2.0, + from_logits=False, + axis=-1, +): + """DEPRECATED.""" + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + target.shape.assert_is_compatible_with(output.shape) + + if from_logits: + output = tf.nn.softmax(output, axis=axis) + + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. + output = output / tf.reduce_sum(output, axis=axis, keepdims=True) + + epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype) + output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) + + # Calculate cross entropy + cce = -target * tf.math.log(output) + + # Calculate factors + modulating_factor = tf.pow(1.0 - output, gamma) + weighting_factor = tf.multiply(modulating_factor, alpha) + + # Apply weighting factor + focal_cce = tf.multiply(weighting_factor, cce) + focal_cce = tf.reduce_sum(focal_cce, axis=axis) + return focal_cce + + +@keras_export("keras._legacy.backend.clip") +def clip(x, min_value, max_value): + """DEPRECATED.""" + if isinstance(min_value, (int, float)) and isinstance( + max_value, (int, float) + ): + if max_value < min_value: + max_value = min_value + if min_value is None: + min_value = -np.inf + if max_value is None: + max_value = np.inf + return tf.clip_by_value(x, min_value, max_value) + + +@keras_export("keras._legacy.backend.concatenate") +def concatenate(tensors, axis=-1): + """DEPRECATED.""" + if axis < 0: + rank = ndim(tensors[0]) + if rank: + axis %= rank + else: + axis = 0 + + if py_all(is_sparse(x) for x in tensors): + return tf.compat.v1.sparse_concat(axis, tensors) + elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors): + return tf.concat(tensors, axis) + else: + return tf.concat([to_dense(x) for x in tensors], axis) + + +@keras_export("keras._legacy.backend.constant") +def constant(value, dtype=None, shape=None, name=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + + return tf.constant(value, dtype=dtype, shape=shape, name=name) + + +def _preprocess_conv1d_input(x, data_format): + tf_data_format = "NWC" # to pass TF Conv2dNative operations + if data_format == "channels_first": + tf_data_format = "NCW" + return x, tf_data_format + + +def _preprocess_conv2d_input(x, data_format, force_transpose=False): + tf_data_format = "NHWC" + if data_format == "channels_first": + if force_transpose: + x = tf.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC + else: + tf_data_format = "NCHW" + return x, tf_data_format + + +def _preprocess_conv3d_input(x, data_format): + tf_data_format = "NDHWC" + if data_format == "channels_first": + tf_data_format = "NCDHW" + return x, tf_data_format + + +def _preprocess_padding(padding): + if padding == "same": + padding = "SAME" + elif padding == "valid": + padding = "VALID" + else: + raise ValueError(f"Invalid padding: {padding}") + return padding + + +@keras_export("keras._legacy.backend.conv1d") +def conv1d( + x, kernel, strides=1, padding="valid", data_format=None, dilation_rate=1 +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + kernel_shape = kernel.shape.as_list() + if padding == "causal": + # causal (dilated) convolution: + left_pad = dilation_rate * (kernel_shape[0] - 1) + x = temporal_padding(x, (left_pad, 0)) + padding = "valid" + padding = _preprocess_padding(padding) + + x, tf_data_format = _preprocess_conv1d_input(x, data_format) + x = tf.compat.v1.nn.convolution( + input=x, + filter=kernel, + dilation_rate=dilation_rate, + strides=strides, + padding=padding, + data_format=tf_data_format, + ) + if data_format == "channels_first" and tf_data_format == "NWC": + x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW + return x + + +@keras_export("keras._legacy.backend.conv2d") +def conv2d( + x, + kernel, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + x, tf_data_format = _preprocess_conv2d_input(x, data_format) + padding = _preprocess_padding(padding) + x = tf.compat.v1.nn.convolution( + input=x, + filter=kernel, + dilation_rate=dilation_rate, + strides=strides, + padding=padding, + data_format=tf_data_format, + ) + if data_format == "channels_first" and tf_data_format == "NHWC": + x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW + return x + + +@keras_export("keras._legacy.backend.conv2d_transpose") +def conv2d_transpose( + x, + kernel, + output_shape, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + # `atrous_conv2d_transpose` only supports NHWC format, even on GPU. + if data_format == "channels_first" and dilation_rate != (1, 1): + force_transpose = True + else: + force_transpose = False + + x, tf_data_format = _preprocess_conv2d_input( + x, data_format, force_transpose + ) + + if data_format == "channels_first" and tf_data_format == "NHWC": + output_shape = ( + output_shape[0], + output_shape[2], + output_shape[3], + output_shape[1], + ) + if output_shape[0] is None: + output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:]) + + if isinstance(output_shape, (tuple, list)): + output_shape = tf.stack(list(output_shape)) + + padding = _preprocess_padding(padding) + if tf_data_format == "NHWC": + strides = (1,) + strides + (1,) + else: + strides = (1, 1) + strides + + if dilation_rate == (1, 1): + x = tf.compat.v1.nn.conv2d_transpose( + x, + kernel, + output_shape, + strides, + padding=padding, + data_format=tf_data_format, + ) + else: + if dilation_rate[0] != dilation_rate[1]: + raise ValueError( + "Expected the 2 dimensions of the `dilation_rate` argument " + "to be equal to each other. " + f"Received: dilation_rate={dilation_rate}" + ) + x = tf.nn.atrous_conv2d_transpose( + x, kernel, output_shape, rate=dilation_rate[0], padding=padding + ) + if data_format == "channels_first" and tf_data_format == "NHWC": + x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW + return x + + +@keras_export("keras._legacy.backend.conv3d") +def conv3d( + x, + kernel, + strides=(1, 1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1, 1), +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + x, tf_data_format = _preprocess_conv3d_input(x, data_format) + padding = _preprocess_padding(padding) + x = tf.compat.v1.nn.convolution( + input=x, + filter=kernel, + dilation_rate=dilation_rate, + strides=strides, + padding=padding, + data_format=tf_data_format, + ) + if data_format == "channels_first" and tf_data_format == "NDHWC": + x = tf.transpose(x, (0, 4, 1, 2, 3)) + return x + + +@keras_export("keras._legacy.backend.cos") +def cos(x): + """DEPRECATED.""" + return tf.cos(x) + + +@keras_export("keras._legacy.backend.count_params") +def count_params(x): + """DEPRECATED.""" + return np.prod(x.shape.as_list()) + + +@keras_export("keras._legacy.backend.ctc_batch_cost") +def ctc_batch_cost(y_true, y_pred, input_length, label_length): + """DEPRECATED.""" + label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32) + input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32) + sparse_labels = tf.cast( + ctc_label_dense_to_sparse(y_true, label_length), tf.int32 + ) + + y_pred = tf.math.log( + tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon() + ) + + return tf.expand_dims( + tf.compat.v1.nn.ctc_loss( + inputs=y_pred, labels=sparse_labels, sequence_length=input_length + ), + 1, + ) + + +@keras_export("keras._legacy.backend.ctc_label_dense_to_sparse") +def ctc_label_dense_to_sparse(labels, label_lengths): + """DEPRECATED.""" + label_shape = tf.shape(labels) + num_batches_tns = tf.stack([label_shape[0]]) + max_num_labels_tns = tf.stack([label_shape[1]]) + + def range_less_than(old_input, current_input): + return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill( + max_num_labels_tns, current_input + ) + + init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool) + dense_mask = tf.compat.v1.scan( + range_less_than, label_lengths, initializer=init, parallel_iterations=1 + ) + dense_mask = dense_mask[:, 0, :] + + label_array = tf.reshape( + tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape + ) + label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask) + + batch_array = tf.transpose( + tf.reshape( + tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), + reverse(label_shape, 0), + ) + ) + batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask) + indices = tf.transpose( + tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]) + ) + + vals_sparse = tf.compat.v1.gather_nd(labels, indices) + + return tf.SparseTensor( + tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64) + ) + + +@keras_export("keras._legacy.backend.ctc_decode") +def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): + """DEPRECATED.""" + input_shape = tf.shape(y_pred) + num_samples, num_steps = input_shape[0], input_shape[1] + y_pred = tf.math.log( + tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon() + ) + input_length = tf.cast(input_length, tf.int32) + + if greedy: + (decoded, log_prob) = tf.nn.ctc_greedy_decoder( + inputs=y_pred, sequence_length=input_length + ) + else: + (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder( + inputs=y_pred, + sequence_length=input_length, + beam_width=beam_width, + top_paths=top_paths, + ) + decoded_dense = [] + for st in decoded: + st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) + decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) + return (decoded_dense, log_prob) + + +@keras_export("keras._legacy.backend.cumsum") +def cumsum(x, axis=0): + """DEPRECATED.""" + return tf.cumsum(x, axis=axis) + + +@keras_export("keras._legacy.backend.cumprod") +def cumprod(x, axis=0): + """DEPRECATED.""" + return tf.math.cumprod(x, axis=axis) + + +@keras_export("keras._legacy.backend.depthwise_conv2d") +def depthwise_conv2d( + x, + depthwise_kernel, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + x, tf_data_format = _preprocess_conv2d_input(x, data_format) + padding = _preprocess_padding(padding) + if tf_data_format == "NHWC": + strides = (1,) + strides + (1,) + else: + strides = (1, 1) + strides + + x = tf.nn.depthwise_conv2d( + x, + depthwise_kernel, + strides=strides, + padding=padding, + dilations=dilation_rate, + data_format=tf_data_format, + ) + if data_format == "channels_first" and tf_data_format == "NHWC": + x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW + return x + + +@keras_export("keras._legacy.backend.dot") +def dot(x, y): + """DEPRECATED.""" + if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2): + x_shape = [] + for i, s in zip(x.shape, tf.unstack(tf.shape(x))): + if i is not None: + x_shape.append(i) + else: + x_shape.append(s) + x_shape = tuple(x_shape) + y_shape = [] + for i, s in zip(y.shape, tf.unstack(tf.shape(y))): + if i is not None: + y_shape.append(i) + else: + y_shape.append(s) + y_shape = tuple(y_shape) + y_permute_dim = list(range(ndim(y))) + y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim + xt = tf.reshape(x, [-1, x_shape[-1]]) + yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1]) + return tf.reshape( + tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:] + ) + if is_sparse(x): + out = tf.sparse.sparse_dense_matmul(x, y) + else: + out = tf.matmul(x, y) + return out + + +@keras_export("keras._legacy.backend.dropout") +def dropout(x, level, noise_shape=None, seed=None): + """DEPRECATED.""" + if seed is None: + seed = np.random.randint(10e6) + return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed) + + +@keras_export("keras._legacy.backend.dtype") +def dtype(x): + """DEPRECATED.""" + return x.dtype.base_dtype.name + + +@keras_export("keras._legacy.backend.elu") +def elu(x, alpha=1.0): + """DEPRECATED.""" + res = tf.nn.elu(x) + if alpha == 1: + return res + else: + return tf.where(x > 0, res, alpha * res) + + +@keras_export("keras._legacy.backend.equal") +def equal(x, y): + """DEPRECATED.""" + return tf.equal(x, y) + + +@keras_export("keras._legacy.backend.eval") +def eval(x): + """DEPRECATED.""" + return get_value(to_dense(x)) + + +@keras_export("keras._legacy.backend.exp") +def exp(x): + """DEPRECATED.""" + return tf.exp(x) + + +@keras_export("keras._legacy.backend.expand_dims") +def expand_dims(x, axis=-1): + """DEPRECATED.""" + return tf.expand_dims(x, axis) + + +@keras_export("keras._legacy.backend.eye") +def eye(size, dtype=None, name=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + tf_dtype = tf.as_dtype(dtype) + return variable(tf.eye(size, dtype=tf_dtype), dtype, name) + + +@keras_export("keras._legacy.backend.flatten") +def flatten(x): + """DEPRECATED.""" + return tf.reshape(x, [-1]) + + +@keras_export("keras._legacy.backend.foldl") +def foldl(fn, elems, initializer=None, name=None): + """DEPRECATED.""" + return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name) + + +@keras_export("keras._legacy.backend.foldr") +def foldr(fn, elems, initializer=None, name=None): + """DEPRECATED.""" + return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name) + + +@keras_export("keras._legacy.backend.gather") +def gather(reference, indices): + """DEPRECATED.""" + return tf.compat.v1.gather(reference, indices) + + +@keras_export("keras._legacy.backend.get_value") +def get_value(x): + """DEPRECATED.""" + if not tf.is_tensor(x): + return x + if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor): + return x.numpy() + if not getattr(x, "_in_graph_mode", True): + # This is a variable which was created in an eager context, but is being + # evaluated from a Graph. + with tf.__internal__.eager_context.eager_mode(): + return x.numpy() + with tf.init_scope(): + return x.numpy() + + +@keras_export("keras._legacy.backend.gradients") +def gradients(loss, variables): + """DEPRECATED.""" + return tf.compat.v1.gradients( + loss, variables, colocate_gradients_with_ops=True + ) + + +@keras_export("keras._legacy.backend.greater") +def greater(x, y): + """DEPRECATED.""" + return tf.greater(x, y) + + +@keras_export("keras._legacy.backend.greater_equal") +def greater_equal(x, y): + """DEPRECATED.""" + return tf.greater_equal(x, y) + + +@keras_export("keras._legacy.backend.hard_sigmoid") +def hard_sigmoid(x): + """DEPRECATED.""" + point_two = tf.convert_to_tensor(0.2, dtype=x.dtype) + point_five = tf.convert_to_tensor(0.5, dtype=x.dtype) + x = tf.multiply(x, point_two) + x = tf.add(x, point_five) + x = tf.clip_by_value(x, 0.0, 1.0) + return x + + +@keras_export("keras._legacy.backend.in_top_k") +def in_top_k(predictions, targets, k): + """DEPRECATED.""" + return tf.compat.v1.math.in_top_k(predictions, targets, k) + + +@keras_export("keras._legacy.backend.int_shape") +def int_shape(x): + """DEPRECATED.""" + try: + shape = x.shape + if not isinstance(shape, tuple): + shape = tuple(shape.as_list()) + return shape + except ValueError: + return None + + +@keras_export("keras._legacy.backend.is_sparse") +def is_sparse(tensor): + """DEPRECATED.""" + spec = getattr(tensor, "_type_spec", None) + if spec is not None: + return isinstance(spec, tf.SparseTensorSpec) + return isinstance(tensor, tf.SparseTensor) + + +@keras_export("keras._legacy.backend.l2_normalize") +def l2_normalize(x, axis=None): + """DEPRECATED.""" + return tf.linalg.l2_normalize(x, axis=axis) + + +@keras_export("keras._legacy.backend.less") +def less(x, y): + """DEPRECATED.""" + return tf.less(x, y) + + +@keras_export("keras._legacy.backend.less_equal") +def less_equal(x, y): + """DEPRECATED.""" + return tf.less_equal(x, y) + + +@keras_export("keras._legacy.backend.log") +def log(x): + """DEPRECATED.""" + return tf.math.log(x) + + +@keras_export("keras._legacy.backend.map_fn") +def map_fn(fn, elems, name=None, dtype=None): + """DEPRECATED.""" + return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype) + + +@keras_export("keras._legacy.backend.max") +def max(x, axis=None, keepdims=False): + """DEPRECATED.""" + return tf.reduce_max(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.maximum") +def maximum(x, y): + """DEPRECATED.""" + return tf.maximum(x, y) + + +@keras_export("keras._legacy.backend.mean") +def mean(x, axis=None, keepdims=False): + """DEPRECATED.""" + if x.dtype.base_dtype == tf.bool: + x = tf.cast(x, backend.floatx()) + return tf.reduce_mean(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.min") +def min(x, axis=None, keepdims=False): + """DEPRECATED.""" + return tf.reduce_min(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.minimum") +def minimum(x, y): + """DEPRECATED.""" + return tf.minimum(x, y) + + +@keras_export("keras._legacy.backend.moving_average_update") +def moving_average_update(x, value, momentum): + """DEPRECATED.""" + momentum = tf.cast(momentum, x.dtype) + value = tf.cast(value, x.dtype) + return x.assign_sub((x - value) * (1 - momentum)) + + +@keras_export("keras._legacy.backend.name_scope") +def name_scope(name): + """DEPRECATED.""" + return tf.name_scope(name) + + +@keras_export("keras._legacy.backend.ndim") +def ndim(x): + """DEPRECATED.""" + return x.shape.rank + + +@keras_export("keras._legacy.backend.not_equal") +def not_equal(x, y): + """DEPRECATED.""" + return tf.not_equal(x, y) + + +@keras_export("keras._legacy.backend.one_hot") +def one_hot(indices, num_classes): + """DEPRECATED.""" + return tf.one_hot(indices, depth=num_classes, axis=-1) + + +@keras_export("keras._legacy.backend.ones") +def ones(shape, dtype=None, name=None): + """DEPRECATED.""" + with tf.init_scope(): + if dtype is None: + dtype = backend.floatx() + tf_dtype = tf.as_dtype(dtype) + v = tf.ones(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.shape.as_list()): + return variable(v, dtype=dtype, name=name) + return v + + +@keras_export("keras._legacy.backend.ones_like") +def ones_like(x, dtype=None, name=None): + """DEPRECATED.""" + return tf.ones_like(x, dtype=dtype, name=name) + + +@keras_export("keras._legacy.backend.permute_dimensions") +def permute_dimensions(x, pattern): + """DEPRECATED.""" + return tf.transpose(x, perm=pattern) + + +@keras_export("keras._legacy.backend.pool2d") +def pool2d( + x, + pool_size, + strides=(1, 1), + padding="valid", + data_format=None, + pool_mode="max", +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + if len(pool_size) != 2: + raise ValueError("`pool_size` must be a tuple of 2 integers.") + if len(strides) != 2: + raise ValueError("`strides` must be a tuple of 2 integers.") + + x, tf_data_format = _preprocess_conv2d_input(x, data_format) + padding = _preprocess_padding(padding) + if tf_data_format == "NHWC": + strides = (1,) + strides + (1,) + pool_size = (1,) + pool_size + (1,) + else: + strides = (1, 1) + strides + pool_size = (1, 1) + pool_size + + if pool_mode == "max": + x = tf.compat.v1.nn.max_pool( + x, pool_size, strides, padding=padding, data_format=tf_data_format + ) + elif pool_mode == "avg": + x = tf.compat.v1.nn.avg_pool( + x, pool_size, strides, padding=padding, data_format=tf_data_format + ) + else: + raise ValueError(f"Invalid pooling mode: {str(pool_mode)}") + + if data_format == "channels_first" and tf_data_format == "NHWC": + x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW + return x + + +@keras_export("keras._legacy.backend.pool3d") +def pool3d( + x, + pool_size, + strides=(1, 1, 1), + padding="valid", + data_format=None, + pool_mode="max", +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + x, tf_data_format = _preprocess_conv3d_input(x, data_format) + padding = _preprocess_padding(padding) + if tf_data_format == "NDHWC": + strides = (1,) + strides + (1,) + pool_size = (1,) + pool_size + (1,) + else: + strides = (1, 1) + strides + pool_size = (1, 1) + pool_size + + if pool_mode == "max": + x = tf.nn.max_pool3d( + x, pool_size, strides, padding=padding, data_format=tf_data_format + ) + elif pool_mode == "avg": + x = tf.nn.avg_pool3d( + x, pool_size, strides, padding=padding, data_format=tf_data_format + ) + else: + raise ValueError(f"Invalid pooling mode: {str(pool_mode)}") + + if data_format == "channels_first" and tf_data_format == "NDHWC": + x = tf.transpose(x, (0, 4, 1, 2, 3)) + return x + + +@keras_export("keras._legacy.backend.pow") +def pow(x, a): + """DEPRECATED.""" + return tf.pow(x, a) + + +@keras_export("keras._legacy.backend.prod") +def prod(x, axis=None, keepdims=False): + """DEPRECATED.""" + return tf.reduce_prod(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.random_bernoulli") +def random_bernoulli(shape, p=0.0, dtype=None, seed=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + if seed is None: + seed = np.random.randint(10e6) + return tf.where( + tf.random.uniform(shape, dtype=dtype, seed=seed) <= p, + tf.ones(shape, dtype=dtype), + tf.zeros(shape, dtype=dtype), + ) + + +@keras_export("keras._legacy.backend.random_normal") +def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + if seed is None: + seed = np.random.randint(10e6) + return tf.random.normal( + shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) + + +@keras_export("keras._legacy.backend.random_normal_variable") +def random_normal_variable( + shape, mean, scale, dtype=None, name=None, seed=None +): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + tf_dtype = tf.as_dtype(dtype) + if seed is None: + # ensure that randomness is conditioned by the Numpy RNG + seed = np.random.randint(10e8) + value = tf.compat.v1.random_normal_initializer( + mean, scale, dtype=tf_dtype, seed=seed + )(shape) + return variable(value, dtype=dtype, name=name) + + +@keras_export("keras._legacy.backend.random_uniform") +def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + if seed is None: + seed = np.random.randint(10e6) + return tf.random.uniform( + shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed + ) + + +@keras_export("keras._legacy.backend.random_uniform_variable") +def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + tf_dtype = tf.as_dtype(dtype) + if seed is None: + # ensure that randomness is conditioned by the Numpy RNG + seed = np.random.randint(10e8) + value = tf.compat.v1.random_uniform_initializer( + low, high, dtype=tf_dtype, seed=seed + )(shape) + return variable(value, dtype=dtype, name=name) + + +@keras_export("keras._legacy.backend.reshape") +def reshape(x, shape): + """DEPRECATED.""" + return tf.reshape(x, shape) + + +@keras_export("keras._legacy.backend.relu") +def relu(x, alpha=0.0, max_value=None, threshold=0.0): + """DEPRECATED.""" + # While x can be a tensor or variable, we also see cases where + # numpy arrays, lists, tuples are passed as well. + # lists, tuples do not have 'dtype' attribute. + dtype = getattr(x, "dtype", backend.floatx()) + if alpha != 0.0: + if max_value is None and threshold == 0: + return tf.nn.leaky_relu(x, alpha=alpha) + + if threshold != 0: + negative_part = tf.nn.relu(-x + threshold) + else: + negative_part = tf.nn.relu(-x) + else: + negative_part = 1 + + clip_max = max_value is not None + + if threshold != 0: + # computes x for x > threshold else 0 + x = x * tf.cast(tf.greater(x, threshold), dtype=dtype) + elif max_value == 6: + # if no threshold, then can use nn.relu6 native TF op for performance + x = tf.nn.relu6(x) + clip_max = False + else: + x = tf.nn.relu(x) + + if clip_max: + max_value = tf.convert_to_tensor(max_value, dtype=x.dtype) + zero = tf.convert_to_tensor(0, dtype=x.dtype) + x = tf.clip_by_value(x, zero, max_value) + + if alpha != 0.0: + alpha = tf.convert_to_tensor(alpha, dtype=x.dtype) + x -= alpha * negative_part + return x + + +@keras_export("keras._legacy.backend.repeat") +def repeat(x, n): + """DEPRECATED.""" + assert ndim(x) == 2 + x = tf.expand_dims(x, 1) + pattern = tf.stack([1, n, 1]) + return tf.tile(x, pattern) + + +@keras_export("keras._legacy.backend.repeat_elements") +def repeat_elements(x, rep, axis): + """DEPRECATED.""" + x_shape = x.shape.as_list() + # For static axis + if x_shape[axis] is not None: + # slices along the repeat axis + splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis) + # repeat each slice the given number of reps + x_rep = [s for s in splits for _ in range(rep)] + return concatenate(x_rep, axis) + + # Here we use tf.tile to mimic behavior of np.repeat so that + # we can handle dynamic shapes (that include None). + # To do that, we need an auxiliary axis to repeat elements along + # it and then merge them along the desired axis. + + # Repeating + auxiliary_axis = axis + 1 + x_shape = tf.shape(x) + x_rep = tf.expand_dims(x, axis=auxiliary_axis) + reps = np.ones(len(x.shape) + 1) + reps[auxiliary_axis] = rep + x_rep = tf.tile(x_rep, reps) + + # Merging + reps = np.delete(reps, auxiliary_axis) + reps[axis] = rep + reps = tf.constant(reps, dtype="int32") + x_shape *= reps + x_rep = tf.reshape(x_rep, x_shape) + + # Fix shape representation + x_shape = x.shape.as_list() + x_rep.set_shape(x_shape) + return x_rep + + +@keras_export("keras._legacy.backend.resize_images") +def resize_images( + x, height_factor, width_factor, data_format, interpolation="nearest" +): + """DEPRECATED.""" + if data_format == "channels_first": + rows, cols = 2, 3 + elif data_format == "channels_last": + rows, cols = 1, 2 + else: + raise ValueError(f"Invalid `data_format` argument: {data_format}") + + new_shape = x.shape[rows : cols + 1] + if new_shape.is_fully_defined(): + new_shape = tf.constant(new_shape.as_list(), dtype="int32") + else: + new_shape = tf.shape(x)[rows : cols + 1] + new_shape *= tf.constant( + np.array([height_factor, width_factor], dtype="int32") + ) + + if data_format == "channels_first": + x = permute_dimensions(x, [0, 2, 3, 1]) + interpolations = { + "area": tf.image.ResizeMethod.AREA, + "bicubic": tf.image.ResizeMethod.BICUBIC, + "bilinear": tf.image.ResizeMethod.BILINEAR, + "gaussian": tf.image.ResizeMethod.GAUSSIAN, + "lanczos3": tf.image.ResizeMethod.LANCZOS3, + "lanczos5": tf.image.ResizeMethod.LANCZOS5, + "mitchellcubic": tf.image.ResizeMethod.MITCHELLCUBIC, + "nearest": tf.image.ResizeMethod.NEAREST_NEIGHBOR, + } + interploations_list = '"' + '", "'.join(interpolations.keys()) + '"' + if interpolation in interpolations: + x = tf.image.resize(x, new_shape, method=interpolations[interpolation]) + else: + raise ValueError( + "`interpolation` argument should be one of: " + f'{interploations_list}. Received: "{interpolation}".' + ) + if data_format == "channels_first": + x = permute_dimensions(x, [0, 3, 1, 2]) + + return x + + +@keras_export("keras._legacy.backend.resize_volumes") +def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): + """DEPRECATED.""" + if data_format == "channels_first": + output = repeat_elements(x, depth_factor, axis=2) + output = repeat_elements(output, height_factor, axis=3) + output = repeat_elements(output, width_factor, axis=4) + return output + elif data_format == "channels_last": + output = repeat_elements(x, depth_factor, axis=1) + output = repeat_elements(output, height_factor, axis=2) + output = repeat_elements(output, width_factor, axis=3) + return output + else: + raise ValueError(f"Invalid data_format: {data_format}") + + +@keras_export("keras._legacy.backend.reverse") +def reverse(x, axes): + """DEPRECATED.""" + if isinstance(axes, int): + axes = [axes] + return tf.reverse(x, axes) + + +@keras_export("keras._legacy.backend.rnn") +def rnn( + step_function, + inputs, + initial_states, + go_backwards=False, + mask=None, + constants=None, + unroll=False, + input_length=None, + time_major=False, + zero_output_for_mask=False, + return_all_outputs=True, +): + """DEPRECATED.""" + if not tf.__internal__.tf2.enabled(): + return_all_outputs = True # Not supported in TF1. + + def swap_batch_timestep(input_t): + # Swap the batch and timestep dim for the incoming tensor. + axes = list(range(len(input_t.shape))) + axes[0], axes[1] = 1, 0 + return tf.transpose(input_t, axes) + + if not time_major: + inputs = tf.nest.map_structure(swap_batch_timestep, inputs) + + flatted_inputs = tf.nest.flatten(inputs) + time_steps = flatted_inputs[0].shape[0] + batch = flatted_inputs[0].shape[1] + time_steps_t = tf.shape(flatted_inputs[0])[0] + + for input_ in flatted_inputs: + input_.shape.with_rank_at_least(3) + + if mask is not None: + if mask.dtype != tf.bool: + mask = tf.cast(mask, tf.bool) + if len(mask.shape) == 2: + mask = expand_dims(mask) + if not time_major: + mask = swap_batch_timestep(mask) + + if constants is None: + constants = [] + + # tf.where needs its condition tensor to be the same shape as its two + # result tensors, but in our case the condition (mask) tensor is + # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. + # So we need to broadcast the mask to match the shape of inputs. + # That's what the tile call does, it just repeats the mask along its + # second dimension n times. + def _expand_mask(mask_t, input_t, fixed_dim=1): + if tf.nest.is_nested(mask_t): + raise ValueError( + f"mask_t is expected to be tensor, but got {mask_t}" + ) + if tf.nest.is_nested(input_t): + raise ValueError( + f"input_t is expected to be tensor, but got {input_t}" + ) + rank_diff = len(input_t.shape) - len(mask_t.shape) + for _ in range(rank_diff): + mask_t = tf.expand_dims(mask_t, -1) + multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:] + return tf.tile(mask_t, multiples) + + if unroll: + if not time_steps: + raise ValueError("Unrolling requires a fixed number of timesteps.") + states = tuple(initial_states) + successive_states = [] + successive_outputs = [] + + # Process the input tensors. The input tensor need to be split on the + # time_step dim, and reverse if go_backwards is True. In the case of + # nested input, the input is flattened and then transformed + # individually. The result of this will be a tuple of lists, each of + # the item in tuple is list of the tensor with shape (batch, feature) + def _process_single_input_t(input_t): + input_t = tf.unstack(input_t) # unstack for time_step dim + if go_backwards: + input_t.reverse() + return input_t + + if tf.nest.is_nested(inputs): + processed_input = tf.nest.map_structure( + _process_single_input_t, inputs + ) + else: + processed_input = (_process_single_input_t(inputs),) + + def _get_input_tensor(time): + inp = [t_[time] for t_ in processed_input] + return tf.nest.pack_sequence_as(inputs, inp) + + if mask is not None: + mask_list = tf.unstack(mask) + if go_backwards: + mask_list.reverse() + + for i in range(time_steps): + inp = _get_input_tensor(i) + mask_t = mask_list[i] + output, new_states = step_function( + inp, tuple(states) + tuple(constants) + ) + tiled_mask_t = _expand_mask(mask_t, output) + + if not successive_outputs: + prev_output = zeros_like(output) + else: + prev_output = successive_outputs[-1] + + output = tf.where(tiled_mask_t, output, prev_output) + + flat_states = tf.nest.flatten(states) + flat_new_states = tf.nest.flatten(new_states) + tiled_mask_t = tuple( + _expand_mask(mask_t, s) for s in flat_states + ) + flat_final_states = tuple( + tf.where(m, s, ps) + for m, s, ps in zip( + tiled_mask_t, flat_new_states, flat_states + ) + ) + states = tf.nest.pack_sequence_as(states, flat_final_states) + + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = tf.stack(successive_outputs) + + if zero_output_for_mask: + last_output = tf.where( + _expand_mask(mask_list[-1], last_output), + last_output, + zeros_like(last_output), + ) + outputs = tf.where( + _expand_mask(mask, outputs, fixed_dim=2), + outputs, + zeros_like(outputs), + ) + + else: # mask is None + for i in range(time_steps): + inp = _get_input_tensor(i) + output, states = step_function( + inp, tuple(states) + tuple(constants) + ) + if return_all_outputs: + successive_outputs.append(output) + successive_states.append(states) + else: + successive_outputs = [output] + successive_states = [states] + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = tf.stack(successive_outputs) + + else: # Unroll == False + states = tuple(initial_states) + + # Create input tensor array, if the inputs is nested tensors, then it + # will be flattened first, and tensor array will be created one per + # flattened tensor. + input_ta = tuple( + tf.TensorArray( + dtype=inp.dtype, + size=time_steps_t, + tensor_array_name=f"input_ta_{i}", + ) + for i, inp in enumerate(flatted_inputs) + ) + input_ta = tuple( + ( + ta.unstack(input_) + if not go_backwards + else ta.unstack(reverse(input_, 0)) + ) + for ta, input_ in zip(input_ta, flatted_inputs) + ) + + # Get the time(0) input and compute the output for that, the output will + # be used to determine the dtype of output tensor array. Don't read from + # input_ta due to TensorArray clear_after_read default to True. + input_time_zero = tf.nest.pack_sequence_as( + inputs, [inp[0] for inp in flatted_inputs] + ) + # output_time_zero is used to determine the cell output shape and its + # dtype. the value is discarded. + output_time_zero, _ = step_function( + input_time_zero, tuple(initial_states) + tuple(constants) + ) + + output_ta_size = time_steps_t if return_all_outputs else 1 + output_ta = tuple( + tf.TensorArray( + dtype=out.dtype, + size=output_ta_size, + element_shape=out.shape, + tensor_array_name=f"output_ta_{i}", + ) + for i, out in enumerate(tf.nest.flatten(output_time_zero)) + ) + + time = tf.constant(0, dtype="int32", name="time") + + if input_length is None: + max_iterations = time_steps_t + else: + max_iterations = tf.reduce_max(input_length) + + while_loop_kwargs = { + "cond": lambda time, *_: time < time_steps_t, + "maximum_iterations": max_iterations, + "parallel_iterations": 32, + "swap_memory": True, + } + if mask is not None: + if go_backwards: + mask = reverse(mask, 0) + + mask_ta = tf.TensorArray( + dtype=tf.bool, size=time_steps_t, tensor_array_name="mask_ta" + ) + mask_ta = mask_ta.unstack(mask) + + def masking_fn(time): + return mask_ta.read(time) + + def compute_masked_output(mask_t, flat_out, flat_mask): + tiled_mask_t = tuple( + _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape)) + for o in flat_out + ) + return tuple( + tf.where(m, o, fm) + for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask) + ) + + elif isinstance(input_length, tf.Tensor): + if go_backwards: + max_len = tf.reduce_max(input_length, axis=0) + rev_input_length = tf.subtract(max_len - 1, input_length) + + def masking_fn(time): + return tf.less(rev_input_length, time) + + else: + + def masking_fn(time): + return tf.greater(input_length, time) + + def compute_masked_output(mask_t, flat_out, flat_mask): + return tuple( + tf.compat.v1.where(mask_t, o, zo) + for (o, zo) in zip(flat_out, flat_mask) + ) + + else: + masking_fn = None + + if masking_fn is not None: + # Mask for the T output will be base on the output of T - 1. In the + # case T = 0, a zero filled tensor will be used. + flat_zero_output = tuple( + tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero) + ) + + def _step(time, output_ta_t, prev_output, *states): + """RNN step function. + + Args: + time: Current timestep value. + output_ta_t: TensorArray. + prev_output: tuple of outputs from time - 1. + *states: List of states. + + Returns: + Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)` + """ + current_input = tuple(ta.read(time) for ta in input_ta) + # maybe set shape. + current_input = tf.nest.pack_sequence_as(inputs, current_input) + mask_t = masking_fn(time) + output, new_states = step_function( + current_input, tuple(states) + tuple(constants) + ) + # mask output + flat_output = tf.nest.flatten(output) + flat_mask_output = ( + flat_zero_output + if zero_output_for_mask + else tf.nest.flatten(prev_output) + ) + flat_new_output = compute_masked_output( + mask_t, flat_output, flat_mask_output + ) + + # mask states + flat_state = tf.nest.flatten(states) + flat_new_state = tf.nest.flatten(new_states) + for state, new_state in zip(flat_state, flat_new_state): + if isinstance(new_state, tf.Tensor): + new_state.set_shape(state.shape) + flat_final_state = compute_masked_output( + mask_t, flat_new_state, flat_state + ) + new_states = tf.nest.pack_sequence_as( + new_states, flat_final_state + ) + + ta_index_to_write = time if return_all_outputs else 0 + output_ta_t = tuple( + ta.write(ta_index_to_write, out) + for ta, out in zip(output_ta_t, flat_new_output) + ) + + return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple( + new_states + ) + + final_outputs = tf.compat.v1.while_loop( + body=_step, + loop_vars=(time, output_ta, flat_zero_output) + states, + **while_loop_kwargs, + ) + # Skip final_outputs[2] which is the output for final timestep. + new_states = final_outputs[3:] + else: + + def _step(time, output_ta_t, *states): + """RNN step function. + + Args: + time: Current timestep value. + output_ta_t: TensorArray. + *states: List of states. + + Returns: + Tuple: `(time + 1,output_ta_t) + tuple(new_states)` + """ + current_input = tuple(ta.read(time) for ta in input_ta) + current_input = tf.nest.pack_sequence_as(inputs, current_input) + output, new_states = step_function( + current_input, tuple(states) + tuple(constants) + ) + flat_state = tf.nest.flatten(states) + flat_new_state = tf.nest.flatten(new_states) + for state, new_state in zip(flat_state, flat_new_state): + if isinstance(new_state, tf.Tensor): + new_state.set_shape(state.shape) + + flat_output = tf.nest.flatten(output) + ta_index_to_write = time if return_all_outputs else 0 + output_ta_t = tuple( + ta.write(ta_index_to_write, out) + for ta, out in zip(output_ta_t, flat_output) + ) + + new_states = tf.nest.pack_sequence_as( + initial_states, flat_new_state + ) + return (time + 1, output_ta_t) + tuple(new_states) + + final_outputs = tf.compat.v1.while_loop( + body=_step, + loop_vars=(time, output_ta) + states, + **while_loop_kwargs, + ) + new_states = final_outputs[2:] + + output_ta = final_outputs[1] + + outputs = tuple(o.stack() for o in output_ta) + last_output = tuple(o[-1] for o in outputs) + + outputs = tf.nest.pack_sequence_as(output_time_zero, outputs) + last_output = tf.nest.pack_sequence_as(output_time_zero, last_output) + + # static shape inference + def set_shape(output_): + if isinstance(output_, tf.Tensor): + shape = output_.shape.as_list() + if return_all_outputs: + shape[0] = time_steps + else: + shape[0] = 1 + shape[1] = batch + output_.set_shape(shape) + return output_ + + outputs = tf.nest.map_structure(set_shape, outputs) + + if not time_major: + outputs = tf.nest.map_structure(swap_batch_timestep, outputs) + + return last_output, outputs, new_states + + +@keras_export("keras._legacy.backend.round") +def round(x): + """DEPRECATED.""" + return tf.round(x) + + +@keras_export("keras._legacy.backend.separable_conv2d") +def separable_conv2d( + x, + depthwise_kernel, + pointwise_kernel, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), +): + """DEPRECATED.""" + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + if len(strides) != 2: + raise ValueError("`strides` must be a tuple of 2 integers.") + + x, tf_data_format = _preprocess_conv2d_input(x, data_format) + padding = _preprocess_padding(padding) + if not isinstance(strides, tuple): + strides = tuple(strides) + if tf_data_format == "NHWC": + strides = (1,) + strides + (1,) + else: + strides = (1, 1) + strides + + x = tf.nn.separable_conv2d( + x, + depthwise_kernel, + pointwise_kernel, + strides=strides, + padding=padding, + dilations=dilation_rate, + data_format=tf_data_format, + ) + if data_format == "channels_first" and tf_data_format == "NHWC": + x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW + return x + + +@keras_export("keras._legacy.backend.set_value") +def set_value(x, value): + """DEPRECATED.""" + value = np.asarray(value, dtype=x.dtype.name) + x.assign(value) + + +@keras_export("keras._legacy.backend.shape") +def shape(x): + """DEPRECATED.""" + return tf.shape(x) + + +@keras_export("keras._legacy.backend.sigmoid") +def sigmoid(x): + """DEPRECATED.""" + output = tf.sigmoid(x) + return output + + +@keras_export("keras._legacy.backend.sign") +def sign(x): + """DEPRECATED.""" + return tf.sign(x) + + +@keras_export("keras._legacy.backend.sin") +def sin(x): + """DEPRECATED.""" + return tf.sin(x) + + +@keras_export("keras._legacy.backend.softmax") +def softmax(x, axis=-1): + """DEPRECATED.""" + if x.shape.rank <= 1: + raise ValueError( + f"Cannot apply softmax to a tensor that is 1D. Received input: {x}" + ) + + if isinstance(axis, int): + output = tf.nn.softmax(x, axis=axis) + else: + # nn.softmax does not support tuple axis. + numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True)) + denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True) + output = numerator / denominator + + # Cache the logits to use for crossentropy loss. + output._keras_logits = x + return output + + +@keras_export("keras._legacy.backend.softplus") +def softplus(x): + """DEPRECATED.""" + return tf.math.softplus(x) + + +@keras_export("keras._legacy.backend.softsign") +def softsign(x): + """DEPRECATED.""" + return tf.math.softsign(x) + + +@keras_export("keras._legacy.backend.sparse_categorical_crossentropy") +def sparse_categorical_crossentropy( + target, output, from_logits=False, axis=-1, ignore_class=None +): + """DEPRECATED.""" + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + + target = cast(target, "int64") + + if not from_logits: + epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype) + output = tf.clip_by_value(output, epsilon_, 1 - epsilon_) + output = tf.math.log(output) + + # Permute output so that the last axis contains the logits/probabilities. + if isinstance(output.shape, (tuple, list)): + output_rank = len(output.shape) + else: + output_rank = output.shape.ndims + if output_rank is not None: + axis %= output_rank + if axis != output_rank - 1: + permutation = list( + itertools.chain( + range(axis), range(axis + 1, output_rank), [axis] + ) + ) + output = tf.transpose(output, perm=permutation) + elif axis != -1: + raise ValueError( + "Cannot compute sparse categorical crossentropy with `axis={}` " + "on an output tensor with unknown rank".format(axis) + ) + + # Try to adjust the shape so that rank of labels = rank of logits - 1. + output_shape = tf.shape(output) + target_rank = target.shape.ndims + + update_shape = ( + target_rank is not None + and output_rank is not None + and target_rank != output_rank - 1 + ) + if update_shape: + target = flatten(target) + output = tf.reshape(output, [-1, output_shape[-1]]) + + if ignore_class is not None: + valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype)) + target = target[valid_mask] + output = output[valid_mask] + + res = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=target, logits=output + ) + + if ignore_class is not None: + res_shape = cast(output_shape[:-1], "int64") + valid_mask = tf.reshape(valid_mask, res_shape) + res = tf.scatter_nd(tf.where(valid_mask), res, res_shape) + res._keras_mask = valid_mask + + return res + + if update_shape and output_rank >= 3: + # If our output includes timesteps or + # spatial dimensions we need to reshape + res = tf.reshape(res, output_shape[:-1]) + + return res + + +@keras_export("keras._legacy.backend.spatial_2d_padding") +def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): + """DEPRECATED.""" + assert len(padding) == 2 + assert len(padding[0]) == 2 + assert len(padding[1]) == 2 + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + if data_format == "channels_first": + pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])] + else: + pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]] + return tf.compat.v1.pad(x, pattern) + + +@keras_export("keras._legacy.backend.spatial_3d_padding") +def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): + """DEPRECATED.""" + assert len(padding) == 3 + assert len(padding[0]) == 2 + assert len(padding[1]) == 2 + assert len(padding[2]) == 2 + if data_format is None: + data_format = backend.image_data_format() + if data_format not in {"channels_first", "channels_last"}: + raise ValueError(f"Unknown data_format: {data_format}") + + if data_format == "channels_first": + pattern = [ + [0, 0], + [0, 0], + [padding[0][0], padding[0][1]], + [padding[1][0], padding[1][1]], + [padding[2][0], padding[2][1]], + ] + else: + pattern = [ + [0, 0], + [padding[0][0], padding[0][1]], + [padding[1][0], padding[1][1]], + [padding[2][0], padding[2][1]], + [0, 0], + ] + return tf.compat.v1.pad(x, pattern) + + +@keras_export("keras._legacy.backend.sqrt") +def sqrt(x): + """DEPRECATED.""" + zero = tf.convert_to_tensor(0.0, x.dtype) + x = tf.maximum(x, zero) + return tf.sqrt(x) + + +@keras_export("keras._legacy.backend.square") +def square(x): + """DEPRECATED.""" + return tf.square(x) + + +@keras_export("keras._legacy.backend.squeeze") +def squeeze(x, axis): + """DEPRECATED.""" + return tf.squeeze(x, [axis]) + + +@keras_export("keras._legacy.backend.stack") +def stack(x, axis=0): + """DEPRECATED.""" + return tf.stack(x, axis=axis) + + +@keras_export("keras._legacy.backend.std") +def std(x, axis=None, keepdims=False): + """DEPRECATED.""" + if x.dtype.base_dtype == tf.bool: + x = tf.cast(x, backend.floatx()) + return tf.math.reduce_std(x, axis=axis, keepdims=keepdims) + + +@keras_export("keras._legacy.backend.stop_gradient") +def stop_gradient(variables): + """DEPRECATED.""" + if isinstance(variables, (list, tuple)): + return map(tf.stop_gradient, variables) + return tf.stop_gradient(variables) + + +@keras_export("keras._legacy.backend.sum") +def sum(x, axis=None, keepdims=False): + """DEPRECATED.""" + return tf.reduce_sum(x, axis, keepdims) + + +@keras_export("keras._legacy.backend.switch") +def switch(condition, then_expression, else_expression): + """DEPRECATED.""" + if condition.dtype != tf.bool: + condition = tf.cast(condition, "bool") + cond_ndim = ndim(condition) + if not cond_ndim: + if not callable(then_expression): + + def then_expression_fn(): + return then_expression + + else: + then_expression_fn = then_expression + if not callable(else_expression): + + def else_expression_fn(): + return else_expression + + else: + else_expression_fn = else_expression + x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn) + else: + # tf.where needs its condition tensor + # to be the same shape as its two + # result tensors + if callable(then_expression): + then_expression = then_expression() + if callable(else_expression): + else_expression = else_expression() + expr_ndim = ndim(then_expression) + if cond_ndim > expr_ndim: + raise ValueError( + "Rank of `condition` should be less than or" + " equal to rank of `then_expression` and " + "`else_expression`. ndim(condition)=" + f"{cond_ndim}, ndim(then_expression)={expr_ndim}" + ) + if cond_ndim > 1: + ndim_diff = expr_ndim - cond_ndim + cond_shape = tf.concat( + [tf.shape(condition), [1] * ndim_diff], axis=0 + ) + condition = tf.reshape(condition, cond_shape) + expr_shape = tf.shape(then_expression) + shape_diff = expr_shape - cond_shape + tile_shape = tf.where( + shape_diff > 0, expr_shape, tf.ones_like(expr_shape) + ) + condition = tf.tile(condition, tile_shape) + x = tf.where(condition, then_expression, else_expression) + return x + + +@keras_export("keras._legacy.backend.tanh") +def tanh(x): + """DEPRECATED.""" + return tf.tanh(x) + + +@keras_export("keras._legacy.backend.temporal_padding") +def temporal_padding(x, padding=(1, 1)): + """DEPRECATED.""" + assert len(padding) == 2 + pattern = [[0, 0], [padding[0], padding[1]], [0, 0]] + return tf.compat.v1.pad(x, pattern) + + +@keras_export("keras._legacy.backend.tile") +def tile(x, n): + """DEPRECATED.""" + if isinstance(n, int): + n = [n] + return tf.tile(x, n) + + +@keras_export("keras._legacy.backend.to_dense") +def to_dense(tensor): + """DEPRECATED.""" + if is_sparse(tensor): + return tf.sparse.to_dense(tensor) + else: + return tensor + + +@keras_export("keras._legacy.backend.transpose") +def transpose(x): + """DEPRECATED.""" + return tf.transpose(x) + + +@keras_export("keras._legacy.backend.truncated_normal") +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + if seed is None: + seed = np.random.randint(10e6) + return tf.random.truncated_normal( + shape, mean, stddev, dtype=dtype, seed=seed + ) + + +@keras_export("keras._legacy.backend.update") +def update(x, new_x): + """DEPRECATED.""" + return tf.compat.v1.assign(x, new_x) + + +@keras_export("keras._legacy.backend.update_add") +def update_add(x, increment): + """DEPRECATED.""" + return tf.compat.v1.assign_add(x, increment) + + +@keras_export("keras._legacy.backend.update_sub") +def update_sub(x, decrement): + """DEPRECATED.""" + return tf.compat.v1.assign_sub(x, decrement) + + +@keras_export("keras._legacy.backend.var") +def var(x, axis=None, keepdims=False): + """DEPRECATED.""" + if x.dtype.base_dtype == tf.bool: + x = tf.cast(x, backend.floatx()) + return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims) + + +@keras_export("keras._legacy.backend.variable") +def variable(value, dtype=None, name=None, constraint=None): + """DEPRECATED.""" + if dtype is None: + dtype = backend.floatx() + if hasattr(value, "tocoo"): + sparse_coo = value.tocoo() + indices = np.concatenate( + ( + np.expand_dims(sparse_coo.row, 1), + np.expand_dims(sparse_coo.col, 1), + ), + 1, + ) + v = tf.SparseTensor( + indices=indices, + values=sparse_coo.data, + dense_shape=sparse_coo.shape, + ) + v._keras_shape = sparse_coo.shape + return v + v = tf.Variable( + value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint + ) + return v + + +@keras_export("keras._legacy.backend.zeros") +def zeros(shape, dtype=None, name=None): + """DEPRECATED.""" + with tf.init_scope(): + if dtype is None: + dtype = backend.floatx() + tf_dtype = tf.as_dtype(dtype) + v = tf.zeros(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.shape.as_list()): + return variable(v, dtype=dtype, name=name) + return v + + +@keras_export("keras._legacy.backend.zeros_like") +def zeros_like(x, dtype=None, name=None): + """DEPRECATED.""" + return tf.zeros_like(x, dtype=dtype, name=name) diff --git a/keras/src/legacy/layers.py b/keras/src/legacy/layers.py new file mode 100644 index 000000000000..b51ecf86c751 --- /dev/null +++ b/keras/src/legacy/layers.py @@ -0,0 +1,244 @@ +"""Legacy Keras 1/2 layers. + +AlphaDropout +RandomHeight +RandomWidth +ThresholdedReLU +""" + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras._legacy.layers.AlphaDropout") +class AlphaDropout(Layer): + """DEPRECATED.""" + + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(**kwargs) + self.rate = rate + self.seed = seed + self.noise_shape = noise_shape + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + self.built = True + + def call(self, inputs, training=False): + if training and self.rate > 0: + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + if self.noise_shape is None: + noise_shape = tf.shape(inputs) + else: + noise_shape = self.noise_shape + kept_idx = tf.greater_equal( + backend.random.uniform(noise_shape, seed=self.seed_generator), + self.rate, + ) + kept_idx = tf.cast(kept_idx, inputs.dtype) + + # Get affine transformation params + a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * self.rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + + # Do affine transformation + return a * x + b + return inputs + + def get_config(self): + config = {"rate": self.rate, "seed": self.seed} + base_config = super().get_config() + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + return input_shape + + +@keras_export("keras._legacy.layers.RandomHeight") +class RandomHeight(Layer): + """DEPRECATED.""" + + def __init__(self, factor, interpolation="bilinear", seed=None, **kwargs): + super().__init__(**kwargs) + self.seed_generator = backend.random.SeedGenerator(seed) + self.factor = factor + if isinstance(factor, (tuple, list)): + self.height_lower = factor[0] + self.height_upper = factor[1] + else: + self.height_lower = -factor + self.height_upper = factor + + if self.height_upper < self.height_lower: + raise ValueError( + "`factor` argument cannot have an upper bound lesser than the " + f"lower bound. Received: factor={factor}" + ) + if self.height_lower < -1.0 or self.height_upper < -1.0: + raise ValueError( + "`factor` argument must have values larger than -1. " + f"Received: factor={factor}" + ) + self.interpolation = interpolation + self.seed = seed + + def call(self, inputs, training=True): + inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype) + + def random_height_inputs(inputs): + """Inputs height-adjusted with random ops.""" + inputs_shape = tf.shape(inputs) + img_hd = tf.cast(inputs_shape[-3], tf.float32) + img_wd = inputs_shape[-2] + height_factor = backend.random.uniform( + shape=[], + minval=(1.0 + self.height_lower), + maxval=(1.0 + self.height_upper), + seed=self.seed_generator, + ) + adjusted_height = tf.cast(height_factor * img_hd, tf.int32) + adjusted_size = tf.stack([adjusted_height, img_wd]) + output = tf.image.resize( + images=inputs, + size=adjusted_size, + method=self.interpolation, + ) + # tf.resize will output float32 regardless of input type. + output = tf.cast(output, self.compute_dtype) + output_shape = inputs.shape.as_list() + output_shape[-3] = None + output.set_shape(output_shape) + return output + + if training: + return random_height_inputs(inputs) + else: + return inputs + + def compute_output_shape(self, input_shape): + input_shape = list(input_shape) + input_shape[-3] = None + return tuple(input_shape) + + def get_config(self): + config = { + "factor": self.factor, + "interpolation": self.interpolation, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras._legacy.layers.RandomWidth") +class RandomWidth(Layer): + """DEPRECATED.""" + + def __init__(self, factor, interpolation="bilinear", seed=None, **kwargs): + super().__init__(**kwargs) + self.seed_generator = backend.random.SeedGenerator(seed) + self.factor = factor + if isinstance(factor, (tuple, list)): + self.width_lower = factor[0] + self.width_upper = factor[1] + else: + self.width_lower = -factor + self.width_upper = factor + if self.width_upper < self.width_lower: + raise ValueError( + "`factor` argument cannot have an upper bound less than the " + f"lower bound. Received: factor={factor}" + ) + if self.width_lower < -1.0 or self.width_upper < -1.0: + raise ValueError( + "`factor` argument must have values larger than -1. " + f"Received: factor={factor}" + ) + self.interpolation = interpolation + self.seed = seed + + def call(self, inputs, training=True): + inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype) + + def random_width_inputs(inputs): + """Inputs width-adjusted with random ops.""" + inputs_shape = tf.shape(inputs) + img_hd = inputs_shape[-3] + img_wd = tf.cast(inputs_shape[-2], tf.float32) + width_factor = backend.random.uniform( + shape=[], + minval=(1.0 + self.width_lower), + maxval=(1.0 + self.width_upper), + seed=self.seed_generator, + ) + adjusted_width = tf.cast(width_factor * img_wd, tf.int32) + adjusted_size = tf.stack([img_hd, adjusted_width]) + output = tf.image.resize( + images=inputs, + size=adjusted_size, + method=self.interpolation, + ) + # tf.resize will output float32 regardless of input type. + output = tf.cast(output, self.compute_dtype) + output_shape = inputs.shape.as_list() + output_shape[-2] = None + output.set_shape(output_shape) + return output + + if training: + return random_width_inputs(inputs) + else: + return inputs + + def compute_output_shape(self, input_shape): + input_shape = list(input_shape) + input_shape[-2] = None + return tuple(input_shape) + + def get_config(self): + config = { + "factor": self.factor, + "interpolation": self.interpolation, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras._legacy.layers.ThresholdedReLU") +class ThresholdedReLU(Layer): + """DEPRECATED.""" + + def __init__(self, theta=1.0, **kwargs): + super().__init__(**kwargs) + if theta is None: + raise ValueError( + "Theta of a Thresholded ReLU layer cannot be None, expecting a " + f"float. Received: {theta}" + ) + if theta < 0: + raise ValueError( + "The theta value of a Thresholded ReLU layer " + f"should be >=0. Received: {theta}" + ) + self.supports_masking = True + self.theta = tf.convert_to_tensor(theta, dtype=self.compute_dtype) + + def call(self, inputs): + dtype = self.compute_dtype + return inputs * tf.cast(tf.greater(inputs, self.theta), dtype) + + def get_config(self): + config = {"theta": float(self.theta)} + base_config = super().get_config() + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/legacy/losses.py b/keras/src/legacy/losses.py new file mode 100644 index 000000000000..a84284bfc38d --- /dev/null +++ b/keras/src/legacy/losses.py @@ -0,0 +1,20 @@ +from keras.src.api_export import keras_export + + +@keras_export("keras._legacy.losses.Reduction") +class Reduction: + AUTO = "auto" + NONE = "none" + SUM = "sum" + SUM_OVER_BATCH_SIZE = "sum_over_batch_size" + + @classmethod + def all(cls): + return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) + + @classmethod + def validate(cls, key): + if key not in cls.all(): + raise ValueError( + f'Invalid Reduction Key: {key}. Expected keys are "{cls.all()}"' + ) diff --git a/keras/src/legacy/preprocessing/__init__.py b/keras/src/legacy/preprocessing/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/legacy/preprocessing/image.py b/keras/src/legacy/preprocessing/image.py new file mode 100644 index 000000000000..497bb95909b2 --- /dev/null +++ b/keras/src/legacy/preprocessing/image.py @@ -0,0 +1,1884 @@ +"""Deprecated image preprocessing APIs from Keras 1.""" + +import collections +import multiprocessing +import os +import threading +import warnings + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset +from keras.src.utils import image_utils +from keras.src.utils import io_utils +from keras.src.utils.module_utils import scipy + + +@keras_export("keras._legacy.preprocessing.image.Iterator") +class Iterator(PyDataset): + """Base class for image data iterators. + + DEPRECATED. + + Every `Iterator` must implement the `_get_batches_of_transformed_samples` + method. + + Args: + n: Integer, total number of samples in the dataset to loop over. + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + seed: Random seeding for data shuffling. + **kwargs: Additional keyword arguments for the `PyDataset` base class, + such as `workers`, `use_multiprocessing`, and `max_queue_size`. + """ + + white_list_formats = ("png", "jpg", "jpeg", "bmp", "ppm", "tif", "tiff") + + def __init__(self, n, batch_size, shuffle, seed, **kwargs): + super().__init__(**kwargs) + self.n = n + self.batch_size = batch_size + self.seed = seed + self.shuffle = shuffle + self.batch_index = 0 + self.total_batches_seen = 0 + self.lock = threading.Lock() + self.index_array = None + self.index_generator = self._flow_index() + + def _set_index_array(self): + self.index_array = np.arange(self.n) + if self.shuffle: + self.index_array = np.random.permutation(self.n) + + def __getitem__(self, idx): + if idx >= len(self): + raise ValueError( + "Asked to retrieve element {idx}, " + "but the Sequence " + "has length {length}".format(idx=idx, length=len(self)) + ) + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) + self.total_batches_seen += 1 + if self.index_array is None: + self._set_index_array() + index_array = self.index_array[ + self.batch_size * idx : self.batch_size * (idx + 1) + ] + return self._get_batches_of_transformed_samples(index_array) + + def __len__(self): + return (self.n + self.batch_size - 1) // self.batch_size # round up + + def on_epoch_end(self): + self._set_index_array() + + def reset(self): + self.batch_index = 0 + + def _flow_index(self): + # Ensure self.batch_index is 0. + self.reset() + while 1: + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) + if self.batch_index == 0: + self._set_index_array() + + if self.n == 0: + # Avoiding modulo by zero error + current_index = 0 + else: + current_index = (self.batch_index * self.batch_size) % self.n + if self.n > current_index + self.batch_size: + self.batch_index += 1 + else: + self.batch_index = 0 + self.total_batches_seen += 1 + yield self.index_array[ + current_index : current_index + self.batch_size + ] + + def __iter__(self): + # Needed if we want to do something like: + # for x, y in data_gen.flow(...): + return self + + def __next__(self): + with self.lock: + index_array = next(self.index_generator) + # The transformation of images is not under thread lock + # so it can be done in parallel + return self._get_batches_of_transformed_samples(index_array) + + def _get_batches_of_transformed_samples(self, index_array): + """Gets a batch of transformed samples. + + Args: + index_array: Array of sample indices to include in batch. + Returns: + A batch of transformed samples. + """ + raise NotImplementedError + + +def _iter_valid_files(directory, white_list_formats, follow_links): + """Iterates on files with extension. + + Args: + directory: Absolute path to the directory + containing files to be counted + white_list_formats: Set of strings containing allowed extensions for + the files to be counted. + follow_links: Boolean, follow symbolic links to subdirectories. + Yields: + Tuple of (root, filename) with extension in `white_list_formats`. + """ + + def _recursive_list(subpath): + return sorted( + os.walk(subpath, followlinks=follow_links), key=lambda x: x[0] + ) + + for root, _, files in _recursive_list(directory): + for fname in sorted(files): + if fname.lower().endswith(".tiff"): + warnings.warn( + 'Using ".tiff" files with multiple bands ' + "will cause distortion. Please verify your output." + ) + if fname.lower().endswith(white_list_formats): + yield root, fname + + +def _list_valid_filenames_in_directory( + directory, white_list_formats, split, class_indices, follow_links +): + """Lists paths of files in `subdir` with extensions in `white_list_formats`. + + Args: + directory: absolute path to a directory containing the files to list. + The directory name is used as class label + and must be a key of `class_indices`. + white_list_formats: set of strings containing allowed extensions for + the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. + class_indices: dictionary mapping a class name to its index. + follow_links: boolean, follow symbolic links to subdirectories. + + Returns: + classes: a list of class indices + filenames: the path of valid files in `directory`, relative from + `directory`'s parent (e.g., if `directory` is "dataset/class1", + the filenames will be + `["class1/file1.jpg", "class1/file2.jpg", ...]`). + """ + dirname = os.path.basename(directory) + if split: + all_files = list( + _iter_valid_files(directory, white_list_formats, follow_links) + ) + num_files = len(all_files) + start, stop = int(split[0] * num_files), int(split[1] * num_files) + valid_files = all_files[start:stop] + else: + valid_files = _iter_valid_files( + directory, white_list_formats, follow_links + ) + classes = [] + filenames = [] + for root, fname in valid_files: + classes.append(class_indices[dirname]) + absolute_path = os.path.join(root, fname) + relative_path = os.path.join( + dirname, os.path.relpath(absolute_path, directory) + ) + filenames.append(relative_path) + + return classes, filenames + + +class BatchFromFilesMixin: + """Adds methods related to getting batches from filenames. + + It includes the logic to transform image files to batches. + """ + + def set_processing_attrs( + self, + image_data_generator, + target_size, + color_mode, + data_format, + save_to_dir, + save_prefix, + save_format, + subset, + interpolation, + keep_aspect_ratio, + ): + """Sets attributes to use later for processing files into a batch. + + Args: + image_data_generator: Instance of `ImageDataGenerator` + to use for random transformations and normalization. + target_size: tuple of integers, dimensions to resize input images + to. + color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. + Color mode to read images. + data_format: String, one of `channels_first`, `channels_last`. + save_to_dir: Optional directory where to save the pictures + being yielded, in a viewable format. This is useful + for visualizing the random transformations being + applied, for debugging purposes. + save_prefix: String prefix to use for saving sample + images (if `save_to_dir` is set). + save_format: Format to use for saving sample images + (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. + interpolation: Interpolation method used to resample the image if + the target size is different from that of the loaded image. + Supported methods are "nearest", "bilinear", and "bicubic". If + PIL version 1.1.3 or newer is installed, "lanczos" is also + supported. If PIL version 3.4.0 or newer is installed, "box" and + "hamming" are also supported. By default, "nearest" is used. + keep_aspect_ratio: Boolean, whether to resize images to a target + size without aspect ratio distortion. The image is cropped in + the center with target aspect ratio before resizing. + """ + self.image_data_generator = image_data_generator + self.target_size = tuple(target_size) + self.keep_aspect_ratio = keep_aspect_ratio + if color_mode not in {"rgb", "rgba", "grayscale"}: + raise ValueError( + f"Invalid color mode: {color_mode}" + '; expected "rgb", "rgba", or "grayscale".' + ) + self.color_mode = color_mode + self.data_format = data_format + if self.color_mode == "rgba": + if self.data_format == "channels_last": + self.image_shape = self.target_size + (4,) + else: + self.image_shape = (4,) + self.target_size + elif self.color_mode == "rgb": + if self.data_format == "channels_last": + self.image_shape = self.target_size + (3,) + else: + self.image_shape = (3,) + self.target_size + else: + if self.data_format == "channels_last": + self.image_shape = self.target_size + (1,) + else: + self.image_shape = (1,) + self.target_size + self.save_to_dir = save_to_dir + self.save_prefix = save_prefix + self.save_format = save_format + self.interpolation = interpolation + if subset is not None: + validation_split = self.image_data_generator._validation_split + if subset == "validation": + split = (0, validation_split) + elif subset == "training": + split = (validation_split, 1) + else: + raise ValueError( + f"Invalid subset name: {subset};" + 'expected "training" or "validation"' + ) + else: + split = None + self.split = split + self.subset = subset + + def _get_batches_of_transformed_samples(self, index_array): + """Gets a batch of transformed samples. + + Args: + index_array: Array of sample indices to include in batch. + Returns: + A batch of transformed samples. + """ + batch_x = np.zeros( + (len(index_array),) + self.image_shape, dtype=self.dtype + ) + # build batch of image data + # self.filepaths is dynamic, is better to call it once outside the loop + filepaths = self.filepaths + for i, j in enumerate(index_array): + img = image_utils.load_img( + filepaths[j], + color_mode=self.color_mode, + target_size=self.target_size, + interpolation=self.interpolation, + keep_aspect_ratio=self.keep_aspect_ratio, + ) + x = image_utils.img_to_array(img, data_format=self.data_format) + # Pillow images should be closed after `load_img`, + # but not PIL images. + if hasattr(img, "close"): + img.close() + if self.image_data_generator: + params = self.image_data_generator.get_random_transform(x.shape) + x = self.image_data_generator.apply_transform(x, params) + x = self.image_data_generator.standardize(x) + batch_x[i] = x + # optionally save augmented images to disk for debugging purposes + if self.save_to_dir: + for i, j in enumerate(index_array): + img = image_utils.array_to_img( + batch_x[i], self.data_format, scale=True + ) + fname = "{prefix}_{index}_{hash}.{format}".format( + prefix=self.save_prefix, + index=j, + hash=np.random.randint(1e7), + format=self.save_format, + ) + img.save(os.path.join(self.save_to_dir, fname)) + # build batch of labels + if self.class_mode == "input": + batch_y = batch_x.copy() + elif self.class_mode in {"binary", "sparse"}: + batch_y = np.empty(len(batch_x), dtype=self.dtype) + for i, n_observation in enumerate(index_array): + batch_y[i] = self.classes[n_observation] + elif self.class_mode == "categorical": + batch_y = np.zeros( + (len(batch_x), len(self.class_indices)), dtype=self.dtype + ) + for i, n_observation in enumerate(index_array): + batch_y[i, self.classes[n_observation]] = 1.0 + elif self.class_mode == "multi_output": + batch_y = [output[index_array] for output in self.labels] + elif self.class_mode == "raw": + batch_y = self.labels[index_array] + else: + return batch_x + if self.sample_weight is None: + return batch_x, batch_y + else: + return batch_x, batch_y, self.sample_weight[index_array] + + @property + def filepaths(self): + """List of absolute paths to image files.""" + raise NotImplementedError( + "`filepaths` property method has not " + "been implemented in {}.".format(type(self).__name__) + ) + + @property + def labels(self): + """Class labels of every observation.""" + raise NotImplementedError( + "`labels` property method has not been implemented in {}.".format( + type(self).__name__ + ) + ) + + @property + def sample_weight(self): + raise NotImplementedError( + "`sample_weight` property method has not " + "been implemented in {}.".format(type(self).__name__) + ) + + +@keras_export("keras._legacy.preprocessing.image.DirectoryIterator") +class DirectoryIterator(BatchFromFilesMixin, Iterator): + """Iterator capable of reading images from a directory on disk. + + DEPRECATED. + """ + + allowed_class_modes = {"categorical", "binary", "sparse", "input", None} + + def __init__( + self, + directory, + image_data_generator, + target_size=(256, 256), + color_mode="rgb", + classes=None, + class_mode="categorical", + batch_size=32, + shuffle=True, + seed=None, + data_format=None, + save_to_dir=None, + save_prefix="", + save_format="png", + follow_links=False, + subset=None, + interpolation="nearest", + keep_aspect_ratio=False, + dtype=None, + ): + if data_format is None: + data_format = backend.image_data_format() + if dtype is None: + dtype = backend.floatx() + super().set_processing_attrs( + image_data_generator, + target_size, + color_mode, + data_format, + save_to_dir, + save_prefix, + save_format, + subset, + interpolation, + keep_aspect_ratio, + ) + self.directory = directory + self.classes = classes + if class_mode not in self.allowed_class_modes: + raise ValueError( + "Invalid class_mode: {}; expected one of: {}".format( + class_mode, self.allowed_class_modes + ) + ) + self.class_mode = class_mode + self.dtype = dtype + # First, count the number of samples and classes. + self.samples = 0 + + if not classes: + classes = [] + for subdir in sorted(os.listdir(directory)): + if os.path.isdir(os.path.join(directory, subdir)): + classes.append(subdir) + self.num_classes = len(classes) + self.class_indices = dict(zip(classes, range(len(classes)))) + + pool = multiprocessing.pool.ThreadPool() + + # Second, build an index of the images + # in the different class subfolders. + results = [] + self.filenames = [] + i = 0 + for dirpath in (os.path.join(directory, subdir) for subdir in classes): + results.append( + pool.apply_async( + _list_valid_filenames_in_directory, + ( + dirpath, + self.white_list_formats, + self.split, + self.class_indices, + follow_links, + ), + ) + ) + classes_list = [] + for res in results: + classes, filenames = res.get() + classes_list.append(classes) + self.filenames += filenames + self.samples = len(self.filenames) + self.classes = np.zeros((self.samples,), dtype="int32") + for classes in classes_list: + self.classes[i : i + len(classes)] = classes + i += len(classes) + + io_utils.print_msg( + f"Found {self.samples} images belonging to " + f"{self.num_classes} classes." + ) + pool.close() + pool.join() + self._filepaths = [ + os.path.join(self.directory, fname) for fname in self.filenames + ] + super().__init__(self.samples, batch_size, shuffle, seed) + + @property + def filepaths(self): + return self._filepaths + + @property + def labels(self): + return self.classes + + @property # mixin needs this property to work + def sample_weight(self): + # no sample weights will be returned + return None + + +@keras_export("keras._legacy.preprocessing.image.NumpyArrayIterator") +class NumpyArrayIterator(Iterator): + """Iterator yielding data from a Numpy array. + + DEPRECATED. + """ + + def __init__( + self, + x, + y, + image_data_generator, + batch_size=32, + shuffle=False, + sample_weight=None, + seed=None, + data_format=None, + save_to_dir=None, + save_prefix="", + save_format="png", + subset=None, + ignore_class_split=False, + dtype=None, + ): + if data_format is None: + data_format = backend.image_data_format() + if dtype is None: + dtype = backend.floatx() + self.dtype = dtype + if isinstance(x, tuple) or isinstance(x, list): + if not isinstance(x[1], list): + x_misc = [np.asarray(x[1])] + else: + x_misc = [np.asarray(xx) for xx in x[1]] + x = x[0] + for xx in x_misc: + if len(x) != len(xx): + raise ValueError( + "All of the arrays in `x` " + "should have the same length. " + "Found a pair with: " + f"len(x[0]) = {len(x)}, len(x[?]) = {len(xx)}" + ) + else: + x_misc = [] + + if y is not None and len(x) != len(y): + raise ValueError( + "`x` (images tensor) and `y` (labels) " + "should have the same length. " + f"Found: x.shape = {np.asarray(x).shape}, " + f"y.shape = {np.asarray(y).shape}" + ) + if sample_weight is not None and len(x) != len(sample_weight): + raise ValueError( + "`x` (images tensor) and `sample_weight` " + "should have the same length. " + f"Found: x.shape = {np.asarray(x).shape}, " + f"sample_weight.shape = {np.asarray(sample_weight).shape}" + ) + if subset is not None: + if subset not in {"training", "validation"}: + raise ValueError( + f"Invalid subset name: {subset}" + '; expected "training" or "validation".' + ) + split_idx = int(len(x) * image_data_generator._validation_split) + + if ( + y is not None + and not ignore_class_split + and not np.array_equal( + np.unique(y[:split_idx]), np.unique(y[split_idx:]) + ) + ): + raise ValueError( + "Training and validation subsets " + "have different number of classes after " + "the split. If your numpy arrays are " + "sorted by the label, you might want " + "to shuffle them." + ) + + if subset == "validation": + x = x[:split_idx] + x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc] + if y is not None: + y = y[:split_idx] + else: + x = x[split_idx:] + x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] + if y is not None: + y = y[split_idx:] + + self.x = np.asarray(x, dtype=self.dtype) + self.x_misc = x_misc + if self.x.ndim != 4: + raise ValueError( + "Input data in `NumpyArrayIterator` " + "should have rank 4. You passed an array " + f"with shape {self.x.shape}" + ) + channels_axis = 3 if data_format == "channels_last" else 1 + if self.x.shape[channels_axis] not in {1, 3, 4}: + warnings.warn( + f"NumpyArrayIterator is set to use the data format convention" + f' "{data_format}" (channels on axis {channels_axis})' + ", i.e. expected either 1, 3, or 4 channels " + f"on axis {channels_axis}. " + f"However, it was passed an array with shape {self.x.shape}" + f" ({self.x.shape[channels_axis]} channels)." + ) + if y is not None: + self.y = np.asarray(y) + else: + self.y = None + if sample_weight is not None: + self.sample_weight = np.asarray(sample_weight) + else: + self.sample_weight = None + self.image_data_generator = image_data_generator + self.data_format = data_format + self.save_to_dir = save_to_dir + self.save_prefix = save_prefix + self.save_format = save_format + super().__init__(x.shape[0], batch_size, shuffle, seed) + + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros( + tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=self.dtype + ) + for i, j in enumerate(index_array): + x = self.x[j] + params = self.image_data_generator.get_random_transform(x.shape) + x = self.image_data_generator.apply_transform( + x.astype(self.dtype), params + ) + x = self.image_data_generator.standardize(x) + batch_x[i] = x + + if self.save_to_dir: + for i, j in enumerate(index_array): + img = image_utils.array_to_img( + batch_x[i], self.data_format, scale=True + ) + fname = "{prefix}_{index}_{hash}.{format}".format( + prefix=self.save_prefix, + index=j, + hash=np.random.randint(1e4), + format=self.save_format, + ) + img.save(os.path.join(self.save_to_dir, fname)) + batch_x_miscs = [xx[index_array] for xx in self.x_misc] + output = (batch_x if not batch_x_miscs else [batch_x] + batch_x_miscs,) + if self.y is None: + return output[0] + output += (self.y[index_array],) + if self.sample_weight is not None: + output += (self.sample_weight[index_array],) + return output + + +def validate_filename(filename, white_list_formats): + """Check if a filename refers to a valid file. + + Args: + filename: String, absolute path to a file + white_list_formats: Set, allowed file extensions + Returns: + A boolean value indicating if the filename is valid or not + """ + return filename.lower().endswith(white_list_formats) and os.path.isfile( + filename + ) + + +class DataFrameIterator(BatchFromFilesMixin, Iterator): + """Iterator capable of reading images from a directory as a dataframe.""" + + allowed_class_modes = { + "binary", + "categorical", + "input", + "multi_output", + "raw", + "sparse", + None, + } + + def __init__( + self, + dataframe, + directory=None, + image_data_generator=None, + x_col="filename", + y_col="class", + weight_col=None, + target_size=(256, 256), + color_mode="rgb", + classes=None, + class_mode="categorical", + batch_size=32, + shuffle=True, + seed=None, + data_format="channels_last", + save_to_dir=None, + save_prefix="", + save_format="png", + subset=None, + interpolation="nearest", + keep_aspect_ratio=False, + dtype="float32", + validate_filenames=True, + ): + super().set_processing_attrs( + image_data_generator, + target_size, + color_mode, + data_format, + save_to_dir, + save_prefix, + save_format, + subset, + interpolation, + keep_aspect_ratio, + ) + df = dataframe.copy() + self.directory = directory or "" + self.class_mode = class_mode + self.dtype = dtype + # check that inputs match the required class_mode + self._check_params(df, x_col, y_col, weight_col, classes) + if ( + validate_filenames + ): # check which image files are valid and keep them + df = self._filter_valid_filepaths(df, x_col) + if class_mode not in ["input", "multi_output", "raw", None]: + df, classes = self._filter_classes(df, y_col, classes) + num_classes = len(classes) + # build an index of all the unique classes + self.class_indices = dict(zip(classes, range(len(classes)))) + # retrieve only training or validation set + if self.split: + num_files = len(df) + start = int(self.split[0] * num_files) + stop = int(self.split[1] * num_files) + df = df.iloc[start:stop, :] + # get labels for each observation + if class_mode not in ["input", "multi_output", "raw", None]: + self.classes = self.get_classes(df, y_col) + self.filenames = df[x_col].tolist() + self._sample_weight = df[weight_col].values if weight_col else None + + if class_mode == "multi_output": + self._targets = [np.array(df[col].tolist()) for col in y_col] + if class_mode == "raw": + self._targets = df[y_col].values + self.samples = len(self.filenames) + validated_string = ( + "validated" if validate_filenames else "non-validated" + ) + if class_mode in ["input", "multi_output", "raw", None]: + io_utils.print_msg( + f"Found {self.samples} {validated_string} image filenames." + ) + else: + io_utils.print_msg( + f"Found {self.samples} {validated_string} image filenames " + f"belonging to {num_classes} classes." + ) + self._filepaths = [ + os.path.join(self.directory, fname) for fname in self.filenames + ] + super().__init__(self.samples, batch_size, shuffle, seed) + + def _check_params(self, df, x_col, y_col, weight_col, classes): + # check class mode is one of the currently supported + if self.class_mode not in self.allowed_class_modes: + raise ValueError( + "Invalid class_mode: {}; expected one of: {}".format( + self.class_mode, self.allowed_class_modes + ) + ) + # check that y_col has several column names if class_mode is + # multi_output + if (self.class_mode == "multi_output") and not isinstance(y_col, list): + raise TypeError( + 'If class_mode="{}", y_col must be a list. Received {}.'.format( + self.class_mode, type(y_col).__name__ + ) + ) + # check that filenames/filepaths column values are all strings + if not all(df[x_col].apply(lambda x: isinstance(x, str))): + raise TypeError( + f"All values in column x_col={x_col} must be strings." + ) + # check labels are string if class_mode is binary or sparse + if self.class_mode in {"binary", "sparse"}: + if not all(df[y_col].apply(lambda x: isinstance(x, str))): + raise TypeError( + 'If class_mode="{}", y_col="{}" column ' + "values must be strings.".format(self.class_mode, y_col) + ) + # check that if binary there are only 2 different classes + if self.class_mode == "binary": + if classes: + classes = set(classes) + if len(classes) != 2: + raise ValueError( + 'If class_mode="binary" there must be 2 ' + "classes. {} class/es were given.".format(len(classes)) + ) + elif df[y_col].nunique() != 2: + raise ValueError( + 'If class_mode="binary" there must be 2 classes. ' + "Found {} classes.".format(df[y_col].nunique()) + ) + # check values are string, list or tuple if class_mode is categorical + if self.class_mode == "categorical": + types = (str, list, tuple) + if not all(df[y_col].apply(lambda x: isinstance(x, types))): + raise TypeError( + 'If class_mode="{}", y_col="{}" column ' + "values must be type string, list or tuple.".format( + self.class_mode, y_col + ) + ) + # raise warning if classes are given but will be unused + if classes and self.class_mode in { + "input", + "multi_output", + "raw", + None, + }: + warnings.warn( + '`classes` will be ignored given the class_mode="{}"'.format( + self.class_mode + ) + ) + # check that if weight column that the values are numerical + if weight_col and not issubclass(df[weight_col].dtype.type, np.number): + raise TypeError(f"Column weight_col={weight_col} must be numeric.") + + def get_classes(self, df, y_col): + labels = [] + for label in df[y_col]: + if isinstance(label, (list, tuple)): + labels.append([self.class_indices[lbl] for lbl in label]) + else: + labels.append(self.class_indices[label]) + return labels + + @staticmethod + def _filter_classes(df, y_col, classes): + df = df.copy() + + def remove_classes(labels, classes): + if isinstance(labels, (list, tuple)): + labels = [cls for cls in labels if cls in classes] + return labels or None + elif isinstance(labels, str): + return labels if labels in classes else None + else: + raise TypeError( + "Expect string, list or tuple " + "but found {} in {} column ".format(type(labels), y_col) + ) + + if classes: + # prepare for membership lookup + classes = list(collections.OrderedDict.fromkeys(classes).keys()) + df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes)) + else: + classes = set() + for v in df[y_col]: + if isinstance(v, (list, tuple)): + classes.update(v) + else: + classes.add(v) + classes = sorted(classes) + return df.dropna(subset=[y_col]), classes + + def _filter_valid_filepaths(self, df, x_col): + """Keep only dataframe rows with valid filenames. + + Args: + df: Pandas dataframe containing filenames in a column + x_col: string, column in `df` that contains the filenames or + filepaths + Returns: + absolute paths to image files + """ + filepaths = df[x_col].map( + lambda fname: os.path.join(self.directory, fname) + ) + mask = filepaths.apply( + validate_filename, args=(self.white_list_formats,) + ) + n_invalid = (~mask).sum() + if n_invalid: + warnings.warn( + 'Found {} invalid image filename(s) in x_col="{}". ' + "These filename(s) will be ignored.".format(n_invalid, x_col) + ) + return df[mask] + + @property + def filepaths(self): + return self._filepaths + + @property + def labels(self): + if self.class_mode in {"multi_output", "raw"}: + return self._targets + else: + return self.classes + + @property + def sample_weight(self): + return self._sample_weight + + +def flip_axis(x, axis): + x = np.asarray(x).swapaxes(axis, 0) + x = x[::-1, ...] + x = x.swapaxes(0, axis) + return x + + +@keras_export("keras._legacy.preprocessing.image.ImageDataGenerator") +class ImageDataGenerator: + """DEPRECATED.""" + + def __init__( + self, + featurewise_center=False, + samplewise_center=False, + featurewise_std_normalization=False, + samplewise_std_normalization=False, + zca_whitening=False, + zca_epsilon=1e-6, + rotation_range=0, + width_shift_range=0.0, + height_shift_range=0.0, + brightness_range=None, + shear_range=0.0, + zoom_range=0.0, + channel_shift_range=0.0, + fill_mode="nearest", + cval=0.0, + horizontal_flip=False, + vertical_flip=False, + rescale=None, + preprocessing_function=None, + data_format=None, + validation_split=0.0, + interpolation_order=1, + dtype=None, + ): + if data_format is None: + data_format = backend.image_data_format() + if dtype is None: + dtype = backend.floatx() + + self.featurewise_center = featurewise_center + self.samplewise_center = samplewise_center + self.featurewise_std_normalization = featurewise_std_normalization + self.samplewise_std_normalization = samplewise_std_normalization + self.zca_whitening = zca_whitening + self.zca_epsilon = zca_epsilon + self.rotation_range = rotation_range + self.width_shift_range = width_shift_range + self.height_shift_range = height_shift_range + self.shear_range = shear_range + self.zoom_range = zoom_range + self.channel_shift_range = channel_shift_range + self.fill_mode = fill_mode + self.cval = cval + self.horizontal_flip = horizontal_flip + self.vertical_flip = vertical_flip + self.rescale = rescale + self.preprocessing_function = preprocessing_function + self.dtype = dtype + self.interpolation_order = interpolation_order + + if data_format not in {"channels_last", "channels_first"}: + raise ValueError( + '`data_format` should be `"channels_last"` ' + "(channel after row and column) or " + '`"channels_first"` (channel before row and column). ' + f"Received: {data_format}" + ) + self.data_format = data_format + if data_format == "channels_first": + self.channel_axis = 1 + self.row_axis = 2 + self.col_axis = 3 + if data_format == "channels_last": + self.channel_axis = 3 + self.row_axis = 1 + self.col_axis = 2 + if validation_split and not 0 < validation_split < 1: + raise ValueError( + "`validation_split` must be strictly between 0 and 1. " + f" Received: {validation_split}" + ) + self._validation_split = validation_split + + self.mean = None + self.std = None + self.zca_whitening_matrix = None + + if isinstance(zoom_range, (float, int)): + self.zoom_range = [1 - zoom_range, 1 + zoom_range] + elif len(zoom_range) == 2 and all( + isinstance(val, (float, int)) for val in zoom_range + ): + self.zoom_range = [zoom_range[0], zoom_range[1]] + else: + raise ValueError( + "`zoom_range` should be a float or " + "a tuple or list of two floats. " + f"Received: {zoom_range}" + ) + if zca_whitening: + if not featurewise_center: + self.featurewise_center = True + warnings.warn( + "This ImageDataGenerator specifies " + "`zca_whitening`, which overrides " + "setting of `featurewise_center`." + ) + if featurewise_std_normalization: + self.featurewise_std_normalization = False + warnings.warn( + "This ImageDataGenerator specifies " + "`zca_whitening` " + "which overrides setting of" + "`featurewise_std_normalization`." + ) + if featurewise_std_normalization: + if not featurewise_center: + self.featurewise_center = True + warnings.warn( + "This ImageDataGenerator specifies " + "`featurewise_std_normalization`, " + "which overrides setting of " + "`featurewise_center`." + ) + if samplewise_std_normalization: + if not samplewise_center: + self.samplewise_center = True + warnings.warn( + "This ImageDataGenerator specifies " + "`samplewise_std_normalization`, " + "which overrides setting of " + "`samplewise_center`." + ) + if brightness_range is not None: + if ( + not isinstance(brightness_range, (tuple, list)) + or len(brightness_range) != 2 + ): + raise ValueError( + "`brightness_range should be tuple or list of two floats. " + f"Received: {brightness_range}" + ) + self.brightness_range = brightness_range + + def flow( + self, + x, + y=None, + batch_size=32, + shuffle=True, + sample_weight=None, + seed=None, + save_to_dir=None, + save_prefix="", + save_format="png", + ignore_class_split=False, + subset=None, + ): + return NumpyArrayIterator( + x, + y, + self, + batch_size=batch_size, + shuffle=shuffle, + sample_weight=sample_weight, + seed=seed, + data_format=self.data_format, + save_to_dir=save_to_dir, + save_prefix=save_prefix, + save_format=save_format, + ignore_class_split=ignore_class_split, + subset=subset, + dtype=self.dtype, + ) + + def flow_from_directory( + self, + directory, + target_size=(256, 256), + color_mode="rgb", + classes=None, + class_mode="categorical", + batch_size=32, + shuffle=True, + seed=None, + save_to_dir=None, + save_prefix="", + save_format="png", + follow_links=False, + subset=None, + interpolation="nearest", + keep_aspect_ratio=False, + ): + return DirectoryIterator( + directory, + self, + target_size=target_size, + color_mode=color_mode, + keep_aspect_ratio=keep_aspect_ratio, + classes=classes, + class_mode=class_mode, + data_format=self.data_format, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + save_to_dir=save_to_dir, + save_prefix=save_prefix, + save_format=save_format, + follow_links=follow_links, + subset=subset, + interpolation=interpolation, + dtype=self.dtype, + ) + + def flow_from_dataframe( + self, + dataframe, + directory=None, + x_col="filename", + y_col="class", + weight_col=None, + target_size=(256, 256), + color_mode="rgb", + classes=None, + class_mode="categorical", + batch_size=32, + shuffle=True, + seed=None, + save_to_dir=None, + save_prefix="", + save_format="png", + subset=None, + interpolation="nearest", + validate_filenames=True, + **kwargs, + ): + if "has_ext" in kwargs: + warnings.warn( + "has_ext is deprecated, filenames in the dataframe have " + "to match the exact filenames in disk.", + DeprecationWarning, + ) + if "sort" in kwargs: + warnings.warn( + "sort is deprecated, batches will be created in the" + "same order than the filenames provided if `shuffle`" + "is set to `False`.", + DeprecationWarning, + ) + if class_mode == "other": + warnings.warn( + '`class_mode="other"` is deprecated, please use ' + '`class_mode="raw"`.', + DeprecationWarning, + ) + class_mode = "raw" + if "drop_duplicates" in kwargs: + warnings.warn( + "drop_duplicates is deprecated, you can drop duplicates " + "by using the pandas.DataFrame.drop_duplicates method.", + DeprecationWarning, + ) + + return DataFrameIterator( + dataframe, + directory, + self, + x_col=x_col, + y_col=y_col, + weight_col=weight_col, + target_size=target_size, + color_mode=color_mode, + classes=classes, + class_mode=class_mode, + data_format=self.data_format, + batch_size=batch_size, + shuffle=shuffle, + seed=seed, + save_to_dir=save_to_dir, + save_prefix=save_prefix, + save_format=save_format, + subset=subset, + interpolation=interpolation, + validate_filenames=validate_filenames, + dtype=self.dtype, + ) + + def standardize(self, x): + """Applies the normalization configuration in-place to a batch of + inputs. + + `x` is changed in-place since the function is mainly used internally + to standardize images and feed them to your network. If a copy of `x` + would be created instead it would have a significant performance cost. + If you want to apply this method without changing the input in-place + you can call the method creating a copy before: + + standardize(np.copy(x)) + + Args: + x: Batch of inputs to be normalized. + + Returns: + The inputs, normalized. + """ + if self.preprocessing_function: + x = self.preprocessing_function(x) + if self.rescale: + x *= self.rescale + if self.samplewise_center: + x -= np.mean(x, keepdims=True) + if self.samplewise_std_normalization: + x /= np.std(x, keepdims=True) + 1e-6 + + if self.featurewise_center: + if self.mean is not None: + x -= self.mean + else: + warnings.warn( + "This ImageDataGenerator specifies " + "`featurewise_center`, but it hasn't " + "been fit on any training data. Fit it " + "first by calling `.fit(numpy_data)`." + ) + if self.featurewise_std_normalization: + if self.std is not None: + x /= self.std + 1e-6 + else: + warnings.warn( + "This ImageDataGenerator specifies " + "`featurewise_std_normalization`, " + "but it hasn't " + "been fit on any training data. Fit it " + "first by calling `.fit(numpy_data)`." + ) + if self.zca_whitening: + if self.zca_whitening_matrix is not None: + flat_x = x.reshape(-1, np.prod(x.shape[-3:])) + white_x = flat_x @ self.zca_whitening_matrix + x = np.reshape(white_x, x.shape) + else: + warnings.warn( + "This ImageDataGenerator specifies " + "`zca_whitening`, but it hasn't " + "been fit on any training data. Fit it " + "first by calling `.fit(numpy_data)`." + ) + return x + + def get_random_transform(self, img_shape, seed=None): + """Generates random parameters for a transformation. + + Args: + img_shape: Tuple of integers. + Shape of the image that is transformed. + seed: Random seed. + + Returns: + A dictionary containing randomly chosen parameters describing the + transformation. + """ + img_row_axis = self.row_axis - 1 + img_col_axis = self.col_axis - 1 + + if seed is not None: + np.random.seed(seed) + + if self.rotation_range: + theta = np.random.uniform(-self.rotation_range, self.rotation_range) + else: + theta = 0 + + if self.height_shift_range: + try: # 1-D array-like or int + tx = np.random.choice(self.height_shift_range) + tx *= np.random.choice([-1, 1]) + except ValueError: # floating point + tx = np.random.uniform( + -self.height_shift_range, self.height_shift_range + ) + if np.max(self.height_shift_range) < 1: + tx *= img_shape[img_row_axis] + else: + tx = 0 + + if self.width_shift_range: + try: # 1-D array-like or int + ty = np.random.choice(self.width_shift_range) + ty *= np.random.choice([-1, 1]) + except ValueError: # floating point + ty = np.random.uniform( + -self.width_shift_range, self.width_shift_range + ) + if np.max(self.width_shift_range) < 1: + ty *= img_shape[img_col_axis] + else: + ty = 0 + + if self.shear_range: + shear = np.random.uniform(-self.shear_range, self.shear_range) + else: + shear = 0 + + if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: + zx, zy = 1, 1 + else: + zx, zy = np.random.uniform( + self.zoom_range[0], self.zoom_range[1], 2 + ) + + flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip + flip_vertical = (np.random.random() < 0.5) * self.vertical_flip + + channel_shift_intensity = None + if self.channel_shift_range != 0: + channel_shift_intensity = np.random.uniform( + -self.channel_shift_range, self.channel_shift_range + ) + + brightness = None + if self.brightness_range is not None: + brightness = np.random.uniform( + self.brightness_range[0], self.brightness_range[1] + ) + + transform_parameters = { + "theta": theta, + "tx": tx, + "ty": ty, + "shear": shear, + "zx": zx, + "zy": zy, + "flip_horizontal": flip_horizontal, + "flip_vertical": flip_vertical, + "channel_shift_intensity": channel_shift_intensity, + "brightness": brightness, + } + + return transform_parameters + + def apply_transform(self, x, transform_parameters): + """Applies a transformation to an image according to given parameters. + + Args: + x: 3D tensor, single image. + transform_parameters: Dictionary with string - parameter pairs + describing the transformation. + Currently, the following parameters + from the dictionary are used: + - `'theta'`: Float. Rotation angle in degrees. + - `'tx'`: Float. Shift in the x direction. + - `'ty'`: Float. Shift in the y direction. + - `'shear'`: Float. Shear angle in degrees. + - `'zx'`: Float. Zoom in the x direction. + - `'zy'`: Float. Zoom in the y direction. + - `'flip_horizontal'`: Boolean. Horizontal flip. + - `'flip_vertical'`: Boolean. Vertical flip. + - `'channel_shift_intensity'`: Float. Channel shift intensity. + - `'brightness'`: Float. Brightness shift intensity. + + Returns: + A transformed version of the input (same shape). + """ + # x is a single image, so it doesn't have image number at index 0 + img_row_axis = self.row_axis - 1 + img_col_axis = self.col_axis - 1 + img_channel_axis = self.channel_axis - 1 + + x = apply_affine_transform( + x, + transform_parameters.get("theta", 0), + transform_parameters.get("tx", 0), + transform_parameters.get("ty", 0), + transform_parameters.get("shear", 0), + transform_parameters.get("zx", 1), + transform_parameters.get("zy", 1), + row_axis=img_row_axis, + col_axis=img_col_axis, + channel_axis=img_channel_axis, + fill_mode=self.fill_mode, + cval=self.cval, + order=self.interpolation_order, + ) + + if transform_parameters.get("channel_shift_intensity") is not None: + x = apply_channel_shift( + x, + transform_parameters["channel_shift_intensity"], + img_channel_axis, + ) + + if transform_parameters.get("flip_horizontal", False): + x = flip_axis(x, img_col_axis) + + if transform_parameters.get("flip_vertical", False): + x = flip_axis(x, img_row_axis) + + if transform_parameters.get("brightness") is not None: + x = apply_brightness_shift( + x, transform_parameters["brightness"], False + ) + + return x + + def random_transform(self, x, seed=None): + """Applies a random transformation to an image. + + Args: + x: 3D tensor, single image. + seed: Random seed. + + Returns: + A randomly transformed version of the input (same shape). + """ + params = self.get_random_transform(x.shape, seed) + return self.apply_transform(x, params) + + def fit(self, x, augment=False, rounds=1, seed=None): + """Fits the data generator to some sample data. + + This computes the internal data stats related to the + data-dependent transformations, based on an array of sample data. + + Only required if `featurewise_center` or + `featurewise_std_normalization` or `zca_whitening` + are set to `True`. + + When `rescale` is set to a value, rescaling is applied to + sample data before computing the internal data stats. + + Args: + x: Sample data. Should have rank 4. + In case of grayscale data, + the channels axis should have value 1, in case + of RGB data, it should have value 3, and in case + of RGBA data, it should have value 4. + augment: Boolean (default: False). + Whether to fit on randomly augmented samples. + rounds: Int (default: 1). + If using data augmentation (`augment=True`), + this is how many augmentation passes over the data to use. + seed: Int (default: None). Random seed. + """ + x = np.asarray(x, dtype=self.dtype) + if x.ndim != 4: + raise ValueError( + "Input to `.fit()` should have rank 4. Got array with shape: " + + str(x.shape) + ) + if x.shape[self.channel_axis] not in {1, 3, 4}: + warnings.warn( + "Expected input to be images (as Numpy array) " + f'following the data format convention "{self.data_format}' + f'" (channels on axis {self.channel_axis})' + ", i.e. expected either 1, 3 or 4 channels on axis " + f"{self.channel_axis}. However, it was passed an array with" + f" shape {x.shape} ({x.shape[self.channel_axis]} channels)." + ) + + if seed is not None: + np.random.seed(seed) + + x = np.copy(x) + if self.rescale: + x *= self.rescale + + if augment: + ax = np.zeros( + tuple([rounds * x.shape[0]] + list(x.shape)[1:]), + dtype=self.dtype, + ) + for r in range(rounds): + for i in range(x.shape[0]): + ax[i + r * x.shape[0]] = self.random_transform(x[i]) + x = ax + + if self.featurewise_center: + self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) + broadcast_shape = [1, 1, 1] + broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] + self.mean = np.reshape(self.mean, broadcast_shape) + x -= self.mean + + if self.featurewise_std_normalization: + self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) + broadcast_shape = [1, 1, 1] + broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] + self.std = np.reshape(self.std, broadcast_shape) + x /= self.std + 1e-6 + + if self.zca_whitening: + n = len(x) + flat_x = np.reshape(x, (n, -1)) + + u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False) + s_inv = np.sqrt(n) / (s + self.zca_epsilon) + self.zca_whitening_matrix = (u * s_inv).dot(u.T) + + +@keras_export("keras._legacy.preprocessing.image.random_rotation") +def random_rotation( + x, + rg, + row_axis=1, + col_axis=2, + channel_axis=0, + fill_mode="nearest", + cval=0.0, + interpolation_order=1, +): + """DEPRECATED.""" + theta = np.random.uniform(-rg, rg) + x = apply_affine_transform( + x, + theta=theta, + row_axis=row_axis, + col_axis=col_axis, + channel_axis=channel_axis, + fill_mode=fill_mode, + cval=cval, + order=interpolation_order, + ) + return x + + +@keras_export("keras._legacy.preprocessing.image.random_shift") +def random_shift( + x, + wrg, + hrg, + row_axis=1, + col_axis=2, + channel_axis=0, + fill_mode="nearest", + cval=0.0, + interpolation_order=1, +): + """DEPRECATED.""" + h, w = x.shape[row_axis], x.shape[col_axis] + tx = np.random.uniform(-hrg, hrg) * h + ty = np.random.uniform(-wrg, wrg) * w + x = apply_affine_transform( + x, + tx=tx, + ty=ty, + row_axis=row_axis, + col_axis=col_axis, + channel_axis=channel_axis, + fill_mode=fill_mode, + cval=cval, + order=interpolation_order, + ) + return x + + +@keras_export("keras._legacy.preprocessing.image.random_shear") +def random_shear( + x, + intensity, + row_axis=1, + col_axis=2, + channel_axis=0, + fill_mode="nearest", + cval=0.0, + interpolation_order=1, +): + """DEPRECATED.""" + shear = np.random.uniform(-intensity, intensity) + x = apply_affine_transform( + x, + shear=shear, + row_axis=row_axis, + col_axis=col_axis, + channel_axis=channel_axis, + fill_mode=fill_mode, + cval=cval, + order=interpolation_order, + ) + return x + + +@keras_export("keras._legacy.preprocessing.image.random_zoom") +def random_zoom( + x, + zoom_range, + row_axis=1, + col_axis=2, + channel_axis=0, + fill_mode="nearest", + cval=0.0, + interpolation_order=1, +): + """DEPRECATED.""" + if len(zoom_range) != 2: + raise ValueError( + "`zoom_range` should be a tuple or list of two floats. " + f"Received: {zoom_range}" + ) + + if zoom_range[0] == 1 and zoom_range[1] == 1: + zx, zy = 1, 1 + else: + zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) + x = apply_affine_transform( + x, + zx=zx, + zy=zy, + row_axis=row_axis, + col_axis=col_axis, + channel_axis=channel_axis, + fill_mode=fill_mode, + cval=cval, + order=interpolation_order, + ) + return x + + +@keras_export("keras._legacy.preprocessing.image.apply_channel_shift") +def apply_channel_shift(x, intensity, channel_axis=0): + """Performs a channel shift. + + DEPRECATED. + + Args: + x: Input tensor. Must be 3D. + intensity: Transformation intensity. + channel_axis: Index of axis for channels in the input tensor. + + Returns: + Numpy image tensor. + """ + x = np.rollaxis(x, channel_axis, 0) + min_x, max_x = np.min(x), np.max(x) + channel_images = [ + np.clip(x_channel + intensity, min_x, max_x) for x_channel in x + ] + x = np.stack(channel_images, axis=0) + x = np.rollaxis(x, 0, channel_axis + 1) + return x + + +@keras_export("keras._legacy.preprocessing.image.random_channel_shift") +def random_channel_shift(x, intensity_range, channel_axis=0): + """Performs a random channel shift. + + DEPRECATED. + + Args: + x: Input tensor. Must be 3D. + intensity_range: Transformation intensity. + channel_axis: Index of axis for channels in the input tensor. + + Returns: + Numpy image tensor. + """ + intensity = np.random.uniform(-intensity_range, intensity_range) + return apply_channel_shift(x, intensity, channel_axis=channel_axis) + + +@keras_export("keras._legacy.preprocessing.image.apply_brightness_shift") +def apply_brightness_shift(x, brightness, scale=True): + """Performs a brightness shift. + + DEPRECATED. + + Args: + x: Input tensor. Must be 3D. + brightness: Float. The new brightness value. + scale: Whether to rescale the image such that minimum and maximum values + are 0 and 255 respectively. Default: True. + + Returns: + Numpy image tensor. + + Raises: + ImportError: if PIL is not available. + """ + from PIL import ImageEnhance + + x_min, x_max = np.min(x), np.max(x) + local_scale = (x_min < 0) or (x_max > 255) + x = image_utils.array_to_img(x, scale=local_scale or scale) + x = imgenhancer_Brightness = ImageEnhance.Brightness(x) + x = imgenhancer_Brightness.enhance(brightness) + x = image_utils.img_to_array(x) + if not scale and local_scale: + x = x / 255 * (x_max - x_min) + x_min + return x + + +@keras_export("keras._legacy.preprocessing.image.random_brightness") +def random_brightness(x, brightness_range, scale=True): + """Performs a random brightness shift. + + DEPRECATED. + + Args: + x: Input tensor. Must be 3D. + brightness_range: Tuple of floats; brightness range. + scale: Whether to rescale the image such that minimum and maximum values + are 0 and 255 respectively. Default: True. + + Returns: + Numpy image tensor. + + Raises: + ValueError if `brightness_range` isn't a tuple. + """ + if len(brightness_range) != 2: + raise ValueError( + "`brightness_range should be tuple or list of two floats. " + f"Received: {brightness_range}" + ) + + u = np.random.uniform(brightness_range[0], brightness_range[1]) + return apply_brightness_shift(x, u, scale) + + +def transform_matrix_offset_center(matrix, x, y): + o_x = float(x) / 2 - 0.5 + o_y = float(y) / 2 - 0.5 + offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) + reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) + transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) + return transform_matrix + + +@keras_export("keras._legacy.preprocessing.image.apply_affine_transform") +def apply_affine_transform( + x, + theta=0, + tx=0, + ty=0, + shear=0, + zx=1, + zy=1, + row_axis=1, + col_axis=2, + channel_axis=0, + fill_mode="nearest", + cval=0.0, + order=1, +): + """Applies an affine transformation specified by the parameters given. + + DEPRECATED. + """ + # Input sanity checks: + # 1. x must 2D image with one or more channels (i.e., a 3D tensor) + # 2. channels must be either first or last dimension + if np.unique([row_axis, col_axis, channel_axis]).size != 3: + raise ValueError( + "'row_axis', 'col_axis', and 'channel_axis' must be distinct" + ) + + # shall we support negative indices? + valid_indices = set([0, 1, 2]) + actual_indices = set([row_axis, col_axis, channel_axis]) + if actual_indices != valid_indices: + raise ValueError( + f"Invalid axis' indices: {actual_indices - valid_indices}" + ) + + if x.ndim != 3: + raise ValueError("Input arrays must be multi-channel 2D images.") + if channel_axis not in [0, 2]: + raise ValueError( + "Channels are allowed and the first and last dimensions." + ) + + transform_matrix = None + if theta != 0: + theta = np.deg2rad(theta) + rotation_matrix = np.array( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ] + ) + transform_matrix = rotation_matrix + + if tx != 0 or ty != 0: + shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) + if transform_matrix is None: + transform_matrix = shift_matrix + else: + transform_matrix = np.dot(transform_matrix, shift_matrix) + + if shear != 0: + shear = np.deg2rad(shear) + shear_matrix = np.array( + [[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]] + ) + if transform_matrix is None: + transform_matrix = shear_matrix + else: + transform_matrix = np.dot(transform_matrix, shear_matrix) + + if zx != 1 or zy != 1: + zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) + if transform_matrix is None: + transform_matrix = zoom_matrix + else: + transform_matrix = np.dot(transform_matrix, zoom_matrix) + + if transform_matrix is not None: + h, w = x.shape[row_axis], x.shape[col_axis] + transform_matrix = transform_matrix_offset_center( + transform_matrix, h, w + ) + x = np.rollaxis(x, channel_axis, 0) + + # Matrix construction assumes that coordinates are x, y (in that order). + # However, regular numpy arrays use y,x (aka i,j) indexing. + # Possible solution is: + # 1. Swap the x and y axes. + # 2. Apply transform. + # 3. Swap the x and y axes again to restore image-like data ordering. + # Mathematically, it is equivalent to the following transformation: + # M' = PMP, where P is the permutation matrix, M is the original + # transformation matrix. + if col_axis > row_axis: + transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]] + transform_matrix[[0, 1]] = transform_matrix[[1, 0]] + final_affine_matrix = transform_matrix[:2, :2] + final_offset = transform_matrix[:2, 2] + + channel_images = [ + scipy.ndimage.interpolation.affine_transform( + x_channel, + final_affine_matrix, + final_offset, + order=order, + mode=fill_mode, + cval=cval, + ) + for x_channel in x + ] + x = np.stack(channel_images, axis=0) + x = np.rollaxis(x, 0, channel_axis + 1) + return x diff --git a/keras/src/legacy/preprocessing/sequence.py b/keras/src/legacy/preprocessing/sequence.py new file mode 100644 index 000000000000..18e21d944262 --- /dev/null +++ b/keras/src/legacy/preprocessing/sequence.py @@ -0,0 +1,328 @@ +"""Deprecated sequence preprocessing APIs from Keras 1.""" + +import json +import random + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset + + +@keras_export("keras._legacy.preprocessing.sequence.TimeseriesGenerator") +class TimeseriesGenerator(PyDataset): + """Utility class for generating batches of temporal data. + + DEPRECATED. + + This class takes in a sequence of data-points gathered at + equal intervals, along with time series parameters such as + stride, length of history, etc., to produce batches for + training/validation. + + Arguments: + data: Indexable generator (such as list or Numpy array) + containing consecutive data points (timesteps). + The data should be at 2D, and axis 0 is expected + to be the time dimension. + targets: Targets corresponding to timesteps in `data`. + It should have same length as `data`. + length: Length of the output sequences (in number of timesteps). + sampling_rate: Period between successive individual timesteps + within sequences. For rate `r`, timesteps + `data[i]`, `data[i-r]`, ... `data[i - length]` + are used for create a sample sequence. + stride: Period between successive output sequences. + For stride `s`, consecutive output samples would + be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. + start_index: Data points earlier than `start_index` will not be used + in the output sequences. This is useful to reserve part of the + data for test or validation. + end_index: Data points later than `end_index` will not be used + in the output sequences. This is useful to reserve part of the + data for test or validation. + shuffle: Whether to shuffle output samples, + or instead draw them in chronological order. + reverse: Boolean: if `true`, timesteps in each output sample will be + in reverse chronological order. + batch_size: Number of timeseries samples in each batch + (except maybe the last one). + **kwargs: Additional keyword arguments for the `PyDataset` base class, + such as `workers`, `use_multiprocessing`, and `max_queue_size`. + + Returns: + A PyDataset instance. + """ + + def __init__( + self, + data, + targets, + length, + sampling_rate=1, + stride=1, + start_index=0, + end_index=None, + shuffle=False, + reverse=False, + batch_size=128, + **kwargs, + ): + super().__init__(**kwargs) + if len(data) != len(targets): + raise ValueError( + "Data and targets have to be " + f"of same length. Data length is {len(data)} " + f"while target length is {len(targets)}" + ) + + self.data = data + self.targets = targets + self.length = length + self.sampling_rate = sampling_rate + self.stride = stride + self.start_index = start_index + length + if end_index is None: + end_index = len(data) - 1 + self.end_index = end_index + self.shuffle = shuffle + self.reverse = reverse + self.batch_size = batch_size + + if self.start_index > self.end_index: + raise ValueError( + f"`start_index+length={self.start_index} " + f"> end_index={self.end_index}` " + "is disallowed, as no part of the sequence " + "would be left to be used as current step." + ) + + def __len__(self): + return ( + self.end_index - self.start_index + self.batch_size * self.stride + ) // (self.batch_size * self.stride) + + def __getitem__(self, index): + if self.shuffle: + rows = np.random.randint( + self.start_index, self.end_index + 1, size=self.batch_size + ) + else: + i = self.start_index + self.batch_size * self.stride * index + rows = np.arange( + i, + min(i + self.batch_size * self.stride, self.end_index + 1), + self.stride, + ) + + samples = np.array( + [ + self.data[row - self.length : row : self.sampling_rate] + for row in rows + ] + ) + targets = np.array([self.targets[row] for row in rows]) + + if self.reverse: + return samples[:, ::-1, ...], targets + return samples, targets + + def get_config(self): + """Returns the TimeseriesGenerator configuration as Python dictionary. + + Returns: + A Python dictionary with the TimeseriesGenerator configuration. + """ + data = self.data + if type(self.data).__module__ == np.__name__: + data = self.data.tolist() + try: + json_data = json.dumps(data) + except TypeError as e: + raise TypeError(f"Data not JSON Serializable: {data}") from e + + targets = self.targets + if type(self.targets).__module__ == np.__name__: + targets = self.targets.tolist() + try: + json_targets = json.dumps(targets) + except TypeError as e: + raise TypeError(f"Targets not JSON Serializable: {targets}") from e + + config = super().get_config() + config.update( + { + "data": json_data, + "targets": json_targets, + "length": self.length, + "sampling_rate": self.sampling_rate, + "stride": self.stride, + "start_index": self.start_index, + "end_index": self.end_index, + "shuffle": self.shuffle, + "reverse": self.reverse, + "batch_size": self.batch_size, + } + ) + return config + + def to_json(self, **kwargs): + """Returns a JSON string containing the generator's configuration. + + Args: + **kwargs: Additional keyword arguments to be passed + to `json.dumps()`. + + Returns: + A JSON string containing the tokenizer configuration. + """ + config = self.get_config() + timeseries_generator_config = { + "class_name": self.__class__.__name__, + "config": config, + } + return json.dumps(timeseries_generator_config, **kwargs) + + +@keras_export("keras._legacy.preprocessing.sequence.make_sampling_table") +def make_sampling_table(size, sampling_factor=1e-5): + """Generates a word rank-based probabilistic sampling table. + + DEPRECATED. + + Used for generating the `sampling_table` argument for `skipgrams`. + `sampling_table[i]` is the probability of sampling + the word i-th most common word in a dataset + (more common words should be sampled less frequently, for balance). + + The sampling probabilities are generated according + to the sampling distribution used in word2vec: + + ``` + p(word) = (min(1, sqrt(word_frequency / sampling_factor) / + (word_frequency / sampling_factor))) + ``` + + We assume that the word frequencies follow Zipf's law (s=1) to derive + a numerical approximation of frequency(rank): + + `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))` + where `gamma` is the Euler-Mascheroni constant. + + Args: + size: Int, number of possible words to sample. + sampling_factor: The sampling factor in the word2vec formula. + + Returns: + A 1D Numpy array of length `size` where the ith entry + is the probability that a word of rank i should be sampled. + """ + gamma = 0.577 + rank = np.arange(size) + rank[0] = 1 + inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1.0 / (12.0 * rank) + f = sampling_factor * inv_fq + + return np.minimum(1.0, f / np.sqrt(f)) + + +@keras_export("keras._legacy.preprocessing.sequence.skipgrams") +def skipgrams( + sequence, + vocabulary_size, + window_size=4, + negative_samples=1.0, + shuffle=True, + categorical=False, + sampling_table=None, + seed=None, +): + """Generates skipgram word pairs. + + DEPRECATED. + + This function transforms a sequence of word indexes (list of integers) + into tuples of words of the form: + + - (word, word in the same window), with label 1 (positive samples). + - (word, random word from the vocabulary), with label 0 (negative samples). + + Read more about Skipgram in this gnomic paper by Mikolov et al.: + [Efficient Estimation of Word Representations in + Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf) + + Args: + sequence: A word sequence (sentence), encoded as a list + of word indices (integers). If using a `sampling_table`, + word indices are expected to match the rank + of the words in a reference dataset (e.g. 10 would encode + the 10-th most frequently occurring token). + Note that index 0 is expected to be a non-word and will be skipped. + vocabulary_size: Int, maximum possible word index + 1 + window_size: Int, size of sampling windows (technically half-window). + The window of a word `w_i` will be + `[i - window_size, i + window_size+1]`. + negative_samples: Float >= 0. 0 for no negative (i.e. random) samples. + 1 for same number as positive samples. + shuffle: Whether to shuffle the word couples before returning them. + categorical: bool. if False, labels will be + integers (eg. `[0, 1, 1 .. ]`), + if `True`, labels will be categorical, e.g. + `[[1,0],[0,1],[0,1] .. ]`. + sampling_table: 1D array of size `vocabulary_size` where the entry i + encodes the probability to sample a word of rank i. + seed: Random seed. + + Returns: + couples, labels: where `couples` are int pairs and + `labels` are either 0 or 1. + + Note: + By convention, index 0 in the vocabulary is + a non-word and will be skipped. + """ + couples = [] + labels = [] + for i, wi in enumerate(sequence): + if not wi: + continue + if sampling_table is not None: + if sampling_table[wi] < random.random(): + continue + + window_start = max(0, i - window_size) + window_end = min(len(sequence), i + window_size + 1) + for j in range(window_start, window_end): + if j != i: + wj = sequence[j] + if not wj: + continue + couples.append([wi, wj]) + if categorical: + labels.append([0, 1]) + else: + labels.append(1) + + if negative_samples > 0: + num_negative_samples = int(len(labels) * negative_samples) + words = [c[0] for c in couples] + random.shuffle(words) + + couples += [ + [words[i % len(words)], random.randint(1, vocabulary_size - 1)] + for i in range(num_negative_samples) + ] + if categorical: + labels += [[1, 0]] * num_negative_samples + else: + labels += [0] * num_negative_samples + + if shuffle: + if seed is None: + seed = random.randint(0, 10e6) + random.seed(seed) + random.shuffle(couples) + random.seed(seed) + random.shuffle(labels) + + return couples, labels diff --git a/keras/src/legacy/preprocessing/text.py b/keras/src/legacy/preprocessing/text.py new file mode 100644 index 000000000000..bcf59a870256 --- /dev/null +++ b/keras/src/legacy/preprocessing/text.py @@ -0,0 +1,336 @@ +"""Deprecated text preprocessing APIs from Keras 1.""" + +import collections +import hashlib +import json +import warnings + +import numpy as np + +from keras.src.api_export import keras_export + + +@keras_export("keras._legacy.preprocessing.text.text_to_word_sequence") +def text_to_word_sequence( + input_text, + filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', + lower=True, + split=" ", +): + """DEPRECATED.""" + if lower: + input_text = input_text.lower() + + translate_dict = {c: split for c in filters} + translate_map = str.maketrans(translate_dict) + input_text = input_text.translate(translate_map) + + seq = input_text.split(split) + return [i for i in seq if i] + + +@keras_export("keras._legacy.preprocessing.text.one_hot") +def one_hot( + input_text, + n, + filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', + lower=True, + split=" ", + analyzer=None, +): + """DEPRECATED.""" + return hashing_trick( + input_text, + n, + hash_function=hash, + filters=filters, + lower=lower, + split=split, + analyzer=analyzer, + ) + + +@keras_export("keras._legacy.preprocessing.text.hashing_trick") +def hashing_trick( + text, + n, + hash_function=None, + filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', + lower=True, + split=" ", + analyzer=None, +): + """DEPRECATED.""" + if hash_function is None: + hash_function = hash + elif hash_function == "md5": + + def hash_function(w): + return int(hashlib.md5(w.encode()).hexdigest(), 16) + + if analyzer is None: + seq = text_to_word_sequence( + text, filters=filters, lower=lower, split=split + ) + else: + seq = analyzer(text) + + return [(hash_function(w) % (n - 1) + 1) for w in seq] + + +@keras_export("keras._legacy.preprocessing.text.Tokenizer") +class Tokenizer: + """DEPRECATED.""" + + def __init__( + self, + num_words=None, + filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', + lower=True, + split=" ", + char_level=False, + oov_token=None, + analyzer=None, + **kwargs, + ): + # Legacy support + if "nb_words" in kwargs: + warnings.warn( + "The `nb_words` argument in `Tokenizer` " + "has been renamed `num_words`." + ) + num_words = kwargs.pop("nb_words") + document_count = kwargs.pop("document_count", 0) + if kwargs: + raise TypeError(f"Unrecognized keyword arguments: {str(kwargs)}") + + self.word_counts = collections.OrderedDict() + self.word_docs = collections.defaultdict(int) + self.filters = filters + self.split = split + self.lower = lower + self.num_words = num_words + self.document_count = document_count + self.char_level = char_level + self.oov_token = oov_token + self.index_docs = collections.defaultdict(int) + self.word_index = {} + self.index_word = {} + self.analyzer = analyzer + + def fit_on_texts(self, texts): + for text in texts: + self.document_count += 1 + if self.char_level or isinstance(text, list): + if self.lower: + if isinstance(text, list): + text = [text_elem.lower() for text_elem in text] + else: + text = text.lower() + seq = text + else: + if self.analyzer is None: + seq = text_to_word_sequence( + text, + filters=self.filters, + lower=self.lower, + split=self.split, + ) + else: + seq = self.analyzer(text) + for w in seq: + if w in self.word_counts: + self.word_counts[w] += 1 + else: + self.word_counts[w] = 1 + for w in set(seq): + # In how many documents each word occurs + self.word_docs[w] += 1 + + wcounts = list(self.word_counts.items()) + wcounts.sort(key=lambda x: x[1], reverse=True) + # forcing the oov_token to index 1 if it exists + if self.oov_token is None: + sorted_voc = [] + else: + sorted_voc = [self.oov_token] + sorted_voc.extend(wc[0] for wc in wcounts) + + # note that index 0 is reserved, never assigned to an existing word + self.word_index = dict( + zip(sorted_voc, list(range(1, len(sorted_voc) + 1))) + ) + + self.index_word = {c: w for w, c in self.word_index.items()} + + for w, c in list(self.word_docs.items()): + self.index_docs[self.word_index[w]] = c + + def fit_on_sequences(self, sequences): + self.document_count += len(sequences) + for seq in sequences: + seq = set(seq) + for i in seq: + self.index_docs[i] += 1 + + def texts_to_sequences(self, texts): + return list(self.texts_to_sequences_generator(texts)) + + def texts_to_sequences_generator(self, texts): + num_words = self.num_words + oov_token_index = self.word_index.get(self.oov_token) + for text in texts: + if self.char_level or isinstance(text, list): + if self.lower: + if isinstance(text, list): + text = [text_elem.lower() for text_elem in text] + else: + text = text.lower() + seq = text + else: + if self.analyzer is None: + seq = text_to_word_sequence( + text, + filters=self.filters, + lower=self.lower, + split=self.split, + ) + else: + seq = self.analyzer(text) + vect = [] + for w in seq: + i = self.word_index.get(w) + if i is not None: + if num_words and i >= num_words: + if oov_token_index is not None: + vect.append(oov_token_index) + else: + vect.append(i) + elif self.oov_token is not None: + vect.append(oov_token_index) + yield vect + + def sequences_to_texts(self, sequences): + return list(self.sequences_to_texts_generator(sequences)) + + def sequences_to_texts_generator(self, sequences): + num_words = self.num_words + oov_token_index = self.word_index.get(self.oov_token) + for seq in sequences: + vect = [] + for num in seq: + word = self.index_word.get(num) + if word is not None: + if num_words and num >= num_words: + if oov_token_index is not None: + vect.append(self.index_word[oov_token_index]) + else: + vect.append(word) + elif self.oov_token is not None: + vect.append(self.index_word[oov_token_index]) + vect = " ".join(vect) + yield vect + + def texts_to_matrix(self, texts, mode="binary"): + sequences = self.texts_to_sequences(texts) + return self.sequences_to_matrix(sequences, mode=mode) + + def sequences_to_matrix(self, sequences, mode="binary"): + if not self.num_words: + if self.word_index: + num_words = len(self.word_index) + 1 + else: + raise ValueError( + "Specify a dimension (`num_words` argument), " + "or fit on some text data first." + ) + else: + num_words = self.num_words + + if mode == "tfidf" and not self.document_count: + raise ValueError( + "Fit the Tokenizer on some data before using tfidf mode." + ) + + x = np.zeros((len(sequences), num_words)) + for i, seq in enumerate(sequences): + if not seq: + continue + counts = collections.defaultdict(int) + for j in seq: + if j >= num_words: + continue + counts[j] += 1 + for j, c in list(counts.items()): + if mode == "count": + x[i][j] = c + elif mode == "freq": + x[i][j] = c / len(seq) + elif mode == "binary": + x[i][j] = 1 + elif mode == "tfidf": + # Use weighting scheme 2 in + # https://en.wikipedia.org/wiki/Tf%E2%80%93idf + tf = 1 + np.log(c) + idf = np.log( + 1 + + self.document_count / (1 + self.index_docs.get(j, 0)) + ) + x[i][j] = tf * idf + else: + raise ValueError("Unknown vectorization mode:", mode) + return x + + def get_config(self): + json_word_counts = json.dumps(self.word_counts) + json_word_docs = json.dumps(self.word_docs) + json_index_docs = json.dumps(self.index_docs) + json_word_index = json.dumps(self.word_index) + json_index_word = json.dumps(self.index_word) + + return { + "num_words": self.num_words, + "filters": self.filters, + "lower": self.lower, + "split": self.split, + "char_level": self.char_level, + "oov_token": self.oov_token, + "document_count": self.document_count, + "word_counts": json_word_counts, + "word_docs": json_word_docs, + "index_docs": json_index_docs, + "index_word": json_index_word, + "word_index": json_word_index, + } + + def to_json(self, **kwargs): + config = self.get_config() + tokenizer_config = { + "class_name": self.__class__.__name__, + "config": config, + } + return json.dumps(tokenizer_config, **kwargs) + + +@keras_export("keras._legacy.preprocessing.text.tokenizer_from_json") +def tokenizer_from_json(json_string): + """DEPRECATED.""" + tokenizer_config = json.loads(json_string) + config = tokenizer_config.get("config") + + word_counts = json.loads(config.pop("word_counts")) + word_docs = json.loads(config.pop("word_docs")) + index_docs = json.loads(config.pop("index_docs")) + # Integer indexing gets converted to strings with json.dumps() + index_docs = {int(k): v for k, v in index_docs.items()} + index_word = json.loads(config.pop("index_word")) + index_word = {int(k): v for k, v in index_word.items()} + word_index = json.loads(config.pop("word_index")) + + tokenizer = Tokenizer(**config) + tokenizer.word_counts = word_counts + tokenizer.word_docs = word_docs + tokenizer.index_docs = index_docs + tokenizer.word_index = word_index + tokenizer.index_word = index_word + return tokenizer diff --git a/keras/src/legacy/saving/__init__.py b/keras/src/legacy/saving/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/legacy/saving/json_utils.py b/keras/src/legacy/saving/json_utils.py new file mode 100644 index 000000000000..0dbc578d25ab --- /dev/null +++ b/keras/src/legacy/saving/json_utils.py @@ -0,0 +1,220 @@ +"""JSON utilities for legacy saving formats (h5 and SavedModel)""" + +import collections +import enum +import functools +import json + +import numpy as np + +from keras.src.legacy.saving import serialization +from keras.src.saving import serialization_lib +from keras.src.utils.module_utils import tensorflow as tf + +_EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC" + + +class Encoder(json.JSONEncoder): + """JSON encoder and decoder that handles TensorShapes and tuples.""" + + def default(self, obj): + """Encodes objects for types that aren't handled by the default + encoder.""" + if tf.available and isinstance(obj, tf.TensorShape): + items = obj.as_list() if obj.rank is not None else None + return {"class_name": "TensorShape", "items": items} + return get_json_type(obj) + + def encode(self, obj): + return super().encode(_encode_tuple(obj)) + + +def _encode_tuple(x): + if isinstance(x, tuple): + return { + "class_name": "__tuple__", + "items": tuple(_encode_tuple(i) for i in x), + } + elif isinstance(x, list): + return [_encode_tuple(i) for i in x] + elif isinstance(x, dict): + return {key: _encode_tuple(value) for key, value in x.items()} + else: + return x + + +def decode(json_string): + return json.loads(json_string, object_hook=_decode_helper) + + +def decode_and_deserialize( + json_string, module_objects=None, custom_objects=None +): + """Decodes the JSON and deserializes any Keras objects found in the dict.""" + return json.loads( + json_string, + object_hook=functools.partial( + _decode_helper, + deserialize=True, + module_objects=module_objects, + custom_objects=custom_objects, + ), + ) + + +def _decode_helper( + obj, deserialize=False, module_objects=None, custom_objects=None +): + """A decoding helper that is TF-object aware. + + Args: + obj: A decoded dictionary that may represent an object. + deserialize: Boolean. When True, deserializes any Keras + objects found in `obj`. Defaults to `False`. + module_objects: A dictionary of built-in objects to look the name up in. + Generally, `module_objects` is provided by midlevel library + implementers. + custom_objects: A dictionary of custom objects to look the name up in. + Generally, `custom_objects` is provided by the end user. + + Returns: + The decoded object. + """ + if isinstance(obj, dict) and "class_name" in obj: + if tf.available: + if obj["class_name"] == "TensorShape": + return tf.TensorShape(obj["items"]) + elif obj["class_name"] == "TypeSpec": + from tensorflow.python.framework import type_spec_registry + + return type_spec_registry.lookup(obj["type_spec"])._deserialize( + _decode_helper(obj["serialized"]) + ) + elif obj["class_name"] == "CompositeTensor": + spec = obj["spec"] + tensors = [] + for dtype, tensor in obj["tensors"]: + tensors.append( + tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype)) + ) + return tf.nest.pack_sequence_as( + _decode_helper(spec), tensors, expand_composites=True + ) + + if obj["class_name"] == "__tuple__": + return tuple(_decode_helper(i) for i in obj["items"]) + elif obj["class_name"] == "__ellipsis__": + return Ellipsis + elif deserialize and "__passive_serialization__" in obj: + # __passive_serialization__ is added by the JSON encoder when + # encoding an object that has a `get_config()` method. + try: + if ( + "module" not in obj + ): # TODO(nkovela): Add TF SavedModel scope + return serialization.deserialize_keras_object( + obj, + module_objects=module_objects, + custom_objects=custom_objects, + ) + else: + return serialization_lib.deserialize_keras_object( + obj, + module_objects=module_objects, + custom_objects=custom_objects, + ) + except ValueError: + pass + elif obj["class_name"] == "__bytes__": + return obj["value"].encode("utf-8") + return obj + + +def get_json_type(obj): + """Serializes any object to a JSON-serializable structure. + + Args: + obj: the object to serialize + + Returns: + JSON-serializable structure representing `obj`. + + Raises: + TypeError: if `obj` cannot be serialized. + """ + # if obj is a serializable Keras class instance + # e.g. optimizer, layer + if hasattr(obj, "get_config"): + # TODO(nkovela): Replace with legacy serialization + serialized = serialization.serialize_keras_object(obj) + serialized["__passive_serialization__"] = True + return serialized + + # if obj is any numpy type + if type(obj).__module__ == np.__name__: + if isinstance(obj, np.ndarray): + return obj.tolist() + else: + return obj.item() + + # misc functions (e.g. loss function) + if callable(obj): + return obj.__name__ + + # if obj is a python 'type' + if type(obj).__name__ == type.__name__: + return obj.__name__ + + if tf.available and isinstance(obj, tf.compat.v1.Dimension): + return obj.value + + if tf.available and isinstance(obj, tf.TensorShape): + return obj.as_list() + + if tf.available and isinstance(obj, tf.DType): + return obj.name + + if isinstance(obj, collections.abc.Mapping): + return dict(obj) + + if obj is Ellipsis: + return {"class_name": "__ellipsis__"} + + # if isinstance(obj, wrapt.ObjectProxy): + # return obj.__wrapped__ + + if tf.available and isinstance(obj, tf.TypeSpec): + from tensorflow.python.framework import type_spec_registry + + try: + type_spec_name = type_spec_registry.get_name(type(obj)) + return { + "class_name": "TypeSpec", + "type_spec": type_spec_name, + "serialized": obj._serialize(), + } + except ValueError: + raise ValueError( + f"Unable to serialize {obj} to JSON, because the TypeSpec " + f"class {type(obj)} has not been registered." + ) + if tf.available and isinstance(obj, tf.__internal__.CompositeTensor): + spec = tf.type_spec_from_value(obj) + tensors = [] + for tensor in tf.nest.flatten(obj, expand_composites=True): + tensors.append((tensor.dtype.name, tensor.numpy().tolist())) + return { + "class_name": "CompositeTensor", + "spec": get_json_type(spec), + "tensors": tensors, + } + + if isinstance(obj, enum.Enum): + return obj.value + + if isinstance(obj, bytes): + return {"class_name": "__bytes__", "value": obj.decode("utf-8")} + + raise TypeError( + f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}." + ) diff --git a/keras/src/legacy/saving/json_utils_test.py b/keras/src/legacy/saving/json_utils_test.py new file mode 100644 index 000000000000..def0111441b3 --- /dev/null +++ b/keras/src/legacy/saving/json_utils_test.py @@ -0,0 +1,94 @@ +import enum + +import pytest + +from keras.src import backend +from keras.src import testing +from keras.src.legacy.saving import json_utils + +if backend.backend() == "tensorflow": + import tensorflow as tf + + +class JsonUtilsTestAllBackends(testing.TestCase): + def test_encode_decode_tuple(self): + metadata = {"key1": (3, 5), "key2": [(1, (3, 4)), (1,)]} + string = json_utils.Encoder().encode(metadata) + loaded = json_utils.decode(string) + + self.assertEqual(set(loaded.keys()), {"key1", "key2"}) + self.assertAllEqual(loaded["key1"], (3, 5)) + self.assertAllEqual(loaded["key2"], [(1, (3, 4)), (1,)]) + + def test_encode_decode_enum(self): + class Enum(enum.Enum): + CLASS_A = "a" + CLASS_B = "b" + + config = {"key": Enum.CLASS_A, "key2": Enum.CLASS_B} + string = json_utils.Encoder().encode(config) + loaded = json_utils.decode(string) + self.assertAllEqual({"key": "a", "key2": "b"}, loaded) + + def test_encode_decode_bytes(self): + b_string = b"abc" + json_string = json_utils.Encoder().encode(b_string) + loaded = json_utils.decode(json_string) + self.assertAllEqual(b_string, loaded) + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="These JSON serialization tests are specific to TF components.", +) +class JsonUtilsTestTF(testing.TestCase): + def test_encode_decode_tensor_shape(self): + metadata = { + "key1": tf.TensorShape(None), + "key2": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])], + } + string = json_utils.Encoder().encode(metadata) + loaded = json_utils.decode(string) + + self.assertEqual(set(loaded.keys()), {"key1", "key2"}) + self.assertEqual(loaded["key1"].rank, None) + self.assertAllEqual(loaded["key2"][0].as_list(), [None]) + self.assertAllEqual(loaded["key2"][1].as_list(), [3, None, 5]) + + def test_encode_decode_type_spec(self): + spec = tf.TensorSpec((1, 5), tf.float32) + string = json_utils.Encoder().encode(spec) + loaded = json_utils.decode(string) + self.assertEqual(spec, loaded) + + invalid_type_spec = { + "class_name": "TypeSpec", + "type_spec": "Invalid Type", + "serialized": None, + } + string = json_utils.Encoder().encode(invalid_type_spec) + with self.assertRaisesRegex( + ValueError, "No TypeSpec has been registered" + ): + loaded = json_utils.decode(string) + + def test_encode_decode_ragged_tensor(self): + x = tf.ragged.constant([[1.0, 2.0], [3.0]]) + string = json_utils.Encoder().encode(x) + loaded = json_utils.decode(string) + self.assertAllClose(loaded.values, x.values) + + def test_encode_decode_extension_type_tensor(self): + class MaskedTensor(tf.experimental.ExtensionType): + __name__ = "MaskedTensor" + values: tf.Tensor + mask: tf.Tensor + + x = MaskedTensor( + values=[[1, 2, 3], [4, 5, 6]], + mask=[[True, True, False], [True, False, True]], + ) + string = json_utils.Encoder().encode(x) + loaded = json_utils.decode(string) + self.assertAllClose(loaded.values, x.values) + self.assertAllClose(loaded.mask, x.mask) diff --git a/keras/src/legacy/saving/legacy_h5_format.py b/keras/src/legacy/saving/legacy_h5_format.py new file mode 100644 index 000000000000..7cb0ed8d1dbe --- /dev/null +++ b/keras/src/legacy/saving/legacy_h5_format.py @@ -0,0 +1,652 @@ +import json +import os +import warnings + +import numpy as np +from absl import logging + +from keras.src import backend +from keras.src.backend.common import global_state +from keras.src.legacy.saving import json_utils +from keras.src.legacy.saving import saving_options +from keras.src.legacy.saving import saving_utils +from keras.src.saving import object_registration +from keras.src.saving import serialization_lib +from keras.src.utils import io_utils + +try: + import h5py +except ImportError: + h5py = None + + +HDF5_OBJECT_HEADER_LIMIT = 64512 + + +def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): + if h5py is None: + raise ImportError( + "`save_model()` using h5 format requires h5py. Could not " + "import h5py." + ) + + if not isinstance(filepath, h5py.File): + # If file exists and should not be overwritten. + if not overwrite and os.path.isfile(filepath): + proceed = io_utils.ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + dirpath = os.path.dirname(filepath) + if dirpath and not os.path.exists(dirpath): + os.makedirs(dirpath, exist_ok=True) + + f = h5py.File(filepath, mode="w") + opened_new_file = True + else: + f = filepath + opened_new_file = False + try: + with saving_options.keras_option_scope(use_legacy_config=True): + model_metadata = saving_utils.model_metadata( + model, include_optimizer + ) + for k, v in model_metadata.items(): + if isinstance(v, (dict, list, tuple)): + f.attrs[k] = json.dumps( + v, default=json_utils.get_json_type + ).encode("utf8") + else: + f.attrs[k] = v + + model_weights_group = f.create_group("model_weights") + save_weights_to_hdf5_group(model_weights_group, model) + + # TODO(b/128683857): Add integration tests between tf.keras and + # external Keras, to avoid breaking TF.js users. + if include_optimizer and hasattr(model, "optimizer"): + save_optimizer_weights_to_hdf5_group(f, model.optimizer) + + f.flush() + finally: + if opened_new_file: + f.close() + + +def load_model_from_hdf5( + filepath, custom_objects=None, compile=True, safe_mode=True +): + """Loads a model saved via `save_model_to_hdf5`. + + Args: + filepath: One of the following: + - String, path to the saved model + - `h5py.File` object from which to load the model + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + compile: Boolean, whether to compile the model + after loading. + + Returns: + A Keras model instance. If an optimizer was found + as part of the saved model, the model is already + compiled. Otherwise, the model is uncompiled and + a warning will be displayed. When `compile` is set + to `False`, the compilation is omitted without any + warning. + + Raises: + ImportError: if h5py is not available. + ValueError: In case of an invalid savefile. + """ + if h5py is None: + raise ImportError( + "`load_model()` using h5 format requires h5py. Could not " + "import h5py." + ) + + if not custom_objects: + custom_objects = {} + + gco = object_registration.GLOBAL_CUSTOM_OBJECTS + tlco = global_state.get_global_attribute("custom_objects_scope_dict", {}) + custom_objects = {**custom_objects, **gco, **tlco} + + opened_new_file = not isinstance(filepath, h5py.File) + if opened_new_file: + f = h5py.File(filepath, mode="r") + else: + f = filepath + + model = None + try: + # instantiate model + model_config = f.attrs.get("model_config") + if model_config is None: + raise ValueError( + f"No model config found in the file at {filepath}." + ) + if hasattr(model_config, "decode"): + model_config = model_config.decode("utf-8") + model_config = json_utils.decode(model_config) + + legacy_scope = saving_options.keras_option_scope(use_legacy_config=True) + safe_mode_scope = serialization_lib.SafeModeScope(safe_mode) + with legacy_scope, safe_mode_scope: + model = saving_utils.model_from_config( + model_config, custom_objects=custom_objects + ) + + # set weights + load_weights_from_hdf5_group(f["model_weights"], model) + + if compile: + # instantiate optimizer + training_config = f.attrs.get("training_config") + if hasattr(training_config, "decode"): + training_config = training_config.decode("utf-8") + if training_config is None: + logging.warning( + "No training configuration found in the save file, so " + "the model was *not* compiled. Compile it manually." + ) + return model + training_config = json_utils.decode(training_config) + + # Compile model. + model.compile( + **saving_utils.compile_args_from_training_config( + training_config, custom_objects + ) + ) + saving_utils.try_build_compiled_arguments(model) + + # Set optimizer weights. + if "optimizer_weights" in f: + try: + from keras.src import optimizers + + if isinstance(model.optimizer, optimizers.Optimizer): + model.optimizer.build(model._trainable_variables) + else: + model.optimizer._create_all_weights( + model._trainable_variables + ) + except (NotImplementedError, AttributeError): + logging.warning( + "Error when creating the weights of optimizer {}, " + "making it impossible to restore the saved optimizer " + "state. As a result, your model is starting with " + "a freshly initialized optimizer." + ) + + optimizer_weight_values = ( + load_optimizer_weights_from_hdf5_group(f) + ) + try: + model.optimizer.set_weights(optimizer_weight_values) + except ValueError: + logging.warning( + "Error in loading the saved optimizer " + "state. As a result, your model is " + "starting with a freshly initialized " + "optimizer." + ) + finally: + if opened_new_file: + f.close() + return model + + +def save_weights_to_hdf5_group(f, model): + """Saves the weights of a list of layers to a HDF5 group. + + Args: + f: HDF5 group. + model: Model instance. + """ + from keras.src import __version__ as keras_version + + save_attributes_to_hdf5_group( + f, "layer_names", [layer.name.encode("utf8") for layer in model.layers] + ) + f.attrs["backend"] = backend.backend().encode("utf8") + f.attrs["keras_version"] = str(keras_version).encode("utf8") + + # Sort model layers by layer name to ensure that group names are strictly + # growing to avoid prefix issues. + for layer in sorted(model.layers, key=lambda x: x.name): + g = f.create_group(layer.name) + weights = _legacy_weights(layer) + save_subset_weights_to_hdf5_group(g, weights) + weights = list( + v + for v in model._trainable_variables + model._non_trainable_variables + if v in model.weights + ) + g = f.create_group("top_level_model_weights") + save_subset_weights_to_hdf5_group(g, weights) + + +def save_subset_weights_to_hdf5_group(f, weights): + """Save top-level weights of a model to a HDF5 group. + + Args: + f: HDF5 group. + weights: List of weight variables. + """ + weight_values = [backend.convert_to_numpy(w) for w in weights] + weight_names = [str(w.path).encode("utf8") for w in weights] + save_attributes_to_hdf5_group(f, "weight_names", weight_names) + for name, val in zip(weight_names, weight_values): + param_dset = f.create_dataset(name, val.shape, dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + + +def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer): + """Saves optimizer weights of a optimizer to a HDF5 group. + + Args: + hdf5_group: HDF5 group. + optimizer: optimizer instance. + """ + from keras.src import optimizers + + if isinstance(optimizer, optimizers.Optimizer): + symbolic_weights = optimizer.variables + else: + symbolic_weights = getattr(optimizer, "weights") + if symbolic_weights: + weights_group = hdf5_group.create_group("optimizer_weights") + weight_names = [str(w.path).encode("utf8") for w in symbolic_weights] + save_attributes_to_hdf5_group( + weights_group, "weight_names", weight_names + ) + weight_values = [backend.convert_to_numpy(w) for w in symbolic_weights] + for name, val in zip(weight_names, weight_values): + param_dset = weights_group.create_dataset( + name, val.shape, dtype=val.dtype + ) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + + +def save_attributes_to_hdf5_group(group, name, data): + """Saves attributes (data) of the specified name into the HDF5 group. + + This method deals with an inherent problem of HDF5 file which is not + able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes. + + Args: + group: A pointer to a HDF5 group. + name: A name of the attributes to save. + data: Attributes data to store. + + Raises: + RuntimeError: If any single attribute is too large to be saved. + """ + # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` + # because in that case even chunking the array would not make the saving + # possible. + bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] + + # Expecting this to never be true. + if bad_attributes: + raise RuntimeError( + "The following attributes cannot be saved to HDF5 file because " + f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} " + f"bytes: {bad_attributes}" + ) + + data_npy = np.asarray(data) + + num_chunks = 1 + chunked_data = np.array_split(data_npy, num_chunks) + + # This will never loop forever thanks to the test above. + while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): + num_chunks += 1 + chunked_data = np.array_split(data_npy, num_chunks) + + if num_chunks > 1: + for chunk_id, chunk_data in enumerate(chunked_data): + group.attrs["%s%d" % (name, chunk_id)] = chunk_data + else: + group.attrs[name] = data + + +def load_weights_from_hdf5_group(f, model, skip_mismatch=False): + """Implements topological (order-based) weight loading. + + Args: + f: A pointer to a HDF5 group. + model: Model instance. + skip_mismatch: Boolean, whether to skip loading of weights + where there is a mismatch in the shape of the weights, + + Raises: + ValueError: in case of mismatch between provided layers + and weights file. + """ + if "keras_version" in f.attrs: + original_keras_version = f.attrs["keras_version"] + if hasattr(original_keras_version, "decode"): + original_keras_version = original_keras_version.decode("utf8") + else: + original_keras_version = "1" + if "backend" in f.attrs: + original_backend = f.attrs["backend"] + if hasattr(original_backend, "decode"): + original_backend = original_backend.decode("utf8") + else: + original_backend = None + + filtered_layers = [] + for layer in model.layers: + weights = _legacy_weights(layer) + if weights: + filtered_layers.append(layer) + + layer_names = load_attributes_from_hdf5_group(f, "layer_names") + filtered_layer_names = [] + for name in layer_names: + g = f[name] + weight_names = load_attributes_from_hdf5_group(g, "weight_names") + if weight_names: + filtered_layer_names.append(name) + layer_names = filtered_layer_names + if len(layer_names) != len(filtered_layers): + raise ValueError( + "Layer count mismatch when loading weights from file. " + f"Model expected {len(filtered_layers)} layers, found " + f"{len(layer_names)} saved layers." + ) + + for k, name in enumerate(layer_names): + g = f[name] + layer = filtered_layers[k] + symbolic_weights = _legacy_weights(layer) + weight_values = load_subset_weights_from_hdf5_group(g) + if len(weight_values) != len(symbolic_weights): + raise ValueError( + f"Weight count mismatch for layer #{k} (named {layer.name} in " + f"the current model, {name} in the save file). " + f"Layer expects {len(symbolic_weights)} weight(s). Received " + f"{len(weight_values)} saved weight(s)" + ) + _set_weights( + layer, + symbolic_weights, + weight_values, + skip_mismatch=skip_mismatch, + name=f"layer #{k} (named {layer.name})", + ) + + if "top_level_model_weights" in f: + symbolic_weights = list( + # model.weights + v + for v in model._trainable_variables + model._non_trainable_variables + if v in model.weights + ) + weight_values = load_subset_weights_from_hdf5_group( + f["top_level_model_weights"] + ) + if len(weight_values) != len(symbolic_weights): + raise ValueError( + "Weight count mismatch for top-level weights when loading " + "weights from file. " + f"Model expects {len(symbolic_weights)} top-level weight(s). " + f"Received {len(weight_values)} saved top-level weight(s)" + ) + _set_weights( + model, + symbolic_weights, + weight_values, + skip_mismatch=skip_mismatch, + name="top-level model", + ) + + +def _set_weights( + instance, symbolic_weights, weight_values, name, skip_mismatch=False +): + """Safely set weights into a model or a layer. + + Args: + instance: Model or layer instance, + symbolic_weights: symbolic tensors representing + the weights of the variables to load, + weight_values: values of the weights to load, + skip_mismatch: Boolean, whether to skip loading of weights + where there is a mismatch in the shape of the weights, + name: name used to identify the group. + + Raises: + ValueError: in case of mismatch between provided + model/layer and weights. + """ + for i, weight_value in enumerate(weight_values): + expected_shape = symbolic_weights[i].shape + received_shape = weight_value.shape + if expected_shape != received_shape: + if skip_mismatch: + warnings.warn( + f"Skipping loading weights for {name}" + f"due to mismatch in shape for " + f"weight {symbolic_weights[i].path}. " + f"Weight expects shape {expected_shape}. " + "Received saved weight " + f"with shape {received_shape}", + stacklevel=2, + ) + continue + raise ValueError( + f"Shape mismatch in {name}" + f"for weight {symbolic_weights[i].path}. " + f"Weight expects shape {expected_shape}. " + "Received saved weight " + f"with shape {received_shape}" + ) + symbolic_weights[i].assign(weight_value) + + if hasattr(instance, "finalize_state") and symbolic_weights: + instance.finalize_state() + + +def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False): + """Implements name-based weight loading (instead of topological loading). + + Layers that have no matching name are skipped. + + Args: + f: A pointer to a HDF5 group. + model: Model instance. + skip_mismatch: Boolean, whether to skip loading of layers + where there is a mismatch in the number of weights, + or a mismatch in the shape of the weights. + + Raises: + ValueError: in case of mismatch between provided layers + and weights file and skip_match=False. + """ + if "keras_version" in f.attrs: + original_keras_version = f.attrs["keras_version"] + if hasattr(original_keras_version, "decode"): + original_keras_version = original_keras_version.decode("utf8") + else: + original_keras_version = "1" + if "backend" in f.attrs: + original_backend = f.attrs["backend"] + if hasattr(original_backend, "decode"): + original_backend = original_backend.decode("utf8") + else: + original_backend = None + + # New file format. + layer_names = load_attributes_from_hdf5_group(f, "layer_names") + + # Reverse index of layer name to list of layers with name. + index = {} + for layer in model.layers: + if layer.name: + index.setdefault(layer.name, []).append(layer) + + for k, name in enumerate(layer_names): + g = f[name] + weight_values = load_subset_weights_from_hdf5_group(g) + for layer in index.get(name, []): + symbolic_weights = _legacy_weights(layer) + if len(weight_values) != len(symbolic_weights): + if skip_mismatch: + warnings.warn( + f"Skipping loading of weights for layer #{k} (named " + f"{layer.name}) due to mismatch in number of weights. " + f"Layer expects {len(symbolic_weights)} weight(s). " + f"Received {len(weight_values)} saved weight(s)", + stacklevel=2, + ) + continue + raise ValueError( + f"Weight count mismatch for layer #{k} " + f"(named {layer.name}). " + f"Layer expects {len(symbolic_weights)} weight(s). " + f"Received {len(weight_values)} saved weight(s)" + ) + # Set values. + _set_weights( + layer, + symbolic_weights, + weight_values, + skip_mismatch=skip_mismatch, + name=f"layer #{k} (named {layer.name})", + ) + + if "top_level_model_weights" in f: + symbolic_weights = ( + model._trainable_variables + model._non_trainable_variables + ) + weight_values = load_subset_weights_from_hdf5_group( + f["top_level_model_weights"] + ) + + if len(weight_values) != len(symbolic_weights): + if skip_mismatch: + warnings.warn( + "Skipping loading top-level weights for model due to " + "mismatch in number of weights. " + f"Model expects {len(symbolic_weights)} " + "top-level weight(s). " + f"Received {len(weight_values)} saved top-level weight(s)", + stacklevel=2, + ) + else: + raise ValueError( + "Weight count mismatch for top-level weights of model. " + f"Model expects {len(symbolic_weights)} " + "top-level weight(s). " + f"Received {len(weight_values)} saved top-level weight(s)" + ) + else: + _set_weights( + model, + symbolic_weights, + weight_values, + skip_mismatch=skip_mismatch, + name="top-level model", + ) + + +def load_subset_weights_from_hdf5_group(f): + """Load layer weights of a model from hdf5. + + Args: + f: A pointer to a HDF5 group. + + Returns: + List of NumPy arrays of the weight values. + + Raises: + ValueError: in case of mismatch between provided model + and weights file. + """ + weight_names = load_attributes_from_hdf5_group(f, "weight_names") + return [np.asarray(f[weight_name]) for weight_name in weight_names] + + +def load_optimizer_weights_from_hdf5_group(hdf5_group): + """Load optimizer weights from a HDF5 group. + + Args: + hdf5_group: A pointer to a HDF5 group. + + Returns: + data: List of optimizer weight names. + """ + weights_group = hdf5_group["optimizer_weights"] + optimizer_weight_names = load_attributes_from_hdf5_group( + weights_group, "weight_names" + ) + return [ + weights_group[weight_name] for weight_name in optimizer_weight_names + ] + + +def load_attributes_from_hdf5_group(group, name): + """Loads attributes of the specified name from the HDF5 group. + + This method deals with an inherent problem + of HDF5 file which is not able to store + data larger than HDF5_OBJECT_HEADER_LIMIT bytes. + + Args: + group: A pointer to a HDF5 group. + name: A name of the attributes to load. + + Returns: + data: Attributes data. + """ + if name in group.attrs: + data = [ + n.decode("utf8") if hasattr(n, "decode") else n + for n in group.attrs[name] + ] + else: + data = [] + chunk_id = 0 + while f"{name}{chunk_id}" in group.attrs: + data.extend( + [ + n.decode("utf8") if hasattr(n, "decode") else n + for n in group.attrs[f"{name}{chunk_id}"] + ] + ) + chunk_id += 1 + return data + + +def _legacy_weights(layer): + """Legacy weight order converter. + + For legacy reason, the layer.weights was in the order of + [self.trainable_weights + self.non_trainable_weights], and this order was + used for preserving the weights in h5 format. The new order of layer.weights + are the same as layer.get_weights() which is more intuitive for user. To + keep supporting the existing saved h5 file, this method should be used to + save/load weights. + + Args: + layer: a `Model` or `Layer` instance. + + Returns: + A list of variables with the legacy weight order. + """ + return layer.trainable_weights + layer.non_trainable_weights diff --git a/keras/src/legacy/saving/legacy_h5_format_test.py b/keras/src/legacy/saving/legacy_h5_format_test.py new file mode 100644 index 000000000000..1588150300cf --- /dev/null +++ b/keras/src/legacy/saving/legacy_h5_format_test.py @@ -0,0 +1,547 @@ +import os + +import numpy as np +import pytest + +import keras +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.legacy.saving import legacy_h5_format +from keras.src.saving import object_registration +from keras.src.saving import serialization_lib + +# TODO: more thorough testing. Correctness depends +# on exact weight ordering for each layer, so we need +# to test across all types of layers. + +try: + import tf_keras +except: + tf_keras = None + + +def get_sequential_model(keras): + return keras.Sequential( + [ + keras.layers.Input((3,), batch_size=2), + keras.layers.Dense(4, activation="relu"), + keras.layers.BatchNormalization( + moving_mean_initializer="uniform", gamma_initializer="uniform" + ), + keras.layers.Dense(5, activation="softmax"), + ] + ) + + +def get_functional_model(keras): + inputs = keras.Input((3,), batch_size=2) + x = keras.layers.Dense(4, activation="relu")(inputs) + residual = x + x = keras.layers.BatchNormalization( + moving_mean_initializer="uniform", gamma_initializer="uniform" + )(x) + x = keras.layers.Dense(4, activation="relu")(x) + x = keras.layers.add([x, residual]) + outputs = keras.layers.Dense(5, activation="softmax")(x) + return keras.Model(inputs, outputs) + + +def get_subclassed_model(keras): + class MyModel(keras.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense_1 = keras.layers.Dense(3, activation="relu") + self.dense_2 = keras.layers.Dense(1, activation="sigmoid") + + # top_level_model_weights + self.bias = self.add_weight( + name="bias", + shape=[1], + trainable=True, + initializer=keras.initializers.Zeros(), + ) + + def call(self, x): + x = self.dense_1(x) + x = self.dense_2(x) + + # top_level_model_weights + x += ops.cast(self.bias, x.dtype) + return x + + model = MyModel() + model(np.random.random((2, 3))) + return model + + +@pytest.mark.requires_trainable_backend +@pytest.mark.skipif(tf_keras is None, reason="Test requires tf_keras") +class LegacyH5WeightsTest(testing.TestCase): + def _check_reloading_weights(self, ref_input, model, tf_keras_model): + ref_output = tf_keras_model(ref_input) + initial_weights = model.get_weights() + # Check weights only file + temp_filepath = os.path.join(self.get_temp_dir(), "weights.h5") + tf_keras_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + output = model(ref_input) + self.assertAllClose(ref_output, output, atol=1e-5) + model.set_weights(initial_weights) + model.load_weights(temp_filepath) + output = model(ref_input) + self.assertAllClose(ref_output, output, atol=1e-5) + + def test_sequential_model_weights(self): + model = get_sequential_model(keras) + tf_keras_model = get_sequential_model(tf_keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_weights(ref_input, model, tf_keras_model) + + def test_functional_model_weights(self): + model = get_functional_model(keras) + tf_keras_model = get_functional_model(tf_keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_weights(ref_input, model, tf_keras_model) + + def test_subclassed_model_weights(self): + model = get_subclassed_model(keras) + tf_keras_model = get_subclassed_model(tf_keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_weights(ref_input, model, tf_keras_model) + + +@pytest.mark.requires_trainable_backend +class LegacyH5WholeModelTest(testing.TestCase): + def _check_reloading_model(self, ref_input, model): + # Whole model file + ref_output = model(ref_input) + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + legacy_h5_format.save_model_to_hdf5(model, temp_filepath) + loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath) + output = loaded(ref_input) + self.assertAllClose(ref_output, output, atol=1e-5) + + def test_sequential_model(self): + model = get_sequential_model(keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_model(ref_input, model) + + def test_functional_model(self): + model = get_functional_model(keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_model(ref_input, model) + + def test_compiled_model_with_various_layers(self): + model = models.Sequential() + model.add(layers.Dense(2, input_shape=(3,))) + model.add(layers.RepeatVector(3)) + model.add(layers.TimeDistributed(layers.Dense(3))) + + model.compile(optimizer="rmsprop", loss="mean_squared_error") + ref_input = np.random.random((1, 3)) + self._check_reloading_model(ref_input, model) + + def test_saving_lambda(self): + mean = ops.random.uniform((4, 2, 3)) + std = ops.abs(ops.random.uniform((4, 2, 3))) + 1e-5 + inputs = layers.Input(shape=(4, 2, 3)) + output = layers.Lambda( + lambda image, mu, std: (image - mu) / std, + arguments={"mu": mean, "std": std}, + )(inputs) + model = models.Model(inputs, output) + model.compile( + loss="mean_squared_error", optimizer="sgd", metrics=["acc"] + ) + + temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5") + legacy_h5_format.save_model_to_hdf5(model, temp_filepath) + + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + legacy_h5_format.load_model_from_hdf5(temp_filepath) + + loaded = legacy_h5_format.load_model_from_hdf5( + temp_filepath, safe_mode=False + ) + self.assertAllClose(mean, loaded.layers[1].arguments["mu"]) + self.assertAllClose(std, loaded.layers[1].arguments["std"]) + + def test_saving_include_optimizer_false(self): + model = models.Sequential() + model.add(layers.Dense(1)) + model.compile("adam", loss="mean_squared_error") + x, y = np.ones((10, 10)), np.ones((10, 1)) + model.fit(x, y) + ref_output = model(x) + + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + legacy_h5_format.save_model_to_hdf5( + model, temp_filepath, include_optimizer=False + ) + loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath) + output = loaded(x) + + # Assert that optimizer does not exist in loaded model + with self.assertRaises(AttributeError): + _ = loaded.optimizer + + # Compare output + self.assertAllClose(ref_output, output, atol=1e-5) + + def test_custom_sequential_registered_no_scope(self): + @object_registration.register_keras_serializable(package="my_package") + class MyDense(layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + inputs = layers.Input(shape=[1]) + custom_layer = MyDense(1) + model = models.Sequential(layers=[inputs, custom_layer]) + + ref_input = np.array([5]) + self._check_reloading_model(ref_input, model) + + def test_custom_functional_registered_no_scope(self): + @object_registration.register_keras_serializable(package="my_package") + class MyDense(layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + inputs = layers.Input(shape=[1]) + outputs = MyDense(1)(inputs) + model = models.Model(inputs, outputs) + + ref_input = np.array([5]) + self._check_reloading_model(ref_input, model) + + def test_nested_layers(self): + class MyLayer(layers.Layer): + def __init__(self, sublayers, **kwargs): + super().__init__(**kwargs) + self.sublayers = sublayers + + def call(self, x): + prev_input = x + for layer in self.sublayers: + prev_input = layer(prev_input) + return prev_input + + def get_config(self): + config = super().get_config() + config["sublayers"] = serialization_lib.serialize_keras_object( + self.sublayers + ) + return config + + @classmethod + def from_config(cls, config): + config["sublayers"] = ( + serialization_lib.deserialize_keras_object( + config["sublayers"] + ) + ) + return cls(**config) + + @object_registration.register_keras_serializable(package="Foo") + class RegisteredSubLayer(layers.Layer): + pass + + layer = MyLayer( + [ + layers.Dense(2, name="MyDense"), + RegisteredSubLayer(name="MySubLayer"), + ] + ) + model = models.Sequential([layer]) + with self.subTest("test_JSON"): + from keras.src.models.model import model_from_json + + model_json = model.to_json() + self.assertIn("Foo>RegisteredSubLayer", model_json) + + loaded_model = model_from_json( + model_json, custom_objects={"MyLayer": MyLayer} + ) + loaded_layer = loaded_model.layers[0] + + self.assertIsInstance(loaded_layer.sublayers[0], layers.Dense) + self.assertEqual(loaded_layer.sublayers[0].name, "MyDense") + self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer) + self.assertEqual(loaded_layer.sublayers[1].name, "MySubLayer") + + with self.subTest("test_H5"): + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + legacy_h5_format.save_model_to_hdf5(model, temp_filepath) + loaded_model = legacy_h5_format.load_model_from_hdf5( + temp_filepath, custom_objects={"MyLayer": MyLayer} + ) + loaded_layer = loaded_model.layers[0] + + self.assertIsInstance(loaded_layer.sublayers[0], layers.Dense) + self.assertEqual(loaded_layer.sublayers[0].name, "MyDense") + self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer) + self.assertEqual(loaded_layer.sublayers[1].name, "MySubLayer") + + def test_model_loading_with_axis_arg(self): + input1 = layers.Input(shape=(1, 4), name="input1") + input2 = layers.Input(shape=(1, 4), name="input2") + concat1 = layers.Concatenate(axis=1)([input1, input2]) + output = layers.Dense(1, activation="sigmoid")(concat1) + model = models.Model(inputs=[input1, input2], outputs=output) + model.compile( + optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] + ) + temp_filepath = os.path.join( + self.get_temp_dir(), "model_with_axis_arg.h5" + ) + legacy_h5_format.save_model_to_hdf5(model, temp_filepath) + legacy_h5_format.load_model_from_hdf5(temp_filepath) + + +@pytest.mark.requires_trainable_backend +@pytest.mark.skipif(tf_keras is None, reason="Test requires tf_keras") +class LegacyH5BackwardsCompatTest(testing.TestCase): + def _check_reloading_model(self, ref_input, model, tf_keras_model): + # Whole model file + ref_output = tf_keras_model(ref_input) + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + tf_keras_model.save(temp_filepath) + loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath) + output = loaded(ref_input) + self.assertAllClose(ref_output, output, atol=1e-5) + + def test_sequential_model(self): + model = get_sequential_model(keras) + tf_keras_model = get_sequential_model(tf_keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_functional_model(self): + tf_keras_model = get_functional_model(tf_keras) + model = get_functional_model(keras) + ref_input = np.random.random((2, 3)) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_compiled_model_with_various_layers(self): + model = models.Sequential() + model.add(layers.Dense(2, input_shape=(3,))) + model.add(layers.RepeatVector(3)) + model.add(layers.TimeDistributed(layers.Dense(3))) + model.compile(optimizer="rmsprop", loss="mse") + + tf_keras_model = tf_keras.Sequential() + tf_keras_model.add(tf_keras.layers.Dense(2, input_shape=(3,))) + tf_keras_model.add(tf_keras.layers.RepeatVector(3)) + tf_keras_model.add( + tf_keras.layers.TimeDistributed(tf_keras.layers.Dense(3)) + ) + tf_keras_model.compile(optimizer="rmsprop", loss="mean_squared_error") + + ref_input = np.random.random((1, 3)) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_saving_lambda(self): + mean = np.random.random((4, 2, 3)) + std = np.abs(np.random.random((4, 2, 3))) + 1e-5 + inputs = tf_keras.layers.Input(shape=(4, 2, 3)) + output = tf_keras.layers.Lambda( + lambda image, mu, std: (image - mu) / std, + arguments={"mu": mean, "std": std}, + output_shape=inputs.shape, + )(inputs) + tf_keras_model = tf_keras.Model(inputs, output) + tf_keras_model.compile( + loss="mean_squared_error", optimizer="sgd", metrics=["acc"] + ) + + temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5") + tf_keras_model.save(temp_filepath) + + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + legacy_h5_format.load_model_from_hdf5(temp_filepath) + + loaded = legacy_h5_format.load_model_from_hdf5( + temp_filepath, safe_mode=False + ) + self.assertAllClose(mean, loaded.layers[1].arguments["mu"]) + self.assertAllClose(std, loaded.layers[1].arguments["std"]) + + def test_saving_include_optimizer_false(self): + tf_keras_model = tf_keras.Sequential() + tf_keras_model.add(tf_keras.layers.Dense(1)) + tf_keras_model.compile("adam", loss="mse") + x, y = np.ones((10, 10)), np.ones((10, 1)) + tf_keras_model.fit(x, y) + ref_output = tf_keras_model(x) + + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + tf_keras_model.save(temp_filepath, include_optimizer=False) + loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath) + output = loaded(x) + + # Assert that optimizer does not exist in loaded model + with self.assertRaises(AttributeError): + _ = loaded.optimizer + + # Compare output + self.assertAllClose(ref_output, output, atol=1e-5) + + def test_custom_sequential_registered_no_scope(self): + @tf_keras.saving.register_keras_serializable(package="my_package") + class MyDense(tf_keras.layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + inputs = tf_keras.layers.Input(shape=[1]) + custom_layer = MyDense(1) + tf_keras_model = tf_keras.Sequential(layers=[inputs, custom_layer]) + + # Re-implement and re-register in Keras 3 + @object_registration.register_keras_serializable(package="my_package") + class MyDense(layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + inputs = layers.Input(shape=[1]) + custom_layer = MyDense(1) + model = models.Sequential(layers=[inputs, custom_layer]) + + ref_input = np.array([5]) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_custom_functional_registered_no_scope(self): + @tf_keras.saving.register_keras_serializable(package="my_package") + class MyDense(tf_keras.layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + inputs = tf_keras.layers.Input(shape=[1]) + outputs = MyDense(1)(inputs) + tf_keras_model = tf_keras.Model(inputs, outputs) + + # Re-implement and re-register in Keras 3 + @object_registration.register_keras_serializable(package="my_package") + class MyDense(layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + + inputs = layers.Input(shape=[1]) + outputs = MyDense(1)(inputs) + model = models.Model(inputs, outputs) + + ref_input = np.array([5]) + self._check_reloading_model(ref_input, model, tf_keras_model) + + def test_nested_layers(self): + class MyLayer(tf_keras.layers.Layer): + def __init__(self, sublayers, **kwargs): + super().__init__(**kwargs) + self.sublayers = sublayers + + def call(self, x): + prev_input = x + for layer in self.sublayers: + prev_input = layer(prev_input) + return prev_input + + def get_config(self): + config = super().get_config() + config["sublayers"] = tf_keras.saving.serialize_keras_object( + self.sublayers + ) + return config + + @classmethod + def from_config(cls, config): + config["sublayers"] = tf_keras.saving.deserialize_keras_object( + config["sublayers"] + ) + return cls(**config) + + @tf_keras.saving.register_keras_serializable(package="Foo") + class RegisteredSubLayer(layers.Layer): + def call(self, x): + return x + + layer = MyLayer( + [ + tf_keras.layers.Dense(2, name="MyDense"), + RegisteredSubLayer(name="MySubLayer"), + ] + ) + tf_keras_model = tf_keras.Sequential([layer]) + + x = np.random.random((4, 2)) + ref_output = tf_keras_model(x) + + # Save TF Keras model to H5 file + temp_filepath = os.path.join(self.get_temp_dir(), "model.h5") + tf_keras_model.save(temp_filepath) + + # Re-implement in Keras 3 + class MyLayer(layers.Layer): + def __init__(self, sublayers, **kwargs): + super().__init__(**kwargs) + self.sublayers = sublayers + + def call(self, x): + prev_input = x + for layer in self.sublayers: + prev_input = layer(prev_input) + return prev_input + + def get_config(self): + config = super().get_config() + config["sublayers"] = serialization_lib.serialize_keras_object( + self.sublayers + ) + return config + + @classmethod + def from_config(cls, config): + config["sublayers"] = ( + serialization_lib.deserialize_keras_object( + config["sublayers"] + ) + ) + return cls(**config) + + # Re-implement and re-register in Keras 3 + @object_registration.register_keras_serializable(package="Foo") + class RegisteredSubLayer(layers.Layer): + def call(self, x): + return x + + # Load in Keras 3 + loaded_model = legacy_h5_format.load_model_from_hdf5( + temp_filepath, custom_objects={"MyLayer": MyLayer} + ) + loaded_layer = loaded_model.layers[0] + output = loaded_model(x) + + # Ensure nested layer structure + self.assertIsInstance(loaded_layer.sublayers[0], layers.Dense) + self.assertEqual(loaded_layer.sublayers[0].name, "MyDense") + self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer) + self.assertEqual(loaded_layer.sublayers[1].name, "MySubLayer") + + # Compare output + self.assertAllClose(ref_output, output, atol=1e-5) + + +@pytest.mark.requires_trainable_backend +class DirectoryCreationTest(testing.TestCase): + def test_directory_creation_on_save(self): + """Test if directory is created on model save.""" + model = get_sequential_model(keras) + nested_dirpath = os.path.join( + self.get_temp_dir(), "dir1", "dir2", "dir3" + ) + filepath = os.path.join(nested_dirpath, "model.h5") + self.assertFalse(os.path.exists(nested_dirpath)) + legacy_h5_format.save_model_to_hdf5(model, filepath) + self.assertTrue(os.path.exists(nested_dirpath)) + loaded_model = legacy_h5_format.load_model_from_hdf5(filepath) + self.assertEqual(model.to_json(), loaded_model.to_json()) diff --git a/keras/src/legacy/saving/saving_options.py b/keras/src/legacy/saving/saving_options.py new file mode 100644 index 000000000000..6f270fb23290 --- /dev/null +++ b/keras/src/legacy/saving/saving_options.py @@ -0,0 +1,17 @@ +import contextlib + +from keras.src.backend.common import global_state + + +@contextlib.contextmanager +def keras_option_scope(use_legacy_config=True): + use_legacy_config_prev_value = global_state.get_global_attribute( + "use_legacy_config", None + ) + global_state.set_global_attribute("use_legacy_config", use_legacy_config) + try: + yield + finally: + global_state.set_global_attribute( + "use_legacy_config", use_legacy_config_prev_value + ) diff --git a/keras/src/legacy/saving/saving_utils.py b/keras/src/legacy/saving/saving_utils.py new file mode 100644 index 000000000000..62d1222aed4b --- /dev/null +++ b/keras/src/legacy/saving/saving_utils.py @@ -0,0 +1,253 @@ +import threading + +from absl import logging + +from keras.src import backend +from keras.src import losses +from keras.src import metrics as metrics_module +from keras.src import tree +from keras.src.legacy.saving import serialization +from keras.src.saving import object_registration + +MODULE_OBJECTS = threading.local() + +# Legacy lambda arguments not found in Keras 3 +LAMBDA_DEP_ARGS = ( + "module", + "function_type", + "output_shape_type", + "output_shape_module", +) + + +def model_from_config(config, custom_objects=None): + """Instantiates a Keras model from its config. + + Args: + config: Configuration dictionary. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + + Raises: + TypeError: if `config` is not a dictionary. + """ + if isinstance(config, list): + raise TypeError( + "`model_from_config` expects a dictionary, not a list. " + f"Received: config={config}. Did you meant to use " + "`Sequential.from_config(config)`?" + ) + + global MODULE_OBJECTS + + if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"): + from keras.src import layers + from keras.src import models + + MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__ + MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer + MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional + MODULE_OBJECTS.ALL_OBJECTS["Model"] = models.Model + MODULE_OBJECTS.ALL_OBJECTS["Sequential"] = models.Sequential + + batch_input_shape = config["config"].pop("batch_input_shape", None) + if batch_input_shape is not None: + if config["class_name"] == "InputLayer": + config["config"]["batch_shape"] = batch_input_shape + else: + config["config"]["input_shape"] = batch_input_shape + + axis = config["config"].pop("axis", None) + if axis is not None: + if isinstance(axis, list) and len(axis) == 1: + config["config"]["axis"] = int(axis[0]) + elif isinstance(axis, (int, float)): + config["config"]["axis"] = int(axis) + + # Handle backwards compatibility for Keras lambdas + if config["class_name"] == "Lambda": + for dep_arg in LAMBDA_DEP_ARGS: + _ = config["config"].pop(dep_arg, None) + function_config = config["config"]["function"] + if isinstance(function_config, list): + function_dict = {"class_name": "__lambda__", "config": {}} + function_dict["config"]["code"] = function_config[0] + function_dict["config"]["defaults"] = function_config[1] + function_dict["config"]["closure"] = function_config[2] + config["config"]["function"] = function_dict + + return serialization.deserialize_keras_object( + config, + module_objects=MODULE_OBJECTS.ALL_OBJECTS, + custom_objects=custom_objects, + printable_module_name="layer", + ) + + +def model_metadata(model, include_optimizer=True, require_config=True): + """Returns a dictionary containing the model metadata.""" + from keras.src import __version__ as keras_version + + model_config = {"class_name": model.__class__.__name__} + try: + model_config["config"] = model.get_config() + except NotImplementedError as e: + if require_config: + raise e + + metadata = dict( + keras_version=str(keras_version), + backend=backend.backend(), + model_config=model_config, + ) + if getattr(model, "optimizer", False) and include_optimizer: + if model.compiled: + training_config = model._compile_config.config + training_config.pop("optimizer", None) # Handled separately. + metadata["training_config"] = _serialize_nested_config( + training_config + ) + optimizer_config = { + "class_name": object_registration.get_registered_name( + model.optimizer.__class__ + ), + "config": model.optimizer.get_config(), + } + metadata["training_config"]["optimizer_config"] = optimizer_config + return metadata + + +def compile_args_from_training_config(training_config, custom_objects=None): + """Return model.compile arguments from training config.""" + if custom_objects is None: + custom_objects = {} + + with object_registration.CustomObjectScope(custom_objects): + from keras.src import optimizers + + optimizer_config = training_config["optimizer_config"] + optimizer = optimizers.deserialize(optimizer_config) + # Ensure backwards compatibility for optimizers in legacy H5 files + optimizer = _resolve_compile_arguments_compat( + optimizer, optimizer_config, optimizers + ) + + # Recover losses. + loss = None + loss_config = training_config.get("loss", None) + if loss_config is not None: + loss = _deserialize_nested_config(losses.deserialize, loss_config) + # Ensure backwards compatibility for losses in legacy H5 files + loss = _resolve_compile_arguments_compat(loss, loss_config, losses) + + # Recover metrics. + metrics = None + metrics_config = training_config.get("metrics", None) + if metrics_config is not None: + metrics = _deserialize_nested_config( + _deserialize_metric, metrics_config + ) + # Ensure backwards compatibility for metrics in legacy H5 files + metrics = _resolve_compile_arguments_compat( + metrics, metrics_config, metrics_module + ) + + # Recover weighted metrics. + weighted_metrics = None + weighted_metrics_config = training_config.get("weighted_metrics", None) + if weighted_metrics_config is not None: + weighted_metrics = _deserialize_nested_config( + _deserialize_metric, weighted_metrics_config + ) + + loss_weights = training_config["loss_weights"] + + return dict( + optimizer=optimizer, + loss=loss, + metrics=metrics, + weighted_metrics=weighted_metrics, + loss_weights=loss_weights, + ) + + +def _serialize_nested_config(config): + """Serialized a nested structure of Keras objects.""" + + def _serialize_fn(obj): + if callable(obj): + return serialization.serialize_keras_object(obj) + return obj + + return tree.map_structure(_serialize_fn, config) + + +def _deserialize_nested_config(deserialize_fn, config): + """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" + + def _is_single_object(obj): + if isinstance(obj, dict) and "class_name" in obj: + return True # Serialized Keras object. + if isinstance(obj, str): + return True # Serialized function or string. + return False + + if config is None: + return None + if _is_single_object(config): + return deserialize_fn(config) + elif isinstance(config, dict): + return { + k: _deserialize_nested_config(deserialize_fn, v) + for k, v in config.items() + } + elif isinstance(config, (tuple, list)): + return [ + _deserialize_nested_config(deserialize_fn, obj) for obj in config + ] + + raise ValueError( + "Saved configuration not understood. Configuration should be a " + f"dictionary, string, tuple or list. Received: config={config}." + ) + + +def _deserialize_metric(metric_config): + """Deserialize metrics, leaving special strings untouched.""" + if metric_config in ["accuracy", "acc", "crossentropy", "ce"]: + # Do not deserialize accuracy and cross-entropy strings as we have + # special case handling for these in compile, based on model output + # shape. + return metric_config + return metrics_module.deserialize(metric_config) + + +def _resolve_compile_arguments_compat(obj, obj_config, module): + """Resolves backwards compatibility issues with training config arguments. + + This helper function accepts built-in Keras modules such as optimizers, + losses, and metrics to ensure an object being deserialized is compatible + with Keras 3 built-ins. For legacy H5 files saved within Keras 3, + this does nothing. + """ + if isinstance(obj, str) and obj not in module.ALL_OBJECTS_DICT: + obj = module.get(obj_config["config"]["name"]) + return obj + + +def try_build_compiled_arguments(model): + try: + if not model.compiled_loss.built: + model.compiled_loss.build(model.outputs) + if not model.compiled_metrics.built: + model.compiled_metrics.build(model.outputs, model.outputs) + except: + logging.warning( + "Compiled the loaded model, but the compiled metrics have " + "yet to be built. `model.compile_metrics` will be empty " + "until you train or evaluate the model." + ) diff --git a/keras/src/legacy/saving/serialization.py b/keras/src/legacy/saving/serialization.py new file mode 100644 index 000000000000..8474363895f2 --- /dev/null +++ b/keras/src/legacy/saving/serialization.py @@ -0,0 +1,560 @@ +"""Legacy serialization logic for Keras models.""" + +import contextlib +import inspect +import threading +import weakref + +# isort: off +from keras.src.api_export import keras_export +from keras.src.saving import object_registration + +# Flag that determines whether to skip the NotImplementedError when calling +# get_config in custom models and layers. This is only enabled when saving to +# SavedModel, when the config isn't required. +_SKIP_FAILED_SERIALIZATION = False +# If a layer does not have a defined config, then the returned config will be a +# dictionary with the below key. +_LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config" + +# Store a unique, per-object ID for shared objects. +# +# We store a unique ID for each object so that we may, at loading time, +# re-create the network properly. Without this ID, we would have no way of +# determining whether a config is a description of a new object that +# should be created or is merely a reference to an already-created object. +SHARED_OBJECT_KEY = "shared_object_id" + +SHARED_OBJECT_DISABLED = threading.local() +SHARED_OBJECT_LOADING = threading.local() +SHARED_OBJECT_SAVING = threading.local() + + +# Attributes on the threadlocal variable must be set per-thread, thus we +# cannot initialize these globally. Instead, we have accessor functions with +# default values. +def _shared_object_disabled(): + """Get whether shared object handling is disabled in a threadsafe manner.""" + return getattr(SHARED_OBJECT_DISABLED, "disabled", False) + + +def _shared_object_loading_scope(): + """Get the current shared object saving scope in a threadsafe manner.""" + return getattr(SHARED_OBJECT_LOADING, "scope", NoopLoadingScope()) + + +def _shared_object_saving_scope(): + """Get the current shared object saving scope in a threadsafe manner.""" + return getattr(SHARED_OBJECT_SAVING, "scope", None) + + +class DisableSharedObjectScope: + """A context manager for disabling handling of shared objects. + + Disables shared object handling for both saving and loading. + + Created primarily for use with `clone_model`, which does extra surgery that + is incompatible with shared objects. + """ + + def __enter__(self): + SHARED_OBJECT_DISABLED.disabled = True + self._orig_loading_scope = _shared_object_loading_scope() + self._orig_saving_scope = _shared_object_saving_scope() + + def __exit__(self, *args, **kwargs): + SHARED_OBJECT_DISABLED.disabled = False + SHARED_OBJECT_LOADING.scope = self._orig_loading_scope + SHARED_OBJECT_SAVING.scope = self._orig_saving_scope + + +class NoopLoadingScope: + """The default shared object loading scope. It does nothing. + + Created to simplify serialization code that doesn't care about shared + objects (e.g. when serializing a single object). + """ + + def get(self, unused_object_id): + return None + + def set(self, object_id, obj): + pass + + +class SharedObjectLoadingScope: + """A context manager for keeping track of loaded objects. + + During the deserialization process, we may come across objects that are + shared across multiple layers. In order to accurately restore the network + structure to its original state, `SharedObjectLoadingScope` allows us to + re-use shared objects rather than cloning them. + """ + + def __enter__(self): + if _shared_object_disabled(): + return NoopLoadingScope() + + global SHARED_OBJECT_LOADING + SHARED_OBJECT_LOADING.scope = self + self._obj_ids_to_obj = {} + return self + + def get(self, object_id): + """Given a shared object ID, returns a previously instantiated object. + + Args: + object_id: shared object ID to use when attempting to find + already-loaded object. + + Returns: + The object, if we've seen this ID before. Else, `None`. + """ + # Explicitly check for `None` internally to make external calling code a + # bit cleaner. + if object_id is None: + return + return self._obj_ids_to_obj.get(object_id) + + def set(self, object_id, obj): + """Stores an instantiated object for future lookup and sharing.""" + if object_id is None: + return + self._obj_ids_to_obj[object_id] = obj + + def __exit__(self, *args, **kwargs): + global SHARED_OBJECT_LOADING + SHARED_OBJECT_LOADING.scope = NoopLoadingScope() + + +class SharedObjectConfig(dict): + """A configuration container that keeps track of references. + + `SharedObjectConfig` will automatically attach a shared object ID to any + configs which are referenced more than once, allowing for proper shared + object reconstruction at load time. + + In most cases, it would be more proper to subclass something like + `collections.UserDict` or `collections.Mapping` rather than `dict` directly. + Unfortunately, python's json encoder does not support `Mapping`s. This is + important functionality to retain, since we are dealing with serialization. + + We should be safe to subclass `dict` here, since we aren't actually + overriding any core methods, only augmenting with a new one for reference + counting. + """ + + def __init__(self, base_config, object_id, **kwargs): + self.ref_count = 1 + self.object_id = object_id + super().__init__(base_config, **kwargs) + + def increment_ref_count(self): + # As soon as we've seen the object more than once, we want to attach the + # shared object ID. This allows us to only attach the shared object ID + # when it's strictly necessary, making backwards compatibility breakage + # less likely. + if self.ref_count == 1: + self[SHARED_OBJECT_KEY] = self.object_id + self.ref_count += 1 + + +class SharedObjectSavingScope: + """Keeps track of shared object configs when serializing.""" + + def __enter__(self): + if _shared_object_disabled(): + return None + + global SHARED_OBJECT_SAVING + + # Serialization can happen at a number of layers for a number of + # reasons. We may end up with a case where we're opening a saving scope + # within another saving scope. In that case, we'd like to use the + # outermost scope available and ignore inner scopes, since there is not + # (yet) a reasonable use case for having these nested and distinct. + if _shared_object_saving_scope() is not None: + self._passthrough = True + return _shared_object_saving_scope() + else: + self._passthrough = False + + SHARED_OBJECT_SAVING.scope = self + self._shared_objects_config = weakref.WeakKeyDictionary() + self._next_id = 0 + return self + + def get_config(self, obj): + """Gets a `SharedObjectConfig` if one has already been seen for `obj`. + + Args: + obj: The object for which to retrieve the `SharedObjectConfig`. + + Returns: + The SharedObjectConfig for a given object, if already seen. Else, + `None`. + """ + try: + shared_object_config = self._shared_objects_config[obj] + except (TypeError, KeyError): + # If the object is unhashable (e.g. a subclass of + # `AbstractBaseClass` that has not overridden `__hash__`), a + # `TypeError` will be thrown. We'll just continue on without shared + # object support. + return None + shared_object_config.increment_ref_count() + return shared_object_config + + def create_config(self, base_config, obj): + """Create a new SharedObjectConfig for a given object.""" + shared_object_config = SharedObjectConfig(base_config, self._next_id) + self._next_id += 1 + try: + self._shared_objects_config[obj] = shared_object_config + except TypeError: + # If the object is unhashable (e.g. a subclass of + # `AbstractBaseClass` that has not overridden `__hash__`), a + # `TypeError` will be thrown. We'll just continue on without shared + # object support. + pass + return shared_object_config + + def __exit__(self, *args, **kwargs): + if not getattr(self, "_passthrough", False): + global SHARED_OBJECT_SAVING + SHARED_OBJECT_SAVING.scope = None + + +def serialize_keras_class_and_config( + cls_name, cls_config, obj=None, shared_object_id=None +): + """Returns the serialization of the class with the given config.""" + base_config = {"class_name": cls_name, "config": cls_config} + + # We call `serialize_keras_class_and_config` for some branches of the load + # path. In that case, we may already have a shared object ID we'd like to + # retain. + if shared_object_id is not None: + base_config[SHARED_OBJECT_KEY] = shared_object_id + + # If we have an active `SharedObjectSavingScope`, check whether we've + # already serialized this config. If so, just use that config. This will + # store an extra ID field in the config, allowing us to re-create the shared + # object relationship at load time. + if _shared_object_saving_scope() is not None and obj is not None: + shared_object_config = _shared_object_saving_scope().get_config(obj) + if shared_object_config is None: + return _shared_object_saving_scope().create_config(base_config, obj) + return shared_object_config + + return base_config + + +@contextlib.contextmanager +def skip_failed_serialization(): + global _SKIP_FAILED_SERIALIZATION + prev = _SKIP_FAILED_SERIALIZATION + try: + _SKIP_FAILED_SERIALIZATION = True + yield + finally: + _SKIP_FAILED_SERIALIZATION = prev + + +@keras_export( + [ + "keras.legacy.saving.serialize_keras_object", + "keras.utils.legacy.serialize_keras_object", + ] +) +def serialize_keras_object(instance): + """Serialize a Keras object into a JSON-compatible representation. + + Calls to `serialize_keras_object` while underneath the + `SharedObjectSavingScope` context manager will cause any objects re-used + across multiple layers to be saved with a special shared object ID. This + allows the network to be re-created properly during deserialization. + + Args: + instance: The object to serialize. + + Returns: + A dict-like, JSON-compatible representation of the object's config. + """ + + # _, instance = tf.__internal__.decorator.unwrap(instance) + instance = inspect.unwrap(instance) + if instance is None: + return None + + if hasattr(instance, "get_config"): + name = object_registration.get_registered_name(instance.__class__) + try: + config = instance.get_config() + except NotImplementedError as e: + if _SKIP_FAILED_SERIALIZATION: + return serialize_keras_class_and_config( + name, {_LAYER_UNDEFINED_CONFIG_KEY: True} + ) + raise e + serialization_config = {} + for key, item in config.items(): + if isinstance(item, str): + serialization_config[key] = item + continue + + # Any object of a different type needs to be converted to string or + # dict for serialization (e.g. custom functions, custom classes) + try: + serialized_item = serialize_keras_object(item) + if isinstance(serialized_item, dict) and not isinstance( + item, dict + ): + serialized_item["__passive_serialization__"] = True + serialization_config[key] = serialized_item + except ValueError: + serialization_config[key] = item + + name = object_registration.get_registered_name(instance.__class__) + return serialize_keras_class_and_config( + name, serialization_config, instance + ) + if hasattr(instance, "__name__"): + return object_registration.get_registered_name(instance) + raise ValueError( + f"Cannot serialize {instance} because it doesn't implement " + "`get_config()`." + ) + + +def class_and_config_for_serialized_keras_object( + config, + module_objects=None, + custom_objects=None, + printable_module_name="object", +): + """Returns the class name and config for a serialized keras object.""" + + if ( + not isinstance(config, dict) + or "class_name" not in config + or "config" not in config + ): + raise ValueError( + f"Improper config format for {config}. " + "Expecting python dict contains `class_name` and `config` as keys" + ) + + class_name = config["class_name"] + cls = object_registration.get_registered_object( + class_name, custom_objects, module_objects + ) + if cls is None: + raise ValueError( + f"Unknown {printable_module_name}: '{class_name}'. " + "Please ensure you are using a `keras.utils.custom_object_scope` " + "and that this object is included in the scope. See " + "https://www.tensorflow.org/guide/keras/save_and_serialize" + "#registering_the_custom_object for details." + ) + + cls_config = config["config"] + # Check if `cls_config` is a list. If it is a list, return the class and the + # associated class configs for recursively deserialization. This case will + # happen on the old version of sequential model (e.g. `keras_version` == + # "2.0.6"), which is serialized in a different structure, for example + # "{'class_name': 'Sequential', + # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". + if isinstance(cls_config, list): + return (cls, cls_config) + + deserialized_objects = {} + for key, item in cls_config.items(): + if key == "name": + # Assume that the value of 'name' is a string that should not be + # deserialized as a function. This avoids the corner case where + # cls_config['name'] has an identical name to a custom function and + # gets converted into that function. + deserialized_objects[key] = item + elif isinstance(item, dict) and "__passive_serialization__" in item: + deserialized_objects[key] = deserialize_keras_object( + item, + module_objects=module_objects, + custom_objects=custom_objects, + printable_module_name="config_item", + ) + # TODO(momernick): Should this also have 'module_objects'? + elif isinstance(item, str) and inspect.isfunction( + object_registration.get_registered_object(item, custom_objects) + ): + # Handle custom functions here. When saving functions, we only save + # the function's name as a string. If we find a matching string in + # the custom objects during deserialization, we convert the string + # back to the original function. + # Note that a potential issue is that a string field could have a + # naming conflict with a custom function name, but this should be a + # rare case. This issue does not occur if a string field has a + # naming conflict with a custom object, since the config of an + # object will always be a dict. + deserialized_objects[key] = ( + object_registration.get_registered_object(item, custom_objects) + ) + for key, item in deserialized_objects.items(): + cls_config[key] = deserialized_objects[key] + + return (cls, cls_config) + + +@keras_export( + [ + "keras.legacy.saving.deserialize_keras_object", + "keras.utils.legacy.deserialize_keras_object", + ] +) +def deserialize_keras_object( + identifier, + module_objects=None, + custom_objects=None, + printable_module_name="object", +): + """Turns the serialized form of a Keras object back into an actual object. + + This function is for mid-level library implementers rather than end users. + + Importantly, this utility requires you to provide the dict of + `module_objects` to use for looking up the object config; this is not + populated by default. If you need a deserialization utility that has + preexisting knowledge of built-in Keras objects, use e.g. + `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, + etc. + + Calling `deserialize_keras_object` while underneath the + `SharedObjectLoadingScope` context manager will cause any already-seen + shared objects to be returned as-is rather than creating a new object. + + Args: + identifier: the serialized form of the object. + module_objects: A dictionary of built-in objects to look the name up in. + Generally, `module_objects` is provided by midlevel library + implementers. + custom_objects: A dictionary of custom objects to look the name up in. + Generally, `custom_objects` is provided by the end user. + printable_module_name: A human-readable string representing the type of + the object. Printed in case of exception. + + Returns: + The deserialized object. + + Example: + + A mid-level library implementer might want to implement a utility for + retrieving an object from its config, as such: + + ```python + def deserialize(config, custom_objects=None): + return deserialize_keras_object( + identifier, + module_objects=globals(), + custom_objects=custom_objects, + name="MyObjectType", + ) + ``` + + This is how e.g. `keras.layers.deserialize()` is implemented. + """ + + if identifier is None: + return None + + if isinstance(identifier, dict): + # In this case we are dealing with a Keras config dictionary. + config = identifier + (cls, cls_config) = class_and_config_for_serialized_keras_object( + config, module_objects, custom_objects, printable_module_name + ) + + # If this object has already been loaded (i.e. it's shared between + # multiple objects), return the already-loaded object. + shared_object_id = config.get(SHARED_OBJECT_KEY) + shared_object = _shared_object_loading_scope().get(shared_object_id) + if shared_object is not None: + return shared_object + + if hasattr(cls, "from_config"): + arg_spec = inspect.getfullargspec(cls.from_config) + custom_objects = custom_objects or {} + + if "custom_objects" in arg_spec.args: + deserialized_obj = cls.from_config( + cls_config, + custom_objects={ + **object_registration.GLOBAL_CUSTOM_OBJECTS, + **custom_objects, + }, + ) + else: + with object_registration.CustomObjectScope(custom_objects): + deserialized_obj = cls.from_config(cls_config) + else: + # Then `cls` may be a function returning a class. + # in this case by convention `config` holds + # the kwargs of the function. + custom_objects = custom_objects or {} + with object_registration.CustomObjectScope(custom_objects): + deserialized_obj = cls(**cls_config) + + # Add object to shared objects, in case we find it referenced again. + _shared_object_loading_scope().set(shared_object_id, deserialized_obj) + + return deserialized_obj + + elif isinstance(identifier, str): + object_name = identifier + if custom_objects and object_name in custom_objects: + obj = custom_objects.get(object_name) + elif ( + object_name + in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__ + ): + obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[ + object_name + ] + elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS: + obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name] + else: + obj = module_objects.get(object_name) + if obj is None: + raise ValueError( + f"Unknown {printable_module_name}: '{object_name}'. " + "Please ensure you are using a " + "`keras.utils.custom_object_scope` " + "and that this object is included in the scope. See " + "https://www.tensorflow.org/guide/keras/save_and_serialize" + "#registering_the_custom_object for details." + ) + + # Classes passed by name are instantiated with no args, functions are + # returned as-is. + if inspect.isclass(obj): + return obj() + return obj + elif inspect.isfunction(identifier): + # If a function has already been deserialized, return as is. + return identifier + else: + raise ValueError( + "Could not interpret serialized " + f"{printable_module_name}: {identifier}" + ) + + +def validate_config(config): + """Determines whether config appears to be a valid layer config.""" + return ( + isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config + ) + + +def is_default(method): + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) diff --git a/keras/src/losses/__init__.py b/keras/src/losses/__init__.py new file mode 100644 index 000000000000..7afeb55a01d1 --- /dev/null +++ b/keras/src/losses/__init__.py @@ -0,0 +1,207 @@ +import inspect + +from keras.src.api_export import keras_export +from keras.src.losses.loss import Loss +from keras.src.losses.losses import CTC +from keras.src.losses.losses import BinaryCrossentropy +from keras.src.losses.losses import BinaryFocalCrossentropy +from keras.src.losses.losses import CategoricalCrossentropy +from keras.src.losses.losses import CategoricalFocalCrossentropy +from keras.src.losses.losses import CategoricalHinge +from keras.src.losses.losses import Circle +from keras.src.losses.losses import CosineSimilarity +from keras.src.losses.losses import Dice +from keras.src.losses.losses import Hinge +from keras.src.losses.losses import Huber +from keras.src.losses.losses import KLDivergence +from keras.src.losses.losses import LogCosh +from keras.src.losses.losses import LossFunctionWrapper +from keras.src.losses.losses import MeanAbsoluteError +from keras.src.losses.losses import MeanAbsolutePercentageError +from keras.src.losses.losses import MeanSquaredError +from keras.src.losses.losses import MeanSquaredLogarithmicError +from keras.src.losses.losses import Poisson +from keras.src.losses.losses import SparseCategoricalCrossentropy +from keras.src.losses.losses import SquaredHinge +from keras.src.losses.losses import Tversky +from keras.src.losses.losses import binary_crossentropy +from keras.src.losses.losses import binary_focal_crossentropy +from keras.src.losses.losses import categorical_crossentropy +from keras.src.losses.losses import categorical_focal_crossentropy +from keras.src.losses.losses import categorical_hinge +from keras.src.losses.losses import circle +from keras.src.losses.losses import cosine_similarity +from keras.src.losses.losses import ctc +from keras.src.losses.losses import dice +from keras.src.losses.losses import hinge +from keras.src.losses.losses import huber +from keras.src.losses.losses import kl_divergence +from keras.src.losses.losses import log_cosh +from keras.src.losses.losses import mean_absolute_error +from keras.src.losses.losses import mean_absolute_percentage_error +from keras.src.losses.losses import mean_squared_error +from keras.src.losses.losses import mean_squared_logarithmic_error +from keras.src.losses.losses import poisson +from keras.src.losses.losses import sparse_categorical_crossentropy +from keras.src.losses.losses import squared_hinge +from keras.src.losses.losses import tversky +from keras.src.saving import serialization_lib + +ALL_OBJECTS = { + # Base + Loss, + LossFunctionWrapper, + # Probabilistic + KLDivergence, + Poisson, + BinaryCrossentropy, + BinaryFocalCrossentropy, + CategoricalCrossentropy, + CategoricalFocalCrossentropy, + SparseCategoricalCrossentropy, + # Regression + MeanSquaredError, + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredLogarithmicError, + CosineSimilarity, + LogCosh, + Huber, + # Hinge + Hinge, + SquaredHinge, + CategoricalHinge, + # Image segmentation + Dice, + Tversky, + # Similarity + Circle, + # Sequence + CTC, + # Probabilistic + kl_divergence, + poisson, + binary_crossentropy, + binary_focal_crossentropy, + categorical_crossentropy, + categorical_focal_crossentropy, + sparse_categorical_crossentropy, + # Regression + mean_squared_error, + mean_absolute_error, + mean_absolute_percentage_error, + mean_squared_logarithmic_error, + cosine_similarity, + log_cosh, + huber, + # Hinge + hinge, + squared_hinge, + categorical_hinge, + # Image segmentation + dice, + tversky, + # Similarity + circle, + # Sequence + ctc, +} + +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + { + "bce": binary_crossentropy, + "BCE": binary_crossentropy, + "kld": kl_divergence, + "KLD": kl_divergence, + "mae": mean_absolute_error, + "MAE": mean_absolute_error, + "mse": mean_squared_error, + "MSE": mean_squared_error, + "mape": mean_absolute_percentage_error, + "MAPE": mean_absolute_percentage_error, + "msle": mean_squared_logarithmic_error, + "MSLE": mean_squared_logarithmic_error, + } +) + + +@keras_export("keras.losses.serialize") +def serialize(loss): + """Serializes loss function or `Loss` instance. + + Args: + loss: A Keras `Loss` instance or a loss function. + + Returns: + Loss configuration dictionary. + """ + return serialization_lib.serialize_keras_object(loss) + + +@keras_export("keras.losses.deserialize") +def deserialize(name, custom_objects=None): + """Deserializes a serialized loss class/function instance. + + Args: + name: Loss configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras `Loss` instance or a loss function. + """ + return serialization_lib.deserialize_keras_object( + name, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.losses.get") +def get(identifier): + """Retrieves a Keras loss as a `function`/`Loss` class instance. + + The `identifier` may be the string name of a loss function or `Loss` class. + + >>> loss = losses.get("categorical_crossentropy") + >>> type(loss) + + >>> loss = losses.get("CategoricalCrossentropy") + >>> type(loss) + + + You can also specify `config` of the loss to this function by passing dict + containing `class_name` and `config` as an identifier. Also note that the + `class_name` must map to a `Loss` class + + >>> identifier = {"class_name": "CategoricalCrossentropy", + ... "config": {"from_logits": True}} + >>> loss = losses.get(identifier) + >>> type(loss) + + + Args: + identifier: A loss identifier. One of None or string name of a loss + function/class or loss configuration dictionary or a loss function + or a loss class instance. + + Returns: + A Keras loss as a `function`/ `Loss` class instance. + """ + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj() + return obj + else: + raise ValueError(f"Could not interpret loss identifier: {identifier}") diff --git a/keras/src/losses/loss.py b/keras/src/losses/loss.py new file mode 100644 index 000000000000..6af73902d0fd --- /dev/null +++ b/keras/src/losses/loss.py @@ -0,0 +1,256 @@ +from keras.src import backend +from keras.src import dtype_policies +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.saving.keras_saveable import KerasSaveable +from keras.src.utils.naming import auto_name + + +@keras_export(["keras.Loss", "keras.losses.Loss"]) +class Loss(KerasSaveable): + """Loss base class. + + This is the class to subclass in order to create new custom losses. + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + To be implemented by subclasses: + + * `call()`: Contains the logic for loss calculation using `y_true`, + `y_pred`. + + Example subclass implementation: + + ```python + class MeanSquaredError(Loss): + def call(self, y_true, y_pred): + return ops.mean(ops.square(y_pred - y_true), axis=-1) + ``` + """ + + def __init__(self, name=None, reduction="sum_over_batch_size", dtype=None): + self.name = name or auto_name(self.__class__.__name__) + self.reduction = standardize_reduction(reduction) + self._dtype_policy = dtype_policies.get(dtype or backend.floatx()) + self._dtype = self._dtype_policy.compute_dtype + + @property + def dtype(self): + return self._dtype + + def __call__(self, y_true, y_pred, sample_weight=None): + in_mask = backend.get_keras_mask(y_pred) + + with ops.name_scope(self.name): + y_pred = tree.map_structure( + lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred + ) + y_true = tree.map_structure( + lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_true + ) + + losses = self.call(y_true, y_pred) + out_mask = backend.get_keras_mask(losses) + + if in_mask is not None and out_mask is not None: + mask = in_mask & out_mask + elif in_mask is not None: + mask = in_mask + elif out_mask is not None: + mask = out_mask + else: + mask = None + + return reduce_weighted_values( + losses, + sample_weight=sample_weight, + mask=mask, + reduction=self.reduction, + dtype=self.dtype, + ) + + def call(self, y_true, y_pred): + raise NotImplementedError + + def get_config(self): + return {"name": self.name, "reduction": self.reduction} + + @classmethod + def from_config(cls, config): + return cls(**config) + + def _obj_type(self): + return "Loss" + + +def standardize_reduction(reduction): + allowed = { + "sum_over_batch_size", + "sum", + None, + "none", + "mean", + "mean_with_sample_weight", + } + if reduction not in allowed: + raise ValueError( + "Invalid value for argument `reduction`. " + f"Expected one of {allowed}. Received: " + f"reduction={reduction}" + ) + return reduction + + +def squeeze_or_expand_to_same_rank(x1, x2, expand_rank_1=True): + """Squeeze/expand last dim if ranks differ from expected by exactly 1.""" + x1_rank = len(x1.shape) + x2_rank = len(x2.shape) + if x1_rank == x2_rank: + return x1, x2 + if x1_rank == x2_rank + 1: + if x1.shape[-1] == 1: + if x2_rank == 1 and expand_rank_1: + x2 = ops.expand_dims(x2, axis=-1) + else: + x1 = ops.squeeze(x1, axis=-1) + if x2_rank == x1_rank + 1: + if x2.shape[-1] == 1: + if x1_rank == 1 and expand_rank_1: + x1 = ops.expand_dims(x1, axis=-1) + else: + x2 = ops.squeeze(x2, axis=-1) + return x1, x2 + + +def reduce_values(values, sample_weight=None, reduction="sum_over_batch_size"): + if ( + reduction is None + or reduction == "none" + or tuple(values.shape) == () + or tuple(values.shape) == (0,) + ): + return values + loss = ops.sum(values) + if reduction in ("sum_over_batch_size", "mean", "mean_with_sample_weight"): + if reduction == "mean_with_sample_weight" and sample_weight is not None: + divisor = ops.cast(ops.sum(sample_weight), loss.dtype) + else: + divisor = ops.cast( + ops.prod( + ops.convert_to_tensor(ops.shape(values), dtype="int32") + ), + loss.dtype, + ) + loss = ops.divide_no_nan(loss, divisor) + loss = scale_loss_for_distribution(loss) + return loss + + +def reduce_weighted_values( + values, + sample_weight=None, + mask=None, + reduction="sum_over_batch_size", + dtype=None, +): + reduction = standardize_reduction(reduction) + + values = ops.convert_to_tensor(values, dtype=dtype) + if sample_weight is not None: + sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype) + if mask is not None: + mask = ops.convert_to_tensor(mask, dtype=dtype) + + # Merge mask and sample weight into sample weight. + sample_weight = apply_mask( + sample_weight, mask, dtype=values.dtype, reduction=reduction + ) + + if sample_weight is not None: + sample_weight = ops.cast(sample_weight, values.dtype) + # Update dimensions of `sample_weight` to match `losses`. + values, sample_weight = squeeze_or_expand_to_same_rank( + values, sample_weight + ) + values = values * sample_weight + + # Apply reduction function to the individual weighted losses. + loss = reduce_values(values, sample_weight, reduction) + return loss + + +def apply_mask(sample_weight, mask, dtype, reduction): + """Applies any mask on predictions to sample weights.""" + if mask is not None: + mask = ops.cast(mask, dtype=dtype) + if reduction in ("mean", "sum_over_batch_size"): + # Valid entries have weight `total/valid`, while invalid ones + # have 0. When summed over batch, they will be reduced to: + # + # mean(loss * sample_weight * total / valid) + # = sum(loss * sample_weight * total / valid) / total + # = sum(loss * sample_weight) / total * total / valid + # = sum(loss * sample_weight) / valid + total = ops.cast( + ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")), + dtype, + ) + valid = ops.sum(mask) # May be 0! + mask *= total / (valid + backend.epsilon()) + + if sample_weight is not None: + sample_weight = ops.cast(sample_weight, dtype=dtype) + mask, sample_weight = squeeze_or_expand_to_same_rank( + mask, sample_weight + ) + sample_weight *= mask + else: + sample_weight = mask + return sample_weight + + +def scale_loss_for_distribution(value): + """Scales the given value by the number of replicas in the strategy. + + Currently, this function is only effective when using the tensorflow backend + and `tf.distribute`. + """ + if backend.backend() == "tensorflow": + import tensorflow as tf + + num_replicas = tf.distribute.get_strategy().num_replicas_in_sync + if num_replicas > 1: + value = ops.multiply( + value, ops.cast(1.0 / num_replicas, value.dtype) + ) + return value + + +def unscale_loss_for_distribution(value): + """Unscales the given value by the number of replicas in the strategy. + + Currently, this function is only effective when using the tensorflow backend + and `tf.distribute`. + """ + if backend.backend() == "tensorflow": + import tensorflow as tf + + num_replicas = tf.distribute.get_strategy().num_replicas_in_sync + if num_replicas > 1: + value = ops.multiply(value, ops.cast(num_replicas, value.dtype)) + return value diff --git a/keras/src/losses/loss_test.py b/keras/src/losses/loss_test.py new file mode 100644 index 000000000000..849e553ff9cf --- /dev/null +++ b/keras/src/losses/loss_test.py @@ -0,0 +1,289 @@ +import pickle + +import numpy as np +import pytest + +from keras.src import backend +from keras.src import dtype_policies +from keras.src import losses as losses_module +from keras.src import ops +from keras.src import testing +from keras.src.losses.loss import Loss +from keras.src.losses.loss import squeeze_or_expand_to_same_rank + + +class ExampleLoss(Loss): + def call(self, y_true, y_pred): + return (y_true - y_pred) ** 2 + + +class LossTest(testing.TestCase): + def setUp(self): + self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy() + self._floatx = backend.floatx() + return super().setUp() + + def tearDown(self): + dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy) + backend.set_floatx(self._floatx) + return super().tearDown() + + def test_squeeze_or_expand(self): + x1 = ops.ones((3,)) + x2 = ops.ones((3, 1)) + x1, x2 = squeeze_or_expand_to_same_rank(x1, x2) + self.assertEqual(ops.shape(x1), (3, 1)) + self.assertEqual(ops.shape(x2), (3, 1)) + + x1 = ops.ones((3, 2)) + x2 = ops.ones((3, 2, 1)) + x1, x2 = squeeze_or_expand_to_same_rank(x1, x2) + self.assertEqual(ops.shape(x1), (3, 2)) + self.assertEqual(ops.shape(x2), (3, 2)) + + x1 = ops.ones((3,)) + x2 = ops.ones((3, 1)) + x2, x1 = squeeze_or_expand_to_same_rank(x2, x1) + self.assertEqual(ops.shape(x1), (3, 1)) + self.assertEqual(ops.shape(x2), (3, 1)) + + x1 = ops.ones((3, 2)) + x2 = ops.ones((3, 2, 1)) + x2, x1 = squeeze_or_expand_to_same_rank(x2, x1) + self.assertEqual(ops.shape(x1), (3, 2)) + self.assertEqual(ops.shape(x2), (3, 2)) + + def test_reduction(self): + y_true = np.array([1.0, 0.0, 1.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + + # No reduction + loss_fn = ExampleLoss(reduction=None) + loss = loss_fn(y_true, y_pred) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose((y_true - y_pred) ** 2, loss) + + # sum + loss_fn = ExampleLoss(reduction="sum") + loss = loss_fn(y_true, y_pred) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose(np.sum((y_true - y_pred) ** 2), loss) + + # sum_over_batch_size or mean + loss_fn = ExampleLoss(reduction="sum_over_batch_size") + loss = loss_fn(y_true, y_pred) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose(np.sum((y_true - y_pred) ** 2) / 4, loss) + + # bad reduction + with self.assertRaisesRegex(ValueError, "Invalid value for argument"): + ExampleLoss(reduction="abc") + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_mask(self): + mask = np.array([True, False, True, True]) + y_true = np.array([1.0, 0.0, 1.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + + masked_y_true = np.array([1.0, 1.0, 0.0]) + masked_y_pred = np.array([0.1, 0.3, 0.4]) + + mask = ops.convert_to_tensor(mask) + y_true = ops.convert_to_tensor(y_true) + y_pred = ops.convert_to_tensor(y_pred) + y_pred._keras_mask = mask + + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose( + np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss + ) + + # Test edge case where everything is masked. + mask = np.array([False, False, False, False]) + y_pred._keras_mask = mask + loss = loss_fn(y_true, y_pred) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose(loss, 0) # No NaN. + + def test_sample_weight(self): + sample_weight = np.array([0.4, 0.3, 0.2, 0.1]) + y_true = np.array([1.0, 0.0, 1.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose( + np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss + ) + + # Test edge case where every weight is 0. + sample_weight = np.array([0.0, 0.0, 0.0, 0.0]) + loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose(loss, 0) # No NaN. + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_mask_and_sample_weight(self): + sample_weight = np.array([0.4, 0.3, 0.2, 0.1]) + y_true = np.array([1.0, 0.0, 1.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + mask = np.array([True, False, True, True]) + + masked_sample_weight = np.array([0.4, 0.2, 0.1]) + masked_y_true = np.array([1.0, 1.0, 0.0]) + masked_y_pred = np.array([0.1, 0.3, 0.4]) + + mask = ops.convert_to_tensor(mask) + y_true = ops.convert_to_tensor(y_true) + y_pred = ops.convert_to_tensor(y_pred) + y_pred._keras_mask = mask + + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose( + np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2) + / 3, + loss, + ) + + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_mask_and_sample_weight_rank2(self): + # check loss of inputs with duplicate rows doesn't change + sample_weight = np.array([0.4, 0.3, 0.2, 0.1]) + y_true = np.array([1.0, 0.0, 1.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + mask = np.array([True, False, True, True]) + + mask = ops.convert_to_tensor(mask) + y_true = ops.convert_to_tensor(y_true) + y_pred = ops.convert_to_tensor(y_pred) + y_pred._keras_mask = mask + + loss_fn = ExampleLoss() + rank1_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + + # duplicate rows + mask = ops.tile(ops.expand_dims(mask, axis=0), (2, 1)) + y_true = ops.tile(ops.expand_dims(y_true, axis=0), (2, 1)) + y_pred = ops.tile(ops.expand_dims(y_pred, axis=0), (2, 1)) + sample_weight = ops.tile(ops.expand_dims(sample_weight, axis=0), (2, 1)) + y_pred._keras_mask = mask + rank2_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(rank1_loss, rank2_loss) + + # @testing.parametrize( + # "uprank", ["mask", "sample_weight", "y_true", "y_pred"]) + # TODO: use parameterization decorator + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support masking.", + ) + def test_rank_adjustment(self): + for uprank in ["mask", "sample_weight", "ys"]: + sample_weight = np.array([0.4, 0.3, 0.2, 0.1]) + y_true = np.array([1.0, 0.0, 1.0, 0.0]) + y_pred = np.array([0.1, 0.2, 0.3, 0.4]) + mask = np.array([True, False, True, True]) + + if uprank == "mask": + mask = np.expand_dims(mask, -1) + elif uprank == "sample_weight": + sample_weight = np.expand_dims(sample_weight, -1) + elif uprank == "ys": + y_true = np.expand_dims(y_true, -1) + y_pred = np.expand_dims(y_pred, -1) + + masked_sample_weight = np.array([0.4, 0.2, 0.1]) + masked_y_true = np.array([1.0, 1.0, 0.0]) + masked_y_pred = np.array([0.1, 0.3, 0.4]) + + mask = ops.convert_to_tensor(mask) + y_true = ops.convert_to_tensor(y_true) + y_pred = ops.convert_to_tensor(y_pred) + y_pred._keras_mask = mask + + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose( + np.sum( + masked_sample_weight * (masked_y_true - masked_y_pred) ** 2 + ) + / 3, + loss, + ) + + def test_mixed_dtypes(self): + sample_weight = np.array([0.4, 0.3, 0.2, 0.1], dtype="float64") + y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="int32") + y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype="float32") + + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(backend.standardize_dtype(loss.dtype), "float32") + self.assertAllClose( + np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, + loss, + ) + + def test_pickle(self): + loss = losses_module.get("mse") + loss = pickle.loads(pickle.dumps(loss)) + self.assertEqual(loss, losses_module.mean_squared_error) + + def test_get_method(self): + loss = losses_module.get("mse") + self.assertEqual(loss, losses_module.mean_squared_error) + + loss = losses_module.get(None) + self.assertEqual(loss, None) + + with self.assertRaises(ValueError): + losses_module.get("typo") + + def test_dtype_arg(self): + y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="float32") + y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype="float32") + + # Note: we use float16 and not float64 to test this because + # JAX will map float64 to float32. + loss_fn = ExampleLoss(dtype="float16") + loss = loss_fn(y_true, y_pred) + self.assertDType(loss, "float16") + + # Test DTypePolicy for `dtype` argument + loss_fn = ExampleLoss(dtype=dtype_policies.DTypePolicy("mixed_float16")) + loss = loss_fn(y_true, y_pred) + self.assertDType(loss, "float16") + + # `dtype` setter should raise AttributeError + with self.assertRaises(AttributeError): + loss_fn.dtype = "bfloat16" + + def test_default_dtype(self): + y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="float32") + y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype="float32") + + # Defaults to `keras.config.floatx()` not global `dtype_policy` + dtype_policies.dtype_policy.set_dtype_policy("mixed_float16") + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred) + self.assertDType(loss, "float32") + + backend.set_floatx("float16") + loss_fn = ExampleLoss() + loss = loss_fn(y_true, y_pred) + self.assertDType(loss, backend.floatx()) diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py new file mode 100644 index 000000000000..4bf2ba062253 --- /dev/null +++ b/keras/src/losses/losses.py @@ -0,0 +1,2764 @@ +import warnings + +from keras.src import backend +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.losses.loss import Loss +from keras.src.losses.loss import squeeze_or_expand_to_same_rank +from keras.src.saving import serialization_lib +from keras.src.utils.numerical_utils import build_pos_neg_masks +from keras.src.utils.numerical_utils import normalize + + +class LossFunctionWrapper(Loss): + def __init__( + self, + fn, + reduction="sum_over_batch_size", + name=None, + dtype=None, + **kwargs, + ): + super().__init__(name=name, reduction=reduction, dtype=dtype) + self.fn = fn + self._fn_kwargs = kwargs + + def call(self, y_true, y_pred): + y_true_y_pred = tree.map_structure( + squeeze_or_expand_to_same_rank, y_true, y_pred + ) + y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred) + y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred) + return self.fn(y_true, y_pred, **self._fn_kwargs) + + def get_config(self): + config = super().get_config() + config.update({"fn": serialization_lib.serialize_keras_object(self.fn)}) + config.update(serialization_lib.serialize_keras_object(self._fn_kwargs)) + return config + + @classmethod + def from_config(cls, config): + if "fn" in config: + config = serialization_lib.deserialize_keras_object(config) + return cls(**config) + + def __repr__(self): + return f"" + + +@keras_export("keras.losses.MeanSquaredError") +class MeanSquaredError(LossFunctionWrapper): + """Computes the mean of squares of errors between labels and predictions. + + Formula: + + ```python + loss = mean(square(y_true - y_pred)) + ``` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="mean_squared_error", + dtype=None, + ): + super().__init__( + mean_squared_error, name=name, reduction=reduction, dtype=dtype + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.MeanAbsoluteError") +class MeanAbsoluteError(LossFunctionWrapper): + """Computes the mean of absolute difference between labels and predictions. + + Formula: + + ```python + loss = mean(abs(y_true - y_pred)) + ``` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="mean_absolute_error", + dtype=None, + ): + super().__init__( + mean_absolute_error, name=name, reduction=reduction, dtype=dtype + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.MeanAbsolutePercentageError") +class MeanAbsolutePercentageError(LossFunctionWrapper): + """Computes the mean absolute percentage error between `y_true` & `y_pred`. + + Formula: + + ```python + loss = 100 * mean(abs((y_true - y_pred) / y_true)) + ``` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="mean_absolute_percentage_error", + dtype=None, + ): + super().__init__( + mean_absolute_percentage_error, + name=name, + reduction=reduction, + dtype=dtype, + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.MeanSquaredLogarithmicError") +class MeanSquaredLogarithmicError(LossFunctionWrapper): + """Computes the mean squared logarithmic error between `y_true` & `y_pred`. + + Formula: + + ```python + loss = mean(square(log(y_true + 1) - log(y_pred + 1))) + ``` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="mean_squared_logarithmic_error", + dtype=None, + ): + super().__init__( + mean_squared_logarithmic_error, + name=name, + reduction=reduction, + dtype=dtype, + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.CosineSimilarity") +class CosineSimilarity(LossFunctionWrapper): + """Computes the cosine similarity between `y_true` & `y_pred`. + + Note that it is a number between -1 and 1. When it is a negative number + between -1 and 0, 0 indicates orthogonality and values closer to -1 + indicate greater similarity. This makes it usable as a loss function in a + setting where you try to maximize the proximity between predictions and + targets. If either `y_true` or `y_pred` is a zero vector, cosine similarity + will be 0 regardless of the proximity between predictions and targets. + + Formula: + + ```python + loss = -sum(l2_norm(y_true) * l2_norm(y_pred)) + ``` + + Args: + axis: The axis along which the cosine similarity is computed + (the features axis). Defaults to `-1`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + axis=-1, + reduction="sum_over_batch_size", + name="cosine_similarity", + dtype=None, + ): + super().__init__( + cosine_similarity, + name=name, + reduction=reduction, + dtype=dtype, + axis=axis, + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.Huber") +class Huber(LossFunctionWrapper): + """Computes the Huber loss between `y_true` & `y_pred`. + + Formula: + + ```python + for x in error: + if abs(x) <= delta: + loss.append(0.5 * x^2) + elif abs(x) > delta: + loss.append(delta * abs(x) - 0.5 * delta^2) + + loss = mean(loss, axis=-1) + ``` + See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss). + + Args: + delta: A float, the point where the Huber loss function changes from a + quadratic to linear. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + delta=1.0, + reduction="sum_over_batch_size", + name="huber_loss", + dtype=None, + ): + super().__init__( + huber, + name=name, + reduction=reduction, + dtype=dtype, + delta=delta, + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.LogCosh") +class LogCosh(LossFunctionWrapper): + """Computes the logarithm of the hyperbolic cosine of the prediction error. + + Formula: + + ```python + error = y_pred - y_true + logcosh = mean(log((exp(error) + exp(-error))/2), axis=-1)` + ``` + where x is the error `y_pred - y_true`. + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="log_cosh", + dtype=None, + ): + super().__init__(log_cosh, name=name, reduction=reduction, dtype=dtype) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.Hinge") +class Hinge(LossFunctionWrapper): + """Computes the hinge loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = maximum(1 - y_true * y_pred, 0) + ``` + + `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are + provided we will convert them to -1 or 1. + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="hinge", + dtype=None, + ): + super().__init__(hinge, name=name, reduction=reduction, dtype=dtype) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.SquaredHinge") +class SquaredHinge(LossFunctionWrapper): + """Computes the squared hinge loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = square(maximum(1 - y_true * y_pred, 0)) + ``` + + `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are + provided we will convert them to -1 or 1. + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, reduction="sum_over_batch_size", name="squared_hinge", dtype=None + ): + super().__init__( + squared_hinge, name=name, reduction=reduction, dtype=dtype + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.CategoricalHinge") +class CategoricalHinge(LossFunctionWrapper): + """Computes the categorical hinge loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = maximum(neg - pos + 1, 0) + ``` + + where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="categorical_hinge", + dtype=None, + ): + super().__init__( + categorical_hinge, name=name, reduction=reduction, dtype=dtype + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.KLDivergence") +class KLDivergence(LossFunctionWrapper): + """Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = y_true * log(y_true / y_pred) + ``` + + `y_true` and `y_pred` are expected to be probability + distributions, with values between 0 and 1. They will get + clipped to the `[0, 1]` range. + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, reduction="sum_over_batch_size", name="kl_divergence", dtype=None + ): + super().__init__( + kl_divergence, name=name, reduction=reduction, dtype=dtype + ) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.Poisson") +class Poisson(LossFunctionWrapper): + """Computes the Poisson loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = y_pred - y_true * log(y_pred) + ``` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, reduction="sum_over_batch_size", name="poisson", dtype=None + ): + super().__init__(poisson, name=name, reduction=reduction, dtype=dtype) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.BinaryCrossentropy") +class BinaryCrossentropy(LossFunctionWrapper): + """Computes the cross-entropy loss between true labels and predicted labels. + + Use this cross-entropy loss for binary (0 or 1) classification applications. + The loss function requires the following inputs: + + - `y_true` (true label): This is either 0 or 1. + - `y_pred` (predicted value): This is the model's prediction, i.e, a single + floating-point value which either represents a + [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] + when `from_logits=True`) or a probability (i.e, value in [0., 1.] when + `from_logits=False`). + + Args: + from_logits: Whether to interpret `y_pred` as a tensor of + [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we + assume that `y_pred` is probabilities (i.e., values in [0, 1]). + label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs. + When > 0, we compute the loss between the predicted labels + and a smoothed version of the true labels, where the smoothing + squeezes the labels towards 0.5. Larger values of + `label_smoothing` correspond to heavier smoothing. + axis: The axis along which to compute crossentropy (the features axis). + Defaults to `-1`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + + **Recommended Usage:** (set `from_logits=True`) + + With `compile()` API: + + ```python + model.compile( + loss=keras.losses.BinaryCrossentropy(from_logits=True), + ... + ) + ``` + + As a standalone function: + + >>> # Example 1: (batch_size = 1, number of samples = 4) + >>> y_true = np.array([0, 1, 0, 0]) + >>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8]) + >>> bce = keras.losses.BinaryCrossentropy(from_logits=True) + >>> bce(y_true, y_pred) + 0.8654 + + >>> # Example 2: (batch_size = 2, number of samples = 4) + >>> y_true = np.array([[0, 1], [0, 0]]) + >>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]]) + >>> # Using default 'auto'/'sum_over_batch_size' reduction type. + >>> bce = keras.losses.BinaryCrossentropy(from_logits=True) + >>> bce(y_true, y_pred) + 0.8654 + >>> # Using 'sample_weight' attribute + >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]) + 0.243 + >>> # Using 'sum' reduction` type. + >>> bce = keras.losses.BinaryCrossentropy(from_logits=True, + ... reduction="sum") + >>> bce(y_true, y_pred) + 1.730 + >>> # Using 'none' reduction type. + >>> bce = keras.losses.BinaryCrossentropy(from_logits=True, + ... reduction=None) + >>> bce(y_true, y_pred) + array([0.235, 1.496], dtype=float32) + + **Default Usage:** (set `from_logits=False`) + + >>> # Make the following updates to the above "Recommended Usage" section + >>> # 1. Set `from_logits=False` + >>> keras.losses.BinaryCrossentropy() # OR ...('from_logits=False') + >>> # 2. Update `y_pred` to use probabilities instead of logits + >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]] + """ + + def __init__( + self, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="binary_crossentropy", + dtype=None, + ): + super().__init__( + binary_crossentropy, + name=name, + reduction=reduction, + dtype=dtype, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + } + ) + return config + + +@keras_export("keras.losses.BinaryFocalCrossentropy") +class BinaryFocalCrossentropy(LossFunctionWrapper): + """Computes focal cross-entropy loss between true labels and predictions. + + Binary cross-entropy loss is often used for binary (0 or 1) classification + tasks. The loss function requires the following inputs: + + - `y_true` (true label): This is either 0 or 1. + - `y_pred` (predicted value): This is the model's prediction, i.e, a single + floating-point value which either represents a + [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] + when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when + `from_logits=False`). + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a "focal factor" to down-weight easy examples and focus more + on hard examples. By default, the focal tensor is computed as follows: + + `focal_factor = (1 - output) ** gamma` for class 1 + `focal_factor = output ** gamma` for class 0 + where `gamma` is a focusing parameter. When `gamma=0`, this function is + equivalent to the binary crossentropy loss. + + Args: + apply_class_balancing: A bool, whether to apply weight balancing on the + binary classes 0 and 1. + alpha: A weight balancing factor for class 1, default is `0.25` as + mentioned in reference [Lin et al., 2018]( + https://arxiv.org/pdf/1708.02002.pdf). The weight for class 0 is + `1.0 - alpha`. + gamma: A focusing parameter used to compute the focal factor, default is + `2.0` as mentioned in the reference + [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf). + from_logits: Whether to interpret `y_pred` as a tensor of + [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we + assume that `y_pred` are probabilities (i.e., values in `[0, 1]`). + label_smoothing: Float in `[0, 1]`. When `0`, no smoothing occurs. + When > `0`, we compute the loss between the predicted labels + and a smoothed version of the true labels, where the smoothing + squeezes the labels towards `0.5`. + Larger values of `label_smoothing` correspond to heavier smoothing. + axis: The axis along which to compute crossentropy (the features axis). + Defaults to `-1`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + + With the `compile()` API: + + ```python + model.compile( + loss=keras.losses.BinaryFocalCrossentropy( + gamma=2.0, from_logits=True), + ... + ) + ``` + + As a standalone function: + + >>> # Example 1: (batch_size = 1, number of samples = 4) + >>> y_true = np.array([0, 1, 0, 0]) + >>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8]) + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... gamma=2, from_logits=True) + >>> loss(y_true, y_pred) + 0.691 + + >>> # Apply class weight + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=2, from_logits=True) + >>> loss(y_true, y_pred) + 0.51 + + >>> # Example 2: (batch_size = 2, number of samples = 4) + >>> y_true = np.array([[0, 1], [0, 0]]) + >>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]]) + >>> # Using default 'auto'/'sum_over_batch_size' reduction type. + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... gamma=3, from_logits=True) + >>> loss(y_true, y_pred) + 0.647 + + >>> # Apply class weight + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=3, from_logits=True) + >>> loss(y_true, y_pred) + 0.482 + + >>> # Using 'sample_weight' attribute with focal effect + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... gamma=3, from_logits=True) + >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2]) + 0.133 + + >>> # Apply class weight + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=3, from_logits=True) + >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2]) + 0.097 + + >>> # Using 'sum' reduction` type. + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... gamma=4, from_logits=True, + ... reduction="sum") + >>> loss(y_true, y_pred) + 1.222 + + >>> # Apply class weight + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=4, from_logits=True, + ... reduction="sum") + >>> loss(y_true, y_pred) + 0.914 + + >>> # Using 'none' reduction type. + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... gamma=5, from_logits=True, + ... reduction=None) + >>> loss(y_true, y_pred) + array([0.0017 1.1561], dtype=float32) + + >>> # Apply class weight + >>> loss = keras.losses.BinaryFocalCrossentropy( + ... apply_class_balancing=True, gamma=5, from_logits=True, + ... reduction=None) + >>> loss(y_true, y_pred) + array([0.0004 0.8670], dtype=float32) + """ + + def __init__( + self, + apply_class_balancing=False, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="binary_focal_crossentropy", + dtype=None, + ): + super().__init__( + binary_focal_crossentropy, + name=name, + reduction=reduction, + dtype=dtype, + apply_class_balancing=apply_class_balancing, + alpha=alpha, + gamma=gamma, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + self.apply_class_balancing = apply_class_balancing + self.alpha = alpha + self.gamma = gamma + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + "apply_class_balancing": self.apply_class_balancing, + "alpha": self.alpha, + "gamma": self.gamma, + } + ) + return config + + +@keras_export("keras.losses.CategoricalCrossentropy") +class CategoricalCrossentropy(LossFunctionWrapper): + """Computes the crossentropy loss between the labels and predictions. + + Use this crossentropy loss function when there are two or more label + classes. We expect labels to be provided in a `one_hot` representation. If + you want to provide labels as integers, please use + `SparseCategoricalCrossentropy` loss. There should be `num_classes` floating + point values per feature, i.e., the shape of both `y_pred` and `y_true` are + `[batch_size, num_classes]`. + + Args: + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. For example, if + `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to `-1`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + + Standalone usage: + + >>> y_true = np.array([[0, 1, 0], [0, 0, 1]]) + >>> y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> cce = keras.losses.CategoricalCrossentropy() + >>> cce(y_true, y_pred) + 1.177 + + >>> # Calling with 'sample_weight'. + >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7])) + 0.814 + + >>> # Using 'sum' reduction type. + >>> cce = keras.losses.CategoricalCrossentropy( + ... reduction="sum") + >>> cce(y_true, y_pred) + 2.354 + + >>> # Using 'none' reduction type. + >>> cce = keras.losses.CategoricalCrossentropy( + ... reduction=None) + >>> cce(y_true, y_pred) + array([0.0513, 2.303], dtype=float32) + + Usage with the `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss=keras.losses.CategoricalCrossentropy()) + ``` + """ + + def __init__( + self, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="categorical_crossentropy", + dtype=None, + ): + super().__init__( + categorical_crossentropy, + name=name, + reduction=reduction, + dtype=dtype, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + } + ) + return config + + +@keras_export("keras.losses.CategoricalFocalCrossentropy") +class CategoricalFocalCrossentropy(LossFunctionWrapper): + """Computes the alpha balanced focal crossentropy loss. + + Use this crossentropy loss function when there are two or more label + classes and if you want to handle class imbalance without using + `class_weights`. We expect labels to be provided in a `one_hot` + representation. + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a focal factor to down-weight easy examples and focus more on + hard examples. The general formula for the focal loss (FL) + is as follows: + + `FL(p_t) = (1 - p_t) ** gamma * log(p_t)` + + where `p_t` is defined as follows: + `p_t = output if y_true == 1, else 1 - output` + + `(1 - p_t) ** gamma` is the `modulating_factor`, where `gamma` is a focusing + parameter. When `gamma` = 0, there is no focal effect on the cross entropy. + `gamma` reduces the importance given to simple examples in a smooth manner. + + The authors use alpha-balanced variant of focal loss (FL) in the paper: + `FL(p_t) = -alpha * (1 - p_t) ** gamma * log(p_t)` + + where `alpha` is the weight factor for the classes. If `alpha` = 1, the + loss won't be able to handle class imbalance properly as all + classes will have the same weight. This can be a constant or a list of + constants. If alpha is a list, it must have the same length as the number + of classes. + + The formula above can be generalized to: + `FL(p_t) = alpha * (1 - p_t) ** gamma * CrossEntropy(y_true, y_pred)` + + where minus comes from `CrossEntropy(y_true, y_pred)` (CE). + + Extending this to multi-class case is straightforward: + `FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)` + + In the snippet below, there is `num_classes` floating pointing values per + example. The shape of both `y_pred` and `y_true` are + `(batch_size, num_classes)`. + + Args: + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple (easy) examples in a smooth manner. + from_logits: Whether `output` is expected to be a logits tensor. By + default, we consider that `output` encodes a probability + distribution. + label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. For example, if + `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to `-1`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + + Standalone usage: + + >>> y_true = [[0., 1., 0.], [0., 0., 1.]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> cce = keras.losses.CategoricalFocalCrossentropy() + >>> cce(y_true, y_pred) + 0.23315276 + + >>> # Calling with 'sample_weight'. + >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7])) + 0.1632 + + >>> # Using 'sum' reduction type. + >>> cce = keras.losses.CategoricalFocalCrossentropy( + ... reduction="sum") + >>> cce(y_true, y_pred) + 0.46631 + + >>> # Using 'none' reduction type. + >>> cce = keras.losses.CategoricalFocalCrossentropy( + ... reduction=None) + >>> cce(y_true, y_pred) + array([3.2058331e-05, 4.6627346e-01], dtype=float32) + + Usage with the `compile()` API: + + ```python + model.compile(optimizer='adam', + loss=keras.losses.CategoricalFocalCrossentropy()) + ``` + """ + + def __init__( + self, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction="sum_over_batch_size", + name="categorical_focal_crossentropy", + dtype=None, + ): + """Initializes `CategoricalFocalCrossentropy` instance.""" + super().__init__( + categorical_focal_crossentropy, + name=name, + reduction=reduction, + dtype=dtype, + alpha=alpha, + gamma=gamma, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + self.alpha = alpha + self.gamma = gamma + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + "alpha": self.alpha, + "gamma": self.gamma, + } + ) + return config + + +@keras_export("keras.losses.SparseCategoricalCrossentropy") +class SparseCategoricalCrossentropy(LossFunctionWrapper): + """Computes the crossentropy loss between the labels and predictions. + + Use this crossentropy loss function when there are two or more label + classes. We expect labels to be provided as integers. If you want to + provide labels using `one-hot` representation, please use + `CategoricalCrossentropy` loss. There should be `# classes` floating point + values per feature for `y_pred` and a single floating point value per + feature for `y_true`. + + In the snippet below, there is a single floating point value per example for + `y_true` and `num_classes` floating pointing values per example for + `y_pred`. The shape of `y_true` is `[batch_size]` and the shape of `y_pred` + is `[batch_size, num_classes]`. + + Args: + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to `-1`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + + >>> y_true = np.array([1, 2]) + >>> y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> scce = keras.losses.SparseCategoricalCrossentropy() + >>> scce(y_true, y_pred) + 1.177 + + >>> # Calling with 'sample_weight'. + >>> scce(y_true, y_pred, sample_weight=np.array([0.3, 0.7])) + 0.814 + + >>> # Using 'sum' reduction type. + >>> scce = keras.losses.SparseCategoricalCrossentropy( + ... reduction="sum") + >>> scce(y_true, y_pred) + 2.354 + + >>> # Using 'none' reduction type. + >>> scce = keras.losses.SparseCategoricalCrossentropy( + ... reduction=None) + >>> scce(y_true, y_pred) + array([0.0513, 2.303], dtype=float32) + + Usage with the `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss=keras.losses.SparseCategoricalCrossentropy()) + ``` + """ + + def __init__( + self, + from_logits=False, + ignore_class=None, + reduction="sum_over_batch_size", + axis=-1, + name="sparse_categorical_crossentropy", + dtype=None, + ): + super().__init__( + sparse_categorical_crossentropy, + name=name, + reduction=reduction, + dtype=dtype, + from_logits=from_logits, + ignore_class=ignore_class, + axis=axis, + ) + self.from_logits = from_logits + self.ignore_class = ignore_class + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "from_logits": self.from_logits, + "ignore_class": self.ignore_class, + } + ) + return config + + +@keras_export("keras.losses.CTC") +class CTC(LossFunctionWrapper): + """CTC (Connectionist Temporal Classification) loss. + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__(self, reduction="sum_over_batch_size", name="ctc", dtype=None): + super().__init__(ctc, name=name, reduction=reduction, dtype=dtype) + + def get_config(self): + return Loss.get_config(self) + + +@keras_export("keras.losses.Dice") +class Dice(LossFunctionWrapper): + """Computes the Dice loss value between `y_true` and `y_pred`. + + Formula: + ```python + loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred)) + ``` + + Args: + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + axis: Tuple for which dimensions the loss is calculated. Defaults to + `None`. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Returns: + Dice loss value. + + Example: + + >>> y_true = [[[[1.0], [1.0]], [[0.0], [0.0]]], + ... [[[1.0], [1.0]], [[0.0], [0.0]]]] + >>> y_pred = [[[[0.0], [1.0]], [[0.0], [1.0]]], + ... [[[0.4], [0.0]], [[0.0], [0.9]]]] + >>> axis = (1, 2, 3) + >>> loss = keras.losses.Dice(axis=axis, reduction=None)(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.5, 0.75757575], shape=(2,), dtype=float32) + + >>> loss = keras.losses.Dice()(y_true, y_pred) + >>> assert loss.shape == () + >>> loss + array(0.6164384, shape=(), dtype=float32) + + >>> y_true = np.array(y_true) + >>> y_pred = np.array(y_pred) + >>> loss = keras.losses.Dice(axis=axis, reduction=None)(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.5, 0.75757575], shape=(2,), dtype=float32) + + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="dice", + axis=None, + dtype=None, + ): + super().__init__( + dice, name=name, reduction=reduction, dtype=dtype, axis=axis + ) + self.axis = axis + + def get_config(self): + config = Loss.get_config(self) + config.update({"axis": self.axis}) + return config + + +@keras_export("keras.losses.Tversky") +class Tversky(LossFunctionWrapper): + """Computes the Tversky loss value between `y_true` and `y_pred`. + + This loss function is weighted by the alpha and beta coefficients + that penalize false positives and false negatives. + + With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to + Dice Loss. + + Args: + alpha: The coefficient controlling incidence of false positives. + Defaults to `0.5`. + beta: The coefficient controlling incidence of false negatives. + Defaults to `0.5`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Returns: + Tversky loss value. + + Reference: + + - [Salehi et al., 2017](https://arxiv.org/abs/1706.05721) + """ + + def __init__( + self, + alpha=0.5, + beta=0.5, + reduction="sum_over_batch_size", + name="tversky", + axis=None, + dtype=None, + ): + super().__init__( + tversky, + name=name, + reduction=reduction, + dtype=dtype, + alpha=alpha, + beta=beta, + axis=axis, + ) + self.alpha = alpha + self.beta = beta + self.axis = axis + + def get_config(self): + config = Loss.get_config(self) + config.update( + {"alpha": self.alpha, "beta": self.beta, "axis": self.axis} + ) + return config + + +@keras_export("keras.losses.Circle") +class Circle(LossFunctionWrapper): + """Computes Circle Loss between integer labels and L2-normalized embeddings. + + This is a metric learning loss designed to minimize within-class distance + and maximize between-class distance in a flexible manner by dynamically + adjusting the penalty strength based on optimization status of each + similarity score. + + To use Circle Loss effectively, the model should output embeddings without + an activation function (such as a `Dense` layer with `activation=None`) + followed by UnitNormalization layer to ensure unit-norm embeddings. + + Args: + gamma: Scaling factor that determines the largest scale of each + similarity score. Defaults to `80`. + margin: The relaxation factor, below this distance, negatives are + up weighted and positives are down weighted. Similarly, above this + distance negatives are down weighted and positive are up weighted. + Defaults to `0.4`. + remove_diagonal: Boolean, whether to remove self-similarities from the + positive mask. Defaults to `True`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Examples: + + Usage with the `compile()` API: + + ```python + model = models.Sequential([ + keras.layers.Input(shape=(224, 224, 3)), + keras.layers.Conv2D(16, (3, 3), activation='relu'), + keras.layers.Flatten(), + keras.layers.Dense(64, activation=None), # No activation + keras.layers.UnitNormalization() # L2 normalization + ]) + + model.compile(optimizer="adam", loss=keras.losses.Circle()) + ``` + + Reference: + - [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857) + + """ + + def __init__( + self, + gamma=80.0, + margin=0.4, + remove_diagonal=True, + reduction="sum_over_batch_size", + name="circle", + dtype=None, + ): + super().__init__( + circle, + name=name, + reduction=reduction, + dtype=dtype, + gamma=gamma, + margin=margin, + remove_diagonal=remove_diagonal, + ) + self.gamma = gamma + self.margin = margin + self.remove_diagonal = remove_diagonal + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "gamma": self.gamma, + "margin": self.margin, + "remove_diagonal": self.remove_diagonal, + } + ) + return config + + +@keras_export("keras.losses.CategoricalGeneralizedCrossEntropy") +class CategoricalGeneralizedCrossEntropy(LossFunctionWrapper): + """Computes the Generalized Cross Entropy loss between `y_true` & `y_pred`. + + Generalized Cross Entropy (GCE) is a noise-robust loss function + that provides better robustness against noisy labels than + standard cross entropy. + It generalizes both cross entropy and mean absolute error through + the parameter q, where values closer to 1 make the loss more robust + to noisy labels. + + Formula: + ```python + loss = (1 - p**q) / q + ``` + where `p` is the predicted probability for the true class and `q` + is the noise parameter. + + Args: + q: Float in range `(0, 1)`. It is the noise parameter. + Controls the behavior of the loss: + - As `q` approaches 0: Behaves more like cross entropy + - As `q` approaches 1: Behaves more like mean absolute error + Defaults to `0.5` + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Example: + ```python + y_true = np.array([0, 1, 0, 1]) + y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]]) + keras.losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred) + ``` + + References: + - [Zhang, Sabuncu, 2018](https://arxiv.org/abs/1805.07836) + ("Generalized Cross Entropy Loss for Training + Deep Neural Networks with Noisy Labels") + """ + + def __init__( + self, + q=0.5, + reduction="sum_over_batch_size", + name="categorical_generalized_cross_entropy", + dtype=None, + ): + if not 0 < q < 1: + raise ValueError("q must be in the interval (0, 1)") + super().__init__( + categorical_generalized_cross_entropy, + name=name, + reduction=reduction, + dtype=dtype, + q=q, + ) + self.q = q + + def get_config(self): + config = Loss.get_config(self) + config.update( + { + "q": self.q, + } + ) + return config + + +def convert_binary_labels_to_hinge(y_true): + """Converts binary labels into -1/1 for hinge loss/metric calculation.""" + are_zeros = ops.equal(y_true, 0) + are_ones = ops.equal(y_true, 1) + is_binary = ops.all((ops.logical_or(are_zeros, are_ones))) + + def _convert_binary_labels(): + # Convert the binary labels to -1 or 1. + return 2.0 * y_true - 1.0 + + def _return_labels_unconverted(): + # Returns the labels unchanged if they are non-binary + return y_true + + updated_y_true = ops.cond( + is_binary, _convert_binary_labels, _return_labels_unconverted + ) + return updated_y_true + + +@keras_export( + [ + "keras.metrics.hinge", + "keras.losses.hinge", + ] +) +def hinge(y_true, y_pred): + """Computes the hinge loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1) + ``` + + Args: + y_true: The ground truth values. `y_true` values are expected to be -1 + or 1. If binary (0 or 1) labels are provided they will be converted + to -1 or 1 with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Hinge loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.choice([-1, 1], size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.hinge(y_true, y_pred) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, dtype=y_pred.dtype) + y_true = ops.convert_to_tensor(y_true) + y_true = convert_binary_labels_to_hinge(y_true) + return ops.mean(ops.maximum(1.0 - y_true * y_pred, 0.0), axis=-1) + + +@keras_export( + [ + "keras.metrics.squared_hinge", + "keras.losses.squared_hinge", + ] +) +def squared_hinge(y_true, y_pred): + """Computes the squared hinge loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1) + ``` + + Args: + y_true: The ground truth values. `y_true` values are expected to be -1 + or 1. If binary (0 or 1) labels are provided we will convert them + to -1 or 1 with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Squared hinge loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.choice([-1, 1], size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.squared_hinge(y_true, y_pred) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + y_true = convert_binary_labels_to_hinge(y_true) + return ops.mean( + ops.square(ops.maximum(1.0 - y_true * y_pred, 0.0)), axis=-1 + ) + + +@keras_export( + [ + "keras.metrics.categorical_hinge", + "keras.losses.categorical_hinge", + ] +) +def categorical_hinge(y_true, y_pred): + """Computes the categorical hinge loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = maximum(neg - pos + 1, 0) + ``` + + where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)` + + Args: + y_true: The ground truth values. `y_true` values are expected to be + either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor) with + shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Categorical hinge loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 3, size=(2,)) + >>> y_true = np.eye(np.max(y_true) + 1)[y_true] + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.categorical_hinge(y_true, y_pred) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + pos = ops.sum(y_true * y_pred, axis=-1) + neg = ops.max((1.0 - y_true) * y_pred, axis=-1) + zero = ops.cast(0.0, y_pred.dtype) + return ops.maximum(neg - pos + 1.0, zero) + + +@keras_export( + [ + "keras.metrics.mean_squared_error", + "keras.losses.mean_squared_error", + # Legacy aliases + "keras._legacy.losses.mse", + "keras._legacy.losses.MSE", + "keras._legacy.metrics.mse", + "keras._legacy.metrics.MSE", + ] +) +def mean_squared_error(y_true, y_pred): + """Computes the mean squared error between labels and predictions. + + Formula: + + ```python + loss = mean(square(y_true - y_pred), axis=-1) + ``` + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.mean_squared_error(y_true, y_pred) + + Args: + y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Mean squared error values with shape = `[batch_size, d0, .. dN-1]`. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + return ops.mean(ops.square(y_true - y_pred), axis=-1) + + +@keras_export( + [ + "keras.metrics.mean_absolute_error", + "keras.losses.mean_absolute_error", + # Legacy aliases + "keras._legacy.losses.MAE", + "keras._legacy.losses.mae", + "keras._legacy.metrics.MAE", + "keras._legacy.metrics.mae", + ] +) +def mean_absolute_error(y_true, y_pred): + """Computes the mean absolute error between labels and predictions. + + ```python + loss = mean(abs(y_true - y_pred), axis=-1) + ``` + + Args: + y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Mean absolute error values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.mean_absolute_error(y_true, y_pred) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + return ops.mean(ops.abs(y_true - y_pred), axis=-1) + + +@keras_export( + [ + "keras.metrics.mean_absolute_percentage_error", + "keras.losses.mean_absolute_percentage_error", + # Legacy aliases + "keras._legacy.losses.mape", + "keras._legacy.losses.MAPE", + "keras._legacy.metrics.mape", + "keras._legacy.metrics.MAPE", + ] +) +def mean_absolute_percentage_error(y_true, y_pred): + """Computes the mean absolute percentage error between `y_true` & `y_pred`. + + Formula: + + ```python + loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1) + ``` + + Division by zero is prevented by dividing by `maximum(y_true, epsilon)` + where `epsilon = keras.backend.epsilon()` + (default to `1e-7`). + + Args: + y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Mean absolute percentage error values with shape = `[batch_size, d0, .. + dN-1]`. + + Example: + + >>> y_true = np.random.random(size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.mean_absolute_percentage_error(y_true, y_pred) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + diff = ops.abs((y_true - y_pred) / ops.maximum(ops.abs(y_true), epsilon)) + return 100.0 * ops.mean(diff, axis=-1) + + +@keras_export( + [ + "keras.metrics.mean_squared_logarithmic_error", + "keras.losses.mean_squared_logarithmic_error", + # Legacy aliases + "keras._legacy.losses.msle", + "keras._legacy.losses.MSLE", + "keras._legacy.metrics.msle", + "keras._legacy.metrics.MSLE", + ] +) +def mean_squared_logarithmic_error(y_true, y_pred): + """Computes the mean squared logarithmic error between `y_true` & `y_pred`. + + Formula: + + ```python + loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1) + ``` + + Note that `y_pred` and `y_true` cannot be less or equal to 0. Negative + values and 0 values will be replaced with `keras.backend.epsilon()` + (default to `1e-7`). + + Args: + y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Mean squared logarithmic error values with shape = `[batch_size, d0, .. + dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.mean_squared_logarithmic_error(y_true, y_pred) + """ + epsilon = ops.convert_to_tensor(backend.epsilon()) + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + first_log = ops.log(ops.maximum(y_pred, epsilon) + 1.0) + second_log = ops.log(ops.maximum(y_true, epsilon) + 1.0) + return ops.mean(ops.square(first_log - second_log), axis=-1) + + +@keras_export("keras.losses.cosine_similarity") +def cosine_similarity(y_true, y_pred, axis=-1): + """Computes the cosine similarity between labels and predictions. + + Formula: + ```python + loss = -sum(l2_norm(y_true) * l2_norm(y_pred)) + ``` + + Note that it is a number between -1 and 1. When it is a negative number + between -1 and 0, 0 indicates orthogonality and values closer to -1 + indicate greater similarity. This makes it usable as a loss function in a + setting where you try to maximize the proximity between predictions and + targets. If either `y_true` or `y_pred` is a zero vector, cosine + similarity will be 0 regardless of the proximity between predictions + and targets. + + Args: + y_true: Tensor of true targets. + y_pred: Tensor of predicted targets. + axis: Axis along which to determine similarity. Defaults to `-1`. + + Returns: + Cosine similarity tensor. + + Example: + + >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] + >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] + >>> loss = keras.losses.cosine_similarity(y_true, y_pred, axis=-1) + [-0., -0.99999994, 0.99999994] + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + y_pred = normalize(y_pred, axis=axis) + y_true = normalize(y_true, axis=axis) + return -ops.sum(y_true * y_pred, axis=axis) + + +@keras_export(["keras.losses.huber", "keras.metrics.huber"]) +def huber(y_true, y_pred, delta=1.0): + """Computes Huber loss value. + + Formula: + ```python + for x in error: + if abs(x) <= delta: + loss.append(0.5 * x^2) + elif abs(x) > delta: + loss.append(delta * abs(x) - 0.5 * delta^2) + + loss = mean(loss, axis=-1) + ``` + See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss). + + Example: + + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] + >>> loss = keras.losses.huber(y_true, y_pred) + 0.155 + + + Args: + y_true: tensor of true targets. + y_pred: tensor of predicted targets. + delta: A float, the point where the Huber loss function changes from a + quadratic to linear. Defaults to `1.0`. + + Returns: + Tensor with one scalar loss entry per sample. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + delta = ops.convert_to_tensor(delta, dtype=y_pred.dtype) + error = ops.subtract(y_pred, y_true) + abs_error = ops.abs(error) + half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype) + return ops.mean( + ops.where( + abs_error <= delta, + half * ops.square(error), + delta * abs_error - half * ops.square(delta), + ), + axis=-1, + ) + + +@keras_export( + [ + "keras.losses.log_cosh", + "keras.metrics.log_cosh", + # Legacy aliases + "keras._legacy.losses.logcosh", + "keras._legacy.metrics.logcosh", + ] +) +def log_cosh(y_true, y_pred): + """Logarithm of the hyperbolic cosine of the prediction error. + + Formula: + ```python + loss = mean(log(cosh(y_pred - y_true)), axis=-1) + ``` + + Note that `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small + `x` and to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works + mostly like the mean squared error, but will not be so strongly affected by + the occasional wildly incorrect prediction. + + Example: + + >>> y_true = [[0., 1.], [0., 0.]] + >>> y_pred = [[1., 1.], [0., 0.]] + >>> loss = keras.losses.log_cosh(y_true, y_pred) + 0.108 + + Args: + y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`. + + Returns: + Logcosh error values with shape = `[batch_size, d0, .. dN-1]`. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + log2 = ops.convert_to_tensor(ops.log(2.0), dtype=y_pred.dtype) + + def _logcosh(x): + return x + ops.softplus(x * -2.0) - log2 + + return ops.mean(_logcosh(y_pred - y_true), axis=-1) + + +@keras_export( + [ + "keras.metrics.kl_divergence", + "keras.losses.kl_divergence", + # Legacy aliases + "keras._legacy.losses.KLD", + "keras._legacy.losses.kld", + "keras._legacy.losses.kullback_leibler_divergence", + "keras._legacy.metrics.KLD", + "keras._legacy.metrics.kld", + "keras._legacy.metrics.kullback_leibler_divergence", + ] +) +def kl_divergence(y_true, y_pred): + """Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`. + + Formula: + + ```python + loss = y_true * log(y_true / y_pred) + ``` + + `y_true` and `y_pred` are expected to be probability + distributions, with values between 0 and 1. They will get + clipped to the `[0, 1]` range. + + Args: + y_true: Tensor of true targets. + y_pred: Tensor of predicted targets. + + Returns: + KL Divergence loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float32) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.kl_divergence(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> y_true = ops.clip(y_true, 1e-7, 1) + >>> y_pred = ops.clip(y_pred, 1e-7, 1) + >>> assert np.array_equal( + ... loss, np.sum(y_true * np.log(y_true / y_pred), axis=-1)) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, y_pred.dtype) + y_true = ops.clip(y_true, backend.epsilon(), 1) + y_pred = ops.clip(y_pred, backend.epsilon(), 1) + return ops.sum(y_true * ops.log(y_true / y_pred), axis=-1) + + +@keras_export( + [ + "keras.metrics.poisson", + "keras.losses.poisson", + ] +) +def poisson(y_true, y_pred): + """Computes the Poisson loss between y_true and y_pred. + + Formula: + + ```python + loss = y_pred - y_true * log(y_pred) + ``` + + Args: + y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. + + Returns: + Poisson loss values with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = np.random.randint(0, 2, size=(2, 3)) + >>> y_pred = np.random.random(size=(2, 3)) + >>> loss = keras.losses.poisson(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> y_pred = y_pred + 1e-7 + >>> assert np.allclose( + ... loss, np.mean(y_pred - y_true * np.log(y_pred), axis=-1), + ... atol=1e-5) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype) + return ops.mean(y_pred - y_true * ops.log(y_pred + epsilon), axis=-1) + + +@keras_export( + [ + "keras.metrics.categorical_crossentropy", + "keras.losses.categorical_crossentropy", + ] +) +def categorical_crossentropy( + y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1 +): + """Computes the categorical crossentropy loss. + + Args: + y_true: Tensor of one-hot true targets. + y_pred: Tensor of predicted targets. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. + axis: Defaults to `-1`. The dimension along which the entropy is + computed. + + Returns: + Categorical crossentropy loss value. + + Example: + + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> loss = keras.losses.categorical_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.0513, 2.303], dtype=float32) + """ + if isinstance(axis, bool): + raise ValueError( + "`axis` must be of type `int`. " + f"Received: axis={axis} of type {type(axis)}" + ) + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if y_pred.shape[-1] == 1: + warnings.warn( + "In loss categorical_crossentropy, expected " + "y_pred.shape to be (batch_size, num_classes) " + f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. " + "Consider using 'binary_crossentropy' if you only have 2 classes.", + SyntaxWarning, + stacklevel=2, + ) + + if label_smoothing: + num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype) + y_true = y_true * (1.0 - label_smoothing) + ( + label_smoothing / num_classes + ) + + return ops.categorical_crossentropy( + y_true, y_pred, from_logits=from_logits, axis=axis + ) + + +@keras_export( + [ + "keras.metrics.categorical_focal_crossentropy", + "keras.losses.categorical_focal_crossentropy", + ] +) +def categorical_focal_crossentropy( + y_true, + y_pred, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, +): + """Computes the categorical focal crossentropy loss. + + Args: + y_true: Tensor of one-hot true targets. + y_pred: Tensor of predicted targets. + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple examples in a smooth manner. When `gamma` = 0, there is + no focal effect on the categorical crossentropy. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability + distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. + axis: Defaults to `-1`. The dimension along which the entropy is + computed. + + Returns: + Categorical focal crossentropy loss value. + + Example: + + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]] + >>> loss = keras.losses.categorical_focal_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([2.63401289e-04, 6.75912094e-01], dtype=float32) + """ + if isinstance(axis, bool): + raise ValueError( + "`axis` must be of type `int`. " + f"Received: axis={axis} of type {type(axis)}" + ) + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if y_pred.shape[-1] == 1: + warnings.warn( + "In loss categorical_focal_crossentropy, expected " + "y_pred.shape to be (batch_size, num_classes) " + f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. " + "Consider using 'binary_crossentropy' if you only have 2 classes.", + SyntaxWarning, + stacklevel=2, + ) + + if label_smoothing: + num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype) + y_true = y_true * (1.0 - label_smoothing) + ( + label_smoothing / num_classes + ) + + if from_logits: + y_pred = ops.softmax(y_pred, axis=axis) + + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. + output = y_pred / ops.sum(y_pred, axis=axis, keepdims=True) + output = ops.clip(output, backend.epsilon(), 1.0 - backend.epsilon()) + + # Calculate cross entropy + cce = -y_true * ops.log(output) + + # Calculate factors + modulating_factor = ops.power(1.0 - output, gamma) + weighting_factor = ops.multiply(modulating_factor, alpha) + + # Apply weighting factor + focal_cce = ops.multiply(weighting_factor, cce) + focal_cce = ops.sum(focal_cce, axis=axis) + return focal_cce + + +@keras_export( + [ + "keras.metrics.sparse_categorical_crossentropy", + "keras.losses.sparse_categorical_crossentropy", + ] +) +def sparse_categorical_crossentropy( + y_true, y_pred, from_logits=False, ignore_class=None, axis=-1 +): + """Computes the sparse categorical crossentropy loss. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + ignore_class: Optional integer. The ID of a class to be ignored during + loss computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. By default (`ignore_class=None`), all classes are + considered. + axis: Defaults to `-1`. The dimension along which the entropy is + computed. + + Returns: + Sparse categorical crossentropy loss value. + + Examples: + + >>> y_true = [1, 2] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.0513, 2.303], dtype=float32) + """ + + if len(y_true.shape) == len(y_pred.shape) and y_true.shape[-1] == 1: + y_true = ops.squeeze(y_true, axis=-1) + + if ignore_class is not None: + res_shape = ops.shape(y_pred)[:-1] + valid_mask = ops.not_equal(y_true, ops.cast(ignore_class, y_pred.dtype)) + y_true = y_true * ops.cast(valid_mask, y_true.dtype) + y_pred = y_pred * ops.cast( + ops.expand_dims(valid_mask, -1), y_pred.dtype + ) + + res = ops.sparse_categorical_crossentropy( + y_true, + y_pred, + from_logits=from_logits, + axis=axis, + ) + + if ignore_class is not None: + valid_mask = ops.reshape(valid_mask, res_shape) + res = ops.where(valid_mask, res, 0.0) + backend.set_keras_mask(res, mask=valid_mask) + + return res + + +@keras_export( + [ + "keras.metrics.binary_crossentropy", + "keras.losses.binary_crossentropy", + ] +) +def binary_crossentropy( + y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1 +): + """Computes the binary crossentropy loss. + + Args: + y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. + y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by + squeezing them towards 0.5, that is, + using `1. - 0.5 * label_smoothing` for the target class + and `0.5 * label_smoothing` for the non-target class. + axis: The axis along which the mean is computed. Defaults to `-1`. + + Returns: + Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] + >>> loss = keras.losses.binary_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss + array([0.916 , 0.714], dtype=float32) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if label_smoothing: + y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing + + return ops.mean( + ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits), + axis=axis, + ) + + +@keras_export( + [ + "keras.metrics.binary_focal_crossentropy", + "keras.losses.binary_focal_crossentropy", + ] +) +def binary_focal_crossentropy( + y_true, + y_pred, + apply_class_balancing=False, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, +): + """Computes the binary focal crossentropy loss. + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a focal factor to down-weight easy examples and focus more on + hard examples. By default, the focal tensor is computed as follows: + + `focal_factor = (1 - output) ** gamma` for class 1 + `focal_factor = output ** gamma` for class 0 + where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal + effect on the binary crossentropy loss. + + If `apply_class_balancing == True`, this function also takes into account a + weight balancing factor for the binary classes 0 and 1 as follows: + + `weight = alpha` for class 1 (`target == 1`) + `weight = 1 - alpha` for class 0 + where `alpha` is a float in the range of `[0, 1]`. + + Args: + y_true: Ground truth values, of shape `(batch_size, d0, .. dN)`. + y_pred: The predicted values, of shape `(batch_size, d0, .. dN)`. + apply_class_balancing: A bool, whether to apply weight balancing on the + binary classes 0 and 1. + alpha: A weight balancing factor for class 1, default is `0.25` as + mentioned in the reference. The weight for class 0 is `1.0 - alpha`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by + squeezing them towards 0.5, that is, + using `1. - 0.5 * label_smoothing` for the target class + and `0.5 * label_smoothing` for the non-target class. + axis: The axis along which the mean is computed. Defaults to `-1`. + + Returns: + Binary focal crossentropy loss value + with shape = `[batch_size, d0, .. dN-1]`. + + Example: + + >>> y_true = [[0, 1], [0, 0]] + >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] + >>> # In this instance, the first sample in the second batch is the + >>> # 'easier' example. + >>> focal_loss = keras.losses.binary_focal_crossentropy( + ... y_true, y_pred, gamma=2) + >>> assert loss.shape == (2,) + >>> focal_loss + array([0.330, 0.206], dtype=float32) + >>> # Compare with binary_crossentropy + >>> bce_loss = keras.losses.binary_focal_crossentropy( + ... y_true, y_pred) + >>> bce_loss + array([0.916, 0.714], dtype=float32) + >>> # Binary focal crossentropy loss attributes more importance to the + >>> # harder example which results in a higher loss for the first batch + >>> # when normalized by binary cross entropy loss + >>> focal_loss/bce_loss + array([0.360, 0.289] + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + if label_smoothing: + y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing + + if from_logits: + y_pred = ops.sigmoid(y_pred) + + bce = ops.binary_crossentropy( + target=y_true, + output=y_pred, + from_logits=False, + ) + + # Calculate focal factor + p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred) + focal_factor = ops.power(1.0 - p_t, gamma) + + focal_bce = focal_factor * bce + + if apply_class_balancing: + weight = y_true * alpha + (1 - y_true) * (1 - alpha) + focal_bce = weight * focal_bce + + return ops.mean(focal_bce, axis=axis) + + +@keras_export("keras.losses.ctc") +def ctc(y_true, y_pred): + """CTC (Connectionist Temporal Classification) loss. + + Args: + y_true: A tensor of shape `(batch_size, max_length)` containing + the true labels in integer format. `0` always represents + the blank/mask index and should not be used for classes. + y_pred: A tensor of shape `(batch_size, max_length, num_classes)` + containing logits (the output of your model). + They should *not* be normalized via softmax. + """ + if len(ops.shape(y_true)) != 2: + raise ValueError( + "Targets `y_true` are expected to be a tensor of shape " + "`(batch_size, max_length)` in integer format. " + f"Received: y_true.shape={ops.shape(y_true)}" + ) + if len(ops.shape(y_pred)) != 3: + raise ValueError( + "Logits `y_pred` are expected to be a tensor of shape " + "`(batch_size, max_length, num_classes)`. " + f"Received: y_pred.shape={ops.shape(y_pred)}" + ) + + mask_index = 0 + batch_length = ops.shape(y_pred)[0] + input_length = ops.shape(y_pred)[1] + input_length = input_length * ops.ones((batch_length,), dtype="int32") + label_length = ops.cast( + ops.sum(y_true != mask_index, axis=-1), dtype="int32" + ) + + return ops.ctc_loss( + y_true, y_pred, label_length, input_length, mask_index=mask_index + ) + + +@keras_export("keras.losses.dice") +def dice(y_true, y_pred, axis=None): + """Computes the Dice loss value between `y_true` and `y_pred`. + + Formula: + ```python + loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred)) + ``` + + Args: + y_true: tensor of true targets. + y_pred: tensor of predicted targets. + axis: tuple for which dimensions the loss is calculated + + Returns: + Dice loss value. + + Example: + + >>> y_true = [[[[1.0], [1.0]], [[0.0], [0.0]]], + ... [[[1.0], [1.0]], [[0.0], [0.0]]]] + >>> y_pred = [[[[0.0], [1.0]], [[0.0], [1.0]]], + ... [[[0.4], [0.0]], [[0.0], [0.9]]]] + >>> axis = (1, 2, 3) + >>> loss = keras.losses.dice(y_true, y_pred, axis=axis) + >>> assert loss.shape == (2,) + >>> loss + array([0.5, 0.75757575], shape=(2,), dtype=float32) + + >>> loss = keras.losses.dice(y_true, y_pred) + >>> assert loss.shape == () + >>> loss + array(0.6164384, shape=(), dtype=float32) + + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + inputs = y_true + targets = y_pred + + intersection = ops.sum(inputs * targets, axis=axis) + dice = ops.divide( + 2.0 * intersection, + ops.sum(y_true, axis=axis) + + ops.sum(y_pred, axis=axis) + + backend.epsilon(), + ) + + return 1 - dice + + +@keras_export("keras.losses.tversky") +def tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None): + """Computes the Tversky loss value between `y_true` and `y_pred`. + + This loss function is weighted by the alpha and beta coefficients + that penalize false positives and false negatives. + + With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to + Dice Loss. + + Args: + y_true: tensor of true targets. + y_pred: tensor of predicted targets. + alpha: coefficient controlling incidence of false positives. + beta: coefficient controlling incidence of false negatives. + axis: tuple for which dimensions the loss is calculated. + + Returns: + Tversky loss value. + + Reference: + + - [Salehi et al., 2017](https://arxiv.org/abs/1706.05721) + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + inputs = y_true + targets = y_pred + + intersection = ops.sum(inputs * targets, axis=axis) + fp = ops.sum((1 - targets) * inputs, axis=axis) + fn = ops.sum(targets * (1 - inputs), axis=axis) + + tversky = ops.divide( + intersection, + intersection + fp * alpha + fn * beta + backend.epsilon(), + ) + + return 1 - tversky + + +@keras_export("keras.losses.circle") +def circle( + y_true, + y_pred, + ref_labels=None, + ref_embeddings=None, + remove_diagonal=True, + gamma=80, + margin=0.4, +): + """Computes the Circle loss. + + It is designed to minimize within-class distances and maximize between-class + distances in L2 normalized embedding space. + + Args: + y_true: Tensor with ground truth labels in integer format. + y_pred: Tensor with predicted L2 normalized embeddings. + ref_labels: Optional integer tensor with labels for reference + embeddings. If `None`, defaults to `y_true`. + ref_embeddings: Optional tensor with L2 normalized reference embeddings. + If `None`, defaults to `y_pred`. + remove_diagonal: Boolean, whether to remove self-similarities from + positive mask. Defaults to `True`. + gamma: Float, scaling factor for the loss. Defaults to `80`. + margin: Float, relaxation factor for the loss. Defaults to `0.4`. + + Returns: + Circle loss value. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, "int32") + ref_embeddings = ( + y_pred + if ref_embeddings is None + else ops.convert_to_tensor(ref_embeddings) + ) + ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, "int32") + + optim_pos = margin + optim_neg = 1 + margin + delta_pos = margin + delta_neg = 1 - margin + + pairwise_cosine_distances = 1 - ops.matmul( + y_pred, ops.transpose(ref_embeddings) + ) + + pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0) + positive_mask, negative_mask = build_pos_neg_masks( + y_true, + ref_labels, + remove_diagonal=remove_diagonal, + ) + positive_mask = ops.cast( + positive_mask, dtype=pairwise_cosine_distances.dtype + ) + negative_mask = ops.cast( + negative_mask, dtype=pairwise_cosine_distances.dtype + ) + + pos_weights = optim_pos + pairwise_cosine_distances + pos_weights = pos_weights * positive_mask + pos_weights = ops.maximum(pos_weights, 0.0) + neg_weights = optim_neg - pairwise_cosine_distances + neg_weights = neg_weights * negative_mask + neg_weights = ops.maximum(neg_weights, 0.0) + + pos_dists = delta_pos - pairwise_cosine_distances + neg_dists = delta_neg - pairwise_cosine_distances + + pos_wdists = -1 * gamma * pos_weights * pos_dists + neg_wdists = gamma * neg_weights * neg_dists + + p_loss = ops.logsumexp( + ops.where(positive_mask, pos_wdists, float("-inf")), + axis=1, + ) + n_loss = ops.logsumexp( + ops.where(negative_mask, neg_wdists, float("-inf")), + axis=1, + ) + + circle_loss = ops.softplus(p_loss + n_loss) + backend.set_keras_mask(circle_loss, circle_loss > 0) + return circle_loss + + +@keras_export("keras.losses.categorical_generalized_cross_entropy") +def categorical_generalized_cross_entropy(y_true, y_pred, q): + """Computes the Generalized Cross Entropy loss. + + Generalized Cross Entropy (GCE) is a noise-robust loss function that + provides better robustness against noisy labels than standard cross entropy. + It generalizes both cross entropy and mean absolute error through + the parameter q, where values closer to 1 make the loss more robust + to noisy labels. + + Formula: + ```python + loss = (1 - p**q) / q + ``` + where `p` is the predicted probability for the true class and `q` + is the noise parameter. + + Args: + y_true: Ground truth labels. Expected to contain *integer class indices* + with shape `[batch_size]` or `[batch_size, 1]`. + y_pred: The predicted class probabilities, with shape + `[batch_size, num_classes]`. + q: Float in range `(0, 1)`. It is the noise parameter. + Controls the behavior of the loss: + - As `q` approaches 0: Behaves more like cross entropy + - As `q` approaches 1: Behaves more like mean absolute error + + Returns: + GCE loss values with shape `[batch_size]`. + ``` + + References: + - [Zhang, Sabuncu, 2018](https://arxiv.org/abs/1805.07836) + ("Generalized Cross Entropy Loss for Training + Deep Neural Networks with Noisy Labels") + """ + + # Convert y_true to integer type and one-hot encode + y_true_one_hot = ops.one_hot( + ops.cast(y_true, "int"), num_classes=ops.shape(y_pred)[-1] + ) + y_true_one_hot = ops.cast(y_true_one_hot, y_pred.dtype) + # Calculate the probability of the true class + p = ops.sum(y_pred * y_true_one_hot, axis=-1) + + # Compute the GCE loss for q in (0,1) + gce_loss = (1 - ops.power(p, q)) / q + + return gce_loss diff --git a/keras/src/losses/losses_test.py b/keras/src/losses/losses_test.py new file mode 100644 index 000000000000..fe0d557d96c9 --- /dev/null +++ b/keras/src/losses/losses_test.py @@ -0,0 +1,2263 @@ +import re + +import numpy as np +import pytest + +from keras.src import backend +from keras.src import testing +from keras.src.losses import losses + + +class MeanSquaredErrorTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test(losses.MeanSquaredError(name="mymse")) + + def test_base_function_reduction(self): + mse_fn = losses.mean_squared_error + y_true = np.array([4, 8, 12]) + y_pred = np.array([[3], [0], [1]]) + loss = mse_fn(y_true, y_pred) + self.assertEqual(backend.shape(loss), (3,)) + + def test_all_correct_unweighted(self): + mse_obj = losses.MeanSquaredError() + y_true = np.array([[4, 8, 12], [8, 1, 3]]) + loss = mse_obj(y_true, y_true) + self.assertAlmostEqual(loss, 0.0) + + def test_unweighted(self): + mse_obj = losses.MeanSquaredError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mse_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 49.5) + + def test_scalar_weighted(self): + mse_obj = losses.MeanSquaredError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mse_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 113.85) + + def test_sample_weighted(self): + mse_obj = losses.MeanSquaredError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 767.8 / 6) + + def test_timestep_weighted(self): + mse_obj = losses.MeanSquaredError() + y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1) + y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1) + sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + loss = mse_obj( + y_true, + y_pred, + sample_weight=sample_weight, + ) + self.assertAlmostEqual(loss, 97.833336) + + def test_zero_weighted(self): + mse_obj = losses.MeanSquaredError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mse_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0) + + def test_no_reduction(self): + mse_obj = losses.MeanSquaredError(reduction=None) + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mse_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, [84.3333, 143.3666]) + + def test_sum_reduction(self): + mse_obj = losses.MeanSquaredError(reduction="sum") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mse_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 227.69998) + + def test_mean_with_sample_weight_reduction(self): + mse_obj = losses.MeanSquaredError(reduction="mean_with_sample_weight") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual( + loss, (110 / 3 * 1.2 + 187 / 3 * 3.4) / (1.2 + 3.4) + ) + + def test_dtype_arg(self): + mse_obj = losses.MeanSquaredError(dtype="bfloat16") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mse_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class MeanAbsoluteErrorTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.MeanAbsoluteError(name="myname") + ) + + def test_all_correct_unweighted(self): + mae_obj = losses.MeanAbsoluteError() + y_true = np.array([[4, 8, 12], [8, 1, 3]]) + loss = mae_obj(y_true, y_true) + self.assertAlmostEqual(loss, 0.0) + + def test_unweighted(self): + mae_obj = losses.MeanAbsoluteError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mae_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 5.5) + + def test_scalar_weighted(self): + mae_obj = losses.MeanAbsoluteError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mae_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 12.65) + + def test_sample_weighted(self): + mae_obj = losses.MeanAbsoluteError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = mae_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 81.4 / 6) + + def test_timestep_weighted(self): + mae_obj = losses.MeanAbsoluteError() + y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1) + y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1) + sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + loss = mae_obj( + y_true, + y_pred, + sample_weight=sample_weight, + ) + self.assertAlmostEqual(loss, 13.833333) + + def test_zero_weighted(self): + mae_obj = losses.MeanAbsoluteError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mae_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0) + + def test_no_reduction(self): + mae_obj = losses.MeanAbsoluteError(reduction=None) + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mae_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, [10.7333, 14.5666]) + + def test_sum_reduction(self): + mae_obj = losses.MeanAbsoluteError(reduction="sum") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mae_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 25.29999) + + def test_mean_with_sample_weight_reduction(self): + mae_obj = losses.MeanAbsoluteError(reduction="mean_with_sample_weight") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = mae_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual( + loss, (14 / 3 * 1.2 + 19 / 3 * 3.4) / (1.2 + 3.4) + ) + + def test_dtype_arg(self): + mae_obj = losses.MeanAbsoluteError(dtype="bfloat16") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mae_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class MeanAbsolutePercentageErrorTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.MeanAbsolutePercentageError(name="mymape") + ) + + def test_all_correct_unweighted(self): + mape_obj = losses.MeanAbsolutePercentageError() + y_true = np.array([[4, 8, 12], [8, 1, 3]]) + loss = mape_obj(y_true, y_true) + self.assertAlmostEqual(loss, 0.0) + + def test_unweighted(self): + mape_obj = losses.MeanAbsolutePercentageError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mape_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 211.8518, 3) + + def test_scalar_weighted(self): + mape_obj = losses.MeanAbsolutePercentageError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mape_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 487.259, 3) + + def test_sample_weighted(self): + mape_obj = losses.MeanAbsolutePercentageError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = mape_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 422.8888, 3) + + def test_timestep_weighted(self): + mape_obj = losses.MeanAbsolutePercentageError() + y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1) + y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1) + sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + loss = mape_obj( + y_true, + y_pred, + sample_weight=sample_weight, + ) + self.assertAlmostEqual(loss, 694.4444) + + def test_zero_weighted(self): + mape_obj = losses.MeanAbsolutePercentageError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mape_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_no_reduction(self): + mape_obj = losses.MeanAbsolutePercentageError(reduction=None) + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mape_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, [621.8518, 352.6666]) + + def test_mean_with_sample_weight_reduction(self): + mape_obj = losses.MeanAbsolutePercentageError( + reduction="mean_with_sample_weight" + ) + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = mape_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 183.865) + + def test_dtype_arg(self): + mape_obj = losses.MeanAbsolutePercentageError(dtype="bfloat16") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = mape_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class MeanSquaredLogarithmicErrorTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.MeanSquaredLogarithmicError(name="mysloge") + ) + + def test_unweighted(self): + msle_obj = losses.MeanSquaredLogarithmicError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = msle_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 1.4370, 3) + + def test_scalar_weighted(self): + msle_obj = losses.MeanSquaredLogarithmicError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = msle_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 3.3051, 3) + + def test_sample_weighted(self): + msle_obj = losses.MeanSquaredLogarithmicError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = msle_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 3.7856, 3) + + def test_timestep_weighted(self): + msle_obj = losses.MeanSquaredLogarithmicError() + y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1) + y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1) + sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + loss = msle_obj( + y_true, + y_pred, + sample_weight=sample_weight, + ) + self.assertAlmostEqual(loss, 2.647374) + + def test_zero_weighted(self): + msle_obj = losses.MeanSquaredLogarithmicError() + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = msle_obj(y_true, y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_mean_with_sample_weight_reduction(self): + msle_obj = losses.MeanSquaredLogarithmicError( + reduction="mean_with_sample_weight" + ) + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + loss = msle_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 1.646) + + def test_dtype_arg(self): + msle_obj = losses.MeanSquaredLogarithmicError(dtype="bfloat16") + y_true = np.array([[1, 9, 2], [-5, -2, 6]]) + y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype="float32") + loss = msle_obj(y_true, y_pred, sample_weight=2.3) + self.assertDType(loss, "bfloat16") + + +class HingeTest(testing.TestCase): + def test_unweighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + + # Reduction = "sum_over_batch_size" + hinge_obj = losses.Hinge(reduction="sum_over_batch_size") + loss = hinge_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 1.3, 3) + + # Reduction = "sum" + hinge_obj = losses.Hinge(reduction="sum") + loss = hinge_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 2.6, 3) + + # Reduction = None + hinge_obj = losses.Hinge(reduction=None) + loss = hinge_obj(y_true, y_pred) + self.assertAllClose(loss, [1.1, 1.5]) + + # Bad reduction + with self.assertRaisesRegex(ValueError, "Invalid value for argument"): + losses.Hinge(reduction="abc") + + def test_weighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + sample_weight = [1, 0] + + # Reduction = "sum_over_batch_size" + hinge_obj = losses.Hinge(reduction="sum_over_batch_size") + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.55, 3) + + # Reduction = "sum" + hinge_obj = losses.Hinge(reduction="sum") + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 1.1, 3) + + # Reduction = None + hinge_obj = losses.Hinge(reduction=None) + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, [1.1, 0.0]) + + def test_zero_weighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + sample_weight = 0.0 + + hinge_obj = losses.Hinge() + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(loss, 0.0) + + def test_dtype_arg(self): + hinge_obj = losses.Hinge(dtype="bfloat16") + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + loss = hinge_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class SquaredHingeTest(testing.TestCase): + def test_unweighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + + # Reduction = "sum_over_batch_size" + hinge_obj = losses.SquaredHinge(reduction="sum_over_batch_size") + loss = hinge_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 1.86, 3) + + # Reduction = "sum" + hinge_obj = losses.SquaredHinge(reduction="sum") + loss = hinge_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 3.72, 3) + + # Reduction = None + hinge_obj = losses.SquaredHinge(reduction=None) + loss = hinge_obj(y_true, y_pred) + self.assertAllClose(loss, [1.46, 2.26]) + + # Bad reduction + with self.assertRaisesRegex(ValueError, "Invalid value for argument"): + losses.SquaredHinge(reduction="abc") + + def test_weighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + sample_weight = [1, 0] + + # Reduction = "sum_over_batch_size" + hinge_obj = losses.SquaredHinge(reduction="sum_over_batch_size") + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.73, 3) + + # Reduction = "sum" + hinge_obj = losses.SquaredHinge(reduction="sum") + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 1.46, 3) + + # Reduction = None + hinge_obj = losses.SquaredHinge(reduction=None) + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, [1.46, 0.0]) + + def test_zero_weighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + sample_weight = 0.0 + + hinge_obj = losses.SquaredHinge() + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(loss, 0.0) + + def test_dtype_arg(self): + hinge_obj = losses.SquaredHinge(dtype="bfloat16") + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + loss = hinge_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class CategoricalHingeTest(testing.TestCase): + def test_unweighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + + # Reduction = "sum_over_batch_size" + hinge_obj = losses.CategoricalHinge(reduction="sum_over_batch_size") + loss = hinge_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 1.4, 3) + + # Reduction = "sum" + hinge_obj = losses.CategoricalHinge(reduction="sum") + loss = hinge_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 2.8, 3) + + # Reduction = None + hinge_obj = losses.CategoricalHinge(reduction=None) + loss = hinge_obj(y_true, y_pred) + self.assertAllClose(loss, [1.2, 1.6]) + + # Bad reduction + with self.assertRaisesRegex(ValueError, "Invalid value for argument"): + losses.CategoricalHinge(reduction="abc") + + def test_weighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + sample_weight = [1, 0] + + # Reduction = "sum_over_batch_size" + hinge_obj = losses.CategoricalHinge(reduction="sum_over_batch_size") + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.6, 3) + + # Reduction = "sum" + hinge_obj = losses.CategoricalHinge(reduction="sum") + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 1.2, 3) + + # Reduction = None + hinge_obj = losses.CategoricalHinge(reduction=None) + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, [1.2, 0.0]) + + def test_zero_weighted(self): + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + sample_weight = 0.0 + + hinge_obj = losses.CategoricalHinge() + loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertEqual(loss, 0.0) + + def test_dtype_arg(self): + hinge_obj = losses.CategoricalHinge(dtype="bfloat16") + y_true = np.array([[0.0, 1.0], [0.0, 0.0]]) + y_pred = np.array([[0.6, 0.4], [0.4, 0.6]]) + loss = hinge_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class CosineSimilarityTest(testing.TestCase): + def l2_norm(self, x, axis): + epsilon = 1e-12 + square_sum = np.sum(np.square(x), axis=axis, keepdims=True) + x_inv_norm = 1 / np.sqrt(np.maximum(square_sum, epsilon)) + return np.multiply(x, x_inv_norm) + + def setup(self, axis=1): + self.np_y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32) + self.np_y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32) + + y_true = self.l2_norm(self.np_y_true, axis) + y_pred = self.l2_norm(self.np_y_pred, axis) + self.expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(axis,)) + + self.y_true = self.np_y_true + self.y_pred = self.np_y_pred + + def test_config(self): + cosine_obj = losses.CosineSimilarity( + axis=2, reduction="sum", name="cosine_loss" + ) + self.assertEqual(cosine_obj.name, "cosine_loss") + self.assertEqual(cosine_obj.reduction, "sum") + config = cosine_obj.get_config() + self.assertEqual(config, {"name": "cosine_loss", "reduction": "sum"}) + + def test_unweighted(self): + self.setup() + cosine_obj = losses.CosineSimilarity() + loss = cosine_obj(self.y_true, self.y_pred) + expected_loss = -np.mean(self.expected_loss) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_scalar_weighted(self): + self.setup() + cosine_obj = losses.CosineSimilarity() + sample_weight = 2.3 + loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + expected_loss = -np.mean(self.expected_loss * sample_weight) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_sample_weighted(self): + self.setup() + cosine_obj = losses.CosineSimilarity() + sample_weight = np.asarray([1.2, 3.4]) + loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + expected_loss = -np.mean(self.expected_loss * sample_weight) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_timestep_weighted(self): + self.setup() + cosine_obj = losses.CosineSimilarity() + np_y_true = self.np_y_true.reshape((2, 3, 1)) + np_y_pred = self.np_y_pred.reshape((2, 3, 1)) + sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + + y_true = self.l2_norm(np_y_true, 2) + y_pred = self.l2_norm(np_y_pred, 2) + expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(2,)) + + y_true = np_y_true + y_pred = np_y_pred + loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight) + + expected_loss = -np.mean(expected_loss * sample_weight) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_zero_weighted(self): + self.setup() + cosine_obj = losses.CosineSimilarity() + loss = cosine_obj(self.y_true, self.y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_axis(self): + self.setup(axis=1) + cosine_obj = losses.CosineSimilarity(axis=1) + loss = cosine_obj(self.y_true, self.y_pred) + expected_loss = -np.mean(self.expected_loss) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_dtype_arg(self): + self.setup() + cosine_obj = losses.CosineSimilarity(dtype="bfloat16") + loss = cosine_obj(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") + + +class HuberLossTest(testing.TestCase): + def huber_loss(self, y_true, y_pred, delta=1.0): + error = y_pred - y_true + abs_error = np.abs(error) + + quadratic = np.minimum(abs_error, delta) + linear = np.subtract(abs_error, quadratic) + return np.add( + np.multiply(0.5, np.multiply(quadratic, quadratic)), + np.multiply(delta, linear), + ) + + def setup(self, delta=1.0): + self.np_y_pred = np.array([[0.9, 0.2, 0.2], [0.8, 0.4, 0.6]]) + self.np_y_true = np.array([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + + self.batch_size = 6 + self.expected_losses = self.huber_loss( + self.np_y_true, self.np_y_pred, delta + ) + + self.y_pred = self.np_y_pred + self.y_true = self.np_y_true + + def test_config(self): + h_obj = losses.Huber(reduction="sum", name="huber") + self.assertEqual(h_obj.name, "huber") + self.assertEqual(h_obj.reduction, "sum") + config = h_obj.get_config() + self.assertEqual(config, {"name": "huber", "reduction": "sum"}) + + def test_all_correct(self): + self.setup() + h_obj = losses.Huber() + loss = h_obj(self.y_true, self.y_true) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_unweighted(self): + self.setup() + h_obj = losses.Huber() + loss = h_obj(self.y_true, self.y_pred) + actual_loss = np.sum(self.expected_losses) / self.batch_size + self.assertAlmostEqual(loss, actual_loss, 3) + + def test_scalar_weighted(self): + self.setup() + h_obj = losses.Huber() + sample_weight = 2.3 + loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + actual_loss = ( + sample_weight * np.sum(self.expected_losses) / self.batch_size + ) + self.assertAlmostEqual(loss, actual_loss, 3) + + # Verify we get the same output when the same input is given + loss_2 = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, loss_2, 3) + + def test_sample_weighted(self): + self.setup() + h_obj = losses.Huber() + sample_weight = np.array([[1.2], [3.4]]) + + loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + actual_loss = np.multiply( + self.expected_losses, + np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)), + ) + actual_loss = np.sum(actual_loss) / self.batch_size + self.assertAlmostEqual(loss, actual_loss, 3) + + def test_timestep_weighted(self): + self.setup() + h_obj = losses.Huber() + y_pred = self.np_y_pred.reshape((2, 3, 1)) + y_true = self.np_y_true.reshape((2, 3, 1)) + expected_losses = self.huber_loss(y_true, y_pred) + + sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1)) + loss = h_obj( + y_true, + y_pred, + sample_weight=sample_weight, + ) + actual_loss = np.multiply(expected_losses, sample_weight) + actual_loss = np.sum(actual_loss) / self.batch_size + self.assertAlmostEqual(loss, actual_loss, 3) + + def test_zero_weighted(self): + self.setup() + h_obj = losses.Huber() + sample_weight = 0 + loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_non_default_delta(self): + self.setup(delta=0.8) + h_obj = losses.Huber(delta=0.8) + sample_weight = 2.3 + loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + actual_loss = ( + sample_weight * np.sum(self.expected_losses) / self.batch_size + ) + self.assertAlmostEqual(loss, actual_loss, 3) + + def test_dtype_arg(self): + self.setup() + h_obj = losses.Huber(dtype="bfloat16") + loss = h_obj(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") + + +class LogCoshTest(testing.TestCase): + def setup(self): + y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32) + y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32) + + self.batch_size = 6 + error = y_pred - y_true + self.expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2) + + self.y_true = y_true + self.y_pred = y_pred + + def test_config(self): + logcosh_obj = losses.LogCosh(reduction="sum", name="logcosh_loss") + self.assertEqual(logcosh_obj.name, "logcosh_loss") + self.assertEqual(logcosh_obj.reduction, "sum") + config = logcosh_obj.get_config() + self.assertEqual(config, {"name": "logcosh_loss", "reduction": "sum"}) + + def test_unweighted(self): + self.setup() + logcosh_obj = losses.LogCosh() + + loss = logcosh_obj(self.y_true, self.y_pred) + expected_loss = np.sum(self.expected_losses) / self.batch_size + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_scalar_weighted(self): + self.setup() + logcosh_obj = losses.LogCosh() + sample_weight = 2.3 + + loss = logcosh_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + expected_loss = ( + sample_weight * np.sum(self.expected_losses) / self.batch_size + ) + self.assertAlmostEqual(loss, expected_loss, 3) + + # Verify we get the same output when the same input is given + loss_2 = logcosh_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, loss_2, 3) + + def test_sample_weighted(self): + self.setup() + logcosh_obj = losses.LogCosh() + + sample_weight = np.asarray([1.2, 3.4]) + loss = logcosh_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + + expected_loss = np.multiply( + self.expected_losses, + np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)), + ) + expected_loss = np.sum(expected_loss) / self.batch_size + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_timestep_weighted(self): + self.setup() + logcosh_obj = losses.LogCosh() + y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1) + y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1) + error = y_pred - y_true + expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2) + sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1)) + + loss = logcosh_obj( + y_true, + y_pred, + sample_weight=sample_weight, + ) + expected_loss = ( + np.sum(expected_losses * sample_weight) / self.batch_size + ) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_zero_weighted(self): + self.setup() + logcosh_obj = losses.LogCosh() + sample_weight = 0 + loss = logcosh_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_dtype_arg(self): + self.setup() + logcosh_obj = losses.LogCosh(dtype="bfloat16") + loss = logcosh_obj(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") + + +class KLDivergenceTest(testing.TestCase): + def setup(self): + self.y_pred = np.asarray( + [0.4, 0.9, 0.12, 0.36, 0.3, 0.4], dtype=np.float32 + ).reshape((2, 3)) + self.y_true = np.asarray( + [0.5, 0.8, 0.12, 0.7, 0.43, 0.8], dtype=np.float32 + ).reshape((2, 3)) + + self.batch_size = 2 + self.expected_losses = np.multiply( + self.y_true, np.log(self.y_true / self.y_pred) + ) + + def test_config(self): + k_obj = losses.KLDivergence(reduction="sum", name="kld") + self.assertEqual(k_obj.name, "kld") + self.assertEqual(k_obj.reduction, "sum") + + def test_unweighted(self): + self.setup() + k_obj = losses.KLDivergence() + + loss = k_obj(self.y_true, self.y_pred) + expected_loss = np.sum(self.expected_losses) / self.batch_size + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_scalar_weighted(self): + self.setup() + k_obj = losses.KLDivergence() + sample_weight = 2.3 + + loss = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + expected_loss = ( + sample_weight * np.sum(self.expected_losses) / self.batch_size + ) + self.assertAlmostEqual(loss, expected_loss, 3) + + # Verify we get the same output when the same input is given + loss_2 = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, loss_2, 3) + + def test_sample_weighted(self): + self.setup() + k_obj = losses.KLDivergence() + sample_weight = np.asarray([1.2, 3.4], dtype=np.float32).reshape((2, 1)) + loss = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + + expected_loss = np.multiply( + self.expected_losses, + np.asarray( + [1.2, 1.2, 1.2, 3.4, 3.4, 3.4], dtype=np.float32 + ).reshape(2, 3), + ) + expected_loss = np.sum(expected_loss) / self.batch_size + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_timestep_weighted(self): + self.setup() + k_obj = losses.KLDivergence() + y_true = self.y_true.reshape(2, 3, 1) + y_pred = self.y_pred.reshape(2, 3, 1) + sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape(2, 3) + expected_losses = np.sum( + np.multiply(y_true, np.log(y_true / y_pred)), axis=-1 + ) + loss = k_obj(y_true, y_pred, sample_weight=sample_weight) + + num_timesteps = 3 + expected_loss = np.sum(expected_losses * sample_weight) / ( + self.batch_size * num_timesteps + ) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_zero_weighted(self): + self.setup() + k_obj = losses.KLDivergence() + loss = k_obj(self.y_true, self.y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_dtype_arg(self): + self.setup() + k_obj = losses.KLDivergence(dtype="bfloat16") + loss = k_obj(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") + + +class PoissonTest(testing.TestCase): + def setup(self): + self.y_pred = np.asarray([1, 9, 2, 5, 2, 6], dtype=np.float32).reshape( + (2, 3) + ) + self.y_true = np.asarray([4, 8, 12, 8, 1, 3], dtype=np.float32).reshape( + (2, 3) + ) + + self.batch_size = 6 + self.expected_losses = self.y_pred - np.multiply( + self.y_true, np.log(self.y_pred) + ) + + def test_config(self): + poisson_obj = losses.Poisson(reduction="sum", name="poisson") + self.assertEqual(poisson_obj.name, "poisson") + self.assertEqual(poisson_obj.reduction, "sum") + + def test_unweighted(self): + self.setup() + poisson_obj = losses.Poisson() + + loss = poisson_obj(self.y_true, self.y_pred) + expected_loss = np.sum(self.expected_losses) / self.batch_size + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_scalar_weighted(self): + self.setup() + poisson_obj = losses.Poisson() + sample_weight = 2.3 + loss = poisson_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + expected_loss = ( + sample_weight * np.sum(self.expected_losses) / self.batch_size + ) + self.assertAlmostEqual(loss, expected_loss, 3) + self.assertAlmostEqual(loss, expected_loss, 3) + + # Verify we get the same output when the same input is given + loss_2 = poisson_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, loss_2, 3) + + def test_sample_weighted(self): + self.setup() + poisson_obj = losses.Poisson() + + sample_weight = np.asarray([1.2, 3.4]).reshape((2, 1)) + loss = poisson_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + + expected_loss = np.multiply( + self.expected_losses, + np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)), + ) + expected_loss = np.sum(expected_loss) / self.batch_size + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_timestep_weighted(self): + self.setup() + poisson_obj = losses.Poisson() + y_true = self.y_true.reshape(2, 3, 1) + y_pred = self.y_pred.reshape(2, 3, 1) + sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape(2, 3, 1) + expected_losses = y_pred - np.multiply(y_true, np.log(y_pred)) + + loss = poisson_obj( + y_true, + y_pred, + sample_weight=np.asarray(sample_weight).reshape((2, 3)), + ) + expected_loss = ( + np.sum(expected_losses * sample_weight) / self.batch_size + ) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_zero_weighted(self): + self.setup() + poisson_obj = losses.Poisson() + loss = poisson_obj(self.y_true, self.y_pred, sample_weight=0) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_dtype_arg(self): + self.setup() + poisson_obj = losses.Poisson(dtype="bfloat16") + loss = poisson_obj(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") + + +class BinaryCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.BinaryCrossentropy(name="bce", axis=-1) + ) + + def test_all_correct_unweighted(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="float32") + bce_obj = losses.BinaryCrossentropy() + loss = bce_obj(y_true, y_true) + self.assertAlmostEqual(loss, 0.0) + + # Test with logits. + logits = np.array( + [ + [10.0, -10.0, -10.0], + [-10.0, 10.0, -10.0], + [-10.0, -10.0, 10.0], + ] + ) + bce_obj = losses.BinaryCrossentropy(from_logits=True) + loss = bce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0) + + def test_unweighted(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="float32") + y_pred = np.array( + [[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype="float32" + ) + bce_obj = losses.BinaryCrossentropy() + loss = bce_obj(y_true, y_pred) + self.assertAllClose(loss, 0.20046903) + + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2]) + bce_obj = losses.BinaryCrossentropy() + loss = bce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 3.98559) + + # Test with logits. + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + bce_obj = losses.BinaryCrossentropy(from_logits=True) + loss = bce_obj(y_true, logits) + self.assertAlmostEqual(loss, 3.3333) + + def test_scalar_weighted(self): + bce_obj = losses.BinaryCrossentropy() + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype="float32").reshape([2, 2]) + loss = bce_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 9.1668) + + # Test with logits. + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + bce_obj = losses.BinaryCrossentropy(from_logits=True) + loss = bce_obj(y_true, logits, sample_weight=2.3) + self.assertAlmostEqual(loss, 7.666) + + def test_sample_weighted(self): + bce_obj = losses.BinaryCrossentropy() + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype="float32").reshape([2, 2]) + sample_weight = np.array([1.2, 3.4]).reshape((2, 1)) + loss = bce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 4.7827) + + # Test with logits. + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + weights = np.array([4, 3]) + bce_obj = losses.BinaryCrossentropy(from_logits=True) + loss = bce_obj(y_true, logits, sample_weight=weights) + self.assertAlmostEqual(loss, 10.0) + + def test_no_reduction(self): + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + bce_obj = losses.BinaryCrossentropy(from_logits=True, reduction=None) + loss = bce_obj(y_true, logits) + self.assertAllClose(loss, [0.0, 6.666], atol=1e-3) + + def test_label_smoothing(self): + logits = np.array([[10.0, -10.0, -10.0]]) + y_true = np.array([[1, 0, 1]]) + label_smoothing = 0.1 + bce_obj = losses.BinaryCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = bce_obj(y_true, logits) + expected_value = (10.0 + 5.0 * label_smoothing) / 3.0 + self.assertAlmostEqual(loss, expected_value) + + def test_shape_mismatch(self): + y_true = np.array([[0], [1], [2]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]] + ) + cce_obj = losses.BinaryCrossentropy() + with self.assertRaisesRegex(ValueError, "must have the same shape"): + cce_obj(y_true, y_pred) + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Torch doesn't support bfloat16 for BinaryCrossentropy", + ) + def test_dtype_arg(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="float32") + y_pred = np.array( + [[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype="float32" + ) + bce_obj = losses.BinaryCrossentropy(dtype="bfloat16") + loss = bce_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class CategoricalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.CategoricalCrossentropy(name="cce", axis=-1) + ) + + def test_all_correct_unweighted(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="int64") + y_pred = np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype="float32", + ) + cce_obj = losses.CategoricalCrossentropy() + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.0) + + # Test with logits. + logits = np.array( + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + ) + cce_obj = losses.CategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0) + + def test_unweighted(self): + cce_obj = losses.CategoricalCrossentropy() + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.3239) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0573) + + def test_scalar_weighted(self): + cce_obj = losses.CategoricalCrossentropy() + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + loss = cce_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 0.7449) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=2.3) + self.assertAlmostEqual(loss, 0.1317) + + def test_sample_weighted(self): + cce_obj = losses.CategoricalCrossentropy() + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + sample_weight = np.array([[1.2], [3.4], [5.6]]).reshape((3, 1)) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 1.0696) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.31829) + + def test_no_reduction(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalCrossentropy( + from_logits=True, reduction=None + ) + loss = cce_obj(y_true, logits) + self.assertAllClose((0.001822, 0.000459, 0.169846), loss) + + def test_label_smoothing(self): + logits = np.array([[100.0, -100.0, -100.0]]) + y_true = np.array([[1, 0, 0]]) + label_smoothing = 0.1 + cce_obj = losses.CategoricalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + expected_value = 400.0 * label_smoothing / 3.0 + self.assertAlmostEqual(loss, expected_value) + + def test_label_smoothing_ndarray(self): + logits = np.asarray([[100.0, -100.0, -100.0]]) + y_true = np.asarray([[1, 0, 0]]) + label_smoothing = 0.1 + cce_obj = losses.CategoricalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + expected_value = 400.0 * label_smoothing / 3.0 + self.assertAlmostEqual(loss, expected_value) + + def test_shape_mismatch(self): + y_true = np.array([[0], [1], [2]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]] + ) + + cce_obj = losses.CategoricalCrossentropy() + with self.assertRaisesRegex(ValueError, "must have the same shape"): + cce_obj(y_true, y_pred) + + def test_dtype_arg(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="int64") + y_pred = np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype="float32", + ) + cce_obj = losses.CategoricalCrossentropy(dtype="bfloat16") + loss = cce_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class SparseCategoricalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.SparseCategoricalCrossentropy(name="scce") + ) + + def test_all_correct_unweighted(self): + y_true = np.array([[0], [1], [2]], dtype="int64") + y_pred = np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype="float32", + ) + cce_obj = losses.SparseCategoricalCrossentropy() + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.0, 3) + + # Test with logits. + logits = np.array( + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + ) + cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_unweighted(self): + cce_obj = losses.SparseCategoricalCrossentropy() + y_true = np.array([0, 1, 2]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.3239, 3) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0573, 3) + + def test_scalar_weighted(self): + cce_obj = losses.SparseCategoricalCrossentropy() + y_true = np.array([[0], [1], [2]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + loss = cce_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 0.7449, 3) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=2.3) + self.assertAlmostEqual(loss, 0.1317, 3) + + def test_sample_weighted(self): + cce_obj = losses.SparseCategoricalCrossentropy() + y_true = np.array([[0], [1], [2]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + sample_weight = np.array([[1.2], [3.4], [5.6]]).reshape((3, 1)) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 1.0696, 3) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.31829, 3) + + def test_no_reduction(self): + y_true = np.array([[0], [1], [2]]) + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=None + ) + loss = cce_obj(y_true, logits) + self.assertAllClose((0.001822, 0.000459, 0.169846), loss) + + def test_ignore_class(self): + y_true = np.array([[-1, 2]]) + logits = np.array([[[0.854, 0.698, 0.598], [0.088, 0.86, 0.018]]]) + cce_obj = losses.SparseCategoricalCrossentropy( + from_logits=True, ignore_class=-1, reduction=None + ) + loss = cce_obj(y_true, logits) + self.assertAllClose([[0.0, 1.480129]], loss) + + y_true = np.array([[[-1], [2]]]) + logits = np.array([[[0.854, 0.698, 0.598], [0.088, 0.86, 0.018]]]) + cce_obj = losses.SparseCategoricalCrossentropy( + from_logits=True, ignore_class=-1, reduction=None + ) + loss = cce_obj(y_true, logits) + self.assertAllClose([[0.0, 1.480129]], loss) + + def test_binary_segmentation(self): + y_true = np.array( + [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]], + ] + ) + output = losses.SparseCategoricalCrossentropy()(y_true, y_pred) + self.assertAllClose(output, 0.0) + + y_true = np.array( + [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]], + ] + ) + expected = np.array([-np.log(0.2), -np.log(0.4)]) + output = losses.SparseCategoricalCrossentropy()(y_true, y_pred) + self.assertAllClose(output, expected.sum() / 16.0) # 16 pixels + + def test_binary_segmentation_different_axis(self): + y_true = np.array( + [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]], + ] + ) + y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0) + if backend.backend() == "tensorflow": + expected_message = ( + "Only axis=-1 is currently supported. Received: axis=0" + ) + escaped_message = re.escape(expected_message) + + with pytest.raises(ValueError, match=escaped_message): + losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + elif backend.backend() == "jax": + expected_message = ( + "Arguments `target` and `output` " + "must have the same shape up until" + " the last dimension: target.shape=(4, 4)," + " output.shape=(2, 4, 4)" + ) + escaped_message = re.escape(expected_message) + + with pytest.raises(ValueError, match=escaped_message): + losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + elif backend.backend() == "torch": + output = losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + self.assertAllClose(output, 0.0) + + if backend.backend() == "torch": + y_true = np.array( + [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]], + ] + ) + y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0) + expected = np.array([-np.log(0.2), -np.log(0.4)]) + output = losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + self.assertAllClose(output, expected.sum() / 16.0) + + y_true = np.array([y_true, y_true, y_true]) + y_pred_reshaped = np.array( + [y_pred_reshaped, y_pred_reshaped, y_pred_reshaped] + ) + output = losses.SparseCategoricalCrossentropy(axis=1)( + y_true, y_pred_reshaped + ) + self.assertAllClose(output, expected.sum() / 16.0) + + def test_multi_class_segmentation(self): + y_true = np.array( + [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ], + ] + ) + output = losses.SparseCategoricalCrossentropy()(y_true, y_pred) + self.assertAllClose(output, 0.0) + + y_true = np.array( + [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.2, 0.0, 0.8], + ], + [ + [0.7, 0.3, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.5, 0.5, 0.0], + [0.0, 1.0, 0.0], + ], + ] + ) + expected = np.array( + [ + -np.log(0.2), + -np.log(0.3), + -np.log(0.5), + ] + ) + output = losses.SparseCategoricalCrossentropy()(y_true, y_pred) + self.assertAllClose(output, expected.sum() / 16.0) # 16 pixels + + def test_multi_class_segmentation_different_axis(self): + y_true = np.array( + [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ], + ] + ) + y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0) + if backend.backend() == "tensorflow": + expected_message = ( + "Only axis=-1 is currently supported. Received: axis=0" + ) + escaped_message = re.escape(expected_message) + + with pytest.raises(ValueError, match=escaped_message): + losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + elif backend.backend() == "jax": + expected_message = ( + "Arguments `target` and `output` " + "must have the same shape up until" + " the last dimension: target.shape=(4, 4)," + " output.shape=(3, 4, 4)" + ) + escaped_message = re.escape(expected_message) + + with pytest.raises(ValueError, match=escaped_message): + losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + elif backend.backend() == "torch": + output = losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + self.assertAllClose(output, 0.0) + + if backend.backend() == "torch": + y_true = np.array( + [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.2, 0.0, 0.8], + ], + [ + [0.7, 0.3, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.5, 0.5, 0.0], + [0.0, 1.0, 0.0], + ], + ] + ) + expected = np.array( + [ + -np.log(0.2), + -np.log(0.3), + -np.log(0.5), + ] + ) + y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0) + output = losses.SparseCategoricalCrossentropy(axis=0)( + y_true, y_pred_reshaped + ) + self.assertAllClose(output, expected.sum() / 16.0) + y_true = np.array([y_true, y_true, y_true]) + y_pred_reshaped = np.array( + [y_pred_reshaped, y_pred_reshaped, y_pred_reshaped] + ) + output = losses.SparseCategoricalCrossentropy(axis=1)( + y_true, y_pred_reshaped + ) + self.assertAllClose(output, expected.sum() / 16.0) + + def test_dtype_arg(self): + y_true = np.array([[0], [1], [2]], dtype="int64") + y_pred = np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype="float32", + ) + cce_obj = losses.SparseCategoricalCrossentropy(dtype="bfloat16") + loss = cce_obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class BinaryFocalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.BinaryFocalCrossentropy(name="bfce") + ) + + def test_all_correct_unweighted(self): + y_true = np.array( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype="float32", + ) + obj = losses.BinaryFocalCrossentropy(gamma=1.5) + loss = obj(y_true, y_true) + self.assertAlmostEqual(loss, 0.0, 3) + + # Test with logits. + logits = np.array( + [ + [100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0], + ] + ) + obj = losses.BinaryFocalCrossentropy(gamma=2.0, from_logits=True) + loss = obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_unweighted(self): + y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape( + [2, 2] + ) + obj = losses.BinaryFocalCrossentropy(gamma=2.0) + loss = obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.268, 3) + + # Test with logits. + y_true = np.array([[1, 1, 0], [0, 1, 0]], dtype="float32") + logits = np.array([[1.5, -2.7, 2.9], [-3.8, 1.2, -4.5]]) + obj = losses.BinaryFocalCrossentropy(gamma=3.0, from_logits=True) + loss = obj(y_true, logits) + self.assertAlmostEqual(loss, 0.799, 3) + + def test_scalar_weighted(self): + y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape( + [2, 2] + ) + obj = losses.BinaryFocalCrossentropy(gamma=2.0) + loss = obj(y_true, y_pred, sample_weight=1.23) + self.assertAlmostEqual(loss, 0.3296, 3) + + # Test with logits. + y_true = np.array([[1, 1, 0], [0, 1, 0]], dtype="float32") + logits = np.array([[1.5, -2.7, 2.9], [-3.8, 1.2, -4.5]]) + obj = losses.BinaryFocalCrossentropy(gamma=3.0, from_logits=True) + loss = obj(y_true, logits, sample_weight=3.21) + self.assertAlmostEqual(loss, 2.565, 3) + + def test_sample_weighted(self): + y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape( + [2, 2] + ) + sample_weight = np.array([1.2, 3.4]).reshape((2, 1)) + obj = losses.BinaryFocalCrossentropy(gamma=2.0) + loss = obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.34415, 3) + + # Test with logits. + y_true = np.array([[1, 1, 0], [0, 1, 0]], dtype="float32") + logits = np.array([[1.5, -2.7, 2.9], [-3.8, 1.2, -4.5]]) + obj = losses.BinaryFocalCrossentropy(gamma=3.0, from_logits=True) + loss = obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.95977, 3) + + def test_no_reduction(self): + y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape( + [2, 2] + ) + obj = losses.BinaryFocalCrossentropy( + gamma=2.0, + reduction=None, + ) + loss = obj(y_true, y_pred) + self.assertAllClose(loss, (0.515547, 0.020513)) + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="Torch doesn't support bfloat16 for BinaryFocalCrossentropy", + ) + def test_dtype_arg(self): + y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape( + [2, 2] + ) + obj = losses.BinaryFocalCrossentropy(dtype="bfloat16") + loss = obj(y_true, y_pred) + self.assertDType(loss, "bfloat16") + + +class CategoricalFocalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.CategoricalFocalCrossentropy(name="cfce") + ) + + def test_all_correct_unweighted(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="int64") + y_pred = np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype="float32", + ) + cce_obj = losses.CategoricalFocalCrossentropy(alpha=0.25, gamma=2.0) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.0, 3) + + # Test with logits. + logits = np.array( + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.0, 3) + + def test_unweighted(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(loss, 0.02059, 3) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(loss, 0.000345, 3) + + def test_scalar_weighted(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + loss = cce_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(loss, 0.047368, 3) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=2.3) + self.assertAlmostEqual(loss, 0.000794, 4) + + def test_sample_weighted(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype="float32", + ) + sample_weight = np.array([[1.2], [3.4], [5.6]]).reshape((3, 1)) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.06987, 3) + + # Test with logits. + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(loss, 0.001933, 3) + + def test_no_reduction(self): + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + cce_obj = losses.CategoricalFocalCrossentropy( + from_logits=True, reduction=None + ) + loss = cce_obj(y_true, logits) + self.assertAllClose( + (1.5096224e-09, 2.4136547e-11, 1.0360638e-03), + loss, + ) + + def test_label_smoothing(self): + logits = np.array([[4.9, -0.5, 2.05]]) + y_true = np.array([[1, 0, 0]]) + label_smoothing = 0.1 + + cce_obj = losses.CategoricalFocalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + + expected_value = 0.06685 + self.assertAlmostEqual(loss, expected_value, 3) + + def test_dtype_arg(self): + logits = np.array([[4.9, -0.5, 2.05]]) + y_true = np.array([[1, 0, 0]]) + cce_obj = losses.CategoricalFocalCrossentropy( + from_logits=True, dtype="bfloat16" + ) + loss = cce_obj(y_true, logits) + self.assertDType(loss, "bfloat16") + + +class CTCTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test(losses.CTC(name="myctc")) + + def test_correctness(self): + logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100 + y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]])) + output = losses.CTC()(y_true, logits) + self.assertAllClose(output, 2.448645) + + def test_dtype_arg(self): + logits = (np.arange(24).reshape((2, 4, 3)).astype("float32") - 12) / 100 + y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]])) + output = losses.CTC(dtype="bfloat16")(y_true, logits) + self.assertDType(output, "bfloat16") + + +class DiceTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test(losses.Dice(name="mydice")) + + def test_correctness(self): + y_true = np.array(([[1, 2], [1, 2]])) + y_pred = np.array(([[4, 1], [6, 1]])) + output = losses.Dice()(y_true, y_pred) + self.assertAllClose(output, -0.55555546) + + def test_binary_segmentation(self): + y_true = np.array( + ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) + ) + y_pred = np.array( + ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]]) + ) + output = losses.Dice()(y_true, y_pred) + self.assertAllClose(output, 0.77777773) + + def test_binary_segmentation_with_axis(self): + y_true = np.array( + [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]] + ) + y_pred = np.array( + [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]] + ) + output = losses.Dice(axis=(1, 2, 3), reduction=None)(y_true, y_pred) + self.assertAllClose(output, [0.5, 0.75757575]) + + def test_dtype_arg(self): + y_true = np.array(([[1, 2], [1, 2]])) + y_pred = np.array(([[4, 1], [6, 1]])) + output = losses.Dice(dtype="bfloat16")(y_true, y_pred) + self.assertDType(output, "bfloat16") + + +class TverskyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test(losses.Tversky(name="mytversky")) + + def test_correctness(self): + y_true = np.array(([[1, 2], [1, 2]])) + y_pred = np.array(([[4, 1], [6, 1]])) + output = losses.Tversky()(y_true, y_pred) + self.assertAllClose(output, -0.55555546) + + def test_correctness_custom_coefficients(self): + y_true = np.array(([[1, 2], [1, 2]])) + y_pred = np.array(([[4, 1], [6, 1]])) + output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred) + self.assertAllClose(output, -0.29629636) + + def test_binary_segmentation(self): + y_true = np.array( + ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) + ) + y_pred = np.array( + ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]]) + ) + output = losses.Tversky()(y_true, y_pred) + self.assertAllClose(output, 0.77777773) + + def test_binary_segmentation_with_axis(self): + y_true = np.array( + [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]] + ) + y_pred = np.array( + [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]] + ) + output = losses.Tversky(axis=(1, 2, 3), reduction=None)(y_true, y_pred) + self.assertAllClose(output, [0.5, 0.75757575]) + + def test_binary_segmentation_custom_coefficients(self): + y_true = np.array( + ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) + ) + y_pred = np.array( + ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]]) + ) + output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred) + self.assertAllClose(output, 0.7916667) + + def test_binary_segmentation_custom_coefficients_with_axis(self): + y_true = np.array( + [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]] + ) + y_pred = np.array( + [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]] + ) + output = losses.Tversky( + alpha=0.2, beta=0.8, axis=(1, 2, 3), reduction=None + )(y_true, y_pred) + self.assertAllClose(output, [0.5, 0.7222222]) + + def test_dtype_arg(self): + y_true = np.array(([[1, 2], [1, 2]])) + y_pred = np.array(([[4, 1], [6, 1]])) + output = losses.Tversky(dtype="bfloat16")(y_true, y_pred) + self.assertDType(output, "bfloat16") + + +class CircleTest(testing.TestCase): + def setup(self): + self.y_true = np.array([1, 1, 2, 2, 3]) + self.y_pred = np.array( + [ + [0.70014004, -0.42008403, 0.14002801, 0.56011203], + [0.17609018, 0.70436073, -0.52827054, 0.44022545], + [-0.34050261, 0.25537696, -0.68100522, 0.59587957], + [0.32163376, -0.75047877, 0.53605627, -0.21442251], + [0.51261459, -0.34174306, 0.17087153, 0.76892189], + ] + ) + self.ref_labels = np.array([1, 1, 2, 2, 3, 4]) + self.ref_embeddings = np.array( + [ + [0.40824829, -0.54433105, 0.27216553, 0.68041382], + [0.76376261, 0.10910895, -0.54554473, 0.32732684], + [-0.74420841, 0.24806947, 0.49613894, -0.3721042], + [0.52981294, -0.13245324, 0.79471941, -0.26490647], + [0.54554473, -0.32732684, 0.10910895, 0.76376261], + [-0.27216553, 0.68041382, 0.40824829, -0.54433105], + ] + ) + + def test_config(self): + self.run_class_serialization_test( + losses.Circle(name="mycircle", gamma=80.0, margin=0.4) + ) + + def test_correctness(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4) + loss = circle_loss(self.y_true, self.y_pred) + self.assertAlmostEqual(loss, 188.3883) + + circle_loss = losses.Circle(gamma=256, margin=0.25) + loss = circle_loss(self.y_true, self.y_pred) + self.assertAlmostEqual(loss, 652.7617) + + loss = losses.circle( + self.y_true, + self.y_pred, + ref_labels=self.ref_labels, + ref_embeddings=self.ref_embeddings, + gamma=80.0, + margin=0.4, + remove_diagonal=False, + ) + + self.assertAllClose( + loss, (61.5844, 94.3465, 276.9344, 90.9873, 48.8963) + ) + + def test_correctness_weighted(self): + self.setup() + sample_weight = np.array([2.0, 2.0, 1.0, 1.0, 0.5]) + circle_loss = losses.Circle(gamma=80.0, margin=0.4) + loss = circle_loss( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, 244.91918) + + def test_no_reduction(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction=None) + loss = circle_loss(self.ref_labels, self.ref_embeddings) + + self.assertAllClose( + loss, [82.9116, 36.7942, 92.4590, 52.6798, 0.0, 0.0] + ) + + def test_sum_reduction(self): + self.setup() + circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction="sum") + loss = circle_loss(self.ref_labels, self.ref_embeddings) + + self.assertAlmostEqual(loss, 264.845) + + def test_mean_with_sample_weight_reduction(self): + self.setup() + sample_weight = np.array([2.0, 2.0, 1.0, 1.0, 0.5]) + circle_loss = losses.Circle( + gamma=80.0, margin=0.4, reduction="mean_with_sample_weight" + ) + loss = circle_loss( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + self.assertAlmostEqual(loss, 163.27948) + + def test_dtype_arg(self): + self.setup() + circle_loss = losses.Circle(dtype="bfloat16") + loss = circle_loss(self.y_true, self.y_pred) + self.assertDType(loss, "bfloat16") + + +class CategoricalGeneralizedCrossEntropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + losses.CategoricalGeneralizedCrossEntropy(name="gce") + ) + self.run_class_serialization_test( + losses.CategoricalGeneralizedCrossEntropy(q=0.1, name="gce") + ) + + def test_basic_correctness_for_binary(self): + y_true = np.array([0, 1, 0, 1]) + y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]]) + # Calculate expected GCE loss manually + # For q=0.5: + # First sample (class 0): gce = (1 - 0.7^0.5) / 0.5 + # Second sample (class 1): gce = (1 - 0.8^0.5) / 0.5 + # Third sample (class 0): gce = (1 - 0.6^0.5) / 0.5 + # Fourth sample (class 1): gce = (1 - 0.6^0.5) / 0.5 + expected = np.array( + [ + (1 - np.power(0.7, 0.5)) / 0.5, + (1 - np.power(0.8, 0.5)) / 0.5, + (1 - np.power(0.6, 0.5)) / 0.5, + (1 - np.power(0.6, 0.5)) / 0.5, + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred) + self.assertAllClose(output, expected.sum() / len(expected)) + + expected_q_08 = np.array( + [ + (1 - np.power(0.7, 0.8)) / 0.8, + (1 - np.power(0.8, 0.8)) / 0.8, + (1 - np.power(0.6, 0.8)) / 0.8, + (1 - np.power(0.6, 0.8)) / 0.8, + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy(q=0.8)( + y_true, y_pred + ) + self.assertAllClose(output, expected_q_08.sum() / len(expected_q_08)) + + def test_basic_correctness_for_multi_class(self): + y_true = np.array([0, 1, 0, 1]) + y_pred = np.array( + [[0.7, 0.3, 0.0], [0.2, 0.2, 0.6], [0.6, 0.4, 0.0], [0.2, 0.2, 0.6]] + ) + # Calculate expected GCE loss manually + # For q=0.5: + # First sample (class 0): gce = (1 - 0.7^0.5) / 0.5 + # Second sample (class 1): gce = (1 - 0^0.5) / 0.5 + # Third sample (class 0): gce = (1 - 0.6^0.5) / 0.5 + # Fourth sample (class 1): gce = (1 - 0.0^0.5) / 0.5 + expected = np.array( + [ + (1 - np.power(0.7, 0.5)) / 0.5, + (1 - np.power(0.2, 0.5)) / 0.5, + (1 - np.power(0.6, 0.5)) / 0.5, + (1 - np.power(0.2, 0.5)) / 0.5, + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred) + self.assertAllClose(output, expected.sum() / len(expected)) + + expected_q_08 = np.array( + [ + (1 - np.power(0.7, 0.8)) / 0.8, + (1 - np.power(0.2, 0.8)) / 0.8, + (1 - np.power(0.6, 0.8)) / 0.8, + (1 - np.power(0.2, 0.8)) / 0.8, + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy(q=0.8)( + y_true, y_pred + ) + self.assertAllClose(output, expected_q_08.sum() / len(expected_q_08)) + + def test_binary_segmentation(self): + y_true = np.array( + [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]], + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)( + y_true, y_pred + ) + self.assertAllClose(output, 0.0) + + y_true = np.array( + [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]], + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]], + ] + ) + expected = np.array( + [ + (1 - np.power(0.2, 0.5)) / 0.5, + (1 - np.power(0.4, 0.5)) / 0.5, + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)( + y_true, y_pred + ) + self.assertAllClose(output, expected.sum() / 16.0) # 16 pixels + + def test_multi_class_segmentation(self): + y_true = np.array( + [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ], + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)( + y_true, y_pred + ) + self.assertAllClose(output, 0.0) + + y_true = np.array( + [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]] + ) + y_pred = np.array( + [ + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.2, 0.0, 0.8], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ], + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.5, 0.5, 0.0], + [0.0, 1.0, 0.0], + ], + ] + ) + expected = np.array( + [ + (1 - np.power(0.2, 0.5)) / 0.5, + (1 - np.power(0.0, 0.5)) / 0.5, + (1 - np.power(0.5, 0.5)) / 0.5, + ] + ) + output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)( + y_true, y_pred + ) + self.assertAllClose(output, expected.sum() / 16.0) # 16 pixels + + def test_dtype_arg(self): + y_true = np.array([0, 1, 0, 1]) + y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]]) + output = losses.CategoricalGeneralizedCrossEntropy(dtype="bfloat16")( + y_true, y_pred + ) + self.assertDType(output, "bfloat16") diff --git a/keras/src/metrics/__init__.py b/keras/src/metrics/__init__.py new file mode 100644 index 000000000000..4cb9dc42cd5c --- /dev/null +++ b/keras/src/metrics/__init__.py @@ -0,0 +1,211 @@ +import inspect + +from keras.src.api_export import keras_export +from keras.src.metrics.accuracy_metrics import Accuracy +from keras.src.metrics.accuracy_metrics import BinaryAccuracy +from keras.src.metrics.accuracy_metrics import CategoricalAccuracy +from keras.src.metrics.accuracy_metrics import SparseCategoricalAccuracy +from keras.src.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy +from keras.src.metrics.accuracy_metrics import TopKCategoricalAccuracy +from keras.src.metrics.confusion_metrics import AUC +from keras.src.metrics.confusion_metrics import FalseNegatives +from keras.src.metrics.confusion_metrics import FalsePositives +from keras.src.metrics.confusion_metrics import Precision +from keras.src.metrics.confusion_metrics import PrecisionAtRecall +from keras.src.metrics.confusion_metrics import Recall +from keras.src.metrics.confusion_metrics import RecallAtPrecision +from keras.src.metrics.confusion_metrics import SensitivityAtSpecificity +from keras.src.metrics.confusion_metrics import SpecificityAtSensitivity +from keras.src.metrics.confusion_metrics import TrueNegatives +from keras.src.metrics.confusion_metrics import TruePositives +from keras.src.metrics.correlation_metrics import ConcordanceCorrelation +from keras.src.metrics.correlation_metrics import PearsonCorrelation +from keras.src.metrics.f_score_metrics import F1Score +from keras.src.metrics.f_score_metrics import FBetaScore +from keras.src.metrics.hinge_metrics import CategoricalHinge +from keras.src.metrics.hinge_metrics import Hinge +from keras.src.metrics.hinge_metrics import SquaredHinge +from keras.src.metrics.iou_metrics import BinaryIoU +from keras.src.metrics.iou_metrics import IoU +from keras.src.metrics.iou_metrics import MeanIoU +from keras.src.metrics.iou_metrics import OneHotIoU +from keras.src.metrics.iou_metrics import OneHotMeanIoU +from keras.src.metrics.metric import Metric +from keras.src.metrics.probabilistic_metrics import BinaryCrossentropy +from keras.src.metrics.probabilistic_metrics import CategoricalCrossentropy +from keras.src.metrics.probabilistic_metrics import KLDivergence +from keras.src.metrics.probabilistic_metrics import Poisson +from keras.src.metrics.probabilistic_metrics import ( + SparseCategoricalCrossentropy, +) +from keras.src.metrics.reduction_metrics import Mean +from keras.src.metrics.reduction_metrics import MeanMetricWrapper +from keras.src.metrics.reduction_metrics import Sum +from keras.src.metrics.regression_metrics import CosineSimilarity +from keras.src.metrics.regression_metrics import LogCoshError +from keras.src.metrics.regression_metrics import MeanAbsoluteError +from keras.src.metrics.regression_metrics import MeanAbsolutePercentageError +from keras.src.metrics.regression_metrics import MeanSquaredError +from keras.src.metrics.regression_metrics import MeanSquaredLogarithmicError +from keras.src.metrics.regression_metrics import R2Score +from keras.src.metrics.regression_metrics import RootMeanSquaredError +from keras.src.saving import serialization_lib +from keras.src.utils.naming import to_snake_case + +ALL_OBJECTS = { + # Base + Metric, + Mean, + Sum, + MeanMetricWrapper, + # Regression + MeanSquaredError, + RootMeanSquaredError, + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredLogarithmicError, + CosineSimilarity, + LogCoshError, + R2Score, + # Classification + AUC, + FalseNegatives, + FalsePositives, + Precision, + PrecisionAtRecall, + Recall, + RecallAtPrecision, + SensitivityAtSpecificity, + SpecificityAtSensitivity, + TrueNegatives, + TruePositives, + # Correlation + ConcordanceCorrelation, + PearsonCorrelation, + # Hinge + Hinge, + SquaredHinge, + CategoricalHinge, + # Probabilistic + KLDivergence, + Poisson, + BinaryCrossentropy, + CategoricalCrossentropy, + SparseCategoricalCrossentropy, + # Accuracy + Accuracy, + BinaryAccuracy, + CategoricalAccuracy, + SparseCategoricalAccuracy, + TopKCategoricalAccuracy, + SparseTopKCategoricalAccuracy, + # F-Score + F1Score, + FBetaScore, + # IoU + IoU, + BinaryIoU, + MeanIoU, + OneHotIoU, + OneHotMeanIoU, +} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) +# TODO: Align with `tf.keras` and set the name attribute of metrics +# with the key name. Currently it uses default name of class definitions. +ALL_OBJECTS_DICT.update( + { + "bce": BinaryCrossentropy, + "BCE": BinaryCrossentropy, + "mse": MeanSquaredError, + "MSE": MeanSquaredError, + "mae": MeanAbsoluteError, + "MAE": MeanAbsoluteError, + "mape": MeanAbsolutePercentageError, + "MAPE": MeanAbsolutePercentageError, + "msle": MeanSquaredLogarithmicError, + "MSLE": MeanSquaredLogarithmicError, + } +) + + +@keras_export("keras.metrics.serialize") +def serialize(metric): + """Serializes metric function or `Metric` instance. + + Args: + metric: A Keras `Metric` instance or a metric function. + + Returns: + Metric configuration dictionary. + """ + return serialization_lib.serialize_keras_object(metric) + + +@keras_export("keras.metrics.deserialize") +def deserialize(config, custom_objects=None): + """Deserializes a serialized metric class/function instance. + + Args: + config: Metric configuration. + custom_objects: Optional dictionary mapping names (strings) + to custom objects (classes and functions) to be + considered during deserialization. + + Returns: + A Keras `Metric` instance or a metric function. + """ + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.metrics.get") +def get(identifier): + """Retrieves a Keras metric as a `function`/`Metric` class instance. + + The `identifier` may be the string name of a metric function or class. + + >>> metric = metrics.get("categorical_crossentropy") + >>> type(metric) + + >>> metric = metrics.get("CategoricalCrossentropy") + >>> type(metric) + + + You can also specify `config` of the metric to this function by passing dict + containing `class_name` and `config` as an identifier. Also note that the + `class_name` must map to a `Metric` class + + >>> identifier = {"class_name": "CategoricalCrossentropy", + ... "config": {"from_logits": True}} + >>> metric = metrics.get(identifier) + >>> type(metric) + + + Args: + identifier: A metric identifier. One of None or string name of a metric + function/class or metric configuration dictionary or a metric + function or a metric class instance + + Returns: + A Keras metric as a `function`/ `Metric` class instance. + """ + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + if callable(obj): + if inspect.isclass(obj): + obj = obj() + return obj + else: + raise ValueError(f"Could not interpret metric identifier: {identifier}") diff --git a/keras/src/metrics/accuracy_metrics.py b/keras/src/metrics/accuracy_metrics.py new file mode 100644 index 000000000000..817d2a5ae33d --- /dev/null +++ b/keras/src/metrics/accuracy_metrics.py @@ -0,0 +1,522 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.losses.loss import squeeze_or_expand_to_same_rank +from keras.src.metrics import reduction_metrics + + +def accuracy(y_true, y_pred): + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()) + + +@keras_export("keras.metrics.Accuracy") +class Accuracy(reduction_metrics.MeanMetricWrapper): + """Calculates how often predictions equal labels. + + This metric creates two local variables, `total` and `count` that are used + to compute the frequency with which `y_pred` matches `y_true`. This + frequency is ultimately returned as `binary accuracy`: an idempotent + operation that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.Accuracy() + >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]]) + >>> m.result() + 0.75 + + >>> m.reset_state() + >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]], + ... sample_weight=[1, 1, 0, 0]) + >>> m.result() + 0.5 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.Accuracy()]) + ``` + """ + + def __init__(self, name="accuracy", dtype=None): + super().__init__(fn=accuracy, name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.binary_accuracy") +def binary_accuracy(y_true, y_pred, threshold=0.5): + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true) + threshold = ops.convert_to_tensor(threshold) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + y_pred = ops.cast(ops.greater(y_pred, threshold), y_true.dtype) + return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()) + + +@keras_export("keras.metrics.BinaryAccuracy") +class BinaryAccuracy(reduction_metrics.MeanMetricWrapper): + """Calculates how often predictions match binary labels. + + This metric creates two local variables, `total` and `count` that are used + to compute the frequency with which `y_pred` matches `y_true`. This + frequency is ultimately returned as `binary accuracy`: an idempotent + operation that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + threshold: (Optional) Float representing the threshold for deciding + whether prediction values are 1 or 0. + + Example: + + >>> m = keras.metrics.BinaryAccuracy() + >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) + >>> m.result() + 0.75 + + >>> m.reset_state() + >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], + ... sample_weight=[1, 0, 0, 1]) + >>> m.result() + 0.5 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.BinaryAccuracy()]) + ``` + """ + + def __init__(self, name="binary_accuracy", dtype=None, threshold=0.5): + super().__init__( + fn=binary_accuracy, name=name, dtype=dtype, threshold=threshold + ) + self.threshold = threshold + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "threshold": self.threshold, + } + + +@keras_export("keras.metrics.categorical_accuracy") +def categorical_accuracy(y_true, y_pred): + y_true = ops.argmax(y_true, axis=-1) + + reshape_matches = False + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + + y_true_org_shape = ops.shape(y_true) + y_pred_rank = len(y_pred.shape) + y_true_rank = len(y_true.shape) + + # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if ( + (y_true_rank is not None) + and (y_pred_rank is not None) + and (len(y_true.shape) == len(y_pred.shape)) + ): + y_true = ops.squeeze(y_true, -1) + reshape_matches = True + y_pred = ops.argmax(y_pred, axis=-1) + + # If the predicted output and actual output types don't match, force cast + # them to match. + if y_pred.dtype is not y_true.dtype: + y_pred = ops.cast(y_pred, dtype=y_true.dtype) + matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx()) + if reshape_matches: + matches = ops.reshape(matches, y_true_org_shape) + return matches + + +@keras_export("keras.metrics.CategoricalAccuracy") +class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper): + """Calculates how often predictions match one-hot labels. + + You can provide logits of classes as `y_pred`, since argmax of + logits and probabilities are same. + + This metric creates two local variables, `total` and `count` that are used + to compute the frequency with which `y_pred` matches `y_true`. This + frequency is ultimately returned as `categorical accuracy`: an idempotent + operation that simply divides `total` by `count`. + + `y_pred` and `y_true` should be passed in as vectors of probabilities, + rather than as labels. If necessary, use `ops.one_hot` to expand `y_true` as + a vector. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.CategoricalAccuracy() + >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], + ... [0.05, 0.95, 0]]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], + ... [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) + >>> m.result() + 0.3 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='categorical_crossentropy', + metrics=[keras.metrics.CategoricalAccuracy()]) + ``` + """ + + def __init__(self, name="categorical_accuracy", dtype=None): + super().__init__(fn=categorical_accuracy, name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.sparse_categorical_accuracy") +def sparse_categorical_accuracy(y_true, y_pred): + reshape_matches = False + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true_org_shape = ops.shape(y_true) + y_pred_rank = len(y_pred.shape) + y_true_rank = len(y_true.shape) + + # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) + if ( + (y_true_rank is not None) + and (y_pred_rank is not None) + and (len(y_true.shape) == len(y_pred.shape)) + and ops.shape(y_true)[-1] == 1 + ): + y_true = ops.squeeze(y_true, -1) + reshape_matches = True + y_pred = ops.argmax(y_pred, axis=-1) + + # If the predicted output and actual output types don't match, force cast + # them to match. + if y_pred.dtype is not y_true.dtype: + y_pred = ops.cast(y_pred, y_true.dtype) + matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx()) + if reshape_matches: + matches = ops.reshape(matches, y_true_org_shape) + # if shape is (num_samples, 1) squeeze + if len(matches.shape) > 1 and matches.shape[-1] == 1: + matches = ops.squeeze(matches, -1) + return matches + + +@keras_export("keras.metrics.SparseCategoricalAccuracy") +class SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper): + """Calculates how often predictions match integer labels. + + ```python + acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1)) + ``` + + You can provide logits of classes as `y_pred`, since argmax of + logits and probabilities are same. + + This metric creates two local variables, `total` and `count` that are used + to compute the frequency with which `y_pred` matches `y_true`. This + frequency is ultimately returned as `sparse categorical accuracy`: an + idempotent operation that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.SparseCategoricalAccuracy() + >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) + >>> m.result() + 0.3 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='sparse_categorical_crossentropy', + metrics=[keras.metrics.SparseCategoricalAccuracy()]) + ``` + """ + + def __init__(self, name="sparse_categorical_accuracy", dtype=None): + super().__init__(fn=sparse_categorical_accuracy, name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.top_k_categorical_accuracy") +def top_k_categorical_accuracy(y_true, y_pred, k=5): + reshape_matches = False + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true = ops.argmax(y_true, axis=-1) + y_true_rank = len(y_true.shape) + y_pred_rank = len(y_pred.shape) + y_true_org_shape = ops.shape(y_true) + + # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) + if (y_true_rank is not None) and (y_pred_rank is not None): + if y_pred_rank > 2: + y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + if y_true_rank > 1: + reshape_matches = True + y_true = ops.reshape(y_true, [-1]) + + matches = ops.cast( + ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k), + dtype=backend.floatx(), + ) + + # returned matches is expected to have same shape as y_true input + if reshape_matches: + matches = ops.reshape(matches, y_true_org_shape) + + return matches + + +@keras_export("keras.metrics.TopKCategoricalAccuracy") +class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper): + """Computes how often targets are in the top `K` predictions. + + Args: + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to `5`. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.TopKCategoricalAccuracy(k=1) + >>> m.update_state([[0, 0, 1], [0, 1, 0]], + ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([[0, 0, 1], [0, 1, 0]], + ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) + >>> m.result() + 0.3 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='categorical_crossentropy', + metrics=[keras.metrics.TopKCategoricalAccuracy()]) + ``` + """ + + def __init__(self, k=5, name="top_k_categorical_accuracy", dtype=None): + super().__init__( + fn=top_k_categorical_accuracy, + name=name, + dtype=dtype, + k=k, + ) + self.k = k + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype, "k": self.k} + + +@keras_export("keras.metrics.sparse_top_k_categorical_accuracy") +def sparse_top_k_categorical_accuracy( + y_true, y_pred, k=5, from_sorted_ids=False +): + """Computes how often integer targets are in the top `K` predictions. + + Args: + y_true: A tensor of shape `(batch_size)` representing indices or IDs of + true categories. + y_pred: If `from_sorted_ids=False`, a tensor of shape + `(batch_size, num_categories)` containing the scores for each sample + for all possible categories. If `from_sorted_ids=True`, a tensor of + shape `(batch_size, N)` containing indices or IDs of the top `N` + categories in order from highest score to lowest score. + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to `5`. + from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or + scores for all categories (the default). + + Returns: + A tensor with the same shape as `y_true` containing ones where `y_true` + is in the top `k` and zeros elsewhere. + """ + reshape_matches = False + y_pred = ops.convert_to_tensor(y_pred) + y_true_dtype = y_pred.dtype if from_sorted_ids else "int32" + y_true = ops.convert_to_tensor(y_true, dtype=y_true_dtype) + y_true_rank = len(y_true.shape) + y_pred_rank = len(y_pred.shape) + y_true_org_shape = ops.shape(y_true) + + # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) + if (y_true_rank is not None) and (y_pred_rank is not None): + if y_pred_rank > 2: + y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + if y_true_rank > 1: + reshape_matches = True + y_true = ops.reshape(y_true, [-1]) + + if from_sorted_ids: + # By slicing the first k items, we assume they are sorted by score. + # Reduce with `any` to count multiple matches only once. + matches = ops.any( + ops.equal(ops.expand_dims(y_true, axis=1), y_pred[:, :k]), axis=1 + ) + else: + matches = ops.in_top_k(y_true, y_pred, k=k) + + matches = ops.cast(matches, dtype=backend.floatx()) + + # returned matches is expected to have same shape as y_true input + if reshape_matches: + matches = ops.reshape(matches, y_true_org_shape) + + return matches + + +@keras_export("keras.metrics.SparseTopKCategoricalAccuracy") +class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper): + """Computes how often integer targets are in the top `K` predictions. + + By default, the arguments expected by `update_state()` are: + - `y_true`: a tensor of shape `(batch_size)` representing indices of true + categories. + - `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the + scores for each sample for all possible categories. + + With `from_sorted_ids=True`, the arguments expected by `update_state` are: + - `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of + true categories. + - `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or + IDs of the top `N` categories sorted in order from highest score to + lowest score. `N` must be greater or equal to `k`. + + The `from_sorted_ids=True` option can be more efficient when the set of + categories is very large and the model has an optimized way to retrieve the + top ones either without scoring or without maintaining the scores for all + the possible categories. + + Args: + k: (Optional) Number of top elements to look at for computing accuracy. + Defaults to `5`. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_sorted_ids: (Optional) When `False`, the default, the tensor passed + in `y_pred` contains the unsorted scores of all possible categories. + When `True`, `y_pred` contains a the indices or IDs for the top + categories. + + Example: + + >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1) + >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], + ... sample_weight=[0.7, 0.3]) + >>> m.result() + 0.3 + + >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1, + ... from_sorted_ids=True) + >>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]]) + >>> m.result() + 0.5 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='sparse_categorical_crossentropy', + metrics=[keras.metrics.SparseTopKCategoricalAccuracy()]) + ``` + """ + + def __init__( + self, + k=5, + name="sparse_top_k_categorical_accuracy", + dtype=None, + from_sorted_ids=False, + ): + super().__init__( + fn=sparse_top_k_categorical_accuracy, + name=name, + dtype=dtype, + k=k, + from_sorted_ids=from_sorted_ids, + ) + self.k = k + self.from_sorted_ids = from_sorted_ids + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + config = {"name": self.name, "dtype": self.dtype, "k": self.k} + if self.from_sorted_ids: + config["from_sorted_ids"] = True + return config diff --git a/keras/src/metrics/accuracy_metrics_test.py b/keras/src/metrics/accuracy_metrics_test.py new file mode 100644 index 000000000000..74a48f276824 --- /dev/null +++ b/keras/src/metrics/accuracy_metrics_test.py @@ -0,0 +1,515 @@ +import numpy as np + +from keras.src import testing +from keras.src.metrics import accuracy_metrics + + +class AccuracyTest(testing.TestCase): + def test_config(self): + acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32") + self.assertEqual(acc_obj.name, "accuracy") + self.assertEqual(len(acc_obj.variables), 2) + self.assertEqual(acc_obj._dtype, "float32") + + # Test get_config + acc_obj_config = acc_obj.get_config() + self.assertEqual(acc_obj_config["name"], "accuracy") + self.assertEqual(acc_obj_config["dtype"], "float32") + + # Check save and restore config + acc_obj2 = accuracy_metrics.Accuracy.from_config(acc_obj_config) + self.assertEqual(acc_obj2.name, "accuracy") + self.assertEqual(len(acc_obj2.variables), 2) + self.assertEqual(acc_obj2._dtype, "float32") + + def test_unweighted(self): + acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32") + y_true = np.array([[1], [2], [3], [4]]) + y_pred = np.array([[0], [2], [3], [4]]) + acc_obj.update_state(y_true, y_pred) + result = acc_obj.result() + self.assertAllClose(result, 0.75, atol=1e-3) + + def test_weighted(self): + acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32") + y_true = np.array([[1], [2], [3], [4]]) + y_pred = np.array([[0], [2], [3], [4]]) + sample_weight = np.array([1, 1, 0, 0]) + acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted_rank_1(self): + acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32") + y_true = np.array([1, 2, 3, 4]) + y_pred = np.array([0, 2, 3, 4]) + sample_weight = np.array([1, 1, 0, 0]) + acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted_nd_weights(self): + acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32") + y_true = np.array([[1, 2], [3, 4]]) + y_pred = np.array([[0, 2], [3, 4]]) + sample_weight = np.array([[1, 0], [0, 1]]) + acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted_nd_broadcast_weights(self): + acc_obj = accuracy_metrics.Accuracy(name="accuracy", dtype="float32") + y_true = np.array([[1, 2], [3, 4]]) + y_pred = np.array([[0, 2], [3, 4]]) + sample_weight = np.array([[1, 0]]) + acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + +class BinaryAccuracyTest(testing.TestCase): + def test_config(self): + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + self.assertEqual(bin_acc_obj.name, "binary_accuracy") + self.assertEqual(len(bin_acc_obj.variables), 2) + self.assertEqual(bin_acc_obj._dtype, "float32") + + # Test get_config + bin_acc_obj_config = bin_acc_obj.get_config() + self.assertEqual(bin_acc_obj_config["name"], "binary_accuracy") + self.assertEqual(bin_acc_obj_config["dtype"], "float32") + + # Check save and restore config + bin_acc_obj2 = accuracy_metrics.BinaryAccuracy.from_config( + bin_acc_obj_config + ) + self.assertEqual(bin_acc_obj2.name, "binary_accuracy") + self.assertEqual(len(bin_acc_obj2.variables), 2) + self.assertEqual(bin_acc_obj2._dtype, "float32") + + def test_unweighted(self): + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + y_true = np.array([[1], [1], [0], [0]]) + y_pred = np.array([[0.98], [1], [0], [0.6]]) + bin_acc_obj.update_state(y_true, y_pred) + result = bin_acc_obj.result() + self.assertAllClose(result, 0.75, atol=1e-3) + + # Test broadcasting case + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + y_true = np.array([1, 1, 0, 0]) + y_pred = np.array([[0.98], [1], [0], [0.6]]) + bin_acc_obj.update_state(y_true, y_pred) + result = bin_acc_obj.result() + self.assertAllClose(result, 0.75, atol=1e-3) + + def test_weighted(self): + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + y_true = np.array([[1], [1], [0], [0]]) + y_pred = np.array([[0.98], [1], [0], [0.6]]) + sample_weight = np.array([1, 0, 0, 1]) + bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = bin_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted_rank_1(self): + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + y_true = np.array([1, 1, 0, 0]) + y_pred = np.array([0.98, 1, 0, 0.6]) + sample_weight = np.array([1, 0, 0, 1]) + bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = bin_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted_nd_weights(self): + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + y_true = np.array([[1, 1], [0, 0]]) + y_pred = np.array([[0.98, 1], [0, 0.6]]) + sample_weight = np.array([[1, 0], [0, 1]]) + bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = bin_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted_nd_broadcast_weights(self): + bin_acc_obj = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32" + ) + y_true = np.array([[1, 1], [0, 0]]) + y_pred = np.array([[0.98, 1], [0, 0.6]]) + sample_weight = np.array([[1, 0]]) + bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = bin_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-3) + + def test_threshold(self): + bin_acc_obj_1 = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32", threshold=0.3 + ) + bin_acc_obj_2 = accuracy_metrics.BinaryAccuracy( + name="binary_accuracy", dtype="float32", threshold=0.9 + ) + y_true = np.array([[1], [1], [0], [0]]) + y_pred = np.array([[0.98], [0.5], [0.1], [0.2]]) + + bin_acc_obj_1.update_state(y_true, y_pred) + bin_acc_obj_2.update_state(y_true, y_pred) + result_1 = bin_acc_obj_1.result() + result_2 = bin_acc_obj_2.result() + + # Higher threshold must result in lower measured accuracy. + self.assertAllClose(result_1, 1.0) + self.assertAllClose(result_2, 0.75) + + +class CategoricalAccuracyTest(testing.TestCase): + def test_config(self): + cat_acc_obj = accuracy_metrics.CategoricalAccuracy( + name="categorical_accuracy", dtype="float32" + ) + self.assertEqual(cat_acc_obj.name, "categorical_accuracy") + self.assertEqual(len(cat_acc_obj.variables), 2) + self.assertEqual(cat_acc_obj._dtype, "float32") + + # Test get_config + cat_acc_obj_config = cat_acc_obj.get_config() + self.assertEqual(cat_acc_obj_config["name"], "categorical_accuracy") + self.assertEqual(cat_acc_obj_config["dtype"], "float32") + + # Check save and restore config + cat_acc_obj2 = accuracy_metrics.CategoricalAccuracy.from_config( + cat_acc_obj_config + ) + self.assertEqual(cat_acc_obj2.name, "categorical_accuracy") + self.assertEqual(len(cat_acc_obj2.variables), 2) + self.assertEqual(cat_acc_obj2._dtype, "float32") + + def test_unweighted(self): + cat_acc_obj = accuracy_metrics.CategoricalAccuracy( + name="categorical_accuracy", dtype="float32" + ) + y_true = np.array([[0, 0, 1], [0, 1, 0]]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + cat_acc_obj.update_state(y_true, y_pred) + result = cat_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted(self): + cat_acc_obj = accuracy_metrics.CategoricalAccuracy( + name="categorical_accuracy", dtype="float32" + ) + y_true = np.array([[0, 0, 1], [0, 1, 0]]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) + sample_weight = np.array([0.7, 0.3]) + cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = cat_acc_obj.result() + self.assertAllClose(result, 0.3, atol=1e-3) + + +class SparseCategoricalAccuracyTest(testing.TestCase): + def test_config(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + self.assertEqual(sp_cat_acc_obj.name, "sparse_categorical_accuracy") + self.assertEqual(len(sp_cat_acc_obj.variables), 2) + self.assertEqual(sp_cat_acc_obj._dtype, "float32") + + # Test get_config + sp_cat_acc_obj_config = sp_cat_acc_obj.get_config() + self.assertEqual( + sp_cat_acc_obj_config["name"], "sparse_categorical_accuracy" + ) + self.assertEqual(sp_cat_acc_obj_config["dtype"], "float32") + + # Check save and restore config + sp_cat_acc_obj2 = ( + accuracy_metrics.SparseCategoricalAccuracy.from_config( + sp_cat_acc_obj_config + ) + ) + self.assertEqual(sp_cat_acc_obj2.name, "sparse_categorical_accuracy") + self.assertEqual(len(sp_cat_acc_obj2.variables), 2) + self.assertEqual(sp_cat_acc_obj2._dtype, "float32") + + def test_unweighted(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + y_true = np.array([[2], [1]]) + y_pred = np.array([[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + y_true = np.array([[2], [1]]) + y_pred = np.array([[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) + sample_weight = np.array([0.7, 0.3]) + sp_cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 0.3, atol=1e-3) + + def test_squeeze_y_true(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + # Scenario with 100% accuracy for simplicity. + # y_true is a 2D tensor with shape (3, 1) to test squeeze. + y_true = np.array([[0], [1], [2]]) + y_pred = np.array( + [[0.9, 0.05, 0.05], [0.05, 0.9, 0.05], [0.05, 0.05, 0.9]] + ) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-4) + + def test_cast_y_pred_dtype(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + # Scenario with 100% accuracy for simplicity. + # y_true is a 1D tensor with shape (2,) to test cast. + y_true = np.array([0, 1], dtype=np.int64) + y_pred = np.array([[0.9, 0.1], [0.1, 0.9]], dtype=np.float32) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-4) + + def test_reshape_matches(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + # Scenario with 100% accuracy for simplicity. + # y_true is a 2D tensor with shape (2, 1) to test reshape. + y_true = np.array([[0], [0]], dtype=np.int64) + y_pred = np.array( + [[[0.9, 0.1, 0.0], [0.8, 0.15, 0.05]]], dtype=np.float32 + ) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, np.array([1.0, 1.0])) + + def test_squeeze_y_true_shape(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + # True labels are in the shape (num_samples, 1) should be squeezed. + y_true = np.array([[0], [1], [2]]) + y_pred = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-4) + + def test_cast_y_pred_to_match_y_true_dtype(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + # True labels are integers, while predictions are floats. + y_true = np.array([0, 1, 2], dtype=np.int32) + y_pred = np.array( + [[0.9, 0.1, 0.0], [0.0, 0.9, 0.1], [0.1, 0.0, 0.9]], + dtype=np.float64, + ) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-4) + + def test_reshape_matches_to_original_y_true_shape(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + # True labels have an additional dimension that needs to be squeezed. + y_true = np.array([[0], [1]]) + # Predictions must trigger a reshape of matches. + y_pred = np.array([[0.9, 0.1], [0.1, 0.9]]) + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-4) + + def test_matching_shapes_without_squeeze(self): + sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy( + name="sparse_categorical_accuracy", dtype="float32" + ) + y_true = np.array([2, 1, 0], dtype=np.int32) + y_pred = np.array( + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], + dtype=np.float32, + ) + # No need to squeeze or reshape. + sp_cat_acc_obj.update_state(y_true, y_pred) + result = sp_cat_acc_obj.result() + self.assertAllClose(result, 1.0, atol=1e-4) + + +class TopKCategoricalAccuracyTest(testing.TestCase): + def test_config(self): + top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy( + k=1, name="top_k_categorical_accuracy", dtype="float32" + ) + self.assertEqual(top_k_cat_acc_obj.name, "top_k_categorical_accuracy") + self.assertEqual(len(top_k_cat_acc_obj.variables), 2) + self.assertEqual(top_k_cat_acc_obj._dtype, "float32") + + # Test get_config + top_k_cat_acc_obj_config = top_k_cat_acc_obj.get_config() + self.assertEqual( + top_k_cat_acc_obj_config["name"], "top_k_categorical_accuracy" + ) + self.assertEqual(top_k_cat_acc_obj_config["dtype"], "float32") + self.assertEqual(top_k_cat_acc_obj_config["k"], 1) + + # Check save and restore config + top_k_cat_acc_obj2 = ( + accuracy_metrics.TopKCategoricalAccuracy.from_config( + top_k_cat_acc_obj_config + ) + ) + self.assertEqual(top_k_cat_acc_obj2.name, "top_k_categorical_accuracy") + self.assertEqual(len(top_k_cat_acc_obj2.variables), 2) + self.assertEqual(top_k_cat_acc_obj2._dtype, "float32") + self.assertEqual(top_k_cat_acc_obj2.k, 1) + + def test_unweighted(self): + top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy( + k=1, name="top_k_categorical_accuracy", dtype="float32" + ) + y_true = np.array([[0, 0, 1], [0, 1, 0]]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32") + top_k_cat_acc_obj.update_state(y_true, y_pred) + result = top_k_cat_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted(self): + top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy( + k=1, name="top_k_categorical_accuracy", dtype="float32" + ) + y_true = np.array([[0, 0, 1], [0, 1, 0]]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32") + sample_weight = np.array([0.7, 0.3]) + top_k_cat_acc_obj.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + result = top_k_cat_acc_obj.result() + self.assertAllClose(result, 0.3, atol=1e-3) + + +class SparseTopKCategoricalAccuracyTest(testing.TestCase): + def test_config(self): + sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy( + k=1, name="sparse_top_k_categorical_accuracy", dtype="float32" + ) + self.assertEqual( + sp_top_k_cat_acc_obj.name, "sparse_top_k_categorical_accuracy" + ) + self.assertEqual(len(sp_top_k_cat_acc_obj.variables), 2) + self.assertEqual(sp_top_k_cat_acc_obj._dtype, "float32") + + # Test get_config + sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config() + self.assertEqual( + sp_top_k_cat_acc_obj_config["name"], + "sparse_top_k_categorical_accuracy", + ) + self.assertEqual(sp_top_k_cat_acc_obj_config["dtype"], "float32") + self.assertEqual(sp_top_k_cat_acc_obj_config["k"], 1) + + # Check save and restore config + sp_top_k_cat_acc_obj2 = ( + accuracy_metrics.SparseTopKCategoricalAccuracy.from_config( + sp_top_k_cat_acc_obj_config + ) + ) + self.assertEqual( + sp_top_k_cat_acc_obj2.name, "sparse_top_k_categorical_accuracy" + ) + self.assertEqual(len(sp_top_k_cat_acc_obj2.variables), 2) + self.assertEqual(sp_top_k_cat_acc_obj2._dtype, "float32") + self.assertEqual(sp_top_k_cat_acc_obj2.k, 1) + self.assertFalse(sp_top_k_cat_acc_obj2.from_sorted_ids) + + def test_config_from_sorted_ids(self): + sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy( + k=1, + name="sparse_top_k_categorical_accuracy", + dtype="float32", + from_sorted_ids=True, + ) + + # Test get_config + sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config() + self.assertTrue(sp_top_k_cat_acc_obj_config["from_sorted_ids"]) + + # Check save and restore config + sp_top_k_cat_acc_obj2 = ( + accuracy_metrics.SparseTopKCategoricalAccuracy.from_config( + sp_top_k_cat_acc_obj_config + ) + ) + self.assertTrue(sp_top_k_cat_acc_obj2.from_sorted_ids) + + def test_unweighted(self): + sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy( + k=1, name="sparse_top_k_categorical_accuracy", dtype="float32" + ) + y_true = np.array([2, 1]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32") + sp_top_k_cat_acc_obj.update_state(y_true, y_pred) + result = sp_top_k_cat_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_weighted(self): + sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy( + k=1, name="sparse_top_k_categorical_accuracy", dtype="float32" + ) + y_true = np.array([2, 1]) + y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32") + sample_weight = np.array([0.7, 0.3]) + sp_top_k_cat_acc_obj.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + result = sp_top_k_cat_acc_obj.result() + self.assertAllClose(result, 0.3, atol=1e-3) + + def test_from_sorted_ids_unweighted(self): + sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy( + k=1, + name="sparse_top_k_categorical_accuracy", + dtype="float32", + from_sorted_ids=True, + ) + y_true = np.array([2, 1]) + y_pred = np.array([[1, 0, 3], [1, 2, 3]]) + sp_top_k_cat_acc_obj.update_state(y_true, y_pred) + result = sp_top_k_cat_acc_obj.result() + self.assertAllClose(result, 0.5, atol=1e-3) + + def test_from_sorted_ids_weighted(self): + sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy( + k=1, + name="sparse_top_k_categorical_accuracy", + dtype="float32", + from_sorted_ids=True, + ) + y_true = np.array([2, 1]) + y_pred = np.array([[1, 0, 3], [1, 2, 3]]) + sample_weight = np.array([0.7, 0.3]) + sp_top_k_cat_acc_obj.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + result = sp_top_k_cat_acc_obj.result() + self.assertAllClose(result, 0.3, atol=1e-3) diff --git a/keras/src/metrics/confusion_metrics.py b/keras/src/metrics/confusion_metrics.py new file mode 100644 index 000000000000..b03dbdb29352 --- /dev/null +++ b/keras/src/metrics/confusion_metrics.py @@ -0,0 +1,1597 @@ +import numpy as np + +from keras.src import activations +from keras.src import backend +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.metrics import metrics_utils +from keras.src.metrics.metric import Metric +from keras.src.utils.python_utils import to_list + + +class _ConfusionMatrixConditionCount(Metric): + """Calculates the number of the given confusion matrix condition. + + Args: + confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` + conditions. + thresholds: (Optional) Defaults to `0.5`. A float value or a python list + / tuple of float threshold values in `[0, 1]`. A threshold is + compared with prediction values to determine the truth value of + predictions (i.e., above the threshold is `True`, below is `False`). + One metric value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + + def __init__( + self, confusion_matrix_cond, thresholds=None, name=None, dtype=None + ): + super().__init__(name=name, dtype=dtype) + self._confusion_matrix_cond = confusion_matrix_cond + self.init_thresholds = thresholds + self.thresholds = metrics_utils.parse_init_thresholds( + thresholds, default_threshold=0.5 + ) + self._thresholds_distributed_evenly = ( + metrics_utils.is_evenly_distributed_thresholds(self.thresholds) + ) + self.accumulator = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="accumulator", + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates the metric statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Defaults to `1`. + Can be a tensor whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + """ + return metrics_utils.update_confusion_matrix_variables( + {self._confusion_matrix_cond: self.accumulator}, + y_true, + y_pred, + thresholds=self.thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + sample_weight=sample_weight, + ) + + def result(self): + if len(self.thresholds) == 1: + result = self.accumulator[0] + else: + result = self.accumulator + return backend.convert_to_tensor(result) + + def get_config(self): + config = {"thresholds": self.init_thresholds} + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.metrics.FalsePositives") +class FalsePositives(_ConfusionMatrixConditionCount): + """Calculates the number of false positives. + + If `sample_weight` is given, calculates the sum of the weights of + false positives. This metric creates one local variable, `accumulator` + that is used to keep track of the number of false positives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + thresholds: (Optional) Defaults to `0.5`. A float value, or a Python + list/tuple of float threshold values in `[0, 1]`. A threshold is + compared with prediction values to determine the truth value of + predictions (i.e., above the threshold is `True`, below is `False`). + If used with a loss function that sets `from_logits=True` (i.e. no + sigmoid applied to predictions), `thresholds` should be set to 0. + One metric value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.FalsePositives() + >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) + >>> m.result() + 2.0 + + >>> m.reset_state() + >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + """ + + def __init__(self, thresholds=None, name=None, dtype=None): + super().__init__( + confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, + thresholds=thresholds, + name=name, + dtype=dtype, + ) + + +@keras_export("keras.metrics.FalseNegatives") +class FalseNegatives(_ConfusionMatrixConditionCount): + """Calculates the number of false negatives. + + If `sample_weight` is given, calculates the sum of the weights of + false negatives. This metric creates one local variable, `accumulator` + that is used to keep track of the number of false negatives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + thresholds: (Optional) Defaults to `0.5`. A float value, or a Python + list/tuple of float threshold values in `[0, 1]`. A threshold is + compared with prediction values to determine the truth value of + predictions (i.e., above the threshold is `True`, below is `False`). + If used with a loss function that sets `from_logits=True` (i.e. no + sigmoid applied to predictions), `thresholds` should be set to 0. + One metric value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.FalseNegatives() + >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) + >>> m.result() + 2.0 + + >>> m.reset_state() + >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + """ + + def __init__(self, thresholds=None, name=None, dtype=None): + super().__init__( + confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, + thresholds=thresholds, + name=name, + dtype=dtype, + ) + + +@keras_export("keras.metrics.TrueNegatives") +class TrueNegatives(_ConfusionMatrixConditionCount): + """Calculates the number of true negatives. + + If `sample_weight` is given, calculates the sum of the weights of + true negatives. This metric creates one local variable, `accumulator` + that is used to keep track of the number of true negatives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + thresholds: (Optional) Defaults to `0.5`. A float value, or a Python + list/tuple of float threshold values in `[0, 1]`. A threshold is + compared with prediction values to determine the truth value of + predictions (i.e., above the threshold is `True`, below is `False`). + If used with a loss function that sets `from_logits=True` (i.e. no + sigmoid applied to predictions), `thresholds` should be set to 0. + One metric value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.TrueNegatives() + >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) + >>> m.result() + 2.0 + + >>> m.reset_state() + >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + """ + + def __init__(self, thresholds=None, name=None, dtype=None): + super().__init__( + confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, + thresholds=thresholds, + name=name, + dtype=dtype, + ) + + +@keras_export("keras.metrics.TruePositives") +class TruePositives(_ConfusionMatrixConditionCount): + """Calculates the number of true positives. + + If `sample_weight` is given, calculates the sum of the weights of + true positives. This metric creates one local variable, `true_positives` + that is used to keep track of the number of true positives. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + thresholds: (Optional) Defaults to `0.5`. A float value, or a Python + list/tuple of float threshold values in `[0, 1]`. A threshold is + compared with prediction values to determine the truth value of + predictions (i.e., above the threshold is `True`, below is `False`). + If used with a loss function that sets `from_logits=True` (i.e. no + sigmoid applied to predictions), `thresholds` should be set to 0. + One metric value is generated for each threshold value. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.TruePositives() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.result() + 2.0 + + >>> m.reset_state() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + """ + + def __init__(self, thresholds=None, name=None, dtype=None): + super().__init__( + confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, + thresholds=thresholds, + name=name, + dtype=dtype, + ) + + +@keras_export("keras.metrics.Precision") +class Precision(Metric): + """Computes the precision of the predictions with respect to the labels. + + The metric creates two local variables, `true_positives` and + `false_positives` that are used to compute the precision. This value is + ultimately returned as `precision`, an idempotent operation that simply + divides `true_positives` by the sum of `true_positives` and + `false_positives`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `top_k` is set, we'll calculate precision as how often on average a class + among the top-k classes with the highest predicted values of a batch entry + is correct and can be found in the label for that entry. + + If `class_id` is specified, we calculate precision by considering only the + entries in the batch for which `class_id` is above the threshold and/or in + the top-k highest predictions, and computing the fraction of them for which + `class_id` is indeed a correct label. + + Args: + thresholds: (Optional) A float value, or a Python list/tuple of float + threshold values in `[0, 1]`. A threshold is compared with + prediction values to determine the truth value of predictions (i.e., + above the threshold is `True`, below is `False`). If used with a + loss function that sets `from_logits=True` (i.e. no sigmoid applied + to predictions), `thresholds` should be set to 0. One metric value + is generated for each threshold value. If neither `thresholds` nor + `top_k` are set, the default is to calculate precision with + `thresholds=0.5`. + top_k: (Optional) Unset by default. An int value specifying the top-k + predictions to consider when calculating precision. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.Precision() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.result() + 0.6666667 + + >>> m.reset_state() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + + >>> # With top_k=2, it will calculate precision over y_true[:2] + >>> # and y_pred[:2] + >>> m = keras.metrics.Precision(top_k=2) + >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) + >>> m.result() + 0.0 + + >>> # With top_k=4, it will calculate precision over y_true[:4] + >>> # and y_pred[:4] + >>> m = keras.metrics.Precision(top_k=4) + >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) + >>> m.result() + 0.5 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.Precision()]) + ``` + + Usage with a loss with `from_logits=True`: + + ```python + model.compile(optimizer='adam', + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.Precision(thresholds=0)]) + ``` + """ + + def __init__( + self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None + ): + super().__init__(name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + self.init_thresholds = thresholds + self.top_k = top_k + self.class_id = class_id + + default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF + self.thresholds = metrics_utils.parse_init_thresholds( + thresholds, default_threshold=default_threshold + ) + self._thresholds_distributed_evenly = ( + metrics_utils.is_evenly_distributed_thresholds(self.thresholds) + ) + self.true_positives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="true_positives", + ) + self.false_positives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="false_positives", + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates true positive and false positive statistics. + + Args: + y_true: The ground truth values, with the same dimensions as + `y_pred`. Will be cast to `bool`. + y_pred: The predicted values. Each element must be in the range + `[0, 1]`. + sample_weight: Optional weighting of each example. Defaults to `1`. + Can be a tensor whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + """ + metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 + }, + y_true, + y_pred, + thresholds=self.thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + top_k=self.top_k, + class_id=self.class_id, + sample_weight=sample_weight, + ) + + def result(self): + result = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_positives), + ) + return result[0] if len(self.thresholds) == 1 else result + + def reset_state(self): + num_thresholds = len(to_list(self.thresholds)) + self.true_positives.assign(ops.zeros((num_thresholds,))) + self.false_positives.assign(ops.zeros((num_thresholds,))) + + def get_config(self): + config = { + "thresholds": self.init_thresholds, + "top_k": self.top_k, + "class_id": self.class_id, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.metrics.Recall") +class Recall(Metric): + """Computes the recall of the predictions with respect to the labels. + + This metric creates two local variables, `true_positives` and + `false_negatives`, that are used to compute the recall. This value is + ultimately returned as `recall`, an idempotent operation that simply divides + `true_positives` by the sum of `true_positives` and `false_negatives`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `top_k` is set, recall will be computed as how often on average a class + among the labels of a batch entry is in the top-k predictions. + + If `class_id` is specified, we calculate recall by considering only the + entries in the batch for which `class_id` is in the label, and computing the + fraction of them for which `class_id` is above the threshold and/or in the + top-k predictions. + + Args: + thresholds: (Optional) A float value, or a Python list/tuple of float + threshold values in `[0, 1]`. A threshold is compared with + prediction values to determine the truth value of predictions (i.e., + above the threshold is `True`, below is `False`). If used with a + loss function that sets `from_logits=True` (i.e. no sigmoid + applied to predictions), `thresholds` should be set to 0. + One metric value is generated for each threshold value. + If neither `thresholds` nor `top_k` are set, + the default is to calculate recall with `thresholds=0.5`. + top_k: (Optional) Unset by default. An int value specifying the top-k + predictions to consider when calculating recall. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.Recall() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) + >>> m.result() + 0.6666667 + + >>> m.reset_state() + >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) + >>> m.result() + 1.0 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.Recall()]) + ``` + + Usage with a loss with `from_logits=True`: + + ```python + model.compile(optimizer='adam', + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.Recall(thresholds=0)]) + ``` + """ + + def __init__( + self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None + ): + super().__init__(name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + self.init_thresholds = thresholds + self.top_k = top_k + self.class_id = class_id + + default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF + self.thresholds = metrics_utils.parse_init_thresholds( + thresholds, default_threshold=default_threshold + ) + self._thresholds_distributed_evenly = ( + metrics_utils.is_evenly_distributed_thresholds(self.thresholds) + ) + self.true_positives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="true_positives", + ) + self.false_negatives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="false_negatives", + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates true positive and false negative statistics. + + Args: + y_true: The ground truth values, with the same dimensions as + `y_pred`. Will be cast to `bool`. + y_pred: The predicted values. Each element must be in the range + `[0, 1]`. + sample_weight: Optional weighting of each example. Defaults to `1`. + Can be a tensor whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + """ + metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 + }, + y_true, + y_pred, + thresholds=self.thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + top_k=self.top_k, + class_id=self.class_id, + sample_weight=sample_weight, + ) + + def result(self): + result = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_negatives), + ) + return result[0] if len(self.thresholds) == 1 else result + + def reset_state(self): + num_thresholds = len(to_list(self.thresholds)) + self.true_positives.assign(ops.zeros((num_thresholds,))) + self.false_negatives.assign(ops.zeros((num_thresholds,))) + + def get_config(self): + config = { + "thresholds": self.init_thresholds, + "top_k": self.top_k, + "class_id": self.class_id, + } + base_config = super().get_config() + return {**base_config, **config} + + +class SensitivitySpecificityBase(Metric): + """Abstract base class for computing sensitivity and specificity. + + For additional information about specificity and sensitivity, see + [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). + """ + + def __init__( + self, value, num_thresholds=200, class_id=None, name=None, dtype=None + ): + super().__init__(name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + if num_thresholds <= 0: + raise ValueError( + "Argument `num_thresholds` must be an integer > 0. " + f"Received: num_thresholds={num_thresholds}" + ) + self.value = value + self.class_id = class_id + + # Compute `num_thresholds` thresholds in [0, 1] + if num_thresholds == 1: + self.thresholds = [0.5] + self._thresholds_distributed_evenly = False + else: + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2) + ] + self.thresholds = [0.0] + thresholds + [1.0] + self._thresholds_distributed_evenly = True + + self.true_positives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="true_positives", + ) + self.false_positives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="false_positives", + ) + self.true_negatives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="true_negatives", + ) + self.false_negatives = self.add_variable( + shape=(len(self.thresholds),), + initializer=initializers.Zeros(), + name="false_negatives", + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates confusion matrix statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Defaults to `1`. + Can be a tensor whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + """ + metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 + }, + y_true, + y_pred, + thresholds=self.thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + class_id=self.class_id, + sample_weight=sample_weight, + ) + + def reset_state(self): + num_thresholds = len(self.thresholds) + self.true_positives.assign(ops.zeros((num_thresholds,))) + self.false_positives.assign(ops.zeros((num_thresholds,))) + self.true_negatives.assign(ops.zeros((num_thresholds,))) + self.false_negatives.assign(ops.zeros((num_thresholds,))) + + def get_config(self): + config = {"class_id": self.class_id} + base_config = super().get_config() + return {**base_config, **config} + + def _find_max_under_constraint(self, constrained, dependent, predicate): + """Returns the maximum of dependent_statistic that satisfies the + constraint. + + Args: + constrained: Over these values the constraint is specified. A rank-1 + tensor. + dependent: From these values the maximum that satiesfies the + constraint is selected. Values in this tensor and in + `constrained` are linked by having the same threshold at each + position, hence this tensor must have the same shape. + predicate: A binary boolean functor to be applied to arguments + `constrained` and `self.value`, e.g. `ops.greater`. + + Returns: + maximal dependent value, if no value satisfies the constraint 0.0. + """ + feasible = ops.nonzero(predicate(constrained, self.value)) + feasible_exists = ops.greater(ops.size(feasible), 0) + max_dependent = ops.max(ops.take(dependent, feasible), initial=0) + + return ops.where(feasible_exists, max_dependent, 0.0) + + +@keras_export("keras.metrics.SensitivityAtSpecificity") +class SensitivityAtSpecificity(SensitivitySpecificityBase): + """Computes best sensitivity where specificity is >= specified value. + + `Sensitivity` measures the proportion of actual positives that are correctly + identified as such `(tp / (tp + fn))`. + `Specificity` measures the proportion of actual negatives that are correctly + identified as such `(tn / (tn + fp))`. + + This metric creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the sensitivity at the given specificity. The threshold for the + given specificity value is computed and used to evaluate the corresponding + sensitivity. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `class_id` is specified, we calculate precision by considering only the + entries in the batch for which `class_id` is above the threshold + predictions, and computing the fraction of them for which `class_id` is + indeed a correct label. + + For additional information about specificity and sensitivity, see + [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). + + Args: + specificity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given specificity. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.SensitivityAtSpecificity(0.5) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[1, 1, 2, 2, 1]) + >>> m.result() + 0.333333 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.SensitivityAtSpecificity(specificity=0.5)]) + ``` + """ + + def __init__( + self, + specificity, + num_thresholds=200, + class_id=None, + name=None, + dtype=None, + ): + if specificity < 0 or specificity > 1: + raise ValueError( + "Argument `specificity` must be in the range [0, 1]. " + f"Received: specificity={specificity}" + ) + self.specificity = specificity + self.num_thresholds = num_thresholds + super().__init__( + specificity, + num_thresholds=num_thresholds, + class_id=class_id, + name=name, + dtype=dtype, + ) + + def result(self): + sensitivities = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_negatives), + ) + specificities = ops.divide_no_nan( + self.true_negatives, + ops.add(self.true_negatives, self.false_positives), + ) + return self._find_max_under_constraint( + specificities, sensitivities, ops.greater_equal + ) + + def get_config(self): + config = { + "num_thresholds": self.num_thresholds, + "specificity": self.specificity, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.metrics.SpecificityAtSensitivity") +class SpecificityAtSensitivity(SensitivitySpecificityBase): + """Computes best specificity where sensitivity is >= specified value. + + `Sensitivity` measures the proportion of actual positives that are correctly + identified as such `(tp / (tp + fn))`. + `Specificity` measures the proportion of actual negatives that are correctly + identified as such `(tn / (tn + fp))`. + + This metric creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the specificity at the given sensitivity. The threshold for the + given sensitivity value is computed and used to evaluate the corresponding + specificity. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `class_id` is specified, we calculate precision by considering only the + entries in the batch for which `class_id` is above the threshold + predictions, and computing the fraction of them for which `class_id` is + indeed a correct label. + + For additional information about specificity and sensitivity, see + [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). + + Args: + sensitivity: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given sensitivity. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.SpecificityAtSensitivity(0.5) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m.result() + 0.66666667 + + >>> m.reset_state() + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[1, 1, 2, 2, 2]) + >>> m.result() + 0.5 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.SpecificityAtSensitivity(sensitivity=0.3)]) + ``` + """ + + def __init__( + self, + sensitivity, + num_thresholds=200, + class_id=None, + name=None, + dtype=None, + ): + if sensitivity < 0 or sensitivity > 1: + raise ValueError( + "Argument `sensitivity` must be in the range [0, 1]. " + f"Received: sensitivity={sensitivity}" + ) + self.sensitivity = sensitivity + self.num_thresholds = num_thresholds + super().__init__( + sensitivity, + num_thresholds=num_thresholds, + class_id=class_id, + name=name, + dtype=dtype, + ) + + def result(self): + sensitivities = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_negatives), + ) + specificities = ops.divide_no_nan( + self.true_negatives, + ops.add(self.true_negatives, self.false_positives), + ) + return self._find_max_under_constraint( + sensitivities, specificities, ops.greater_equal + ) + + def get_config(self): + config = { + "num_thresholds": self.num_thresholds, + "sensitivity": self.sensitivity, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.metrics.PrecisionAtRecall") +class PrecisionAtRecall(SensitivitySpecificityBase): + """Computes best precision where recall is >= specified value. + + This metric creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the precision at the given recall. The threshold for the given + recall value is computed and used to evaluate the corresponding precision. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `class_id` is specified, we calculate precision by considering only the + entries in the batch for which `class_id` is above the threshold + predictions, and computing the fraction of them for which `class_id` is + indeed a correct label. + + Args: + recall: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds to + use for matching the given recall. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.PrecisionAtRecall(0.5) + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], + ... sample_weight=[2, 2, 2, 1, 1]) + >>> m.result() + 0.33333333 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.PrecisionAtRecall(recall=0.8)]) + ``` + """ + + def __init__( + self, recall, num_thresholds=200, class_id=None, name=None, dtype=None + ): + if recall < 0 or recall > 1: + raise ValueError( + "Argument `recall` must be in the range [0, 1]. " + f"Received: recall={recall}" + ) + self.recall = recall + self.num_thresholds = num_thresholds + super().__init__( + value=recall, + num_thresholds=num_thresholds, + class_id=class_id, + name=name, + dtype=dtype, + ) + + def result(self): + recalls = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_negatives), + ) + precisions = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_positives), + ) + return self._find_max_under_constraint( + recalls, precisions, ops.greater_equal + ) + + def get_config(self): + config = {"num_thresholds": self.num_thresholds, "recall": self.recall} + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.metrics.RecallAtPrecision") +class RecallAtPrecision(SensitivitySpecificityBase): + """Computes best recall where precision is >= specified value. + + For a given score-label-distribution the required precision might not + be achievable, in this case 0.0 is returned as recall. + + This metric creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the recall at the given precision. The threshold for the given + precision value is computed and used to evaluate the corresponding recall. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + If `class_id` is specified, we calculate precision by considering only the + entries in the batch for which `class_id` is above the threshold + predictions, and computing the fraction of them for which `class_id` is + indeed a correct label. + + Args: + precision: A scalar value in range `[0, 1]`. + num_thresholds: (Optional) Defaults to 200. The number of thresholds + to use for matching the given precision. + class_id: (Optional) Integer class ID for which we want binary metrics. + This must be in the half-open interval `[0, num_classes)`, where + `num_classes` is the last dimension of predictions. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.RecallAtPrecision(0.8) + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) + >>> m.result() + 1.0 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='binary_crossentropy', + metrics=[keras.metrics.RecallAtPrecision(precision=0.8)]) + ``` + """ + + def __init__( + self, + precision, + num_thresholds=200, + class_id=None, + name=None, + dtype=None, + ): + if precision < 0 or precision > 1: + raise ValueError( + "Argument `precision` must be in the range [0, 1]. " + f"Received: precision={precision}" + ) + self.precision = precision + self.num_thresholds = num_thresholds + super().__init__( + value=precision, + num_thresholds=num_thresholds, + class_id=class_id, + name=name, + dtype=dtype, + ) + + def result(self): + recalls = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_negatives), + ) + precisions = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_positives), + ) + return self._find_max_under_constraint( + precisions, recalls, ops.greater_equal + ) + + def get_config(self): + config = { + "num_thresholds": self.num_thresholds, + "precision": self.precision, + } + base_config = super().get_config() + return {**base_config, **config} + + +@keras_export("keras.metrics.AUC") +class AUC(Metric): + """Approximates the AUC (Area under the curve) of the ROC or PR curves. + + The AUC (Area under the curve) of the ROC (Receiver operating + characteristic; default) or PR (Precision Recall) curves are quality + measures of binary classifiers. Unlike the accuracy, and like cross-entropy + losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. + + This class approximates AUCs using a Riemann sum. During the metric + accumulation phrase, predictions are accumulated within predefined buckets + by value. The AUC is then computed by interpolating per-bucket averages. + These buckets define the evaluated operational points. + + This metric creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the AUC. To discretize the AUC curve, a linearly spaced set of + thresholds is used to compute pairs of recall and precision values. The area + under the ROC-curve is therefore computed using the height of the recall + values by the false positive rate, while the area under the PR-curve is the + computed using the height of the precision values by the recall. + + This value is ultimately returned as `auc`, an idempotent operation that + computes the area under a discretized curve of precision versus recall + values (computed using the aforementioned variables). The `num_thresholds` + variable controls the degree of discretization with larger numbers of + thresholds more closely approximating the true AUC. The quality of the + approximation may vary dramatically depending on `num_thresholds`. The + `thresholds` parameter can be used to manually specify thresholds which + split the predictions more evenly. + + For a best approximation of the real AUC, `predictions` should be + distributed approximately uniformly in the range `[0, 1]` (if + `from_logits=False`). The quality of the AUC approximation may be poor if + this is not the case. Setting `summation_method` to 'minoring' or 'majoring' + can help quantify the error in the approximation by providing lower or upper + bound estimate of the AUC. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + num_thresholds: (Optional) The number of thresholds to + use when discretizing the roc curve. Values must be > 1. + Defaults to `200`. + curve: (Optional) Specifies the name of the curve to be computed, + `'ROC'` (default) or `'PR'` for the Precision-Recall-curve. + summation_method: (Optional) Specifies the [Riemann summation method]( + https://en.wikipedia.org/wiki/Riemann_sum) used. + 'interpolation' (default) applies mid-point summation scheme for + `ROC`. For PR-AUC, interpolates (true/false) positives but not + the ratio that is precision (see Davis & Goadrich 2006 for + details); 'minoring' applies left summation for increasing + intervals and right summation for decreasing intervals; 'majoring' + does the opposite. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + thresholds: (Optional) A list of floating point values to use as the + thresholds for discretizing the curve. If set, the `num_thresholds` + parameter is ignored. Values should be in `[0, 1]`. Endpoint + thresholds equal to {`-epsilon`, `1+epsilon`} for a small positive + epsilon value will be automatically included with these to correctly + handle predictions equal to exactly 0 or 1. + multi_label: boolean indicating whether multilabel data should be + treated as such, wherein AUC is computed separately for each label + and then averaged across labels, or (when `False`) if the data + should be flattened into a single label before AUC computation. In + the latter case, when multilabel data is passed to AUC, each + label-prediction pair is treated as an individual data point. Should + be set to `False` for multi-class data. + num_labels: (Optional) The number of labels, used when `multi_label` is + True. If `num_labels` is not specified, then state variables get + created on the first call to `update_state`. + label_weights: (Optional) list, array, or tensor of non-negative weights + used to compute AUCs for multilabel data. When `multi_label` is + True, the weights are applied to the individual label AUCs when they + are averaged to produce the multi-label AUC. When it's False, they + are used to weight the individual label predictions in computing the + confusion matrix on the flattened data. Note that this is unlike + `class_weights` in that `class_weights` weights the example + depending on the value of its label, whereas `label_weights` depends + only on the index of that label before flattening; therefore + `label_weights` should not be used for multi-class data. + from_logits: boolean indicating whether the predictions (`y_pred` in + `update_state`) are probabilities or sigmoid logits. As a rule of thumb, + when using a keras loss, the `from_logits` constructor argument of the + loss should match the AUC `from_logits` constructor argument. + + Example: + + >>> m = keras.metrics.AUC(num_thresholds=3) + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) + >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] + >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0] + >>> # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0))) + >>> # = 0.75 + >>> m.result() + 0.75 + + >>> m.reset_state() + >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], + ... sample_weight=[1, 0, 0, 1]) + >>> m.result() + 1.0 + + Usage with `compile()` API: + + ```python + # Reports the AUC of a model outputting a probability. + model.compile(optimizer='sgd', + loss=keras.losses.BinaryCrossentropy(), + metrics=[keras.metrics.AUC()]) + + # Reports the AUC of a model outputting a logit. + model.compile(optimizer='sgd', + loss=keras.losses.BinaryCrossentropy(from_logits=True), + metrics=[keras.metrics.AUC(from_logits=True)]) + ``` + """ + + def __init__( + self, + num_thresholds=200, + curve="ROC", + summation_method="interpolation", + name=None, + dtype=None, + thresholds=None, + multi_label=False, + num_labels=None, + label_weights=None, + from_logits=False, + ): + # Metric should be maximized during optimization. + self._direction = "up" + + # Validate configurations. + if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( + metrics_utils.AUCCurve + ): + raise ValueError( + f'Invalid `curve` argument value "{curve}". ' + f"Expected one of: {list(metrics_utils.AUCCurve)}" + ) + if isinstance( + summation_method, metrics_utils.AUCSummationMethod + ) and summation_method not in list(metrics_utils.AUCSummationMethod): + raise ValueError( + "Invalid `summation_method` " + f'argument value "{summation_method}". ' + f"Expected one of: {list(metrics_utils.AUCSummationMethod)}" + ) + + # Update properties. + self._init_from_thresholds = thresholds is not None + if thresholds is not None: + # If specified, use the supplied thresholds. + self.num_thresholds = len(thresholds) + 2 + thresholds = sorted(thresholds) + self._thresholds_distributed_evenly = ( + metrics_utils.is_evenly_distributed_thresholds( + np.array([0.0] + thresholds + [1.0]) + ) + ) + else: + if num_thresholds <= 1: + raise ValueError( + "Argument `num_thresholds` must be an integer > 1. " + f"Received: num_thresholds={num_thresholds}" + ) + + # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in + # (0, 1). + self.num_thresholds = num_thresholds + thresholds = [ + (i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2) + ] + self._thresholds_distributed_evenly = True + + # Add an endpoint "threshold" below zero and above one for either + # threshold method to account for floating point imprecisions. + self._thresholds = np.array( + [0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()] + ) + + if isinstance(curve, metrics_utils.AUCCurve): + self.curve = curve + else: + self.curve = metrics_utils.AUCCurve.from_str(curve) + if isinstance(summation_method, metrics_utils.AUCSummationMethod): + self.summation_method = summation_method + else: + self.summation_method = metrics_utils.AUCSummationMethod.from_str( + summation_method + ) + super().__init__(name=name, dtype=dtype) + + # Handle multilabel arguments. + self.multi_label = multi_label + self.num_labels = num_labels + if label_weights is not None: + label_weights = ops.array(label_weights, dtype=self.dtype) + self.label_weights = label_weights + + else: + self.label_weights = None + + self._from_logits = from_logits + + self._built = False + if self.multi_label: + if num_labels: + shape = [None, num_labels] + self._build(shape) + else: + if num_labels: + raise ValueError( + "`num_labels` is needed only when `multi_label` is True." + ) + self._build(None) + + @property + def thresholds(self): + """The thresholds used for evaluating AUC.""" + return list(self._thresholds) + + def _build(self, shape): + """Initialize TP, FP, TN, and FN tensors, given the shape of the + data.""" + if self.multi_label: + if len(shape) != 2: + raise ValueError( + "`y_pred` must have rank 2 when `multi_label=True`. " + f"Found rank {len(shape)}. " + f"Full shape received for `y_pred`: {shape}" + ) + self._num_labels = shape[1] + variable_shape = [self.num_thresholds, self._num_labels] + else: + variable_shape = [self.num_thresholds] + + self._build_input_shape = shape + # Create metric variables + self.true_positives = self.add_variable( + shape=variable_shape, + initializer=initializers.Zeros(), + name="true_positives", + ) + self.false_positives = self.add_variable( + shape=variable_shape, + initializer=initializers.Zeros(), + name="false_positives", + ) + self.true_negatives = self.add_variable( + shape=variable_shape, + initializer=initializers.Zeros(), + name="true_negatives", + ) + self.false_negatives = self.add_variable( + shape=variable_shape, + initializer=initializers.Zeros(), + name="false_negatives", + ) + + self._built = True + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates confusion matrix statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Can + be a tensor whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. Defaults to + `1`. + """ + if not self._built: + self._build(y_pred.shape) + + # Only forward label_weights to update_confusion_matrix_variables when + # multi_label is False. Otherwise the averaging of individual label AUCs + # is handled in AUC.result + label_weights = None if self.multi_label else self.label_weights + + if self._from_logits: + y_pred = activations.sigmoid(y_pred) + + metrics_utils.update_confusion_matrix_variables( + { + metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, # noqa: E501 + metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, # noqa: E501 + }, + y_true, + y_pred, + self._thresholds, + thresholds_distributed_evenly=self._thresholds_distributed_evenly, + sample_weight=sample_weight, + multi_label=self.multi_label, + label_weights=label_weights, + ) + + def interpolate_pr_auc(self): + """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + + https://www.biostat.wisc.edu/~page/rocpr.pdf + + Note here we derive & use a closed formula not present in the paper + as follows: + + Precision = TP / (TP + FP) = TP / P + + Modeling all of TP (true positive), FP (false positive) and their sum + P = TP + FP (predicted positive) as varying linearly within each + interval [A, B] between successive thresholds, we get + + Precision slope = dTP / dP + = (TP_B - TP_A) / (P_B - P_A) + = (TP - TP_A) / (P - P_A) + Precision = (TP_A + slope * (P - P_A)) / P + + The area within the interval is (slope / total_pos_weight) times + + int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} + int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} + + where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + + int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) + + Bringing back the factor (slope / total_pos_weight) we'd put aside, we + get + + slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight + + where dTP == TP_B - TP_A. + + Note that when P_A == 0 the above calculation simplifies into + + int_A^B{Precision.dTP} = int_A^B{slope * dTP} + = slope * (TP_B - TP_A) + + which is really equivalent to imputing constant precision throughout the + first bucket having >0 true positives. + + Returns: + pr_auc: an approximation of the area under the P-R curve. + """ + + dtp = ops.subtract( + self.true_positives[: self.num_thresholds - 1], + self.true_positives[1:], + ) + p = ops.add(self.true_positives, self.false_positives) + dp = ops.subtract(p[: self.num_thresholds - 1], p[1:]) + prec_slope = ops.divide_no_nan(dtp, ops.maximum(dp, 0)) + intercept = ops.subtract( + self.true_positives[1:], ops.multiply(prec_slope, p[1:]) + ) + + safe_p_ratio = ops.where( + ops.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0), + ops.divide_no_nan( + p[: self.num_thresholds - 1], ops.maximum(p[1:], 0) + ), + ops.ones_like(p[1:]), + ) + + pr_auc_increment = ops.divide_no_nan( + ops.multiply( + prec_slope, + (ops.add(dtp, ops.multiply(intercept, ops.log(safe_p_ratio)))), + ), + ops.maximum( + ops.add(self.true_positives[1:], self.false_negatives[1:]), 0 + ), + ) + + if self.multi_label: + by_label_auc = ops.sum(pr_auc_increment, axis=0) + if self.label_weights is None: + # Evenly weighted average of the label AUCs. + return ops.mean(by_label_auc) + else: + # Weighted average of the label AUCs. + return ops.divide_no_nan( + ops.sum(ops.multiply(by_label_auc, self.label_weights)), + ops.sum(self.label_weights), + ) + else: + return ops.sum(pr_auc_increment) + + def result(self): + if ( + self.curve == metrics_utils.AUCCurve.PR + and self.summation_method + == metrics_utils.AUCSummationMethod.INTERPOLATION + ): + # This use case is different and is handled separately. + return self.interpolate_pr_auc() + + # Set `x` and `y` values for the curves based on `curve` config. + recall = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_negatives), + ) + if self.curve == metrics_utils.AUCCurve.ROC: + fp_rate = ops.divide_no_nan( + self.false_positives, + ops.add(self.false_positives, self.true_negatives), + ) + x = fp_rate + y = recall + elif self.curve == metrics_utils.AUCCurve.PR: # curve == 'PR'. + precision = ops.divide_no_nan( + self.true_positives, + ops.add(self.true_positives, self.false_positives), + ) + x = recall + y = precision + else: # curve == 'PRGAIN'. + # Due to the hyperbolic transform, this formula is less robust than + # ROC and PR values. In particular + # 1) Both measures diverge when there are no negative values; + # 2) Both measures diverge when there are no true positives; + # 3) Recall gain becomes negative when the recall is lower than the + # label average (i.e. when more negative examples are + # classified positive than real positives). + # + # We ignore case 1 as it is easily understood that metrics would be + # badly defined then. For case 2 we set recall_gain to 0 and + # precision_gain to 1. For case 3 we set recall_gain to 0. These + # fixes will result in an overestimation of the AUC for estimators + # that are anti-correlated with the label (at some threshold). + + # The scaling factor $\frac{P}{N}$ that is used to for both gain + # values. + scaling_factor = ops.divide_no_nan( + ops.add(self.true_positives, self.false_negatives), + ops.add(self.true_negatives, self.false_positives), + ) + + recall_gain = 1.0 - scaling_factor * ops.divide_no_nan( + self.false_negatives, self.true_positives + ) + precision_gain = 1.0 - scaling_factor * ops.divide_no_nan( + self.false_positives, self.true_positives + ) + # Handle case 2. + recall_gain = ops.where( + ops.equal(self.true_positives, 0.0), 0.0, recall_gain + ) + precision_gain = ops.where( + ops.equal(self.true_positives, 0.0), 1.0, precision_gain + ) + # Handle case 3. + recall_gain = ops.maximum(recall_gain, 0.0) + + x = recall_gain + y = precision_gain + + # Find the rectangle heights based on `summation_method`. + if ( + self.summation_method + == metrics_utils.AUCSummationMethod.INTERPOLATION + ): + # Note: the case ('PR', 'interpolation') has been handled above. + heights = ops.divide( + ops.add(y[: self.num_thresholds - 1], y[1:]), 2.0 + ) + elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: + heights = ops.minimum(y[: self.num_thresholds - 1], y[1:]) + # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: + else: + heights = ops.maximum(y[: self.num_thresholds - 1], y[1:]) + + # Sum up the areas of all the rectangles. + riemann_terms = ops.multiply( + ops.subtract(x[: self.num_thresholds - 1], x[1:]), heights + ) + if self.multi_label: + by_label_auc = ops.sum(riemann_terms, axis=0) + + if self.label_weights is None: + # Unweighted average of the label AUCs. + return ops.mean(by_label_auc) + else: + # Weighted average of the label AUCs. + return ops.divide_no_nan( + ops.sum(ops.multiply(by_label_auc, self.label_weights)), + ops.sum(self.label_weights), + ) + else: + return ops.sum(riemann_terms) + + def reset_state(self): + if self._built: + if self.multi_label: + variable_shape = (self.num_thresholds, self._num_labels) + else: + variable_shape = (self.num_thresholds,) + + self.true_positives.assign(ops.zeros(variable_shape)) + self.false_positives.assign(ops.zeros(variable_shape)) + self.true_negatives.assign(ops.zeros(variable_shape)) + self.false_negatives.assign(ops.zeros(variable_shape)) + + def get_config(self): + label_weights = self.label_weights + config = { + "num_thresholds": self.num_thresholds, + "curve": self.curve.value, + "summation_method": self.summation_method.value, + "multi_label": self.multi_label, + "num_labels": self.num_labels, + "label_weights": label_weights, + "from_logits": self._from_logits, + } + # optimization to avoid serializing a large number of generated + # thresholds + if self._init_from_thresholds: + # We remove the endpoint thresholds as an inverse of how the + # thresholds were initialized. This ensures that a metric + # initialized from this config has the same thresholds. + config["thresholds"] = self.thresholds[1:-1] + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/metrics/confusion_metrics_test.py b/keras/src/metrics/confusion_metrics_test.py new file mode 100644 index 000000000000..16941fb3be66 --- /dev/null +++ b/keras/src/metrics/confusion_metrics_test.py @@ -0,0 +1,1763 @@ +import json + +import numpy as np +import pytest +from absl import logging +from absl.testing import parameterized + +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.metrics import metrics_utils + + +class FalsePositivesTest(testing.TestCase): + def test_config(self): + fp_obj = metrics.FalsePositives(name="my_fp", thresholds=[0.4, 0.9]) + self.assertEqual(fp_obj.name, "my_fp") + self.assertLen(fp_obj.variables, 1) + self.assertEqual(fp_obj.thresholds, [0.4, 0.9]) + + # Check save and restore config + fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config()) + self.assertEqual(fp_obj2.name, "my_fp") + self.assertLen(fp_obj2.variables, 1) + self.assertEqual(fp_obj2.thresholds, [0.4, 0.9]) + + def test_unweighted(self): + fp_obj = metrics.FalsePositives() + + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + + fp_obj.update_state(y_true, y_pred) + self.assertAllClose(7.0, fp_obj.result()) + + def test_weighted(self): + fp_obj = metrics.FalsePositives() + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + sample_weight = np.array((1.0, 1.5, 2.0, 2.5)) + result = fp_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(14.0, result) + + def test_unweighted_with_thresholds(self): + fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + + fp_obj.update_state(y_true, y_pred) + self.assertAllClose([7.0, 4.0, 2.0], fp_obj.result()) + + def test_weighted_with_thresholds(self): + fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + sample_weight = ( + (1.0, 2.0, 3.0, 5.0), + (7.0, 11.0, 13.0, 17.0), + (19.0, 23.0, 29.0, 31.0), + (5.0, 15.0, 10.0, 0), + ) + + result = fp_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose([125.0, 42.0, 12.0], result) + + def test_threshold_limit(self): + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]", + ): + metrics.FalsePositives(thresholds=[-1, 0.5, 2]) + + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[None\]", + ): + metrics.FalsePositives(thresholds=[None]) + + +class FalseNegativesTest(testing.TestCase): + def test_config(self): + fn_obj = metrics.FalseNegatives(name="my_fn", thresholds=[0.4, 0.9]) + self.assertEqual(fn_obj.name, "my_fn") + self.assertLen(fn_obj.variables, 1) + self.assertEqual(fn_obj.thresholds, [0.4, 0.9]) + + # Check save and restore config + fn_obj2 = metrics.FalseNegatives.from_config(fn_obj.get_config()) + self.assertEqual(fn_obj2.name, "my_fn") + self.assertLen(fn_obj2.variables, 1) + self.assertEqual(fn_obj2.thresholds, [0.4, 0.9]) + + def test_unweighted(self): + fn_obj = metrics.FalseNegatives() + + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + + fn_obj.update_state(y_true, y_pred) + self.assertAllClose(3.0, fn_obj.result()) + + def test_weighted(self): + fn_obj = metrics.FalseNegatives() + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + sample_weight = np.array((1.0, 1.5, 2.0, 2.5)) + result = fn_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(5.0, result) + + def test_unweighted_with_thresholds(self): + fn_obj = metrics.FalseNegatives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + + fn_obj.update_state(y_true, y_pred) + self.assertAllClose([1.0, 4.0, 6.0], fn_obj.result()) + + def test_weighted_with_thresholds(self): + fn_obj = metrics.FalseNegatives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + sample_weight = ((3.0,), (5.0,), (7.0,), (4.0,)) + + result = fn_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose([4.0, 16.0, 23.0], result) + + def test_threshold_limit(self): + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]", + ): + metrics.FalseNegatives(thresholds=[-1, 0.5, 2]) + + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[None\]", + ): + metrics.FalseNegatives(thresholds=[None]) + + +class TrueNegativesTest(testing.TestCase): + def test_config(self): + tn_obj = metrics.TrueNegatives(name="my_tn", thresholds=[0.4, 0.9]) + self.assertEqual(tn_obj.name, "my_tn") + self.assertLen(tn_obj.variables, 1) + self.assertEqual(tn_obj.thresholds, [0.4, 0.9]) + + # Check save and restore config + tn_obj2 = metrics.TrueNegatives.from_config(tn_obj.get_config()) + self.assertEqual(tn_obj2.name, "my_tn") + self.assertLen(tn_obj2.variables, 1) + self.assertEqual(tn_obj2.thresholds, [0.4, 0.9]) + + def test_unweighted(self): + tn_obj = metrics.TrueNegatives() + + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + + tn_obj.update_state(y_true, y_pred) + self.assertAllClose(3.0, tn_obj.result()) + + def test_weighted(self): + tn_obj = metrics.TrueNegatives() + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + sample_weight = np.array((1.0, 1.5, 2.0, 2.5)) + result = tn_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(4.0, result) + + def test_unweighted_with_thresholds(self): + tn_obj = metrics.TrueNegatives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + + tn_obj.update_state(y_true, y_pred) + self.assertAllClose([2.0, 5.0, 7.0], tn_obj.result()) + + def test_weighted_with_thresholds(self): + tn_obj = metrics.TrueNegatives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + sample_weight = ((0.0, 2.0, 3.0, 5.0),) + + result = tn_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose([5.0, 15.0, 23.0], result) + + def test_threshold_limit(self): + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]", + ): + metrics.TrueNegatives(thresholds=[-1, 0.5, 2]) + + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[None\]", + ): + metrics.TrueNegatives(thresholds=[None]) + + +class TruePositiveTest(testing.TestCase): + def test_config(self): + tp_obj = metrics.TruePositives(name="my_tp", thresholds=[0.4, 0.9]) + self.assertEqual(tp_obj.name, "my_tp") + self.assertLen(tp_obj.variables, 1) + self.assertEqual(tp_obj.thresholds, [0.4, 0.9]) + + # Check save and restore config + tp_obj2 = metrics.TruePositives.from_config(tp_obj.get_config()) + self.assertEqual(tp_obj2.name, "my_tp") + self.assertLen(tp_obj2.variables, 1) + self.assertEqual(tp_obj2.thresholds, [0.4, 0.9]) + + def test_unweighted(self): + tp_obj = metrics.TruePositives() + + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + + tp_obj.update_state(y_true, y_pred) + self.assertAllClose(7.0, tp_obj.result()) + + def test_weighted(self): + tp_obj = metrics.TruePositives() + y_true = np.array( + ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)) + ) + y_pred = np.array( + ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)) + ) + sample_weight = np.array((1.0, 1.5, 2.0, 2.5)) + result = tp_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(12.0, result) + + def test_unweighted_with_thresholds(self): + tp_obj = metrics.TruePositives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + + tp_obj.update_state(y_true, y_pred) + self.assertAllClose([6.0, 3.0, 1.0], tp_obj.result()) + + def test_weighted_with_thresholds(self): + tp_obj = metrics.TruePositives(thresholds=[0.15, 0.5, 0.85]) + + y_pred = np.array( + ( + (0.9, 0.2, 0.8, 0.1), + (0.2, 0.9, 0.7, 0.6), + (0.1, 0.2, 0.4, 0.3), + (0, 1, 0.7, 0.3), + ) + ) + y_true = np.array( + ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1)) + ) + sample_weight = 37.0 + + result = tp_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose([222.0, 111.0, 37.0], result) + + def test_threshold_limit(self): + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[-1, 2\]", + ): + metrics.TruePositives(thresholds=[-1, 0.5, 2]) + + with self.assertRaisesRegex( + ValueError, + r"Threshold values must be in \[0, 1\]. Received: \[None\]", + ): + metrics.TruePositives(thresholds=[None]) + + +class PrecisionTest(testing.TestCase): + def test_config(self): + p_obj = metrics.Precision( + name="my_precision", thresholds=[0.4, 0.9], top_k=15, class_id=12 + ) + self.assertEqual(p_obj.name, "my_precision") + self.assertLen(p_obj.variables, 2) + self.assertEqual( + [v.name for v in p_obj.variables], + ["true_positives", "false_positives"], + ) + self.assertEqual(p_obj.thresholds, [0.4, 0.9]) + self.assertEqual(p_obj.top_k, 15) + self.assertEqual(p_obj.class_id, 12) + + # Check save and restore config + p_obj2 = metrics.Precision.from_config(p_obj.get_config()) + self.assertEqual(p_obj2.name, "my_precision") + self.assertLen(p_obj2.variables, 2) + self.assertEqual(p_obj2.thresholds, [0.4, 0.9]) + self.assertEqual(p_obj2.top_k, 15) + self.assertEqual(p_obj2.class_id, 12) + + def test_unweighted(self): + p_obj = metrics.Precision() + y_pred = np.array([1, 0, 1, 0]) + y_true = np.array([0, 1, 1, 0]) + result = p_obj(y_true, y_pred) + self.assertAlmostEqual(0.5, result) + + def test_unweighted_all_incorrect(self): + p_obj = metrics.Precision(thresholds=[0.5]) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs) + y_true = np.array(1 - inputs) + result = p_obj(y_true, y_pred) + self.assertAlmostEqual(0, result) + + def test_weighted(self): + p_obj = metrics.Precision() + y_pred = np.array([[1, 0, 1, 0], [1, 0, 1, 0]]) + y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]]) + result = p_obj( + y_true, + y_pred, + sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]), + ) + weighted_tp = 3.0 + 4.0 + weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) + expected_precision = weighted_tp / weighted_positives + self.assertAlmostEqual(expected_precision, result) + + def test_div_by_zero(self): + p_obj = metrics.Precision() + y_pred = np.array([0, 0, 0, 0]) + y_true = np.array([0, 0, 0, 0]) + result = p_obj(y_true, y_pred) + self.assertEqual(0, result) + + def test_unweighted_with_threshold(self): + p_obj = metrics.Precision(thresholds=[0.5, 0.7]) + y_pred = np.array([1, 0, 0.6, 0]) + y_true = np.array([0, 1, 1, 0]) + result = p_obj(y_true, y_pred) + self.assertAlmostEqual([0.5, 0.0], result, 0) + + def test_weighted_with_threshold(self): + p_obj = metrics.Precision(thresholds=[0.5, 1.0]) + y_true = np.array([[0, 1], [1, 0]]) + y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32") + weights = np.array([[4, 0], [3, 1]], dtype="float32") + result = p_obj(y_true, y_pred, sample_weight=weights) + weighted_tp = 0 + 3.0 + weighted_positives = (0 + 3.0) + (4.0 + 0.0) + expected_precision = weighted_tp / weighted_positives + self.assertAlmostEqual([expected_precision, 0], result, 1e-3) + + def test_multiple_updates(self): + p_obj = metrics.Precision(thresholds=[0.5, 1.0]) + y_true = np.array([[0, 1], [1, 0]]) + y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32") + weights = np.array([[4, 0], [3, 1]], dtype="float32") + for _ in range(2): + p_obj.update_state(y_true, y_pred, sample_weight=weights) + + weighted_tp = (0 + 3.0) + (0 + 3.0) + weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + ( + (0 + 3.0) + (4.0 + 0.0) + ) + expected_precision = weighted_tp / weighted_positives + self.assertAlmostEqual([expected_precision, 0], p_obj.result(), 1e-3) + + def test_unweighted_top_k(self): + p_obj = metrics.Precision(top_k=3) + y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + result = p_obj(y_true, y_pred) + self.assertAlmostEqual(1.0 / 3, result) + + def test_weighted_top_k(self): + p_obj = metrics.Precision(top_k=3) + y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]]) + y_true1 = np.array([[0, 1, 1, 0, 1]]) + p_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]])) + + y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2]) + y_true2 = np.array([1, 0, 1, 1, 1]) + result = p_obj(y_true2, y_pred2, sample_weight=np.array(3)) + + tp = (2 + 5) + (3 + 3) + predicted_positives = (1 + 2 + 5) + (3 + 3 + 3) + expected_precision = tp / predicted_positives + self.assertAlmostEqual(expected_precision, result) + + def test_unweighted_class_id_should_throw_error_1d(self): + p_obj = metrics.Precision(class_id=2) + + y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + + with self.assertRaisesRegex( + ValueError, + r"When class_id is provided, y_pred must be a 2D array " + r"with shape \(num_samples, num_classes\), found shape:.*", + ): + p_obj(y_true, y_pred) + + def test_unweighted_class_id_multiclass(self): + p_obj = metrics.Precision(class_id=1) + + y_pred = np.array( + [ + [0.1, 0.2, 0.7], + [0.5, 0.3, 0.2], + [0.2, 0.6, 0.2], + [0.7, 0.2, 0.1], + [0.1, 0.1, 0.8], + ] + ) + + y_true = np.array( + [ + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + result = p_obj(y_true, y_pred) + self.assertAlmostEqual(1.0, result) + self.assertAlmostEqual(1.0, p_obj.true_positives) + self.assertAlmostEqual(0.0, p_obj.false_positives) + + def test_unweighted_top_k_and_threshold(self): + p_obj = metrics.Precision(thresholds=0.7, top_k=2) + + y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 1]) + result = p_obj(y_true, y_pred) + self.assertAlmostEqual(1, result) + self.assertAlmostEqual(1, p_obj.true_positives) + self.assertAlmostEqual(0, p_obj.false_positives) + + +class RecallTest(testing.TestCase): + def test_config(self): + r_obj = metrics.Recall( + name="my_recall", thresholds=[0.4, 0.9], top_k=15, class_id=12 + ) + self.assertEqual(r_obj.name, "my_recall") + self.assertLen(r_obj.variables, 2) + self.assertEqual( + [v.name for v in r_obj.variables], + ["true_positives", "false_negatives"], + ) + self.assertEqual(r_obj.thresholds, [0.4, 0.9]) + self.assertEqual(r_obj.top_k, 15) + self.assertEqual(r_obj.class_id, 12) + + # Check save and restore config + r_obj2 = metrics.Recall.from_config(r_obj.get_config()) + self.assertEqual(r_obj2.name, "my_recall") + self.assertLen(r_obj2.variables, 2) + self.assertEqual(r_obj2.thresholds, [0.4, 0.9]) + self.assertEqual(r_obj2.top_k, 15) + self.assertEqual(r_obj2.class_id, 12) + + def test_unweighted(self): + r_obj = metrics.Recall() + y_pred = np.array([1, 0, 1, 0]) + y_true = np.array([0, 1, 1, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + + def test_unweighted_all_incorrect(self): + r_obj = metrics.Recall(thresholds=[0.5]) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs) + y_true = np.array(1 - inputs) + self.assertAlmostEqual(0, r_obj(y_true, y_pred)) + + def test_weighted(self): + r_obj = metrics.Recall() + y_pred = np.array([[1, 0, 1, 0], [0, 1, 0, 1]]) + y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]]) + result = r_obj( + y_true, + y_pred, + sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]), + ) + weighted_tp = 3.0 + 1.0 + weighted_t = (2.0 + 3.0) + (4.0 + 1.0) + expected_recall = weighted_tp / weighted_t + self.assertAlmostEqual(expected_recall, result) + + def test_div_by_zero(self): + r_obj = metrics.Recall() + y_pred = np.array([0, 0, 0, 0]) + y_true = np.array([0, 0, 0, 0]) + self.assertEqual(0, r_obj(y_true, y_pred)) + + def test_unweighted_with_threshold(self): + r_obj = metrics.Recall(thresholds=[0.5, 0.7]) + y_pred = np.array([1, 0, 0.6, 0]) + y_true = np.array([0, 1, 1, 0]) + self.assertAllClose([0.5, 0.0], r_obj(y_true, y_pred), 0) + + def test_weighted_with_threshold(self): + r_obj = metrics.Recall(thresholds=[0.5, 1.0]) + y_true = np.array([[0, 1], [1, 0]]) + y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32") + weights = np.array([[1, 4], [3, 2]], dtype="float32") + result = r_obj(y_true, y_pred, sample_weight=weights) + weighted_tp = 0 + 3.0 + weighted_positives = (0 + 3.0) + (4.0 + 0.0) + expected_recall = weighted_tp / weighted_positives + self.assertAllClose([expected_recall, 0], result, 1e-3) + + def test_multiple_updates(self): + r_obj = metrics.Recall(thresholds=[0.5, 1.0]) + y_true = np.array([[0, 1], [1, 0]]) + y_pred = np.array([[1, 0], [0.6, 0]], dtype="float32") + weights = np.array([[1, 4], [3, 2]], dtype="float32") + for _ in range(2): + r_obj.update_state(y_true, y_pred, sample_weight=weights) + + weighted_tp = (0 + 3.0) + (0 + 3.0) + weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + ( + (0 + 3.0) + (4.0 + 0.0) + ) + expected_recall = weighted_tp / weighted_positives + self.assertAllClose([expected_recall, 0], r_obj.result(), 1e-3) + + def test_unweighted_top_k(self): + r_obj = metrics.Recall(top_k=3) + y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + self.assertAlmostEqual(0.5, r_obj(y_true, y_pred)) + + def test_weighted_top_k(self): + r_obj = metrics.Recall(top_k=3) + y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]]) + y_true1 = np.array([[0, 1, 1, 0, 1]]) + r_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]])) + + y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2]) + y_true2 = np.array([1, 0, 1, 1, 1]) + result = r_obj(y_true2, y_pred2, sample_weight=np.array(3)) + + tp = (2 + 5) + (3 + 3) + positives = (4 + 2 + 5) + (3 + 3 + 3 + 3) + expected_recall = tp / positives + self.assertAlmostEqual(expected_recall, result) + + def test_unweighted_class_id_should_throw_error_1d(self): + r_obj = metrics.Recall(class_id=2) + + y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2]) + y_true = np.array([0, 1, 1, 0, 0]) + + with self.assertRaisesRegex( + ValueError, + r"When class_id is provided, y_pred must be a 2D array " + r"with shape \(num_samples, num_classes\), found shape:.*", + ): + r_obj(y_true, y_pred) + + def test_unweighted_class_id_multiclass(self): + r_obj = metrics.Recall(class_id=1) + + y_pred = np.array( + [ + [0.1, 0.2, 0.7], + [0.5, 0.3, 0.2], + [0.2, 0.6, 0.2], + [0.7, 0.2, 0.1], + [0.1, 0.1, 0.8], + ] + ) + + y_true = np.array( + [ + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ] + ) + + result = r_obj(y_true, y_pred) + self.assertAlmostEqual(1.0, result) + self.assertAlmostEqual(1.0, r_obj.true_positives) + self.assertAlmostEqual(0.0, r_obj.false_negatives) + + def test_unweighted_top_k_and_threshold(self): + r_obj = metrics.Recall(thresholds=0.7, top_k=2) + + y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2]) + y_true = np.array([1, 1, 1, 0, 1]) + self.assertAlmostEqual(0.25, r_obj(y_true, y_pred)) + self.assertAlmostEqual(1, r_obj.true_positives) + self.assertAlmostEqual(3, r_obj.false_negatives) + + +class SensitivityAtSpecificityTest(testing.TestCase): + def test_config(self): + s_obj = metrics.SensitivityAtSpecificity( + 0.4, + num_thresholds=100, + class_id=12, + name="sensitivity_at_specificity_1", + ) + self.assertEqual(s_obj.name, "sensitivity_at_specificity_1") + self.assertLen(s_obj.variables, 4) + self.assertEqual(s_obj.specificity, 0.4) + self.assertEqual(s_obj.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + # Check save and restore config + s_obj2 = metrics.SensitivityAtSpecificity.from_config( + s_obj.get_config() + ) + self.assertEqual(s_obj2.name, "sensitivity_at_specificity_1") + self.assertLen(s_obj2.variables, 4) + self.assertEqual(s_obj2.specificity, 0.4) + self.assertEqual(s_obj2.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + def test_unweighted_all_correct(self): + s_obj = metrics.SensitivityAtSpecificity(0.7) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs, dtype="float32") + y_true = np.array(inputs) + self.assertAlmostEqual(1, s_obj(y_true, y_pred)) + + def test_unweighted_high_specificity(self): + s_obj = metrics.SensitivityAtSpecificity(0.8) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + self.assertAlmostEqual(0.8, s_obj(y_true, y_pred)) + + def test_unweighted_low_specificity(self): + s_obj = metrics.SensitivityAtSpecificity(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + self.assertAlmostEqual(0.6, s_obj(y_true, y_pred)) + + def test_unweighted_class_id(self): + s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2] + + y_pred = ops.transpose(np.array([pred_values] * 3)) + y_true = ops.one_hot(np.array(label_values), num_classes=3) + + self.assertAlmostEqual(0.6, s_obj(y_true, y_pred)) + + @parameterized.parameters(["bool", "int32", "float32"]) + def test_weighted(self, label_dtype): + s_obj = metrics.SensitivityAtSpecificity(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + y_pred = np.array(pred_values, dtype="float32") + y_true = ops.cast(label_values, dtype=label_dtype) + weights = np.array(weight_values) + + result = s_obj(y_true, y_pred, sample_weight=weights) + self.assertAlmostEqual(0.675, result) + + def test_invalid_specificity(self): + with self.assertRaisesRegex( + ValueError, r"`specificity` must be in the range \[0, 1\]." + ): + metrics.SensitivityAtSpecificity(-1) + + def test_invalid_num_thresholds(self): + with self.assertRaisesRegex( + ValueError, "Argument `num_thresholds` must be an integer > 0" + ): + metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1) + + +class SpecificityAtSensitivityTest(testing.TestCase): + def test_config(self): + s_obj = metrics.SpecificityAtSensitivity( + 0.4, + num_thresholds=100, + class_id=12, + name="specificity_at_sensitivity_1", + ) + self.assertEqual(s_obj.name, "specificity_at_sensitivity_1") + self.assertLen(s_obj.variables, 4) + self.assertEqual(s_obj.sensitivity, 0.4) + self.assertEqual(s_obj.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + # Check save and restore config + s_obj2 = metrics.SpecificityAtSensitivity.from_config( + s_obj.get_config() + ) + self.assertEqual(s_obj2.name, "specificity_at_sensitivity_1") + self.assertLen(s_obj2.variables, 4) + self.assertEqual(s_obj2.sensitivity, 0.4) + self.assertEqual(s_obj2.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + def test_unweighted_all_correct(self): + s_obj = metrics.SpecificityAtSensitivity(0.7) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs, dtype="float32") + y_true = np.array(inputs) + + self.assertAlmostEqual(1, s_obj(y_true, y_pred)) + + def test_unweighted_high_sensitivity(self): + s_obj = metrics.SpecificityAtSensitivity(1.0) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + self.assertAlmostEqual(0.2, s_obj(y_true, y_pred)) + + def test_unweighted_low_sensitivity(self): + s_obj = metrics.SpecificityAtSensitivity(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + self.assertAlmostEqual(0.6, s_obj(y_true, y_pred)) + + def test_unweighted_class_id(self): + s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2] + + y_pred = ops.transpose(np.array([pred_values] * 3)) + y_true = ops.one_hot(np.array(label_values), num_classes=3) + + self.assertAlmostEqual(0.6, s_obj(y_true, y_pred)) + + @parameterized.parameters(["bool", "int32", "float32"]) + def test_weighted(self, label_dtype): + s_obj = metrics.SpecificityAtSensitivity(0.4) + pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + y_pred = np.array(pred_values, dtype="float32") + y_true = ops.cast(label_values, dtype=label_dtype) + weights = np.array(weight_values) + + result = s_obj(y_true, y_pred, sample_weight=weights) + self.assertAlmostEqual(0.4, result) + + def test_invalid_sensitivity(self): + with self.assertRaisesRegex( + ValueError, r"`sensitivity` must be in the range \[0, 1\]." + ): + metrics.SpecificityAtSensitivity(-1) + + def test_invalid_num_thresholds(self): + with self.assertRaisesRegex( + ValueError, "Argument `num_thresholds` must be an integer > 0" + ): + metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1) + + +class PrecisionAtRecallTest(testing.TestCase): + def test_config(self): + s_obj = metrics.PrecisionAtRecall( + 0.4, num_thresholds=100, class_id=12, name="precision_at_recall_1" + ) + self.assertEqual(s_obj.name, "precision_at_recall_1") + self.assertLen(s_obj.variables, 4) + self.assertEqual(s_obj.recall, 0.4) + self.assertEqual(s_obj.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + # Check save and restore config + s_obj2 = metrics.PrecisionAtRecall.from_config(s_obj.get_config()) + self.assertEqual(s_obj2.name, "precision_at_recall_1") + self.assertLen(s_obj2.variables, 4) + self.assertEqual(s_obj2.recall, 0.4) + self.assertEqual(s_obj2.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + def test_unweighted_all_correct(self): + s_obj = metrics.PrecisionAtRecall(0.7) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs, dtype="float32") + y_true = np.array(inputs) + + self.assertAlmostEqual(1, s_obj(y_true, y_pred)) + + def test_unweighted_high_recall(self): + s_obj = metrics.PrecisionAtRecall(0.8) + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + # For 0.5 < decision threshold < 0.6. + self.assertAlmostEqual(2.0 / 3, s_obj(y_true, y_pred)) + + def test_unweighted_low_recall(self): + s_obj = metrics.PrecisionAtRecall(0.6) + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + # For 0.2 < decision threshold < 0.5. + self.assertAlmostEqual(0.75, s_obj(y_true, y_pred)) + + def test_unweighted_class_id(self): + s_obj = metrics.PrecisionAtRecall(0.6, class_id=2) + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2] + + y_pred = ops.transpose(np.array([pred_values] * 3)) + y_true = ops.one_hot(np.array(label_values), num_classes=3) + + # For 0.2 < decision threshold < 0.5. + self.assertAlmostEqual(0.75, s_obj(y_true, y_pred)) + + @parameterized.parameters(["bool", "int32", "float32"]) + def test_weighted(self, label_dtype): + s_obj = metrics.PrecisionAtRecall(7.0 / 8) + pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9] + label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2] + + y_pred = np.array(pred_values, dtype="float32") + y_true = ops.cast(label_values, dtype=label_dtype) + weights = np.array(weight_values) + + result = s_obj(y_true, y_pred, sample_weight=weights) + # For 0.0 < decision threshold < 0.2. + self.assertAlmostEqual(0.7, result) + + def test_invalid_sensitivity(self): + with self.assertRaisesRegex( + ValueError, r"`recall` must be in the range \[0, 1\]." + ): + metrics.PrecisionAtRecall(-1) + + def test_invalid_num_thresholds(self): + with self.assertRaisesRegex( + ValueError, "Argument `num_thresholds` must be an integer > 0" + ): + metrics.PrecisionAtRecall(0.4, num_thresholds=-1) + + +class RecallAtPrecisionTest(testing.TestCase): + def test_config(self): + s_obj = metrics.RecallAtPrecision( + 0.4, num_thresholds=100, class_id=12, name="recall_at_precision_1" + ) + self.assertEqual(s_obj.name, "recall_at_precision_1") + self.assertLen(s_obj.variables, 4) + self.assertEqual(s_obj.precision, 0.4) + self.assertEqual(s_obj.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + # Check save and restore config + s_obj2 = metrics.RecallAtPrecision.from_config(s_obj.get_config()) + self.assertEqual(s_obj2.name, "recall_at_precision_1") + self.assertLen(s_obj2.variables, 4) + self.assertEqual(s_obj2.precision, 0.4) + self.assertEqual(s_obj2.num_thresholds, 100) + self.assertEqual(s_obj.class_id, 12) + + def test_unweighted_all_correct(self): + s_obj = metrics.RecallAtPrecision(0.7) + inputs = np.random.randint(0, 2, size=(100, 1)) + y_pred = np.array(inputs, dtype="float32") + y_true = np.array(inputs) + + self.assertAlmostEqual(1, s_obj(y_true, y_pred)) + + def test_unweighted_high_precision(self): + s_obj = metrics.RecallAtPrecision(0.75) + pred_values = [ + 0.05, + 0.1, + 0.2, + 0.3, + 0.3, + 0.35, + 0.4, + 0.45, + 0.5, + 0.6, + 0.9, + 0.95, + ] + label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1] + # precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2, + # 1]. + # recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6, + # 1/6]. + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + # The precision 0.75 can be reached at thresholds 0.4<=t<0.45. + self.assertAlmostEqual(0.5, s_obj(y_true, y_pred)) + + def test_unweighted_low_precision(self): + s_obj = metrics.RecallAtPrecision(2.0 / 3) + pred_values = [ + 0.05, + 0.1, + 0.2, + 0.3, + 0.3, + 0.35, + 0.4, + 0.45, + 0.5, + 0.6, + 0.9, + 0.95, + ] + label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1] + # precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2, + # 1]. + # recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6, + # 1/6]. + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + # The precision 5/7 can be reached at thresholds 00.3<=t<0.35. + self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred)) + + def test_unweighted_class_id(self): + s_obj = metrics.RecallAtPrecision(2.0 / 3, class_id=2) + pred_values = [ + 0.05, + 0.1, + 0.2, + 0.3, + 0.3, + 0.35, + 0.4, + 0.45, + 0.5, + 0.6, + 0.9, + 0.95, + ] + label_values = [0, 2, 0, 0, 0, 2, 2, 0, 2, 2, 0, 2] + # precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2, + # 1]. + # recalls: [1, 1, 5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6, + # 1/6]. + y_pred = ops.transpose(np.array([pred_values] * 3)) + y_true = ops.one_hot(np.array(label_values), num_classes=3) + + # The precision 5/7 can be reached at thresholds 00.3<=t<0.35. + self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred)) + + @parameterized.parameters(["bool", "int32", "float32"]) + def test_weighted(self, label_dtype): + s_obj = metrics.RecallAtPrecision(0.75) + pred_values = [0.1, 0.2, 0.3, 0.5, 0.6, 0.9, 0.9] + label_values = [0, 1, 0, 0, 0, 1, 1] + weight_values = [1, 2, 1, 2, 1, 2, 1] + y_pred = np.array(pred_values, dtype="float32") + y_true = ops.cast(label_values, dtype=label_dtype) + weights = np.array(weight_values) + + result = s_obj(y_true, y_pred, sample_weight=weights) + self.assertAlmostEqual(0.6, result) + + def test_unachievable_precision(self): + s_obj = metrics.RecallAtPrecision(2.0 / 3) + pred_values = [0.1, 0.2, 0.3, 0.9] + label_values = [1, 1, 0, 0] + y_pred = np.array(pred_values, dtype="float32") + y_true = np.array(label_values) + + # The highest possible precision is 1/2 which is below the required + # value, expect 0 recall. + self.assertAlmostEqual(0, s_obj(y_true, y_pred)) + + def test_invalid_sensitivity(self): + with self.assertRaisesRegex( + ValueError, r"`precision` must be in the range \[0, 1\]." + ): + metrics.RecallAtPrecision(-1) + + def test_invalid_num_thresholds(self): + with self.assertRaisesRegex( + ValueError, "Argument `num_thresholds` must be an integer > 0" + ): + metrics.RecallAtPrecision(0.4, num_thresholds=-1) + + @pytest.mark.requires_trainable_backend + def test_end_to_end(self): + # Test for https://github.com/keras-team/keras/issues/718 + model = models.Sequential( + [ + layers.Input((1,)), + layers.Dense(1), + ] + ) + model.compile( + optimizer="rmsprop", loss="mse", metrics=[metrics.Precision()] + ) + model.fit(np.ones((5, 1)), np.ones((5, 1))) + + +class AUCTest(testing.TestCase): + def setUp(self): + self.num_thresholds = 3 + self.y_pred = np.array([0, 0.5, 0.3, 0.9], dtype="float32") + self.y_pred_multi_label = np.array( + [[0.0, 0.4], [0.5, 0.7], [0.3, 0.2], [0.9, 0.3]], dtype="float32" + ) + epsilon = 1e-12 + self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0) + self.y_true = np.array([0, 0, 1, 1]) + self.y_true_multi_label = np.array([[0, 0], [1, 1], [1, 1], [1, 0]]) + self.sample_weight = [1, 2, 3, 4] + + # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] + # y_pred when threshold = 0 - 1e-7 : [1, 1, 1, 1] + # y_pred when threshold = 0.5 : [0, 0, 0, 1] + # y_pred when threshold = 1 + 1e-7 : [0, 0, 0, 0] + + # without sample_weight: + # tp = np.sum([[0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]], axis=1) + # fp = np.sum([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1) + # fn = np.sum([[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]], axis=1) + # tn = np.sum([[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0]], axis=1) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + + # with sample_weight: + # tp = np.sum([[0, 0, 3, 4], [0, 0, 0, 4], [0, 0, 0, 0]], axis=1) + # fp = np.sum([[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1) + # fn = np.sum([[0, 0, 0, 0], [0, 0, 3, 0], [0, 0, 3, 4]], axis=1) + # tn = np.sum([[0, 0, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0]], axis=1) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + + def test_config(self): + auc_obj = metrics.AUC( + num_thresholds=100, + curve="PR", + summation_method="majoring", + name="auc_1", + dtype="float64", + multi_label=True, + num_labels=2, + from_logits=True, + ) + auc_obj.update_state(self.y_true_multi_label, self.y_pred_multi_label) + self.assertEqual(auc_obj.name, "auc_1") + self.assertEqual(auc_obj._dtype, "float64") + self.assertLen(auc_obj.variables, 4) + self.assertEqual(auc_obj.num_thresholds, 100) + self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR) + self.assertEqual( + auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING + ) + self.assertTrue(auc_obj.multi_label) + self.assertEqual(auc_obj.num_labels, 2) + self.assertTrue(auc_obj._from_logits) + old_config = auc_obj.get_config() + self.assertNotIn("thresholds", old_config) + self.assertDictEqual(old_config, json.loads(json.dumps(old_config))) + + # Check save and restore config. + auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) + auc_obj2.update_state(self.y_true_multi_label, self.y_pred_multi_label) + self.assertEqual(auc_obj2.name, "auc_1") + self.assertLen(auc_obj2.variables, 4) + self.assertEqual(auc_obj2.num_thresholds, 100) + self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR) + self.assertEqual( + auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING + ) + self.assertTrue(auc_obj2.multi_label) + self.assertEqual(auc_obj2.num_labels, 2) + self.assertTrue(auc_obj2._from_logits) + new_config = auc_obj2.get_config() + self.assertNotIn("thresholds", new_config) + self.assertDictEqual(old_config, new_config) + self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds) + + def test_config_manual_thresholds(self): + auc_obj = metrics.AUC( + num_thresholds=None, + curve="PR", + summation_method="majoring", + name="auc_1", + thresholds=[0.3, 0.5], + ) + auc_obj.update_state(self.y_true, self.y_pred) + self.assertEqual(auc_obj.name, "auc_1") + self.assertLen(auc_obj.variables, 4) + self.assertEqual(auc_obj.num_thresholds, 4) + self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0]) + self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR) + self.assertEqual( + auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING + ) + old_config = auc_obj.get_config() + self.assertDictEqual(old_config, json.loads(json.dumps(old_config))) + + # Check save and restore config. + auc_obj2 = metrics.AUC.from_config(auc_obj.get_config()) + auc_obj2.update_state(self.y_true, self.y_pred) + self.assertEqual(auc_obj2.name, "auc_1") + self.assertLen(auc_obj2.variables, 4) + self.assertEqual(auc_obj2.num_thresholds, 4) + self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR) + self.assertEqual( + auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING + ) + new_config = auc_obj2.get_config() + self.assertDictEqual(old_config, new_config) + self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds) + + def test_unweighted_all_correct(self): + auc_obj = metrics.AUC() + self.assertEqual(auc_obj(self.y_true, self.y_true), 1) + + def test_unweighted(self): + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds) + result = auc_obj(self.y_true, self.y_pred) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 0.75 * 1 + 0.25 * 0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_unweighted_from_logits(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, from_logits=True + ) + result = auc_obj(self.y_true, self.y_pred_logits) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 0.75 * 1 + 0.25 * 0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_manual_thresholds(self): + # Verify that when specified, thresholds are used instead of + # num_thresholds. + auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5]) + self.assertEqual(auc_obj.num_thresholds, 3) + self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0]) + result = auc_obj(self.y_true, self.y_pred) + + # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] + # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0] + # fp_rate = [2/2, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 0.75 * 1 + 0.25 * 0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_roc_interpolation(self): + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # fp_rate = [3/3, 0, 0] = [1, 0, 0] + # heights = [(1 + 0.571)/2, (0.571 + 0)/2] = [0.7855, 0.2855] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 0.7855 * 1 + 0.2855 * 0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_roc_majoring(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, summation_method="majoring" + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # fp_rate = [3/3, 0, 0] = [1, 0, 0] + # heights = [max(1, 0.571), max(0.571, 0)] = [1, 0.571] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 1 * 1 + 0.571 * 0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_roc_minoring(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, summation_method="minoring" + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # fp_rate = [3/3, 0, 0] = [1, 0, 0] + # heights = [min(1, 0.571), min(0.571, 0)] = [0.571, 0] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 0.571 * 1 + 0 * 0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_pr_majoring(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + curve="PR", + summation_method="majoring", + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # heights = [max(0.7, 1), max(1, 0)] = [1, 1] + # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571] + expected_result = 1 * 0.429 + 1 * 0.571 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_pr_minoring(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + curve="PR", + summation_method="minoring", + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0] + # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0] + # heights = [min(0.7, 1), min(1, 0)] = [0.7, 0] + # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571] + expected_result = 0.7 * 0.429 + 0 * 0.571 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_pr_interpolation(self): + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR") + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # auc = (slope / Total Pos) * [dTP - intercept * log(Pb/Pa)] + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # P = tp + fp = [10, 4, 0] + # dTP = [7-4, 4-0] = [3, 4] + # dP = [10-4, 4-0] = [6, 4] + # slope = dTP/dP = [0.5, 1] + # intercept = (TPa+(slope*Pa) = [(4 - 0.5*4), (0 - 1*0)] = [2, 0] + # (Pb/Pa) = (Pb/Pa) if Pb > 0 AND Pa > 0 else 1 = [10/4, 4/0] = [2.5, 1] + # auc * TotalPos = [(0.5 * (3 + 2 * log(2.5))), (1 * (4 + 0))] + # = [2.416, 4] + # auc = [2.416, 4]/(tp[1:]+fn[1:]) + expected_result = 2.416 / 7 + 4 / 7 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_pr_interpolation_negative_weights(self): + auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR") + sample_weight = [-1, -2, -3, -4] + result = auc_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + + # Divisor in auc formula is max(tp[1:]+fn[1:], 0), which is all zeros + # because the all values in tp and fn are negative, divide_no_nan will + # produce all zeros. + self.assertAllClose(result, 0.0, 1e-3) + + def test_weighted_prgain_majoring(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + curve="PRGAIN", + summation_method="majoring", + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # scaling_factor (P/N) = 7/3 + # recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0] + # precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1] + # heights = [max(0, 1), max(1, 1)] = [1, 1] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 1 * 1 + 0 * 1 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_prgain_minoring(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + curve="PRGAIN", + summation_method="minoring", + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # scaling_factor (P/N) = 7/3 + # recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0] + # precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1] + # heights = [min(0, 1), min(1, 1)] = [0, 1] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 1 * 0 + 0 * 1 + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_prgain_interpolation(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, curve="PRGAIN" + ) + result = auc_obj( + self.y_true, self.y_pred, sample_weight=self.sample_weight + ) + + # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3] + # scaling_factor (P/N) = 7/3 + # recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0] + # precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1] + # heights = [(0+1)/2, (1+1)/2] = [0.5, 1] + # widths = [(1 - 0), (0 - 0)] = [1, 0] + expected_result = 1 * 0.5 + 0 * 1 + self.assertAllClose(result, expected_result, 1e-3) + + def test_prgain_interpolation(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, curve="PRGAIN" + ) + + y_true = np.array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1]) + y_pred = np.array([0.1, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.6, 0.8, 0.9]) + result = auc_obj(y_true, y_pred) + + # tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4] + # scaling_factor (P/N) = 5/5 = 1 + # recall_gain = 1 - [0/5, 2/3, 5/0] = [1, 1/3, -inf] -> [1, 1/3, 0] + # precision_gain = 1 - [5/5, 1/3, 0/0] = [1, 1/3, NaN] -> [0, 2/3, 1] + # heights = [(0+2/3)/2, (2/3+1)/2] = [0.333333, 0.833333] + # widths = [(1 - 1/3), (1/3 - 0)] = [0.666666, 0.333333] + expected_result = 0.666666 * 0.333333 + 0.333333 * 0.833333 + self.assertAllClose(result, expected_result, 1e-3) + + def test_invalid_num_thresholds(self): + with self.assertRaisesRegex( + ValueError, "Argument `num_thresholds` must be an integer > 1" + ): + metrics.AUC(num_thresholds=-1) + + with self.assertRaisesRegex( + ValueError, "Argument `num_thresholds` must be an integer > 1." + ): + metrics.AUC(num_thresholds=1) + + def test_invalid_curve(self): + with self.assertRaisesRegex( + ValueError, 'Invalid AUC curve value: "Invalid".' + ): + metrics.AUC(curve="Invalid") + + def test_invalid_summation_method(self): + with self.assertRaisesRegex( + ValueError, 'Invalid AUC summation method value: "Invalid".' + ): + metrics.AUC(summation_method="Invalid") + + def test_extra_dims(self): + try: + from scipy import special + + logits = special.expit( + -np.array( + [ + [[-10.0, 10.0, -10.0], [10.0, -10.0, 10.0]], + [[-12.0, 12.0, -12.0], [12.0, -12.0, 12.0]], + ], + dtype=np.float32, + ) + ) + labels = np.array( + [[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]], dtype=np.int64 + ) + auc_obj = metrics.AUC() + result = auc_obj(labels, logits) + self.assertEqual(result, 0.5) + except ImportError as e: + logging.warning(f"Cannot test special functions: {str(e)}") + + +class MultiAUCTest(testing.TestCase): + def setUp(self): + self.num_thresholds = 5 + self.y_pred = np.array( + [[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]], dtype="float32" + ).T + + epsilon = 1e-12 + self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0) + + self.y_true_good = np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T + self.y_true_bad = np.array([[0, 0, 1, 1], [1, 1, 0, 0]]).T + self.sample_weight = [1, 2, 3, 4] + + # threshold values are [0 - 1e-7, 0.25, 0.5, 0.75, 1 + 1e-7] + # y_pred when threshold = 0 - 1e-7 : [[1, 1, 1, 1], [1, 1, 1, 1]] + # y_pred when threshold = 0.25 : [[0, 1, 1, 1], [0, 0, 1, 1]] + # y_pred when threshold = 0.5 : [[0, 0, 0, 1], [0, 0, 0, 0]] + # y_pred when threshold = 0.75 : [[0, 0, 0, 1], [0, 0, 0, 0]] + # y_pred when threshold = 1 + 1e-7 : [[0, 0, 0, 0], [0, 0, 0, 0]] + + # for y_true_good, over thresholds: + # tp = [[2, 2, 1, 1, 0], [2, 2, 0, 0, 0]] + # fp = [[2, 1, 0, 0 , 0], [2, 0, 0 ,0, 0]] + # fn = [[0, 0, 1, 1, 2], [0, 0, 2, 2, 2]] + # tn = [[0, 1, 2, 2, 2], [0, 2, 2, 2, 2]] + + # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] + + # for y_true_bad: + # tp = [[2, 2, 1, 1, 0], [2, 0, 0, 0, 0]] + # fp = [[2, 1, 0, 0 , 0], [2, 2, 0 ,0, 0]] + # fn = [[0, 0, 1, 1, 2], [0, 2, 2, 2, 2]] + # tn = [[0, 1, 2, 2, 2], [0, 0, 2, 2, 2]] + + # tpr = [[1, 1, 0.5, 0.5, 0], [1, 0, 0, 0, 0]] + # fpr = [[1, 0.5, 0, 0, 0], [1, 1, 0, 0, 0]] + + # for y_true_good with sample_weights: + + # tp = [[7, 7, 4, 4, 0], [7, 7, 0, 0, 0]] + # fp = [[3, 2, 0, 0, 0], [3, 0, 0, 0, 0]] + # fn = [[0, 0, 3, 3, 7], [0, 0, 7, 7, 7]] + # tn = [[0, 1, 3, 3, 3], [0, 3, 3, 3, 3]] + + # tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]] + + def test_unweighted_all_correct(self): + auc_obj = metrics.AUC(multi_label=True) + result = auc_obj(self.y_true_good, self.y_true_good) + self.assertEqual(result, 1) + + def test_unweighted_all_correct_flat(self): + auc_obj = metrics.AUC(multi_label=False) + result = auc_obj(self.y_true_good, self.y_true_good) + self.assertEqual(result, 1) + + def test_unweighted(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=True + ) + result = auc_obj(self.y_true_good, self.y_pred) + + # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] + expected_result = (0.875 + 1.0) / 2.0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_unweighted_from_logits(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + multi_label=True, + from_logits=True, + ) + result = auc_obj(self.y_true_good, self.y_pred_logits) + + # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] + expected_result = (0.875 + 1.0) / 2.0 + self.assertAllClose(result, expected_result, 1e-3) + + def test_sample_weight_flat(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=False + ) + result = auc_obj( + self.y_true_good, self.y_pred, sample_weight=[1, 2, 3, 4] + ) + + # tpr = [1, 1, 0.2857, 0.2857, 0] + # fpr = [1, 0.3333, 0, 0, 0] + expected_result = 1.0 - (0.3333 * (1.0 - 0.2857) / 2.0) + self.assertAllClose(result, expected_result, 1e-3) + + def test_full_sample_weight_flat(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=False + ) + sw = np.arange(4 * 2) + sw = sw.reshape(4, 2) + result = auc_obj(self.y_true_good, self.y_pred, sample_weight=sw) + + # tpr = [1, 1, 0.2727, 0.2727, 0] + # fpr = [1, 0.3333, 0, 0, 0] + expected_result = 1.0 - (0.3333 * (1.0 - 0.2727) / 2.0) + self.assertAllClose(result, expected_result, 1e-3) + + def test_label_weights(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + multi_label=True, + label_weights=[0.75, 0.25], + ) + result = auc_obj(self.y_true_good, self.y_pred) + + # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]] + expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25) + self.assertAllClose(result, expected_result, 1e-3) + + def test_label_weights_flat(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + multi_label=False, + label_weights=[0.75, 0.25], + ) + result = auc_obj(self.y_true_good, self.y_pred) + + # tpr = [1, 1, 0.375, 0.375, 0] + # fpr = [1, 0.375, 0, 0, 0] + expected_result = 1.0 - ((1.0 - 0.375) * 0.375 / 2.0) + self.assertAllClose(result, expected_result, 1e-2) + + def test_unweighted_flat(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=False + ) + result = auc_obj(self.y_true_good, self.y_pred) + + # tp = [4, 4, 1, 1, 0] + # fp = [4, 1, 0, 0, 0] + # fn = [0, 0, 3, 3, 4] + # tn = [0, 3, 4, 4, 4] + + # tpr = [1, 1, 0.25, 0.25, 0] + # fpr = [1, 0.25, 0, 0, 0] + expected_result = 1.0 - (3.0 / 32.0) + self.assertAllClose(result, expected_result, 1e-3) + + def test_unweighted_flat_from_logits(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, + multi_label=False, + from_logits=True, + ) + result = auc_obj(self.y_true_good, self.y_pred_logits) + + # tp = [4, 4, 1, 1, 0] + # fp = [4, 1, 0, 0, 0] + # fn = [0, 0, 3, 3, 4] + # tn = [0, 3, 4, 4, 4] + + # tpr = [1, 1, 0.25, 0.25, 0] + # fpr = [1, 0.25, 0, 0, 0] + expected_result = 1.0 - (3.0 / 32.0) + self.assertAllClose(result, expected_result, 1e-3) + + def test_manual_thresholds(self): + # Verify that when specified, thresholds are used instead of + # num_thresholds. + auc_obj = metrics.AUC( + num_thresholds=2, thresholds=[0.5], multi_label=True + ) + self.assertEqual(auc_obj.num_thresholds, 3) + self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0]) + result = auc_obj(self.y_true_good, self.y_pred) + + # tp = [[2, 1, 0], [2, 0, 0]] + # fp = [2, 0, 0], [2, 0, 0]] + # fn = [[0, 1, 2], [0, 2, 2]] + # tn = [[0, 2, 2], [0, 2, 2]] + + # tpr = [[1, 0.5, 0], [1, 0, 0]] + # fpr = [[1, 0, 0], [1, 0, 0]] + + # auc by slice = [0.75, 0.5] + expected_result = (0.75 + 0.5) / 2.0 + + self.assertAllClose(result, expected_result, 1e-3) + + def test_weighted_roc_interpolation(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=True + ) + result = auc_obj( + self.y_true_good, self.y_pred, sample_weight=self.sample_weight + ) + + # tpr = [[1, 1, 0.57, 0.57, 0], [1, 1, 0, 0, 0]] + # fpr = [[1, 0.67, 0, 0, 0], [1, 0, 0, 0, 0]] + expected_result = 1.0 - 0.5 * 0.43 * 0.67 + self.assertAllClose(result, expected_result, 1e-1) + + def test_pr_interpolation_unweighted(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, curve="PR", multi_label=True + ) + good_result = auc_obj(self.y_true_good, self.y_pred) + with self.subTest(name="good"): + # PR AUCs are 0.917 and 1.0 respectively + self.assertAllClose(good_result, (0.91667 + 1.0) / 2.0, 1e-1) + bad_result = auc_obj(self.y_true_bad, self.y_pred) + with self.subTest(name="bad"): + # PR AUCs are 0.917 and 0.5 respectively + self.assertAllClose(bad_result, (0.91667 + 0.5) / 2.0, 1e-1) + + def test_pr_interpolation(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, curve="PR", multi_label=True + ) + good_result = auc_obj( + self.y_true_good, self.y_pred, sample_weight=self.sample_weight + ) + # PR AUCs are 0.939 and 1.0 respectively + self.assertAllClose(good_result, (0.939 + 1.0) / 2.0, 1e-1) + + @pytest.mark.requires_trainable_backend + def test_keras_model_compiles(self): + inputs = layers.Input(shape=(10,), batch_size=1) + output = layers.Dense(3, activation="sigmoid")(inputs) + model = models.Model(inputs=inputs, outputs=output) + model.compile( + optimizer="adam", + loss="binary_crossentropy", + metrics=[metrics.AUC(multi_label=True)], + ) + + def test_reset_state(self): + auc_obj = metrics.AUC( + num_thresholds=self.num_thresholds, multi_label=True + ) + auc_obj(self.y_true_good, self.y_pred) + auc_obj.reset_state() + self.assertAllClose(auc_obj.true_positives, np.zeros((5, 2))) diff --git a/keras/src/metrics/correlation_metrics.py b/keras/src/metrics/correlation_metrics.py new file mode 100644 index 000000000000..1d2c8efea6c7 --- /dev/null +++ b/keras/src/metrics/correlation_metrics.py @@ -0,0 +1,215 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.losses.loss import squeeze_or_expand_to_same_rank +from keras.src.metrics import reduction_metrics + + +@keras_export("keras.metrics.pearson_correlation") +def pearson_correlation(y_true, y_pred, axis=-1): + """Computes the Pearson coefficient between labels and predictions. + + Formula: + + ```python + loss = mean(l2norm(y_true - mean(y_true) * l2norm(y_pred - mean(y_pred))) + ``` + + Args: + y_true: Tensor of true targets. + y_pred: Tensor of predicted targets. + axis: Axis along which to determine similarity. Defaults to `-1`. + + Returns: + Pearson Correlation Coefficient tensor. + + Example: + + >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]] + >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]] + >>> loss = keras.losses.concordance_correlation( + ... y_true, y_pred, axis=-1 + ... ).numpy() + [1. 0.99339927] + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + + y_true_norm = y_true - ops.mean(y_true, axis=axis, keepdims=True) + y_pred_norm = y_pred - ops.mean(y_pred, axis=axis, keepdims=True) + + y_true_norm = y_true_norm / ops.std(y_true_norm, axis=axis, keepdims=True) + y_pred_norm = y_pred_norm / ops.std(y_pred_norm, axis=axis, keepdims=True) + + return ops.mean(y_true_norm * y_pred_norm, axis=axis) + + +@keras_export("keras.metrics.concordance_correlation") +def concordance_correlation(y_true, y_pred, axis=-1): + """Computes the Concordance coefficient between labels and predictions. + + Formula: + + ```python + loss = mean( + 2 * (y_true - mean(y_true) * (y_pred - mean(y_pred)) / ( + var(y_true) + var(y_pred) + square(mean(y_true) - mean(y_pred)) + ) + ) + ``` + + Args: + y_true: Tensor of true targets. + y_pred: Tensor of predicted targets. + axis: Axis along which to determine similarity. Defaults to `-1`. + + Returns: + Concordance Correlation Coefficient tensor. + + Example: + + >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]] + >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]] + >>> loss = keras.losses.concordance_correlation( + ... y_true, y_pred, axis=-1 + ... ).numpy() + [0.97560976 0.98765432] + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + + y_true_mean = ops.mean(y_true, axis=axis, keepdims=True) + y_pred_mean = ops.mean(y_pred, axis=axis, keepdims=True) + + y_true_var = ops.var(y_true - y_true_mean, axis=axis, keepdims=True) + y_pred_var = ops.var(y_pred - y_pred_mean, axis=axis, keepdims=True) + + covar = (y_true - y_pred_mean) * (y_pred - y_pred_mean) + norm = y_true_var + y_pred_var + ops.square(y_true_mean - y_pred_mean) + + return ops.mean(2 * covar / (norm + backend.epsilon()), axis=axis) + + +@keras_export("keras.metrics.PearsonCorrelation") +class PearsonCorrelation(reduction_metrics.MeanMetricWrapper): + """Calculates the Pearson Correlation Coefficient (PCC). + + PCC measures the linear relationship between the true values (`y_true`) and + the predicted values (`y_pred`). The coefficient ranges from -1 to 1, where + a value of 1 implies a perfect positive linear correlation, 0 indicates no + linear correlation, and -1 indicates a perfect negative linear correlation. + + This metric is widely used in regression tasks where the strength of the + linear relationship between predictions and true labels is an + important evaluation criterion. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + axis: (Optional) integer or tuple of integers of the axis/axes along + which to compute the metric. Defaults to `-1`. + + Example: + + >>> pcc = keras.metrics.PearsonCorrelation(axis=-1) + >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]] + >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]] + >>> pcc.update_state(y_true, y_pred) + >>> pcc.result() + 0.9966996338993913 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mean_squared_error', + metrics=[keras.metrics.PearsonCorrelation()]) + ``` + """ + + def __init__( + self, + name="pearson_correlation", + dtype=None, + axis=-1, + ): + super().__init__( + fn=pearson_correlation, + name=name, + dtype=dtype, + axis=axis, + ) + self.axis = axis + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "axis": self.axis, + } + + +@keras_export("keras.metrics.ConcordanceCorrelation") +class ConcordanceCorrelation(reduction_metrics.MeanMetricWrapper): + """Calculates the Concordance Correlation Coefficient (CCC). + + CCC evaluates the agreement between true values (`y_true`) and predicted + values (`y_pred`) by considering both precision and accuracy. The + coefficient ranges from -1 to 1, where a value of 1 indicates perfect + agreement. + + This metric is useful in regression tasks where it is important to assess + how well the predictions match the true values, taking into account both + their correlation and proximity to the 45-degree line of perfect + concordance. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + axis: (Optional) integer or tuple of integers of the axis/axes along + which to compute the metric. Defaults to `-1`. + + Example: + + >>> ccc = keras.metrics.ConcordanceCorrelation(axis=-1) + >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]] + >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]] + >>> ccc.update_state(y_true, y_pred) + >>> ccc.result() + 0.9816320385426076 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mean_squared_error', + metrics=[keras.metrics.ConcordanceCorrelation()]) + ``` + """ + + def __init__( + self, + name="concordance_correlation", + dtype=None, + axis=-1, + ): + super().__init__( + fn=concordance_correlation, + name=name, + dtype=dtype, + axis=axis, + ) + self.axis = axis + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "axis": self.axis, + } diff --git a/keras/src/metrics/correlation_metrics_test.py b/keras/src/metrics/correlation_metrics_test.py new file mode 100644 index 000000000000..d7150985aa52 --- /dev/null +++ b/keras/src/metrics/correlation_metrics_test.py @@ -0,0 +1,79 @@ +import numpy as np +from scipy.stats import pearsonr + +from keras.src import testing +from keras.src.metrics import ConcordanceCorrelation +from keras.src.metrics import PearsonCorrelation +from keras.src.metrics import correlation_metrics + + +class CorrelationsTest(testing.TestCase): + def _get_data(self): + # Sample data for testing + y_true = np.array( + [[0, 1, 0.5], [1, 1, 0.2], [1, 1, 0.1], [0.1, 0.7, 0.0]], + dtype="float32", + ) + y_pred = np.array( + [[0.1, 0.9, 0.5], [1, 0.9, 0.2], [0.2, 0.8, 0], [0.3, 0.3, 0.9]], + dtype="float32", + ) + + ccc_expected = np.array( + [0.97560976, 0.98765432, 0.46511628, -0.46376812] + ) + # pcc_expected = np.array([1, 0.99339927, 0.69337525, -0.60999428]) + pcc_expected = np.array( + [pearsonr(yt, yp).statistic for yt, yp in zip(y_true, y_pred)] + ) + return y_true, y_pred, ccc_expected, pcc_expected + + def test_pearson_function(self): + """Test the functional API for Pearson Correlation Coefficient.""" + y_true, y_pred, _, pcc_expected = self._get_data() + result = correlation_metrics.pearson_correlation( + y_true, y_pred, axis=-1 + ) + self.assertAllClose(result, pcc_expected) + + def test_concordance_function(self): + """Test the functional API for Concordance Correlation Coefficient.""" + y_true, y_pred, ccc_expected, _ = self._get_data() + result = correlation_metrics.concordance_correlation( + y_true, y_pred, axis=-1 + ) + self.assertAllClose(result, ccc_expected) + + def test_pearson_class(self): + """Test the PearsonCorrelation metric class.""" + y_true, y_pred, _, pcc_expected = self._get_data() + m = PearsonCorrelation(axis=-1, dtype="float32") + m.update_state(y_true[:2], y_pred[:2]) + self.assertAllClose(m.result(), np.mean(pcc_expected[:2])) + m.update_state(y_true[2:], y_pred[2:]) + self.assertAllClose(m.result(), np.mean(pcc_expected)) + + def test_concordance_class(self): + """Test the ConcordanceCorrelation metric class.""" + y_true, y_pred, ccc_expected, _ = self._get_data() + m = ConcordanceCorrelation(axis=-1, dtype="float32") + m.update_state(y_true[:2], y_pred[:2]) + self.assertAllClose(m.result(), np.mean(ccc_expected[:2])) + m.update_state(y_true[2:], y_pred[2:]) + self.assertAllClose(m.result(), np.mean(ccc_expected)) + + def test_pearson_config(self): + """Test the get_config method for PearsonCorrelation.""" + m = PearsonCorrelation(axis=-1, dtype="float16") + config = m.get_config() + self.assertEqual(config["axis"], -1) + self.assertEqual(config["dtype"], "float16") + self.assertEqual(config["name"], "pearson_correlation") + + def test_concordance_config(self): + """Test the get_config method for ConcordanceCorrelation.""" + m = ConcordanceCorrelation(axis=-1, dtype="float32") + config = m.get_config() + self.assertEqual(config["axis"], -1) + self.assertEqual(config["dtype"], "float32") + self.assertEqual(config["name"], "concordance_correlation") diff --git a/keras/src/metrics/f_score_metrics.py b/keras/src/metrics/f_score_metrics.py new file mode 100644 index 000000000000..a51119cb48e4 --- /dev/null +++ b/keras/src/metrics/f_score_metrics.py @@ -0,0 +1,320 @@ +from keras.src import backend +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.metrics.metric import Metric + + +@keras_export("keras.metrics.FBetaScore") +class FBetaScore(Metric): + """Computes F-Beta score. + + Formula: + + ```python + b2 = beta ** 2 + f_beta_score = (1 + b2) * (precision * recall) / (precision * b2 + recall) + ``` + This is the weighted harmonic mean of precision and recall. + Its output range is `[0, 1]`. It works for both multi-class + and multi-label classification. + + Args: + average: Type of averaging to be performed across per-class results + in the multi-class case. + Acceptable values are `None`, `"micro"`, `"macro"` and + `"weighted"`. Defaults to `None`. + If `None`, no averaging is performed and `result()` will return + the score for each class. + If `"micro"`, compute metrics globally by counting the total + true positives, false negatives and false positives. + If `"macro"`, compute metrics for each label, + and return their unweighted mean. + This does not take label imbalance into account. + If `"weighted"`, compute metrics for each label, + and return their average weighted by support + (the number of true instances for each label). + This alters `"macro"` to account for label imbalance. + It can result in an score that is not between precision and recall. + beta: Determines the weight of given to recall + in the harmonic mean between precision and recall (see pseudocode + equation above). Defaults to `1`. + threshold: Elements of `y_pred` greater than `threshold` are + converted to be 1, and the rest 0. If `threshold` is + `None`, the argmax of `y_pred` is converted to 1, and the rest to 0. + name: Optional. String name of the metric instance. + dtype: Optional. Data type of the metric result. + + Returns: + F-Beta Score: float. + + Example: + + >>> metric = keras.metrics.FBetaScore(beta=2.0, threshold=0.5) + >>> y_true = np.array([[1, 1, 1], + ... [1, 0, 0], + ... [1, 1, 0]], np.int32) + >>> y_pred = np.array([[0.2, 0.6, 0.7], + ... [0.2, 0.6, 0.6], + ... [0.6, 0.8, 0.0]], np.float32) + >>> metric.update_state(y_true, y_pred) + >>> result = metric.result() + >>> result + [0.3846154 , 0.90909094, 0.8333334 ] + """ + + def __init__( + self, + average=None, + beta=1.0, + threshold=None, + name="fbeta_score", + dtype=None, + ): + super().__init__(name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + if average not in (None, "micro", "macro", "weighted"): + raise ValueError( + "Invalid `average` argument value. Expected one of: " + "{None, 'micro', 'macro', 'weighted'}. " + f"Received: average={average}" + ) + + if not isinstance(beta, float): + raise ValueError( + "Invalid `beta` argument value. " + "It should be a Python float. " + f"Received: beta={beta} of type '{type(beta)}'" + ) + if beta <= 0.0: + raise ValueError( + "Invalid `beta` argument value. " + "It should be > 0. " + f"Received: beta={beta}" + ) + + if threshold is not None: + if not isinstance(threshold, float): + raise ValueError( + "Invalid `threshold` argument value. " + "It should be a Python float. " + f"Received: threshold={threshold} " + f"of type '{type(threshold)}'" + ) + if threshold > 1.0 or threshold <= 0.0: + raise ValueError( + "Invalid `threshold` argument value. " + "It should verify 0 < threshold <= 1. " + f"Received: threshold={threshold}" + ) + + self.average = average + self.beta = beta + self.threshold = threshold + self.axis = None + self._built = False + + if self.average != "micro": + self.axis = 0 + + def _build(self, y_true_shape, y_pred_shape): + if len(y_pred_shape) != 2 or len(y_true_shape) != 2: + raise ValueError( + "FBetaScore expects 2D inputs with shape " + "(batch_size, output_dim). Received input " + f"shapes: y_pred.shape={y_pred_shape} and " + f"y_true.shape={y_true_shape}." + ) + if y_pred_shape[-1] is None or y_true_shape[-1] is None: + raise ValueError( + "FBetaScore expects 2D inputs with shape " + "(batch_size, output_dim), with output_dim fully " + "defined (not None). Received input " + f"shapes: y_pred.shape={y_pred_shape} and " + f"y_true.shape={y_true_shape}." + ) + num_classes = y_pred_shape[-1] + if self.average != "micro": + init_shape = (num_classes,) + else: + init_shape = () + + def _add_zeros_variable(name): + return self.add_variable( + name=name, + shape=init_shape, + initializer=initializers.Zeros(), + dtype=self.dtype, + ) + + self.true_positives = _add_zeros_variable("true_positives") + self.false_positives = _add_zeros_variable("false_positives") + self.false_negatives = _add_zeros_variable("false_negatives") + self.intermediate_weights = _add_zeros_variable("intermediate_weights") + self._built = True + + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = ops.convert_to_tensor(y_true, dtype=self.dtype) + y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype) + if not self._built: + self._build(y_true.shape, y_pred.shape) + + if self.threshold is None: + threshold = ops.max(y_pred, axis=-1, keepdims=True) + # make sure [0, 0, 0] doesn't become [1, 1, 1] + # Use abs(x) > eps, instead of x != 0 to check for zero + y_pred = ops.logical_and( + y_pred >= threshold, ops.abs(y_pred) > 1e-9 + ) + else: + y_pred = y_pred > self.threshold + + y_pred = ops.cast(y_pred, dtype=self.dtype) + y_true = ops.cast(y_true, dtype=self.dtype) + if sample_weight is not None: + sample_weight = ops.convert_to_tensor( + sample_weight, dtype=self.dtype + ) + + def _weighted_sum(val, sample_weight): + if sample_weight is not None: + val = ops.multiply(val, ops.expand_dims(sample_weight, 1)) + return ops.sum(val, axis=self.axis) + + self.true_positives.assign( + self.true_positives + _weighted_sum(y_pred * y_true, sample_weight) + ) + self.false_positives.assign( + self.false_positives + + _weighted_sum(y_pred * (1 - y_true), sample_weight) + ) + self.false_negatives.assign( + self.false_negatives + + _weighted_sum((1 - y_pred) * y_true, sample_weight) + ) + self.intermediate_weights.assign( + self.intermediate_weights + _weighted_sum(y_true, sample_weight) + ) + + def result(self): + precision = ops.divide( + self.true_positives, + self.true_positives + self.false_positives + backend.epsilon(), + ) + recall = ops.divide( + self.true_positives, + self.true_positives + self.false_negatives + backend.epsilon(), + ) + + precision = ops.convert_to_tensor(precision, dtype=self.dtype) + recall = ops.convert_to_tensor(recall, dtype=self.dtype) + + mul_value = precision * recall + add_value = ((self.beta**2) * precision) + recall + mean = ops.divide(mul_value, add_value + backend.epsilon()) + f1_score = mean * (1 + (self.beta**2)) + + if self.average == "weighted": + weights = ops.divide( + self.intermediate_weights, + ops.sum(self.intermediate_weights) + backend.epsilon(), + ) + f1_score = ops.sum(f1_score * weights) + + elif self.average is not None: # [micro, macro] + f1_score = ops.mean(f1_score) + + return f1_score + + def get_config(self): + """Returns the serializable config of the metric.""" + + config = { + "name": self.name, + "dtype": self.dtype, + "average": self.average, + "beta": self.beta, + "threshold": self.threshold, + } + + base_config = super().get_config() + return {**base_config, **config} + + def reset_state(self): + for v in self.variables: + v.assign(ops.zeros(v.shape, dtype=v.dtype)) + + +@keras_export("keras.metrics.F1Score") +class F1Score(FBetaScore): + r"""Computes F-1 Score. + + Formula: + + ```python + f1_score = 2 * (precision * recall) / (precision + recall) + ``` + This is the harmonic mean of precision and recall. + Its output range is `[0, 1]`. It works for both multi-class + and multi-label classification. + + Args: + average: Type of averaging to be performed on data. + Acceptable values are `None`, `"micro"`, `"macro"` + and `"weighted"`. Defaults to `None`. + If `None`, no averaging is performed and `result()` will return + the score for each class. + If `"micro"`, compute metrics globally by counting the total + true positives, false negatives and false positives. + If `"macro"`, compute metrics for each label, + and return their unweighted mean. + This does not take label imbalance into account. + If `"weighted"`, compute metrics for each label, + and return their average weighted by support + (the number of true instances for each label). + This alters `"macro"` to account for label imbalance. + It can result in an score that is not between precision and recall. + threshold: Elements of `y_pred` greater than `threshold` are + converted to be 1, and the rest 0. If `threshold` is + `None`, the argmax of `y_pred` is converted to 1, and the rest to 0. + name: Optional. String name of the metric instance. + dtype: Optional. Data type of the metric result. + + Returns: + F-1 Score: float. + + Example: + + >>> metric = keras.metrics.F1Score(threshold=0.5) + >>> y_true = np.array([[1, 1, 1], + ... [1, 0, 0], + ... [1, 1, 0]], np.int32) + >>> y_pred = np.array([[0.2, 0.6, 0.7], + ... [0.2, 0.6, 0.6], + ... [0.6, 0.8, 0.0]], np.float32) + >>> metric.update_state(y_true, y_pred) + >>> result = metric.result() + array([0.5 , 0.8 , 0.6666667], dtype=float32) + """ + + def __init__( + self, + average=None, + threshold=None, + name="f1_score", + dtype=None, + ): + super().__init__( + average=average, + beta=1.0, + threshold=threshold, + name=name, + dtype=dtype, + ) + + def get_config(self): + base_config = super().get_config() + del base_config["beta"] + return base_config diff --git a/keras/src/metrics/f_score_metrics_test.py b/keras/src/metrics/f_score_metrics_test.py new file mode 100644 index 000000000000..352ebe316e22 --- /dev/null +++ b/keras/src/metrics/f_score_metrics_test.py @@ -0,0 +1,422 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import testing +from keras.src.metrics import f_score_metrics + + +class FBetaScoreTest(testing.TestCase): + def _run_test( + self, + y_true, + y_pred, + sample_weights, + average, + beta, + threshold, + reference_result, + ): + fbeta = f_score_metrics.FBetaScore( + average, beta, threshold, dtype="float32" + ) + fbeta.update_state(y_true, y_pred, sample_weights) + result = fbeta.result() + self.assertAllClose(result, reference_result, atol=1e-6) + + def test_config(self): + fbeta_obj = f_score_metrics.FBetaScore( + beta=0.5, threshold=0.3, average=None, dtype="float32" + ) + self.assertEqual(fbeta_obj.beta, 0.5) + self.assertEqual(fbeta_obj.average, None) + self.assertEqual(fbeta_obj.threshold, 0.3) + self.assertEqual(fbeta_obj.dtype, "float32") + + # Check save and restore config + fbeta_obj2 = f_score_metrics.FBetaScore.from_config( + fbeta_obj.get_config() + ) + self.assertEqual(fbeta_obj2.beta, 0.5) + self.assertEqual(fbeta_obj2.average, None) + self.assertEqual(fbeta_obj2.threshold, 0.3) + self.assertEqual(fbeta_obj2.dtype, "float32") + + @parameterized.parameters( + ("micro", 0.5), + ("micro", 1.0), + ("micro", 2.0), + ("macro", 0.5), + ("macro", 1.0), + ("macro", 2.0), + ("weighted", 0.5), + ("weighted", 1.0), + ("weighted", 2.0), + ) + def test_fbeta_perfect_score(self, average, beta): + y_true = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] + y_pred = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] + self._run_test( + y_true, + y_pred, + None, + average=average, + beta=beta, + threshold=0.66, + reference_result=1.0, + ) + + @parameterized.parameters( + ("micro", 0.5), + ("micro", 1.0), + ("micro", 2.0), + ("macro", 0.5), + ("macro", 1.0), + ("macro", 2.0), + ("weighted", 0.5), + ("weighted", 1.0), + ("weighted", 2.0), + ) + def test_fbeta_worst_score(self, average, beta): + y_true = [[0, 0, 0], [0, 1, 0], [0, 0, 1]] + y_pred = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] + self._run_test( + y_true, + y_pred, + None, + average=average, + beta=beta, + threshold=0.66, + reference_result=0.0, + ) + + @parameterized.parameters( + # average, beta, result + (None, 0.5, [0.71428573, 0.5, 0.833334]), + (None, 1.0, [0.8, 0.5, 0.6666667]), + (None, 2.0, [0.9090904, 0.5, 0.555556]), + ("micro", 0.5, 0.6666667), + ("micro", 1.0, 0.6666667), + ("micro", 2.0, 0.6666667), + ("macro", 0.5, 0.6825397), + ("macro", 1.0, 0.6555555), + ("macro", 2.0, 0.6548822), + ("weighted", 0.5, 0.6825397), + ("weighted", 1.0, 0.6555555), + ("weighted", 2.0, 0.6548822), + ) + def test_fbeta_random_score(self, average, beta, result): + y_pred = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] + y_true = [[0, 0, 1], [1, 1, 0], [1, 1, 1]] + self._run_test( + y_true, + y_pred, + None, + average=average, + beta=beta, + threshold=0.66, + reference_result=result, + ) + + @parameterized.parameters( + # average, beta, result + (None, 0.5, [0.9090904, 0.555556, 1.0]), + (None, 1.0, [0.8, 0.6666667, 1.0]), + (None, 2.0, [0.71428573, 0.833334, 1.0]), + ("micro", 0.5, 0.833334), + ("micro", 1.0, 0.833334), + ("micro", 2.0, 0.833334), + ("macro", 0.5, 0.821549), + ("macro", 1.0, 0.822222), + ("macro", 2.0, 0.849206), + ("weighted", 0.5, 0.880471), + ("weighted", 1.0, 0.844445), + ("weighted", 2.0, 0.829365), + ) + def test_fbeta_random_score_none(self, average, beta, result): + y_true = [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + ] + y_pred = [ + [0.9, 0.1, 0], + [0.2, 0.6, 0.2], + [0, 0, 1], + [0.4, 0.3, 0.3], + [0, 0.9, 0.1], + [0, 0, 1], + ] + self._run_test( + y_true, + y_pred, + None, + average=average, + beta=beta, + threshold=None, + reference_result=result, + ) + + @parameterized.parameters( + # average, beta, sample_weights, result + (None, 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.909091, 0.555556, 1.0]), + (None, 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]), + (None, 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.9375, 0.714286, 1.0]), + (None, 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.8, 0.666667, 1.0]), + (None, 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]), + (None, 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.857143, 0.8, 1.0]), + (None, 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.714286, 0.833333, 1.0]), + (None, 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]), + (None, 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.789474, 0.909091, 1.0]), + ("micro", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333), + ("micro", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0), + ("micro", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9), + ("micro", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333), + ("micro", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0), + ("micro", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9), + ("micro", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333), + ("micro", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0), + ("micro", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9), + ("macro", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.821549), + ("macro", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667), + ("macro", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.883929), + ("macro", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.822222), + ("macro", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667), + ("macro", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.885714), + ("macro", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.849206), + ("macro", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667), + ("macro", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.899522), + ("weighted", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.880471), + ("weighted", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0), + ("weighted", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.917857), + ("weighted", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.844444), + ("weighted", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0), + ("weighted", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.902857), + ("weighted", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.829365), + ("weighted", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0), + ("weighted", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.897608), + ) + def test_fbeta_weighted_random_score_none( + self, average, beta, sample_weights, result + ): + y_true = [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + ] + y_pred = [ + [0.9, 0.1, 0], + [0.2, 0.6, 0.2], + [0, 0, 1], + [0.4, 0.3, 0.3], + [0, 0.9, 0.1], + [0, 0, 1], + ] + self._run_test( + y_true, + y_pred, + sample_weights, + average=average, + beta=beta, + threshold=None, + reference_result=result, + ) + + def test_invalid_average_raises_value_error(self): + expected_message = ( + "Invalid `average` argument value. Expected one of: " + r"\{None, 'micro', 'macro', 'weighted'\}. " + "Received: average=invalid_average" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="invalid_average", + beta=1.0, + threshold=None, + dtype="float32", + ) + + def test_beta_integer_type_raises_value_error(self): + with self.assertRaisesRegex( + ValueError, + "Invalid `beta` argument value. It should be a Python float.", + ): + f_score_metrics.FBetaScore( + average="macro", beta=1, threshold=None, dtype="float32" + ) + + def test_beta_string_type_raises_value_error(self): + with self.assertRaisesRegex( + ValueError, + "Invalid `beta` argument value. It should be a Python float.", + ): + f_score_metrics.FBetaScore( + average="macro", beta="1.0", threshold=None, dtype="float32" + ) + + def test_beta_none_type_raises_value_error(self): + with self.assertRaisesRegex( + ValueError, + "Invalid `beta` argument value. It should be a Python float.", + ): + f_score_metrics.FBetaScore( + average="macro", beta=None, threshold=None, dtype="float32" + ) + + def test_beta_zero_raises_value_error(self): + expected_message = ( + "Invalid `beta` argument value. It should be > 0. " + "Received: beta=0.0" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="macro", beta=0.0, threshold=None, dtype="float32" + ) + + def test_beta_negative_one_raises_value_error(self): + expected_message = ( + "Invalid `beta` argument value. It should be > 0. " + "Received: beta=-1.0" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="macro", beta=-1.0, threshold=None, dtype="float32" + ) + + def test_beta_negative_half_raises_value_error(self): + expected_message = ( + "Invalid `beta` argument value. It should be > 0. " + "Received: beta=-0.5" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="macro", beta=-0.5, threshold=None, dtype="float32" + ) + + def test_threshold_not_float_raises_value_error(self): + expected_message_pattern = ( + "Invalid `threshold` argument value. " + "It should be a Python float. " + "Received: threshold=1 of type ''" + ) + with self.assertRaisesRegex(ValueError, expected_message_pattern): + f_score_metrics.FBetaScore( + average="macro", beta=1.0, threshold=1, dtype="float32" + ) + + def test_threshold_string_raises_value_error(self): + expected_message_pattern = ( + "Invalid `threshold` argument value. " + "It should be a Python float. " + "Received: threshold=0.5 of type ''" + ) + with self.assertRaisesRegex(ValueError, expected_message_pattern): + f_score_metrics.FBetaScore( + average="macro", beta=1.0, threshold="0.5", dtype="float32" + ) + + def test_threshold_above_one_raises_value_error(self): + expected_message = ( + "Invalid `threshold` argument value. " + "It should verify 0 < threshold <= 1. " + "Received: threshold=1.1" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="macro", beta=1.0, threshold=1.1, dtype="float32" + ) + + def test_threshold_zero_raises_value_error(self): + expected_message = ( + "Invalid `threshold` argument value. " + "It should verify 0 < threshold <= 1. " + "Received: threshold=0.0" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="macro", beta=1.0, threshold=0.0, dtype="float32" + ) + + def test_threshold_negative_raises_value_error(self): + expected_message = ( + "Invalid `threshold` argument value. " + "It should verify 0 < threshold <= 1. " + "Received: threshold=-0.5" + ) + with self.assertRaisesRegex(ValueError, expected_message): + f_score_metrics.FBetaScore( + average="macro", beta=1.0, threshold=-0.5, dtype="float32" + ) + + def test_non_2d_input_shapes_raises_value_error(self): + fbeta = f_score_metrics.FBetaScore(beta=1.0, dtype="float32") + y_true_shape = (2, 3, 4) + y_pred_shape = (2, 3, 4) + expected_error_message = ( + "FBetaScore expects 2D inputs with shape " + r"\(batch_size, output_dim\)\. Received input " + r"shapes: y_pred\.shape=\(2, 3, 4\) and " + r"y_true\.shape=\(2, 3, 4\)\." + ) + with self.assertRaisesRegex(ValueError, expected_error_message): + fbeta._build(y_true_shape, y_pred_shape) + + def test_undefined_output_dim_raises_value_error(self): + fbeta = f_score_metrics.FBetaScore(beta=1.0, dtype="float32") + y_true_shape = (2, None) + y_pred_shape = (2, None) + expected_error_message = ( + "FBetaScore expects 2D inputs with shape " + r"\(batch_size, output_dim\), with output_dim fully " + r"defined \(not None\)\. Received input " + r"shapes: y_pred\.shape=\(2, None\) and " + r"y_true\.shape=\(2, None\)\." + ) + with self.assertRaisesRegex(ValueError, expected_error_message): + fbeta._build(y_true_shape, y_pred_shape) + + +class F1ScoreTest(testing.TestCase): + def test_config(self): + f1_obj = f_score_metrics.F1Score(dtype="float32") + config = f1_obj.get_config() + self.assertNotIn("beta", config) + + # Check save and restore config + f1_obj = f_score_metrics.F1Score.from_config(config) + self.assertEqual(f1_obj.average, None) + self.assertEqual(f1_obj.dtype, "float32") + + def test_correctness(self): + f1 = f_score_metrics.F1Score() + fbeta = f_score_metrics.FBetaScore(beta=1.0) + + y_true = np.array( + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + ] + ) + y_pred = np.array( + [ + [0.9, 0.1, 0], + [0.2, 0.6, 0.2], + [0, 0, 1], + [0.4, 0.3, 0.3], + [0, 0.9, 0.1], + [0, 0, 1], + ] + ) + + fbeta.update_state(y_true, y_pred) + f1.update_state(y_true, y_pred) + self.assertAllClose(fbeta.result(), f1.result(), atol=1e-6) diff --git a/keras/src/metrics/hinge_metrics.py b/keras/src/metrics/hinge_metrics.py new file mode 100644 index 000000000000..4678b3fa1718 --- /dev/null +++ b/keras/src/metrics/hinge_metrics.py @@ -0,0 +1,100 @@ +from keras.src.api_export import keras_export +from keras.src.losses.losses import categorical_hinge +from keras.src.losses.losses import hinge +from keras.src.losses.losses import squared_hinge +from keras.src.metrics import reduction_metrics + + +@keras_export("keras.metrics.Hinge") +class Hinge(reduction_metrics.MeanMetricWrapper): + """Computes the hinge metric between `y_true` and `y_pred`. + + `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are + provided we will convert them to -1 or 1. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.Hinge() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.result() + 1.3 + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) + >>> m.result() + 1.1 + """ + + def __init__(self, name="hinge", dtype=None): + super().__init__(fn=hinge, name=name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.SquaredHinge") +class SquaredHinge(reduction_metrics.MeanMetricWrapper): + """Computes the hinge metric between `y_true` and `y_pred`. + + `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are + provided we will convert them to -1 or 1. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.SquaredHinge() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.result() + 1.86 + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) + >>> m.result() + 1.46 + """ + + def __init__(self, name="squared_hinge", dtype=None): + super().__init__(fn=squared_hinge, name=name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.CategoricalHinge") +class CategoricalHinge(reduction_metrics.MeanMetricWrapper): + """Computes the categorical hinge metric between `y_true` and `y_pred`. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + >>> m = keras.metrics.CategoricalHinge() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.result().numpy() + 1.4000001 + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) + >>> m.result() + 1.2 + """ + + def __init__(self, name="categorical_hinge", dtype=None): + super().__init__(fn=categorical_hinge, name=name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} diff --git a/keras/src/metrics/hinge_metrics_test.py b/keras/src/metrics/hinge_metrics_test.py new file mode 100644 index 000000000000..26d67b98ee6d --- /dev/null +++ b/keras/src/metrics/hinge_metrics_test.py @@ -0,0 +1,131 @@ +import numpy as np + +from keras.src import testing +from keras.src.metrics import hinge_metrics + + +class HingeTest(testing.TestCase): + def test_config(self): + hinge_obj = hinge_metrics.Hinge(name="hinge", dtype="int32") + self.assertEqual(hinge_obj.name, "hinge") + self.assertEqual(hinge_obj._dtype, "int32") + + # Check save and restore config + hinge_obj2 = hinge_metrics.Hinge.from_config(hinge_obj.get_config()) + self.assertEqual(hinge_obj2.name, "hinge") + self.assertEqual(len(hinge_obj2.variables), 2) + self.assertEqual(hinge_obj2._dtype, "int32") + + def test_unweighted(self): + hinge_obj = hinge_metrics.Hinge() + y_true = np.array([[0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]]) + y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]]) + hinge_obj.update_state(y_true, y_pred) + result = hinge_obj.result() + self.assertAllClose(0.506, result, atol=1e-3) + + def test_weighted(self): + hinge_obj = hinge_metrics.Hinge() + y_true = np.array([[-1, 1, -1, 1], [-1, -1, 1, 1]]) + y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]]) + sample_weight = np.array([1.5, 2.0]) + result = hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.493, result, atol=1e-3) + + +class SquaredHingeTest(testing.TestCase): + def test_config(self): + sq_hinge_obj = hinge_metrics.SquaredHinge( + name="squared_hinge", dtype="int32" + ) + self.assertEqual(sq_hinge_obj.name, "squared_hinge") + self.assertEqual(sq_hinge_obj._dtype, "int32") + + # Check save and restore config + sq_hinge_obj2 = hinge_metrics.SquaredHinge.from_config( + sq_hinge_obj.get_config() + ) + self.assertEqual(sq_hinge_obj2.name, "squared_hinge") + self.assertEqual(len(sq_hinge_obj2.variables), 2) + self.assertEqual(sq_hinge_obj2._dtype, "int32") + + def test_unweighted(self): + sq_hinge_obj = hinge_metrics.SquaredHinge() + y_true = np.array([[0, 1, 0, 1], [0, 0, 1, 1]], dtype="float32") + y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]]) + sq_hinge_obj.update_state(y_true, y_pred) + result = sq_hinge_obj.result() + self.assertAllClose(0.364, result, atol=1e-3) + + def test_weighted(self): + sq_hinge_obj = hinge_metrics.SquaredHinge() + y_true = np.array([[-1, 1, -1, 1], [-1, -1, 1, 1]], dtype="float32") + y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]]) + sample_weight = np.array([1.5, 2.0]) + result = sq_hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.347, result, atol=1e-3) + + +class CategoricalHingeTest(testing.TestCase): + def test_config(self): + cat_hinge_obj = hinge_metrics.CategoricalHinge( + name="cat_hinge", dtype="int32" + ) + self.assertEqual(cat_hinge_obj.name, "cat_hinge") + self.assertEqual(cat_hinge_obj._dtype, "int32") + + # Check save and restore config + cat_hinge_obj2 = hinge_metrics.CategoricalHinge.from_config( + cat_hinge_obj.get_config() + ) + self.assertEqual(cat_hinge_obj2.name, "cat_hinge") + self.assertEqual(len(cat_hinge_obj2.variables), 2) + self.assertEqual(cat_hinge_obj2._dtype, "int32") + + def test_unweighted(self): + cat_hinge_obj = hinge_metrics.CategoricalHinge() + y_true = np.array( + ( + (0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1), + ), + dtype="float32", + ) + y_pred = np.array( + ( + (0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1), + ), + dtype="float32", + ) + cat_hinge_obj.update_state(y_true, y_pred) + result = cat_hinge_obj.result() + self.assertAllClose(0.5, result, atol=1e-5) + + def test_weighted(self): + cat_hinge_obj = hinge_metrics.CategoricalHinge() + y_true = np.array( + ( + (0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1), + ), + dtype="float32", + ) + y_pred = np.array( + ( + (0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1), + ), + dtype="float32", + ) + sample_weight = np.array((1.0, 1.5, 2.0, 2.5)) + result = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.5, result, atol=1e-5) diff --git a/keras/src/metrics/iou_metrics.py b/keras/src/metrics/iou_metrics.py new file mode 100644 index 000000000000..0208381431d1 --- /dev/null +++ b/keras/src/metrics/iou_metrics.py @@ -0,0 +1,759 @@ +import warnings + +from keras.src import backend +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.metrics.metric import Metric +from keras.src.metrics.metrics_utils import confusion_matrix + + +class _IoUBase(Metric): + """Computes the confusion matrix for Intersection-Over-Union metrics. + + Formula: + + ```python + iou = true_positives / (true_positives + false_positives + false_negatives) + ``` + Intersection-Over-Union is a common evaluation metric for semantic image + segmentation. + + From IoUs of individual classes, the MeanIoU can be computed as the mean of + the individual IoUs. + + To compute IoUs, the predictions are accumulated in a confusion matrix, + weighted by `sample_weight` and the metric is then calculated from it. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Args: + num_classes: The possible number of labels the prediction task can have. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + ignore_class: Optional integer. The ID of a class to be ignored during + metric computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. By default (`ignore_class=None`), all classes are + considered. + sparse_y_true: Whether labels are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + sparse_y_pred: Whether predictions are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + axis: (Optional) -1 is the dimension containing the logits. + Defaults to `-1`. + """ + + def __init__( + self, + num_classes, + name=None, + dtype=None, + ignore_class=None, + sparse_y_true=True, + sparse_y_pred=True, + axis=-1, + ): + # defaulting to int to avoid issues with confusion matrix + super().__init__(name=name, dtype=dtype or "int") + # Metric should be maximized during optimization. + self._direction = "up" + self.num_classes = num_classes + self.ignore_class = ignore_class + self.sparse_y_true = sparse_y_true + self.sparse_y_pred = sparse_y_pred + self.axis = axis + + self.total_cm = self.add_variable( + name="total_confusion_matrix", + shape=(num_classes, num_classes), + initializer=initializers.Zeros(), + dtype=self.dtype, + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates the confusion matrix statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Can + be a `Tensor` whose rank is either 0, or the same as `y_true`, + and must be broadcastable to `y_true`. Defaults to `1`. + + Returns: + Update op. + """ + + if not self.sparse_y_true: + y_true = ops.argmax(y_true, axis=self.axis) + if not self.sparse_y_pred: + y_pred = ops.argmax(y_pred, axis=self.axis) + + y_true = ops.convert_to_tensor(y_true, dtype=self.dtype) + y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype) + + # Flatten the input if its rank > 1. + if len(y_pred.shape) > 1: + y_pred = ops.reshape(y_pred, [-1]) + + if len(y_true.shape) > 1: + y_true = ops.reshape(y_true, [-1]) + + if sample_weight is None: + sample_weight = 1 + else: + if ( + hasattr(sample_weight, "dtype") + and "float" in str(sample_weight.dtype) + and "int" in str(self.dtype) + ): + warnings.warn( + "You are passing weight as `float`, but dtype is `int`. " + "This may result in an incorrect weight due to type casting" + " Consider using integer weights." + ) + sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype) + + if len(sample_weight.shape) > 1: + sample_weight = ops.reshape(sample_weight, [-1]) + + sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true)) + + if self.ignore_class is not None: + ignore_class = ops.convert_to_tensor( + self.ignore_class, y_true.dtype + ) + valid_mask = ops.not_equal(y_true, ignore_class) + y_true = y_true * ops.cast(valid_mask, y_true.dtype) + y_pred = y_pred * ops.cast(valid_mask, y_pred.dtype) + if sample_weight is not None: + sample_weight = sample_weight * ops.cast( + valid_mask, sample_weight.dtype + ) + + y_pred = ops.cast(y_pred, dtype=self.dtype) + y_true = ops.cast(y_true, dtype=self.dtype) + sample_weight = ops.cast(sample_weight, dtype=self.dtype) + + current_cm = confusion_matrix( + y_true, + y_pred, + self.num_classes, + weights=sample_weight, + dtype=self.dtype, + ) + + return self.total_cm.assign(self.total_cm + current_cm) + + def reset_state(self): + self.total_cm.assign( + ops.zeros(self.total_cm.shape, dtype=self.total_cm.dtype) + ) + + +@keras_export("keras.metrics.IoU") +class IoU(_IoUBase): + """Computes the Intersection-Over-Union metric for specific target classes. + + Formula: + + ```python + iou = true_positives / (true_positives + false_positives + false_negatives) + ``` + Intersection-Over-Union is a common evaluation metric for semantic image + segmentation. + + To compute IoUs, the predictions are accumulated in a confusion matrix, + weighted by `sample_weight` and the metric is then calculated from it. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Note, this class first computes IoUs for all individual classes, then + returns the mean of IoUs for the classes that are specified by + `target_class_ids`. If `target_class_ids` has only one id value, the IoU of + that specific class is returned. + + Args: + num_classes: The possible number of labels the prediction task can have. + target_class_ids: A tuple or list of target class ids for which the + metric is returned. To compute IoU for a specific class, a list + (or tuple) of a single id value should be provided. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + ignore_class: Optional integer. The ID of a class to be ignored during + metric computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. By default (`ignore_class=None`), all classes are + considered. + sparse_y_true: Whether labels are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + sparse_y_pred: Whether predictions are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + axis: (Optional) -1 is the dimension containing the logits. + Defaults to `-1`. + + Examples: + + >>> # cm = [[1, 1], + >>> # [1, 1]] + >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] + >>> # iou = true_positives / (sum_row + sum_col - true_positives)) + >>> # iou = [0.33, 0.33] + >>> m = keras.metrics.IoU(num_classes=2, target_class_ids=[0]) + >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) + >>> m.result() + 0.33333334 + + >>> m.reset_state() + >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], + ... sample_weight=[0.3, 0.3, 0.3, 0.1]) + >>> # cm = [[0.3, 0.3], + >>> # [0.3, 0.1]] + >>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4], + >>> # true_positives = [0.3, 0.1] + >>> # iou = [0.33, 0.14] + >>> m.result() + 0.33333334 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.IoU(num_classes=2, target_class_ids=[0])]) + ``` + """ + + def __init__( + self, + num_classes, + target_class_ids, + name=None, + dtype=None, + ignore_class=None, + sparse_y_true=True, + sparse_y_pred=True, + axis=-1, + ): + super().__init__( + name=name, + num_classes=num_classes, + ignore_class=ignore_class, + sparse_y_true=sparse_y_true, + sparse_y_pred=sparse_y_pred, + axis=axis, + dtype=dtype, + ) + if max(target_class_ids) >= num_classes: + raise ValueError( + f"Target class id {max(target_class_ids)} " + "is out of range, which is " + f"[{0}, {num_classes})." + ) + self.target_class_ids = list(target_class_ids) + + def result(self): + """Compute the intersection-over-union via the confusion matrix.""" + sum_over_row = ops.cast( + ops.sum(self.total_cm, axis=0), dtype=self.dtype + ) + sum_over_col = ops.cast( + ops.sum(self.total_cm, axis=1), dtype=self.dtype + ) + true_positives = ops.cast(ops.diag(self.total_cm), dtype=self.dtype) + + # sum_over_row + sum_over_col = + # 2 * true_positives + false_positives + false_negatives. + denominator = sum_over_row + sum_over_col - true_positives + + target_class_ids = ops.convert_to_tensor( + self.target_class_ids, dtype="int32" + ) + + # Only keep the target classes + true_positives = ops.take_along_axis( + true_positives, target_class_ids, axis=-1 + ) + denominator = ops.take_along_axis( + denominator, target_class_ids, axis=-1 + ) + denominator = ops.cast(denominator, dtype="float32") + + # If the denominator is 0, we need to ignore the class. + num_valid_entries = ops.sum( + ops.cast(ops.greater(denominator, 1e-9), dtype="float32") + ) + + iou = ops.divide(true_positives, denominator + backend.epsilon()) + + return ops.divide( + ops.sum(iou, axis=self.axis), num_valid_entries + backend.epsilon() + ) + + def get_config(self): + config = { + "num_classes": self.num_classes, + "target_class_ids": self.target_class_ids, + "ignore_class": self.ignore_class, + "sparse_y_true": self.sparse_y_true, + "sparse_y_pred": self.sparse_y_pred, + "axis": self.axis, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +@keras_export("keras.metrics.BinaryIoU") +class BinaryIoU(IoU): + """Computes the Intersection-Over-Union metric for class 0 and/or 1. + + Formula: + + ```python + iou = true_positives / (true_positives + false_positives + false_negatives) + ``` + Intersection-Over-Union is a common evaluation metric for semantic image + segmentation. + + To compute IoUs, the predictions are accumulated in a confusion matrix, + weighted by `sample_weight` and the metric is then calculated from it. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + This class can be used to compute IoUs for a binary classification task + where the predictions are provided as logits. First a `threshold` is applied + to the predicted values such that those that are below the `threshold` are + converted to class 0 and those that are above the `threshold` are converted + to class 1. + + IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes + that are specified by `target_class_ids` is returned. + + Note: with `threshold=0`, this metric has the same behavior as `IoU`. + + Args: + target_class_ids: A tuple or list of target class ids for which the + metric is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With + `[0]` (or `[1]`), the IoU metric for class 0 (or class 1, + respectively) is returned. With `[0, 1]`, the mean of IoUs for the + two classes is returned. + threshold: A threshold that applies to the prediction logits to convert + them to either predicted class 0 if the logit is below `threshold` + or predicted class 1 if the logit is above `threshold`. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3) + >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7]) + >>> m.result() + 0.33333334 + + >>> m.reset_state() + >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7], + ... sample_weight=[0.2, 0.3, 0.4, 0.1]) + >>> # cm = [[0.2, 0.4], + >>> # [0.3, 0.1]] + >>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], + >>> # true_positives = [0.2, 0.1] + >>> # iou = [0.222, 0.125] + >>> m.result() + 0.17361112 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.BinaryIoU( + target_class_ids=[0], + threshold=0.5 + )] + ) + ``` + """ + + def __init__( + self, + target_class_ids=(0, 1), + threshold=0.5, + name=None, + dtype=None, + ): + super().__init__( + num_classes=2, + target_class_ids=target_class_ids, + name=name, + dtype=dtype, + ) + self.threshold = threshold + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates the confusion matrix statistics. + + Before the confusion matrix is updated, the predicted values are + thresholded to be: + 0 for values that are smaller than the `threshold` + 1 for values that are larger or equal to the `threshold` + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Can + be a `Tensor` whose rank is either 0, or the same as `y_true`, + and must be broadcastable to `y_true`. Defaults to `1`. + + Returns: + Update op. + """ + y_true = ops.convert_to_tensor(y_true, dtype=self.dtype) + # convert y_pred on float 32 and cast just after to dtype + y_pred = ops.convert_to_tensor(y_pred, dtype="float32") + y_pred = ops.cast(y_pred >= self.threshold, self.dtype) + return super().update_state(y_true, y_pred, sample_weight) + + def get_config(self): + return { + "target_class_ids": self.target_class_ids, + "threshold": self.threshold, + "name": self.name, + "dtype": self._dtype, + } + + +@keras_export("keras.metrics.MeanIoU") +class MeanIoU(IoU): + """Computes the mean Intersection-Over-Union metric. + + Formula: + + ```python + iou = true_positives / (true_positives + false_positives + false_negatives) + ``` + Intersection-Over-Union is a common evaluation metric for semantic image + segmentation. + + To compute IoUs, the predictions are accumulated in a confusion matrix, + weighted by `sample_weight` and the metric is then calculated from it. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + Note that this class first computes IoUs for all individual classes, then + returns the mean of these values. + + Args: + num_classes: The possible number of labels the prediction task can have. + This value must be provided, since a confusion matrix of dimension = + [num_classes, num_classes] will be allocated. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + ignore_class: Optional integer. The ID of a class to be ignored during + metric computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. By default (`ignore_class=None`), all classes are + considered. + sparse_y_true: Whether labels are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + sparse_y_pred: Whether predictions are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + axis: (Optional) The dimension containing the logits. Defaults to `-1`. + + + Example: + + >>> # cm = [[1, 1], + >>> # [1, 1]] + >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] + >>> # iou = true_positives / (sum_row + sum_col - true_positives)) + >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 + >>> m = keras.metrics.MeanIoU(num_classes=2) + >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) + >>> m.result() + 0.33333334 + + >>> m.reset_state() + >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], + ... sample_weight=[0.3, 0.3, 0.3, 0.1]) + >>> m.result().numpy() + 0.23809525 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.MeanIoU(num_classes=2)]) + ``` + """ + + def __init__( + self, + num_classes, + name=None, + dtype=None, + ignore_class=None, + sparse_y_true=True, + sparse_y_pred=True, + axis=-1, + ): + target_class_ids = list(range(num_classes)) + super().__init__( + name=name, + num_classes=num_classes, + target_class_ids=target_class_ids, + axis=axis, + dtype=dtype, + ignore_class=ignore_class, + sparse_y_true=sparse_y_true, + sparse_y_pred=sparse_y_pred, + ) + + def get_config(self): + return { + "num_classes": self.num_classes, + "name": self.name, + "dtype": self._dtype, + "ignore_class": self.ignore_class, + "sparse_y_true": self.sparse_y_true, + "sparse_y_pred": self.sparse_y_pred, + "axis": self.axis, + } + + +@keras_export("keras.metrics.OneHotIoU") +class OneHotIoU(IoU): + """Computes the Intersection-Over-Union metric for one-hot encoded labels. + + Formula: + + ```python + iou = true_positives / (true_positives + false_positives + false_negatives) + ``` + Intersection-Over-Union is a common evaluation metric for semantic image + segmentation. + + To compute IoUs, the predictions are accumulated in a confusion matrix, + weighted by `sample_weight` and the metric is then calculated from it. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + This class can be used to compute IoU for multi-class classification tasks + where the labels are one-hot encoded (the last axis should have one + dimension per class). Note that the predictions should also have the same + shape. To compute the IoU, first the labels and predictions are converted + back into integer format by taking the argmax over the class axis. Then the + same computation steps as for the base `IoU` class apply. + + Note, if there is only one channel in the labels and predictions, this class + is the same as class `IoU`. In this case, use `IoU` instead. + + Also, make sure that `num_classes` is equal to the number of classes in the + data, to avoid a "labels out of bound" error when the confusion matrix is + computed. + + Args: + num_classes: The possible number of labels the prediction task can have. + target_class_ids: A tuple or list of target class ids for which the + metric is returned. To compute IoU for a specific class, a list + (or tuple) of a single id value should be provided. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + ignore_class: Optional integer. The ID of a class to be ignored during + metric computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. By default (`ignore_class=None`), all classes are + considered. + sparse_y_pred: Whether predictions are encoded using integers or + dense floating point vectors. If `False`, the `argmax` function + is used to determine each sample's most likely associated label. + axis: (Optional) The dimension containing the logits. Defaults to `-1`. + + + Example: + + >>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) + >>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], + ... [0.1, 0.4, 0.5]]) + >>> sample_weight = [0.1, 0.2, 0.3, 0.4] + >>> m = keras.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2]) + >>> m.update_state( + ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) + >>> # cm = [[0, 0, 0.2+0.4], + >>> # [0.3, 0, 0], + >>> # [0, 0, 0.1]] + >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1] + >>> # true_positives = [0, 0, 0.1] + >>> # single_iou = true_positives / (sum_row + sum_col - true_positives)) + >>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2 + >>> m.result() + 0.071 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.OneHotIoU( + num_classes=3, + target_class_id=[1] + )] + ) + ``` + """ + + def __init__( + self, + num_classes, + target_class_ids, + name=None, + dtype=None, + ignore_class=None, + sparse_y_pred=False, + axis=-1, + ): + super().__init__( + num_classes=num_classes, + target_class_ids=target_class_ids, + name=name, + dtype=dtype, + ignore_class=ignore_class, + sparse_y_true=False, + sparse_y_pred=sparse_y_pred, + axis=axis, + ) + + def get_config(self): + return { + "num_classes": self.num_classes, + "target_class_ids": self.target_class_ids, + "name": self.name, + "dtype": self._dtype, + "ignore_class": self.ignore_class, + "sparse_y_pred": self.sparse_y_pred, + "axis": self.axis, + } + + +@keras_export("keras.metrics.OneHotMeanIoU") +class OneHotMeanIoU(MeanIoU): + """Computes mean Intersection-Over-Union metric for one-hot encoded labels. + + Formula: + + ```python + iou = true_positives / (true_positives + false_positives + false_negatives) + ``` + Intersection-Over-Union is a common evaluation metric for semantic image + segmentation. + + To compute IoUs, the predictions are accumulated in a confusion matrix, + weighted by `sample_weight` and the metric is then calculated from it. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + + This class can be used to compute the mean IoU for multi-class + classification tasks where the labels are one-hot encoded (the last axis + should have one dimension per class). Note that the predictions should also + have the same shape. To compute the mean IoU, first the labels and + predictions are converted back into integer format by taking the argmax over + the class axis. Then the same computation steps as for the base `MeanIoU` + class apply. + + Note, if there is only one channel in the labels and predictions, this class + is the same as class `MeanIoU`. In this case, use `MeanIoU` instead. + + Also, make sure that `num_classes` is equal to the number of classes in the + data, to avoid a "labels out of bound" error when the confusion matrix is + computed. + + Args: + num_classes: The possible number of labels the prediction task can have. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + ignore_class: Optional integer. The ID of a class to be ignored during + metric computation. This is useful, for example, in segmentation + problems featuring a "void" class (commonly -1 or 255) in + segmentation maps. By default (`ignore_class=None`), all classes are + considered. + sparse_y_pred: Whether predictions are encoded using natural numbers or + probability distribution vectors. If `False`, the `argmax` + function will be used to determine each sample's most likely + associated label. + axis: (Optional) The dimension containing the logits. Defaults to `-1`. + + + Example: + + >>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) + >>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], + ... [0.1, 0.4, 0.5]]) + >>> sample_weight = [0.1, 0.2, 0.3, 0.4] + >>> m = keras.metrics.OneHotMeanIoU(num_classes=3) + >>> m.update_state( + ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) + >>> # cm = [[0, 0, 0.2+0.4], + >>> # [0.3, 0, 0], + >>> # [0, 0, 0.1]] + >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1] + >>> # true_positives = [0, 0, 0.1] + >>> # single_iou = true_positives / (sum_row + sum_col - true_positives)) + >>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3 + >>> m.result() + 0.048 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.OneHotMeanIoU(num_classes=3)]) + ``` + """ + + def __init__( + self, + num_classes, + name=None, + dtype=None, + ignore_class=None, + sparse_y_pred=False, + axis=-1, + ): + super().__init__( + num_classes=num_classes, + axis=axis, + name=name, + dtype=dtype, + ignore_class=ignore_class, + sparse_y_true=False, + sparse_y_pred=sparse_y_pred, + ) + + def get_config(self): + return { + "num_classes": self.num_classes, + "name": self.name, + "dtype": self._dtype, + "ignore_class": self.ignore_class, + "sparse_y_pred": self.sparse_y_pred, + "axis": self.axis, + } diff --git a/keras/src/metrics/iou_metrics_test.py b/keras/src/metrics/iou_metrics_test.py new file mode 100644 index 000000000000..172c3b02f089 --- /dev/null +++ b/keras/src/metrics/iou_metrics_test.py @@ -0,0 +1,565 @@ +import numpy as np +import pytest + +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.metrics import iou_metrics as metrics +from keras.src.ops import convert_to_tensor + + +class IoUTest(testing.TestCase): + def test_config(self): + obj = metrics.IoU( + num_classes=2, target_class_ids=[1, 0], name="iou_class_1_0" + ) + self.assertEqual(obj.name, "iou_class_1_0") + self.assertEqual(obj.num_classes, 2) + self.assertEqual(obj.target_class_ids, [1, 0]) + + obj2 = metrics.IoU.from_config(obj.get_config()) + self.assertEqual(obj2.name, "iou_class_1_0") + self.assertEqual(obj2.num_classes, 2) + self.assertEqual(obj2.target_class_ids, [1, 0]) + + def test_unweighted(self): + y_pred = [0, 1, 0, 1] + y_true = [0, 0, 1, 1] + + obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1]) + + result = obj(y_true, y_pred) + + # cm = [[1, 1], + # [1, 1]] + # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + y_pred = np.array([0, 1, 0, 1], dtype=np.float32) + y_true = np.array([0, 0, 1, 1]) + sample_weight = np.array([0.2, 0.3, 0.4, 0.1]) + + obj = metrics.IoU( + num_classes=2, target_class_ids=[1, 0], dtype="float32" + ) + + result = obj(y_true, y_pred, sample_weight=sample_weight) + + # cm = [[0.2, 0.3], + # [0.4, 0.1]] + # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2, + # 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.1 / (0.4 + 0.5 - 0.1) + 0.2 / (0.6 + 0.5 - 0.2) + ) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_multi_dim_input(self): + y_pred = np.array([[0, 1], [0, 1]], dtype=np.float32) + y_true = np.array([[0, 0], [1, 1]]) + sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]]) + + obj = metrics.IoU( + num_classes=2, target_class_ids=[0, 1], dtype="float32" + ) + + result = obj(y_true, y_pred, sample_weight=sample_weight) + + # cm = [[0.2, 0.3], + # [0.4, 0.1]] + # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2, + # 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1) + ) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_zero_valid_entries(self): + obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1]) + self.assertAllClose(obj.result(), 0, atol=1e-3) + + def test_zero_and_non_zero_entries(self): + y_pred = np.array([1], dtype=np.float32) + y_true = np.array([1]) + + obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1]) + result = obj(y_true, y_pred) + + # cm = [[0, 0], + # [0, 1]] + # sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (1 / (1 + 1 - 1)) / 1 + self.assertAllClose(result, expected_result, atol=1e-3) + + @pytest.mark.requires_trainable_backend + def test_compilation(self): + m_obj = metrics.MeanIoU(num_classes=2, ignore_class=0) + model = models.Sequential( + [ + layers.Dense(2, activation="softmax"), + ] + ) + model.compile(optimizer="rmsprop", loss="mse", metrics=[m_obj]) + model.fit(np.array([[1.0, 1.0]]), np.array([[1.0, 0.0]])) + + +class BinaryIoUTest(testing.TestCase): + def test_config(self): + obj = metrics.BinaryIoU( + target_class_ids=[1, 0], threshold=0.1, name="iou_class_1_0" + ) + self.assertEqual(obj.name, "iou_class_1_0") + self.assertAlmostEqual(obj.threshold, 0.1) + self.assertEqual(obj.target_class_ids, [1, 0]) + + obj2 = metrics.BinaryIoU.from_config(obj.get_config()) + self.assertEqual(obj.name, "iou_class_1_0") + self.assertAlmostEqual(obj2.threshold, 0.1) + self.assertEqual(obj.target_class_ids, [1, 0]) + + def test_different_thresholds_weighted(self): + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.2, 0.4, 0.7] + + sample_weight = np.array([0.2, 0.3, 0.4, 0.1]) + # with threshold = 0.3, y_pred will be converted to [0, 0, 1, 1] + # cm = [[0.2, 0.4], + # [0.3, 0.1]] + # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2, + # 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1) + ) / 2 + obj = metrics.BinaryIoU( + target_class_ids=[0, 1], threshold=0.3, dtype="float32" + ) + result = obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + sample_weight = np.array([0.1, 0.2, 0.4, 0.3]) + # with threshold = 0.5, y_pred will be converted to [0, 0, 0, 1] + # cm = [[0.1+0.4, 0], + # [0.2, 0.3]] + # sum_row = [0.5, 0.5], sum_col = [0.7, 0.3], true_positives = [0.5, + # 0.3] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.5 / (0.5 + 0.7 - 0.5) + 0.3 / (0.5 + 0.3 - 0.3) + ) / 2 + obj = metrics.BinaryIoU( + target_class_ids=[0, 1], threshold=0.5, dtype="float32" + ) + result = obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_different_thresholds_unweighted(self): + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.2, 0.4, 0.7] + + # with threshold = 0.3, y_pred will be converted to [0, 0, 1, 1] + # cm = [[1, 1], + # [1, 1]] + # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 + obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3) + result = obj(y_true, y_pred) + self.assertAllClose(result, expected_result, atol=1e-3) + + # with threshold = 0.5, y_pred will be converted to [0, 0, 0, 1] + # cm = [[2, 0], + # [1, 1]] + # sum_row = [2, 2], sum_col = [3, 1], true_positives = [2, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (2 / (2 + 3 - 2) + 1 / (2 + 1 - 1)) / 2 + obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5) + result = obj(y_true, y_pred) + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_multi_dim_input(self): + y_true = np.array([[0, 1], [0, 1]], dtype=np.float32) + y_pred = np.array([[0.1, 0.7], [0.9, 0.3]]) + threshold = 0.4 # y_pred will become [[0, 1], [1, 0]] + sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]]) + # cm = [[0.2, 0.4], + # [0.1, 0.3]] + # sum_row = [0.6, 0.4], sum_col = [0.3, 0.7], true_positives = [0.2, + # 0.3] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.2 / (0.6 + 0.3 - 0.2) + 0.3 / (0.4 + 0.7 - 0.3) + ) / 2 + obj = metrics.BinaryIoU( + target_class_ids=[0, 1], threshold=threshold, dtype="float32" + ) + result = obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_zero_valid_entries(self): + obj = metrics.BinaryIoU(target_class_ids=[0, 1]) + self.assertAllClose(obj.result(), 0, atol=1e-3) + + def test_zero_and_non_zero_entries(self): + y_pred = np.array([0.6], dtype=np.float32) + threshold = 0.5 + y_true = np.array([1]) + + obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=threshold) + result = obj(y_true, y_pred) + + # cm = [[0, 0], + # [0, 1]] + # sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = 1 / (1 + 1 - 1) + self.assertAllClose(result, expected_result, atol=1e-3) + + +class MeanIoUTest(testing.TestCase): + def test_config(self): + m_obj = metrics.MeanIoU(num_classes=2, name="mean_iou") + self.assertEqual(m_obj.name, "mean_iou") + self.assertEqual(m_obj.num_classes, 2) + + m_obj2 = metrics.MeanIoU.from_config(m_obj.get_config()) + self.assertEqual(m_obj2.name, "mean_iou") + self.assertEqual(m_obj2.num_classes, 2) + + def test_unweighted(self): + y_pred = [0, 1, 0, 1] + y_true = [0, 0, 1, 1] + + m_obj = metrics.MeanIoU(num_classes=2) + + result = m_obj(y_true, y_pred) + + # cm = [[1, 1], + # [1, 1]] + # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_unweighted_ignore_class_255(self): + y_pred = [0, 1, 1, 1] + y_true = [0, 1, 2, 255] + + m_obj = metrics.MeanIoU(num_classes=3, ignore_class=255) + + result = m_obj(y_true, y_pred) + + # cm = [[1, 0, 0], + # [0, 1, 0], + # [0, 1, 0]] + # sum_row = [1, 1, 1], sum_col = [1, 2, 0], true_positives = [1, 1, 0] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 1 / (1 + 1 - 1) + 1 / (2 + 1 - 1) + 0 / (0 + 1 - 0) + ) / 3 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_unweighted_ignore_class_1(self): + y_pred = [0, 1, 1, 1] + y_true = [0, 1, 2, -1] + + m_obj = metrics.MeanIoU(num_classes=3, ignore_class=-1) + + result = m_obj(y_true, y_pred) + + # cm = [[1, 0, 0], + # [0, 1, 0], + # [0, 1, 0]] + # sum_row = [1, 1, 1], sum_col = [1, 2, 0], true_positives = [1, 1, 0] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 1 / (1 + 1 - 1) + 1 / (2 + 1 - 1) + 0 / (0 + 1 - 0) + ) / 3 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + y_pred = np.array([0, 1, 0, 1], dtype=np.float32) + y_true = np.array([0, 0, 1, 1]) + sample_weight = np.array([0.2, 0.3, 0.4, 0.1]) + + m_obj = metrics.MeanIoU(num_classes=2, dtype="float32") + + result = m_obj(y_true, y_pred, sample_weight=sample_weight) + + # cm = [[0.2, 0.3], + # [0.4, 0.1]] + # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2, + # 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1) + ) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted_ignore_class_1(self): + y_pred = np.array([0, 1, 0, 1], dtype=np.float32) + y_true = np.array([0, 0, 1, -1]) + sample_weight = np.array([0.2, 0.3, 0.4, 0.1]) + + m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1, dtype="float32") + + result = m_obj(y_true, y_pred, sample_weight=sample_weight) + + # cm = [[0.2, 0.3], + # [0.4, 0.0]] + # sum_row = [0.6, 0.3], sum_col = [0.5, 0.4], true_positives = [0.2, + # 0.0] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.2 / (0.6 + 0.5 - 0.2) + 0.0 / (0.3 + 0.4 - 0.0) + ) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_multi_dim_input(self): + y_pred = np.array([[0, 1], [0, 1]], dtype=np.float32) + y_true = np.array([[0, 0], [1, 1]]) + sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]]) + + m_obj = metrics.MeanIoU(num_classes=2, dtype="float32") + + result = m_obj(y_true, y_pred, sample_weight=sample_weight) + + # cm = [[0.2, 0.3], + # [0.4, 0.1]] + # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2, + # 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1) + ) / 2 + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_zero_valid_entries(self): + m_obj = metrics.MeanIoU(num_classes=2) + self.assertAllClose(m_obj.result(), 0, atol=1e-3) + + def test_zero_and_non_zero_entries(self): + y_pred = np.array([1], dtype=np.float32) + y_true = np.array([1]) + + m_obj = metrics.MeanIoU(num_classes=2) + result = m_obj(y_true, y_pred) + + # cm = [[0, 0], + # [0, 1]] + # sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (0 + 1 / (1 + 1 - 1)) / 1 + self.assertAllClose(result, expected_result, atol=1e-3) + + @staticmethod + def _confusion_matrix(y_true, y_pred, num_classes): + """ + Creates a confusion matrix as a numpy array using vectorized operations. + + Parameters: + - y_true: array-like, true class labels. + - y_pred: array-like, predicted class labels. + - num_classes: int, number of classes. + + Returns: + - conf_matrix: np.ndarray, confusion matrix of shape (num_classes, + num_classes). + """ + # Map pairs of (y_true, y_pred) to indices in the confusion matrix + indices = y_true * num_classes + y_pred + # Count occurrences of each index + conf_matrix = np.bincount(indices, minlength=num_classes * num_classes) + # Reshape the flat array into a 2D confusion matrix + conf_matrix = conf_matrix.reshape((num_classes, num_classes)) + return conf_matrix + + @staticmethod + def _get_big_chunk(dtype): + np.random.seed(14) + all_y_true = np.random.choice([0, 1, 2], size=(10, 530, 530)) + # Generate random probabilities for each channel + random_probs = np.random.rand(10, 530, 530, 3) + # Normalize to ensure the last dimension sums to 1 + all_y_pred = random_probs / random_probs.sum(axis=-1, keepdims=True) + # Convert predictions to class indices + all_y_pred_arg = np.argmax(all_y_pred, axis=-1) + mean_iou_metric = metrics.MeanIoU(num_classes=3, dtype=dtype) + conf_matrix_start_point = np.array( + [ + [18729664, 18728760, 18731196], + [18727297, 18726105, 18728071], + [18727917, 18717835, 18723155], + ] + ) + mean_iou_metric.total_cm = mean_iou_metric.add_variable( + name="total_confusion_matrix", + shape=(3, 3), + initializer=convert_to_tensor(conf_matrix_start_point), + dtype=dtype or "int", + ) + mean_iou_metric.update_state(all_y_true, all_y_pred_arg) + tmp_true = np.reshape(all_y_true, -1) + tmp_pred = np.reshape(all_y_pred_arg, -1) + return ( + all_y_true, + all_y_pred_arg, + mean_iou_metric, + tmp_true, + tmp_pred, + conf_matrix_start_point, + ) + + def test_big_chunk(self): + # Init. process with dtype=None which will default to int + ( + all_y_true, + all_y_pred_arg, + mean_iou_metric_all, + tmp_true, + tmp_pred, + conf_matrix_start_point, + ) = self._get_big_chunk(dtype=None) + conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm) + # Validate confusion matrices and results + conf_matrix_manual = ( + self._confusion_matrix(tmp_true, tmp_pred, 3) + + conf_matrix_start_point + ) + self.assertTrue( + np.array_equal(conf_matrix_from_keras, conf_matrix_manual), + msg="Confusion matrices do not match!", + ) + # Now same but with float32 dtype, in here the confusion matrix + # should not match. Likely this can be removed + ( + all_y_true, + all_y_pred_arg, + mean_iou_metric_all, + tmp_true, + tmp_pred, + conf_matrix_start_point, + ) = self._get_big_chunk(dtype="float32") + conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm) + # Validate confusion matrices and results + conf_matrix_manual = ( + self._confusion_matrix(tmp_true, tmp_pred, 3) + + conf_matrix_start_point + ) + self.assertFalse( + np.array_equal(conf_matrix_from_keras, conf_matrix_manual), + msg="Confusion matrices match, but they should not!", + ) + + def test_user_warning_float_weight(self): + y_pred = [0, 1, 1, 1] + y_true = [0, 1, 1, 0] + m_obj = metrics.MeanIoU(num_classes=3) + with pytest.warns(Warning, match=r"weight.*float.*int.*casting"): + m_obj(y_true, y_pred, sample_weight=np.array([0.2, 0.3, 0.4, 0.1])) + + +class OneHotIoUTest(testing.TestCase): + def test_unweighted(self): + y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) + # y_true will be converted to [2, 0, 1, 0] + y_pred = np.array( + [[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]] + ) + # y_pred will be converted to [2, 2, 0, 2] + # cm = [[0, 0, 2], + # [1, 0, 0], + # [0, 0, 1] + # sum_row = [1, 0, 3], sum_col = [2, 1, 1], true_positives = [0, 0, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (0 / (1 + 2 - 0) + 1 / (3 + 1 - 1)) / 2 + obj = metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2]) + result = obj(y_true, y_pred) + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) + # y_true will be converted to [2, 0, 1, 0] + y_pred = np.array( + [[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]] + ) + # y_pred will be converted to [2, 2, 0, 2] + sample_weight = [0.1, 0.2, 0.3, 0.4] + # cm = [[0, 0, 0.2+0.4], + # [0.3, 0, 0], + # [0, 0, 0.1]] + # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1] + # true_positives = [0, 0, 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2 + obj = metrics.OneHotIoU( + num_classes=3, target_class_ids=[0, 2], dtype="float32" + ) + result = obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + +class OneHotMeanIoUTest(testing.TestCase): + def test_unweighted(self): + y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]]) + # y_true will be converted to [2, 0, 1, 0] + y_pred = np.array( + [[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]] + ) + # y_pred will be converted to [2, 2, 0, 2] + # cm = [[0, 0, 2], + # [1, 0, 0], + # [0, 0, 1] + # sum_row = [1, 0, 3], sum_col = [2, 1, 1], true_positives = [0, 0, 1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = (0 + 0 + 1 / (3 + 1 - 1)) / 3 + obj = metrics.OneHotMeanIoU(num_classes=3) + result = obj(y_true, y_pred) + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + y_true = np.array( + [ + [0, 0, 1], + [1, 0, 0], + [0, 1, 0], + [1, 0, 0], + [1, 0, 0], + ] + ) + # y_true will be converted to [2, 0, 1, 0, 0] + y_pred = np.array( + [ + [0.2, 0.3, 0.5], + [0.1, 0.2, 0.7], + [0.5, 0.3, 0.1], + [0.1, 0.4, 0.5], + [0.6, 0.2, 0.2], + ] + ) + # y_pred will be converted to [2, 2, 0, 2, 0] + sample_weight = [0.1, 0.2, 0.3, 0.3, 0.1] + # cm = [[0.1, 0, 0.2+0.3], + # [0.3, 0, 0], + # [0, 0, 0.1]] + # sum_row = [0.4, 0, 0.6], sum_col = [0.6, 0.3, 0.1] + # true_positives = [0.1, 0, 0.1] + # iou = true_positives / (sum_row + sum_col - true_positives)) + expected_result = ( + 0.1 / (0.4 + 0.6 - 0.1) + 0 + 0.1 / (0.6 + 0.1 - 0.1) + ) / 3 + obj = metrics.OneHotMeanIoU(num_classes=3, dtype="float32") + result = obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + # Check same result with int weights + sample_weight_int = [1, 2, 3, 3, 1] + obj_int = metrics.OneHotMeanIoU(num_classes=3) + result_int = obj_int(y_true, y_pred, sample_weight=sample_weight_int) + self.assertAllClose(result_int, expected_result, atol=1e-3) diff --git a/keras/src/metrics/metric.py b/keras/src/metrics/metric.py new file mode 100644 index 000000000000..eb777c943907 --- /dev/null +++ b/keras/src/metrics/metric.py @@ -0,0 +1,254 @@ +from keras.src import backend +from keras.src import dtype_policies +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.saving.keras_saveable import KerasSaveable +from keras.src.utils.naming import auto_name +from keras.src.utils.tracking import Tracker + + +@keras_export(["keras.Metric", "keras.metrics.Metric"]) +class Metric(KerasSaveable): + """Encapsulates metric logic and state. + + Args: + name: Optional name for the metric instance. + dtype: The dtype of the metric's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + + Example: + + ```python + m = SomeMetric(...) + for input in ...: + m.update_state(input) + print('Final result: ', m.result()) + ``` + + Usage with `compile()` API: + + ```python + model = keras.Sequential() + model.add(keras.layers.Dense(64, activation='relu')) + model.add(keras.layers.Dense(64, activation='relu')) + model.add(keras.layers.Dense(10, activation='softmax')) + + model.compile(optimizer=keras.optimizers.RMSprop(0.01), + loss=keras.losses.CategoricalCrossentropy(), + metrics=[keras.metrics.CategoricalAccuracy()]) + + data = np.random.random((1000, 32)) + labels = np.random.random((1000, 10)) + + model.fit(data, labels, epochs=10) + ``` + + To be implemented by subclasses: + + * `__init__()`: All state variables should be created in this method by + calling `self.add_variable()` like: `self.var = self.add_variable(...)` + * `update_state()`: Has all updates to the state variables like: + `self.var.assign(...)`. + * `result()`: Computes and returns a scalar value or a dict of scalar values + for the metric from the state variables. + + Example subclass implementation: + + ```python + class BinaryTruePositives(Metric): + + def __init__(self, name='binary_true_positives', **kwargs): + super().__init__(name=name, **kwargs) + self.true_positives = self.add_variable( + shape=(), + initializer='zeros', + name='true_positives' + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = ops.cast(y_true, "bool") + y_pred = ops.cast(y_pred, "bool") + + values = ops.logical_and( + ops.equal(y_true, True), ops.equal(y_pred, True)) + values = ops.cast(values, self.dtype) + if sample_weight is not None: + sample_weight = ops.cast(sample_weight, self.dtype) + sample_weight = ops.broadcast_to( + sample_weight, ops.shape(values) + ) + values = ops.multiply(values, sample_weight) + self.true_positives.assign(self.true_positives + ops.sum(values)) + + def result(self): + return self.true_positives + ``` + """ + + def __init__(self, dtype=None, name=None): + self.name = name or auto_name(self.__class__.__name__) + self._dtype_policy = dtype_policies.get(dtype or backend.floatx()) + self._dtype = self._dtype_policy.compute_dtype + self._metrics = [] + self._variables = [] + self._tracker = Tracker( + { + "variables": ( + lambda x: isinstance(x, backend.Variable), + self._variables, + ), + "metrics": (lambda x: isinstance(x, Metric), self._metrics), + } + ) + + def reset_state(self): + """Reset all of the metric state variables. + + This function is called between epochs/steps, + when a metric is evaluated during training. + """ + for v in self.variables: + v.assign(ops.zeros(v.shape, dtype=v.dtype)) + + def update_state(self, *args, **kwargs): + """Accumulate statistics for the metric.""" + raise NotImplementedError + + def stateless_update_state(self, metric_variables, *args, **kwargs): + if len(metric_variables) != len(self.variables): + raise ValueError( + "Argument `metric_variables` must be a list of tensors " + f"corresponding 1:1 to {self.__class__.__name__}().variables. " + f"Received list with length {len(metric_variables)}, but " + f"expected {len(self.variables)} variables." + ) + # Gather variable mapping + mapping = list(zip(self.variables, metric_variables)) + + # Call in stateless scope + with backend.StatelessScope(state_mapping=mapping) as scope: + self.update_state(*args, **kwargs) + + # Gather updated variables + metric_variables = [] + for v in self.variables: + new_v = scope.get_current_value(v) + if new_v is not None: + metric_variables.append(new_v) + else: + metric_variables.append(v) + return metric_variables + + def result(self): + """Compute the current metric value. + + Returns: + A scalar tensor, or a dictionary of scalar tensors. + """ + raise NotImplementedError + + def stateless_result(self, metric_variables): + if len(metric_variables) != len(self.variables): + raise ValueError( + "Argument `metric_variables` must be a list of tensors " + f"corresponding 1:1 to {self.__class__.__name__}().variables. " + f"Received list with length {len(metric_variables)}, but " + f"expected {len(self.variables)} variables." + ) + # Gather variable mapping + mapping = list(zip(self.variables, metric_variables)) + + # Call in stateless scope + with backend.StatelessScope(state_mapping=mapping): + res = self.result() + return res + + def stateless_reset_state(self): + # Call in stateless scope + with backend.StatelessScope() as scope: + self.reset_state() + + # Gather updated variables + metric_variables = [] + for v in self.variables: + new_v = scope.get_current_value(v) + if new_v is not None: + metric_variables.append(new_v) + else: + metric_variables.append(v) + return metric_variables + + @property + def dtype(self): + return self._dtype + + def _obj_type(self): + return "Metric" + + def add_variable( + self, shape, initializer, dtype=None, aggregation="sum", name=None + ): + self._check_super_called() + with backend.name_scope(self.name.replace("/", ">"), caller=self): + initializer = initializers.get(initializer) + variable = backend.Variable( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=False, + aggregation=aggregation, + synchronization="on_read", + name=name, + ) + # Prevent double-tracking + self._tracker.add_to_store("variables", variable) + return variable + + def add_weight(self, shape=(), initializer=None, dtype=None, name=None): + # Backwards compatibility alias + return self.add_variable( + shape=shape, initializer=initializer, dtype=dtype, name=name + ) + + @property + def variables(self): + variables = list(self._variables) + for metric in self._metrics: + variables.extend(metric.variables) + return variables + + def __call__(self, *args, **kwargs): + self._check_super_called() + self.update_state(*args, **kwargs) + return self.result() + + def get_config(self): + """Return the serializable config of the metric.""" + return {"name": self.name, "dtype": self.dtype} + + @classmethod + def from_config(cls, config): + return cls(**config) + + def __setattr__(self, name, value): + # Track Variables, Layers, Metrics + if hasattr(self, "_tracker"): + value = self._tracker.track(value) + return super().__setattr__(name, value) + + def _check_super_called(self): + if not hasattr(self, "_tracker"): + raise RuntimeError( + "You forgot to call `super().__init__()` " + "in the `__init__()` method. Go add it!" + ) + + def __repr__(self): + return f"<{self.__class__.__name__} name={self.name}>" + + def __str__(self): + return self.__repr__() diff --git a/keras/src/metrics/metric_test.py b/keras/src/metrics/metric_test.py new file mode 100644 index 000000000000..28903cb56831 --- /dev/null +++ b/keras/src/metrics/metric_test.py @@ -0,0 +1,257 @@ +import pickle + +import numpy as np + +from keras.src import backend +from keras.src import dtype_policies +from keras.src import initializers +from keras.src import metrics as metrics_module +from keras.src import ops +from keras.src import testing +from keras.src.metrics.metric import Metric + + +class ExampleMetric(Metric): + def __init__(self, name="mean_square_error", dtype=None): + super().__init__(name=name, dtype=dtype) + self.sum = self.add_variable( + name="sum", shape=(), initializer=initializers.Zeros() + ) + self.total = self.add_variable( + name="total", + shape=(), + initializer=initializers.Zeros(), + dtype="int32", + ) + + def update_state(self, y_true, y_pred): + y_true = ops.convert_to_tensor(y_true, dtype=self.dtype) + y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype) + sum = ops.sum((y_true - y_pred) ** 2) + self.sum.assign(self.sum + sum) + batch_size = ops.shape(y_true)[0] + self.total.assign(self.total + batch_size) + + def result(self): + _sum = ops.cast(self.sum, dtype=self.dtype) + _total = ops.cast(self.total, dtype=self.dtype) + _epsilon = ops.cast(backend.epsilon(), dtype=self.dtype) + return _sum / (_total + _epsilon) + + def reset_state(self): + self.sum.assign(0) + self.total.assign(0) + + +class MetricTest(testing.TestCase): + def setUp(self): + self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy() + self._floatx = backend.floatx() + return super().setUp() + + def tearDown(self): + dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy) + backend.set_floatx(self._floatx) + return super().tearDown() + + def test_end_to_end_flow(self): + metric = ExampleMetric(name="mse") + self.assertEqual(metric.name, "mse") + self.assertEqual(len(metric.variables), 2) + + num_samples = 20 + y_true = np.random.random((num_samples, 3)) + y_pred = np.random.random((num_samples, 3)) + batch_size = 8 + for b in range(0, num_samples // batch_size + 1): + y_true_batch = y_true[b * batch_size : (b + 1) * batch_size] + y_pred_batch = y_pred[b * batch_size : (b + 1) * batch_size] + metric.update_state(y_true_batch, y_pred_batch) + + self.assertAllClose(metric.total, 20) + result = metric.result() + self.assertAllClose( + result, np.sum((y_true - y_pred) ** 2) / num_samples + ) + metric.reset_state() + self.assertEqual(metric.result(), 0.0) + + def test_stateless_update_state(self): + metric = ExampleMetric(name="mse") + self.assertEqual(len(metric.variables), 2) + original_variable_values = ( + metric.variables[0].numpy(), + metric.variables[1].numpy(), + ) + + num_samples = 20 + y_true = np.random.random((num_samples, 3)) + y_pred = np.random.random((num_samples, 3)) + batch_size = 8 + metric_variables = metric.variables + for b in range(0, num_samples // batch_size + 1): + y_true_batch = y_true[b * batch_size : (b + 1) * batch_size] + y_pred_batch = y_pred[b * batch_size : (b + 1) * batch_size] + metric_variables = metric.stateless_update_state( + metric_variables, y_true_batch, y_pred_batch + ) + + self.assertAllClose(metric.variables[0], original_variable_values[0]) + self.assertAllClose(metric.variables[1], original_variable_values[1]) + metric.variables[0].assign(metric_variables[0]) + metric.variables[1].assign(metric_variables[1]) + self.assertAllClose(metric.total, 20) + result = metric.result() + self.assertAllClose( + result, np.sum((y_true - y_pred) ** 2) / num_samples + ) + + if backend.backend() == "jax": + # Check no side effects. + import jax + + @jax.jit + def update(metric_variables, y_true_batch, y_pred_batch): + metric_variables = metric.stateless_update_state( + metric_variables, y_true_batch, y_pred_batch + ) + + update(metric_variables, y_true_batch, y_pred_batch) + + def test_stateless_result(self): + metric = ExampleMetric(name="mse") + res = metric.stateless_result([ops.ones(()) * 12, ops.ones(()) * 3]) + self.assertAllClose(res, 4.0) + + def test_stateless_reset_state(self): + metric = ExampleMetric(name="mse") + num_samples = 20 + y_true = np.random.random((num_samples, 3)) + y_pred = np.random.random((num_samples, 3)) + metric.update_state(y_true, y_pred) + vars = metric.stateless_reset_state() + self.assertLen(vars, 2) + self.assertEqual(vars[0], 0) + self.assertEqual(vars[1], 0) + + def test_variable_tracking(self): + # In list + metric = ExampleMetric(name="mse") + metric.more_vars = [backend.Variable(0.0), backend.Variable(1.0)] + self.assertEqual(len(metric.variables), 4) + + # In dict + metric = ExampleMetric(name="mse") + metric.more_vars = { + "a": backend.Variable(0.0), + "b": backend.Variable(1.0), + } + self.assertEqual(len(metric.variables), 4) + + # In nested structured + metric = ExampleMetric(name="mse") + metric.more_vars = {"a": [backend.Variable(0.0), backend.Variable(1.0)]} + self.assertEqual(len(metric.variables), 4) + + def test_submetric_tracking(self): + # Plain attr + metric = ExampleMetric(name="mse") + metric.submetric = ExampleMetric(name="submse") + self.assertEqual(len(metric.variables), 4) + + # In list + metric = ExampleMetric(name="mse") + metric.submetrics = [ + ExampleMetric(name="submse1"), + ExampleMetric(name="submse2"), + ] + self.assertEqual(len(metric.variables), 6) + + # In dict + metric = ExampleMetric(name="mse") + metric.submetrics = { + "1": ExampleMetric(name="submse1"), + "2": ExampleMetric(name="submse2"), + } + self.assertEqual(len(metric.variables), 6) + + # Two levels deep + metric = ExampleMetric(name="mse") + metric.submetric = ExampleMetric(name="submse") + metric.submetric.submetric = ExampleMetric(name="subsubmse") + self.assertEqual(len(metric.variables), 6) + + def test_serialization(self): + self.run_class_serialization_test( + ExampleMetric(name="mse"), + custom_objects={"ExampleMetric": ExampleMetric}, + ) + + def test_pickle(self): + metric = metrics_module.get("mse") + reloaded = pickle.loads(pickle.dumps(metric)) + self.assertIsInstance(reloaded, metrics_module.MeanSquaredError) + + def test_get_method(self): + metric = metrics_module.get("mse") + self.assertIsInstance(metric, metrics_module.MeanSquaredError) + + metric = metrics_module.get("mean_squared_error") + self.assertIsInstance(metric, metrics_module.MeanSquaredError) + + metric = metrics_module.get("categorical_accuracy") + self.assertIsInstance(metric, metrics_module.CategoricalAccuracy) + + metric = metrics_module.get(None) + self.assertEqual(metric, None) + + with self.assertRaises(ValueError): + metrics_module.get("typo") + + def test_dtype_arg(self): + metric = ExampleMetric(name="mse", dtype="float16") + self.assertEqual(metric.name, "mse") + self.assertEqual(len(metric.variables), 2) + + num_samples = 10 + y_true = np.random.random((num_samples, 3)) + y_pred = np.random.random((num_samples, 3)) + metric.update_state(y_true, y_pred) + result = metric.result() + self.assertAllClose( + result, np.sum((y_true - y_pred) ** 2) / num_samples, atol=1e-3 + ) + self.assertDType(result, "float16") + + # Test DTypePolicy for `dtype` argument + metric = ExampleMetric( + dtype=dtype_policies.DTypePolicy("mixed_float16") + ) + metric.update_state(y_true, y_pred) + metric.update_state(y_true, y_pred) + result = metric.result() + self.assertAllClose( + result, np.sum((y_true - y_pred) ** 2) / num_samples, atol=1e-3 + ) + self.assertDType(result, "float16") + + # `dtype` setter should raise AttributeError + with self.assertRaises(AttributeError): + metric.dtype = "bfloat16" + + def test_default_dtype(self): + y_true = np.random.random((10, 3)) + y_pred = np.random.random((10, 3)) + + # Defaults to `keras.config.floatx()` not global `dtype_policy` + dtype_policies.dtype_policy.set_dtype_policy("mixed_float16") + metric = ExampleMetric() + metric.update_state(y_true, y_pred) + result = metric.result() + self.assertDType(result, "float32") + + backend.set_floatx("float16") + metric = ExampleMetric() + metric.update_state(y_true, y_pred) + result = metric.result() + self.assertDType(result, backend.floatx()) diff --git a/keras/src/metrics/metrics_utils.py b/keras/src/metrics/metrics_utils.py new file mode 100644 index 000000000000..d6f6df61d097 --- /dev/null +++ b/keras/src/metrics/metrics_utils.py @@ -0,0 +1,686 @@ +from enum import Enum + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.losses.loss import squeeze_or_expand_to_same_rank +from keras.src.utils.python_utils import to_list + +NEG_INF = -1e10 + + +def assert_thresholds_range(thresholds): + if thresholds is not None: + invalid_thresholds = [ + t for t in thresholds if t is None or t < 0 or t > 1 + ] + if invalid_thresholds: + raise ValueError( + "Threshold values must be in [0, 1]. " + f"Received: {invalid_thresholds}" + ) + + +def parse_init_thresholds(thresholds, default_threshold=0.5): + if thresholds is not None: + assert_thresholds_range(to_list(thresholds)) + thresholds = to_list( + default_threshold if thresholds is None else thresholds + ) + return thresholds + + +class ConfusionMatrix(Enum): + TRUE_POSITIVES = "tp" + FALSE_POSITIVES = "fp" + TRUE_NEGATIVES = "tn" + FALSE_NEGATIVES = "fn" + + +class AUCCurve(Enum): + """Type of AUC Curve (ROC or PR).""" + + ROC = "ROC" + PR = "PR" + PRGAIN = "PRGAIN" + + @staticmethod + def from_str(key): + if key in ("pr", "PR"): + return AUCCurve.PR + elif key in ("roc", "ROC"): + return AUCCurve.ROC + elif key in ("prgain", "PRGAIN"): + return AUCCurve.PRGAIN + else: + raise ValueError( + f'Invalid AUC curve value: "{key}". ' + 'Expected values are ["PR", "ROC", "PRGAIN"]' + ) + + +class AUCSummationMethod(Enum): + """Type of AUC summation method. + + https://en.wikipedia.org/wiki/Riemann_sum) + + Contains the following values: + * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For + `PR` curve, interpolates (true/false) positives but not the ratio that is + precision (see Davis & Goadrich 2006 for details). + * 'minoring': Applies left summation for increasing intervals and right + summation for decreasing intervals. + * 'majoring': Applies right summation for increasing intervals and left + summation for decreasing intervals. + """ + + INTERPOLATION = "interpolation" + MAJORING = "majoring" + MINORING = "minoring" + + @staticmethod + def from_str(key): + if key in ("interpolation", "Interpolation"): + return AUCSummationMethod.INTERPOLATION + elif key in ("majoring", "Majoring"): + return AUCSummationMethod.MAJORING + elif key in ("minoring", "Minoring"): + return AUCSummationMethod.MINORING + else: + raise ValueError( + f'Invalid AUC summation method value: "{key}". ' + 'Expected values are ["interpolation", "majoring", "minoring"]' + ) + + +def _update_confusion_matrix_variables_optimized( + variables_to_update, + y_true, + y_pred, + thresholds, + multi_label=False, + sample_weights=None, + label_weights=None, + thresholds_with_epsilon=False, +): + """Update confusion matrix variables with memory efficient alternative. + + Note that the thresholds need to be evenly distributed within the list, eg, + the diff between consecutive elements are the same. + + To compute TP/FP/TN/FN, we are measuring a binary classifier + C(t) = (predictions >= t) + at each threshold 't'. So we have + TP(t) = sum( C(t) * true_labels ) + FP(t) = sum( C(t) * false_labels ) + + But, computing C(t) requires computation for each t. To make it fast, + observe that C(t) is a cumulative integral, and so if we have + thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} + where n = num_thresholds, and if we can compute the bucket function + B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) + then we get + C(t_i) = sum( B(j), j >= i ) + which is the reversed cumulative sum in ops.cumsum(). + + We can compute B(i) efficiently by taking advantage of the fact that + our thresholds are evenly distributed, in that + width = 1.0 / (num_thresholds - 1) + thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] + Given a prediction value p, we can map it to its bucket by + bucket_index(p) = floor( p * (num_thresholds - 1) ) + so we can use ops.segment_sum() to update the buckets in one pass. + + Consider following example: + y_true = [0, 0, 1, 1] + y_pred = [0.1, 0.5, 0.3, 0.9] + thresholds = [0.0, 0.5, 1.0] + num_buckets = 2 # [0.0, 1.0], (1.0, 2.0] + bucket_index(y_pred) = ops.floor(y_pred * num_buckets) + = ops.floor([0.2, 1.0, 0.6, 1.8]) + = [0, 0, 0, 1] + # The meaning of this bucket is that if any of the label is true, + # then 1 will be added to the corresponding bucket with the index. + # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the + # label for 1.8 is true, then 1 will be added to bucket 1. + # + # Note the second item "1.0" is floored to 0, since the value need to be + # strictly larger than the bucket lower bound. + # In the implementation, we use ops.ceil() - 1 to achieve this. + tp_bucket_value = ops.segment_sum(true_labels, bucket_indices, + num_segments=num_thresholds) + = [1, 1, 0] + # For [1, 1, 0] here, it means there is 1 true value contributed by bucket + # 0, and 1 value contributed by bucket 1. When we aggregate them to + # together, the result become [a + b + c, b + c, c], since large thresholds + # will always contribute to the value for smaller thresholds. + true_positive = ops.cumsum(tp_bucket_value, reverse=True) + = [2, 1, 0] + + This implementation exhibits a run time and space complexity of O(T + N), + where T is the number of thresholds and N is the size of predictions. + Metrics that rely on standard implementation instead exhibit a complexity of + O(T * N). + + Args: + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid + keys and corresponding variables to update as values. + y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be + cast to `bool`. + y_pred: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + thresholds: A sorted floating point `Tensor` with value in `[0, 1]`. + It need to be evenly distributed (the diff between each element need + to be the same). + multi_label: Optional boolean indicating whether multidimensional + prediction/labels should be treated as multilabel responses, or + flattened into a single label. When True, the values of + `variables_to_update` must have a second dimension equal to the + number of labels in y_true and y_pred, and those tensors must not be + RaggedTensors. + sample_weights: Optional `Tensor` whose rank is either 0, or the same + rank as `y_true`, and must be broadcastable to `y_true` (i.e., all + dimensions must be either `1`, or the same as the corresponding + `y_true` dimension). + label_weights: Optional tensor of non-negative weights for multilabel + data. The weights are applied when calculating TP, FP, FN, and TN + without explicit multilabel handling (i.e. when the data is to be + flattened). + thresholds_with_epsilon: Optional boolean indicating whether the leading + and tailing thresholds has any epsilon added for floating point + imprecisions. It will change how we handle the leading and tailing + bucket. + """ + num_thresholds = ops.shape(thresholds)[0] + + if sample_weights is None: + sample_weights = 1.0 + else: + sample_weights = ops.broadcast_to( + ops.cast(sample_weights, dtype=y_pred.dtype), ops.shape(y_pred) + ) + if not multi_label: + sample_weights = ops.reshape(sample_weights, [-1]) + if label_weights is None: + label_weights = 1.0 + else: + label_weights = ops.expand_dims(label_weights, 0) + label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred)) + if not multi_label: + label_weights = ops.reshape(label_weights, [-1]) + weights = ops.cast( + ops.multiply(sample_weights, label_weights), y_true.dtype + ) + + # We shouldn't need this, but in case there are predict value that is out of + # the range of [0.0, 1.0] + y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0) + + y_true = ops.cast(ops.cast(y_true, "bool"), y_true.dtype) + if not multi_label: + y_true = ops.reshape(y_true, [-1]) + y_pred = ops.reshape(y_pred, [-1]) + + true_labels = ops.multiply(y_true, weights) + false_labels = ops.multiply((1.0 - y_true), weights) + + # Compute the bucket indices for each prediction value. + # Since the predict value has to be strictly greater than the thresholds, + # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket. + # We have to use math.ceil(val) - 1 for the bucket. + bucket_indices = ( + ops.ceil(y_pred * (ops.cast(num_thresholds, dtype=y_pred.dtype) - 1)) + - 1 + ) + + if thresholds_with_epsilon: + # In this case, the first bucket should actually take into account since + # the any prediction between [0.0, 1.0] should be larger than the first + # threshold. We change the bucket value from -1 to 0. + bucket_indices = ops.relu(bucket_indices) + + bucket_indices = ops.cast(bucket_indices, "int32") + + if multi_label: + # We need to run bucket segment sum for each of the label class. In the + # multi_label case, the rank of the label is 2. We first transpose it so + # that the label dim becomes the first and we can parallel run though + # them. + true_labels = ops.transpose(true_labels) + false_labels = ops.transpose(false_labels) + bucket_indices = ops.transpose(bucket_indices) + + def gather_bucket(label_and_bucket_index): + label, bucket_index = ( + label_and_bucket_index[0], + label_and_bucket_index[1], + ) + return ops.segment_sum( + data=label, + segment_ids=bucket_index, + num_segments=num_thresholds, + ) + + tp_bucket_v = backend.vectorized_map( + gather_bucket, + (true_labels, bucket_indices), + ) + fp_bucket_v = backend.vectorized_map( + gather_bucket, (false_labels, bucket_indices) + ) + tp = ops.transpose(ops.flip(ops.cumsum(ops.flip(tp_bucket_v), axis=1))) + fp = ops.transpose(ops.flip(ops.cumsum(ops.flip(fp_bucket_v), axis=1))) + else: + tp_bucket_v = ops.segment_sum( + data=true_labels, + segment_ids=bucket_indices, + num_segments=num_thresholds, + ) + fp_bucket_v = ops.segment_sum( + data=false_labels, + segment_ids=bucket_indices, + num_segments=num_thresholds, + ) + tp = ops.flip(ops.cumsum(ops.flip(tp_bucket_v))) + fp = ops.flip(ops.cumsum(ops.flip(fp_bucket_v))) + + # fn = sum(true_labels) - tp + # tn = sum(false_labels) - fp + if ( + ConfusionMatrix.TRUE_NEGATIVES in variables_to_update + or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update + ): + if multi_label: + total_true_labels = ops.sum(true_labels, axis=1) + total_false_labels = ops.sum(false_labels, axis=1) + else: + total_true_labels = ops.sum(true_labels) + total_false_labels = ops.sum(false_labels) + + if ConfusionMatrix.TRUE_POSITIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES] + variable.assign(variable + tp) + if ConfusionMatrix.FALSE_POSITIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES] + variable.assign(variable + fp) + if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES] + tn = total_false_labels - fp + variable.assign(variable + tn) + if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update: + variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES] + fn = total_true_labels - tp + variable.assign(variable + fn) + + +def is_evenly_distributed_thresholds(thresholds): + """Check if the thresholds list is evenly distributed. + + We could leverage evenly distributed thresholds to use less memory when + calculate metrics like AUC where each individual threshold need to be + evaluated. + + Args: + thresholds: A python list or tuple, or 1D numpy array whose value is + ranged in [0, 1]. + + Returns: + boolean, whether the values in the inputs are evenly distributed. + """ + # Check the list value and see if it is evenly distributed. + num_thresholds = len(thresholds) + if num_thresholds < 3: + return False + even_thresholds = np.arange(num_thresholds, dtype=np.float32) / ( + num_thresholds - 1 + ) + return np.allclose(thresholds, even_thresholds, atol=backend.epsilon()) + + +def update_confusion_matrix_variables( + variables_to_update, + y_true, + y_pred, + thresholds, + top_k=None, + class_id=None, + sample_weight=None, + multi_label=False, + label_weights=None, + thresholds_distributed_evenly=False, +): + """Updates the given confusion matrix variables. + + For every pair of values in y_true and y_pred: + + true_positive: y_true == True and y_pred > thresholds + false_negatives: y_true == True and y_pred <= thresholds + true_negatives: y_true == False and y_pred <= thresholds + false_positive: y_true == False and y_pred > thresholds + + The results will be weighted and added together. When multiple thresholds + are provided, we will repeat the same for every threshold. + + For estimation of these metrics over a stream of data, the function creates + an `update_op` operation that updates the given variables. + + If `sample_weight` is `None`, weights default to 1. + Use weights of 0 to mask values. + + Args: + variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys + and corresponding variables to update as values. + y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. + y_pred: A floating point `Tensor` of arbitrary shape and whose values are + in the range `[0, 1]`. + thresholds: A float value, float tensor, python list, or tuple of float + thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). + top_k: Optional int, indicates that the positive labels should be limited + to the top k predictions. + class_id: Optional int, limits the prediction and labels to the class + specified by this argument. + sample_weight: Optional `Tensor` whose rank is either 0, or the same rank + as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions + must be either `1`, or the same as the corresponding `y_true` + dimension). + multi_label: Optional boolean indicating whether multidimensional + prediction/labels should be treated as multilabel responses, or + flattened into a single label. When True, the values of + `variables_to_update` must have a second dimension equal to the number + of labels in y_true and y_pred, and those tensors must not be + RaggedTensors. + label_weights: (optional) tensor of non-negative weights for multilabel + data. The weights are applied when calculating TP, FP, FN, and TN + without explicit multilabel handling (i.e. when the data is to be + flattened). + thresholds_distributed_evenly: Boolean, whether the thresholds are evenly + distributed within the list. An optimized method will be used if this is + the case. See _update_confusion_matrix_variables_optimized() for more + details. + + Raises: + ValueError: If `y_pred` and `y_true` have mismatched shapes, or if + `sample_weight` is not `None` and its shape doesn't match `y_pred`, or + if `variables_to_update` contains invalid keys. + """ + if multi_label and label_weights is not None: + raise ValueError( + "`label_weights` for multilabel data should be handled " + "outside of `update_confusion_matrix_variables` when " + "`multi_label` is True." + ) + if variables_to_update is None: + return + if not any( + key for key in variables_to_update if key in list(ConfusionMatrix) + ): + raise ValueError( + "Please provide at least one valid confusion matrix " + "variable to update. Valid variable key options are: " + f'"{list(ConfusionMatrix)}". ' + f'Received: "{variables_to_update.keys()}"' + ) + + variable_dtype = list(variables_to_update.values())[0].dtype + + y_true = ops.cast(y_true, dtype=variable_dtype) + y_pred = ops.cast(y_pred, dtype=variable_dtype) + + if thresholds_distributed_evenly: + # Check whether the thresholds has any leading or tailing epsilon added + # for floating point imprecision. The leading and tailing threshold will + # be handled bit differently as the corner case. At this point, + # thresholds should be a list/array with more than 2 items, and ranged + # between [0, 1]. See is_evenly_distributed_thresholds() for more + # details. + thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0 + + thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype) + num_thresholds = ops.shape(thresholds)[0] + + if multi_label: + one_thresh = ops.equal( + np.array(1, dtype="int32"), + len(thresholds.shape), + ) + else: + one_thresh = np.array(True, dtype="bool") + + invalid_keys = [ + key for key in variables_to_update if key not in list(ConfusionMatrix) + ] + if invalid_keys: + raise ValueError( + f'Invalid keys: "{invalid_keys}". ' + f'Valid variable key options are: "{list(ConfusionMatrix)}"' + ) + + y_pred, y_true = squeeze_or_expand_to_same_rank(y_pred, y_true) + if sample_weight is not None: + sample_weight = ops.expand_dims( + ops.cast(sample_weight, dtype=variable_dtype), axis=-1 + ) + _, sample_weight = squeeze_or_expand_to_same_rank( + y_true, sample_weight, expand_rank_1=False + ) + + if top_k is not None: + y_pred = _filter_top_k(y_pred, top_k) + + if class_id is not None: + if len(y_pred.shape) == 1: + raise ValueError( + "When class_id is provided, y_pred must be a 2D array " + "with shape (num_samples, num_classes), found shape: " + f"{y_pred.shape}" + ) + + # Preserve dimension to match with sample_weight + y_true = y_true[..., class_id, None] + y_pred = y_pred[..., class_id, None] + + if thresholds_distributed_evenly: + return _update_confusion_matrix_variables_optimized( + variables_to_update, + y_true, + y_pred, + thresholds, + multi_label=multi_label, + sample_weights=sample_weight, + label_weights=label_weights, + thresholds_with_epsilon=thresholds_with_epsilon, + ) + + if None in y_pred.shape: + pred_shape = ops.shape(y_pred) + num_predictions = pred_shape[0] + if len(y_pred.shape) == 1: + num_labels = 1 + else: + num_labels = ops.cast( + ops.prod(ops.array(pred_shape[1:]), axis=0), "int32" + ) + thresh_label_tile = ops.where(one_thresh, num_labels, 1) + else: + pred_shape = ops.shape(y_pred) + num_predictions = pred_shape[0] + if len(y_pred.shape) == 1: + num_labels = 1 + else: + num_labels = np.prod(pred_shape[1:], axis=0).astype("int32") + thresh_label_tile = np.where(one_thresh, num_labels, 1) + + # Reshape predictions and labels, adding a dim for thresholding. + if multi_label: + predictions_extra_dim = ops.expand_dims(y_pred, 0) + labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype="bool"), 0) + else: + # Flatten predictions and labels when not multilabel. + predictions_extra_dim = ops.reshape(y_pred, [1, -1]) + labels_extra_dim = ops.reshape(ops.cast(y_true, dtype="bool"), [1, -1]) + + # Tile the thresholds for every prediction. + if multi_label: + thresh_pretile_shape = [num_thresholds, 1, -1] + thresh_tiles = [1, num_predictions, thresh_label_tile] + data_tiles = [num_thresholds, 1, 1] + else: + thresh_pretile_shape = [num_thresholds, -1] + thresh_tiles = [1, num_predictions * num_labels] + data_tiles = [num_thresholds, 1] + + thresh_tiled = ops.tile( + ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles + ) + + # Tile the predictions for every threshold. + preds_tiled = ops.tile(predictions_extra_dim, data_tiles) + + # Compare predictions and threshold. + pred_is_pos = ops.greater(preds_tiled, thresh_tiled) + + # Tile labels by number of thresholds + label_is_pos = ops.tile(labels_extra_dim, data_tiles) + + if sample_weight is not None: + sample_weight = ops.broadcast_to( + ops.cast(sample_weight, dtype=y_pred.dtype), ops.shape(y_pred) + ) + weights_tiled = ops.tile( + ops.reshape(sample_weight, thresh_tiles), data_tiles + ) + else: + weights_tiled = None + + if label_weights is not None and not multi_label: + label_weights = ops.expand_dims(label_weights, 0) + label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred)) + label_weights_tiled = ops.tile( + ops.reshape(label_weights, thresh_tiles), data_tiles + ) + if weights_tiled is None: + weights_tiled = label_weights_tiled + else: + weights_tiled = ops.multiply(weights_tiled, label_weights_tiled) + + def weighted_assign_add(label, pred, weights, var): + label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype) + if weights is not None: + label_and_pred *= ops.cast(weights, dtype=var.dtype) + var.assign(var + ops.sum(label_and_pred, 1)) + + loop_vars = { + ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), + } + update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update + update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update + update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update + + if update_fn or update_tn: + pred_is_neg = ops.logical_not(pred_is_pos) + loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) + + if update_fp or update_tn: + label_is_neg = ops.logical_not(label_is_pos) + loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) + if update_tn: + loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = ( + label_is_neg, + pred_is_neg, + ) + + for matrix_cond, (label, pred) in loop_vars.items(): + if matrix_cond in variables_to_update: + weighted_assign_add( + label, pred, weights_tiled, variables_to_update[matrix_cond] + ) + + +def _filter_top_k(x, k): + """Filters top-k values in the last dim of x and set the rest to NEG_INF. + + Used for computing top-k prediction values in dense labels (which has the + same shape as predictions) for recall and precision top-k metrics. + + Args: + x: tensor with any dimensions. + k: the number of values to keep. + + Returns: + tensor with same shape and dtype as x. + """ + _, top_k_idx = ops.top_k(x, k) + top_k_mask = ops.sum( + ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2 + ) + return x * top_k_mask + NEG_INF * (1 - top_k_mask) + + +def confusion_matrix( + labels, + predictions, + num_classes, + weights=None, + dtype="int32", +): + """Computes the confusion matrix from predictions and labels. + + The matrix columns represent the prediction labels and the rows represent + the real labels. The confusion matrix is always a 2-D array of shape + `(n, n)`, where `n` is the number of valid labels for a given classification + task. Both prediction and labels must be 1-D arrays of the same shape in + order for this function to work. + + If `num_classes` is `None`, then `num_classes` will be set to one plus the + maximum value in either predictions or labels. Class labels are expected to + start at 0. For example, if `num_classes` is 3, then the possible labels + would be `[0, 1, 2]`. + + If `weights` is not `None`, then each prediction contributes its + corresponding weight to the total value of the confusion matrix cell. + + For example: + + ```python + keras.metrics.metrics_utils.confusion_matrix([1, 2, 4], [2, 2, 4]) ==> + [[0 0 0 0 0] + [0 0 1 0 0] + [0 0 1 0 0] + [0 0 0 0 0] + [0 0 0 0 1]] + ``` + + Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`, + resulting in a 5x5 confusion matrix. + + Args: + labels: 1-D tensor of real labels for the classification task. + predictions: 1-D tensor of predictions for a given classification. + num_classes: The possible number of labels the classification + task can have. + weights: An optional tensor whose shape matches `predictions`. + dtype: Data type of the confusion matrix. + + Returns: + A tensor of type `dtype` with shape `(n, n)` representing the confusion + matrix, where `n` is the number of possible labels in the classification + task. + """ + labels = ops.convert_to_tensor(labels, dtype) + predictions = ops.convert_to_tensor(predictions, dtype) + labels, predictions = squeeze_or_expand_to_same_rank(labels, predictions) + + predictions = ops.cast(predictions, dtype) + labels = ops.cast(labels, dtype) + + if weights is not None: + weights = ops.convert_to_tensor(weights, dtype) + + indices = ops.stack([labels, predictions], axis=1) + values = ops.ones_like(predictions, dtype) if weights is None else weights + indices = ops.cast(indices, dtype="int64") + values = ops.cast(values, dtype=dtype) + num_classes = int(num_classes) + confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes)) + return confusion_matrix diff --git a/keras/src/metrics/probabilistic_metrics.py b/keras/src/metrics/probabilistic_metrics.py new file mode 100644 index 000000000000..2f719d84630e --- /dev/null +++ b/keras/src/metrics/probabilistic_metrics.py @@ -0,0 +1,339 @@ +from keras.src.api_export import keras_export +from keras.src.losses.losses import binary_crossentropy +from keras.src.losses.losses import categorical_crossentropy +from keras.src.losses.losses import kl_divergence +from keras.src.losses.losses import poisson +from keras.src.losses.losses import sparse_categorical_crossentropy +from keras.src.metrics import reduction_metrics + + +@keras_export("keras.metrics.KLDivergence") +class KLDivergence(reduction_metrics.MeanMetricWrapper): + """Computes Kullback-Leibler divergence metric between `y_true` and + `y_pred`. + + Formula: + + ```python + metric = y_true * log(y_true / y_pred) + ``` + + `y_true` and `y_pred` are expected to be probability + distributions, with values between 0 and 1. They will get + clipped to the `[0, 1]` range. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.KLDivergence() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.result() + 0.45814306 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.9162892 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mse', + metrics=[keras.metrics.KLDivergence()]) + ``` + """ + + def __init__(self, name="kl_divergence", dtype=None): + super().__init__(fn=kl_divergence, name=name, dtype=dtype) + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.Poisson") +class Poisson(reduction_metrics.MeanMetricWrapper): + """Computes the Poisson metric between `y_true` and `y_pred`. + + Formula: + + ```python + metric = y_pred - y_true * log(y_pred) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.Poisson() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 0.49999997 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.99999994 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mse', + metrics=[keras.metrics.Poisson()]) + ``` + """ + + def __init__(self, name="poisson", dtype=None): + super().__init__(fn=poisson, name=name, dtype=dtype) + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.BinaryCrossentropy") +class BinaryCrossentropy(reduction_metrics.MeanMetricWrapper): + """Computes the crossentropy metric between the labels and predictions. + + This is the crossentropy metric class to be used when there are only two + label classes (0 and 1). + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected + to be a logits tensor. By default, we consider + that output encodes a probability distribution. + label_smoothing: (Optional) Float in `[0, 1]`. + When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. + e.g. `label_smoothing=0.2` means that we will use + a value of 0.1 for label "0" and 0.9 for label "1". + + Examples: + + >>> m = keras.metrics.BinaryCrossentropy() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) + >>> m.result() + 0.81492424 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.9162905 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.BinaryCrossentropy()]) + ``` + """ + + def __init__( + self, + name="binary_crossentropy", + dtype=None, + from_logits=False, + label_smoothing=0, + ): + super().__init__( + binary_crossentropy, + name, + dtype=dtype, + from_logits=from_logits, + label_smoothing=label_smoothing, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + } + + +@keras_export("keras.metrics.CategoricalCrossentropy") +class CategoricalCrossentropy(reduction_metrics.MeanMetricWrapper): + """Computes the crossentropy metric between the labels and predictions. + + This is the crossentropy metric class to be used when there are multiple + label classes (2 or more). It assumes that labels are one-hot encoded, + e.g., when labels values are `[2, 0, 1]`, then + `y_true` is `[[0, 0, 1], [1, 0, 0], [0, 1, 0]]`. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected to be + a logits tensor. By default, we consider that output + encodes a probability distribution. + label_smoothing: (Optional) Float in `[0, 1]`. + When > 0, label values are smoothed, meaning the confidence + on label values are relaxed. e.g. `label_smoothing=0.2` means + that we will use a value of 0.1 for label + "0" and 0.9 for label "1". + axis: (Optional) Defaults to `-1`. + The dimension along which entropy is computed. + + Examples: + + >>> # EPSILON = 1e-7, y = y_true, y` = y_pred + >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) + >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] + >>> # xent = -sum(y * log(y'), axis = -1) + >>> # = -((log 0.95), (log 0.1)) + >>> # = [0.051, 2.302] + >>> # Reduced xent = (0.051 + 2.302) / 2 + >>> m = keras.metrics.CategoricalCrossentropy() + >>> m.update_state([[0, 1, 0], [0, 0, 1]], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> m.result() + 1.1769392 + + >>> m.reset_state() + >>> m.update_state([[0, 1, 0], [0, 0, 1]], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... sample_weight=np.array([0.3, 0.7])) + >>> m.result() + 1.6271976 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.CategoricalCrossentropy()]) + ``` + """ + + def __init__( + self, + name="categorical_crossentropy", + dtype=None, + from_logits=False, + label_smoothing=0, + axis=-1, + ): + super().__init__( + categorical_crossentropy, + name, + dtype=dtype, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.label_smoothing = label_smoothing + self.axis = axis + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "from_logits": self.from_logits, + "label_smoothing": self.label_smoothing, + "axis": self.axis, + } + + +@keras_export("keras.metrics.SparseCategoricalCrossentropy") +class SparseCategoricalCrossentropy(reduction_metrics.MeanMetricWrapper): + """Computes the crossentropy metric between the labels and predictions. + + Use this crossentropy metric when there are two or more label classes. + It expects labels to be provided as integers. If you want to provide labels + that are one-hot encoded, please use the `CategoricalCrossentropy` + metric instead. + + There should be `num_classes` floating point values per feature for `y_pred` + and a single floating point value per feature for `y_true`. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + from_logits: (Optional) Whether output is expected + to be a logits tensor. By default, we consider that output + encodes a probability distribution. + axis: (Optional) Defaults to `-1`. + The dimension along which entropy is computed. + + Examples: + + >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] + >>> # logits = log(y_pred) + >>> # softmax = exp(logits) / sum(exp(logits), axis=-1) + >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] + >>> # xent = -sum(y * log(softmax), 1) + >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181], + >>> # [-2.3026, -0.2231, -2.3026]] + >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]] + >>> # xent = [0.0513, 2.3026] + >>> # Reduced xent = (0.0513 + 2.3026) / 2 + >>> m = keras.metrics.SparseCategoricalCrossentropy() + >>> m.update_state([1, 2], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + >>> m.result() + 1.1769392 + + >>> m.reset_state() + >>> m.update_state([1, 2], + ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], + ... sample_weight=np.array([0.3, 0.7])) + >>> m.result() + 1.6271976 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.SparseCategoricalCrossentropy()]) + ``` + """ + + def __init__( + self, + name="sparse_categorical_crossentropy", + dtype=None, + from_logits=False, + axis=-1, + ): + super().__init__( + sparse_categorical_crossentropy, + name=name, + dtype=dtype, + from_logits=from_logits, + axis=axis, + ) + self.from_logits = from_logits + self.axis = axis + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return { + "name": self.name, + "dtype": self.dtype, + "from_logits": self.from_logits, + "axis": self.axis, + } diff --git a/keras/src/metrics/probabilistic_metrics_test.py b/keras/src/metrics/probabilistic_metrics_test.py new file mode 100644 index 000000000000..0950b277893b --- /dev/null +++ b/keras/src/metrics/probabilistic_metrics_test.py @@ -0,0 +1,236 @@ +import numpy as np + +from keras.src import metrics +from keras.src import testing + + +class KLDivergenceTest(testing.TestCase): + def setup(self): + self.y_pred = np.asarray( + [0.4, 0.9, 0.12, 0.36, 0.3, 0.4], dtype=np.float32 + ).reshape((2, 3)) + self.y_true = np.asarray( + [0.5, 0.8, 0.12, 0.7, 0.43, 0.8], dtype=np.float32 + ).reshape((2, 3)) + + self.batch_size = 2 + self.expected_results = np.multiply( + self.y_true, np.log(self.y_true / self.y_pred) + ) + + def test_config(self): + k_obj = metrics.KLDivergence(name="kld", dtype="int32") + self.assertEqual(k_obj.name, "kld") + self.assertEqual(k_obj._dtype, "int32") + + k_obj2 = metrics.KLDivergence.from_config(k_obj.get_config()) + self.assertEqual(k_obj2.name, "kld") + self.assertEqual(k_obj2._dtype, "int32") + + def test_unweighted(self): + self.setup() + k_obj = metrics.KLDivergence() + + k_obj.update_state(self.y_true, self.y_pred) + result = k_obj.result() + expected_result = np.sum(self.expected_results) / self.batch_size + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + self.setup() + k_obj = metrics.KLDivergence() + + sample_weight = np.asarray([1.2, 3.4], dtype=np.float32).reshape((2, 1)) + result = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + + sample_weight = np.asarray( + [1.2, 1.2, 1.2, 3.4, 3.4, 3.4], dtype=np.float32 + ).reshape((2, 3)) + expected_result = np.multiply(self.expected_results, sample_weight) + expected_result = np.sum(expected_result) / (1.2 + 3.4) + self.assertAllClose(result, expected_result, atol=1e-3) + + +class PoissonTest(testing.TestCase): + def setup(self): + self.y_pred = np.asarray([1, 9, 2, 5, 2, 6], dtype=np.float32).reshape( + (2, 3) + ) + self.y_true = np.asarray([4, 8, 12, 8, 1, 3], dtype=np.float32).reshape( + (2, 3) + ) + self.batch_size = 6 + self.expected_results = self.y_pred - np.multiply( + self.y_true, np.log(self.y_pred) + ) + + def test_config(self): + self.run_class_serialization_test(metrics.Poisson(name="poisson")) + + def test_unweighted(self): + self.setup() + poisson_obj = metrics.Poisson() + poisson_obj.update_state(self.y_true, self.y_pred) + + result = poisson_obj.result() + expected_result = np.sum(self.expected_results) / self.batch_size + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + self.setup() + poisson_obj = metrics.Poisson() + sample_weight = np.asarray([1.2, 3.4], dtype=np.float32).reshape((2, 1)) + + result = poisson_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + sample_weight = np.asarray( + [1.2, 1.2, 1.2, 3.4, 3.4, 3.4], dtype=np.float32 + ).reshape((2, 3)) + expected_result = np.multiply(self.expected_results, sample_weight) + expected_result = np.sum(expected_result) / np.sum(sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + +class BinaryCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + metrics.BinaryCrossentropy( + name="bce", dtype="int32", label_smoothing=0.2 + ) + ) + + def test_unweighted(self): + bce_obj = metrics.BinaryCrossentropy() + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2]) + result = bce_obj(y_true, y_pred) + self.assertAllClose(result, 3.9855, atol=1e-3) + + def test_unweighted_with_logits(self): + bce_obj = metrics.BinaryCrossentropy(from_logits=True) + + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + y_pred = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + result = bce_obj(y_true, y_pred) + self.assertAllClose(result, 3.333, atol=1e-3) + + def test_weighted(self): + bce_obj = metrics.BinaryCrossentropy() + y_true = np.array([1, 0, 1, 0]).reshape([2, 2]) + y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2]) + sample_weight = np.array([1.5, 2.0]) + result = bce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 3.4162, atol=1e-3) + + def test_weighted_from_logits(self): + bce_obj = metrics.BinaryCrossentropy(from_logits=True) + y_true = np.array([[1, 0, 1], [0, 1, 1]]) + y_pred = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]]) + sample_weight = np.array([2.0, 2.5]) + result = bce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 3.7037, atol=1e-3) + + def test_label_smoothing(self): + logits = np.array(((10.0, -10.0, -10.0))) + y_true = np.array(((1, 0, 1))) + label_smoothing = 0.1 + bce_obj = metrics.BinaryCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + result = bce_obj(y_true, logits) + expected_value = (10.0 + 5.0 * label_smoothing) / 3.0 + self.assertAllClose(expected_value, result, atol=1e-3) + + +class CategoricalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + metrics.CategoricalCrossentropy( + name="cce", dtype="int32", label_smoothing=0.2 + ) + ) + + def test_unweighted(self): + cce_obj = metrics.CategoricalCrossentropy() + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + result = cce_obj(y_true, y_pred) + self.assertAllClose(result, 1.176, atol=1e-3) + + def test_unweighted_from_logits(self): + cce_obj = metrics.CategoricalCrossentropy(from_logits=True) + + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + result = cce_obj(y_true, logits) + self.assertAllClose(result, 3.5011, atol=1e-3) + + def test_weighted(self): + cce_obj = metrics.CategoricalCrossentropy() + + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + sample_weight = np.array([1.5, 2.0]) + result = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 1.338, atol=1e-3) + + def test_weighted_from_logits(self): + cce_obj = metrics.CategoricalCrossentropy(from_logits=True) + + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + sample_weight = np.array([1.5, 2.0]) + result = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAllClose(result, 4.0012, atol=1e-3) + + def test_label_smoothing(self): + y_true = np.array([[0, 1, 0], [0, 0, 1]]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + label_smoothing = 0.1 + cce_obj = metrics.CategoricalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + self.assertAllClose(loss, 3.667, atol=1e-3) + + +class SparseCategoricalCrossentropyTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + metrics.SparseCategoricalCrossentropy(name="scce", dtype="int32") + ) + + def test_unweighted(self): + scce_obj = metrics.SparseCategoricalCrossentropy() + + y_true = np.array([1, 2]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + result = scce_obj(y_true, y_pred) + self.assertAllClose(result, 1.176, atol=1e-3) + + def test_unweighted_from_logits(self): + scce_obj = metrics.SparseCategoricalCrossentropy(from_logits=True) + + y_true = np.array([1, 2]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + result = scce_obj(y_true, logits) + self.assertAllClose(result, 3.5011, atol=1e-3) + + def test_weighted(self): + scce_obj = metrics.SparseCategoricalCrossentropy() + + y_true = np.array([1, 2]) + y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) + sample_weight = np.array([1.5, 2.0]) + result = scce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(result, 1.338, atol=1e-3) + + def test_weighted_from_logits(self): + scce_obj = metrics.SparseCategoricalCrossentropy(from_logits=True) + + y_true = np.array([1, 2]) + logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32) + sample_weight = np.array([1.5, 2.0]) + result = scce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAllClose(result, 4.0012, atol=1e-3) diff --git a/keras/src/metrics/reduction_metrics.py b/keras/src/metrics/reduction_metrics.py new file mode 100644 index 000000000000..d9bcddfd59cb --- /dev/null +++ b/keras/src/metrics/reduction_metrics.py @@ -0,0 +1,219 @@ +from keras.src import backend +from keras.src import initializers +from keras.src import losses +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.metrics.metric import Metric +from keras.src.saving import serialization_lib + + +def reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype): + dtype = dtype or backend.floatx() + mask = backend.get_keras_mask(values) + values = ops.cast(values, dtype=dtype) + if sample_weight is not None: + sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype) + + if mask is not None: + sample_weight = losses.loss.apply_mask( + sample_weight, mask, dtype=dtype, reduction="sum" + ) + # Update dimensions of weights to match with values if possible. + values, sample_weight = losses.loss.squeeze_or_expand_to_same_rank( + values, sample_weight + ) + # Reduce values to same ndim as weight array. + weight_ndim = len(sample_weight.shape) + values_ndim = len(values.shape) + if values_ndim > weight_ndim: + values = reduce_fn( + values, axis=list(range(weight_ndim, values_ndim)) + ) + # Broadcast sample_weight. It doesn't change the multiplication below + # but changes the sample_weight reduction applied later. + sample_weight = ops.broadcast_to(sample_weight, ops.shape(values)) + values = values * sample_weight + if weight_ndim > 1: + sample_weight = reduce_fn( + sample_weight, axis=list(range(1, weight_ndim)) + ) + + values_ndim = len(values.shape) + if values_ndim > 1: + values = reduce_fn(values, axis=list(range(1, values_ndim))) + return values, sample_weight + + +@keras_export("keras.metrics.Sum") +class Sum(Metric): + """Compute the (weighted) sum of the given values. + + For example, if `values` is `[1, 3, 5, 7]` then their sum is 16. + If `sample_weight` was specified as `[1, 1, 0, 0]` then the sum would be 4. + + This metric creates one variable, `total`. + This is ultimately returned as the sum value. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = metrics.Sum() + >>> m.update_state([1, 3, 5, 7]) + >>> m.result() + 16.0 + + >>> m = metrics.Sum() + >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) + >>> m.result() + 4.0 + """ + + def __init__(self, name="sum", dtype=None): + super().__init__(name=name, dtype=dtype) + self.total = self.add_variable( + shape=(), + initializer=initializers.Zeros(), + dtype=self.dtype, + name="total", + ) + + def update_state(self, values, sample_weight=None): + values, _ = reduce_to_samplewise_values( + values, sample_weight, reduce_fn=ops.sum, dtype=self.dtype + ) + self.total.assign_add(ops.sum(values)) + + def reset_state(self): + self.total.assign(0) + + def result(self): + return ops.cast(self.total, self.dtype) + + +@keras_export("keras.metrics.Mean") +class Mean(Metric): + """Compute the (weighted) mean of the given values. + + For example, if values is `[1, 3, 5, 7]` then the mean is 4. + If `sample_weight` was specified as `[1, 1, 0, 0]` then the mean would be 2. + + This metric creates two variables, `total` and `count`. + The mean value returned is simply `total` divided by `count`. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + + >>> m = Mean() + >>> m.update_state([1, 3, 5, 7]) + >>> m.result() + 4.0 + + >>> m.reset_state() + >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) + >>> m.result() + 2.0 + """ + + def __init__(self, name="mean", dtype=None): + super().__init__(name=name, dtype=dtype) + self.total = self.add_variable( + shape=(), + initializer=initializers.Zeros(), + dtype=self.dtype, + name="total", + ) + self.count = self.add_variable( + shape=(), + initializer=initializers.Zeros(), + dtype=self.dtype, + name="count", + ) + + def update_state(self, values, sample_weight=None): + values, sample_weight = reduce_to_samplewise_values( + values, sample_weight, reduce_fn=ops.mean, dtype=self.dtype + ) + self.total.assign_add(ops.sum(values)) + if sample_weight is not None: + num_samples = ops.sum(sample_weight) + elif len(values.shape) >= 1: + num_samples = ops.shape(values)[0] + else: + num_samples = 1 + self.count.assign_add(ops.cast(num_samples, dtype=self.dtype)) + + def reset_state(self): + self.total.assign(0) + self.count.assign(0) + + def result(self): + return ops.divide_no_nan( + self.total, ops.cast(self.count, dtype=self.dtype) + ) + + +@keras_export("keras.metrics.MeanMetricWrapper") +class MeanMetricWrapper(Mean): + """Wrap a stateless metric function with the `Mean` metric. + + You could use this class to quickly build a mean metric from a function. The + function needs to have the signature `fn(y_true, y_pred)` and return a + per-sample loss array. `MeanMetricWrapper.result()` will return + the average metric value across all samples seen so far. + + For example: + + ```python + def mse(y_true, y_pred): + return (y_true - y_pred) ** 2 + + mse_metric = MeanMetricWrapper(fn=mse) + ``` + + Args: + fn: The metric function to wrap, with signature + `fn(y_true, y_pred, **kwargs)`. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + **kwargs: Keyword arguments to pass on to `fn`. + """ + + def __init__(self, fn, name=None, dtype=None, **kwargs): + super().__init__(name=name, dtype=dtype) + self._fn = fn + self._fn_kwargs = kwargs + + # If we are wrapping a Keras loss, register the metric's + # direction as "down" (needs to be minimized during training). + if ( + self._fn in losses.ALL_OBJECTS + or hasattr(self._fn, "__class__") + and self._fn.__class__ in losses.ALL_OBJECTS + ): + self._direction = "down" + + def update_state(self, y_true, y_pred, sample_weight=None): + mask = backend.get_keras_mask(y_pred) + values = self._fn(y_true, y_pred, **self._fn_kwargs) + sample_weight = losses.loss.apply_mask( + sample_weight, mask, dtype=self.dtype, reduction="sum" + ) + return super().update_state(values, sample_weight=sample_weight) + + def get_config(self): + base_config = super().get_config() + config = {"fn": serialization_lib.serialize_keras_object(self._fn)} + config.update(serialization_lib.serialize_keras_object(self._fn_kwargs)) + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + if "fn" in config: + config = serialization_lib.deserialize_keras_object(config) + return cls(**config) diff --git a/keras/src/metrics/reduction_metrics_test.py b/keras/src/metrics/reduction_metrics_test.py new file mode 100644 index 000000000000..679bed081804 --- /dev/null +++ b/keras/src/metrics/reduction_metrics_test.py @@ -0,0 +1,193 @@ +import numpy as np + +from keras.src import backend +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import testing +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.metrics import reduction_metrics +from keras.src.saving import register_keras_serializable + + +class SumTest(testing.TestCase): + def test_config(self): + sum_obj = reduction_metrics.Sum(name="sum", dtype="float32") + self.assertEqual(sum_obj.name, "sum") + self.assertEqual(len(sum_obj.variables), 1) + self.assertEqual(sum_obj._dtype, "float32") + + # Check save and restore config + sum_obj2 = reduction_metrics.Sum.from_config(sum_obj.get_config()) + self.assertEqual(sum_obj2.name, "sum") + self.assertEqual(len(sum_obj2.variables), 1) + self.assertEqual(sum_obj2._dtype, "float32") + + def test_unweighted(self): + sum_obj = reduction_metrics.Sum(name="sum", dtype="float32") + sum_obj.update_state([1, 3, 5, 7]) + result = sum_obj.result() + self.assertAllClose(result, 16.0, atol=1e-3) + + def test_weighted(self): + sum_obj = reduction_metrics.Sum(name="sum", dtype="float32") + sum_obj.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) + result = sum_obj.result() + self.assertAllClose(result, 4.0, atol=1e-3) + + def test_weighted_nd(self): + sum_obj = reduction_metrics.Sum(name="sum", dtype="float32") + sum_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 1], [1, 0]]) + result = sum_obj.result() + self.assertAllClose(result, 9.0, atol=1e-3) + + def test_weighted_nd_broadcast(self): + sum_obj = reduction_metrics.Sum(name="sum", dtype="float32") + sum_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 0]]) + result = sum_obj.result() + self.assertAllClose(result, 6.0, atol=1e-3) + + +class MeanTest(testing.TestCase): + def test_config(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + self.assertEqual(mean_obj.name, "mean") + self.assertEqual(len(mean_obj.variables), 2) + self.assertEqual(mean_obj._dtype, "float32") + + # Check save and restore config + mean_obj2 = reduction_metrics.Mean.from_config(mean_obj.get_config()) + self.assertEqual(mean_obj2.name, "mean") + self.assertEqual(len(mean_obj2.variables), 2) + self.assertEqual(mean_obj2._dtype, "float32") + + def test_unweighted(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + mean_obj.update_state([1, 3, 5, 7]) + result = mean_obj.result() + self.assertAllClose(result, 4.0, atol=1e-3) + + def test_weighted(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + mean_obj.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) + result = mean_obj.result() + self.assertAllClose(result, 2.0, atol=1e-3) + + def test_weighted_negative_weights(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + mean_obj.update_state([1, 3, 5, 7], sample_weight=[-1, -1, 0, 0]) + result = mean_obj.result() + self.assertAllClose(result, 2.0, atol=1e-3) + + def test_weighted_nd(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + mean_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 1], [1, 0]]) + result = mean_obj.result() + self.assertAllClose(result, 3.0, atol=1e-3) + + def test_weighted_nd_broadcast(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + mean_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 0]]) + result = mean_obj.result() + self.assertAllClose(result, 3.0, atol=1e-3) + + def test_weighted_dynamic_shapes(self): + mean_obj = reduction_metrics.Mean(name="mean", dtype="float32") + result = backend.compute_output_spec( + mean_obj, KerasTensor((None, 2)), KerasTensor((None, 2)) + ) + self.assertAllEqual(result.shape, ()) + + +# How users would register a custom function or class to use with +# MeanMetricWrapper. +@register_keras_serializable(package="test", name="mse") +def mse(y_true, y_pred): + return (y_true - y_pred) ** 2 + + +class MetricWrapperTest(testing.TestCase): + def test_config(self): + mse_obj = reduction_metrics.MeanMetricWrapper( + fn=mse, name="mse", dtype="float32" + ) + self.assertEqual(mse_obj.name, "mse") + self.assertEqual(len(mse_obj.variables), 2) + self.assertEqual(mse_obj._dtype, "float32") + # Check save and restore config + mse_obj2 = reduction_metrics.MeanMetricWrapper.from_config( + mse_obj.get_config() + ) + self.assertEqual(mse_obj2.name, "mse") + self.assertEqual(len(mse_obj2.variables), 2) + self.assertEqual(mse_obj2._dtype, "float32") + self.assertTrue("fn" in mse_obj2.get_config()) + + def test_unweighted(self): + mse_obj = reduction_metrics.MeanMetricWrapper( + fn=mse, name="mse", dtype="float32" + ) + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + + mse_obj.update_state(y_true, y_pred) + result = mse_obj.result() + self.assertAllClose(0.5, result, atol=1e-5) + + def test_weighted(self): + mse_obj = reduction_metrics.MeanMetricWrapper( + fn=mse, name="mse", dtype="float32" + ) + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + sample_weight = np.array([1.0, 1.5, 2.0, 2.5]) + result = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.54285, result, atol=1e-5) + + def test_weighted_broadcast(self): + mse_obj = reduction_metrics.MeanMetricWrapper( + fn=mse, name="mse", dtype="float32" + ) + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + sample_weight = np.array([[1.0, 0.0, 0.5, 0.0, 1.0]]) + result = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.45, result, atol=1e-5) + + def test_weighted_dynamic_shape(self): + mse_obj = reduction_metrics.MeanMetricWrapper( + fn=mse, name="mse", dtype="float32" + ) + result = backend.compute_output_spec( + mse_obj, + KerasTensor((None, 5)), + KerasTensor((None, 5)), + KerasTensor((None, 5)), + ) + self.assertAllEqual(result.shape, ()) + + def test_binary_accuracy_with_boolean_inputs(self): + inp = layers.Input(shape=(1,)) + out = inp > 0.5 + model = models.Model(inputs=inp, outputs=out) + + x = np.random.rand(32, 1) + y = x > 0.5 + + res = model.predict(x) + metric = metrics.BinaryAccuracy() + metric.update_state(y, res) + result = metric.result() + assert result == 1.0 diff --git a/keras/src/metrics/regression_metrics.py b/keras/src/metrics/regression_metrics.py new file mode 100644 index 000000000000..1ec0f86c6373 --- /dev/null +++ b/keras/src/metrics/regression_metrics.py @@ -0,0 +1,608 @@ +import warnings + +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.losses.loss import squeeze_or_expand_to_same_rank +from keras.src.losses.losses import log_cosh +from keras.src.losses.losses import mean_absolute_error +from keras.src.losses.losses import mean_absolute_percentage_error +from keras.src.losses.losses import mean_squared_error +from keras.src.losses.losses import mean_squared_logarithmic_error +from keras.src.metrics import reduction_metrics +from keras.src.utils.numerical_utils import normalize + + +@keras_export("keras.metrics.MeanSquaredError") +class MeanSquaredError(reduction_metrics.MeanMetricWrapper): + """Computes the mean squared error between `y_true` and `y_pred`. + + Formula: + + ```python + loss = mean(square(y_true - y_pred)) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Example: + >>> m = keras.metrics.MeanSquaredError() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 0.25 + """ + + def __init__(self, name="mean_squared_error", dtype=None): + super().__init__(fn=mean_squared_error, name=name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.MeanAbsoluteError") +class MeanAbsoluteError(reduction_metrics.MeanMetricWrapper): + """Computes the mean absolute error between the labels and predictions. + + Formula: + + ```python + loss = mean(abs(y_true - y_pred)) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.MeanAbsoluteError() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 0.25 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.5 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.MeanAbsoluteError()]) + ``` + """ + + def __init__(self, name="mean_absolute_error", dtype=None): + super().__init__(mean_absolute_error, name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.MeanAbsolutePercentageError") +class MeanAbsolutePercentageError(reduction_metrics.MeanMetricWrapper): + """Computes mean absolute percentage error between `y_true` and `y_pred`. + + Formula: + + ```python + loss = 100 * mean(abs((y_true - y_pred) / y_true)) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + >>> m = keras.metrics.MeanAbsolutePercentageError() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 250000000.0 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) + >>> m.result() + 500000000.0 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.MeanAbsolutePercentageError()]) + ``` + """ + + def __init__(self, name="mean_absolute_percentage_error", dtype=None): + super().__init__(mean_absolute_percentage_error, name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.MeanSquaredLogarithmicError") +class MeanSquaredLogarithmicError(reduction_metrics.MeanMetricWrapper): + """Computes mean squared logarithmic error between `y_true` and `y_pred`. + + Formula: + + ```python + loss = mean(square(log(y_true + 1) - log(y_pred + 1))) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.MeanSquaredLogarithmicError() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 0.12011322 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.24022643 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.MeanSquaredLogarithmicError()]) + ``` + """ + + def __init__(self, name="mean_squared_logarithmic_error", dtype=None): + super().__init__(mean_squared_logarithmic_error, name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.RootMeanSquaredError") +class RootMeanSquaredError(reduction_metrics.Mean): + """Computes root mean squared error metric between `y_true` and `y_pred`. + + Formula: + + ```python + loss = sqrt(mean((y_pred - y_true) ** 2)) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.RootMeanSquaredError() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 0.5 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.70710677 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.RootMeanSquaredError()]) + ``` + """ + + def __init__(self, name="root_mean_squared_error", dtype=None): + super().__init__(name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates root mean squared error statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Can + be a `Tensor` whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + Defaults to `1`. + + Returns: + Update op. + """ + y_true = ops.convert_to_tensor(y_true, self._dtype) + y_pred = ops.convert_to_tensor(y_pred, self._dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + error_sq = ops.square(y_pred - y_true) + return super().update_state(error_sq, sample_weight=sample_weight) + + def result(self): + return ops.sqrt(super().result()) + + +@keras_export("keras.metrics.CosineSimilarity") +class CosineSimilarity(reduction_metrics.MeanMetricWrapper): + """Computes the cosine similarity between the labels and predictions. + + Formula: + + ```python + loss = sum(l2_norm(y_true) * l2_norm(y_pred)) + ``` + See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity). + This metric keeps the average cosine similarity between `predictions` and + `labels` over a stream of data. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + axis: (Optional) Defaults to `-1`. The dimension along which the cosine + similarity is computed. + + Examples: + + >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]] + >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]] + >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] + >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) + >>> # = ((0. + 0.) + (0.5 + 0.5)) / 2 + >>> m = keras.metrics.CosineSimilarity(axis=1) + >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) + >>> m.result() + 0.49999997 + + >>> m.reset_state() + >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]], + ... sample_weight=[0.3, 0.7]) + >>> m.result() + 0.6999999 + + Usage with `compile()` API: + + ```python + model.compile( + optimizer='sgd', + loss='mse', + metrics=[keras.metrics.CosineSimilarity(axis=1)]) + ``` + """ + + def __init__(self, name="cosine_similarity", dtype=None, axis=-1): + super().__init__(cosine_similarity, name, dtype=dtype, axis=axis) + # Metric should be maximized during optimization. + self._direction = "up" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +@keras_export("keras.metrics.LogCoshError") +class LogCoshError(reduction_metrics.MeanMetricWrapper): + """Computes the logarithm of the hyperbolic cosine of the prediction error. + + Formula: + + ```python + error = y_pred - y_true + logcosh = mean(log((exp(error) + exp(-error))/2), axis=-1) + ``` + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + + Examples: + + >>> m = keras.metrics.LogCoshError() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) + >>> m.result() + 0.10844523 + + >>> m.reset_state() + >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], + ... sample_weight=[1, 0]) + >>> m.result() + 0.21689045 + + Usage with `compile()` API: + + ```python + model.compile(optimizer='sgd', + loss='mse', + metrics=[keras.metrics.LogCoshError()]) + ``` + """ + + def __init__(self, name="logcosh", dtype=None): + super().__init__(log_cosh, name, dtype=dtype) + # Metric should be minimized during optimization. + self._direction = "down" + + def get_config(self): + return {"name": self.name, "dtype": self.dtype} + + +# Adapted from TF-Addons implementation (RSquare class). +@keras_export("keras.metrics.R2Score") +class R2Score(reduction_metrics.Metric): + """Computes R2 score. + + Formula: + + ```python + sum_squares_residuals = sum((y_true - y_pred) ** 2) + sum_squares = sum((y_true - mean(y_true)) ** 2) + R2 = 1 - sum_squares_residuals / sum_squares + ``` + + This is also called the + [coefficient of determination]( + https://en.wikipedia.org/wiki/Coefficient_of_determination). + + It indicates how close the fitted regression line + is to ground-truth data. + + - The highest score possible is 1.0. It indicates that the predictors + perfectly accounts for variation in the target. + - A score of 0.0 indicates that the predictors do not + account for variation in the target. + - It can also be negative if the model is worse than random. + + This metric can also compute the "Adjusted R2" score. + + Args: + class_aggregation: Specifies how to aggregate scores corresponding to + different output classes (or target dimensions), + i.e. different dimensions on the last axis of the predictions. + Equivalent to `multioutput` argument in Scikit-Learn. + Should be one of + `None` (no aggregation), `"uniform_average"`, + `"variance_weighted_average"`. + num_regressors: Number of independent regressors used + ("Adjusted R2" score). 0 is the standard R2 score. + Defaults to `0`. + name: Optional. string name of the metric instance. + dtype: Optional. data type of the metric result. + + Example: + + >>> y_true = np.array([[1], [4], [3]], dtype=np.float32) + >>> y_pred = np.array([[2], [4], [4]], dtype=np.float32) + >>> metric = keras.metrics.R2Score() + >>> metric.update_state(y_true, y_pred) + >>> result = metric.result() + >>> result + 0.57142854 + """ + + def __init__( + self, + class_aggregation="uniform_average", + num_regressors=0, + name="r2_score", + dtype=None, + ): + super().__init__(name=name, dtype=dtype) + # Metric should be maximized during optimization. + self._direction = "up" + + valid_class_aggregation_values = ( + None, + "uniform_average", + "variance_weighted_average", + ) + if class_aggregation not in valid_class_aggregation_values: + raise ValueError( + "Invalid value for argument `class_aggregation`. Expected " + f"one of {valid_class_aggregation_values}. " + f"Received: class_aggregation={class_aggregation}" + ) + if num_regressors < 0: + raise ValueError( + "Invalid value for argument `num_regressors`. " + "Expected a value >= 0. " + f"Received: num_regressors={num_regressors}" + ) + self.class_aggregation = class_aggregation + self.num_regressors = num_regressors + self.num_samples = self.add_variable( + shape=(), + initializer=initializers.Zeros(), + name="num_samples", + ) + self._built = False + + def _build(self, y_true_shape, y_pred_shape): + if len(y_pred_shape) != 2 or len(y_true_shape) != 2: + raise ValueError( + "R2Score expects 2D inputs with shape " + "(batch_size, output_dim). Received input " + f"shapes: y_pred.shape={y_pred_shape} and " + f"y_true.shape={y_true_shape}." + ) + if y_pred_shape[-1] is None or y_true_shape[-1] is None: + raise ValueError( + "R2Score expects 2D inputs with shape " + "(batch_size, output_dim), with output_dim fully " + "defined (not None). Received input " + f"shapes: y_pred.shape={y_pred_shape} and " + f"y_true.shape={y_true_shape}." + ) + num_classes = y_pred_shape[-1] + self.squared_sum = self.add_variable( + name="squared_sum", + shape=[num_classes], + initializer=initializers.Zeros(), + ) + self.sum = self.add_variable( + name="sum", + shape=[num_classes], + initializer=initializers.Zeros(), + ) + self.total_mse = self.add_variable( + name="residual", + shape=[num_classes], + initializer=initializers.Zeros(), + ) + self.count = self.add_variable( + name="count", + shape=[num_classes], + initializer=initializers.Zeros(), + ) + self._built = True + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates root mean squared error statistics. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Can + be a `Tensor` whose rank is either 0, or the same rank as + `y_true`, and must be broadcastable to `y_true`. + Defaults to `1`. + + Returns: + Update op. + """ + y_true = ops.convert_to_tensor(y_true, dtype=self._dtype) + y_pred = ops.convert_to_tensor(y_pred, dtype=self._dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + if not self._built: + self._build(y_true.shape, y_pred.shape) + + if sample_weight is None: + sample_weight = 1 + + sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype) + + if len(sample_weight.shape) == 1: + # Make sure there's a features dimension + sample_weight = ops.expand_dims(sample_weight, axis=1) + + sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true)) + + weighted_y_true = y_true * ops.cast(sample_weight, y_true.dtype) + self.sum.assign(self.sum + ops.sum(weighted_y_true, axis=0)) + self.squared_sum.assign( + self.squared_sum + ops.sum(y_true * weighted_y_true, axis=0) + ) + self.total_mse.assign( + self.total_mse + + ops.sum( + (y_true - y_pred) ** 2 * ops.cast(sample_weight, y_true.dtype), + axis=0, + ) + ) + self.count.assign(self.count + ops.sum(sample_weight, axis=0)) + self.num_samples.assign(self.num_samples + ops.size(y_true)) + + def result(self): + mean = self.sum / self.count + total = self.squared_sum - self.sum * mean + raw_scores = 1 - (self.total_mse / total) + raw_scores = ops.where(ops.isinf(raw_scores), 0.0, raw_scores) + + if self.class_aggregation == "uniform_average": + r2_score = ops.mean(raw_scores) + elif self.class_aggregation == "variance_weighted_average": + weighted_sum = ops.sum(total * raw_scores) + sum_of_weights = ops.sum(total) + r2_score = weighted_sum / sum_of_weights + else: + r2_score = raw_scores + + if self.num_regressors != 0: + if self.num_regressors > self.num_samples - 1: + warnings.warn( + "More independent predictors than datapoints " + "in adjusted R2 score. Falling back to standard R2 score.", + stacklevel=2, + ) + elif self.num_regressors == self.num_samples - 1: + warnings.warn( + "Division by zero in Adjusted R2 score. " + "Falling back to standard R2 score.", + stacklevel=2, + ) + else: + n = ops.convert_to_tensor(self.num_samples, dtype="float32") + p = ops.convert_to_tensor(self.num_regressors, dtype="float32") + num = ops.multiply( + ops.subtract(1.0, r2_score), ops.subtract(n, 1.0) + ) + den = ops.subtract(ops.subtract(n, p), 1.0) + r2_score = ops.subtract(1.0, ops.divide(num, den)) + return r2_score + + def reset_state(self): + for v in self.variables: + v.assign(ops.zeros(v.shape, dtype=v.dtype)) + + def get_config(self): + config = { + "name": self.name, + "dtype": self.dtype, + "class_aggregation": self.class_aggregation, + "num_regressors": self.num_regressors, + } + base_config = super().get_config() + return {**base_config, **config} + + +def cosine_similarity(y_true, y_pred, axis=-1): + """Computes the cosine similarity between labels and predictions. + + Formula: + + ```python + loss = sum(l2_norm(y_true) * l2_norm(y_pred)) + ``` + + Args: + y_true: Tensor of true targets. + y_pred: Tensor of predicted targets. + axis: Axis along which to determine similarity. Defaults to `-1`. + + Returns: + Cosine similarity tensor. + + Example: + + >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] + >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] + >>> loss = keras.losses.cosine_similarity(y_true, y_pred, axis=-1) + [0., 0.99999994, -0.99999994] + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred) + y_pred = normalize(y_pred, axis=axis) + y_true = normalize(y_true, axis=axis) + return ops.sum(y_true * y_pred, axis=axis) diff --git a/keras/src/metrics/regression_metrics_test.py b/keras/src/metrics/regression_metrics_test.py new file mode 100644 index 000000000000..4ad9899d9b5f --- /dev/null +++ b/keras/src/metrics/regression_metrics_test.py @@ -0,0 +1,392 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import testing +from keras.src.metrics import regression_metrics as metrics + + +class MeanSquaredErrorTest(testing.TestCase): + def test_config(self): + # TODO + pass + + def test_unweighted(self): + mse_obj = metrics.MeanSquaredError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + + mse_obj.update_state(y_true, y_pred) + result = mse_obj.result() + self.assertAllClose(0.5, result, atol=1e-5) + + def test_weighted(self): + mse_obj = metrics.MeanSquaredError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + sample_weight = np.array([1.0, 1.5, 2.0, 2.5]) + result = mse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.54285, result, atol=1e-5) + + +class CosineSimilarityTest(testing.TestCase): + def l2_norm(self, x, axis): + epsilon = 1e-12 + square_sum = np.sum(np.square(x), axis=axis, keepdims=True) + x_inv_norm = 1 / np.sqrt(np.maximum(square_sum, epsilon)) + return np.multiply(x, x_inv_norm) + + def setup(self, axis=1): + self.np_y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32) + self.np_y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32) + + y_true = self.l2_norm(self.np_y_true, axis) + y_pred = self.l2_norm(self.np_y_pred, axis) + self.expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(axis,)) + + self.y_true = self.np_y_true + self.y_pred = self.np_y_pred + + def test_config(self): + cosine_obj = metrics.CosineSimilarity( + axis=2, name="my_cos", dtype="int32" + ) + self.assertEqual(cosine_obj.name, "my_cos") + self.assertEqual(cosine_obj.dtype, "int32") + + # Check save and restore config + cosine_obj2 = metrics.CosineSimilarity.from_config( + cosine_obj.get_config() + ) + self.assertEqual(cosine_obj2.name, "my_cos") + self.assertEqual(cosine_obj2._dtype, "int32") + + def test_unweighted(self): + self.setup() + cosine_obj = metrics.CosineSimilarity() + loss = cosine_obj(self.y_true, self.y_pred) + expected_loss = np.mean(self.expected_loss) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_weighted(self): + self.setup() + cosine_obj = metrics.CosineSimilarity() + sample_weight = np.asarray([1.2, 3.4]) + loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight) + expected_loss = np.sum(self.expected_loss * sample_weight) / np.sum( + sample_weight + ) + self.assertAlmostEqual(loss, expected_loss, 3) + + def test_axis(self): + self.setup(axis=1) + cosine_obj = metrics.CosineSimilarity(axis=1) + loss = cosine_obj(self.y_true, self.y_pred) + expected_loss = np.mean(self.expected_loss) + self.assertAlmostEqual(loss, expected_loss, 3) + + +class MeanAbsoluteErrorTest(testing.TestCase): + def test_config(self): + mae_obj = metrics.MeanAbsoluteError(name="my_mae", dtype="int32") + self.assertEqual(mae_obj.name, "my_mae") + self.assertEqual(mae_obj._dtype, "int32") + + # Check save and restore config + mae_obj2 = metrics.MeanAbsoluteError.from_config(mae_obj.get_config()) + self.assertEqual(mae_obj2.name, "my_mae") + self.assertEqual(mae_obj2._dtype, "int32") + + def test_unweighted(self): + mae_obj = metrics.MeanAbsoluteError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + + mae_obj.update_state(y_true, y_pred) + result = mae_obj.result() + self.assertAllClose(0.5, result, atol=1e-5) + + def test_weighted(self): + mae_obj = metrics.MeanAbsoluteError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + sample_weight = np.array([1.0, 1.5, 2.0, 2.5]) + result = mae_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.54285, result, atol=1e-5) + + +class MeanAbsolutePercentageErrorTest(testing.TestCase): + def test_config(self): + mape_obj = metrics.MeanAbsolutePercentageError( + name="my_mape", dtype="int32" + ) + self.assertEqual(mape_obj.name, "my_mape") + self.assertEqual(mape_obj._dtype, "int32") + + # Check save and restore config + mape_obj2 = metrics.MeanAbsolutePercentageError.from_config( + mape_obj.get_config() + ) + self.assertEqual(mape_obj2.name, "my_mape") + self.assertEqual(mape_obj2._dtype, "int32") + + def test_unweighted(self): + mape_obj = metrics.MeanAbsolutePercentageError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [ + [0, 0, 1, 1, 0], + [1, 1, 1, 1, 1], + [0, 1, 0, 1, 0], + [1, 1, 1, 1, 1], + ], + dtype="float32", + ) + + result = mape_obj(y_true, y_pred) + self.assertAllClose(35e7, result, atol=1e-5) + + def test_weighted(self): + mape_obj = metrics.MeanAbsolutePercentageError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [ + [0, 0, 1, 1, 0], + [1, 1, 1, 1, 1], + [0, 1, 0, 1, 0], + [1, 1, 1, 1, 1], + ], + dtype="float32", + ) + + sample_weight = np.array([1.0, 1.5, 2.0, 2.5]) + result = mape_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(40e7, result, atol=1e-5) + + +class MeanSquaredLogarithmicErrorTest(testing.TestCase): + def test_config(self): + msle_obj = metrics.MeanSquaredLogarithmicError( + name="my_msle", dtype="int32" + ) + self.assertEqual(msle_obj.name, "my_msle") + self.assertEqual(msle_obj._dtype, "int32") + + # Check save and restore config + msle_obj2 = metrics.MeanSquaredLogarithmicError.from_config( + msle_obj.get_config() + ) + self.assertEqual(msle_obj2.name, "my_msle") + self.assertEqual(msle_obj2._dtype, "int32") + + def test_unweighted(self): + msle_obj = metrics.MeanSquaredLogarithmicError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + + msle_obj.update_state(y_true, y_pred) + result = msle_obj.result() + self.assertAllClose(0.24022, result, atol=1e-5) + + def test_weighted(self): + msle_obj = metrics.MeanSquaredLogarithmicError() + y_true = np.array( + [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]] + ) + y_pred = np.array( + [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]] + ) + sample_weight = np.array([1.0, 1.5, 2.0, 2.5]) + result = msle_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(0.26082, result, atol=1e-5) + + +class RootMeanSquaredErrorTest(testing.TestCase): + def test_config(self): + rmse_obj = metrics.RootMeanSquaredError(name="rmse", dtype="int32") + self.assertEqual(rmse_obj.name, "rmse") + self.assertEqual(rmse_obj._dtype, "int32") + + rmse_obj2 = metrics.RootMeanSquaredError.from_config( + rmse_obj.get_config() + ) + self.assertEqual(rmse_obj2.name, "rmse") + self.assertEqual(rmse_obj2._dtype, "int32") + + def test_unweighted(self): + rmse_obj = metrics.RootMeanSquaredError() + y_true = np.array([2, 4, 6]) + y_pred = np.array([1, 3, 2]) + + rmse_obj.update_state(y_true, y_pred) + result = rmse_obj.result() + # error = [-1, -1, -4], square(error) = [1, 1, 16], mean = 18/3 = 6 + self.assertAllClose(np.sqrt(6), result, atol=1e-3) + + def test_weighted(self): + rmse_obj = metrics.RootMeanSquaredError() + y_true = np.array([2, 4, 6]) + y_pred = np.array([1, 3, 2]) + y_true = np.array([2, 4, 6, 8]) + y_pred = np.array([1, 3, 2, 3]) + sample_weight = np.array([0, 1, 0, 1]) + result = rmse_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAllClose(np.sqrt(13), result, atol=1e-3) + + +class LogCoshErrorTest(testing.TestCase): + def setup(self): + y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32) + y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32) + + self.batch_size = 6 + error = y_pred - y_true + self.expected_results = np.log((np.exp(error) + np.exp(-error)) / 2) + + self.y_pred = y_pred + self.y_true = y_true + + def test_config(self): + logcosh_obj = metrics.LogCoshError(name="logcosh", dtype="int32") + self.assertEqual(logcosh_obj.name, "logcosh") + self.assertEqual(logcosh_obj._dtype, "int32") + + def test_unweighted(self): + self.setup() + logcosh_obj = metrics.LogCoshError() + + logcosh_obj.update_state(self.y_true, self.y_pred) + result = logcosh_obj.result() + expected_result = np.sum(self.expected_results) / self.batch_size + self.assertAllClose(result, expected_result, atol=1e-3) + + def test_weighted(self): + self.setup() + logcosh_obj = metrics.LogCoshError(dtype="float32") + sample_weight = np.array([[1.2], [3.4]]) + result = logcosh_obj( + self.y_true, self.y_pred, sample_weight=sample_weight + ) + + sample_weight = np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape( + (2, 3) + ) + expected_result = np.multiply(self.expected_results, sample_weight) + expected_result = np.sum(expected_result) / np.sum(sample_weight) + self.assertAllClose(result, expected_result, atol=1e-3) + + +class R2ScoreTest(testing.TestCase): + def _run_test( + self, + y_true, + y_pred, + sample_weights, + class_aggregation, + num_regressors, + reference_result, + ): + r2 = metrics.R2Score(class_aggregation, num_regressors, dtype="float32") + r2.update_state(y_true, y_pred, sample_weights) + result = r2.result() + self.assertAllClose(result, reference_result, atol=1e-6) + + def test_config(self): + r2_obj = metrics.R2Score( + class_aggregation=None, num_regressors=2, dtype="float32" + ) + self.assertEqual(r2_obj.class_aggregation, None) + self.assertEqual(r2_obj.num_regressors, 2) + self.assertEqual(r2_obj.dtype, "float32") + + # Check save and restore config + r2_obj2 = metrics.R2Score.from_config(r2_obj.get_config()) + self.assertEqual(r2_obj2.class_aggregation, None) + self.assertEqual(r2_obj2.num_regressors, 2) + self.assertEqual(r2_obj2.dtype, "float32") + + @parameterized.parameters( + # class_aggregation, num_regressors, result + (None, 0, [0.37, -1.295, 0.565]), + ("uniform_average", 0, -0.12), + ("variance_weighted_average", 0, -0.12), + ) + def test_r2_sklearn_comparison( + self, class_aggregation, num_regressors, result + ): + y_true = [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]] + y_pred = [[0.4, 0.5, 0.6], [0.1, 0.2, 0.3], [0.5, 0.8, 0.2]] + self._run_test( + y_true, + y_pred, + None, + class_aggregation=class_aggregation, + num_regressors=num_regressors, + reference_result=result, + ) + + @parameterized.parameters( + # class_aggregation, num_regressors, result + (None, 0, [0.17305559, -8.836666, -0.521]), + (None, 1, [0.054920673, -10.241904, -0.7382858]), + (None, 2, [-0.10259259, -12.115555, -1.0280001]), + ("uniform_average", 0, -3.0615367889404297), + ("uniform_average", 1, -3.641756534576416), + ("uniform_average", 2, -4.415382385253906), + ("variance_weighted_average", 0, -1.3710224628448486), + ("variance_weighted_average", 1, -1.7097399234771729), + ("variance_weighted_average", 2, -2.161363363265991), + ) + def test_r2_tfa_comparison(self, class_aggregation, num_regressors, result): + y_true = [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]] + y_pred = [[0.4, 0.9, 1.6], [0.1, 1.2, 0.6], [1.5, 0.8, 0.6]] + sample_weights = [0.8, 0.1, 0.4] + self._run_test( + y_true, + y_pred, + sample_weights, + class_aggregation=class_aggregation, + num_regressors=num_regressors, + reference_result=result, + ) + + def test_errors(self): + # Bad class_aggregation value + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `class_aggregation`" + ): + metrics.R2Score(class_aggregation="wrong") + + # Bad num_regressors value + with self.assertRaisesRegex( + ValueError, "Invalid value for argument `num_regressors`" + ): + metrics.R2Score(num_regressors=-1) + + # Bad input shape + with self.assertRaisesRegex(ValueError, "expects 2D inputs with shape"): + r2 = metrics.R2Score() + r2.update_state([0.0, 1.0], [0.0, 1.0]) diff --git a/keras/src/models/__init__.py b/keras/src/models/__init__.py new file mode 100644 index 000000000000..1f3f73c99961 --- /dev/null +++ b/keras/src/models/__init__.py @@ -0,0 +1,3 @@ +from keras.src.models.functional import Functional +from keras.src.models.model import Model +from keras.src.models.sequential import Sequential diff --git a/keras/src/models/cloning.py b/keras/src/models/cloning.py new file mode 100644 index 000000000000..30bc8940bd4b --- /dev/null +++ b/keras/src/models/cloning.py @@ -0,0 +1,419 @@ +from keras.src import backend +from keras.src import tree +from keras.src import utils +from keras.src.api_export import keras_export +from keras.src.layers import Input +from keras.src.layers import InputLayer +from keras.src.models.functional import Functional +from keras.src.models.functional import functional_like_constructor +from keras.src.models.sequential import Sequential +from keras.src.saving import serialization_lib + + +@keras_export("keras.models.clone_model") +def clone_model( + model, + input_tensors=None, + clone_function=None, + call_function=None, + recursive=False, + **kwargs, +): + """Clone a Functional or Sequential `Model` instance. + + Model cloning is similar to calling a model on new inputs, + except that it creates new layers (and thus new weights) instead + of sharing the weights of the existing layers. + + Note that + `clone_model` will not preserve the uniqueness of shared objects within the + model (e.g. a single variable attached to two distinct layers will be + restored as two separate variables). + + Args: + model: Instance of `Model` + (could be a Functional model or a Sequential model). + input_tensors: optional list of input tensors or InputLayer objects + to build the model upon. If not provided, + new `Input` objects will be created. + clone_function: Callable with signature `fn(layer)` + to be used to clone each layer in the target + model (except `Input` instances). It takes as argument the + layer instance to be cloned, and returns the corresponding layer + instance to be used in the model copy. If unspecified, this callable + defaults to the following serialization/deserialization function: + `lambda layer: layer.__class__.from_config(layer.get_config())`. + By passing a custom callable, you can customize your copy of the + model, e.g. by wrapping certain layers of interest (you might want + to replace all `LSTM` instances with equivalent + `Bidirectional(LSTM(...))` instances, for example). + Defaults to `None`. + call_function: Callable with signature + `fn(layer, *args, **kwargs)` to be used to call each + cloned layer and a set of inputs. It takes the layer instance, + the call arguments and keyword arguments, and returns the + call outputs. If unspecified, this callable defaults to + the regular `__call__()` method: + `def fn(layer, *args, **kwargs): return layer(*args, **kwargs)`. + By passing a custom callable, you can insert new layers before or + after a given layer. Note: this argument can only be used with + Functional models. + recursive: Boolean. Whether to recursively clone any Sequential + or Functional models encountered in the original + Sequential/Functional model. If `False`, + then inner models are cloned by calling `clone_function()`. + If `True`, then inner models are cloned by calling `clone_model()` + with the same `clone_function`, `call_function`, and `recursive` + arguments. Note that in this case, `call_function` + will not be propagated to any Sequential model + (since it is not applicable to Sequential models). + + Returns: + An instance of `Model` reproducing the behavior + of the original model, on top of new inputs tensors, + using newly instantiated weights. The cloned model may behave + differently from the original model if a custom `clone_function` + or `call_function` modifies a layer or layer call. + + Example: + + ```python + # Create a test Sequential model. + model = keras.Sequential([ + keras.layers.Input(shape=(728,)), + keras.layers.Dense(32, activation='relu'), + keras.layers.Dense(1, activation='sigmoid'), + ]) + # Create a copy of the test model (with freshly initialized weights). + new_model = clone_model(model) + ``` + + Using a `clone_function` to make a model deterministic by setting the + random seed everywhere: + + ```python + def clone_function(layer): + config = layer.get_config() + if "seed" in config: + config["seed"] = 1337 + return layer.__class__.from_config(config) + + new_model = clone_model(model, clone_function=clone_function) + ``` + + Using a `call_function` to add a `Dropout` layer after each `Dense` layer + (without recreating new layers): + + ```python + def call_function(layer, *args, **kwargs): + out = layer(*args, **kwargs) + if isinstance(layer, keras.layers.Dense): + out = keras.layers.Dropout(0.5)(out) + return out + + new_model = clone_model( + model, + clone_function=lambda x: x, # Reuse the same layers. + call_function=call_function, + ) + ``` + + Note that subclassed models cannot be cloned by default, + since their internal layer structure is not known. + To achieve equivalent functionality + as `clone_model` in the case of a subclassed model, simply make sure + that the model class implements `get_config()` + (and optionally `from_config()`), and call: + + ```python + new_model = model.__class__.from_config(model.get_config()) + ``` + + In the case of a subclassed model, you cannot using a custom + `clone_function`. + """ + cache = kwargs.pop("cache", None) + if kwargs: + raise ValueError( + f"Unexpected keyword argument(s): {tuple(kwargs.keys())}" + ) + + if isinstance(model, Sequential): + # Wrap clone_function to handle recursiveness and layer sharing. + clone_function = _wrap_clone_function( + clone_function, + call_function=call_function, + recursive=recursive, + cache=cache, + ) + if call_function is not None: + raise ValueError( + "`call_function` argument is not supported with Sequential " + "models. In a Sequential model, layers aren't called " + "at model-construction time (they're merely listed). " + "Use `call_function` with Functional models only. " + "Received model of " + f"type '{model.__class__.__name__}', with " + f"call_function={clone_function}" + ) + return _clone_sequential_model( + model, + clone_function=clone_function, + input_tensors=input_tensors, + ) + if isinstance(model, Functional): + # Wrap clone_function to handle recursiveness and layer sharing. + clone_function = _wrap_clone_function( + clone_function, + call_function=call_function, + recursive=recursive, + cache=cache, + ) + + # If the get_config() method is the same as a regular Functional + # model, we're safe to use _clone_functional_model (which relies + # on a Functional constructor). In the case where the get_config + # is custom, this may not necessarily work, but if clone_function + # or input_tensors are passed, we attempt it anyway + # in order to preserve backwards compatibility. + if utils.is_default(model.get_config) or ( + clone_function or input_tensors + ): + return _clone_functional_model( + model, + clone_function=clone_function, + call_function=call_function, + input_tensors=input_tensors, + ) + + # Case of a custom model class + if clone_function or input_tensors: + raise ValueError( + "Arguments `clone_function` and `input_tensors` " + "are only supported for Sequential models " + "or Functional models. Received model of " + f"type '{model.__class__.__name__}', with " + f"clone_function={clone_function} and " + f"input_tensors={input_tensors}" + ) + if call_function is not None: + raise ValueError( + "Argument `call_function` is only supported " + "for Functional models. Received model of " + f"type '{model.__class__.__name__}', with " + f"call_function={clone_function}" + ) + config = serialization_lib.serialize_keras_object(model) + return serialization_lib.deserialize_keras_object( + config, custom_objects={model.__class__.__name__: model.__class__} + ) + + +def _wrap_clone_function( + clone_function, call_function=None, recursive=False, cache=None +): + """Wrapper to handle recursiveness and layer sharing.""" + if clone_function is None: + + def _clone_layer(layer): + return layer.__class__.from_config(layer.get_config()) + + clone_function = _clone_layer + + if cache is None: + cache = {} + + def wrapped_clone_function(layer): + if id(layer) in cache: + return cache[id(layer)] + if recursive: + if isinstance(layer, Sequential): + # Note: Sequential doesn't support call_function. + clone = clone_model( + layer, + clone_function=clone_function, + cache=cache, + ) + cache[id(layer)] = clone + return clone + elif isinstance(layer, Functional): + clone = clone_model( + layer, + clone_function=clone_function, + call_function=call_function, + cache=cache, + ) + cache[id(layer)] = clone + return clone + clone = clone_function(layer) + cache[id(layer)] = clone + return clone + + return wrapped_clone_function + + +def _clone_sequential_model(model, clone_function, input_tensors=None): + """Clone a `Sequential` model instance. + + Model cloning is similar to calling a model on new inputs, + except that it creates new layers (and thus new weights) instead + of sharing the weights of the existing layers. + + Args: + model: Instance of `Sequential`. + input_tensors: optional list of input tensors + to build the model upon. If not provided, + placeholders will be created. + clone_function: callable to be applied on non-input layers in the model. + By default, it clones the layer (without copying the weights). + + Returns: + An instance of `Sequential` reproducing the behavior + of the original model, on top of new inputs tensors, + using newly instantiated weights. + """ + + if not isinstance(model, Sequential): + raise ValueError( + "Expected `model` argument " + "to be a `Sequential` model instance. " + f"Received: model={model}" + ) + + if not callable(clone_function): + raise ValueError( + "Expected `clone_function` argument to be a callable. " + f"Received: clone_function={clone_function}" + ) + + new_layers = [clone_function(layer) for layer in model.layers] + + if isinstance(model._layers[0], InputLayer): + ref_input_layer = model._layers[0] + input_name = ref_input_layer.name + input_batch_shape = ref_input_layer.batch_shape + input_dtype = ref_input_layer._dtype + else: + input_name = None + input_dtype = None + input_batch_shape = None + + if input_tensors is not None: + if isinstance(input_tensors, (list, tuple)): + if len(input_tensors) != 1: + raise ValueError( + "Argument `input_tensors` must contain a single tensor." + ) + input_tensors = input_tensors[0] + if not isinstance(input_tensors, backend.KerasTensor): + raise ValueError( + "Argument `input_tensors` must be a KerasTensor. " + f"Received invalid value: input_tensors={input_tensors}" + ) + inputs = Input( + tensor=input_tensors, + name=input_name, + ) + new_layers = [inputs] + new_layers + else: + if input_batch_shape is not None: + inputs = Input( + batch_shape=input_batch_shape, + dtype=input_dtype, + name=input_name, + ) + new_layers = [inputs] + new_layers + cloned_model = Sequential( + new_layers, name=model.name, trainable=model.trainable + ) + + # If model compiled already then set same to cloned model + if model.compiled: + compiled_config = model.get_compile_config() + cloned_model.compile_from_config(compiled_config) + return cloned_model + + +def _clone_functional_model( + model, clone_function, input_tensors=None, call_function=None +): + """Clone a `Functional` model instance. + + Model cloning is similar to calling a model on new inputs, + except that it creates new layers (and thus new weights) instead + of sharing the weights of the existing layers. + + Input layers are always cloned. + + Args: + model: Instance of `Functional`. + input_tensors: optional list of input tensors + to build the model upon. If not provided, + placeholders will be created. + clone_function: callable to be applied on non-input layers in the model. + By default, it clones the layer (without copying the weights). + + Returns: + An instance of `Functional` reproducing the behavior + of the original model, on top of new inputs tensors, + using newly instantiated weights. + """ + + if not callable(clone_function): + raise ValueError( + "Expected `clone_function` argument to be a callable. " + f"Received: clone_function={clone_function}" + ) + + if not isinstance(model, Functional): + raise ValueError( + "Expected `model` argument " + f"to be a Functional Model instance. Received: model={model}" + ) + + if input_tensors is not None: + if not all( + isinstance(x, backend.KerasTensor) + for x in tree.flatten(input_tensors) + ): + raise ValueError( + "All entries in `input_tensors` must be KerasTensors. " + f"Received invalid values: inputs_tensors={input_tensors}" + ) + try: + tree.assert_same_structure(input_tensors, model.input) + except ValueError as e: + raise ValueError( + "`input_tensors` must have the same structure as model.input" + f"\nReference structure: {model.input}" + f"\nReceived structure: {input_tensors}" + ) from e + else: + input_tensors = tree.map_structure( + lambda x: Input(batch_shape=x.shape, dtype=x.dtype, name=x.name), + model.input, + ) + + def operation_fn(layer): + new_layer = clone_function(layer) + return new_layer + + output_tensors = model._run_through_graph( + input_tensors, + operation_fn=operation_fn, + call_fn=call_function, + ) + + if functional_like_constructor(model.__class__): + new_model = model.__class__( + input_tensors, output_tensors, name=model.name + ) + else: + # This may be incorrect: the new model will end up having a different + # class than the original. However various existing models rely + # on this behavior, so we keep it. + new_model = Functional(input_tensors, output_tensors, name=model.name) + if model.compiled: + compiled_config = model.get_compile_config() + new_model.compile_from_config(compiled_config) + return new_model diff --git a/keras/src/models/cloning_test.py b/keras/src/models/cloning_test.py new file mode 100644 index 000000000000..b370332c87e2 --- /dev/null +++ b/keras/src/models/cloning_test.py @@ -0,0 +1,253 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.models.cloning import clone_model + + +def get_mlp_functional_model(shared_layers=False): + inputs = layers.Input(shape=(3,)) + x = layers.Dense(2)(inputs) + if shared_layers: + layer = layers.Dense(2, name="shared") + x = layer(x) + x = layer(x) + outputs = layers.Dense(2)(x) + model = models.Model(inputs, outputs) + return model + + +def get_nested_functional_model(): + inputs = layers.Input(shape=(4,)) + x = layers.Dense(3)(inputs) + mlp = get_mlp_functional_model() + x = mlp(x) + outputs = layers.Dense(2)(x) + model = models.Model(inputs, outputs) + return model + + +def get_nested_sequential_model(): + model = models.Sequential() + model.add(layers.Dense(2)) + model.add(get_sequential_model(explicit_input=False)) + model.add(layers.Dense(2)) + return model + + +def get_cnn_functional_model(shared_layers=False): + inputs = layers.Input(shape=(7, 3)) + x = layers.Conv1D(2, 2, padding="same")(inputs) + if shared_layers: + layer = layers.Conv1D(2, 2, padding="same", name="shared") + x = layer(x) + x = layer(x) + outputs = layers.Conv1D(2, 2, padding="same")(x) + model = models.Model(inputs, outputs) + return model + + +def get_sequential_model(explicit_input=True): + model = models.Sequential() + if explicit_input: + model.add(layers.Input(shape=(3,))) + model.add(layers.Dense(2)) + model.add(layers.Dense(2)) + return model + + +def get_cnn_sequential_model(explicit_input=True): + model = models.Sequential() + if explicit_input: + model.add(layers.Input(shape=(7, 3))) + model.add(layers.Conv1D(2, 2, padding="same")) + model.add(layers.Conv1D(2, 2, padding="same")) + return model + + +def get_subclassed_model(): + class ExampleModel(models.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.d1 = layers.Dense(2) + self.d2 = layers.Dense(2) + + def call(self, x): + return self.d2(self.d1(x)) + + return ExampleModel() + + +@pytest.mark.requires_trainable_backend +class CloneModelTest(testing.TestCase): + def assert_models_equal(self, model1, model2, ref_input): + result1 = model1(ref_input) + result2 = model2(ref_input) + for r1, r2 in zip(tree.flatten(result1), tree.flatten(result2)): + self.assertAllClose( + ops.convert_to_numpy(r1), ops.convert_to_numpy(r2) + ) + + def assert_weights_equal(self, model1, model2): + for a, b in zip(model1.weights, model2.weights): + self.assertAllClose(a.numpy(), b.numpy()) + + @parameterized.named_parameters( + ("mlp_functional", get_mlp_functional_model), + ("cnn_functional", get_cnn_functional_model, True), + ("sequential", get_sequential_model), + ( + "deferred_sequential", + lambda: get_sequential_model(explicit_input=False), + ), + ("subclassed", get_subclassed_model), + ) + def test_cloning_correctness(self, model_fn, is_conv=False): + ref_input = np.random.random((2, 7, 3) if is_conv else (2, 3)) + model = model_fn() + new_model = clone_model(model) + model(ref_input) # Maybe needed to build the model + new_model(ref_input) # Maybe needed to build the model + new_model.set_weights(model.get_weights()) + self.assert_models_equal(model, new_model, ref_input) + + @parameterized.named_parameters( + ("mlp_functional", get_mlp_functional_model), + ("cnn_functional", get_cnn_functional_model), + ("sequential", get_sequential_model), + ) + def test_custom_clone_function(self, model_fn): + def clone_function(layer): + config = layer.get_config() + config["name"] = f"{config['name']}_custom" + return layer.__class__.from_config(config) + + model = model_fn() + new_model = clone_model(model, clone_function=clone_function) + for l1, l2 in zip(model.layers, new_model.layers): + if not isinstance(l1, layers.InputLayer): + self.assertEqual(l2.name, f"{l1.name}_custom") + + @parameterized.named_parameters( + ("cnn_functional", get_cnn_functional_model), + ("cnn_sequential", get_cnn_sequential_model), + ( + "cnn_sequential_noinputlayer", + lambda: get_cnn_sequential_model(explicit_input=False), + ), + ) + def test_input_tensors(self, model_fn): + ref_input = np.random.random((2, 7, 3)) + model = model_fn() + model(ref_input) # Maybe needed to get model inputs if no Input layer + input_tensor = model.inputs[0] + new_model = clone_model(model, input_tensors=input_tensor) + tree.assert_same_structure(model.inputs, new_model.inputs) + tree.assert_same_structure(model.outputs, new_model.outputs) + + def test_shared_layers_cloning(self): + model = get_mlp_functional_model(shared_layers=True) + new_model = clone_model(model) + self.assertLen(new_model.layers, 4) + + def test_structured_io_cloning(self): + x = layers.Input((3,)) + y = layers.Input((3,)) + z1 = x + y + z2 = layers.Dense(5)(z1) + inputs = dict(x=x, y=y) + outputs = dict(z1=z1, z2=z2) + model0 = models.Model(inputs, outputs) + + model = clone_model(model0) + tree.assert_same_structure(model.input, inputs) + tree.assert_same_structure(model.output, outputs) + + model = clone_model(model0, input_tensors=inputs) + tree.assert_same_structure(model.input, inputs) + tree.assert_same_structure(model.output, outputs) + + with self.assertRaisesRegex( + ValueError, + "`input_tensors` must have the same structure as model.input", + ): + model = clone_model(model0, input_tensors=(x, y)) + + def test_call_fn(self): + model = get_mlp_functional_model(shared_layers=False) + + def call_function(layer, *args, **kwargs): + out = layer(*args, **kwargs) + if isinstance(layer, layers.Dense): + out = layers.Dropout(0.5)(out) + return out + + new_model = clone_model( + model, + clone_function=lambda x: x, # Reuse the same layers. + call_function=call_function, + ) + self.assertLen(model.layers, 3) + self.assertLen(new_model.layers, 5) + self.assertIsInstance(new_model.layers[2], layers.Dropout) + self.assertIsInstance(new_model.layers[4], layers.Dropout) + ref_input = np.random.random((2, 3)) + self.assert_models_equal(model, new_model, ref_input) + + def test_recursive(self): + model = get_nested_functional_model() + + def call_function(layer, *args, **kwargs): + out = layer(*args, **kwargs) + if isinstance(layer, layers.Dense): + out = layers.Dropout(0.5)(out) + return out + + new_model = clone_model( + model, + clone_function=lambda x: x, # Reuse the same layers. + call_function=call_function, + recursive=True, + ) + self.assertLen(model._flatten_layers(), 8) + self.assertLen(new_model._flatten_layers(), 12) + self.assertIsInstance(new_model.layers[3].layers[2], layers.Dropout) + self.assertIsInstance(new_model.layers[3].layers[4], layers.Dropout) + ref_input = np.random.random((2, 4)) + self.assert_models_equal(model, new_model, ref_input) + + # Sequential. + def clone_function(layer): + layer = layer.__class__.from_config(layer.get_config()) + layer.flag = True + return layer + + model = get_nested_sequential_model() + new_model = clone_model( + model, + clone_function=clone_function, + recursive=True, + ) + ref_input = np.random.random((2, 3)) + model(ref_input) # Maybe needed to build the model + new_model(ref_input) # Maybe needed to build the model + new_model.set_weights(model.get_weights()) + self.assert_models_equal(model, new_model, ref_input) + for l1, l2 in zip(model._flatten_layers(), new_model._flatten_layers()): + if isinstance(l2, layers.Dense): + self.assertFalse(hasattr(l1, "flag")) + self.assertTrue(hasattr(l2, "flag")) + + def test_compiled_model_cloning(self): + model = models.Sequential() + model.add(layers.Input((3,))) + model.add(layers.Dense(5, activation="relu")) + model.add(layers.Dense(1, activation="sigmoid")) + model.compile(optimizer="adam", loss="binary_crossentropy") + cloned_model = clone_model(model) + self.assertEqual(model.compiled, cloned_model.compiled) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py new file mode 100644 index 000000000000..4cbdb44cf31f --- /dev/null +++ b/keras/src/models/functional.py @@ -0,0 +1,885 @@ +import copy +import inspect +import typing +import warnings + +from keras.src import backend +from keras.src import ops +from keras.src import tree +from keras.src.backend.common import global_state +from keras.src.layers.core.input_layer import Input +from keras.src.layers.core.input_layer import InputLayer +from keras.src.layers.input_spec import InputSpec +from keras.src.layers.layer import Layer +from keras.src.legacy.saving import saving_utils +from keras.src.legacy.saving import serialization as legacy_serialization +from keras.src.models.model import Model +from keras.src.ops.function import Function +from keras.src.ops.function import _build_map +from keras.src.ops.function import make_node_key +from keras.src.ops.node import KerasHistory +from keras.src.ops.node import Node +from keras.src.ops.operation import Operation +from keras.src.saving import serialization_lib +from keras.src.utils import tracking + + +class Functional(Function, Model): + """A `Functional` model is a `Model` defined as a directed graph of layers. + + Three types of `Model` exist: subclassed `Model`, `Functional` model, + and `Sequential` (a special case of `Functional`). + + A `Functional` model can be instantiated by passing two arguments to + `__init__()`. The first argument is the `keras.Input` objects + that represent the inputs to the model. + The second argument specifies the output tensors that represent + the outputs of this model. Both arguments can be a nested structure + of tensors. + + Example: + + ``` + inputs = {'x1': keras.Input(shape=(10,), name='x1'), + 'x2': keras.Input(shape=(1,), name='x2')} + t = keras.layers.Dense(1, activation='relu')(inputs['x1']) + outputs = keras.layers.Add()([t, inputs['x2']]) + model = keras.Model(inputs, outputs) + ``` + + A `Functional` model constructed using the Functional API can also + include raw Keras 3 ops. + + Example: + + ```python + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(1)(inputs) + outputs = ops.nn.relu(x) + model = keras.Model(inputs, outputs) + ``` + + A new `Functional` model can also be created by using the + intermediate tensors. This enables you to quickly extract sub-components + of the model. + + Example: + + ```python + inputs = keras.Input(shape=(None, None, 3)) + processed = keras.layers.RandomCrop(width=32, height=32)(inputs) + conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed) + pooling = keras.layers.GlobalAveragePooling2D()(conv) + feature = keras.layers.Dense(10)(pooling) + + full_model = keras.Model(inputs, feature) + backbone = keras.Model(processed, conv) + activations = keras.Model(conv, feature) + ``` + + Note that the `backbone` and `activations` models are not + created with `keras.Input` objects, but with the tensors + that are originated from `keras.Input` objects. + Under the hood, the layers and weights will + be shared across these models, so that user can train the `full_model`, and + use `backbone` or `activations` to do feature extraction. + The inputs and outputs of the model can be nested structures of tensors as + well, and the created models are standard `Functional` model that support + all the existing API. + + Args: + inputs: List of input tensors (must be created via `keras.Input()` + or originated from `keras.Input()`). + outputs: List of output tensors. + name: String, optional. Name of the model. + trainable: Boolean, optional. If the model's variables should be + trainable. + """ + + def __new__(cls, *args, **kwargs): + return typing.cast(cls, super().__new__(cls)) + + @tracking.no_automatic_dependency_tracking + def __init__(self, inputs, outputs, name=None, **kwargs): + if isinstance(inputs, dict): + for k, v in inputs.items(): + if isinstance(v, backend.KerasTensor) and k != v.name: + warnings.warn( + "When providing `inputs` as a dict, all keys in the " + "dict must match the names of the corresponding " + f"tensors. Received key '{k}' mapping to value {v} " + f"which has name '{v.name}'. Change the tensor name to " + f"'{k}' (via `Input(..., name='{k}')`)" + ) + + trainable = kwargs.pop("trainable", None) + flat_inputs = tree.flatten(inputs) + flat_outputs = tree.flatten(outputs) + for x in flat_inputs: + if not isinstance(x, backend.KerasTensor): + raise ValueError( + "All `inputs` values must be KerasTensors. Received: " + f"inputs={inputs} including invalid value {x} of " + f"type {type(x)}" + ) + for x in flat_outputs: + if not isinstance(x, backend.KerasTensor): + raise ValueError( + "All `outputs` values must be KerasTensors. Received: " + f"outputs={outputs} including invalid value {x} of " + f"type {type(x)}" + ) + + if not all(is_input_keras_tensor(t) for t in flat_inputs): + inputs, outputs = clone_graph_nodes(inputs, outputs) + + Function.__init__(self, inputs, outputs, name=name) + + if trainable is not None: + self.trainable = trainable + + self._layers = self.layers + self.build(None) + # We will convert directly (to the correct dtype per input). + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + output_layers = [x._keras_history[0] for x in self.outputs] + self.output_names = [x.name for x in output_layers] + + def _lock_state(self): + # Unlike other layers, we allow Functional state to be mutable after + # build. E.g. to attach a layer to a model that is not part of the + # functional DAG. + pass + + def _obj_type(self): + return "Functional" + + @property + def layers(self): + layers = [] + for operation in self._operations: + if isinstance(operation, Layer): + layers.append(operation) + return layers + + @layers.setter + def layers(self, _): + raise AttributeError( + "`Model.layers` attribute is reserved and should not be used. " + "Please use another name." + ) + + def call(self, inputs, training=None, mask=None, **kwargs): + # Add support for training, masking + inputs = self._standardize_inputs(inputs) + if mask is None: + masks = [None] * len(inputs) + else: + masks = tree.flatten(mask) + for x, mask in zip(inputs, masks): + if mask is not None: + backend.set_keras_mask(x, mask) + outputs = self._run_through_graph( + inputs, + operation_fn=lambda op: operation_fn( + op, training=training, **kwargs + ), + ) + return unpack_singleton(outputs) + + def compute_output_spec(self, inputs, training=None, mask=None): + # From Function + return super().compute_output_spec(inputs) + + def compute_output_shape(self, input_shape): + # From Function + return super().compute_output_shape(input_shape) + + def build(self, input_shape): + self.built = True + + @property + def input_shape(self): + input_shapes = tree.map_structure(lambda x: x.shape, self.inputs) + if isinstance(input_shapes, list) and len(input_shapes) == 1: + return input_shapes[0] + return input_shapes + + @property + def output_shape(self): + output_shapes = tree.map_structure(lambda x: x.shape, self.outputs) + if isinstance(output_shapes, list) and len(output_shapes) == 1: + return output_shapes[0] + return output_shapes + + def _assert_input_compatibility(self, *args): + return super(Model, self)._assert_input_compatibility(*args) + + def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False): + try: + # We first normalize to tuples before performing the check to + # suppress warnings when encountering mismatched tuples and lists. + tree.assert_same_structure( + tree.lists_to_tuples(inputs), + tree.lists_to_tuples(self._inputs_struct), + ) + except: + model_inputs_struct = tree.map_structure( + lambda x: x.name, self._inputs_struct + ) + inputs_struct = tree.map_structure( + lambda x: f"Tensor(shape={x.shape})", inputs + ) + msg = ( + "The structure of `inputs` doesn't match the expected " + f"structure.\nExpected: {model_inputs_struct}\n" + f"Received: inputs={inputs_struct}" + ) + if raise_exception: + raise ValueError(msg) + warnings.warn(msg) + + def _convert_inputs_to_tensors(self, flat_inputs): + converted = [] + for x, input in zip(flat_inputs, self._inputs): + if x is None: # TODO: check if optional + converted.append(x) + else: + converted.append( + ops.convert_to_tensor( + x, dtype=input.dtype, sparse=input.sparse + ) + ) + return converted + + def _adjust_input_rank(self, flat_inputs): + flat_ref_shapes = [x.shape for x in self._inputs] + adjusted = [] + for x, ref_shape in zip(flat_inputs, flat_ref_shapes): + if x is None: + adjusted.append(x) + continue + x_rank = len(x.shape) + ref_rank = len(ref_shape) + if x_rank == ref_rank: + adjusted.append(x) + continue + if x_rank == ref_rank + 1: + if x.shape[-1] == 1: + adjusted.append(ops.squeeze(x, axis=-1)) + continue + if x_rank == ref_rank - 1: + if ref_shape[-1] == 1: + adjusted.append(ops.expand_dims(x, axis=-1)) + continue + raise ValueError( + f"Invalid input shape for input {x}. Expected shape " + f"{ref_shape}, but input has incompatible shape {x.shape}" + ) + # Add back metadata. + for i in range(len(flat_inputs)): + if hasattr(flat_inputs[i], "_keras_history"): + adjusted[i]._keras_history = flat_inputs[i]._keras_history + mask = backend.get_keras_mask(flat_inputs[i]) + if mask is not None: + backend.set_keras_mask(adjusted[i], mask) + return adjusted + + def _standardize_inputs(self, inputs): + raise_exception = False + if ( + isinstance(self._inputs_struct, list) + and len(self._inputs_struct) == 1 + and ops.is_tensor(inputs) + ): + inputs = [inputs] + elif isinstance(inputs, dict) and not isinstance( + self._inputs_struct, dict + ): + # This is to avoid warning + # when we have reconcilable dict/list structs + if hasattr(self._inputs_struct, "__len__") and all( + isinstance(i, backend.KerasTensor) for i in self._inputs_struct + ): + expected_keys = set(i.name for i in self._inputs_struct) + keys = set(inputs.keys()) + if expected_keys.issubset(keys): + inputs = [inputs[i.name] for i in self._inputs_struct] + else: + raise_exception = True + elif isinstance(self._inputs_struct, backend.KerasTensor): + if self._inputs_struct.name in inputs: + inputs = [inputs[self._inputs_struct.name]] + else: + raise_exception = True + else: + raise_exception = True + if ( + isinstance(self._inputs_struct, dict) + and not isinstance(inputs, dict) + and list(self._inputs_struct.keys()) + != sorted(self._inputs_struct.keys()) + ): + raise_exception = True + self._maybe_warn_inputs_struct_mismatch( + inputs, raise_exception=raise_exception + ) + + flat_inputs = tree.flatten(inputs) + flat_inputs = self._convert_inputs_to_tensors(flat_inputs) + return self._adjust_input_rank(flat_inputs) + + @property + def input(self): + # For backwards compatibility, + # override `input` to retrieve the used-provided + # constructor inputs + return self._inputs_struct + + @property + def output(self): + return self._outputs_struct + + def add_loss(self, loss): + # Symbolic only. TODO + raise NotImplementedError + + @property + def input_spec(self): + if hasattr(self, "_manual_input_spec"): + return self._manual_input_spec + + def shape_with_no_batch_size(x): + x = list(x) + if x: + x[0] = None + return tuple(x) + + def make_spec_for_tensor(x, name=None): + optional = False + if isinstance(x._keras_history[0], InputLayer): + if x._keras_history[0].optional: + optional = True + return InputSpec( + shape=shape_with_no_batch_size(x.shape), + allow_last_axis_squeeze=True, + name=x._keras_history[0].name if name is None else name, + optional=optional, + ) + + if isinstance(self._inputs_struct, dict): + if all( + isinstance(x, backend.KerasTensor) + for x in self._inputs_struct.values() + ): + # Case where `_nested_inputs` is a plain dict of Inputs. + names = sorted(self._inputs_struct.keys()) + return [ + make_spec_for_tensor(self._inputs_struct[name], name=name) + for name in names + ] + return None # Deeply nested dict: skip checks. + return [make_spec_for_tensor(x) for x in self.inputs] + + @input_spec.setter + def input_spec(self, value): + self._manual_input_spec = value + + def get_config(self): + if not functional_like_constructor(self.__class__): + # Subclassed networks are not serializable + # (unless serialization is implemented by + # the author of the subclassed network). + return Model.get_config(self) + + config = { + "name": self.name, + "trainable": self.trainable, + } + # Build a map from a layer unique name (make_node_key) + # to the index of the nodes that are saved in the config. + # Only nodes in network_nodes are saved. + node_reindexing_map = {} + for operation in self.operations: + if issubclass(operation.__class__, Functional): + # Functional models start with a pre-existing node + # linking their input to output. + kept_nodes = 1 + else: + kept_nodes = 0 + for original_node_index, node in enumerate( + operation._inbound_nodes + ): + node_key = make_node_key(operation, original_node_index) + if node_key in self._nodes: + # i.e. we mark it to be saved + node_reindexing_map[node_key] = kept_nodes + kept_nodes += 1 + + # serialize and save the layers in layer_configs + layer_configs = [] + for operation in self.operations: # From the earliest layers on. + filtered_inbound_nodes = [] + for original_node_index, node in enumerate( + operation._inbound_nodes + ): + node_key = make_node_key(operation, original_node_index) + if node_key in self._nodes: + # The node is relevant to the model: + # add to filtered_inbound_nodes. + node_data = serialize_node(node, own_nodes=self._nodes) + if node_data is not None: + filtered_inbound_nodes.append(node_data) + + serialize_obj_fn = serialization_lib.serialize_keras_object + if global_state.get_global_attribute("use_legacy_config", False): + # Legacy format serialization used for H5 and SavedModel + serialize_obj_fn = legacy_serialization.serialize_keras_object + layer_config = serialize_obj_fn(operation) + layer_config["name"] = operation.name + layer_config["inbound_nodes"] = filtered_inbound_nodes + layer_configs.append(layer_config) + config["layers"] = layer_configs + + # Gather info about inputs and outputs. + def get_tensor_config(tensor): + operation = tensor._keras_history[0] + node_index = tensor._keras_history[1] + tensor_index = tensor._keras_history[2] + node_key = make_node_key(operation, node_index) + assert node_key in self._nodes + new_node_index = node_reindexing_map[node_key] + return [operation.name, new_node_index, tensor_index] + + def map_tensors(tensors): + return tree.map_structure(get_tensor_config, tensors) + + config["input_layers"] = map_tensors(self._inputs_struct) + config["output_layers"] = map_tensors(self._outputs_struct) + return copy.deepcopy(config) + + +def functional_from_config(cls, config, custom_objects=None): + """Instantiates a Functional model from its config (from `get_config()`). + + Args: + cls: Class of the model, e.g. a custom subclass of `Model`. + config: Output of `get_config()` for the original model instance. + custom_objects: Optional dict of custom objects. + + Returns: + An instance of `cls`. + """ + # Layer instances created during + # the graph reconstruction process + created_layers = {} + + # Dictionary mapping layer instances to + # node data that specifies a layer call. + # It acts as a queue that maintains any unprocessed + # layer call until it becomes possible to process it + # (i.e. until the input tensors to the call all exist). + unprocessed_nodes = {} + + def add_unprocessed_node(layer, node_data): + """Add node to layer list + + Arg: + layer: layer object + node_data: Node data specifying layer call + """ + if layer not in unprocessed_nodes: + unprocessed_nodes[layer] = [node_data] + else: + unprocessed_nodes[layer].append(node_data) + + def process_node(layer, node_data): + """Reconstruct node by linking to inbound layers + + Args: + layer: Layer to process + node_data: List of layer configs + """ + args, kwargs = deserialize_node(node_data, created_layers) + # Call layer on its inputs, thus creating the node + # and building the layer if needed. + layer(*args, **kwargs) + + def process_layer(layer_data): + """Deserializes a layer and index its inbound nodes. + + Args: + layer_data: layer config dict. + """ + layer_name = layer_data["name"] + + # Instantiate layer. + if "module" not in layer_data: + # Legacy format deserialization (no "module" key) + # used for H5 and SavedModel formats + layer = saving_utils.model_from_config( + layer_data, custom_objects=custom_objects + ) + else: + layer = serialization_lib.deserialize_keras_object( + layer_data, custom_objects=custom_objects + ) + if not isinstance(layer, Operation): + raise ValueError( + "Unexpected object from deserialization, expected a layer or " + f"operation, got a {type(layer)}" + ) + created_layers[layer_name] = layer + + # Gather layer inputs. + inbound_nodes_data = layer_data["inbound_nodes"] + for node_data in inbound_nodes_data: + # We don't process nodes (i.e. make layer calls) + # on the fly because the inbound node may not yet exist, + # in case of layer shared at different topological depths + # (e.g. a model such as A(B(A(B(x))))) + add_unprocessed_node(layer, node_data) + + # Extract config used to instantiate Functional model from the config. The + # remaining config will be passed as keyword arguments to the Model + # constructor. + functional_config = {} + for key in ["layers", "input_layers", "output_layers"]: + functional_config[key] = config.pop(key) + for key in ["name", "trainable"]: + if key in config: + functional_config[key] = config.pop(key) + else: + functional_config[key] = None + + # First, we create all layers and enqueue nodes to be processed + for layer_data in functional_config["layers"]: + process_layer(layer_data) + + # Then we process nodes in order of layer depth. + # Nodes that cannot yet be processed (if the inbound node + # does not yet exist) are re-enqueued, and the process + # is repeated until all nodes are processed. + while unprocessed_nodes: + for layer_data in functional_config["layers"]: + layer = created_layers[layer_data["name"]] + + # Process all nodes in layer, if not yet processed + if layer in unprocessed_nodes: + node_data_list = unprocessed_nodes[layer] + + # Process nodes in order + node_index = 0 + while node_index < len(node_data_list): + node_data = node_data_list[node_index] + try: + process_node(layer, node_data) + + # If the node does not have all inbound layers + # available, stop processing and continue later + except IndexError: + break + + node_index += 1 + + # If not all nodes processed then store unprocessed nodes + if node_index < len(node_data_list): + unprocessed_nodes[layer] = node_data_list[node_index:] + # If all nodes processed remove the layer + else: + del unprocessed_nodes[layer] + + # Create list of input and output tensors and return new class + name = functional_config["name"] + trainable = functional_config["trainable"] + + def get_tensor(layer_name, node_index, tensor_index): + assert layer_name in created_layers + layer = created_layers[layer_name] + if isinstance(layer, Functional): + # Functional models start out with a built-in node. + node_index -= 1 + layer_output_tensors = layer._inbound_nodes[node_index].output_tensors + return layer_output_tensors[tensor_index] + + def map_tensors(tensors): + if ( + isinstance(tensors, list) + and len(tensors) == 3 + and isinstance(tensors[0], str) + ): + # Leaf + return get_tensor(*tensors) + if isinstance(tensors, dict): + return {k: map_tensors(v) for k, v in tensors.items()} + if isinstance(tensors, tuple): + return tuple([map_tensors(v) for v in tensors]) + return [map_tensors(v) for v in tensors] + + input_tensors = map_tensors(functional_config["input_layers"]) + output_tensors = map_tensors(functional_config["output_layers"]) + + return cls( + inputs=input_tensors, + outputs=output_tensors, + name=name, + trainable=trainable, + **config, + ) + + +def operation_fn(operation, **call_context_args): + """Wraps each op to inject the call-context args.""" + + def call(*args, **kwargs): + # Propagate all registered call-context args + for name, value in call_context_args.items(): + if ( + name in getattr(operation, "_call_context_args", {}) + and value is not None + ): + kwargs[name] = value + + return operation(*args, **kwargs) + + return call + + +def functional_like_constructor(cls): + init_args = inspect.getfullargspec(cls.__init__).args[1:] + functional_init_args = inspect.getfullargspec(Functional.__init__).args[1:] + if init_args == functional_init_args: + return True + return False + + +def unpack_singleton(x): + if isinstance(x, (list, tuple)) and len(x) == 1: + return x[0] + return x + + +def serialize_node(node, own_nodes=()): + if not node.input_tensors: + # Does not need to be serialized. + return + + def serialize_keras_tensor(x): + # Serialize KerasTensor while converting + # node indices to only include nodes relevant to `own_nodes`. + if isinstance(x, backend.KerasTensor): + operation, node_index, tensor_index = x._keras_history + irrelevant_node_count = 0 + for i, node in enumerate(operation._inbound_nodes[:node_index]): + node_key = make_node_key(operation, i) + if node_key not in own_nodes: + irrelevant_node_count += 1 + x._keras_history = KerasHistory( + operation, node_index - irrelevant_node_count, tensor_index + ) + serialized = serialization_lib.serialize_keras_object(x) + x._keras_history = KerasHistory(operation, node_index, tensor_index) + return serialized + return x + + args = node.arguments.args + kwargs = node.arguments.kwargs + + args = tree.map_structure(serialize_keras_tensor, args) + kwargs = tree.map_structure(serialize_keras_tensor, kwargs) + return { + "args": serialization_lib.serialize_keras_object(args), + "kwargs": serialization_lib.serialize_keras_object(kwargs), + } + + +def deserialize_node(node_data, created_layers): + """Return (args, kwargs) for calling the node layer.""" + if not node_data: + return [], {} + + if isinstance(node_data, list): + # Legacy case. + input_tensors = [] + for input_data in node_data: + inbound_layer_name = input_data[0] + inbound_node_index = input_data[1] + inbound_tensor_index = input_data[2] + if len(input_data) == 3: + kwargs = {} + elif len(input_data) == 4: + kwargs = input_data[3] + else: + raise ValueError( + "Cannot deserialize the model (invalid config data?)" + ) + inbound_layer = created_layers[inbound_layer_name] + + # Raise an error if the corresponding layer node + # has not yet been created + if len(inbound_layer._inbound_nodes) <= inbound_node_index: + raise IndexError( + "Layer node index out of bounds.\n" + f"inbound_layer = {inbound_layer}\n" + "inbound_layer._inbound_nodes = " + f"{inbound_layer._inbound_nodes}\n" + f"inbound_node_index = {inbound_node_index}" + ) + inbound_node = inbound_layer._inbound_nodes[inbound_node_index] + input_tensors.append( + inbound_node.output_tensors[inbound_tensor_index] + ) + return [unpack_singleton(input_tensors)], kwargs + + args = serialization_lib.deserialize_keras_object(node_data["args"]) + kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"]) + + def convert_revived_tensor(x): + if isinstance(x, backend.KerasTensor): + history = x._pre_serialization_keras_history + if history is None: + return x + layer = created_layers.get(history[0], None) + if layer is None: + raise ValueError(f"Unknown layer: {history[0]}") + inbound_node_index = history[1] + inbound_tensor_index = history[2] + if len(layer._inbound_nodes) <= inbound_node_index: + raise IndexError( + "Layer node index out of bounds.\n" + f"inbound_layer = {layer}\n" + f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n" + f"inbound_node_index = {inbound_node_index}" + ) + inbound_node = layer._inbound_nodes[inbound_node_index] + return inbound_node.output_tensors[inbound_tensor_index] + return x + + args = tree.map_structure(convert_revived_tensor, args) + kwargs = tree.map_structure(convert_revived_tensor, kwargs) + return args, kwargs + + +def is_input_keras_tensor(x): + ( + operation, + node_index, + _, + ) = x._keras_history + node = operation._inbound_nodes[node_index] + return node.is_input + + +def clone_single_keras_tensor(x): + return backend.KerasTensor( + shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=f"{x.name}_clone" + ) + + +def clone_keras_tensors(tensors, kt_id_mapping): + def swap(x): + if not isinstance(x, backend.KerasTensor): + return x + if id(x) in kt_id_mapping: + return kt_id_mapping[id(x)] + new_x = clone_single_keras_tensor(x) + kt_id_mapping[id(x)] = new_x + return new_x + + return tree.map_structure(swap, tensors) + + +def find_nodes_by_inputs_and_outputs(inputs, outputs): + nodes, _ = _build_map(inputs, outputs) + return nodes + + +def clone_graph_nodes(inputs, outputs): + """Clone the `Node` between the inputs and output tensors. + + This function is used to create a new functional model from any intermediate + Keras tensors. The clone of the nodes mimic the behavior of reconstructing + the functional graph network by re-executing all the `__call__()` methods. + The cloned nodes will be appended to the layers. + + Note that a new `keras.Input` will be created for any items in the + `inputs` + + Args: + inputs: A nested structure of `KerasTensor` instances. + outputs: A nested structure of `KerasTensor` instances. + + Returns: + A pair of inputs and outputs, with cloned `KerasTensor` instances. + They can be used to create a new functional model. + """ + nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs) + cloned_inputs = [] + cloned_outputs = [] + # We not only need to create copies of Nodes (mimic the calls), also need to + # clone Keras tensors to avoid the override of _keras_history attached on + # the Keras tensor. The following dict is used to track any keras tensor we + # cloned The key is the string ID of the original keras tensor, and value is + # the cloned Keras tensor instance. + kt_id_mapping = {} + op_id_mapping = {} + + for kt_input in tree.flatten(inputs): + if is_input_keras_tensor(kt_input): + # For any existing Keras tensor from keras.Input, leave them as is. + cloned_inputs.append(kt_input) + kt_id_mapping[id(kt_input)] = kt_input + else: + # We need to create a new Keras tensor for any intermediate tensor + cloned_input = Input( + batch_shape=kt_input.shape, + dtype=kt_input.dtype, + sparse=kt_input.sparse, + name=f"{kt_input.name}CLONE", + ) + cloned_inputs.append(cloned_input) + kt_id_mapping[id(kt_input)] = cloned_input + op_id_mapping[id(kt_input._keras_history[0])] = ( + cloned_input._keras_history[0] + ) + cloned_inputs = tree.pack_sequence_as(inputs, cloned_inputs) + + for kt_output in tree.flatten(outputs): + cpy = clone_single_keras_tensor(kt_output) + # We reuse the _keras_history here, which contains the old information. + cpy._keras_history = kt_output._keras_history + cloned_outputs.append(cpy) + kt_id_mapping[id(kt_output)] = cpy + cloned_outputs = tree.pack_sequence_as(outputs, cloned_outputs) + + for node in nodes_to_clone: + if id(node.operation) in op_id_mapping: + operation = op_id_mapping[id(node.operation)] + else: + operation = node.operation + # Clone any Keras tensor to avoid override of _keras_history + # Or reuse an existing Keras tensor if it has already been cloned. + output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping) + if not isinstance(operation, InputLayer): + call_args_copy = clone_keras_tensors( + node.arguments.args, kt_id_mapping + ) + call_kwargs_copy = clone_keras_tensors( + node.arguments.kwargs, kt_id_mapping + ) + else: + call_args_copy = () + call_kwargs_copy = {} + # Creating new nodes based on the existing node information. Node wires + # itself to inbound and outbound layers. The Node constructor actually + # updates this layer's self._inbound_nodes, sets _keras_history on the + # outputs, and adds itself to the `_outbound_nodes` of the layers that + # produced the inputs to this layer call. + Node( + operation, + call_args=call_args_copy, + call_kwargs=call_kwargs_copy, + outputs=output_copy, + ) + return cloned_inputs, cloned_outputs diff --git a/keras/src/models/functional_test.py b/keras/src/models/functional_test.py new file mode 100644 index 000000000000..50adef15cb20 --- /dev/null +++ b/keras/src/models/functional_test.py @@ -0,0 +1,776 @@ +import os +import warnings + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import applications +from keras.src import backend +from keras.src import layers +from keras.src import ops +from keras.src import saving +from keras.src import testing +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.dtype_policies import dtype_policy +from keras.src.layers.core.input_layer import Input +from keras.src.layers.input_spec import InputSpec +from keras.src.models import Functional +from keras.src.models import Model +from keras.src.models import Sequential +from keras.src.models.model import model_from_json + + +class FunctionalTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_basic_flow_multi_input(self): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Functional([input_a, input_b], outputs, name="basic") + model.summary() + + self.assertEqual(model.name, "basic") + self.assertIsInstance(model, Functional) + self.assertIsInstance(model, Model) + + # Eager call + in_val = [np.random.random((2, 3)), np.random.random((2, 3))] + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2") + input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2") + in_val = [input_a_2, input_b_2] + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + @pytest.mark.requires_trainable_backend + def test_scalar_input(self): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(), batch_size=2, name="input_b") + outputs = input_a + input_b[:, None] + model = Functional([input_a, input_b], outputs) + model.summary() + + in_val = [np.zeros((2, 3)), np.ones((2,))] + out_val = model(in_val) + self.assertAllClose(out_val, np.ones((2, 3))) + + @pytest.mark.requires_trainable_backend + def test_mutable_state(self): + inputs = Input(shape=(3,), batch_size=2, name="input") + x = layers.Dense(5)(inputs) + outputs = layers.Dense(5)(x) + model = Functional(inputs, outputs) + # Allow attaching state to a model that isn't directly part of the DAG. + # Most useful for functional subclasses. + model.extra_layer = layers.Dense(5) + + @pytest.mark.requires_trainable_backend + def test_basic_flow_multi_output(self): + inputs = Input(shape=(3,), batch_size=2, name="input") + x = layers.Dense(5)(inputs) + output_a = layers.Dense(4)(x) + output_b = layers.Dense(5)(x) + model = Functional(inputs, [output_a, output_b]) + + # Eager call + in_val = np.random.random((2, 3)) + out_val = model(in_val) + self.assertIsInstance(out_val, list) + self.assertEqual(len(out_val), 2) + self.assertEqual(out_val[0].shape, (2, 4)) + self.assertEqual(out_val[1].shape, (2, 5)) + + # Symbolic call + out_val = model(Input(shape=(3,), batch_size=2)) + self.assertIsInstance(out_val, list) + self.assertEqual(len(out_val), 2) + self.assertEqual(out_val[0].shape, (2, 4)) + self.assertEqual(out_val[1].shape, (2, 5)) + + @pytest.mark.requires_trainable_backend + def test_basic_flow_dict_io(self): + input_a = Input(shape=(3,), batch_size=2, name="a") + input_b = Input(shape=(3,), batch_size=2, name="b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + + with self.assertRaisesRegex( + ValueError, "All `inputs` values must be KerasTensors" + ): + model = Functional({"a": "input_a", "b": input_b}, outputs) + + with self.assertRaisesRegex( + ValueError, "All `outputs` values must be KerasTensors" + ): + model = Functional({"a": input_a, "b": input_b}, "outputs") + + model = Functional({"a": input_a, "b": input_b}, outputs) + + # Eager call + in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 3))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(3,), batch_size=2) + in_val = {"a": input_a_2, "b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + def test_basic_flow_as_a_submodel(self): + # Build submodel + submodel_inputs = Input([4]) + submodel_outputs = layers.Flatten()(submodel_inputs) + submodel = Model(submodel_inputs, submodel_outputs) + + inputs = Input((None, 4)) + outputs = layers.TimeDistributed(submodel)(inputs) + model = Model(inputs=inputs, outputs=outputs) + + x = np.random.random((2, 3, 4)) + y = model(x) + self.assertEqual(y.shape, (2, 3, 4)) + + @pytest.mark.requires_trainable_backend + def test_named_input_dict_io(self): + # Single input + input_a = Input(shape=(3,), batch_size=2, name="a") + x = layers.Dense(5)(input_a) + outputs = layers.Dense(4)(x) + model = Functional(input_a, outputs) + + # Eager call + in_val = {"a": np.random.random((2, 3))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + in_val = {"a": input_a_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # ---- + # Two inputs, input is list + input_a = Input(shape=(3,), batch_size=2, name="a") + input_b = Input(shape=(4,), batch_size=2, name="b") + a = layers.Dense(5)(input_a) + b = layers.Dense(5)(input_b) + x = layers.Concatenate()([a, b]) + outputs = layers.Dense(4)(x) + model = Functional([input_a, input_b], outputs) + + # Eager call + in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 4))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(4,), batch_size=2) + in_val = {"a": input_a_2, "b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # ---- + # Two inputs, input is dict + model = Functional({"a": input_a, "b": input_b}, outputs) + + # Eager call + in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 4))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(4,), batch_size=2) + in_val = {"a": input_a_2, "b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # ---- + # Two inputs, input is dict with incorrect names + model = Functional({"c": input_a, "d": input_b}, outputs) + + # Eager call + in_val = {"c": np.random.random((2, 3)), "d": np.random.random((2, 4))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(4,), batch_size=2) + in_val = {"c": input_a_2, "d": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Now we can't use the input names: + with self.assertRaises(ValueError): + in_val = { + "a": np.random.random((2, 3)), + "b": np.random.random((2, 4)), + } + out_val = model(in_val) + + @pytest.mark.requires_trainable_backend + def test_input_dict_with_extra_field(self): + input_a = Input(shape=(3,), batch_size=2, name="a") + x = input_a * 5 + outputs = x + 2 + + model = Functional({"a": input_a}, outputs) + + with pytest.warns() as record: + # Eager call + in_val = { + "a": np.random.random((2, 3)), + "b": np.random.random((2, 1)), + } + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 3)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(1,), batch_size=2) + in_val = {"a": input_a_2, "b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 3)) + self.assertLen(record, 1) + self.assertStartsWith( + str(record[0].message), + r"The structure of `inputs` doesn't match the expected structure", + ) + + @parameterized.named_parameters( + ("list", list), + ("tuple", tuple), + ("dict", dict), + ) + def test_restored_multi_output_type(self, out_type): + inputs = Input(shape=(3,), batch_size=2, name="input") + x = layers.Dense(5)(inputs) + output_a = layers.Dense(4)(x) + output_b = layers.Dense(5)(x) + if out_type is dict: + outputs = {"a": output_a, "b": output_b} + else: + outputs = out_type([output_a, output_b]) + model = Functional(inputs, outputs) + model_restored = Functional.from_config(model.get_config()) + + # Eager call + in_val = np.random.random((2, 3)) + out_val = model_restored(in_val) + self.assertIsInstance(out_val, out_type) + + # Symbolic call + out_val = model_restored(Input(shape=(3,), batch_size=2)) + self.assertIsInstance(out_val, out_type) + + def test_restored_nested_input(self): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + x = layers.Dense(5)(input_a) + outputs = layers.Dense(4)(x) + model = Functional([[input_a]], outputs) + + # Serialize and deserialize the model + json_config = model.to_json() + restored_json_config = model_from_json(json_config).to_json() + + # Check that the serialized model is the same as the original + self.assertEqual(json_config, restored_json_config) + + def test_functional_input_shape_and_type(self): + input = layers.Input((1024, 4)) + conv = layers.Conv1D(32, 3)(input) + model = Functional(input, conv) + + self.assertIsInstance(model.input, KerasTensor) + self.assertEqual(model.input_shape, (None, 1024, 4)) + + @pytest.mark.requires_trainable_backend + def test_layer_getters(self): + # Test mixing ops and layers + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5, name="dense_1")(x) + outputs = layers.Dense(4, name="dense_2")(x) + model = Functional([input_a, input_b], outputs) + + self.assertEqual(len(model.layers), 4) + self.assertEqual(len(model._operations), 5) + self.assertEqual(model.get_layer(index=0).name, "input_a") + self.assertEqual(model.get_layer(index=1).name, "input_b") + self.assertEqual(model.get_layer(index=2).name, "dense_1") + self.assertEqual(model.get_layer(index=3).name, "dense_2") + self.assertEqual(model.get_layer(name="dense_1").name, "dense_1") + + @pytest.mark.requires_trainable_backend + def test_training_arg(self): + class Canary(layers.Layer): + def call(self, x, training=False): + assert training + return x + + def compute_output_spec(self, x, training=False): + return backend.KerasTensor(x.shape, dtype=x.dtype) + + inputs = Input(shape=(3,), batch_size=2) + outputs = Canary()(inputs) + model = Functional(inputs, outputs) + model(np.random.random((2, 3)), training=True) + + def test_mask_arg(self): + # TODO + pass + + @pytest.mark.requires_trainable_backend + def test_passing_inputs_by_name(self): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Functional([input_a, input_b], outputs) + + # Eager call + in_val = { + "input_a": np.random.random((2, 3)), + "input_b": np.random.random((2, 3)), + } + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2, name="input_a_2") + input_b_2 = Input(shape=(3,), batch_size=2, name="input_b_2") + in_val = {"input_a": input_a_2, "input_b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + @pytest.mark.requires_trainable_backend + def test_rank_standardization(self): + # Downranking + inputs = Input(shape=(3,), batch_size=2) + outputs = layers.Dense(3)(inputs) + model = Functional(inputs, outputs) + out_val = model(np.random.random((2, 3, 1))) + self.assertEqual(out_val.shape, (2, 3)) + + # Upranking + inputs = Input(shape=(3, 1), batch_size=2) + outputs = layers.Dense(3)(inputs) + model = Functional(inputs, outputs) + out_val = model(np.random.random((2, 3))) + self.assertEqual(out_val.shape, (2, 3, 3)) + + @pytest.mark.requires_trainable_backend + def test_dtype_standardization(self): + float_input = Input(shape=(2,), dtype="float16") + int_input = Input(shape=(2,), dtype="int32") + float_output = float_input + 2 + int_output = int_input + 2 + model = Functional((float_input, int_input), (float_output, int_output)) + float_data, int_data = model((np.ones((2, 2)), np.ones((2, 2)))) + + self.assertEqual(backend.standardize_dtype(float_data.dtype), "float16") + self.assertEqual(backend.standardize_dtype(int_data.dtype), "int32") + + @pytest.mark.requires_trainable_backend + def test_serialization(self): + # Test basic model + inputs = Input(shape=(3,), batch_size=2) + outputs = layers.Dense(3)(inputs) + model = Functional(inputs, outputs, trainable=False) + self.run_class_serialization_test(model) + + # Test multi-io model + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + xa = layers.Dense(5, name="middle_a")(input_a) + xb = layers.Dense(5, name="middle_b")(input_b) + output_a = layers.Dense(4, name="output_a")(xa) + output_b = layers.Dense(4, name="output_b")(xb) + model = Functional( + [input_a, input_b], [output_a, output_b], name="func" + ) + self.run_class_serialization_test(model) + + # Test model that includes floating ops + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5, name="middle")(x) + output_a = layers.Dense(4, name="output_a")(x) + output_b = layers.Dense(4, name="output_b")(x) + model = Functional( + [input_a, input_b], [output_a, output_b], name="func" + ) + self.run_class_serialization_test(model) + + # Test model with dict i/o + input_a = Input(shape=(3,), batch_size=2, name="a") + input_b = Input(shape=(3,), batch_size=2, name="b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Functional({"a": input_a, "b": input_b}, outputs) + self.run_class_serialization_test(model) + + @pytest.mark.requires_trainable_backend + def test_bad_input_spec(self): + # Single input + inputs = Input(shape=(4,)) + outputs = layers.Dense(2)(inputs) + model = Functional(inputs, outputs) + with self.assertRaisesRegex( + ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)" + ): + model(np.zeros((2, 3))) + with self.assertRaisesRegex(ValueError, "expects 1 input"): + model([np.zeros((2, 4)), np.zeros((2, 4))]) + + # List input + input_a = Input(shape=(4,), name="a") + input_b = Input(shape=(4,), name="b") + x = input_a + input_b + outputs = layers.Dense(2)(x) + model = Functional([input_a, input_b], outputs) + with self.assertRaisesRegex(ValueError, "expects 2 input"): + model(np.zeros((2, 3))) + with self.assertRaisesRegex( + ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)" + ): + model([np.zeros((2, 3)), np.zeros((2, 4))]) + + # Dict input + model = Functional({"a": input_a, "b": input_b}, outputs) + with self.assertRaisesRegex(ValueError, "expects 2 input"): + model(np.zeros((2, 3))) + with self.assertRaisesRegex( + ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)" + ): + model({"a": np.zeros((2, 3)), "b": np.zeros((2, 4))}) + + @pytest.mark.requires_trainable_backend + def test_manual_input_spec(self): + inputs = Input(shape=(None, 3)) + outputs = layers.Dense(2)(inputs) + model = Functional(inputs, outputs) + model.input_spec = InputSpec(shape=(None, 4, 3)) + with self.assertRaisesRegex( + ValueError, + r"expected shape=\(None, 4, 3\), found shape=\(2, 3, 3\)", + ): + model(np.zeros((2, 3, 3))) + model(np.zeros((2, 4, 3))) + + def test_functional_slicing(self): + inputs = Input(shape=(None, 2), name="input") + x1 = layers.Dense(3, name="dense1")(inputs) + x2 = layers.Dense(4, name="dense2")(x1) + outputs = layers.Dense(5, name="dense3")(x2) + + full_model = Functional(inputs, outputs, name="full_model") + self.assertLen(full_model.layers, 4) + + partial_model_1 = Functional(x2, outputs, name="partial1") + self.assertLen(partial_model_1.layers, 2) # input_layer, dense3 + self.assertIsInstance(partial_model_1.layers[0], layers.InputLayer) + self.assertEqual(partial_model_1.layers[1].name, "dense3") + + partial_model_2 = Functional(x1, x2, name="partial2") + self.assertLen(partial_model_2.layers, 2) # input_layer, dense2 + self.assertIsInstance(partial_model_2.layers[0], layers.InputLayer) + self.assertEqual(partial_model_2.layers[1].name, "dense2") + + partial_model_3 = Functional( + full_model.get_layer("dense2").input, outputs, name="partial3" + ) + self.assertLen(partial_model_3.layers, 3) # input_layer, dense2, dense3 + self.assertIsInstance(partial_model_3.layers[0], layers.InputLayer) + self.assertEqual(partial_model_3.layers[1].name, "dense2") + self.assertEqual(partial_model_3.layers[2].name, "dense3") + + partial_model_4 = Functional( + full_model.get_layer("dense1").input, + full_model.get_layer("dense2").output, + name="partial4", + ) + self.assertLen(partial_model_4.layers, 3) # input_layer, dense1, dense2 + self.assertIsInstance(partial_model_4.layers[0], layers.InputLayer) + self.assertEqual(partial_model_4.layers[1].name, "dense1") + self.assertEqual(partial_model_4.layers[2].name, "dense2") + + def test_deeply_nested_model(self): + i1, i2, i3 = Input((1,)), Input((2,)), Input((3,)) + o1, o2, o3 = ( + layers.Dense(1)(i1), + layers.Dense(2)(i2), + layers.Dense(3)(i3), + ) + model = Model( + {"1": i1, "others": {"2": i2, "3": i3}}, + {"1": o1, "others": {"2": o2, "3": o3}}, + ) + out_eager = model( + { + "1": np.ones((2, 1)), + "others": {"2": np.ones((2, 2)), "3": np.ones((2, 3))}, + } + ) + out_symbolic = model( + { + "1": Input((1,), batch_size=2), + "others": { + "2": Input((2,), batch_size=2), + "3": Input((3,), batch_size=2), + }, + } + ) + for out in [out_eager, out_symbolic]: + self.assertIsInstance(out, dict) + self.assertEqual(set(out.keys()), {"1", "others"}) + self.assertEqual(out["1"].shape, (2, 1)) + self.assertIsInstance(out["others"], dict) + self.assertEqual(set(out["others"].keys()), {"2", "3"}) + self.assertEqual(out["others"]["2"].shape, (2, 2)) + self.assertEqual(out["others"]["3"].shape, (2, 3)) + + # Test serialization boundaries + temp_filepath = os.path.join(self.get_temp_dir(), "deeply_nested.keras") + model.save(temp_filepath) + loaded_model = saving.load_model(temp_filepath) + new_out_eager = loaded_model( + { + "1": np.ones((2, 1)), + "others": {"2": np.ones((2, 2)), "3": np.ones((2, 3))}, + } + ) + self.assertAllClose(out_eager["1"], new_out_eager["1"]) + self.assertAllClose( + out_eager["others"]["2"], new_out_eager["others"]["2"] + ) + self.assertAllClose( + out_eager["others"]["3"], new_out_eager["others"]["3"] + ) + + def test_optional_inputs(self): + class OptionalInputLayer(layers.Layer): + def call(self, x, y=None): + if y is not None: + return x + y + return x + + def compute_output_shape(self, x_shape): + return x_shape + + i1 = Input((2,)) + i2 = Input((2,), optional=True) + outputs = OptionalInputLayer()(i1, i2) + model = Model([i1, i2], outputs) + + # Eager test + out = model([np.ones((2, 2)), None]) + self.assertAllClose(out, np.ones((2, 2))) + # Note: it's not intended to work in symbolic mode (yet). + + def test_optional_dict_inputs(self): + class OptionalInputLayer(layers.Layer): + def call(self, x, y=None): + if y is not None: + return x + y + return x + + def compute_output_shape(self, x_shape): + return x_shape + + i1 = Input((2,), name="input1") + i2 = Input((2,), name="input2", optional=True) + outputs = OptionalInputLayer()(i1, i2) + model = Model({"input1": i1, "input2": i2}, outputs) + + # Eager test + out = model({"input1": np.ones((2, 2)), "input2": None}) + self.assertAllClose(out, np.ones((2, 2))) + # Note: it's not intended to work in symbolic mode (yet). + + def test_warning_for_mismatched_inputs_structure(self): + def is_input_warning(w): + return str(w.message).startswith( + "The structure of `inputs` doesn't match the expected structure" + ) + + i1 = Input((2,)) + i2 = Input((2,)) + outputs = layers.Add()([i1, i2]) + + model = Model({"i1": i1, "i2": i2}, outputs) + with pytest.warns() as warning_logs: + model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0) + self.assertLen(list(filter(is_input_warning, warning_logs)), 1) + # No warning for mismatched tuples and lists. + model = Model([i1, i2], outputs) + with warnings.catch_warnings(record=True) as warning_logs: + model.predict((np.ones((2, 2)), np.zeros((2, 2))), verbose=0) + self.assertLen(list(filter(is_input_warning, warning_logs)), 0) + + def test_for_functional_in_sequential(self): + # Test for a v3.4.1 regression. + if backend.image_data_format() == "channels_first": + image_size = (3, 100, 100) + else: + image_size = (100, 100, 3) + base_model = applications.mobilenet.MobileNet( + include_top=False, weights=None + ) + model = Sequential() + model.add(layers.Input(shape=image_size)) + model.add(base_model) + model.add(layers.GlobalAveragePooling2D()) + model.add(layers.Dense(7, activation="softmax")) + config = model.get_config() + model = Sequential.from_config(config) + + def test_add_loss(self): + # TODO + pass + + def test_layers_setter(self): + inputs = Input(shape=(3,), batch_size=2, name="input") + outputs = layers.Dense(5)(inputs) + model = Functional(inputs, outputs) + with self.assertRaisesRegex( + AttributeError, "`Model.layers` attribute is reserved" + ): + model.layers = [layers.Dense(4)] + + @pytest.mark.requires_trainable_backend + def test_dict_input_to_list_model(self): + vocabulary_size = 100 + num_tags = 10 + num_departments = 3 + num_samples = 128 + + title = layers.Input(shape=(vocabulary_size,), name="title") + text_body = layers.Input(shape=(vocabulary_size,), name="text_body") + tags = layers.Input(shape=(num_tags,), name="tags") + features = layers.Concatenate()([title, text_body, tags]) + features = layers.Dense(64, activation="relu")(features) + priority = layers.Dense(1, activation="sigmoid", name="priority")( + features + ) + department = layers.Dense( + num_departments, activation="softmax", name="department" + )(features) + model = Functional( + inputs=[title, text_body, tags], outputs=[priority, department] + ) + + title_data = np.random.randint( + 0, 2, size=(num_samples, vocabulary_size) + ) + text_body_data = np.random.randint( + 0, 2, size=(num_samples, vocabulary_size) + ) + tags_data = np.random.randint(0, 2, size=(num_samples, num_tags)) + priority_data = np.random.random(size=(num_samples, 1)) + department_data = np.random.randint( + 0, 2, size=(num_samples, num_departments) + ) + + # List style fit + model.compile( + optimizer="adam", + loss=["mean_squared_error", "categorical_crossentropy"], + metrics=[["mean_absolute_error"], ["accuracy"]], + ) + model.fit( + [title_data, text_body_data, tags_data], + [priority_data, department_data], + epochs=1, + ) + model.evaluate( + [title_data, text_body_data, tags_data], + [priority_data, department_data], + ) + priority_preds, department_preds = model.predict( + [title_data, text_body_data, tags_data] + ) + + # Dict style fit + model.compile( + optimizer="adam", + loss={ + "priority": "mean_squared_error", + "department": "categorical_crossentropy", + }, + metrics={ + "priority": ["mean_absolute_error"], + "department": ["accuracy"], + }, + ) + model.fit( + { + "title": title_data, + "text_body": text_body_data, + "tags": tags_data, + }, + {"priority": priority_data, "department": department_data}, + epochs=1, + ) + model.evaluate( + { + "title": title_data, + "text_body": text_body_data, + "tags": tags_data, + }, + {"priority": priority_data, "department": department_data}, + ) + priority_preds, department_preds = model.predict( + { + "title": title_data, + "text_body": text_body_data, + "tags": tags_data, + } + ) + + def test_list_input_with_dict_build(self): + x1 = Input((10,), name="IT") + x2 = Input((10,), name="IS") + y = layers.subtract([x1, x2]) + model = Model(inputs={"IT": x1, "IS": x2}, outputs=y) + x1 = ops.ones((1, 10)) + x2 = ops.zeros((1, 10)) + # Works + _ = model({"IT": x1, "IS": x2}) + with self.assertRaisesRegex( + ValueError, + "The structure of `inputs` doesn't match the expected structure", + ): + model([x1, x2]) + + def test_functional_with_dtype_policy(self): + original_dtype_policy = dtype_policy.dtype_policy() + try: + dtype_policy.set_dtype_policy("mixed_float16") + + inputs = Input((10,), name="input") + outputs = layers.Dense(5)(inputs) + model = Model(inputs=inputs, outputs=outputs) + + # Verify that no cast node appears in the graph. + self.assertLen(model.operations, 2) + self.assertIsInstance(model.operations[0], layers.InputLayer) + self.assertIsInstance(model.operations[1], layers.Dense) + finally: + dtype_policy.set_dtype_policy(original_dtype_policy) diff --git a/keras/src/models/model.py b/keras/src/models/model.py new file mode 100644 index 000000000000..e8fa6415b103 --- /dev/null +++ b/keras/src/models/model.py @@ -0,0 +1,951 @@ +import inspect +import json +import typing +import warnings + +from keras.src import backend +from keras.src import utils +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer +from keras.src.models.variable_mapping import map_saveable_variables +from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.gptq_core import gptq_quantize +from keras.src.saving import saving_api +from keras.src.trainers import trainer as base_trainer +from keras.src.utils import summary_utils +from keras.src.utils import traceback_utils + +if backend.backend() == "tensorflow": + from keras.src.backend.tensorflow.trainer import ( + TensorFlowTrainer as Trainer, + ) +elif backend.backend() == "jax": + from keras.src.backend.jax.trainer import JAXTrainer as Trainer +elif backend.backend() == "torch": + from keras.src.backend.torch.trainer import TorchTrainer as Trainer +elif backend.backend() == "numpy": + from keras.src.backend.numpy.trainer import NumpyTrainer as Trainer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.trainer import OpenVINOTrainer as Trainer +else: + raise RuntimeError( + f"Backend '{backend.backend()}' must implement the Trainer class." + ) + + +@keras_export(["keras.Model", "keras.models.Model"]) +class Model(Trainer, base_trainer.Trainer, Layer): + """A model grouping layers into an object with training/inference features. + + There are three ways to instantiate a `Model`: + + ## With the "Functional API" + + You start from `Input`, + you chain layer calls to specify the model's forward pass, + and finally, you create your model from inputs and outputs: + + ```python + inputs = keras.Input(shape=(37,)) + x = keras.layers.Dense(32, activation="relu")(inputs) + outputs = keras.layers.Dense(5, activation="softmax")(x) + model = keras.Model(inputs=inputs, outputs=outputs) + ``` + + Note: Only dicts, lists, and tuples of input tensors are supported. Nested + inputs are not supported (e.g. lists of list or dicts of dict). + + A new Functional API model can also be created by using the + intermediate tensors. This enables you to quickly extract sub-components + of the model. + + Example: + + ```python + inputs = keras.Input(shape=(None, None, 3)) + processed = keras.layers.RandomCrop(width=128, height=128)(inputs) + conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed) + pooling = keras.layers.GlobalAveragePooling2D()(conv) + feature = keras.layers.Dense(10)(pooling) + + full_model = keras.Model(inputs, feature) + backbone = keras.Model(processed, conv) + activations = keras.Model(conv, feature) + ``` + + Note that the `backbone` and `activations` models are not + created with `keras.Input` objects, but with the tensors that originate + from `keras.Input` objects. Under the hood, the layers and weights will + be shared across these models, so that user can train the `full_model`, and + use `backbone` or `activations` to do feature extraction. + The inputs and outputs of the model can be nested structures of tensors as + well, and the created models are standard Functional API models that support + all the existing APIs. + + ## By subclassing the `Model` class + + In that case, you should define your + layers in `__init__()` and you should implement the model's forward pass + in `call()`. + + ```python + class MyModel(keras.Model): + def __init__(self): + super().__init__() + self.dense1 = keras.layers.Dense(32, activation="relu") + self.dense2 = keras.layers.Dense(5, activation="softmax") + + def call(self, inputs): + x = self.dense1(inputs) + return self.dense2(x) + + model = MyModel() + ``` + + If you subclass `Model`, you can optionally have + a `training` argument (boolean) in `call()`, which you can use to specify + a different behavior in training and inference: + + ```python + class MyModel(keras.Model): + def __init__(self): + super().__init__() + self.dense1 = keras.layers.Dense(32, activation="relu") + self.dense2 = keras.layers.Dense(5, activation="softmax") + self.dropout = keras.layers.Dropout(0.5) + + def call(self, inputs, training=False): + x = self.dense1(inputs) + x = self.dropout(x, training=training) + return self.dense2(x) + + model = MyModel() + ``` + + Once the model is created, you can config the model with losses and metrics + with `model.compile()`, train the model with `model.fit()`, or use the model + to do prediction with `model.predict()`. + + ## With the `Sequential` class + + In addition, `keras.Sequential` is a special case of model where + the model is purely a stack of single-input, single-output layers. + + ```python + model = keras.Sequential([ + keras.Input(shape=(None, None, 3)), + keras.layers.Conv2D(filters=32, kernel_size=3), + ]) + ``` + """ + + def __new__(cls, *args, **kwargs): + # Signature detection for usage of `Model` as a `Functional` + if functional_init_arguments(args, kwargs) and cls == Model: + from keras.src.models.functional import Functional + + return Functional.__new__(Functional, *args, **kwargs) + return typing.cast(cls, super().__new__(cls)) + + def __init__(self, *args, **kwargs): + Trainer.__init__(self) + from keras.src.models import functional + + # Signature detection for usage of a `Model` subclass + # as a `Functional` subclass + if functional_init_arguments(args, kwargs): + inject_functional_model_class(self.__class__) + functional.Functional.__init__(self, *args, **kwargs) + else: + Layer.__init__(self, *args, **kwargs) + + def call(self, *args, **kwargs): + raise NotImplementedError( + f"Model {self.__class__.__name__} does not have a `call()` " + "method implemented." + ) + + @property + def layers(self): + return list(self._flatten_layers(include_self=False, recursive=False)) + + @layers.setter + def layers(self, _): + raise AttributeError( + "`Model.layers` attribute is reserved and should not be used. " + "Please use another name." + ) + + @traceback_utils.filter_traceback + def get_layer(self, name=None, index=None): + """Retrieves a layer based on either its name (unique) or index. + + If `name` and `index` are both provided, `index` will take precedence. + Indices are based on order of horizontal graph traversal (bottom-up). + + Args: + name: String, name of layer. + index: Integer, index of layer. + + Returns: + A layer instance. + """ + if index is not None and name is not None: + raise ValueError( + "Provide only a layer name or a layer index. Received: " + f"index={index}, name={name}." + ) + if index is not None: + if len(self.layers) <= index: + raise ValueError( + f"Was asked to retrieve layer at index {index}" + f" but model only has {len(self.layers)}" + " layers." + ) + else: + return self.layers[index] + + if name is not None: + for layer in self.layers: + if layer.name == name: + return layer + raise ValueError( + f"No such layer: {name}. Existing layers are: " + f"{list(layer.name for layer in self.layers)}." + ) + raise ValueError( + "Provide either a layer name or layer index at `get_layer`." + ) + + @traceback_utils.filter_traceback + def summary( + self, + line_length=None, + positions=None, + print_fn=None, + expand_nested=False, + show_trainable=False, + layer_range=None, + ): + """Prints a string summary of the network. + + Args: + line_length: Total length of printed lines + (e.g. set this to adapt the display to different + terminal window sizes). + positions: Relative or absolute positions of log elements + in each line. If not provided, becomes + `[0.3, 0.6, 0.70, 1.]`. Defaults to `None`. + print_fn: Print function to use. By default, prints to `stdout`. + If `stdout` doesn't work in your environment, change to `print`. + It will be called on each line of the summary. + You can set it to a custom function + in order to capture the string summary. + expand_nested: Whether to expand the nested models. + Defaults to `False`. + show_trainable: Whether to show if a layer is trainable. + Defaults to `False`. + layer_range: a list or tuple of 2 strings, + which is the starting layer name and ending layer name + (both inclusive) indicating the range of layers to be printed + in summary. It also accepts regex patterns instead of exact + names. In this case, the start predicate will be + the first element that matches `layer_range[0]` + and the end predicate will be the last element + that matches `layer_range[1]`. + By default `None` considers all layers of the model. + + Raises: + ValueError: if `summary()` is called before the model is built. + """ + summary_utils.print_summary( + self, + line_length=line_length, + positions=positions, + print_fn=print_fn, + expand_nested=expand_nested, + show_trainable=show_trainable, + layer_range=layer_range, + ) + + @traceback_utils.filter_traceback + def save(self, filepath, overwrite=True, zipped=None, **kwargs): + """Saves a model as a `.keras` file. + + Note that `model.save()` is an alias for `keras.saving.save_model()`. + + The saved `.keras` file contains: + + - The model's configuration (architecture) + - The model's weights + - The model's optimizer's state (if any) + + Thus models can be reinstantiated in the exact same state. + + Args: + filepath: `str` or `pathlib.Path` object. + The path where to save the model. Must end in `.keras` + (unless saving the model as an unzipped directory + via `zipped=False`). + overwrite: Whether we should overwrite any existing model at + the target location, or instead ask the user via + an interactive prompt. + zipped: Whether to save the model as a zipped `.keras` + archive (default when saving locally), or as an + unzipped directory (default when saving on the + Hugging Face Hub). + + Example: + + ```python + model = keras.Sequential( + [ + keras.layers.Dense(5, input_shape=(3,)), + keras.layers.Softmax(), + ], + ) + model.save("model.keras") + loaded_model = keras.saving.load_model("model.keras") + x = keras.random.uniform((10, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + ``` + """ + return saving_api.save_model( + self, filepath, overwrite=overwrite, zipped=zipped, **kwargs + ) + + @traceback_utils.filter_traceback + def save_weights(self, filepath, overwrite=True, max_shard_size=None): + """Saves all weights to a single file or sharded files. + + By default, the weights will be saved in a single `.weights.h5` file. + If sharding is enabled (`max_shard_size` is not `None`), the weights + will be saved in multiple files, each with a size at most + `max_shard_size` (in GB). Additionally, a configuration file + `.weights.json` will contain the metadata for the sharded files. + + The saved sharded files contain: + + - `*.weights.json`: The configuration file containing 'metadata' and + 'weight_map'. + - `*_xxxxxx.weights.h5`: The sharded files containing only the + weights. + + Args: + filepath: `str` or `pathlib.Path` object. Path where the weights + will be saved. When sharding, the filepath must end in + `.weights.json`. If `.weights.h5` is provided, it will be + overridden. + overwrite: Whether to overwrite any existing weights at the target + location or instead ask the user via an interactive prompt. + max_shard_size: `int` or `float`. Maximum size in GB for each + sharded file. If `None`, no sharding will be done. Defaults to + `None`. + + Example: + + ```python + # Instantiate a EfficientNetV2L model with about 454MB of weights. + model = keras.applications.EfficientNetV2L(weights=None) + + # Save the weights in a single file. + model.save_weights("model.weights.h5") + + # Save the weights in sharded files. Use `max_shard_size=0.25` means + # each sharded file will be at most ~250MB. + model.save_weights("model.weights.json", max_shard_size=0.25) + + # Load the weights in a new model with the same architecture. + loaded_model = keras.applications.EfficientNetV2L(weights=None) + loaded_model.load_weights("model.weights.h5") + x = keras.random.uniform((1, 480, 480, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + + # Load the sharded weights in a new model with the same architecture. + loaded_model = keras.applications.EfficientNetV2L(weights=None) + loaded_model.load_weights("model.weights.json") + x = keras.random.uniform((1, 480, 480, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + ``` + """ + return saving_api.save_weights( + self, filepath, overwrite=overwrite, max_shard_size=max_shard_size + ) + + @traceback_utils.filter_traceback + def load_weights(self, filepath, skip_mismatch=False, **kwargs): + """Load the weights from a single file or sharded files. + + Weights are loaded based on the network's topology. This means the + architecture should be the same as when the weights were saved. Note + that layers that don't have weights are not taken into account in the + topological ordering, so adding or removing layers is fine as long as + they don't have weights. + + **Partial weight loading** + + If you have modified your model, for instance by adding a new layer + (with weights) or by changing the shape of the weights of a layer, you + can choose to ignore errors and continue loading by setting + `skip_mismatch=True`. In this case any layer with mismatching weights + will be skipped. A warning will be displayed for each skipped layer. + + **Sharding** + + When loading sharded weights, it is important to specify `filepath` that + ends with `*.weights.json` which is used as the configuration file. + Additionally, the sharded files `*_xxxxx.weights.h5` must be in the same + directory as the configuration file. + + Args: + filepath: `str` or `pathlib.Path` object. Path where the weights + will be saved. When sharding, the filepath must end in + `.weights.json`. + skip_mismatch: Boolean, whether to skip loading of layers where + there is a mismatch in the number of weights, or a mismatch in + the shape of the weights. + + Example: + + ```python + # Load the weights in a single file. + model.load_weights("model.weights.h5") + + # Load the weights in sharded files. + model.load_weights("model.weights.json") + ``` + """ + saving_api.load_weights( + self, + filepath, + skip_mismatch=skip_mismatch, + **kwargs, + ) + + def quantize(self, mode, config=None, **kwargs): + """Quantize the weights of the model. + + Note that the model must be built first before calling this method. + `quantize` will recursively call `quantize(mode)` in all layers and + will be skipped if the layer doesn't implement the function. + + Args: + mode: The mode of the quantization. Only 'int8' is supported at this + time. + """ + from keras.src.dtype_policies import QUANTIZATION_MODES + + # Validate inputs. + type_check = kwargs.pop("type_check", True) + if kwargs: + raise ValueError( + "Unrecognized keyword arguments " + f"passed to {self.__class__.__name__}: {kwargs}" + ) + + if mode not in QUANTIZATION_MODES: + raise ValueError( + "Invalid quantization mode. " + f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}" + ) + + if mode == "gptq": + if not isinstance(config, GPTQConfig): + raise ValueError( + "Mode 'gptq' requires a valid `config` argument of type " + f"`GPTQConfig`. Received: {type(config)}" + ) + elif config is not None: + # All other modes must not receive a config + raise ValueError( + f"The `config` argument is only supported for 'gptq' mode, " + f"but received mode='{mode}' and a non-None config." + ) + + graph_modified = False + for layer in self._flatten_layers(): + if len(list(layer._flatten_layers())) == 1: + try: + layer.quantize(mode, type_check=type_check, config=config) + graph_modified = True + except NotImplementedError as e: + warnings.warn(str(e)) + except AttributeError: + pass + + if mode == "gptq": + gptq_quantize(self, config) + + # If any layer was changed, we must rebuild the execution functions. + if graph_modified: + self.train_function = None + self.test_function = None + self.predict_function = None + self._post_quantize(mode, **kwargs) + + def _post_quantize(self, mode, **kwargs): + if backend.backend() == "torch": + # We need to manually retrack `torch_params`. + # The reason is that after quantization, the removed variables are + # still referenced by `torch_params` and cannot be gc. + for layer in self._flatten_layers(): + layer._track_variables() + + def build_from_config(self, config): + if not config: + return + status = False + if "input_shape" in config: + # Case: all inputs are in the first arg (possibly nested). + if utils.is_default(self.build): + status = self._build_by_run_for_single_pos_arg( + config["input_shape"] + ) + else: + try: + self.build(config["input_shape"]) + status = True + except: + pass + self._build_shapes_dict = config + + elif "shapes_dict" in config: + # Case: inputs were recorded as multiple keyword arguments. + if utils.is_default(self.build): + status = self._build_by_run_for_kwargs(config["shapes_dict"]) + else: + try: + self.build(**config["shapes_dict"]) + status = True + except: + pass + self._build_shapes_dict = config["shapes_dict"] + + if not status: + warnings.warn( + f"Model '{self.name}' had a build config, but the model " + "cannot be built automatically in " + "`build_from_config(config)`. " + "You should implement " + "`def build_from_config(self, config)`, " + "and you might also want to implement the method " + " that generates the config at saving time, " + "`def get_build_config(self)`. " + "The method `build_from_config()` is meant to " + "create the state of the model (i.e. its variables) " + "upon deserialization.", + stacklevel=2, + ) + + def to_json(self, **kwargs): + """Returns a JSON string containing the network configuration. + + To load a network from a JSON save file, use + `keras.models.model_from_json(json_string, custom_objects={...})`. + + Args: + **kwargs: Additional keyword arguments to be passed to + `json.dumps()`. + + Returns: + A JSON string. + """ + from keras.src.saving import serialization_lib + + model_config = serialization_lib.serialize_keras_object(self) + return json.dumps(model_config, **kwargs) + + def export( + self, + filepath, + format="tf_saved_model", + verbose=None, + input_signature=None, + **kwargs, + ): + """Export the model as an artifact for inference. + + Args: + filepath: `str` or `pathlib.Path` object. The path to save the + artifact. + format: `str`. The export format. Supported values: + `"tf_saved_model"` and `"onnx"`. Defaults to + `"tf_saved_model"`. + verbose: `bool`. Whether to print a message during export. Defaults + to `None`, which uses the default value set by different + backends and formats. + input_signature: Optional. Specifies the shape and dtype of the + model inputs. Can be a structure of `keras.InputSpec`, + `tf.TensorSpec`, `backend.KerasTensor`, or backend tensor. If + not provided, it will be automatically computed. Defaults to + `None`. + **kwargs: Additional keyword arguments. + - `is_static`: Optional `bool`. Specific to the JAX backend and + `format="tf_saved_model"`. Indicates whether `fn` is static. + Set to `False` if `fn` involves state updates (e.g., RNG + seeds and counters). + - `jax2tf_kwargs`: Optional `dict`. Specific to the JAX backend + and `format="tf_saved_model"`. Arguments for + `jax2tf.convert`. See the documentation for + [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are not + provided, they will be automatically computed. + - `opset_version`: Optional `int`. Specific to `format="onnx"`. + An integer value that specifies the ONNX opset version. + + **Note:** This feature is currently supported only with TensorFlow, JAX + and Torch backends. + + **Note:** Be aware that the exported artifact may contain information + from the local file system when using `format="onnx"`, `verbose=True` + and Torch backend. + + Examples: + + Here's how to export a TensorFlow SavedModel for inference. + + ```python + # Export the model as a TensorFlow SavedModel artifact + model.export("path/to/location", format="tf_saved_model") + + # Load the artifact in a different process/environment + reloaded_artifact = tf.saved_model.load("path/to/location") + predictions = reloaded_artifact.serve(input_data) + ``` + + Here's how to export an ONNX for inference. + + ```python + # Export the model as a ONNX artifact + model.export("path/to/location", format="onnx") + + # Load the artifact in a different process/environment + ort_session = onnxruntime.InferenceSession("path/to/location") + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) + } + predictions = ort_session.run(None, ort_inputs) + ``` + """ + from keras.src.export import export_onnx + from keras.src.export import export_openvino + from keras.src.export import export_saved_model + + available_formats = ("tf_saved_model", "onnx", "openvino") + if format not in available_formats: + raise ValueError( + f"Unrecognized format={format}. Supported formats are: " + f"{list(available_formats)}." + ) + + if format == "tf_saved_model": + export_saved_model( + self, + filepath, + verbose, + input_signature=input_signature, + **kwargs, + ) + elif format == "onnx": + export_onnx( + self, + filepath, + verbose, + input_signature=input_signature, + **kwargs, + ) + elif format == "openvino": + export_openvino( + self, + filepath, + verbose, + input_signature=input_signature, + **kwargs, + ) + + @classmethod + def from_config(cls, config, custom_objects=None): + from keras.src.models.functional import Functional + + functional_config_keys = [ + "name", + "layers", + "input_layers", + "output_layers", + ] + is_functional_config = all( + key in config for key in functional_config_keys + ) + argspec = inspect.getfullargspec(cls.__init__) + functional_init_args = inspect.getfullargspec(Functional.__init__).args[ + 1: + ] + revivable_as_functional = ( + cls in {Functional, Model} + or argspec.args[1:] == functional_init_args + or (argspec.varargs == "args" and argspec.varkw == "kwargs") + ) + if is_functional_config and revivable_as_functional: + # Revive Functional model + # (but not Functional subclasses with a custom __init__) + from keras.src.models.functional import functional_from_config + + return functional_from_config( + cls, config, custom_objects=custom_objects + ) + + # Either the model has a custom __init__, or the config + # does not contain all the information necessary to + # revive a Functional model. This happens when the user creates + # subclassed models where `get_config()` is returning + # insufficient information to be considered a Functional model. + # In this case, we fall back to provide all config into the + # constructor of the class. + try: + return cls(**config) + except TypeError as e: + raise TypeError( + "Unable to revive model from config. When overriding " + "the `get_config()` method, make sure that the " + "returned config contains all items used as arguments " + f"in the constructor to {cls}, " + "which is the default behavior. " + "You can override this default behavior by defining a " + "`from_config(cls, config)` class method to specify " + "how to create an " + f"instance of {cls.__name__} from its config.\n\n" + f"Received config={config}\n\n" + f"Error encountered during deserialization: {e}" + ) + + def _get_variable_map(self): + store = {} + map_saveable_variables(self, store=store, visited_saveables=set()) + return store + + def get_state_tree(self, value_format="backend_tensor"): + """Retrieves tree-like structure of model variables. + + This method allows retrieval of different model variables (trainable, + non-trainable, optimizer, and metrics). The variables are returned in a + nested dictionary format, where the keys correspond to the variable + names and the values are the nested representations of the variables. + + Returns: + dict: A dictionary containing the nested representations of the + requested variables. The keys are the variable names, and the + values are the corresponding nested dictionaries. + value_format: One of `"backend_tensor"`, `"numpy_array"`. + The kind of array to return as the leaves of the nested + state tree. + + Example: + + ```python + model = keras.Sequential([ + keras.Input(shape=(1,), name="my_input"), + keras.layers.Dense(1, activation="sigmoid", name="my_dense"), + ], name="my_sequential") + model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + model.fit(np.array([[1.0]]), np.array([[1.0]])) + state_tree = model.get_state_tree() + ``` + + The `state_tree` dictionary returned looks like: + + ``` + { + 'metrics_variables': { + 'loss': { + 'count': ..., + 'total': ..., + }, + 'mean_absolute_error': { + 'count': ..., + 'total': ..., + } + }, + 'trainable_variables': { + 'my_sequential': { + 'my_dense': { + 'bias': ..., + 'kernel': ..., + } + } + }, + 'non_trainable_variables': {}, + 'optimizer_variables': { + 'adam': { + 'iteration': ..., + 'learning_rate': ..., + 'my_sequential_my_dense_bias_momentum': ..., + 'my_sequential_my_dense_bias_velocity': ..., + 'my_sequential_my_dense_kernel_momentum': ..., + 'my_sequential_my_dense_kernel_velocity': ..., + } + } + } + } + ``` + """ + variables = {} + variables["trainable_variables"] = self._create_nested_dict( + self.trainable_variables, value_format + ) + variables["non_trainable_variables"] = self._create_nested_dict( + self.non_trainable_variables, value_format + ) + variables["optimizer_variables"] = self._create_nested_dict( + self.optimizer.variables, value_format + ) + variables["metrics_variables"] = self._create_nested_dict( + self.metrics_variables, value_format + ) + return variables + + def _create_nested_dict(self, variables, value_format): + flat_dict = {} + for v in variables: + if v.path in flat_dict: + raise ValueError( + "The following variable path is found twice in the model: " + f"'{v.path}'. `get_state_tree()` can only be called when " + "all variable paths are unique. Make sure to give unique " + "names to your layers (and other objects)." + ) + if value_format == "backend_tensor": + flat_dict[v.path] = v.value + elif value_format == "numpy_array": + flat_dict[v.path] = v.numpy() + else: + raise ValueError( + "Invalid `value_format` argument. Expected one of " + "{'numpy_array', 'backend_tensor'}. Received: " + f"value_format={value_format}" + ) + + nested_dict = {} + for path, value in flat_dict.items(): + parts = path.split("/") + current_dict = nested_dict + for part in parts[:-1]: + if part not in current_dict: + current_dict[part] = {} + current_dict = current_dict[part] + current_dict[parts[-1]] = value + + return nested_dict + + def set_state_tree(self, state_tree): + """Assigns values to variables of the model. + + This method takes a dictionary of nested variable values, which + represents the state tree of the model, and assigns them to the + corresponding variables of the model. The dictionary keys represent the + variable names (e.g., `'trainable_variables'`, `'optimizer_variables'`), + and the values are nested dictionaries containing the variable + paths and their corresponding values. + + Args: + state_tree: A dictionary representing the state tree of the model. + The keys are the variable names, and the values are nested + dictionaries representing the variable paths and their values. + """ + for k, v in state_tree.items(): + path_value_dict = self._flatten_nested_dict(v) + if k == "trainable_variables": + self._assign_variable_values( + self.trainable_variables, path_value_dict + ) + elif k == "non_trainable_variables": + self._assign_variable_values( + self.non_trainable_variables, path_value_dict + ) + elif k == "optimizer_variables": + self._assign_variable_values( + self.optimizer.variables, path_value_dict + ) + elif k == "metrics_variables": + self._assign_variable_values( + self.metrics_variables, path_value_dict + ) + else: + raise ValueError(f"Unknown variable name: {k}") + + def _assign_variable_values(self, variables, path_value_dict): + for path, value in path_value_dict.items(): + for variable in variables: + if variable.path == path: + variable.assign(value) + + def _flatten_nested_dict(self, nested_dict): + flat_dict = {} + + def _flatten(current_dict, prefix=""): + for key, value in current_dict.items(): + if isinstance(value, dict): + _flatten(value, f"{prefix}{key}/") + else: + flat_dict[f"{prefix}{key}"] = value + + _flatten(nested_dict) + return flat_dict + + +@keras_export("keras.models.model_from_json") +def model_from_json(json_string, custom_objects=None): + """Parses a JSON model configuration string and returns a model instance. + + Example: + + >>> model = keras.Sequential([ + ... keras.layers.Dense(5, input_shape=(3,)), + ... keras.layers.Softmax()]) + >>> config = model.to_json() + >>> loaded_model = keras.models.model_from_json(config) + + Args: + json_string: JSON string encoding a model configuration. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + """ + from keras.src.saving import serialization_lib + + model_config = json.loads(json_string) + return serialization_lib.deserialize_keras_object( + model_config, custom_objects=custom_objects + ) + + +def functional_init_arguments(args, kwargs): + return ( + (len(args) == 2) + or (len(args) == 1 and "outputs" in kwargs) + or ("inputs" in kwargs and "outputs" in kwargs) + ) + + +def inject_functional_model_class(cls): + """Inject `Functional` into the hierarchy of this class if needed.""" + from keras.src.models import functional + + if cls is Model: + return functional.Functional + # In case there is any multiple inheritance, we stop injecting the + # class if keras model is not in its class hierarchy. + if cls is object: + return object + + cls.__bases__ = tuple( + inject_functional_model_class(base) for base in cls.__bases__ + ) + # Trigger any `__new__` class swapping that needed to happen on `Functional` + # but did not because functional was not in the class hierarchy. + cls.__new__(cls) + + return cls diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py new file mode 100644 index 000000000000..4b2b5ce00081 --- /dev/null +++ b/keras/src/models/model_test.py @@ -0,0 +1,1296 @@ +import os +import pickle +from collections import namedtuple + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import losses +from keras.src import testing +from keras.src import tree +from keras.src.layers.core.input_layer import Input +from keras.src.models.functional import Functional +from keras.src.models.model import Model +from keras.src.models.model import model_from_json + + +def _get_model(): + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + x = input_a + input_b + x = layers.Dense(5)(x) + outputs = layers.Dense(4)(x) + model = Model([input_a, input_b], outputs) + return model + + +def _get_model_multi_outputs_list(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x) + model = Model(x, [output_a, output_b]) + return model + + +def _get_model_multi_outputs_list_no_output_names(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1)(x) + output_b = layers.Dense(1, activation="sigmoid")(x) + model = Model(x, [output_a, output_b]) + return model + + +def _get_model_single_output(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + model = Model(x, output_a) + return model + + +def _get_model_single_output_list(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + model = Model(x, [output_a]) + return model + + +def _get_model_single_output_dict(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + model = Model(x, {"output_a": output_a}) + return model + + +def _get_model_multi_outputs_dict(): + x = Input(shape=(3,), name="input_a") + output_a = layers.Dense(1, name="output_a")(x) + output_b = layers.Dense(1, name="output_b", activation="sigmoid")(x) + model = Model(x, {"output_a": output_a, "output_b": output_b}) + return model + + +def _get_model_multi_outputs_struct_list_like(_type): + x = Input(shape=(3,), name="x") + y1 = layers.Dense(1, name="y1", activation="sigmoid")(x) + y2 = layers.Dense(1, name="y2", activation="sigmoid")(x) + model = Model(x, _type([y1, y2])) + return model + + +def _get_model_multi_outputs_struct_namedtuple(): + Y = namedtuple("Y", ["y1", "y2"]) + x = Input(shape=(3,), name="x") + y1 = layers.Dense(1, name="y1", activation="sigmoid")(x) + y2 = layers.Dense(1, name="y2", activation="sigmoid")(x) + model = Model(x, Y(y1, y2)) + return model, Y + + +def _get_model_multi_outputs_struct_dict(): + x = Input(shape=(3,), name="x") + y1 = layers.Dense(1, name="y1", activation="sigmoid")(x) + y2 = layers.Dense(1, name="y2", activation="sigmoid")(x) + model = Model(x, {"a": y1, "b": y2}) + return model + + +def _get_model_multi_outputs_struct(): + x = Input(shape=(3,), name="x") + y1 = layers.Dense(1, name="y1", activation="sigmoid")(x) + y2 = layers.Dense(1, name="y2", activation="sigmoid")(x) + y3 = layers.Dense(1, name="y3", activation="sigmoid")(x) + model = Model( + x, + { + "a": (y1, y2), + "b": {"b1": y1, "b2": y2}, + "c": {"c1": (y1, y2), "c2": y2}, + "d": y3, + }, + ) + return model + + +def _get_model_multi_outputs_dict_with_single_tensor(): + x = Input(shape=(3,), name="input_a") + output = layers.Dense(1, name="output_a")(x) + model = Model(x, {"output_a": output, "output_b": output}) + return model + + +def _get_model_with_custom_compute_loss(): + class MyModel(Model): + def __init__(self): + inputs = Input(shape=(3,), name="inputs") + outputs = layers.Dense(1, name="a")(inputs) + super().__init__(inputs=inputs, outputs=outputs) + + def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs): + y_pred = [y_pred, y_pred] # To list + return super().compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs + ) + + model = MyModel() + return model + + +def _get_model_with_duplicate_variable_path(): + class MyModel(Model): + def __init__(self): + super().__init__() + self.dense1 = layers.Dense(4, activation="relu", name="layer1") + self.dense2 = layers.Dense(4, activation="relu", name="layer1") + self.dense3 = layers.Dense(2) + + def call(self, x): + x = self.dense1(x) + x = self.dense2(x) + return self.dense3(x) + + model = MyModel() + x = np.random.random((1, 16)) + model(x) + return model + + +def _get_model_optional_inputs(): + class OptionalInputLayer(layers.Layer): + def __init__(self): + super().__init__() + self.dense = layers.Dense(2) + + def call(self, x, o=None): + z = x if o is None else x + o + return self.dense(z) + + x = Input((2,), name="x") + o = Input((2,), name="o", optional=True) + y = OptionalInputLayer()(x, o) + model = Model({"x": x, "o": o}, y) + return model + + +def _get_variable_value_by_path(variables, path): + for v in variables: + if v.path == path: + return v.value + raise ValueError(f"No variable was find with path = {path}") + + +@pytest.mark.requires_trainable_backend +class ModelTest(testing.TestCase): + def test_functional_rerouting(self): + model = _get_model() + self.assertIsInstance(model, Functional) + + def test_json_serialization(self): + model = _get_model() + json_string = model.to_json() + new_model = model_from_json(json_string) + self.assertEqual(json_string, new_model.to_json()) + + def test_tuple_input_model_subclass(self): + # https://github.com/keras-team/keras/issues/324 + + class MultiInputModel(Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense1 = layers.Dense(4) + + def call(self, inputs): + a, b = inputs + r = self.dense1(a) + return layers.concatenate([r, b]) + + model = MultiInputModel() + x1 = np.random.rand(3, 3) + x2 = np.random.rand(3, 2) + out = model((x1, x2)) + self.assertEqual(out.shape, (3, 6)) + + def test_reviving_functional_from_config_custom_layer(self): + class CustomDense(layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.dense = layers.Dense(units) + + def call(self, x): + return self.dense(x) + + inputs = layers.Input((4,)) + outputs = CustomDense(10)(inputs) + model = Model(inputs, outputs) + config = model.get_config() + + new_model = Model.from_config( + config, custom_objects={"CustomDense": CustomDense} + ) + self.assertIsInstance(new_model, Functional) + + def test_reviving_functional_from_config_custom_model(self): + class CustomModel(Model): + def __init__(self, *args, param=1, **kwargs): + super().__init__(*args, **kwargs) + self.param = param + + def get_config(self): + base_config = super().get_config() + config = {"param": self.param} + return base_config | config + + inputs = layers.Input((3,)) + outputs = layers.Dense(5)(inputs) + model = CustomModel(inputs=inputs, outputs=outputs, param=3) + + new_model = CustomModel.from_config(model.get_config()) + self.assertEqual(new_model.param, 3) + + @parameterized.named_parameters( + ("single_output_1", _get_model_single_output), + ("single_output_2", _get_model_single_output), + ("single_output_3", _get_model_single_output), + ("single_output_4", _get_model_single_output), + ("single_list_output_1", _get_model_single_output_list), + ("single_list_output_2", _get_model_single_output_list), + ("single_list_output_3", _get_model_single_output_list), + ("single_list_output_4", _get_model_single_output_list), + ) + def test_functional_pickling(self, model_fn): + model = model_fn() + self.assertIsInstance(model, Functional) + model.compile() + x = np.random.rand(8, 3) + + reloaded_pickle = pickle.loads(pickle.dumps(model)) + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + + self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + + @parameterized.named_parameters( + ("single_output_1", _get_model_single_output, None), + ("single_output_2", _get_model_single_output, "list"), + ("single_output_3", _get_model_single_output, "dict"), + ("single_output_4", _get_model_single_output, "dict_list"), + ("single_list_output_1", _get_model_single_output_list, None), + ("single_list_output_2", _get_model_single_output_list, "list"), + ("single_list_output_3", _get_model_single_output_list, "dict"), + ("single_list_output_4", _get_model_single_output_list, "dict_list"), + ("single_dict_output_1", _get_model_single_output_dict, None), + ("single_dict_output_2", _get_model_single_output_dict, "list"), + ("single_dict_output_3", _get_model_single_output_dict, "dict"), + ("single_dict_output_4", _get_model_single_output_dict, "dict_list"), + ) + def test_functional_single_output(self, model_fn, loss_type): + model = model_fn() + self.assertIsInstance(model, Functional) + loss = "mean_squared_error" + if loss_type == "list": + loss = [loss] + elif loss_type == "dict": + loss = {"output_a": loss} + elif loss_type == "dict_list": + loss = {"output_a": [loss]} + model.compile( + optimizer="sgd", + loss=loss, + metrics={ + "output_a": ["mean_squared_error", "mean_absolute_error"], + }, + weighted_metrics={ + "output_a": "mean_squared_error", + }, + ) + # Fit the model to make sure compile_metrics are built + x = np.random.rand(8, 3) + y = np.random.rand(8, 1) + hist = model.fit( + x, + y, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "mean_absolute_error", + "mean_squared_error", + "weighted_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_list_losses(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss=["mean_squared_error", "binary_crossentropy"], + metrics=[ + "mean_squared_error", + ["mean_squared_error", "accuracy"], + ], + loss_weights=[0.1, 2], + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_list_losses_abbr(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss=["mse", "bce"], + metrics=[ + ["bce", "mse", "mae"], + ["mse", "acc"], + ], + loss_weights=[0.1, 2], + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_a_bce", + "output_a_mae", + "output_a_mse", + "output_b_acc", + "output_b_loss", + "output_b_mse", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_nested_list_losses(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss=["mean_squared_error", ["binary_crossentropy"]], + metrics=[ + "mean_squared_error", + ["mean_squared_error", "accuracy"], + ], + loss_weights=[0.1, 2], + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_a_mean_squared_error", + "output_b_accuracy", + "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_dict_outputs_dict_losses(self): + model = _get_model_multi_outputs_dict() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": ["binary_crossentropy"], + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + weighted_metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Check dict outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, dict) + self.assertEqual(outputs["output_a"].shape, (8, 1)) + self.assertEqual(outputs["output_b"].shape, (8, 1)) + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + {"output_a": y1, "output_b": y2}, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_a_mean_squared_error", + "output_a_weighted_mean_squared_error", + "output_b_accuracy", + "output_b_loss", + "output_b_mean_squared_error", + "output_b_weighted_accuracy", + "output_b_weighted_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_dict_outputs_dict_losses_with_undefined_loss(self): + model = _get_model_multi_outputs_dict() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_b": ["binary_crossentropy"], + }, + metrics={ + "output_b": ["mean_squared_error", "accuracy"], + }, + weighted_metrics={ + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Check dict outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, dict) + self.assertEqual(outputs["output_a"].shape, (8, 1)) + self.assertEqual(outputs["output_b"].shape, (8, 1)) + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + {"output_a": y1, "output_b": y2}, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_b_accuracy", + "output_b_mean_squared_error", + "output_b_weighted_accuracy", + "output_b_weighted_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_metrics(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + weighted_metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Check list outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, list) + self.assertEqual(outputs[0].shape, (8, 1)) + self.assertEqual(outputs[1].shape, (8, 1)) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_a_mean_squared_error", + "output_a_weighted_mean_squared_error", + "output_b_accuracy", + "output_b_loss", + "output_b_mean_squared_error", + "output_b_weighted_accuracy", + "output_b_weighted_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["mean_squared_error"], + }, + weighted_metrics={ + "output_a": ["mean_squared_error"], + "output_b": ["accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + # `output_b_accuracy` doesn't have `weighted_` in metric name. + # When a metric is only in weighted metrics, it skips `weighted_` + # prefix. This behavior matches`tf.keras`. + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_a_mean_squared_error", + "output_a_weighted_mean_squared_error", + "output_b_accuracy", + "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_partial_metrics(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_b": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "output_a_loss", + "output_b_accuracy", + "output_b_loss", + "output_b_mean_squared_error", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_dict_outputs_with_single_tensor(self): + model = _get_model_multi_outputs_dict_with_single_tensor() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + + # `model` has 2 outputs, but there is actually only 1 output tensor. + self.assertLen(model.outputs, 2) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + ) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"]) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_with_custom_compute_loss(self): + model = _get_model_with_custom_compute_loss() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + + # `model` has 1 output, but in `compute_loss` it is separated into 2. + self.assertLen(model.outputs, 1) + model.compile( + optimizer="sgd", loss=["mean_squared_error", "binary_crossentropy"] + ) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + ["binary_crossentropy_loss", "loss", "mean_squared_error_loss"] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_list_outputs_dict_losses_invalid_keys(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_c": "binary_crossentropy", + }, + ) + + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "Expected keys", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_dict_losses_no_output_names(self): + model = _get_model_multi_outputs_list_no_output_names() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={"output_a": "mean_squared_error"}, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "Expected keys", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_dict_metrics_invalid_keys(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_c": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `metrics`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_dict_outputs_dict_losses_invalid_keys(self): + model = _get_model_multi_outputs_dict() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_c": "binary_crossentropy", + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + KeyError, + "in the `loss` argument, can't be found " + "in either the model's output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_dict_outputs_dict_metrics_invalid_keys(self): + model = _get_model_multi_outputs_dict() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss={ + "output_a": "mean_squared_error", + "output_b": "binary_crossentropy", + }, + metrics={ + "output_c": ["mean_squared_error", "accuracy"], + }, + ) + # Fit the model to make sure compile_metrics are built + with self.assertRaisesRegex( + ValueError, + "In the dict argument `metrics`, " + "key 'output_c' does not correspond to any model output", + ): + model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_functional_list_outputs_invalid_nested_list_losses(self): + model = _get_model_multi_outputs_list() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.randint(0, 2, (8, 1)) + model.compile( + optimizer="sgd", + loss=[ + "mean_squared_error", + ["mean_squared_error", "binary_crossentropy"], + ], + ) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"]) + self.assertListEqual(hist_keys, ref_keys) + + @parameterized.named_parameters( + ("int8", "int8"), + ("float8", "float8"), + ) + def test_quantize(self, mode): + model = _get_model() + x1 = np.random.rand(2, 3) + x2 = np.random.rand(2, 3) + model.quantize(mode) + _ = model((x1, x2)) + + for layer in model._flatten_layers(): + if isinstance(layer, (layers.Dense, layers.EinsumDense)): + self.assertEqual( + layer.dtype_policy.name, f"{mode}_from_float32" + ) + self.assertEqual(layer.dtype_policy.quantization_mode, mode) + if mode == "int8": + self.assertLen(model.variables, 6) + if backend.backend() == "torch": + self.assertLen(list(model.named_parameters()), 6) + elif mode == "float8": + self.assertLen(model.variables, 16) + if backend.backend() == "torch": + self.assertLen(list(model.named_parameters()), 16) + + @parameterized.named_parameters( + ("int8", "int8"), + ("float8", "float8"), + ) + def test_quantize_unbuilt(self, mode): + class MyModel(Model): + def __init__(self): + super().__init__() + self.dense1 = layers.Dense(32, activation="relu") + self.dense2 = layers.Dense(5, activation="softmax") + self.dropout = layers.Dropout(0.5) + + def call(self, inputs, training=False): + x = self.dense1(inputs) + x = self.dropout(x, training=training) + return self.dense2(x) + + model = MyModel() + with self.assertRaisesRegex( + ValueError, "Cannot quantize a layer that isn't yet built." + ): + model.quantize(mode) + + x = np.random.rand(2, 3) + _ = model(x) + model.quantize(mode) + + def test_quantize_invalid_args(self): + model = _get_model() + with self.assertRaisesRegex( + ValueError, "Invalid quantization mode. Expected one of" + ): + model.quantize("abc") + + with self.assertRaisesRegex( + ValueError, "Unrecognized keyword arguments" + ): + model.quantize("int8", unrecognized_kwargs=None) + + with self.assertRaisesRegex(ValueError, "Invalid quantization mode"): + model.quantize("int7") + + @parameterized.named_parameters( + ("int8", "int8"), + ("float8", "float8"), + ) + def test_quantize_nested_model(self, mode): + class NestedLayer(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense = layers.Dense(units) + + def call(self, x): + x = self.dense(x) + return x + + class DoubleNestedLayer(layers.Layer): + def __init__(self, units): + super().__init__() + self.nested_dense1 = NestedLayer(units) + self.nested_dense2 = NestedLayer(units) + self.dense = layers.Dense(units) + + def call(self, x): + x = self.nested_dense1(x) + x = self.nested_dense2(x) + x = self.dense(x) + return x + + inputs = layers.Input([3]) + outputs = DoubleNestedLayer(8)(inputs) + model = Model(inputs, outputs) + model.quantize(mode) + + if mode == "int8": + kernel_count = 0 + for weight in model.weights: + if weight.name == "kernel": + kernel_count += 1 + self.assertEqual( + backend.standardize_dtype(weight.dtype), "int8" + ) + self.assertEqual(kernel_count, 3) + if mode == "float8": + # kernel + bias + scale * 3 + amax_history * 3 == 8 + self.assertEqual(len(model.weights), 3 * 8) + + def test_get_state_tree(self): + model = _get_model_single_output() + model.compile(loss="mse", optimizer="adam") + state_tree = model.get_state_tree() + self.assertAllClose( + state_tree["trainable_variables"]["output_a"]["kernel"], + _get_variable_value_by_path( + model.trainable_variables, "output_a/kernel" + ), + ) + self.assertAllClose( + state_tree["trainable_variables"]["output_a"]["bias"], + _get_variable_value_by_path( + model.trainable_variables, "output_a/bias" + ), + ) + self.assertEqual( + state_tree["non_trainable_variables"], + {}, + ) + self.assertEqual( + state_tree["metrics_variables"]["loss"]["count"], + _get_variable_value_by_path(model.metrics_variables, "loss/count"), + ) + self.assertEqual( + state_tree["metrics_variables"]["loss"]["total"], + _get_variable_value_by_path(model.metrics_variables, "loss/total"), + ) + self.assertEqual( + state_tree["optimizer_variables"]["adam"]["iteration"], + _get_variable_value_by_path( + model.optimizer.variables, "adam/iteration" + ), + ) + self.assertEqual( + state_tree["optimizer_variables"]["adam"]["learning_rate"], + _get_variable_value_by_path( + model.optimizer.variables, "adam/learning_rate" + ), + ) + + # Test with numpy + state_tree = model.get_state_tree(value_format="numpy_array") + self.assertIsInstance( + state_tree["trainable_variables"]["output_a"]["kernel"], np.ndarray + ) + + def test_set_state_tree(self): + variables = { + "optimizer_variables": { + "adam": { + "iteration": 0, + "learning_rate": 0.00001, + } + }, + "trainable_variables": { + "output_a": { + "bias": [0.5], + "kernel": [[0.6], [0.7], [1.8]], + } + }, + } + + model = _get_model_single_output() + model.compile(optimizer="adam") + model.set_state_tree(variables) + + self.assertEqual( + variables["optimizer_variables"]["adam"]["iteration"], + _get_variable_value_by_path( + model.optimizer.variables, "adam/iteration" + ), + ) + self.assertEqual( + variables["optimizer_variables"]["adam"]["learning_rate"], + _get_variable_value_by_path( + model.optimizer.variables, "adam/learning_rate" + ), + ) + self.assertAllClose( + variables["trainable_variables"]["output_a"]["bias"], + _get_variable_value_by_path( + model.trainable_variables, "output_a/bias" + ), + ) + self.assertAllClose( + variables["trainable_variables"]["output_a"]["kernel"], + _get_variable_value_by_path( + model.trainable_variables, "output_a/kernel" + ), + ) + + def test_get_state_tree_with_duplicate_path(self): + model = _get_model_with_duplicate_variable_path() + with self.assertRaisesRegex( + ValueError, + "The following variable path is found twice in the model", + ): + model.get_state_tree() + + def test_layers_setter(self): + model = Model() + with self.assertRaisesRegex( + AttributeError, "`Model.layers` attribute is reserved" + ): + model.layers = [layers.Dense(4)] + + def get_struct_loss(self, structure): + def loss_fn(y_true, y_pred): + tree.assert_same_structure(structure, y_true) + tree.assert_same_structure(structure, y_pred) + tree.map_structure( + lambda spec, tensor: self.assertEqual(spec.ndim, tensor.ndim), + structure, + y_true, + ) + tree.map_structure( + lambda spec, tensor: self.assertEqual(spec.ndim, tensor.ndim), + structure, + y_pred, + ) + flat_y_pred = tree.flatten(y_pred) + flat_y_true = tree.flatten(y_true) + diff = 0 + for y_p, y_t in zip(flat_y_pred, flat_y_true): + diff += losses.mean_absolute_error(y_t, y_p) + return diff + + return loss_fn + + @parameterized.product( + _type=[tuple, list], other_type=[list, tuple], weighted=[False, True] + ) + def test_functional_struct_outputs_struct_losses( + self, _type, other_type, weighted + ): + model = _get_model_multi_outputs_struct_list_like(_type) + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.rand(8, 1) + y = _type([y1, y2]) + loss = other_type( + [ + self.get_struct_loss(model.output), + _type( + [ + self.get_struct_loss(model.output[0]), + self.get_struct_loss(model.output[1]), + ] + ), + ] + ) + if weighted: + loss_weights = tree.map_structure(lambda _: np.random.rand(), loss) + else: + loss_weights = None + + model.compile( + optimizer="sgd", + loss=loss, + loss_weights=loss_weights, + ) + + if _type is other_type: + with self.assertRaisesRegex( + ValueError, f"[Ee]xpected.*{_type.__name__}" + ): + model.fit(x, y, batch_size=2, epochs=1, verbose=0) + else: + # Check dict outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, _type) + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + y, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "y1_loss", + "y2_loss", + "y1_y2_loss", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + @parameterized.named_parameters(("weighted", True), ("not_weighted", False)) + def test_functional_struct_outputs_dict_struct_losses(self, weighted): + model = _get_model_multi_outputs_struct_dict() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.rand(8, 1) + + y = {"a": y1, "b": y2} + loss = [ + self.get_struct_loss(model.output), + { + "a": self.get_struct_loss(model.output["a"]), + "b": self.get_struct_loss(model.output["a"]), + }, + ] + if weighted: + loss_weights = tree.map_structure(lambda _: np.random.rand(), loss) + else: + loss_weights = None + + model.compile( + optimizer="sgd", + loss=loss, + loss_weights=loss_weights, + ) + # Check dict outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, dict) + + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + y, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "a_loss", + "b_loss", + "a_b_loss", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_struct_outputs_namedtuple_struct_losses(self): + model, Y = _get_model_multi_outputs_struct_namedtuple() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.rand(8, 1) + + y = Y(y1, y2) + model.compile( + optimizer="sgd", + loss=[ + self.get_struct_loss(model.output), + Y( + self.get_struct_loss(model.output.y1), + self.get_struct_loss(model.output.y2), + ), + ], + ) + # Check dict outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, tuple) + + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + y, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "loss", + "y1_loss", + "y2_loss", + "y1_y2_loss", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + def test_functional_deeply_nested_outputs_struct_losses(self): + model = _get_model_multi_outputs_struct() + self.assertIsInstance(model, Functional) + x = np.random.rand(8, 3) + y1 = np.random.rand(8, 1) + y2 = np.random.rand(8, 1) + y3 = np.random.rand(8, 1) + y = { + "a": (y1, y2), + "b": {"b1": y1, "b2": y2}, + "c": {"c1": (y1, y2), "c2": y2}, + "d": y3, + } + model.compile( + optimizer="sgd", + loss={ + "a": [ + self.get_struct_loss(model.output["a"]), + (None, self.get_struct_loss(model.output["a"][1])), + ], + "b": [ + self.get_struct_loss(model.output["b"]), + {"b1": self.get_struct_loss(model.output["b"]["b1"])}, + ], + "c": [ + self.get_struct_loss(model.output["c"]), + {"c1": self.get_struct_loss(model.output["c"]["c1"])}, + ], + "d": self.get_struct_loss(model.output["d"]), + }, + ) + # Check dict outputs. + outputs = model.predict(x) + self.assertIsInstance(outputs, dict) + + # Fit the model to make sure compile_metrics are built + hist = model.fit( + x, + y, + batch_size=2, + epochs=1, + verbose=0, + ) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + [ + "a/y2_loss", + "a_loss", + "b/b1_loss", + "b_loss", + "c/c1_loss", + "c_loss", + "d_loss", + "loss", + ] + ) + self.assertListEqual(hist_keys, ref_keys) + + @parameterized.named_parameters( + ("optional_none", True), ("optional_tensor", False) + ) + def test_functional_optional_inputs(self, is_optional_none): + model = _get_model_optional_inputs() + x = np.ones((2, 2)) + o = None if is_optional_none else np.ones((2, 2)) + y_true = np.ones((2, 2)) + + model.compile(loss="mse", optimizer="adam") + model.fit(x={"x": x, "o": o}, y=y_true) + model.evaluate(x={"x": x, "o": o}, y=y_true) + model.predict(x={"x": x, "o": o}) + + @parameterized.named_parameters( + ("optional_none", True), ("optional_tensor", False) + ) + def test_functional_optional_inputs_generator(self, is_optional_none): + model = _get_model_optional_inputs() + x = np.ones((2, 2)) + o = None if is_optional_none else np.ones((2, 2)) + y_true = np.ones((2, 2)) + + def data_generator(with_y=True): + for _ in range(4): + yield ({"x": x, "o": o},) + ((y_true,) if with_y else ()) + + model.compile(loss="mse", optimizer="adam") + model.fit(data_generator()) + model.evaluate(data_generator()) + model.predict(data_generator(with_y=False)) + + def test_export_error(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = _get_model() + + # Bad format + with self.assertRaisesRegex(ValueError, "Unrecognized format="): + model.export(temp_filepath, format="bad_format") + + # Bad backend + if backend.backend() not in ("tensorflow", "jax", "torch"): + with self.assertRaisesRegex( + NotImplementedError, + ( + r"`export_saved_model` only currently supports the " + r"tensorflow, jax and torch backends." + ), + ): + model.export(temp_filepath, format="tf_saved_model") diff --git a/keras/src/models/sequential.py b/keras/src/models/sequential.py new file mode 100644 index 000000000000..7d7daf6f1d2b --- /dev/null +++ b/keras/src/models/sequential.py @@ -0,0 +1,383 @@ +import copy +import inspect +import typing + +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state +from keras.src.backend.common import standardize_shape +from keras.src.layers.core.input_layer import InputLayer +from keras.src.layers.layer import Layer +from keras.src.legacy.saving import saving_utils +from keras.src.legacy.saving import serialization as legacy_serialization +from keras.src.models.functional import Functional +from keras.src.models.model import Model +from keras.src.saving import serialization_lib + + +@keras_export(["keras.Sequential", "keras.models.Sequential"]) +class Sequential(Model): + """`Sequential` groups a linear stack of layers into a `Model`. + + Examples: + + ```python + model = keras.Sequential() + model.add(keras.Input(shape=(16,))) + model.add(keras.layers.Dense(8)) + + # Note that you can also omit the initial `Input`. + # In that case the model doesn't have any weights until the first call + # to a training/evaluation method (since it isn't yet built): + model = keras.Sequential() + model.add(keras.layers.Dense(8)) + model.add(keras.layers.Dense(4)) + # model.weights not created yet + + # Whereas if you specify an `Input`, the model gets built + # continuously as you are adding layers: + model = keras.Sequential() + model.add(keras.Input(shape=(16,))) + model.add(keras.layers.Dense(8)) + len(model.weights) # Returns "2" + + # When using the delayed-build pattern (no input shape specified), you can + # choose to manually build your model by calling + # `build(batch_input_shape)`: + model = keras.Sequential() + model.add(keras.layers.Dense(8)) + model.add(keras.layers.Dense(4)) + model.build((None, 16)) + len(model.weights) # Returns "4" + + # Note that when using the delayed-build pattern (no input shape specified), + # the model gets built the first time you call `fit`, `eval`, or `predict`, + # or the first time you call the model on some input data. + model = keras.Sequential() + model.add(keras.layers.Dense(8)) + model.add(keras.layers.Dense(1)) + model.compile(optimizer='sgd', loss='mse') + # This builds the model for the first time: + model.fit(x, y, batch_size=32, epochs=10) + ``` + """ + + def __new__(cls, *args, **kwargs): + return typing.cast(cls, super().__new__(cls)) + + def __init__(self, layers=None, trainable=True, name=None): + super().__init__(trainable=trainable, name=name) + self._functional = None + self._layers = [] + if layers: + for layer in layers: + self.add(layer, rebuild=False) + self._maybe_rebuild() + + def add(self, layer, rebuild=True): + """Adds a layer instance on top of the layer stack. + + Args: + layer: layer instance. + """ + # Legacy case: if the first layer has an input_shape arg, + # use it to build an InputLayer. + if not self._layers: + if getattr(layer, "_input_shape_arg", None) is not None: + self.add(InputLayer(shape=layer._input_shape_arg)) + + # If we are passed a Keras tensor created by keras.Input(), we + # extract the input layer from its keras history and use that. + if hasattr(layer, "_keras_history"): + origin_layer = layer._keras_history[0] + if isinstance(origin_layer, InputLayer): + layer = origin_layer + if not isinstance(layer, Layer): + raise ValueError( + "Only instances of `keras.Layer` can be " + f"added to a Sequential model. Received: {layer} " + f"(of type {type(layer)})" + ) + if not self._is_layer_name_unique(layer): + raise ValueError( + "All layers added to a Sequential model " + f"should have unique names. Name '{layer.name}' is already " + "the name of a layer in this model. Update the `name` argument " + "to pass a unique name." + ) + if ( + isinstance(layer, InputLayer) + and self._layers + and isinstance(self._layers[0], InputLayer) + ): + raise ValueError( + f"Sequential model '{self.name}' has already been configured " + f"to use input shape {self._layers[0].batch_shape}. You cannot " + f"add a different Input layer to it." + ) + + self._layers.append(layer) + if rebuild: + self._maybe_rebuild() + else: + self.built = False + self._functional = None + + def pop(self, rebuild=True): + """Removes the last layer in the model. + + Args: + rebuild: `bool`. Whether to rebuild the model after removing + the layer. Defaults to `True`. + + Returns: + layer: layer instance. + """ + layer = self._layers.pop() + self.built = False + self._functional = None + if rebuild: + self._maybe_rebuild() + return layer + + def _maybe_rebuild(self): + self.built = False + self._functional = None + if isinstance(self._layers[0], InputLayer) and len(self._layers) > 1: + input_shape = self._layers[0].batch_shape + self.build(input_shape) + elif hasattr(self._layers[0], "input_shape") and len(self._layers) > 1: + # We can build the Sequential model if the first layer has the + # `input_shape` property. This is most commonly found in Functional + # model. + input_shape = self._layers[0].input_shape + self.build(input_shape) + + def _lock_state(self): + # Unlike other layers, Sequential is mutable after build. + pass + + def _obj_type(self): + return "Sequential" + + def build(self, input_shape=None): + try: + input_shape = standardize_shape(input_shape) + except: + # Do not attempt to build if the model does not have a single + # input tensor. + return + if not self._layers: + raise ValueError( + f"Sequential model {self.name} cannot be built because it has " + "no layers. Call `model.add(layer)`." + ) + if isinstance(self._layers[0], InputLayer): + if self._layers[0].batch_shape != input_shape: + raise ValueError( + f"Sequential model '{self.name}' has already been " + "configured to use input shape " + f"{self._layers[0].batch_shape}. You cannot build it " + f"with input_shape {input_shape}" + ) + else: + dtype = self._layers[0].compute_dtype + self._layers = [ + InputLayer(batch_shape=input_shape, dtype=dtype) + ] + self._layers + + # Build functional model + inputs = self._layers[0].output + x = inputs + for layer in self._layers[1:]: + try: + x = layer(x) + except NotImplementedError: + # Can happen if shape inference is not implemented. + # TODO: consider reverting inbound nodes on layers processed. + return + except TypeError as e: + signature = inspect.signature(layer.call) + positional_args = [ + param + for param in signature.parameters.values() + if param.default == inspect.Parameter.empty + ] + if len(positional_args) != 1: + raise ValueError( + "Layers added to a Sequential model " + "can only have a single positional argument, " + f"the input tensor. Layer {layer.__class__.__name__} " + f"has multiple positional arguments: {positional_args}" + ) + raise e + outputs = x + self._functional = Functional(inputs=inputs, outputs=outputs) + + def call(self, inputs, training=None, mask=None, **kwargs): + if self._functional: + return self._functional.call( + inputs, training=training, mask=mask, **kwargs + ) + + # Fallback: Just apply the layer sequence. + # This typically happens if `inputs` is a nested struct. + for layer in self.layers: + # During each iteration, `inputs` are the inputs to `layer`, and + # `outputs` are the outputs of `layer` applied to `inputs`. At the + # end of each iteration `inputs` is set to `outputs` to prepare for + # the next layer. + layer_kwargs = { + k: kwargs[k] + # only inject if this layer’s signature actually has that arg + for k in getattr(layer, "_call_has_context_arg", {}) + if k in kwargs + } + if layer._call_has_mask_arg: + layer_kwargs["mask"] = mask + if layer._call_has_training_arg and training is not None: + layer_kwargs["training"] = training + outputs = layer(inputs, **layer_kwargs) + inputs = outputs + + mask = tree.map_structure(backend.get_keras_mask, outputs) + return outputs + + @property + def layers(self): + # Historically, `sequential.layers` only returns layers that were added + # via `add`, and omits the auto-generated `InputLayer` that comes at the + # bottom of the stack. + layers = self._layers + if layers and isinstance(layers[0], InputLayer): + return layers[1:] + return layers[:] + + @layers.setter + def layers(self, _): + raise AttributeError( + "`Sequential.layers` attribute is reserved and should not be used. " + "Use `add()` and `pop()` to change the layers in this model." + ) + + def compute_output_spec(self, inputs, training=None, mask=None, **kwargs): + if self._functional: + return self._functional.compute_output_spec( + inputs, training=training, mask=mask, **kwargs + ) + # Direct application + for layer in self.layers: + outputs = layer.compute_output_spec( + inputs, + training=training, + **kwargs, + ) # Ignore mask + inputs = outputs + return outputs + + def compute_output_shape(self, input_shape): + if self._functional: + return self._functional.compute_output_shape(input_shape) + # Direct application + for layer in self.layers: + output_shape = layer.compute_output_shape(input_shape) + input_shape = output_shape + return output_shape + + @property + def input_shape(self): + if self._functional: + return self._functional.input_shape + raise AttributeError( + f"Sequential model '{self.name}' has no defined input shape yet." + ) + + @property + def output_shape(self): + if self._functional: + return self._functional.output_shape + raise AttributeError( + f"Sequential model '{self.name}' has no defined output shape yet." + ) + + @property + def inputs(self): + if self._functional: + return self._functional.inputs + raise AttributeError( + f"Sequential model '{self.name}' has no defined inputs yet." + ) + + @property + def outputs(self): + if self._functional: + return self._functional.outputs + raise AttributeError( + f"Sequential model '{self.name}' has no defined outputs yet." + ) + + @property + def input_dtype(self): + # Sequential.__call__ will try to convert its inputs + # to the dtype expected by its input layer, if any. + layers = self._layers + if layers and isinstance(layers[0], InputLayer): + return layers[0].dtype + return super().input_dtype + + def _is_layer_name_unique(self, layer): + for ref_layer in self._layers: + if layer.name == ref_layer.name and ref_layer is not layer: + return False + return True + + def get_config(self): + serialize_fn = serialization_lib.serialize_keras_object + if global_state.get_global_attribute("use_legacy_config", False): + # Legacy format serialization used for H5 and SavedModel formats + serialize_fn = legacy_serialization.serialize_keras_object + layer_configs = [] + for layer in super().layers: + # `super().layers` include the InputLayer if available (it is + # filtered out of `self.layers`). + layer_configs.append(serialize_fn(layer)) + config = Model.get_config(self) + config["name"] = self.name + config["layers"] = copy.deepcopy(layer_configs) + if self._functional is not None: + config["build_input_shape"] = self._layers[0].batch_shape + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + if "name" in config: + name = config["name"] + build_input_shape = config.get("build_input_shape") + layer_configs = config["layers"] + else: + name = None + layer_configs = config + model = cls(name=name) + for layer_config in layer_configs: + if "module" not in layer_config: + # Legacy format deserialization (no "module" key) + # used for H5 and SavedModel formats + layer = saving_utils.model_from_config( + layer_config, + custom_objects=custom_objects, + ) + else: + layer = serialization_lib.deserialize_keras_object( + layer_config, + custom_objects=custom_objects, + ) + model.add(layer) + if ( + not model._functional + and "build_input_shape" in locals() + and build_input_shape + and isinstance(build_input_shape, (tuple, list)) + ): + model.build(build_input_shape) + return model diff --git a/keras/src/models/sequential_test.py b/keras/src/models/sequential_test.py new file mode 100644 index 000000000000..d18cf4edc455 --- /dev/null +++ b/keras/src/models/sequential_test.py @@ -0,0 +1,387 @@ +import pickle + +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import saving +from keras.src import testing +from keras.src.layers.core.input_layer import Input +from keras.src.models.functional import Functional +from keras.src.models.model import Model +from keras.src.models.sequential import Sequential + + +@pytest.mark.requires_trainable_backend +class SequentialTest(testing.TestCase): + def test_basic_flow_with_input(self): + model = Sequential(name="seq") + model.add(Input(shape=(2,), batch_size=3)) + model.add(layers.Dense(4)) + model.add(layers.Dense(5)) + model.summary() + + self.assertEqual(len(model.layers), 2) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 4) + + # Test eager call + x = np.random.random((3, 2)) + y = model(x) + + self.assertEqual(type(model._functional), Functional) + self.assertEqual(y.shape, (3, 5)) + + # Test symbolic call + x = backend.KerasTensor((3, 2)) + y = model(x) + self.assertEqual(y.shape, (3, 5)) + + # Test `layers` constructor arg + model = Sequential( + layers=[ + Input(shape=(2,), batch_size=3), + layers.Dense(4), + layers.Dense(5), + ] + ) + self.assertEqual(len(model.layers), 2) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 4) + + x = np.random.random((3, 2)) + y = model(x) + self.assertEqual(y.shape, (3, 5)) + + # Test pop + model.pop() + self.assertEqual(len(model.layers), 1) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 2) + + x = np.random.random((3, 2)) + y = model(x) + self.assertEqual(y.shape, (3, 4)) + + def test_legacy_flow_with_input_shape(self): + model = Sequential(name="seq") + model.add(layers.Dense(4, input_shape=(2,))) + model.add(layers.Dense(5)) + + self.assertEqual(len(model.layers), 2) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 4) + self.assertEqual(type(model._functional), Functional) + + # Input_dim works too + model = Sequential(name="seq") + model.add(layers.Dense(4, input_dim=2)) + model.add(layers.Dense(5)) + + self.assertEqual(len(model.layers), 2) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 4) + self.assertEqual(type(model._functional), Functional) + + # Subsequent input_shapes are ignored + model = Sequential(name="seq") + model.add(layers.Dense(4, input_shape=(2,))) + model.add(layers.Dense(5, input_shape=(3, 4))) + + self.assertEqual(len(model.layers), 2) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 4) + self.assertEqual(type(model._functional), Functional) + + def test_basic_flow_deferred(self): + model = Sequential(name="seq") + model.add(layers.Dense(4)) + model.add(layers.Dense(5)) + model.summary() + + self.assertEqual(len(model.layers), 2) + + # Test eager call + x = np.random.random((3, 2)) + y = model(x) + self.assertTrue(model.built) + model.summary() + + self.assertEqual(type(model._functional), Functional) + self.assertEqual(y.shape, (3, 5)) + + # Test symbolic call + x = backend.KerasTensor((3, 2)) + y = model(x) + self.assertEqual(y.shape, (3, 5)) + + # Test `layers` constructor arg + model = Sequential( + layers=[ + layers.Dense(4), + layers.Dense(5), + ] + ) + x = np.random.random((3, 2)) + y = model(x) + self.assertEqual(y.shape, (3, 5)) + + # Test pop + model.pop() + self.assertEqual(len(model.layers), 1) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 2) + + x = np.random.random((3, 2)) + y = model(x) + self.assertEqual(y.shape, (3, 4)) + + def test_basic_flow_as_a_submodel(self): + # Build submodel + submodel = Sequential() + submodel.add(layers.Flatten()) + self.assertFalse(submodel.built) + + inputs = Input((None, 4)) + outputs = layers.TimeDistributed(submodel)(inputs) + model = Model(inputs=inputs, outputs=outputs) + + x = np.random.random((2, 3, 4)) + y = model(x) + self.assertEqual(y.shape, (2, 3, 4)) + + def test_basic_flow_with_functional_model_as_first_layer(self): + # Build functional model + inputs = Input((16, 16, 3)) + outputs = layers.Conv2D(4, 3, padding="same")(inputs) + functional_model = Model(inputs=inputs, outputs=outputs) + + model = Sequential( + [functional_model, layers.Flatten(), layers.Dense(1)] + ) + model.summary() + self.assertEqual(len(model.layers), 3) + self.assertTrue(model.built) + for layer in model.layers: + self.assertTrue(layer.built) + + # Test eager call + x = np.random.random((1, 16, 16, 3)) + y = model(x) + self.assertEqual(type(model._functional), Functional) + self.assertEqual(tuple(y.shape), (1, 1)) + + # Test symbolic call + x = backend.KerasTensor((1, 16, 16, 3)) + y = model(x) + self.assertEqual(y.shape, (1, 1)) + + def test_basic_flow_with_sequential_model_as_first_layer(self): + # Build sequential model + sequential_model = Sequential( + [Input((16, 16, 3)), layers.Conv2D(4, 3, padding="same")] + ) + + model = Sequential( + [sequential_model, layers.Flatten(), layers.Dense(1)] + ) + model.summary() + self.assertEqual(len(model.layers), 3) + self.assertTrue(model.built) + for layer in model.layers: + self.assertTrue(layer.built) + + # Test eager call + x = np.random.random((1, 16, 16, 3)) + y = model(x) + self.assertEqual(type(model._functional), Functional) + self.assertEqual(tuple(y.shape), (1, 1)) + + # Test symbolic call + x = backend.KerasTensor((1, 16, 16, 3)) + y = model(x) + self.assertEqual(y.shape, (1, 1)) + + def test_dict_inputs(self): + class DictLayer(layers.Layer): + def call(self, inputs): + assert isinstance(inputs, dict) + return inputs + + model = Sequential([DictLayer()]) + x = {"a": np.random.random((3, 2)), "b": np.random.random((3, 2))} + y = model(x) + self.assertEqual(type(y), dict) + model.summary() + + def test_list_inputs(self): + class ListLayer(layers.Layer): + def call(self, inputs): + assert isinstance(inputs, list) + return inputs + + model = Sequential([ListLayer()]) + x = [np.random.random((3, 2)), np.random.random((3, 2))] + y = model(x) + self.assertEqual(type(y), list) + model.summary() + + def test_nested_sequential(self): + # https://github.com/keras-team/keras/issues/20203 + model = Sequential() + model.add(Input(shape=(16,))) + Sequential([model]) + + def test_errors(self): + # Trying to pass 2 Inputs + model = Sequential() + model.add(Input(shape=(2,), batch_size=3)) + with self.assertRaisesRegex(ValueError, "already been configured"): + model.add(Input(shape=(2,), batch_size=3)) + with self.assertRaisesRegex(ValueError, "already been configured"): + model.add(layers.InputLayer(shape=(2,), batch_size=3)) + + # Same name 2x + model = Sequential() + model.add(layers.Dense(2, name="dense")) + with self.assertRaisesRegex(ValueError, "should have unique names"): + model.add(layers.Dense(2, name="dense")) + + # No layers + model = Sequential() + x = np.random.random((3, 2)) + with self.assertRaisesRegex(ValueError, "no layers"): + model(x) + + # Build conflict + model = Sequential() + model.add(Input(shape=(2,), batch_size=3)) + model.add(layers.Dense(2)) + with self.assertRaisesRegex(ValueError, "already been configured"): + model.build((3, 4)) + # But this works + model.build((3, 2)) + + def test_shape_inference_failure(self): + class DynamicLayer(layers.Layer): + def call(self, inputs): + return inputs + 1.0 + + def compute_output_spec(self, *args, **kwargs): + raise NotImplementedError + + model = Sequential([DynamicLayer()]) + x = np.random.random((3, 2)) + y = model(x) + self.assertAllClose(y, x + 1) + model.summary() + + def test_serialization(self): + # Unbuilt deferred + model = Sequential(name="seq") + model.add(layers.Dense(4)) + model.add(layers.Dense(5)) + revived = self.run_class_serialization_test(model) + self.assertLen(revived.layers, 2) + + # Built deferred + model.build((2, 3)) + revived = self.run_class_serialization_test(model) + self.assertLen(revived.layers, 2) + + # Regular + model = Sequential(name="seq") + model.add(Input(shape=(2,), batch_size=3)) + model.add(layers.Dense(4)) + model.add(layers.Dense(5)) + model.add(layers.Dense(6)) + revived = self.run_class_serialization_test(model) + self.assertLen(revived.layers, 3) + + # Weird + class DictLayer(layers.Layer): + def call(self, inputs): + assert isinstance(inputs, dict) + return inputs + + model = Sequential([DictLayer()]) + revived = self.run_class_serialization_test( + model, custom_objects={"DictLayer": DictLayer} + ) + self.assertLen(revived.layers, 1) + + def test_serialization_with_lambda_layer(self): + # https://github.com/keras-team/keras/issues/20074 + inputs = np.random.random(size=(1, 10, 4)).astype("float32") + CONV_WIDTH = 3 + model = Sequential([layers.Lambda(lambda x: x[:, -CONV_WIDTH:, :])]) + outputs = model(inputs) + + temp = self.get_temp_dir() + save_path = f"{temp}/model.keras" + model.save(save_path) + revived = saving.load_model(save_path, safe_mode=False) + revived_outputs = revived(inputs) + self.assertLen(revived.layers, 1) + self.assertAllClose(revived_outputs, outputs) + + def test_functional_properties(self): + model = Sequential(name="seq") + inputs = Input(shape=(2,)) + model.add(inputs) + model.add(layers.Dense(4)) + + self.assertEqual(model.inputs, [inputs]) + self.assertEqual(model.outputs, [model.layers[-1].output]) + self.assertEqual(model.input_shape, (None, 2)) + self.assertEqual(model.output_shape, (None, 4)) + + def test_pickleable(self): + model = Sequential(name="seq") + model.add(layers.Dense(4)) + + result = pickle.loads(pickle.dumps(model)) + assert len(result.layers) == 1 + + def test_bad_layer(self): + model = Sequential(name="seq") + with self.assertRaisesRegex(ValueError, "Only instances of"): + model.add({}) + + model = Sequential(name="seq") + + class BadLayer(layers.Layer): + def call(self, inputs, training): + return inputs + + model.add(BadLayer()) + with self.assertRaisesRegex( + ValueError, "can only have a single positional" + ): + model.build((None, 2)) + + def test_compute_output_shape(self): + layer = Sequential([layers.Dense(4), layers.Dense(8)]) + output_shape = layer.compute_output_shape((1, 2)) + self.assertEqual(output_shape, (1, 8)) + + def test_hasattr(self): + model = Sequential() + self.assertFalse(hasattr(model, "input_shape")) + self.assertFalse(hasattr(model, "output_shape")) + self.assertFalse(hasattr(model, "inputs")) + self.assertFalse(hasattr(model, "outputs")) + + model = Sequential([layers.Input((4,)), layers.Dense(8)]) + self.assertTrue(hasattr(model, "input_shape")) + self.assertTrue(hasattr(model, "output_shape")) + self.assertTrue(hasattr(model, "inputs")) + self.assertTrue(hasattr(model, "outputs")) + + def test_layers_setter(self): + model = Sequential() + with self.assertRaisesRegex( + AttributeError, r"Use `add\(\)` and `pop\(\)`" + ): + model.layers = [layers.Dense(4)] diff --git a/keras/src/models/variable_mapping.py b/keras/src/models/variable_mapping.py new file mode 100644 index 000000000000..e06ea5b09395 --- /dev/null +++ b/keras/src/models/variable_mapping.py @@ -0,0 +1,61 @@ +from keras.src.layers.layer import Layer +from keras.src.metrics.metric import Metric +from keras.src.optimizers.optimizer import Optimizer +from keras.src.saving import saving_lib +from keras.src.saving.keras_saveable import KerasSaveable + + +def map_saveable_variables(saveable, store, visited_saveables): + # If the saveable has already been seen, skip it. + if id(saveable) in visited_saveables: + return + + visited_saveables.add(id(saveable)) + + variables = [] + if isinstance(saveable, Layer): + variables = ( + saveable._trainable_variables + saveable._non_trainable_variables + ) + elif isinstance(saveable, Optimizer): + variables = saveable._variables + elif isinstance(saveable, Metric): + variables = saveable._variables + for v in variables: + if v.path in store: + raise ValueError( + "The model contains two variables with a duplicate path: " + f"path='{v.path}' appears at least twice. " + f"This path is used for {v} and for {store[v.path]}. " + "In order to get a variable map, make sure to use " + "unique paths/names for each variable." + ) + store[v.path] = v + + # Recursively save state of children saveables (layers, optimizers, etc.) + for child_attr, child_obj in saving_lib._walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + map_saveable_variables( + child_obj, + store, + visited_saveables=visited_saveables, + ) + elif isinstance(child_obj, (list, dict, tuple, set)): + map_container_variables( + child_obj, + store, + visited_saveables=visited_saveables, + ) + + +def map_container_variables(container, store, visited_saveables): + if isinstance(container, dict): + container = list(container.values()) + + for saveable in container: + if isinstance(saveable, KerasSaveable): + map_saveable_variables( + saveable, + store, + visited_saveables=visited_saveables, + ) diff --git a/keras/src/models/variable_mapping_test.py b/keras/src/models/variable_mapping_test.py new file mode 100644 index 000000000000..652e578289ce --- /dev/null +++ b/keras/src/models/variable_mapping_test.py @@ -0,0 +1,33 @@ +import numpy as np + +from keras.src import testing +from keras.src.saving import saving_lib_test + + +class VariableMappingTest(testing.TestCase): + def test_basics(self): + model = saving_lib_test._get_basic_functional_model() + model.optimizer.build(model.trainable_variables) + variable_map = model._get_variable_map() + + self.assertIn("first_dense/kernel", variable_map) + self.assertIn("second_dense/bias", variable_map) + self.assertIn("adam/learning_rate", variable_map) + + model = saving_lib_test._get_basic_sequential_model() + model.build((None, 1)) + model.optimizer.build(model.trainable_variables) + variable_map = model._get_variable_map() + self.assertIn("sequential/dense_1/bias", variable_map) + self.assertIn("adam/learning_rate", variable_map) + + model = saving_lib_test._get_subclassed_model() + model(np.ones((1, 1))) + model.optimizer.build(model.trainable_variables) + variable_map = model._get_variable_map() + self.assertIn("custom_model_x/my_dense_1/dense/kernel", variable_map) + self.assertIn("custom_model_x/my_dense_1/my_dict_weight", variable_map) + self.assertIn( + "custom_model_x/my_dense_1/my_additional_weight", variable_map + ) + self.assertIn("adam/learning_rate", variable_map) diff --git a/keras/src/ops/__init__.py b/keras/src/ops/__init__.py new file mode 100644 index 000000000000..754923c9fea5 --- /dev/null +++ b/keras/src/ops/__init__.py @@ -0,0 +1,16 @@ +# from keras.src.ops.numpy import Matmul, matmul +# from keras.src.ops.numpy import Add, add +# from keras.src.ops.numpy import Multiply, multiply + +from keras.src.backend import cast +from keras.src.backend import cond +from keras.src.backend import is_tensor +from keras.src.backend import name_scope +from keras.src.backend import random +from keras.src.ops import image +from keras.src.ops import operation_utils +from keras.src.ops.core import * # noqa: F403 +from keras.src.ops.linalg import * # noqa: F403 +from keras.src.ops.math import * # noqa: F403 +from keras.src.ops.nn import * # noqa: F403 +from keras.src.ops.numpy import * # noqa: F403 diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py new file mode 100644 index 000000000000..03cbcccd296e --- /dev/null +++ b/keras/src/ops/core.py @@ -0,0 +1,1255 @@ +import ml_dtypes +import numpy as np + +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.backend.common.backend_utils import slice_along_axis +from keras.src.ops.operation import Operation +from keras.src.saving import serialization_lib +from keras.src.utils import traceback_utils + + +class Map(Operation): + def call(self, f, xs): + return backend.core.map(f, xs) + + def compute_output_spec(self, f, xs): + x = tree.map_structure(lambda t: t[0], xs) + n = tree.flatten(xs)[0].shape[0] + y = backend.compute_output_spec(f, x) + + def append_batch_axis(t): + return KerasTensor( + shape=(n,) + t.shape, + dtype=t.dtype, + sparse=t.sparse, + ragged=t.ragged, + ) + + y = tree.map_structure(append_batch_axis, y) + return y + + +@keras_export("keras.ops.map") +def map(f, xs): + """Map a function over leading array axes. + + Like Python’s builtin map, except inputs and outputs are in the form of + stacked arrays. Consider using the `vectorized_map()` transform instead, + unless you need to apply a function element by element for reduced memory + usage or heterogeneous computation with other control flow primitives. + + When `xs` is an array type, the semantics of `map()` are given by this + Python implementation: + + ```python + def map(f, xs): + return np.stack([f(x) for x in xs]) + ``` + + Args: + f: Callable defines the function to apply element-wise over the first + axis or axes of `xs`. + xs: Values over which to map along the leading axis. + + Returns: + Mapped values. + + Examples: + + >>> f = lambda x: x**2 + >>> xs = keras.ops.arange(10) + >>> ys = keras.ops.map(f, xs) + >>> ys + [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + + >>> f = lambda x: {"y1": x**2, "y2": x * 10} # Can have nested outputs + >>> ys = keras.ops.map(f, xs) + >>> ys["y1"] + [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + >>> ys["y2"] + [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] + """ + if any_symbolic_tensors((xs,)): + return Map().symbolic_call(f, xs) + return backend.core.map(f, xs) + + +class Scan(Operation): + def __init__(self, length=None, reverse=False, unroll=1, *, name=None): + super().__init__(name=name) + self.length = length + self.reverse = reverse + self.unroll = unroll + + def call(self, f, init, xs=None): + return backend.core.scan( + f, + init, + xs, + length=self.length, + reverse=self.reverse, + unroll=self.unroll, + ) + + def compute_output_spec(self, f, init, xs=None): + if xs is None: + n = int(self.length) + x = None + else: + n = ( + int(self.length) + if self.length is not None + else tree.flatten(xs)[0].shape[0] + ) + x = xs[0] + + carry, y = backend.compute_output_spec(f, init, x) + y = KerasTensor(shape=(n,) + y.shape, dtype=y.dtype, sparse=y.sparse) + return carry, y + + +@keras_export("keras.ops.scan") +def scan(f, init, xs=None, length=None, reverse=False, unroll=1): + """Scan a function over leading array axes while carrying along state. + + When the type of `xs` is an array type or `None`, and the type of `ys` is an + array type, the semantics of `scan()` are given roughly by this Python + implementation: + + ```python + def scan(f, init, xs, length=None): + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, np.stack(ys) + ``` + + The loop-carried value `carry` (`init`) must hold a fixed shape and dtype + across all iterations. + + In TensorFlow, `y` must match `carry` in shape and dtype. This is not + required in other backends. + + Args: + f: Callable defines the logic for each loop iteration. This accepts two + arguments where the first is a value of the loop carry and the + second is a slice of `xs` along its leading axis. + This callable returns a pair where the first represents a new value + for the loop carry and the second represents a slice of the output. + init: The initial loop carry value. This can be a scalar, tensor, or any + nested structure. It must match the structure of the first element + returned by `f`. + xs: Optional value to scan along its leading axis. This can be a tensor + or any nested structure. If `xs` is not provided, you must specify + `length` to define the number of loop iterations. + Defaults to `None`. + length: Optional integer specifying the number of loop iterations. + If `length` is not provided, it defaults to the sizes of leading + axis of the arrays in `xs`. Defaults to `None`. + reverse: Optional boolean specifying whether to run the scan iteration + forward or in reverse, equivalent to reversing the leading axes of + the arrays in both `xs` and in `ys`. + unroll: Optional positive integer or boolean specifying how many scan + iterations to unroll within a single iteration of a loop. If an + integer is provided, it determines how many unrolled loop iterations + to run within a single rolled iteration of the loop. If a boolean is + provided, it will determine if the loop is completely unrolled + (`unroll=True`) or left completely unrolled (`unroll=False`). + Note that unrolling is only supported by JAX and TensorFlow + backends. + + Returns: + A pair where the first element represents the final loop carry value and + the second element represents the stacked outputs of `f` when scanned + over the leading axis of the inputs. + + Examples: + + >>> sum_fn = lambda c, x: (c + x, c + x) + >>> init = keras.ops.array(0) + >>> xs = keras.ops.array([1, 2, 3, 4, 5]) + >>> carry, result = keras.ops.scan(sum_fn, init, xs) + >>> carry + 15 + >>> result + [1, 3, 6, 10, 15] + """ + if any_symbolic_tensors((init, xs)): + return Scan( + length=length, reverse=reverse, unroll=unroll + ).symbolic_call(f, init, xs) + return backend.core.scan( + f, init, xs, length, reverse=reverse, unroll=unroll + ) + + +class AssociativeScan(Operation): + def __init__(self, reverse=False, axis=0, *, name=None): + super().__init__(name=name) + self.reverse = reverse + self.axis = axis + + def call(self, f, elems): + return backend.core.associative_scan( + f, elems, reverse=self.reverse, axis=self.axis + ) + + def compute_output_spec(self, f, elems): + elems_flat = tree.flatten(elems) + lens = [elem.shape[self.axis] for elem in elems_flat] + if len(set(lens)) != 1: + raise ValueError( + "Array inputs to associative_scan must have the same " + "first dimension. (saw: {})".format( + [elem.shape for elem in elems_flat] + ) + ) + + x = tree.pack_sequence_as( + elems, + [slice_along_axis(x, 0, 1, axis=self.axis) for x in elems_flat], + ) + y_spec = backend.compute_output_spec(f, x, x) + + def _restore_shape(x): + return KerasTensor( + shape=elems_flat[0].shape, dtype=x.dtype, sparse=x.sparse + ) + + y_spec = tree.map_structure(_restore_shape, y_spec) + return y_spec + + +@keras_export("keras.ops.associative_scan") +def associative_scan(f, elems, reverse=False, axis=0): + """Performs a scan with an associative binary operation, in parallel. + + This operation his similar to `scan`, with the key difference that + `associative_scan` is a parallel implementation with + potentially significant performance benefits, especially when jit compiled. + The catch is that it can only be used when `f` is a binary associative + operation (i.e. it must verify `f(a, f(b, c)) == f(f(a, b), c)`). + + For an introduction to associative scans, refer to this paper: + Blelloch, Guy E. 1990. + [Prefix Sums and Their Applications]( + https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf). + + Args: + f: A Python callable implementing an associative binary operation with + signature `r = f(a, b)`. Function `f` must be associative, i.e., + it must satisfy the equation + `f(a, f(b, c)) == f(f(a, b), c)`. + The inputs and result are (possibly nested Python tree structures + of) array(s) matching `elems`. Each array has a dimension in place + of the `axis` dimension. `f` should be applied elementwise over + the `axis` dimension. + The result `r` has the same shape (and structure) as the + two inputs `a` and `b`. + elems: A (possibly nested Python tree structure of) array(s), each with + an `axis` dimension of size `num_elems`. + reverse: A boolean stating if the scan should be reversed with respect + to the `axis` dimension. + axis: an integer identifying the axis over which the scan should occur. + + Returns: + A (possibly nested Python tree structure of) array(s) of the same shape + and structure as `elems`, in which the `k`'th element of `axis` is + the result of recursively applying `f` to combine the first `k` + elements of `elems` along `axis`. For example, given + `elems = [a, b, c, ...]`, the result would be + `[a, f(a, b), f(f(a, b), c), ...]`. + + Examples: + + >>> sum_fn = lambda x, y: x + y + >>> xs = keras.ops.arange(5) + >>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0) + >>> ys + [0, 1, 3, 6, 10] + + >>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]] + >>> xs = [keras.ops.array([[1, 2]]) for _ in range(3)] + >>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0) + >>> ys + [[1, 3], [1, 3], [1, 3]] + """ + if any_symbolic_tensors((elems,)): + return AssociativeScan(reverse=reverse, axis=axis).symbolic_call( + f, elems + ) + return backend.core.associative_scan(f, elems, reverse=reverse, axis=axis) + + +class Scatter(Operation): + def __init__(self, shape, *, name=None): + super().__init__(name=name) + self.shape = shape + + def call(self, indices, values): + return backend.core.scatter(indices, values, self.shape) + + def compute_output_spec(self, indices, values): + return KerasTensor(self.shape, dtype=values.dtype) + + +@keras_export("keras.ops.scatter") +def scatter(indices, values, shape): + """Returns a tensor of shape `shape` where `indices` are set to `values`. + + At a high level, this operation does `zeros[indices] = updates` and + returns the output. It is equivalent to: + + ```python + zeros = keras.ops.zeros(shape) + output = keras.ops.scatter_update(zeros, indices, values) + ``` + + Args: + indices: A tensor or list/tuple specifying + indices for the values in `values`. + values: A tensor, the values to be set at `indices`. + shape: Shape of the output tensor. + + Example: + + >>> indices = [[0, 1], [1, 1]] + >>> values = np.array([1., 1.]) + >>> keras.ops.scatter(indices, values, shape=(2, 2)) + array([[0., 1.], + [0., 1.]]) + """ + if any_symbolic_tensors((indices, values)): + return Scatter(shape=shape).symbolic_call(indices, values) + return backend.core.scatter(indices, values, shape) + + +class ScatterUpdate(Operation): + def call(self, inputs, indices, updates): + return backend.core.scatter_update(inputs, indices, updates) + + def compute_output_spec(self, inputs, indices, updates): + return KerasTensor(inputs.shape, dtype=inputs.dtype) + + +@keras_export("keras.ops.scatter_update") +def scatter_update(inputs, indices, updates): + """Update inputs via updates at scattered (sparse) indices. + + At a high level, this operation does `inputs[indices] = updates`. + Assume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main + usages of `scatter_update`. + + 1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates` + is the number of updates to perform, and `updates` is a 1D tensor of + shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`, + and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then + we can use: + + ```python + inputs = np.zeros((4, 4, 4)) + indices = [[1, 2, 3], [0, 1, 3]] + updates = np.array([1., 1.]) + inputs = keras.ops.scatter_update(inputs, indices, updates) + ``` + + 2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates` + is the number of updates to perform, and `k` (`k < n`) is the size of + each index in `indices`. `updates` is a `n - k`-D tensor of shape + `(num_updates, inputs.shape[k:])`. For example, if + `inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]` + and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape + `(num_updates, 2)` (`k = 2`), and `updates` would have shape + `(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below: + + ```python + inputs = np.zeros((4, 4, 4)) + indices = [[1, 2], [2, 3]] + updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,]) + inputs = keras.ops.scatter_update(inputs, indices, updates) + ``` + + Args: + inputs: A tensor, the tensor to be updated. + indices: A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying + indices to update. `N` is the number of indices to update, must be + equal to the first dimension of `updates`. + updates: A tensor, the new values to be put to `inputs` at `indices`. + + Returns: + A tensor, has the same shape and dtype as `inputs`. + """ + if any_symbolic_tensors((inputs, indices, updates)): + return ScatterUpdate().symbolic_call(inputs, indices, updates) + return backend.core.scatter_update(inputs, indices, updates) + + +class Slice(Operation): + def __init__(self, shape, *, name=None): + super().__init__(name=name) + self.shape = shape + + def call(self, inputs, start_indices): + return backend.core.slice(inputs, start_indices, self.shape) + + def compute_output_spec(self, inputs, start_indices): + if any(s == -1 for s in self.shape) and isinstance( + start_indices, KerasTensor + ): + raise ValueError( + "When using -1 in `shape`, `start_indices` should not be a " + "KerasTensor. " + ) + # If self.shape[i] is -1, all remaining elements in dimension i are + # included in the slice. + final_shape = tuple( + inputs.shape[i] - start_indices[i] if s == -1 else s + for i, s in enumerate(self.shape) + ) + return KerasTensor(final_shape, dtype=inputs.dtype) + + +@keras_export("keras.ops.slice") +def slice(inputs, start_indices, shape): + """Return a slice of an input tensor. + + At a high level, this operation is an explicit replacement for array slicing + e.g. `inputs[start_indices: start_indices + shape]`. + Unlike slicing via brackets, this operation will accept tensor start + indices on all backends, which is useful when indices dynamically computed + via other tensor operations. + + ```python + inputs = np.zeros((5, 5)) + start_indices = np.array([3, 3]) + shape = np.array([2, 2]) + inputs = keras.ops.slice(inputs, start_indices, shape) + ``` + + Args: + inputs: A tensor, the tensor to be updated. + start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying + the starting indices for updating. + shape: The full shape of the returned slice. + + Returns: + A tensor, has the same shape and dtype as `inputs`. + """ + if any_symbolic_tensors((inputs, start_indices)): + return Slice(shape=shape).symbolic_call(inputs, start_indices) + return backend.core.slice(inputs, start_indices, shape) + + +class SliceUpdate(Operation): + def call(self, inputs, start_indices, updates): + return backend.core.slice_update(inputs, start_indices, updates) + + def compute_output_spec(self, inputs, start_indices, updates): + return KerasTensor(inputs.shape, dtype=inputs.dtype) + + +@keras_export("keras.ops.slice_update") +def slice_update(inputs, start_indices, updates): + """Update an input by slicing in a tensor of updated values. + + At a high level, this operation does + `inputs[start_indices: start_indices + updates.shape] = updates`. + Assume inputs is a tensor of shape `(D0, D1, ..., Dn)`, + `start_indices` must be a list/tuple of n integers, specifying the starting + indices. `updates` must have the same rank as `inputs`, and the size of each + dim must not exceed `Di - start_indices[i]`. For example, if we have 2D + inputs `inputs = np.zeros((5, 5))`, and we want to update the intersection + of last 2 rows and last 2 columns as 1, i.e., + `inputs[3:, 3:] = np.ones((2, 2))`, then we can use the code below: + + ```python + inputs = np.zeros((5, 5)) + start_indices = [3, 3] + updates = np.ones((2, 2)) + inputs = keras.ops.slice_update(inputs, start_indices, updates) + ``` + + Args: + inputs: A tensor, the tensor to be updated. + start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying + the starting indices for updating. + updates: A tensor, the new values to be put to `inputs` at `indices`. + `updates` must have the same rank as `inputs`. + + Returns: + A tensor, has the same shape and dtype as `inputs`. + """ + if any_symbolic_tensors((inputs, start_indices, updates)): + return SliceUpdate().symbolic_call(inputs, start_indices, updates) + return backend.core.slice_update(inputs, start_indices, updates) + + +class Switch(Operation): + def call(self, index, branches, *operands): + return backend.core.switch(index, branches, *operands) + + def compute_output_spec(self, index, branches, *operands): + # We use first branch for output_spec + spec = backend.compute_output_spec(branches[0], *operands) + return spec + + +@keras_export("keras.ops.switch") +def switch(index, branches, *operands): + """Apply exactly one of the `branches` given by `index`. + + If `index` is out of bounds, it is clamped to within bounds. + + The semantics of `switch` are given roughly by this Python implementation: + + ```python + def switch(index, branches, *operands): + index = clamp(0, index, len(branches) - 1) + return branches[index](*operands) + ``` + + Args: + index: An integer scalar indicating which branch function to apply. + branches: A sequence of functions to be applied based on `index`. + operands: Inputs to whichever branch is applied. + + Returns: + The outputs of `branch(*operands)` for the branch that was selected + based on `index`. + + Examples: + + >>> add_fn = lambda x, y: x + y + >>> subtract_fn = lambda x, y: x - y + >>> x = keras.ops.array(2.0) + >>> y = keras.ops.array(0.5) + >>> branches = [add_fn, subtract_fn] + >>> keras.ops.switch(0, branches, x, y) + 2.5 + + >>> keras.ops.switch(1, branches, x, y) + 1.5 + """ + if any_symbolic_tensors(operands): + return Switch().symbolic_call(index, branches, *operands) + return backend.core.switch(index, branches, *operands) + + +class WhileLoop(Operation): + def __init__(self, cond, body, maximum_iterations=None, *, name=None): + super().__init__(name=name) + self.cond = cond + self.body = body + self.maximum_iterations = maximum_iterations + + def call(self, loop_vars): + return backend.core.while_loop( + self.cond, + self.body, + loop_vars, + maximum_iterations=self.maximum_iterations, + ) + + def compute_output_spec(self, loop_vars): + return tree.map_structure( + lambda v: KerasTensor(v.shape, dtype=v.dtype), loop_vars + ) + + +@keras_export("keras.ops.while_loop") +def while_loop( + cond, + body, + loop_vars, + maximum_iterations=None, +): + """While loop implementation. + + Args: + cond: A callable that represents the termination condition of the loop. + Must accept a `loop_vars` like structure as an argument. If + `loop_vars` is a tuple or list, each element of `loop_vars` will be + passed positionally to the callable. + body: A callable that represents the loop body. Must accept a + `loop_vars` like structure as an argument, and return update value + with the same structure. If `loop_vars` is a tuple or list, each + element of `loop_vars` will be passed positionally to the callable. + loop_vars: An arbitrary nested structure of tensor state to persist + across loop iterations. + maximum_iterations: Optional maximum number of iterations of the while + loop to run. If provided, the `cond` output is AND-ed with an + additional condition ensuring the number of iterations executed is + no greater than `maximum_iterations`. + + Returns: + A list/tuple of tensors, has the same shape and dtype as `inputs`. + + Examples: + + >>> i = 0 + >>> cond = lambda i: i < 10 + >>> body = lambda i: i + 1 + >>> keras.ops.while_loop(cond, body, i) + 10 + + >>> x, y = 0, 1 + >>> cond = lambda x, y: x < 10 + >>> body = lambda x, y: (x + 1, y + 1) + >>> keras.ops.while_loop(cond, body, (x, y)) + 10, 11 + """ + if any_symbolic_tensors((loop_vars,)): + return WhileLoop( + cond, body, maximum_iterations=maximum_iterations + ).symbolic_call(loop_vars) + return backend.core.while_loop( + cond, + body, + loop_vars, + maximum_iterations=maximum_iterations, + ) + + +class StopGradient(Operation): + def call(self, variable): + return backend.core.stop_gradient(variable) + + def compute_output_spec(self, variable): + return KerasTensor(variable.shape, dtype=variable.dtype) + + +@keras_export("keras.ops.stop_gradient") +def stop_gradient(variable): + """Stops gradient computation. + + Args: + variable: A tensor variable for which the gradient + computation is to be disabled. + + Returns: + The variable with gradient computation disabled. + + Examples: + + >>> var = keras.backend.convert_to_tensor( + ... [1., 2., 3.], + ... dtype="float32" + ... ) + >>> var = keras.ops.stop_gradient(var) + """ + if any_symbolic_tensors((variable,)): + return StopGradient().symbolic_call(variable) + return backend.core.stop_gradient(variable) + + +class ForiLoop(Operation): + def __init__(self, lower, upper, body_fun, *, name=None): + super().__init__(name=name) + self.lower = lower + self.upper = upper + self.body_fun = body_fun + + def call(self, init_val): + return backend.core.fori_loop( + self.lower, + self.upper, + self.body_fun, + init_val, + ) + + def compute_output_spec(self, init_val): + return KerasTensor(init_val.shape, dtype=init_val.dtype) + + +@keras_export("keras.ops.fori_loop") +def fori_loop(lower, upper, body_fun, init_val): + """For loop implementation. + + Args: + lower: The initial value of the loop variable. + upper: The upper bound of the loop variable. + body_fun: A callable that represents the loop body. Must take two + arguments: the loop variable and the loop state. The loop state + should be updated and returned by this function. + init_val: The initial value of the loop state. + + Returns: + The final state after the loop. + + Example: + + >>> lower = 0 + >>> upper = 10 + >>> body_fun = lambda i, s: (i + 1, s + i) + >>> init_val = 0 + >>> keras.ops.fori_loop(lower, upper, body_fun, init_val) + 45 + """ + if any_symbolic_tensors((lower, upper, init_val)): + return ForiLoop(lower, upper, body_fun).symbolic_call(init_val) + return backend.core.fori_loop(lower, upper, body_fun, init_val) + + +class Unstack(Operation): + def __init__(self, num=None, axis=0, *, name=None): + super().__init__(name=name) + self.num = num + self.axis = axis + + def call(self, x): + return backend.core.unstack(x, self.num, self.axis) + + def compute_output_spec(self, x): + axis = self.axis + if axis < 0: + axis = len(x.shape) + axis + output_shapes = x.shape[:axis] + x.shape[axis + 1 :] + num = self.num + if num is None: + num = x.shape[axis] + if num is None: + raise ValueError( + "Cannot infer argument `num` from shape " + f"{x.shape}. Either provide a tensor with a " + "concrete shape in the `axis` dimension or " + "explicitly pass the `num` argument." + ) + output = [ + KerasTensor(shape=output_shapes, dtype=x.dtype) for _ in range(num) + ] + return output + + +@keras_export("keras.ops.unstack") +def unstack(x, num=None, axis=0): + """Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors. + + Args: + x: The input tensor. + num: The length of the dimension axis. Automatically inferred + if `None`. + axis: The axis along which to unpack. + + Returns: + A list of tensors unpacked along the given axis. + + Example: + + >>> x = keras.ops.array([[1, 2], [3, 4]]) + >>> keras.ops.unstack(x, axis=0) + [array([1, 2]), array([3, 4])] + """ + if any_symbolic_tensors((x,)): + return Unstack(num, axis).symbolic_call(x) + return backend.core.unstack(x, num=num, axis=axis) + + +@keras_export("keras.ops.shape") +def shape(x): + """Gets the shape of the tensor input. + + Note: On the TensorFlow backend, when `x` is a `tf.Tensor` with dynamic + shape, dimensions which are dynamic in the context of a compiled function + will have a `tf.Tensor` value instead of a static integer value. + + Args: + x: A tensor. This function will try to access the `shape` attribute of + the input tensor. + + Returns: + A tuple of integers or None values, indicating the shape of the input + tensor. + + Example: + + >>> x = keras.ops.zeros((8, 12)) + >>> keras.ops.shape(x) + (8, 12) + """ + if any_symbolic_tensors((x,)): + return x.shape + return backend.core.shape(x) + + +@keras_export("keras.ops.dtype") +def dtype(x): + """Return the dtype of the tensor input as a standardized string. + + Note that due to the standardization, the dtype will not compare equal + to the backend-specific version of the dtype. + + Args: + x: A tensor. This function will try to access the `dtype` attribute of + the input tensor. + + Returns: + A string indicating the dtype of the input tensor, e.g. `"float32"`. + + Example: + + >>> x = keras.ops.zeros((8, 12)) + >>> keras.ops.dtype(x) + 'float32' + + """ + return backend.standardize_dtype(x.dtype) + + +class Cast(Operation): + def __init__(self, dtype, *, name=None): + super().__init__(name=name) + self.dtype = backend.standardize_dtype(dtype) + + def call(self, x): + return backend.core.cast(x, self.dtype) + + def compute_output_spec(self, x): + return backend.KerasTensor(shape=x.shape, dtype=self.dtype) + + +@keras_export("keras.ops.cast") +def cast(x, dtype): + """Cast a tensor to the desired dtype. + + Args: + x: A tensor or variable. + dtype: The target type. + + Returns: + A tensor of the specified `dtype`. + + Example: + + >>> x = keras.ops.arange(4) + >>> x = keras.ops.cast(x, dtype="float16") + """ + if any_symbolic_tensors((x,)): + return Cast(dtype=dtype)(x) + return backend.core.cast(x, dtype) + + +class SaturateCast(Operation): + def __init__(self, dtype, *, name=None): + super().__init__(name=name) + self.dtype = backend.standardize_dtype(dtype) + + def call(self, x): + return _saturate_cast(x, self.dtype) + + def compute_output_spec(self, x): + return backend.KerasTensor(shape=x.shape, dtype=self.dtype) + + +@keras_export("keras.ops.saturate_cast") +def saturate_cast(x, dtype): + """Performs a safe saturating cast to the desired dtype. + + Saturating cast prevents data type overflow when casting to `dtype` with + smaller values range. E.g. + `ops.cast(ops.cast([-1, 256], "float32"), "uint8")` returns `[255, 0]`, + but `ops.saturate_cast(ops.cast([-1, 256], "float32"), "uint8")` returns + `[0, 255]`. + + Args: + x: A tensor or variable. + dtype: The target type. + + Returns: + A safely casted tensor of the specified `dtype`. + + Example: + + Image resizing with bicubic interpolation may produce values outside + original range. + >>> image2x2 = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1) + >>> image4x4 = tf.image.resize(image2x2, (4, 4), method="bicubic") + >>> print(image4x4.numpy().squeeze()) + >>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ] + >>> # [ 52.526054 52.82143 53.407146 53.70253 ] + >>> # [201.29752 201.59288 202.17859 202.47395 ] + >>> # [276.32355 276.61893 277.20465 277.50006 ]] + + Casting this resized image back to `uint8` will cause overflow. + >>> image4x4_casted = ops.cast(image4x4, "uint8") + >>> print(image4x4_casted.numpy().squeeze()) + >>> # [[234 234 235 235] + >>> # [ 52 52 53 53] + >>> # [201 201 202 202] + >>> # [ 20 20 21 21]] + + Saturate casting to `uint8` will clip values to `uint8` range before + casting and will not cause overflow. + >>> image4x4_saturate_casted = ops.saturate_cast(image4x4, "uint8") + >>> print(image4x4_saturate_casted.numpy().squeeze()) + >>> # [[ 0 0 0 0] + >>> # [ 52 52 53 53] + >>> # [201 201 202 202] + >>> # [255 255 255 255]] + + """ + if any_symbolic_tensors((x,)): + return SaturateCast(dtype=dtype)(x) + return _saturate_cast(x, dtype) + + +def _saturate_cast(x, dtype, backend_module=None): + backend_module = backend_module or backend + + def get_dtype_min_max(dtype): + if "bool" == dtype: + dtype_min = 0 + dtype_max = 1 + elif "int" in dtype: + dtype_min = ml_dtypes.iinfo(dtype).min + dtype_max = ml_dtypes.iinfo(dtype).max + else: + dtype_min = ml_dtypes.finfo(dtype).min + dtype_max = ml_dtypes.finfo(dtype).max + return dtype_min, dtype_max + + dtype = backend.standardize_dtype(dtype) + in_dtype = backend.standardize_dtype(x.dtype) + in_min, in_max = get_dtype_min_max(in_dtype) + out_min, out_max = get_dtype_min_max(dtype) + + # The output min/max may not actually be representable in the + # in_dtype (e.g. casting float32 to uint32). This can lead to undefined + # behavior when trying to cast a value outside the valid range of the + # target type. We work around this by nudging the min/max to fall within + # the valid output range. The catch is that we may actually saturate + # to a value less than the true saturation limit, but this is the best we + # can do in order to avoid UB without backend op. + min_limit = np.maximum(in_min, out_min).astype(in_dtype) + if min_limit < out_min: + min_limit = np.nextafter(min_limit, 0, dtype=in_dtype) + max_limit = np.minimum(in_max, out_max).astype(in_dtype) + if max_limit > out_max: + max_limit = np.nextafter(max_limit, 0, dtype=in_dtype) + + # Unconditionally apply `clip` to fix `inf` behavior. + x = backend_module.numpy.clip(x, min_limit, max_limit) + + return backend_module.cast(x, dtype) + + +class ConvertToTensor(Operation): + def __init__(self, dtype=None, sparse=None, ragged=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + self.sparse = sparse + self.ragged = ragged + + def call(self, x): + return backend.core.convert_to_tensor( + x, dtype=self.dtype, sparse=self.sparse, ragged=self.ragged + ) + + def compute_output_spec(self, x): + dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + sparse = ( + False if self.sparse is not None and not self.sparse else x.sparse + ) + ragged = ( + False if self.ragged is not None and not self.ragged else x.ragged + ) + return backend.KerasTensor( + shape=x.shape, dtype=dtype, sparse=sparse, ragged=ragged + ) + + +@keras_export("keras.ops.convert_to_tensor") +def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): + """Convert a NumPy array or Python array to a tensor. + + Native tensors for the current backend or left unchanged unless the `dtype`, + `sparse` or `ragged` arguments are set. + + Args: + x: A NumPy array, Python array (can be nested) or a backend tensor. + dtype: The target type. If `None`, the type of `x` is used. + sparse: Whether to keep sparse tensors. `False` will cause sparse + tensors to be densified. The default value of `None` means that + sparse tensors are kept only if the backend supports them. + ragged: Whether to keep ragged tensors. `False` will cause ragged + tensors to be densified. The default value of `None` means that + ragged tensors are kept only if the backend supports them. + + Returns: + A backend tensor of the specified `dtype` and sparseness. + + Example: + + >>> x = np.array([1, 2, 3]) + >>> y = keras.ops.convert_to_tensor(x) + """ + if any_symbolic_tensors((x,)): + return ConvertToTensor(dtype=dtype, sparse=sparse, ragged=ragged)(x) + return backend.core.convert_to_tensor( + x, dtype=dtype, sparse=sparse, ragged=ragged + ) + + +@keras_export("keras.ops.convert_to_numpy") +def convert_to_numpy(x): + """Convert a tensor to a NumPy array. + + Args: + x: A tensor. + + Returns: + A NumPy array. + """ + if any_symbolic_tensors((x,)): + # This will raise a `ValueError` defined in the `KerasTensor` class. + # We trigger it rather than duplicate it here. + return np.array(x) + return backend.convert_to_numpy(x) + + +class Cond(Operation): + @traceback_utils.filter_traceback + def __call__(self, *args, **kwargs): + def call_fn(*args, **kwargs): + if any_symbolic_tensors(args, kwargs): + return self.symbolic_call(*args, **kwargs) + else: + return self.call(*args, **kwargs) + + if traceback_utils.is_traceback_filtering_enabled(): + # Wrap self.call to provide helpful info in case of exception + call_fn = traceback_utils.inject_argument_info_in_traceback( + call_fn, + object_name=(f"{self.__class__.__name__}.call()"), + ) + return call_fn(*args, **kwargs) + + # Plain flow. + return call_fn(*args, **kwargs) + + def call(self, pred, true_fn, false_fn): + return backend.core.cond(pred, true_fn, false_fn) + + def compute_output_spec(self, pred, true_fn, false_fn): + true_fn_spec = backend.compute_output_spec(true_fn) + false_fn_spec = backend.compute_output_spec(false_fn) + if not self._check_output_spec(true_fn_spec, false_fn_spec): + raise ValueError( + "`true_fn` and `false_fn` should return outputs " + "of the same kind (struct, dtype and shape). " + f"Got {true_fn_spec} and {false_fn_spec} instead." + ) + return true_fn_spec + + def _check_output_spec(self, true_fn_spec, false_fn_spec): + try: + tree.assert_same_structure(true_fn_spec, false_fn_spec) + except: + return False + + def check_leaf(t_spec, f_spec): + if t_spec is None or f_spec is None: + return t_spec is None and f_spec is None + return t_spec.shape == f_spec.shape and t_spec.dtype == f_spec.dtype + + same = tree.map_structure(check_leaf, true_fn_spec, false_fn_spec) + return all(tree.flatten(same)) + + +@keras_export("keras.ops.cond") +def cond(pred, true_fn, false_fn): + """Conditionally applies `true_fn` or `false_fn`. + + Args: + pred: Boolean scalar type + true_fn: Callable returning the output for the `pred == True` case. + false_fn: Callable returning the output for the `pred == False` case. + + Returns: + The output of either `true_fn` or `false_fn` depending on pred. + """ + return Cond()(pred, true_fn, false_fn) + + +class VectorizedMap(Operation): + def __init__(self, function, *, name=None): + super().__init__(name=name) + self.function = function + + def call(self, elements): + return backend.core.vectorized_map(self.function, elements) + + def compute_output_spec(self, elements): + x = tree.map_structure(lambda t: t[0], elements) + n = tree.flatten(elements)[0].shape[0] + y = backend.compute_output_spec(self.function, x) + + def append_batch_axis(t): + return KerasTensor( + shape=(n,) + t.shape, + dtype=t.dtype, + sparse=t.sparse, + ragged=t.ragged, + ) + + y = tree.map_structure(append_batch_axis, y) + return y + + def get_config(self): + config = super().get_config() + config.update({"function": self.function}) + return config + + @classmethod + def from_config(cls, config): + config = config.copy() + config["function"] = serialization_lib.deserialize_keras_object( + config["function"] + ) + return cls(**config) + + +@keras_export("keras.ops.vectorized_map") +def vectorized_map(function, elements): + """Parallel map of `function` on axis 0 of tensor(s) `elements`. + + Schematically, `vectorized_map` implements the following, + in the case of a single tensor input `elements`: + + ```python + def vectorized_map(function, elements): + outputs = [] + for e in elements: + outputs.append(function(e)) + return np.stack(outputs) + ``` + + In the case of an iterable of tensors `elements`, + it implements the following: + + ```python + def vectorized_map(function, elements): + batch_size = elements[0].shape[0] + outputs = [] + for index in range(batch_size): + outputs.append(function([e[index] for e in elements])) + return np.stack(outputs) + ``` + + In this case, `function` is expected to take as input + a single list of tensor arguments. + """ + if any_symbolic_tensors((elements,)): + return VectorizedMap(function)(elements) + return backend.core.vectorized_map(function, elements) + + +@keras_export("keras.ops.is_tensor") +def is_tensor(x): + """Check whether the given object is a tensor. + + Note: This checks for backend specific tensors so passing a TensorFlow + tensor would return `False` if your backend is PyTorch or JAX. + + Args: + x: A variable. + + Returns: + `True` if `x` is a tensor, otherwise `False`. + """ + return backend.core.is_tensor(x) + + +@keras_export("keras.ops.custom_gradient") +def custom_gradient(f): + """Decorator to define a function with a custom gradient. + + This decorator allows fine grained control over the gradients of a sequence + for operations. This may be useful for multiple reasons, including providing + a more efficient or numerically stable gradient for a sequence of + operations. + + Args: + f: Function `f(*args)` that returns a tuple + `(output, grad_fn)`, where: + - `args` is a sequence of (nested structures of) tensor inputs to + the function. + - `output` is a (nested structure of) tensor outputs of applying + operations in `forward_fn` to `args`. + - `grad_fn` is a function with the signature `grad_fn(*args, + upstream)` which returns a tuple of tensors the same size as + (flattened) `args`: the derivatives of tensors in `output` with + respect to the tensors in `args`. `upstream` is a tensor or + sequence of tensors holding the initial value gradients for each + tensor in `output`. + + Returns: + A function `h(*args)` which returns the same value as + `f(*args)[0]` and whose gradient is determined by + `f(*args)[1]`. + + + Examples: + + 1. Backend-agnostic example. + + ```python + @ops.custom_gradient + def log1pexp(x): + e = ops.exp(x) + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) + + return ops.log(1 + e), grad + ``` + + Note that the grad function that returns gradient computation + requires `args` as well as an `upstream` keyword argument, depending + on the backend being set. With the JAX and TensorFlow backends, + it requires only one argument, whereas it might use the `upstream` + argument in the case of the PyTorch backend. + + When working with TensorFlow/JAX backend, `grad(upstream)` + is sufficient. With PyTorch, the `grad` function requires + `*args` as well as `upstream`, e.g. `def grad(*args, upstream)`. + Follow the previous example to use `@ops.custom_gradient` in + a way that is compatible with all backends. + + 2. Here's JAX & TensorFlow-specific example: + + ```python + @ops.custom_gradient + def log1pexp(x): + e = ops.exp(x) + def grad(upstream): + return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) + return ops.log(1 + e), grad + ``` + + 3. Lastly, here's a PyTorch-specific example, + using `*args` & `upstream`: + + ```python + @ops.custom_gradient + def log1pexp(x): + e = ops.exp(x) + def grad(*args, upstream): + return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) + return ops.log(1 + e), grad + ``` + """ + return backend.core.custom_gradient(f) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py new file mode 100644 index 000000000000..ff49a4d34e05 --- /dev/null +++ b/keras/src/ops/core_test.py @@ -0,0 +1,1649 @@ +import operator +from unittest.mock import Mock + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import losses +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import testing +from keras.src import tree +from keras.src.backend.common import dtypes +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.layers.core import input_layer +from keras.src.ops import core +from keras.src.saving import object_registration +from keras.src.testing.test_utils import named_product + + +class CoreOpsDynamicShapeTest(testing.TestCase): + def test_associative_scan(self): + xs = (KerasTensor((5, None)), KerasTensor((5, None))) + ys = core.associative_scan( + f=lambda x, y: (x[0] + y[0], x[1] + y[1]), elems=xs, axis=0 + ) + self.assertEqual(ys[0].shape, (5, None)) + + # sum two tuples of unknown (but same) length at axis + def _fn(x, y): + return tuple([x[i] + y[i] for i in range(len(x))]) + + ys = core.associative_scan(f=_fn, elems=xs, axis=1) + self.assertEqual(ys[0].shape, (5, None)) + + def test_cast(self): + x = KerasTensor((3, 5, None), dtype="float32") + self.assertEqual(core.cast(x, "float16").shape, (3, 5, None)) + + def test_convert_to_tensor(self): + x = KerasTensor((2, None)) + self.assertEqual(core.convert_to_tensor(x).shape, (2, None)) + + def test_fori_loop(self): + def body_fun(i, x): + return x + i + + initial_value = KerasTensor((3, 5, None)) + self.assertEqual( + core.fori_loop(0, 10, body_fun, initial_value).shape, (3, 5, None) + ) + + def test_map(self): + def f(x): + return x**2 + + xs = KerasTensor((None, 5)) + self.assertEqual(core.map(f, xs).shape, (None, 5)) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = KerasTensor((None, 5)) + ys = core.map(f2, xs) + self.assertEqual(ys["a"].shape, (None, 5)) + self.assertEqual(ys["b"].shape, (None, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((None, 5)), KerasTensor((None, 5))) + self.assertEqual(core.map(f3, xs).shape, (None, 5)) + + def test_saturate_cast(self): + x = KerasTensor((3, 5, None), dtype="float32") + self.assertEqual(core.saturate_cast(x, "float16").shape, (3, 5, None)) + + def test_scan(self): + def f(carry, xs): + xs = xs + carry + return carry, carry + + init = KerasTensor((None,)) + xs = KerasTensor((6, None)) + carry, result = core.scan(f, init, xs) + self.assertEqual(carry.shape, (None,)) + self.assertEqual(result.shape, (6, None)) + + def f2(carry, _): + return carry, carry + + carry, result = core.scan(f2, init, xs=None, length=3) + self.assertEqual(carry.shape, (None,)) + self.assertEqual(result.shape, (3, None)) + + # Scatter doesn't support dynamic shape. + + def test_scatter_update(self): + inputs = KerasTensor((4, None)) + indices = KerasTensor((5, 2)) + updates = KerasTensor((5,)) + self.assertEqual( + core.scatter_update(inputs, indices, updates).shape, (4, None) + ) + + # Slice doesn't support dynamic shape. + + def test_slice_update(self): + inputs = KerasTensor((4, None)) + start_indices = KerasTensor((2,)) + updates = KerasTensor((2, 2)) + self.assertEqual( + core.slice_update(inputs, start_indices, updates).shape, (4, None) + ) + + def test_stop_gradient(self): + variable = KerasTensor(shape=(3, None), dtype="float32") + self.assertEqual(core.stop_gradient(variable).shape, (3, None)) + + def test_switch(self): + def fn(x, y): + return x[:, 0], y[0, :] + + index = KerasTensor(()) + x = KerasTensor((None, 2)) + y = KerasTensor((5, None)) + result = core.switch(index, [fn], x, y) + self.assertEqual(result[0].shape, (None,)) + self.assertEqual(result[1].shape, (None,)) + + def test_vectorized_map(self): + def f(x): + return x**2 + + xs = KerasTensor((None, 5)) + self.assertEqual(core.vectorized_map(f, xs).shape, (None, 5)) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = KerasTensor((None, 5)) + ys = core.vectorized_map(f2, xs) + self.assertEqual(ys["a"].shape, (None, 5)) + self.assertEqual(ys["b"].shape, (None, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((None, 5)), KerasTensor((None, 5))) + self.assertEqual(core.vectorized_map(f3, xs).shape, (None, 5)) + + def test_while_loop(self): + def cond(args): + return tree.flatten(args)[0] < 10 + + def body(args): + return tree.map_structure(lambda x: x + 1, args) + + loop_vars = KerasTensor((None,)) + self.assertEqual(core.while_loop(cond, body, loop_vars).shape, (None,)) + + def test_unstack(self): + x = KerasTensor((2, None, None)) + axis, num = 1, 3 + out = core.unstack(x, num=num, axis=axis) + self.assertEqual(len(out), 3) + for o in out: + self.assertEqual(o.shape, (2, None)) + + +class CoreOpsStaticShapeTest(testing.TestCase): + def test_associative_scan(self): + xs = (KerasTensor((5, 10)), KerasTensor((5, 10))) + ys = core.associative_scan( + f=lambda x, y: (x[0] + y[0], x[1] + y[1]), elems=xs, axis=0 + ) + self.assertEqual(ys[0].shape, (5, 10)) + + # sum two tuples of unknown (but same) length at axis + def _fn(x, y): + return tuple([x[i] + y[i] for i in range(len(x))]) + + ys = core.associative_scan(f=_fn, elems=xs, axis=1) + self.assertEqual(ys[0].shape, (5, 10)) + + def test_cast(self): + x = KerasTensor((3, 5, 7), dtype="float32") + self.assertEqual(core.cast(x, "float16").shape, (3, 5, 7)) + + def test_cond(self): + pred = KerasTensor((), dtype="bool") + self.assertEqual( + ops.cond( + pred, lambda: ops.ones((1, 3)), lambda: ops.zeros((1, 3)) + ).shape, + (1, 3), + ) + + def test_convert_to_tensor(self): + x = KerasTensor((2, 3)) + out = core.convert_to_tensor(x) + self.assertEqual(out.shape, x.shape) + self.assertFalse(out.sparse) + + out = core.convert_to_tensor(x, sparse=True) + self.assertFalse(out.sparse) + + x = KerasTensor((2, 3), sparse=True) + out = core.convert_to_tensor(x) + self.assertTrue(out.sparse) + + out = core.convert_to_tensor(x, sparse=True) + self.assertTrue(out.sparse) + + out = core.convert_to_tensor(x, sparse=False) + self.assertFalse(out.sparse) + + def test_fori_loop(self): + def body_fun(i, x): + return x + i + + initial_value = KerasTensor((3, 5, 7)) + result = core.fori_loop(0, 10, body_fun, initial_value) + self.assertEqual(result.shape, (3, 5, 7)) + + def test_map(self): + def f(x): + return x**2 + + xs = KerasTensor((6, 5)) + ys = core.map(f, xs) + self.assertEqual(ys.shape, (6, 5)) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = KerasTensor((6, 5)) + ys = core.map(f2, xs) + self.assertEqual(ys["a"].shape, (6, 5)) + self.assertEqual(ys["b"].shape, (6, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((6, 5)), KerasTensor((6, 5))) + self.assertEqual(core.map(f3, xs).shape, (6, 5)) + + def test_saturate_cast(self): + x = KerasTensor((3, 5, 7), dtype="float32") + self.assertEqual(core.saturate_cast(x, "float16").shape, (3, 5, 7)) + + def test_scan(self): + def f(carry, xs): + xs = xs + carry + return carry, carry + + init = KerasTensor(()) + xs = KerasTensor((6,)) + carry, result = core.scan(f, init, xs) + self.assertEqual(carry.shape, ()) + self.assertEqual(result.shape, (6,)) + + def f2(carry, _): + return carry, carry + + carry, result = core.scan(f2, init, xs=None, length=3) + self.assertEqual(carry.shape, ()) + self.assertEqual(result.shape, (3,)) + + def test_scatter(self): + indices = KerasTensor((5, 2)) + values = KerasTensor((5,)) + shape = (4, 4) + self.assertEqual(core.scatter(indices, values, shape).shape, (4, 4)) + + def test_scatter_update(self): + inputs = KerasTensor((4, 4)) + indices = KerasTensor((5, 2)) + updates = KerasTensor((5,)) + self.assertEqual( + core.scatter_update(inputs, indices, updates).shape, (4, 4) + ) + + inputs = KerasTensor((4, 4, 4)) + indices = KerasTensor((5, 2)) + updates = KerasTensor((5, 4)) + self.assertEqual( + core.scatter_update(inputs, indices, updates).shape, (4, 4, 4) + ) + + def test_slice(self): + inputs = KerasTensor(shape=(3, 3), dtype="float32") + start_indices = KerasTensor(shape=(2,), dtype="int32") + shape = (2, 2) + self.assertEqual(core.slice(inputs, start_indices, shape).shape, (2, 2)) + + def test_slice_negative_one_shape(self): + inputs = KerasTensor(shape=(3, 3), dtype="float32") + start_indices = (1, 1) + shape = (-1, -1) + self.assertEqual(core.slice(inputs, start_indices, shape).shape, (2, 2)) + + def test_slice_negative_one_shape_raises(self): + inputs = KerasTensor(shape=(3, 3), dtype="float32") + start_indices = KerasTensor(shape=(2,), dtype="int32") + shape = (-1, -1) + with self.assertRaises(ValueError): + core.slice(inputs, start_indices, shape) + + def test_slice_update(self): + inputs = KerasTensor((4, 4)) + start_indices = KerasTensor((2,)) + updates = KerasTensor((2, 2)) + self.assertEqual( + core.slice_update(inputs, start_indices, updates).shape, (4, 4) + ) + + inputs = KerasTensor((4, 4, 4)) + start_indices = KerasTensor((3,)) + updates = KerasTensor((2, 2, 2)) + self.assertEqual( + core.slice_update(inputs, start_indices, updates).shape, (4, 4, 4) + ) + + def test_stop_gradient(self): + variable = KerasTensor(shape=(3, 3), dtype="float32") + self.assertEqual(core.stop_gradient(variable).shape, (3, 3)) + + def test_switch(self): + def fn(x, y): + return x[:, 0], y[0, :] + + index = KerasTensor(()) + x = KerasTensor((5, 2)) + y = KerasTensor((5, 2)) + self.assertEqual(core.switch(index, [fn], x, y)[0].shape, (5,)) + self.assertEqual(core.switch(index, [fn], x, y)[1].shape, (2,)) + + def test_vectorized_map(self): + def f(x): + return x**2 + + xs = KerasTensor((6, 5)) + ys = core.vectorized_map(f, xs) + self.assertEqual(ys.shape, (6, 5)) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = KerasTensor((6, 5)) + ys = core.vectorized_map(f2, xs) + self.assertEqual(ys["a"].shape, (6, 5)) + self.assertEqual(ys["b"].shape, (6, 5)) + + # Test nested input + def f3(x): + return x[0] + x[1] + + xs = (KerasTensor((6, 5)), KerasTensor((6, 5))) + self.assertEqual(core.vectorized_map(f3, xs).shape, (6, 5)) + + def test_while_loop(self): + def cond(args): + return tree.flatten(args)[0] < 10 + + def body(args): + return tree.map_structure(lambda x: x + 1, args) + + loop_vars = KerasTensor((10,)) + self.assertEqual(core.while_loop(cond, body, loop_vars).shape, (10,)) + + def test_unstack(self): + x = KerasTensor((2, 3, 4)) + axis = 1 + out = core.unstack(x, axis=axis) + self.assertEqual(len(out), 3) + for o in out: + self.assertEqual(o.shape, (2, 4)) + + +class CoreOpsCorrectnessTest(testing.TestCase): + def test_associative_scan(self): + # Test prefix sum + arr = np.arange(5) + result = core.associative_scan(f=operator.add, elems=arr) + self.assertAllEqual(result, [0, 1, 3, 6, 10]) + # Test reverse + result = core.associative_scan(f=operator.add, elems=arr, reverse=True) + self.assertAllEqual(result, [10, 10, 9, 7, 4]) + + # Test multiple dimensions, across different axes + batched_arr = np.stack([arr, arr + 1, arr + 2]) + result = core.associative_scan( + f=operator.add, elems=batched_arr, axis=1 + ) + self.assertAllEqual(result[2], [2, 5, 9, 14, 20]) + result = core.associative_scan( + f=operator.add, elems=batched_arr, axis=0 + ) + self.assertAllEqual(result[:, 0], [0, 1, 3]) + + # Test structured input + elems = { + "a": np.array([[0, 1, 2], [3, 4, 5]]), + "b": np.array([[6, 7, 8], [9, 10, 11]]), + } + + def _dict_add(x, y): + return {"a": x["a"] + y["b"], "b": x["b"] + y["b"]} + + ax0 = core.associative_scan(f=_dict_add, elems=elems, axis=0) + self.assertAllEqual( + ax0["b"], + [[6, 7, 8], [15, 17, 19]], + ) + + # Test parallel scan op used in mamba + b, l, d, n = 1, 2, 3, 4 + DB = np.random.rand(b, l, d, n) + DA = np.random.rand(b, l, d, n) + + H_seq = np.zeros((b, d, n)) + for i in range(l): + H_seq = DA[:, i] * H_seq + DB[:, i] + + def scan_op(ci, cj): + a = cj[0] * ci[0] + b = cj[0] * ci[1] + cj[1] + return (a, b) + + inputs = (DA.transpose(1, 0, 2, 3), DB.transpose(1, 0, 2, 3)) + H_par = core.associative_scan(f=scan_op, elems=inputs)[-1][-1] + + self.assertAllClose(H_seq, H_par) + + # Test Operation call. + xs = np.arange(5, dtype="float32") + self.assertAllClose( + core.AssociativeScan()(operator.add, xs), ops.cumsum(xs) + ) + + def test_cast(self): + x = ops.ones((2,), dtype="float32") + y = ops.cast(x, "float16") + self.assertIn("float16", str(y.dtype)) + + x = ops.KerasTensor((2,), dtype="float32") + y = ops.cast(x, "float16") + self.assertEqual("float16", y.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(y, "_keras_history")) + + # Test Operation call. + x = ops.ones((2,), dtype="float32") + self.assertDType(core.Cast("float16")(x), "float16") + + @parameterized.named_parameters( + ("float8_e4m3fn", "float8_e4m3fn"), ("float8_e5m2", "float8_e5m2") + ) + def test_cast_float8(self, float8_dtype): + # Cast to float8 and cast back + x = ops.ones((2,), dtype="float32") + y = ops.cast(x, float8_dtype) + self.assertIn(float8_dtype, str(y.dtype)) + x = ops.cast(y, "float32") + self.assertIn("float32", str(x.dtype)) + + x = ops.KerasTensor((2,), dtype="float32") + y = ops.cast(x, float8_dtype) + self.assertEqual(float8_dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(y, "_keras_history")) + x = ops.cast(y, "float32") + self.assertEqual("float32", x.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(x, "_keras_history")) + + def test_cond(self): + t = ops.cond(True, lambda: 0, lambda: 1) + self.assertEqual(t, 0) + f = ops.cond(False, lambda: 0, lambda: 1) + self.assertEqual(f, 1) + f = ops.cond(False, lambda: None, lambda: None) + self.assertEqual(f, None) + + out = ops.cond( + ops.convert_to_tensor(True), + lambda: ops.ones((1, 3)), + lambda: ops.zeros((1, 3)), + ) + self.assertAllClose(out, ops.ones((1, 3))) + + out = ops.cond( + ops.convert_to_tensor(False), + lambda: ops.ones((3,)), + lambda: ops.zeros((3,)), + ) + self.assertAllClose(out, ops.zeros((3,))) + + with self.assertRaises(ValueError): + ops.cond( + KerasTensor((), dtype="bool"), + lambda: ops.ones((3,)), + lambda: ops.zeros((4,)), + ) + + def test_convert_to_tensor(self): + x = np.ones((2,)) + x = ops.convert_to_tensor(x) + x = ops.convert_to_numpy(x) + self.assertAllEqual(x, (1, 1)) + self.assertIsInstance(x, np.ndarray) + + # Empty lists should give an empty array. + x = ops.convert_to_tensor([]) + np_x = ops.convert_to_numpy(x) + self.assertTrue(ops.is_tensor(x)) + self.assertAllEqual(x, []) + self.assertIsInstance(np_x, np.ndarray) + + # Partially converted. + x = ops.convert_to_tensor((1, ops.array(2), 3)) + self.assertAllEqual(x, (1, 2, 3)) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason=f"{backend.backend()} backend doesn't support sparse tensors.", + ) + def test_convert_to_tensor_sparse(self): + if backend.backend() == "tensorflow": + import tensorflow as tf + + x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3)) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 3)) + else: + self.fail(f"Sparse is unsupported with backend {backend.backend()}") + + x_default = ops.convert_to_tensor(x) + self.assertSparse(x_default) + self.assertAllClose(x, x_default) + x_sparse = ops.convert_to_tensor(x, sparse=True) + self.assertSparse(x_sparse) + self.assertAllClose(x, x_sparse) + x_dense = ops.convert_to_tensor(x, sparse=False) + self.assertSparse(x_dense, False) + self.assertAllClose(x, x_dense) + + x_numpy = ops.convert_to_numpy(x) + self.assertIsInstance(x_numpy, np.ndarray) + self.assertAllClose(x_numpy, x_dense) + + @pytest.mark.skipif( + not backend.SUPPORTS_RAGGED_TENSORS, + reason=f"{backend.backend()} backend doesn't support ragged tensors.", + ) + def test_convert_to_tensor_ragged(self): + import tensorflow as tf + + x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + + x_default = ops.convert_to_tensor(x) + self.assertIsInstance(x_default, tf.RaggedTensor) + self.assertAllClose(x, x_default) + x_ragged = ops.convert_to_tensor(x, ragged=True) + self.assertIsInstance(x_ragged, tf.RaggedTensor) + self.assertAllClose(x, x_ragged) + x_dense = ops.convert_to_tensor(x, ragged=False) + self.assertNotIsInstance(x_dense, tf.RaggedTensor) + self.assertAllClose(x, x_dense) + + x_numpy = ops.convert_to_numpy(x) + self.assertIsInstance(x_numpy, np.ndarray) + self.assertAllClose(x_numpy, x_dense) + + @pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=( + f"{backend.backend()} backend doesn't support `custom_gradient`." + ), + ) + def test_custom_gradient(self): + # function to test custom_gradient on + @ops.custom_gradient + def log1pexp(x): + e = ops.exp(x) + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e)) + + return ops.log(1 + e), grad + + def log1pexp_nan(x): + return ops.log(1 + ops.exp(x)) + + x = ops.convert_to_tensor(100.0) + if backend.backend() == "tensorflow": + import tensorflow as tf + + with tf.GradientTape() as tape1: + tape1.watch(x) + y = log1pexp(x) + with tf.GradientTape() as tape2: + tape2.watch(x) + z = log1pexp_nan(x) + dy_dx = tape1.gradient(y, x) + dz_dx = tape2.gradient(z, x) + self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0) + elif backend.backend() == "jax": + import jax + + dy_dx = jax.grad(log1pexp)(x) + dz_dx = jax.grad(log1pexp_nan)(x) + self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0) + self.assertTrue(ops.isnan(dz_dx)) + elif backend.backend() == "torch": + import torch + + x = torch.tensor(100.0, requires_grad=True) + z = log1pexp(x) + z.sum().backward() + self.assertEqual(ops.convert_to_numpy(x.grad), 1.0) + + def test_dynamic_slice(self): + def cond(index, inputs, sum): + return index < 10 + + def body(index, inputs, sum): + sum = sum + core.slice(inputs, [index], [1]) + index = index + 1 + return index, inputs, sum + + index, inputs, sum = 0, np.arange(10), np.array([0]) + index, inputs, sum = core.while_loop(cond, body, (index, inputs, sum)) + self.assertEqual(sum.shape, (1,)) + self.assertAllClose(sum, [45]) + + def test_fori_loop(self): + def body_fun(i, x): + return x + i + + initial_value = np.array(0) + result = core.fori_loop(0, 10, body_fun, initial_value) + self.assertAllClose(result, 45) + + # Test Operation call. + self.assertAllClose(core.ForiLoop(0, 10, body_fun)(initial_value), 45) + + def test_getitem(self): + np_tensor = np.arange(24).reshape(2, 3, 4) + tensor = ops.convert_to_tensor(np_tensor) + + t = tensor[1] + n = np_tensor[1] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1, 2, 3] + n = np_tensor[1, 2, 3] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1:2] + n = np_tensor[1:2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1:2, 2:3, 3:4] + n = np_tensor[1:2, 2:3, 3:4] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1:2, None] + n = np_tensor[1:2, None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1:2, 2:3, ...] + n = np_tensor[1:2, 2:3, ...] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1:2, ..., 3:4] + n = np_tensor[1:2, ..., 3:4] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[None, ..., 3:4, None] + n = np_tensor[None, ..., 3:4, None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1:2:None] + n = np_tensor[1:2:None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[:, 2] + n = np_tensor[:, 2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[None] + n = np_tensor[None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[None, None] + n = np_tensor[None, None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[...] + n = np_tensor[...] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[..., 1] + n = np_tensor[..., 1] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[..., 1, 2] + n = np_tensor[..., 1, 2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[..., -1, 2] + n = np_tensor[..., -1, 2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[..., -1:-2, 2] + n = np_tensor[..., -1:-2, 2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[..., None, None] + n = np_tensor[..., None, None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[None, ..., None] + n = np_tensor[None, ..., None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1, 2, None, ..., None] + n = np_tensor[1, 2, None, ..., None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[None, ..., 1, 2] + n = np_tensor[None, ..., 1, 2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1, None, 2] + n = np_tensor[1, None, 2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32)) + t = tensor[index_tensor] + n = np_tensor[ops.convert_to_numpy(index_tensor)] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32)) + t = tensor[index_tensor, 2, None] + n = np_tensor[ops.convert_to_numpy(index_tensor), 2, None] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + index_tensor = ops.convert_to_tensor(np.array(-2, dtype=np.int32)) + t = tensor[index_tensor, 1] + n = np_tensor[ops.convert_to_numpy(index_tensor), 1] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + index_tensor = ops.convert_to_tensor(np.array(-1, dtype=np.int32)) + t = tensor[-2, index_tensor] + n = np_tensor[-2, ops.convert_to_numpy(index_tensor)] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + # Negative indexing + t = tensor[-1] + n = np_tensor[-1] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[1, -1, -2] + n = np_tensor[1, -1, -2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + # Slicing with step + t = tensor[::2] + n = np_tensor[::2] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + # Mixed slices and integers + t = tensor[1, :, 1:4] + n = np_tensor[1, :, 1:4] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + t = tensor[:, 1:2, 3] + n = np_tensor[:, 1:2, 3] + self.assertEqual(t.shape, n.shape) + self.assertAllClose(t, n) + + def test_is_tensor(self): + np_x = np.array([[1, 2, 3], [3, 2, 1]]) + x = backend.convert_to_tensor(np_x) + if backend.backend() != "numpy": + self.assertFalse(ops.is_tensor(np_x)) + self.assertTrue(ops.is_tensor(x)) + self.assertFalse(ops.is_tensor([1, 2, 3])) + + def test_map(self): + def f(x): + return x**2 + + xs = np.arange(10) + self.assertAllClose(ops.map(f, xs), xs**2) + + # Test nested output + def f2(x): + return {"a": x**2, "b": x * 10} + + xs = np.random.rand(2, 3, 4).astype("float32") + outputs = ops.map(f2, xs) + self.assertAllClose(outputs["a"], xs**2) + self.assertAllClose(outputs["b"], xs * 10) + + # Test with nested structures + def dict_input_fn(inputs): + x = inputs["x"][:, 0] + y = inputs["y"] + 1 + return {"x": x, "y": y} + + def list_input_fn(inputs): + return [x**2 for x in inputs] + + xs = { + "x": ops.convert_to_tensor( + np.random.rand(4, 100, 3), dtype="float32" + ), + "y": ops.convert_to_tensor( + np.random.randint(0, 10, size=(4, 1)), dtype="int32" + ), + } + xs1 = [ + ops.convert_to_tensor(np.random.rand(4, 100, 3), dtype="float32"), + ops.convert_to_tensor( + np.random.randint(0, 10, size=(4, 1)), dtype="int32" + ), + ] + ys = ops.map(dict_input_fn, xs) + self.assertEqual(ys["x"].shape, (4, 100)) + self.assertEqual( + ops.convert_to_numpy(ys["y"]).all(), + ops.convert_to_numpy(xs["y"] + 1).all(), + ) + ys = ops.map(list_input_fn, xs1) + for x, y in zip(xs1, ys): + self.assertEqual( + (ops.convert_to_numpy(y)).all(), + (ops.convert_to_numpy(x) ** 2).all(), + ) + + # Test Operation call. + xs = np.arange(10) + self.assertAllClose(ops.Map()(f, xs), xs**2) + + def test_saturate_cast(self): + x = ops.ones((2,), dtype="float32") + y = ops.saturate_cast(x, "float16") + self.assertIn("float16", str(y.dtype)) + + x = ops.KerasTensor((2,), dtype="float32") + y = ops.saturate_cast(x, "float16") + self.assertEqual("float16", y.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(y, "_keras_history")) + + # Test Operation call. + x = np.array([-256, 1.0, 257.0], dtype="float32") + y = core.SaturateCast("uint8")(x) + self.assertDType(y, "uint8") + # Check that the values are the same + self.assertAllClose(y, np.clip(x, 0, 255).astype("uint8")) + + def test_scan(self): + # Test cumsum + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + carry, result = core.scan(cumsum, init, xs) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + # Test reverse=True + carry, result = core.scan(cumsum, init, xs, reverse=True) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, [40, 39, 37, 34, 30, 20]) + + # Test unroll + for unroll in (True, False, 2): + carry, result = core.scan(cumsum, init, xs, unroll=unroll) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + # Test xs is None + def fibonaccis(carry, _): + return (carry[1], carry[0] + carry[1]), None + + init = (np.array(0, dtype="float32"), np.array(1, dtype="float32")) + carry, _ = core.scan(fibonaccis, init, length=6) + self.assertAllClose(carry, [8, 13]) + + # Test nested init + if backend.backend() != "tensorflow": + # tensorflow doesn't support arbitrary shape/dtype of the output of + # `f`. It must be the same as `init`. + def multiply_two(carry, _): + value1 = carry["value1"] + value2 = carry["value2"] + return ( + {"value1": value1 * 2, "value2": value2 * 2}, + value1 * 2 + value2 * 2, + ) + + init = {"value1": 2.0, "value2": 3.0} + carry, result = core.scan(multiply_two, init, length=3) + self.assertAllClose(carry["value1"], 16) + self.assertAllClose(carry["value2"], 24) + self.assertAllClose(result, [10, 20, 40]) + + # Test nested xs + def reduce_add(carry, xs): + value1 = xs["value1"] + value2 = xs["value2"] + return carry, value1 + value2 + + init = np.array(0, dtype="float32") + xs = { + "value1": np.array([1, 2, 3], dtype="float32"), + "value2": np.array([10, 20, 30], dtype="float32"), + } + _, result = core.scan(reduce_add, init, xs) + self.assertAllClose(result, [11, 22, 33]) + + # Test Operation call. + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + carry, result = core.Scan()(cumsum, init, xs) + self.assertAllClose(carry, 40.0) + self.assertAllClose(result, ops.cumsum(xs)) + + def test_scatter(self): + # Test 1D + indices = np.array([[1], [3], [4], [7]]) + values = np.array([9, 10, 11, 12]) + self.assertAllClose( + core.scatter(indices, values, (8,)), + [0, 9, 0, 10, 11, 0, 0, 12], + ) + # Test 2D + indices = np.array([[0, 1], [2, 0]]) + values = np.array([5, 10]) + self.assertAllClose( + core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]] + ) + # Test 3D + indices = np.array([[1], [3]]) + values = np.array( + [ + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + ] + ) + self.assertAllClose( + core.scatter(indices, values, (4, 4, 4)), + [ + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + ], + ) + # Test slices + indices = np.array([[2], [4]]) + values = np.array([[1, 2, 3], [4, 5, 6]]) + self.assertAllClose( + core.scatter(indices, values, (6, 3)), + [[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]], + ) + # Duplicate indices + indices = np.array([[0], [0]]) + values = np.array([1, 1]) + self.assertAllClose(core.scatter(indices, values, (1,)), [2]) + + # Test Operation call. + indices = np.array([[1, 0], [0, 1]]) + values = np.array([10, 20]) + shape = (2, 2) + self.assertAllClose( + core.Scatter(shape)(indices, values), np.array([[0, 20], [10, 0]]) + ) + + def test_scatter_update(self): + # Test 1D. + inputs = np.array([0, 0, 0, 0, 0, 0, 0, 0]) + indices = [[1], [3], [4], [7]] + updates = np.array([9, 10, 11, 12]) + self.assertAllClose( + core.scatter_update(inputs, indices, updates), + [0, 9, 0, 10, 11, 0, 0, 12], + ) + + # Test 2D. + inputs = np.array([[1, 1], [1, 1], [1, 1]]) + indices = [[0, 1], [2, 0]] + updates = np.array([5, 10]) + self.assertAllClose( + core.scatter_update(inputs, indices, updates), + [[1, 5], [1, 1], [10, 1]], + ) + + # Test updates has multiple dimension. + inputs = np.ones([4, 4, 4]) + indices = [[1, 1], [2, 2]] + updates = np.array([[0, 1, 2, 3], [3, 2, 1, 0]], dtype="float32") + outputs = core.scatter_update(inputs, indices, updates) + self.assertTrue(ops.is_tensor(outputs)) + self.assertAllClose(outputs[1, 1, :], [0, 1, 2, 3]) + self.assertAllClose(outputs[2, 2, :], [3, 2, 1, 0]) + + # Test Operation call. + inputs = np.array([[0, 0], [0, 0]]) + indices = np.array([[1, 0], [0, 1]]) + updates = np.array([10, 20]) + self.assertAllClose( + core.ScatterUpdate()(inputs, indices, updates), + np.array([[0, 20], [10, 0]]), + ) + + def test_shape(self): + x = ops.ones((2, 3, 7, 1)) + self.assertEqual(core.shape(x).__class__, tuple) + self.assertAllEqual(core.shape(x), (2, 3, 7, 1)) + + x = KerasTensor((None, 3, None, 1)) + self.assertEqual(core.shape(x).__class__, tuple) + self.assertAllEqual(core.shape(x), (None, 3, None, 1)) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason=f"{backend.backend()} backend doesn't support sparse tensors.", + ) + def test_shape_sparse(self): + if backend.backend() == "tensorflow": + import tensorflow as tf + + x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3)) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 3)) + else: + self.fail(f"Sparse is unsupported with backend {backend.backend()}") + + self.assertAllEqual(core.shape(x), (2, 3)) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason=f"{backend.backend()} backend doesn't support ragged tensors.", + ) + def test_shape_ragged(self): + import tensorflow as tf + + x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + self.assertAllEqual(core.shape(x), (5, None)) + + x = tf.RaggedTensor.from_row_lengths(tf.zeros([15, 2]), [4, 5, 6]) + self.assertAllEqual(core.shape(x), (3, None, 2)) + + def test_slice(self): + # Test 1D. + inputs = np.arange(10) + start_indices = np.array([1]) + shape = np.array([4]) + self.assertAllClose( + core.slice(inputs, start_indices, shape), + [1, 2, 3, 4], + ) + + # Test 2D. + inputs = np.broadcast_to(np.arange(10), (4, 10)) + start_indices = np.array([1, 1]) + shape = np.array([2, 4]) + self.assertAllClose( + core.slice(inputs, start_indices, shape), + [[1, 2, 3, 4], [1, 2, 3, 4]], + ) + + # Test N-D. + inputs = np.broadcast_to(np.arange(10), (4, 4, 4, 10)) + start_indices = np.array([1, 1, 1, 1]) + shape = np.array([1, 2, 3, 4]) + outputs = core.slice(inputs, start_indices, shape) + expected = np.broadcast_to(np.arange(1, 5), (1, 2, 3, 4)) + self.assertAllClose(outputs, expected) + + # Test Operation call. + inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + start_indices = np.array([1, 1]) + shape = (2, 2) + self.assertAllClose( + core.Slice(shape)(inputs, start_indices), np.array([[5, 6], [8, 9]]) + ) + + def test_slice_update(self): + # Test 1D. + inputs = np.array([0, 0, 0, 0, 0, 0, 0, 0]) + start_indices = np.array([1]) + updates = np.array([9, 10, 11, 12]) + self.assertAllClose( + core.slice_update(inputs, start_indices, updates), + [0, 9, 10, 11, 12, 0, 0, 0], + ) + + # Test 2D. + inputs = np.array([[1, 1], [1, 1], [1, 1]]) + start_indices = [1, 0] + updates = np.array([[2, 2], [2, 2]]) + self.assertAllClose( + core.slice_update(inputs, start_indices, updates), + [[1, 1], [2, 2], [2, 2]], + ) + + # Test N-D. + inputs = np.ones([4, 4, 4, 4]) + start_indices = [1, 1, 2, 2] + updates = np.zeros([2, 2, 2, 2]) + outputs = core.slice_update(inputs, start_indices, updates) + self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) + + # Test Operation call. + inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + start_indices = np.array([1, 1]) + updates = np.array([[10, 11], [12, 13]]) + self.assertAllClose( + core.SliceUpdate()(inputs, start_indices, updates), + np.array([[1, 2, 3], [4, 10, 11], [7, 12, 13]]), + ) + + @pytest.mark.requires_trainable_backend + def test_stop_gradient(self): + class ExampleLayer(layers.Layer): + def __init__(self): + super().__init__() + self.w = self.add_weight(shape=(1,), initializer="zeros") + self.b = self.add_weight(shape=(1,), initializer="zeros") + + def call(self, x, training=False): + return ops.add( + ops.multiply(x, ops.stop_gradient(self.w)), self.b + ) + + model = models.Sequential([ExampleLayer()]) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + rng = np.random.default_rng(0) + x = np.ones((2, 4), dtype="float32") + y = rng.standard_normal((2, 4), dtype="float32") + model.fit(x, y, epochs=1, batch_size=2) + self.assertEqual(model.layers[0].w.numpy(), 0.0) + self.assertNotEqual(model.layers[0].b.numpy(), 0.0) + + def test_stop_gradient_no_fit(self): + x = ops.random.uniform(shape=(2, 4), dtype="float32") + y = ops.stop_gradient(x) + self.assertAllClose(x, y) + + # Functional. + a = layers.Input(shape=(2,)) + b = layers.Dense(4, kernel_initializer="ones", use_bias=False)(a) + c = layers.Dense(4, kernel_initializer="ones", use_bias=False)(b) + d = ops.stop_gradient(b) + c + model = models.Model(inputs=a, outputs=d) + output = model(ops.convert_to_tensor([[1.0, 2.0]])) + self.assertAllClose(output, 15.0) + + # Test Operation call. + variable = ops.convert_to_tensor( + np.array([1.0, 2.0, 3.0], dtype="float32") + ) + self.assertAllClose(core.StopGradient()(variable), variable) + + def test_switch(self): + def fn1(x, y): + return x + y + + def fn2(x, y): + return x - y + + x = np.random.rand(2, 3, 4).astype("float32") + y = np.random.rand(2, 3, 4).astype("float32") + branches = [fn1, fn2] + self.assertAllClose(core.switch(0, branches, x, y), x + y) + self.assertAllClose(core.switch(1, branches, x, y), x - y) + + # Test out-of-bound index + self.assertAllClose(core.switch(-100, branches, x, y), x + y) + self.assertAllClose(core.switch(100, branches, x, y), x - y) + + # Test Operation call. + self.assertAllClose(core.Switch()(0, branches, x, y), x + y) + self.assertAllClose(core.Switch()(1, branches, x, y), x - y) + + def test_vectorized_map(self): + def fn(x): + return x + 1 + + output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype="float32")) + self.assertAllClose(backend.convert_to_numpy(output), np.ones((2, 3))) + + def fn(x): + return ops.stack([x, x]) + + output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype="float32")) + self.assertAllClose( + backend.convert_to_numpy(output), np.zeros((2, 2, 3)) + ) + + # Case: multiple args + def fn(elems): + x, y = elems + return x + y + + output = ops.vectorized_map(fn, [ops.ones((2, 3)), ops.ones((2, 3))]) + self.assertAllClose(output, 2 * np.ones((2, 3))) + + @parameterized.named_parameters( + [ + { + "testcase_name": "scalar_data_with_max", + "loop_vars": np.array(0), + "expected_output": np.array(5), + "maximum_iterations": 5, + }, + { + "testcase_name": "scalar_data_no_max", + "loop_vars": np.array(0), + "expected_output": np.array(10), + "maximum_iterations": None, + }, + { + "testcase_name": "nested_data_with_max", + "loop_vars": { + "a": np.array(0), + "b": (np.array(1), np.array(2)), + }, + "expected_output": { + "a": np.array(5), + "b": (np.array(6), np.array(7)), + }, + "maximum_iterations": 5, + }, + { + "testcase_name": "nested_data_no_max", + "loop_vars": { + "a": np.array(0), + "b": (np.array(1), np.array(2)), + }, + "expected_output": { + "a": np.array(10), + "b": (np.array(11), np.array(12)), + }, + "maximum_iterations": None, + }, + ] + ) + def test_while_loop(self, loop_vars, expected_output, maximum_iterations): + def cond(args): + return tree.flatten(args)[0] < 10 + + def body(args): + return tree.map_structure(lambda x: x + 1, args) + + output = core.while_loop( + cond, body, loop_vars, maximum_iterations=maximum_iterations + ) + tree.map_structure(self.assertAllClose, output, expected_output) + + # Test Operation call. + output = core.WhileLoop( + cond, body, maximum_iterations=maximum_iterations + )(loop_vars) + tree.map_structure(self.assertAllClose, output, expected_output) + + @parameterized.named_parameters( + [ + { + "testcase_name": "with_max", + "state": (np.array(0), np.array(1)), + "output": (np.array(5), np.array(6)), + "maximum_iterations": 5, + }, + { + "testcase_name": "no_max", + "state": (np.array(0), np.array(1)), + "output": (np.array(10), np.array(11)), + "maximum_iterations": None, + }, + ] + ) + def test_while_loop_list_data(self, state, output, maximum_iterations): + def cond(*args): + return tree.flatten(args)[0] < 10 + + def body(*args): + return tree.map_structure(lambda x: x + 1, args) + + state = core.while_loop( + cond, body, state, maximum_iterations=maximum_iterations + ) + tree.map_structure(self.assertAllClose, state, output) + + def test_unstack(self): + rng = np.random.default_rng(0) + x = rng.uniform(size=(2, 3, 4)) + x_tensor = ops.convert_to_tensor(x) + axis = 1 + out = ops.unstack(x_tensor, axis=axis) + out_ex = [x[:, i, :] for i in range(x.shape[axis])] + self.assertEqual(len(out), len(out_ex)) + for o, o_e in zip(out, out_ex): + o = ops.convert_to_numpy(o) + self.assertAllClose(o, o_e) + + # Test Operation call. + out = ops.Unstack(axis=axis)(x_tensor) + self.assertEqual(len(out), len(out_ex)) + for o, o_e in zip(out, out_ex): + o = ops.convert_to_numpy(o) + self.assertAllClose(o, o_e) + + +class CoreOpsDtypeTest(testing.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x + not in ( + "string", + "complex64", + "complex128", + # Remove 64-bit dtypes. + "float64", + "uint64", + "int64", + ) + + dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests + ] + [None] + INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")] + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + + if backend.backend() == "torch": + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")] + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")] + + @parameterized.named_parameters( + named_product( + dtype=[dtype for dtype in ALL_DTYPES if dtype is not None] + ) + ) + def test_cast(self, dtype): + x = np.ones((1,)) + + self.assertDType(core.cast(x, dtype), dtype) + self.assertDType(core.Cast(dtype).symbolic_call(x), dtype) + + @parameterized.parameters( + ((), None, backend.floatx()), + ([], None, backend.floatx()), + (bool(0), None, "bool"), + (int(0), None, "int32"), + (float(0), None, backend.floatx()), + (1, "bool", "bool"), + (1.0, "int32", "int32"), + (1.0, "float32", "float32"), + ([False, True, False], None, "bool"), + ([1, 2, 3], None, "int32"), + ([1.0, 2.0, 3.0], None, backend.floatx()), + ([1, 2.0, 3], None, backend.floatx()), + ([[False], [True], [False]], None, "bool"), + ([[1], [2], [3]], None, "int32"), + ([[1], [2.0], [3]], None, backend.floatx()), + *[ + (np.array(0, dtype=dtype), None, dtype) + for dtype in ALL_DTYPES + if dtype is not None + ], + *[ + ([[1, 0, 1], [1, 1, 0]], dtype, dtype) + for dtype in ALL_DTYPES + if dtype is not None + ], + ) + def test_convert_to_tensor(self, x, dtype, expected_dtype): + self.assertDType(ops.convert_to_tensor(x, dtype=dtype), expected_dtype) + + @parameterized.named_parameters( + named_product( + dtype=[dtype for dtype in ALL_DTYPES if dtype is not None] + ) + ) + def test_convert_to_tensor_with_tensor(self, dtype): + x = ops.convert_to_tensor(np.ones((2, 3), dtype="float32")) + + self.assertDType(ops.convert_to_tensor(x, dtype=dtype), dtype) + + @parameterized.named_parameters( + named_product( + dtype=[dtype for dtype in ALL_DTYPES if dtype is not None] + ) + ) + def test_convert_to_tensor_with_variable(self, dtype): + x = backend.Variable(np.ones((2, 3), dtype="float32")) + + self.assertDType(ops.convert_to_tensor(x, dtype=dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_saturate_cast(self, dtype): + x = np.ones((1,)) + + self.assertDType(core.saturate_cast(x, dtype), dtype) + self.assertDType(core.SaturateCast(dtype).symbolic_call(x), dtype) + + +class CoreOpsBehaviorTests(testing.TestCase): + def test_associative_scan_invalid_arguments(self): + # varying dimension at scan axis + x = (np.array([1, 2]), np.array([3, 4]), np.array([5, 6, 7])) + with self.assertRaisesRegex(ValueError, " first dimension"): + core.associative_scan(lambda x, y: (x[0] + y[0], x[1] + y[1]), x) + + # same error, symbolic + x = ( + KerasTensor((None, 5)), + KerasTensor((None, 4)), + ) + with self.assertRaisesRegex(ValueError, " first dimension"): + core.associative_scan( + lambda x, y: (x[0] + y[0], x[1] + y[1]), x, axis=1 + ) + + def test_cond_check_output_spec(self): + mock_spec = Mock(dtype="float32", shape=(2, 2)) + mock_spec_different = Mock(dtype="int32", shape=(3, 3)) + + # List & tuple. + self.assertTrue( + core.Cond()._check_output_spec( + [mock_spec, mock_spec], [mock_spec, mock_spec] + ) + ) + self.assertTrue( + core.Cond()._check_output_spec([mock_spec], [mock_spec]) + ) + self.assertFalse( + core.Cond()._check_output_spec( + [mock_spec], [mock_spec, mock_spec_different] + ) + ) + self.assertTrue( + core.Cond()._check_output_spec((mock_spec,), (mock_spec,)) + ) + self.assertFalse( + core.Cond()._check_output_spec( + (mock_spec,), (mock_spec, mock_spec_different) + ) + ) + + # Dict. + self.assertTrue( + core.Cond()._check_output_spec({"a": mock_spec}, {"a": mock_spec}) + ) + self.assertFalse( + core.Cond()._check_output_spec({"a": mock_spec}, {"b": mock_spec}) + ) + self.assertFalse( + core.Cond()._check_output_spec( + {"a": mock_spec}, {"a": mock_spec, "b": mock_spec} + ) + ) + + # None. + self.assertTrue(core.Cond()._check_output_spec(None, None)) + self.assertFalse( + core.Cond()._check_output_spec( + None, Mock(dtype="float32", shape=(2, 2)) + ) + ) + self.assertFalse( + core.Cond()._check_output_spec( + Mock(dtype="float32", shape=(2, 2)), None + ) + ) + + # KerasTensor. + mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32") + mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32") + self.assertTrue(core.Cond()._check_output_spec(mock_spec1, mock_spec2)) + + @pytest.mark.requires_trainable_backend + def test_cond_raw_bool_compile(self): + class ExampleLayer(layers.Layer): + def call(self, x, training=False): + return ops.cond(training, lambda: x, lambda: x * 2.0) + + model = models.Sequential([ExampleLayer()]) + model.compile( + optimizer=optimizers.SGD(), loss=losses.MeanSquaredError() + ) + x = np.ones((2, 4), dtype="float32") + y = np.zeros((2, 4), dtype="float32") + model.evaluate(x, y, batch_size=2) + + def test_convert_to_numpy(self): + x = ops.array([1, 2, 3], dtype="float32") + y = ops.convert_to_numpy(x) + self.assertIsInstance(y, np.ndarray) + # Test assignment -- should not fail. + y[0] = 1.0 + + with self.assertRaises(ValueError): + ops.convert_to_numpy(KerasTensor((2,))) + + def test_scan_invalid_arguments(self): + def cumsum(carry, xs): + carry = carry + xs + return carry, carry + + init = np.array(0, dtype="float32") + xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32") + + # Test non-callable + with self.assertRaisesRegex(TypeError, "should be a callable."): + core.scan(123, init, xs) + + # Test bad unroll + with self.assertRaisesRegex( + ValueError, "must be an positive integer or boolean." + ): + core.scan(cumsum, init, xs, unroll=-1) + + # Test both xs and length are None + with self.assertRaisesRegex(ValueError, "to scan over and"): + core.scan(cumsum, init, xs=None, length=None) + + def test_slice_compute_output_spec(self): + inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="float32") + start_indices = np.array([1, 1]) + shape = (2, 2) + output_spec = core.Slice(shape).compute_output_spec( + inputs, start_indices + ) + self.assertEqual(output_spec.shape, shape) + self.assertEqual(output_spec.dtype, inputs.dtype) + + def test_stop_gradient_compute_output_spec(self): + variable = KerasTensor(shape=(3,), dtype="float32") + stop_gradient = core.StopGradient() + output_spec = stop_gradient.compute_output_spec(variable) + self.assertEqual(output_spec.shape, variable.shape) + self.assertEqual(output_spec.dtype, variable.dtype) + + def test_vectorized_map_serialization(self): + @object_registration.register_keras_serializable() + def f(x): + return x + x + + inputs = input_layer.Input((10,), dtype="float32") + outputs = core.vectorized_map(f, inputs) + model = models.Functional(inputs, outputs) + reloaded_model = model.from_config(model.get_config()) + x = np.random.rand(5, 10).astype("float32") + self.assertAllClose(model(x), reloaded_model(x)) + + def test_while_loop_output_spec(self): + # Define dummy cond and body functions + def cond(x): + return True + + def body(x): + return (x,) + + while_loop = core.WhileLoop(cond, body, maximum_iterations=None) + loop_vars = (KerasTensor(shape=(10,), dtype="float32"),) + output_spec = while_loop.compute_output_spec(loop_vars) + self.assertEqual(output_spec[0].shape, loop_vars[0].shape) + self.assertEqual(output_spec[0].dtype, loop_vars[0].dtype) + + # Test with KerasTensor. + loop_vars = (np.random.rand(5, 5), np.random.randint(10, size=(3, 7))) + keras_loop_vars = [ + KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars + ] + while_loop = core.WhileLoop(cond, body, maximum_iterations=None) + output_specs = while_loop.compute_output_spec(keras_loop_vars) + self.assertEqual(output_specs[0].shape, keras_loop_vars[0].shape) + self.assertEqual(output_specs[0].dtype, keras_loop_vars[0].dtype) + self.assertEqual(output_specs[1].shape, keras_loop_vars[1].shape) + self.assertEqual(output_specs[1].dtype, keras_loop_vars[1].dtype) + + def test_unstack_unknown_axis_num(self): + x = KerasTensor((2, None, None)) + axis = 1 + with self.assertRaisesRegex( + ValueError, r"Cannot infer argument `num` from shape" + ): + core.unstack(x, axis=axis) diff --git a/keras/src/ops/einops.py b/keras/src/ops/einops.py new file mode 100644 index 000000000000..5c84ae8cc2b7 --- /dev/null +++ b/keras/src/ops/einops.py @@ -0,0 +1,189 @@ +import re + +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.ops.core import shape +from keras.src.ops.numpy import prod +from keras.src.ops.numpy import reshape +from keras.src.ops.numpy import transpose +from keras.src.ops.operation import Operation + + +def _create_axes_map(axes, input_shape, axes_lengths): + axes_map = {} + + for axis, dim in zip(axes, input_shape): + # Check for grouped axes pattern, e.g., "(h1 h)" + grouped_axes = re.match(r"\(([\w\s]+)\)", axis) + + if grouped_axes: + inner_axes = grouped_axes.group(1).split() + known_axes = [a for a in inner_axes if a in axes_lengths] + inferred_axes = [a for a in inner_axes if a not in axes_lengths] + + if inferred_axes: + inferred_axis = inferred_axes[0] + known_product = prod([axes_lengths[a] for a in known_axes]) + axes_lengths[inferred_axis] = dim // known_product + + axes_map.update({a: axes_lengths[a] for a in inner_axes}) + else: + axes_map[axis] = dim + + return axes_map + + +def _create_grouped_axes(axes): + grouped_output_axes = [] + for axis in axes: + grouped_axes = re.match(r"\(([\w\s]+)\)", axis) + + if grouped_axes: + inner_axes = grouped_axes.group(1).split() + grouped_output_axes.append(inner_axes) + else: + grouped_output_axes.append([axis]) + + return grouped_output_axes + + +def _flatten_group(axes): + return [x for xs in axes for x in xs] + + +def _get_transpose_order(from_shape, to_shape): + flattened_from_shape = _flatten_group(_create_grouped_axes(from_shape)) + + return [flattened_from_shape.index(dim) for dim in to_shape] + + +def _compute_output_shape(axes_map, grouped_axes): + output_shape = [] + for group in grouped_axes: + size = 1 + for axis in group: + size *= axes_map[axis] + output_shape.append(size) + + return tuple(output_shape) + + +def _compute_decomposed_shape(input_axes, axes_lengths, axes_map): + reshaped_input_axes = [] + reshaped_sizes = [] + + for axis in input_axes: + if "(" in axis: # Decomposed axis + inner_axes = re.findall(r"\w+", axis) + sizes = [axes_lengths[a] for a in inner_axes] + reshaped_input_axes.extend(inner_axes) + reshaped_sizes.extend(sizes) + else: + reshaped_input_axes.append(axis) + reshaped_sizes.append(axes_map[axis]) + + return reshaped_sizes + + +class Rearrange(Operation): + def call(self, tensor, pattern, **axes_lengths): + return rearrange(tensor, pattern, **axes_lengths) + + def compute_output_spec(self, tensor, pattern, **axes_lengths): + input_pattern, output_pattern = re.split(r"\s*->\s*", pattern) + input_axes = re.findall(r"\w+|\(.*?\)", input_pattern) + output_axes = re.findall(r"\w+|\(.*?\)", output_pattern) + input_shape = shape(tensor) + + axes_map = _create_axes_map(input_axes, input_shape, axes_lengths) + grouped_output_axes = _create_grouped_axes(output_axes) + output_shape = _compute_output_shape(axes_map, grouped_output_axes) + + return KerasTensor(shape=output_shape, dtype=tensor.dtype) + + +@keras_export("keras.ops.rearrange") +def rearrange(tensor, pattern, **axes_lengths): + """Rearranges the axes of a Keras tensor according to a specified pattern, + einops-style. + + Args: + tensor: Input Keras tensor. + pattern: String describing the rearrangement in einops notation. + **axes_lengths: Keyword arguments specifying lengths of axes + when axes decomposition is used. + + Returns: + Tensor: A Keras tensor with rearranged axes. + + Follows the logic of: + + 1. If decomposition is needed, reshape to match decomposed dimensions. + 2. Permute known and inferred axes to match the form of the output. + 3. Reshape to match the desired output shape. + + + Example Usage: + + ``` + >>> import numpy as np + >>> from keras.ops import rearrange + >>> images = np.random.rand(32, 30, 40, 3) # BHWC format + + # Reordering to BCHW + >>> rearrange(images, 'b h w c -> b c h w').shape + TensorShape([32, 3, 30, 40]) + + # "Merge" along first axis - concat images from a batch + >>> rearrange(images, 'b h w c -> (b h) w c').shape + TensorShape([960, 40, 3]) + + # "Merge" along second axis - concat images horizontally + >>> rearrange(images, 'b h w c -> h (b w) c').shape + TensorShape([30, 1280, 3]) + + # Flatten images into a CHW vector + >>> rearrange(images, 'b h w c -> b (c h w)').shape + TensorShape([32, 3600]) + + # Decompose H and W axes into 4 smaller patches + >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape + TensorShape([128, 15, 20, 3]) + + # Space-to-depth decomposition of input axes + >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape + TensorShape([32, 15, 20, 12]) + ``` + """ # noqa: E501 + + if any_symbolic_tensors((tensor,)): + return Rearrange().symbolic_call(tensor, pattern, **axes_lengths) + + # Split the input and output patterns + input_pattern, output_pattern = re.split(r"\s*->\s*", pattern) + input_axes = re.findall(r"\w+|\(.*?\)", input_pattern) + output_axes = re.findall(r"\w+|\(.*?\)", output_pattern) + input_shape = shape(tensor) + + # Create axes map, and flattened output group + axes_map = _create_axes_map(input_axes, input_shape, axes_lengths) + grouped_output_axes = _create_grouped_axes(output_axes) + flattened_output_axes = _flatten_group(grouped_output_axes) + + # 1. Axes decomposition + decomposed_shapes = _compute_decomposed_shape( + input_axes, axes_lengths, axes_map + ) + if decomposed_shapes != tensor.shape: + tensor = reshape(tensor, decomposed_shapes) + + # 2. Transpose to match target shape + permute_order = _get_transpose_order(input_axes, flattened_output_axes) + tensor = transpose(tensor, permute_order) + + # 3. Reshape to final target shape + output_shape = _compute_output_shape(axes_map, grouped_output_axes) + tensor = reshape(tensor, output_shape) + + return tensor diff --git a/keras/src/ops/einops_test.py b/keras/src/ops/einops_test.py new file mode 100644 index 000000000000..c7963e9c35ec --- /dev/null +++ b/keras/src/ops/einops_test.py @@ -0,0 +1,51 @@ +from conftest import skip_if_backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common import keras_tensor +from keras.src.ops.einops import rearrange + + +class RearrangeTest(testing.TestCase): + def test_basic_rearrangement_symbolic(self): + x = keras_tensor.KerasTensor((2, 3, 4)) + y = rearrange(x, "b c h -> b h c") + self.assertIsInstance(y, keras_tensor.KerasTensor) + self.assertEqual(y.shape, (2, 4, 3)) + + @skip_if_backend("openvino", "Test operation not supported by openvino") + def test_basic_rearrangement(self): + x = ops.random.uniform((2, 3, 4)) + y = rearrange(x, "b c h -> b h c") + self.assertEqual(y.shape, (2, 4, 3)) + self.assertTrue(ops.all(ops.equal(y, ops.transpose(x, (0, 2, 1))))) + + @skip_if_backend("openvino", "Test operation not supported by openvino") + def test_output_composition(self): + x = ops.random.uniform((2, 4, 4, 3)) + y = rearrange(x, "b h w c -> (b h) w c") + target_shape = (8, 4, 3) + self.assertEqual(y.shape, target_shape) + self.assertTrue(ops.all(ops.equal(y, ops.reshape(x, (8, 4, 3))))) + + def test_basic_decomposition_and_rearrangement_symbolic(self): + x = keras_tensor.KerasTensor((6, 8)) + y = rearrange(x, "(h w) c -> h w c", h=2, w=3) + self.assertIsInstance(y, keras_tensor.KerasTensor) + self.assertEqual(y.shape, (2, 3, 8)) + + def test_basic_decomposition_and_rearrangement(self): + x = ops.random.uniform((6, 8)) + y = rearrange(x, "(h w) c -> h w c", h=2, w=3) + self.assertEqual(y.shape, (2, 3, 8)) + + @skip_if_backend("openvino", "Test operation not supported by openvino") + def test_unchanged_shape(self): + x = ops.ones([2, 3, 4]) + y = rearrange(x, "b h c -> b h c") + self.assertTrue(ops.all(ops.equal(y, x))) + self.assertTrue(x.shape, y.shape) + + def test_unchanged_shape_symbolic(self): + x = keras_tensor.KerasTensor((2, 3, 4)) + y = rearrange(x, "b h c -> b h c") + self.assertTrue(x.shape, y.shape) diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py new file mode 100644 index 000000000000..abac0820644f --- /dev/null +++ b/keras/src/ops/function.py @@ -0,0 +1,458 @@ +import collections + +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend.config import backend +from keras.src.backend.config import is_nnx_enabled +from keras.src.ops.operation import Operation + + +@keras_export("keras.Function") +class Function(Operation): + """Class that encapsulates a computation graph of Keras operations. + + You can use a `Function` to capture the computation graph linking + some input tensors to some output tensors, and reapply the same + computation on new inputs. + + A `Function` is similar to a Functional Model, with the difference + that it is stateless (it does not track state variables) + and does not implement the `Layer` API. + + Example: + + ```python + input_1 = keras.KerasTensor(shape=(None, 2, 3)) + input_2 = keras.KerasTensor(shape=(None, 2, 3)) + x = input_1 + input_2 + output = keras.ops.sigmoid(x) + fn = keras.Function(inputs=[input_1, input_2], outputs=output) + + input_1_val = np.random.random((4, 2, 3)) + input_2_val = np.random.random((4, 2, 3)) + output_val = fn([input_1_val, input_2_val]) + ``` + + Args: + inputs: `KerasTensor` instance or nested structured of + `KerasTensor` instances. + outputs: `KerasTensor` instance or nested structured of + `KerasTensor` instances. They should be computable + given only the values of `inputs`. + name: String. The name of the function. + """ + + def __init__(self, inputs, outputs, name=None): + super().__init__(name=name) + + if backend() == "tensorflow": + # Temporary work around for + # https://github.com/keras-team/keras/issues/931 + # This stop tensorflow from wrapping tf.function output in a + # _DictWrapper object. + _self_setattr_tracking = getattr( + self, "_self_setattr_tracking", True + ) + self._self_setattr_tracking = False + self._inputs_struct = tree.map_structure(lambda x: x, inputs) + self._outputs_struct = tree.map_structure(lambda x: x, outputs) + self._inputs = tree.flatten(inputs) + self._outputs = tree.flatten(outputs) + if not self._inputs: + raise ValueError( + "`inputs` argument cannot be empty. Received:\n" + f"inputs={inputs}\n" + f"outputs={outputs}" + ) + if not self._outputs: + raise ValueError( + "`outputs` argument cannot be empty. Received:\n" + f"inputs={inputs}\n" + f"outputs={outputs}" + ) + + if backend() == "tensorflow": + self._self_setattr_tracking = _self_setattr_tracking + + (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph( + self._inputs, self._outputs + ) + self._nodes = nodes + self._nodes_by_depth = nodes_by_depth + self._operations = operations + self._operations_by_depth = operations_by_depth + for input in self._inputs: + if ( + input._keras_history.operation + and not input._keras_history.operation._outbound_nodes + ): + raise ValueError("`inputs` not connected to `outputs`") + + # Special handling for NNX to ensure consistent operation instance usage + if is_nnx_enabled(): + self._setup_nnx_op_mapping() + + @property + def operations(self): + return self._operations[:] + + @property + def inputs(self): + """Flat list of the symbolic inputs of the Function.""" + return self._inputs + + @property + def outputs(self): + """Flat list of the symbolic outputs of the Function.""" + return self._outputs + + def _setup_nnx_op_mapping(self): + """Setup operation mapping for NNX""" + # Create a mapping from operation id to operation instance + self._nnx_op_mapping = {} + + # Assign the list of operations to a single attribute for NNX traversal + self.nnx_operations = self._operations[:] + for operation in self._operations: + # Map the operation id to this operation instance + self._nnx_op_mapping[id(operation)] = operation + + def _get_operation_for_node(self, node): + """Get the operation for a node, using NNX mapping if enabled.""" + operation = node.operation + if hasattr(self, "_nnx_op_mapping") and id(operation) in getattr( + self, "_nnx_op_mapping", {} + ): + return self._nnx_op_mapping[id(operation)] + return operation + + def compute_output_spec(self, inputs): + self._assert_input_compatibility(inputs) + # Check if input shapes are identical to ref input shapes, + # if so take a shortcut. + shortcut = True + for x, x_ref in zip(tree.flatten(inputs), self._inputs): + if x.shape != x_ref.shape: + shortcut = False + break + if shortcut: + return tree.map_structure( + lambda x: KerasTensor(shape=x.shape, dtype=x.dtype), + self._outputs_struct, + ) + # No luck; take the long road through the graph. + # Original Keras used a cache to avoid recomputing all this + # when known input shapes where seen again. Perhaps a good + # idea to bring that back. + return self._run_through_graph( + inputs, operation_fn=lambda op: op.compute_output_spec + ) + + def compute_output_shape(self, input_shape): + # Wrap `input_shape` into the structure of KerasTensor to utilize + # `compute_output_spec`. + input_shape_struct = tree.map_shape_structure( + lambda x: KerasTensor(shape=x), input_shape + ) + # Ensure that dtype and sparse settings are the same as self._inputs, + # because we only care about the shape in this function. + for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs): + x._dtype = x_ref.dtype + x._sparse = x_ref.sparse + output_spec = self.compute_output_spec(input_shape_struct) + return tree.map_structure(lambda x: x.shape, output_spec) + + def call(self, inputs): + """Computes output tensors for new inputs.""" + self._assert_input_compatibility(inputs) + return self._run_through_graph(inputs, operation_fn=lambda op: op) + + def _run_through_graph(self, inputs, operation_fn, call_fn=None): + """Execute the graph. + + At each node we compute outputs via + `operation_fn(node.operation)(*args, **kwargs)`. + """ + inputs = tree.flatten(inputs) + + # Dictionary mapping reference tensors to computed tensors. + tensor_dict = {} + for x, y in zip(self.inputs, inputs): + tensor_dict[id(x)] = y + + nodes_by_depth = self._nodes_by_depth + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + + for depth in depth_keys: + nodes = nodes_by_depth[depth] + for node in nodes: + if not node.operation or node.is_input: + continue # Input tensors already exist. + + if any(id(x) not in tensor_dict for x in node.input_tensors): + continue # Node is not computable, try skipping. + + args, kwargs = node.arguments.fill_in(tensor_dict) + if call_fn is not None: + # Use call_fn if provided (e.g., for symbolic execution) + op = operation_fn(node.operation) + outputs = call_fn(op, *args, **kwargs) + else: + # Use NNX operation mapping + operation = self._get_operation_for_node(node) + op = operation_fn(operation) + outputs = op(*args, **kwargs) + + # Update tensor_dict. + for x, y in zip(node.outputs, tree.flatten(outputs)): + tensor_dict[id(x)] = y + + output_tensors = [] + for x in self.outputs: + output_tensors.append(tensor_dict[id(x)]) + + return tree.pack_sequence_as(self._outputs_struct, output_tensors) + + def _assert_input_compatibility(self, inputs): + try: + tree.assert_same_structure(inputs, self._inputs_struct) + except ValueError: + raise ValueError( + "Function was called with an invalid input structure. " + f"Expected input structure: {self._inputs_struct}\n" + f"Received input structure: {inputs}" + ) + for x, x_ref in zip(tree.flatten(inputs), self._inputs): + if len(x.shape) != len(x_ref.shape): + raise ValueError( + f"{self.__class__.__name__} was passed " + f"incompatible inputs. For input '{x_ref.name}', " + f"expected shape {x_ref.shape}, but received " + f"instead a tensor with shape {x.shape}." + ) + for dim, ref_dim in zip(x.shape, x_ref.shape): + if ref_dim is not None and dim is not None: + if dim != ref_dim: + raise ValueError( + f"{self.__class__.__name__} was passed " + f"incompatible inputs. For input '{x_ref.name}', " + f"expected shape {x_ref.shape}, but received " + f"instead a tensor with shape {x.shape}." + ) + + +def make_node_key(op, node_index): + return f"{id(op)}_ib-{node_index}" + + +def map_graph(inputs, outputs): + """Validates a graph's topology and gather its operations and nodes. + + Args: + inputs: List of input tensors. + outputs: List of outputs tensors. + + Returns: + A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`. + - nodes: set of Node instances + - nodes_by_depth: dict mapping ints (depth) to lists of node instances. + - operations: list of Operation instances. + - operations_by_depth: dict mapping ints (depth) to lists of Operation + instances. + """ + # "depth" is number of operations between output Node and the Node. + # Nodes are ordered from inputs -> outputs. + nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs) + network_nodes = { + make_node_key(node.operation, node.operation._inbound_nodes.index(node)) + for node in nodes_in_decreasing_depth + } + + nodes_depths = {} # dict {node: depth value} + operations_depths = {} # dict {operation: depth value} + + for node in reversed(nodes_in_decreasing_depth): + # If the depth is not set, the node has no outbound nodes (depth 0). + depth = nodes_depths.setdefault(node, 0) + + # Update the depth of the corresponding operation + previous_depth = operations_depths.get(node.operation, 0) + # If we've seen this operation before at a higher depth, + # we should use that depth instead of the node depth. + # This is necessary for shared operations that have inputs at different + # depth levels in the graph. + depth = max(depth, previous_depth) + operations_depths[node.operation] = depth + nodes_depths[node] = depth + + # Update the depth of inbound nodes. + # The "depth" of a node is the max of the depths + # of all nodes it is connected to + 1. + for node_dep in node.parent_nodes: + previous_depth = nodes_depths.get(node_dep, 0) + nodes_depths[node_dep] = max(depth + 1, previous_depth) + + # Handle inputs that are not connected to outputs. + # We do not error out here because the inputs may be used to compute losses + # and metrics. + for input_t in inputs: + input_operation = input_t._keras_history[0] + if input_operation and input_operation not in operations_depths: + operations_depths[input_operation] = 0 + operation_indices[input_operation] = -1 + nodes_depths[input_operation._inbound_nodes[0]] = 0 + network_nodes.add(make_node_key(input_operation, 0)) + + # Build a dict {depth: list of nodes with this depth} + nodes_by_depth = collections.defaultdict(list) + for node, depth in nodes_depths.items(): + nodes_by_depth[depth].append(node) + + # Build a dict {depth: list of operations with this depth} + operations_by_depth = collections.defaultdict(list) + for operation, depth in operations_depths.items(): + operations_by_depth[depth].append(operation) + + # Get sorted list of operation depths. + depth_keys = list(operations_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Set self.operations ordered by depth. + operations = [] + for depth in depth_keys: + operations_for_depth = operations_by_depth[depth] + # Network.operations needs to have a deterministic order: + # here we order them by traversal order. + operations_for_depth.sort(key=lambda x: operation_indices[x]) + operations.extend(operations_for_depth) + + # Get sorted list of node depths. + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Check that all tensors required are computable. + # computable_tensors: all tensors in the graph + # that can be computed from the inputs provided. + computable_tensors = set() + for x in inputs: + computable_tensors.add(x) + + operations_with_complete_input = [] # To provide a better error msg. + for depth in depth_keys: + for node in nodes_by_depth[depth]: + for x in tree.flatten(node.input_tensors): + if x not in computable_tensors: + operation = node.operation + raise ValueError( + "Graph disconnected: cannot find parent for " + f"tensor {x} at operation '{operation}'. " + "The following previous operations were accessed " + f"without issue: {operations_with_complete_input}" + ) + operations_with_complete_input.append(node.operation.name) + + for x in tree.flatten(node.outputs): + computable_tensors.add(x) + + # Ensure name unicity, which will be crucial for serialization + # (since serialized nodes refer to operations by their name). + all_names = [operation.name for operation in operations] + for name in all_names: + if all_names.count(name) != 1: + raise ValueError( + f'The name "{name}" is used {all_names.count(name)} ' + "times in the model. All operation names should be unique." + ) + return network_nodes, nodes_by_depth, operations, operations_by_depth + + +def _build_map(inputs, outputs): + """Topologically sort nodes in order from inputs to outputs. + + It uses a depth-first search to topologically sort nodes that appear in the + _keras_history connectivity metadata of `outputs`. + + Args: + outputs: the output tensors whose _keras_history metadata should be + walked. This may be an arbitrary nested structure. + + Returns: + A tuple like (ordered_nodes, operation_to_first_traversal_index) + ordered_nodes: list of nodes appearing in the keras history, + topologically sorted from original inputs to the `outputs`. + (If outputs have different sets of ancestors, the inputs to one + output may appear after a different output). + operation_to_first_traversal_index: + A dict mapping operation to the traversal index in the DFS where it + is seen. Note: if a operation is shared by several nodes, the dict + will onlystore the index corresponding to the *first* time the + operation seen. + """ + finished_nodes = set() + nodes_in_progress = set() + nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. + operation_indices = {} # operation -> in traversal order. + for output in tree.flatten(outputs): + _build_map_helper( + inputs, + output, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + operation_indices, + ) + return nodes_in_decreasing_depth, operation_indices + + +def _build_map_helper( + inputs, + tensor, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + operation_indices, +): + """Recursive helper for `_build_map`.""" + ( + operation, + node_index, + _, + ) = tensor._keras_history + if not operation: + return + + node = operation._inbound_nodes[node_index] + + # Don't repeat work for shared subgraphs + if node in finished_nodes: + return + + # Prevent cycles. + if node in nodes_in_progress: + raise ValueError( + f"Tensor {tensor} from operation '{operation.name}' is part of a " + "cycle." + ) + + # Store the traversal order for operation sorting. + if operation not in operation_indices: + operation_indices[operation] = len(operation_indices) + + # Propagate to all previous tensors connected to this node. + nodes_in_progress.add(node) + if not node.is_input and tensor not in tree.flatten(inputs): + for tensor in node.input_tensors: + _build_map_helper( + inputs, + tensor, + finished_nodes, + nodes_in_progress, + nodes_in_decreasing_depth, + operation_indices, + ) + + finished_nodes.add(node) + nodes_in_progress.remove(node) + nodes_in_decreasing_depth.append(node) diff --git a/keras/src/ops/function_test.py b/keras/src/ops/function_test.py new file mode 100644 index 000000000000..ea6c3dcf8d79 --- /dev/null +++ b/keras/src/ops/function_test.py @@ -0,0 +1,168 @@ +import json + +import numpy as np + +from keras.src import testing +from keras.src.backend.common import keras_tensor +from keras.src.layers import Dense +from keras.src.layers import Input +from keras.src.models import Model +from keras.src.models import Sequential +from keras.src.ops import function +from keras.src.ops import numpy as knp + + +class FunctionTest(testing.TestCase): + def test_define_and_call(self): + x1 = keras_tensor.KerasTensor((2, 3)) + x2 = keras_tensor.KerasTensor((2, 3)) + x = knp.add(x1, x2) + y1 = x * 3 + y2 = x**2 + fn = function.Function( + inputs=[x1, x2], outputs=[y1, y2], name="test_function" + ) + self.assertEqual(fn.name, "test_function") + + # Eager call + y_val = fn([np.ones((2, 3)), np.ones((2, 3))]) + self.assertIsInstance(y_val, list) + self.assertAllClose(y_val[0], np.ones((2, 3)) * 6) + self.assertAllClose(y_val[1], np.ones((2, 3)) * 4) + + # Symbolic call + x1_alt = keras_tensor.KerasTensor((2, 3)) + x2_alt = keras_tensor.KerasTensor((2, 3)) + y_val = fn([x1_alt, x2_alt]) + self.assertIsInstance(y_val[0], keras_tensor.KerasTensor) + self.assertEqual(y_val[0].shape, (2, 3)) + self.assertIsInstance(y_val[1], keras_tensor.KerasTensor) + self.assertEqual(y_val[1].shape, (2, 3)) + + # Recursion + fn = function.Function(inputs=[x1_alt, x2_alt], outputs=y_val) + y_val = fn([np.ones((2, 3)), np.ones((2, 3))]) + self.assertIsInstance(y_val, list) + self.assertAllClose(y_val[0], np.ones((2, 3)) * 6) + self.assertAllClose(y_val[1], np.ones((2, 3)) * 4) + + def test_dynamic_shape_inference(self): + x = keras_tensor.KerasTensor((None, 3)) + y = x**2 + fn = function.Function(x, y) + + # Test with compute_output_spec + out = fn.compute_output_spec(keras_tensor.KerasTensor((4, 3))) + self.assertIsInstance(out, keras_tensor.KerasTensor) + self.assertEqual(out.shape, (4, 3)) + + # Test with compute_output_shape + out = fn.compute_output_shape((None, 3)) + self.assertIsInstance(out, tuple) + self.assertEqual(out, (None, 3)) + + # Test with call + out = fn(keras_tensor.KerasTensor((4, 3))) + self.assertIsInstance(out, keras_tensor.KerasTensor) + self.assertEqual(out.shape, (4, 3)) + + def test_dict_io(self): + x1 = keras_tensor.KerasTensor((2, 3)) + x2 = keras_tensor.KerasTensor((2, 3)) + x = knp.add(x1, x2) + y1 = x * 3 + y2 = x**2 + fn = function.Function( + inputs={"x1": x1, "x2": x2}, outputs={"y1": y1, "y2": y2} + ) + + # Eager call + y_val = fn({"x1": np.ones((2, 3)), "x2": np.ones((2, 3))}) + self.assertIsInstance(y_val, dict) + self.assertAllClose(y_val["y1"], np.ones((2, 3)) * 6) + self.assertAllClose(y_val["y2"], np.ones((2, 3)) * 4) + + # Symbolic call + x1_alt = keras_tensor.KerasTensor((2, 3)) + x2_alt = keras_tensor.KerasTensor((2, 3)) + y_val = fn({"x1": x1_alt, "x2": x2_alt}) + self.assertIsInstance(y_val["y1"], keras_tensor.KerasTensor) + self.assertEqual(y_val["y1"].shape, (2, 3)) + self.assertIsInstance(y_val["y2"], keras_tensor.KerasTensor) + self.assertEqual(y_val["y2"].shape, (2, 3)) + + def test_invalid_inputs_error(self): + x1 = keras_tensor.KerasTensor((2, 3)) + x2 = keras_tensor.KerasTensor((2, 3)) + x = knp.add(x1, x2) + y1 = x * 3 + y2 = x**2 + fn = function.Function( + inputs=[x1, x2], outputs=[y1, y2], name="test_function" + ) + self.assertEqual(fn.name, "test_function") + + # Bad structure + with self.assertRaisesRegex(ValueError, "invalid input structure"): + _ = fn(np.ones((2, 3))) + + # Bad rank + with self.assertRaisesRegex(ValueError, "incompatible inputs"): + _ = fn([np.ones((2, 3, 3)), np.ones((2, 3))]) + + # Bad shape + with self.assertRaisesRegex(ValueError, "incompatible inputs"): + _ = fn([np.ones((4, 3)), np.ones((2, 3))]) + + def test_graph_disconnected_error(self): + # TODO + pass + + def test_serialization(self): + inputs = Input(shape=(10,)) + outputs = Dense(1)(inputs) + model = Model(inputs=inputs, outputs=outputs) + + config = model.get_config() + new_model = Model.from_config(config) + + self.assertEqual( + json.dumps(model.get_config()), json.dumps(new_model.get_config()) + ) + + def test_function_with_empty_outputs(self): + x = keras_tensor.KerasTensor((None, 3)) + with self.assertRaisesRegex( + ValueError, "`outputs` argument cannot be empty" + ): + _ = function.Function(inputs=x, outputs=[]) + + def test_function_with_empty_inputs(self): + x = keras_tensor.KerasTensor((None, 3)) + with self.assertRaisesRegex( + ValueError, "`inputs` argument cannot be empty" + ): + _ = function.Function(inputs=[], outputs=x) + + def test_function_with_unconnected_inputs(self): + model_1 = Sequential( + [ + Input(shape=(6,)), + Dense(3, activation="sigmoid"), + ] + ) + model_2 = Sequential( + [ + Input(shape=(3,)), + Dense(2, activation="sigmoid"), + ], + ) + with self.assertRaisesRegex( + ValueError, "`inputs` not connected to `outputs`" + ): + _ = Model(Input(shape=(6,)), model_2(model_1(Input(shape=(6,))))) + + with self.assertRaisesRegex( + ValueError, "`inputs` not connected to `outputs`" + ): + _ = Model(model_1(Input(shape=(6,))), model_2(Input(shape=(3,)))) diff --git a/keras/src/ops/image.py b/keras/src/ops/image.py new file mode 100644 index 000000000000..61993aa54873 --- /dev/null +++ b/keras/src/ops/image.py @@ -0,0 +1,1714 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.ops.operation import Operation +from keras.src.ops.operation_utils import compute_conv_output_shape + + +class RGBToGrayscale(Operation): + def __init__(self, data_format=None, *, name=None): + super().__init__(name=name) + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + if len(images_shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). " + f"Received: images.shape={images_shape}" + ) + if self.data_format == "channels_last": + images_shape[-1] = 1 + else: + images_shape[-3] = 1 + return KerasTensor(shape=images_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.rgb_to_grayscale") +def rgb_to_grayscale(images, data_format=None): + """Convert RGB images to grayscale. + + This function converts RGB images to grayscale images. It supports both + 3D and 4D tensors. + + Args: + images: Input image or batch of images. Must be 3D or 4D. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Grayscale image or batch of grayscale images. + + Examples: + + >>> import numpy as np + >>> from keras import ops + >>> x = np.random.random((2, 4, 4, 3)) + >>> y = ops.image.rgb_to_grayscale(x) + >>> y.shape + (2, 4, 4, 1) + + >>> x = np.random.random((4, 4, 3)) # Single RGB image + >>> y = ops.image.rgb_to_grayscale(x) + >>> y.shape + (4, 4, 1) + + >>> x = np.random.random((2, 3, 4, 4)) + >>> y = ops.image.rgb_to_grayscale(x, data_format="channels_first") + >>> y.shape + (2, 1, 4, 4) + """ + if any_symbolic_tensors((images,)): + return RGBToGrayscale(data_format=data_format).symbolic_call(images) + return backend.image.rgb_to_grayscale(images, data_format=data_format) + + +class RGBToHSV(Operation): + def __init__(self, data_format=None, *, name=None): + super().__init__(name=name) + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return backend.image.rgb_to_hsv(images, data_format=self.data_format) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + dtype = images.dtype + if len(images_shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). " + f"Received: images.shape={images_shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={dtype}" + ) + return KerasTensor(shape=images_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.rgb_to_hsv") +def rgb_to_hsv(images, data_format=None): + """Convert RGB images to HSV. + + `images` must be of float dtype, and the output is only well defined if the + values in `images` are in `[0, 1]`. + + All HSV values are in `[0, 1]`. A hue of `0` corresponds to pure red, `1/3` + is pure green, and `2/3` is pure blue. + + Args: + images: Input image or batch of images. Must be 3D or 4D. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + HSV image or batch of HSV images. + + Examples: + + >>> import numpy as np + >>> from keras import ops + >>> x = np.random.random((2, 4, 4, 3)) + >>> y = ops.image.rgb_to_hsv(x) + >>> y.shape + (2, 4, 4, 3) + + >>> x = np.random.random((4, 4, 3)) # Single RGB image + >>> y = ops.image.rgb_to_hsv(x) + >>> y.shape + (4, 4, 3) + + >>> x = np.random.random((2, 3, 4, 4)) + >>> y = ops.image.rgb_to_hsv(x, data_format="channels_first") + >>> y.shape + (2, 3, 4, 4) + """ + if any_symbolic_tensors((images,)): + return RGBToHSV(data_format=data_format).symbolic_call(images) + return backend.image.rgb_to_hsv(images, data_format=data_format) + + +class HSVToRGB(Operation): + def __init__(self, data_format=None, *, name=None): + super().__init__(name=name) + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return backend.image.hsv_to_rgb(images, data_format=self.data_format) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + dtype = images.dtype + if len(images_shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). " + f"Received: images.shape={images_shape}" + ) + if not backend.is_float_dtype(dtype): + raise ValueError( + "Invalid images dtype: expected float dtype. " + f"Received: images.dtype={dtype}" + ) + return KerasTensor(shape=images_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.hsv_to_rgb") +def hsv_to_rgb(images, data_format=None): + """Convert HSV images to RGB. + + `images` must be of float dtype, and the output is only well defined if the + values in `images` are in `[0, 1]`. + + Args: + images: Input image or batch of images. Must be 3D or 4D. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + RGB image or batch of RGB images. + + Examples: + + >>> import numpy as np + >>> from keras import ops + >>> x = np.random.random((2, 4, 4, 3)) + >>> y = ops.image.hsv_to_rgb(x) + >>> y.shape + (2, 4, 4, 3) + + >>> x = np.random.random((4, 4, 3)) # Single HSV image + >>> y = ops.image.hsv_to_rgb(x) + >>> y.shape + (4, 4, 3) + + >>> x = np.random.random((2, 3, 4, 4)) + >>> y = ops.image.hsv_to_rgb(x, data_format="channels_first") + >>> y.shape + (2, 3, 4, 4) + """ + if any_symbolic_tensors((images,)): + return HSVToRGB(data_format=data_format).symbolic_call(images) + return backend.image.hsv_to_rgb(images, data_format=data_format) + + +class Resize(Operation): + def __init__( + self, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.size = tuple(size) + self.interpolation = interpolation + self.antialias = antialias + self.crop_to_aspect_ratio = crop_to_aspect_ratio + self.pad_to_aspect_ratio = pad_to_aspect_ratio + self.fill_mode = fill_mode + self.fill_value = fill_value + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return _resize( + images, + self.size, + interpolation=self.interpolation, + antialias=self.antialias, + data_format=self.data_format, + crop_to_aspect_ratio=self.crop_to_aspect_ratio, + pad_to_aspect_ratio=self.pad_to_aspect_ratio, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + ) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + if len(images_shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if self.data_format == "channels_last": + height_axis, width_axis = -3, -2 + else: + height_axis, width_axis = -2, -1 + images_shape[height_axis] = self.size[0] + images_shape[width_axis] = self.size[1] + return KerasTensor(shape=images_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.resize") +def resize( + images, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, +): + """Resize images to size using the specified interpolation method. + + Args: + images: Input image or batch of images. Must be 3D or 4D. + size: Size of output image in `(height, width)` format. + interpolation: Interpolation method. Available methods are `"nearest"`, + `"bilinear"`, and `"bicubic"`. Defaults to `"bilinear"`. + antialias: Whether to use an antialiasing filter when downsampling an + image. Defaults to `False`. + crop_to_aspect_ratio: If `True`, resize the images without aspect + ratio distortion. When the original aspect ratio differs + from the target aspect ratio, the output image will be + cropped so as to return the + largest possible window in the image (of size `(height, width)`) + that matches the target aspect ratio. By default + (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved. + pad_to_aspect_ratio: If `True`, pad the images without aspect + ratio distortion. When the original aspect ratio differs + from the target aspect ratio, the output image will be + evenly padded on the short side. + fill_mode: When using `pad_to_aspect_ratio=True`, padded areas + are filled according to the given mode. Only `"constant"` is + supported at this time + (fill with constant value, equal to `fill_value`). + fill_value: Float. Padding value to use when `pad_to_aspect_ratio=True`. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Resized image or batch of images. + + Examples: + + >>> x = np.random.random((2, 4, 4, 3)) # batch of 2 RGB images + >>> y = keras.ops.image.resize(x, (2, 2)) + >>> y.shape + (2, 2, 2, 3) + + >>> x = np.random.random((4, 4, 3)) # single RGB image + >>> y = keras.ops.image.resize(x, (2, 2)) + >>> y.shape + (2, 2, 3) + + >>> x = np.random.random((2, 3, 4, 4)) # batch of 2 RGB images + >>> y = keras.ops.image.resize(x, (2, 2), + ... data_format="channels_first") + >>> y.shape + (2, 3, 2, 2) + """ + if len(size) != 2: + raise ValueError( + "Expected `size` to be a tuple of 2 integers. " + f"Received: size={size}" + ) + if len(images.shape) < 3 or len(images.shape) > 4: + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) + if any_symbolic_tensors((images,)): + return Resize( + size, + interpolation=interpolation, + antialias=antialias, + data_format=data_format, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + fill_mode=fill_mode, + fill_value=fill_value, + ).symbolic_call(images) + return _resize( + images, + size, + interpolation=interpolation, + antialias=antialias, + crop_to_aspect_ratio=crop_to_aspect_ratio, + data_format=data_format, + pad_to_aspect_ratio=pad_to_aspect_ratio, + fill_mode=fill_mode, + fill_value=fill_value, + ) + + +def _resize( + images, + size, + interpolation="bilinear", + antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, + data_format=None, +): + resized = backend.image.resize( + images, + size, + interpolation=interpolation, + antialias=antialias, + crop_to_aspect_ratio=crop_to_aspect_ratio, + data_format=data_format, + pad_to_aspect_ratio=pad_to_aspect_ratio, + fill_mode=fill_mode, + fill_value=fill_value, + ) + if resized.dtype == images.dtype: + # Only `torch` backend will cast result to original dtype with + # correct rounding and without dtype overflow + return resized + if backend.is_int_dtype(images.dtype): + resized = ops.round(resized) + return ops.saturate_cast(resized, images.dtype) + + +class AffineTransform(Operation): + def __init__( + self, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.interpolation = interpolation + self.fill_mode = fill_mode + self.fill_value = fill_value + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images, transform): + return backend.image.affine_transform( + images, + transform, + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + def compute_output_spec(self, images, transform): + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if len(transform.shape) not in (1, 2): + raise ValueError( + "Invalid transform rank: expected rank 1 (single transform) " + "or rank 2 (batch of transforms). Received input with shape: " + f"transform.shape={transform.shape}" + ) + return KerasTensor(images.shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.affine_transform") +def affine_transform( + images, + transform, + interpolation="bilinear", + fill_mode="constant", + fill_value=0, + data_format=None, +): + """Applies the given transform(s) to the image(s). + + Args: + images: Input image or batch of images. Must be 3D or 4D. + transform: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transform is + `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point + `(x, y)` to a transformed input point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transform is inverted compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + Note that `c0` and `c1` are only effective when using TensorFlow + backend and will be considered as `0` when using other backends. + interpolation: Interpolation method. Available methods are `"nearest"`, + and `"bilinear"`. Defaults to `"bilinear"`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: Value used for points outside the boundaries of the input if + `fill_mode="constant"`. Defaults to `0`. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Applied affine transform image or batch of images. + + Examples: + + >>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images + >>> transform = np.array( + ... [ + ... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom + ... [1, 0, -20, 0, 1, -16, 0, 0], # translation + ... ] + ... ) + >>> y = keras.ops.image.affine_transform(x, transform) + >>> y.shape + (2, 64, 80, 3) + + >>> x = np.random.random((64, 80, 3)) # single RGB image + >>> transform = np.array([1.0, 0.5, -20, 0.5, 1.0, -16, 0, 0]) # shear + >>> y = keras.ops.image.affine_transform(x, transform) + >>> y.shape + (64, 80, 3) + + >>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images + >>> transform = np.array( + ... [ + ... [1.5, 0, -20, 0, 1.5, -16, 0, 0], # zoom + ... [1, 0, -20, 0, 1, -16, 0, 0], # translation + ... ] + ... ) + >>> y = keras.ops.image.affine_transform(x, transform, + ... data_format="channels_first") + >>> y.shape + (2, 3, 64, 80) + """ + if any_symbolic_tensors((images, transform)): + return AffineTransform( + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + data_format=data_format, + ).symbolic_call(images, transform) + return backend.image.affine_transform( + images, + transform, + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + data_format=data_format, + ) + + +class ExtractPatches(Operation): + def __init__( + self, + size, + strides=None, + dilation_rate=1, + padding="valid", + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + if isinstance(size, int): + size = (size, size) + self.size = size + self.strides = strides + self.dilation_rate = dilation_rate + self.padding = padding + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return _extract_patches( + images=images, + size=self.size, + strides=self.strides, + dilation_rate=self.dilation_rate, + padding=self.padding, + data_format=self.data_format, + ) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + original_ndim = len(images_shape) + if not self.strides: + strides = (self.size[0], self.size[1]) + if self.data_format == "channels_last": + channels_in = images_shape[-1] + else: + channels_in = images_shape[-3] + if original_ndim == 3: + images_shape = [1] + images_shape + filters = self.size[0] * self.size[1] * channels_in + kernel_size = (self.size[0], self.size[1]) + out_shape = compute_conv_output_shape( + images_shape, + filters, + kernel_size, + strides=strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + if original_ndim == 3: + out_shape = out_shape[1:] + return KerasTensor(shape=out_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.extract_patches") +def extract_patches( + images, + size, + strides=None, + dilation_rate=1, + padding="valid", + data_format=None, +): + """Extracts patches from the image(s). + + Args: + images: Input image or batch of images. Must be 3D or 4D. + size: Patch size int or tuple (patch_height, patch_width) + strides: strides along height and width. If not specified, or + if `None`, it defaults to the same value as `size`. + dilation_rate: This is the input stride, specifying how far two + consecutive patch samples are in the input. For value other than 1, + strides must be 1. NOTE: `strides > 1` is not supported in + conjunction with `dilation_rate > 1` + padding: The type of padding algorithm to use: `"same"` or `"valid"`. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Extracted patches 3D (if not batched) or 4D (if batched) + + Examples: + + >>> image = np.random.random( + ... (2, 20, 20, 3) + ... ).astype("float32") # batch of 2 RGB images + >>> patches = keras.ops.image.extract_patches(image, (5, 5)) + >>> patches.shape + (2, 4, 4, 75) + >>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image + >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1)) + >>> patches.shape + (18, 18, 27) + """ + if any_symbolic_tensors((images,)): + return ExtractPatches( + size=size, + strides=strides, + dilation_rate=dilation_rate, + padding=padding, + data_format=data_format, + ).symbolic_call(images) + + return _extract_patches( + images, size, strides, dilation_rate, padding, data_format=data_format + ) + + +def _extract_patches( + images, + size, + strides=None, + dilation_rate=1, + padding="valid", + data_format=None, +): + if isinstance(size, int): + patch_h = patch_w = size + elif len(size) == 2: + patch_h, patch_w = size[0], size[1] + else: + raise TypeError( + "Invalid `size` argument. Expected an " + f"int or a tuple of length 2. Received: size={size}" + ) + data_format = backend.standardize_data_format(data_format) + if data_format == "channels_last": + channels_in = images.shape[-1] + elif data_format == "channels_first": + channels_in = images.shape[-3] + if not strides: + strides = size + out_dim = patch_h * patch_w * channels_in + kernel = backend.numpy.eye(out_dim, dtype=images.dtype) + kernel = backend.numpy.reshape( + kernel, (patch_h, patch_w, channels_in, out_dim) + ) + _unbatched = False + if len(images.shape) == 3: + _unbatched = True + images = backend.numpy.expand_dims(images, axis=0) + patches = backend.nn.conv( + inputs=images, + kernel=kernel, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + if _unbatched: + patches = backend.numpy.squeeze(patches, axis=0) + return patches + + +class MapCoordinates(Operation): + def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None): + super().__init__(name=name) + self.order = order + self.fill_mode = fill_mode + self.fill_value = fill_value + + def call(self, inputs, coordinates): + return backend.image.map_coordinates( + inputs, + coordinates, + order=self.order, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + ) + + def compute_output_spec(self, inputs, coordinates): + if coordinates.shape[0] != len(inputs.shape): + raise ValueError( + "First dim of `coordinates` must be the same as the rank of " + "`inputs`. " + f"Received inputs with shape: {inputs.shape} and coordinate " + f"leading dim of {coordinates.shape[0]}" + ) + if len(coordinates.shape) < 2: + raise ValueError( + "Invalid coordinates rank: expected at least rank 2." + f" Received input with shape: {coordinates.shape}" + ) + return KerasTensor(coordinates.shape[1:], dtype=inputs.dtype) + + +@keras_export("keras.ops.image.map_coordinates") +def map_coordinates( + inputs, coordinates, order, fill_mode="constant", fill_value=0 +): + """Map the input array to new coordinates by interpolation. + + Note that interpolation near boundaries differs from the scipy function, + because we fixed an outstanding bug + [scipy/issues/2640](https://github.com/scipy/scipy/issues/2640). + + Args: + inputs: The input array. + coordinates: The coordinates at which inputs is evaluated. + order: The order of the spline interpolation. The order must be `0` or + `1`. `0` indicates the nearest neighbor and `1` indicates the linear + interpolation. + fill_mode: Points outside the boundaries of the inputs are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"mirror"` and `"reflect"`. Defaults to + `"constant"`. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The inputs is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The inputs is extended by the nearest pixel. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The inputs is extended by wrapping around to the opposite edge. + - `"mirror"`: `(c d c b | a b c d | c b a b)` + The inputs is extended by mirroring about the edge. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The inputs is extended by reflecting about the edge of the last + pixel. + fill_value: Value used for points outside the boundaries of the inputs + if `fill_mode="constant"`. Defaults to `0`. + + Returns: + Output input or batch of inputs. + + """ + if any_symbolic_tensors((inputs, coordinates)): + return MapCoordinates( + order, + fill_mode, + fill_value, + ).symbolic_call(inputs, coordinates) + return backend.image.map_coordinates( + inputs, + coordinates, + order, + fill_mode, + fill_value, + ) + + +class PadImages(Operation): + def __init__( + self, + top_padding=None, + left_padding=None, + bottom_padding=None, + right_padding=None, + target_height=None, + target_width=None, + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.top_padding = top_padding + self.left_padding = left_padding + self.bottom_padding = bottom_padding + self.right_padding = right_padding + self.target_height = target_height + self.target_width = target_width + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return _pad_images( + images, + self.top_padding, + self.left_padding, + self.bottom_padding, + self.right_padding, + self.target_height, + self.target_width, + self.data_format, + ) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + + if self.data_format == "channels_last": + height_axis, width_axis = -3, -2 + height, width = images_shape[height_axis], images_shape[width_axis] + else: + height_axis, width_axis = -2, -1 + height, width = images_shape[height_axis], images_shape[width_axis] + + target_height = self.target_height + if target_height is None and height is not None: + target_height = self.top_padding + height + self.bottom_padding + target_width = self.target_width + if target_width is None and width is not None: + target_width = self.left_padding + width + self.right_padding + + images_shape[height_axis] = target_height + images_shape[width_axis] = target_width + return KerasTensor(shape=images_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.pad_images") +def pad_images( + images, + top_padding=None, + left_padding=None, + bottom_padding=None, + right_padding=None, + target_height=None, + target_width=None, + data_format=None, +): + """Pad `images` with zeros to the specified `height` and `width`. + + Args: + images: Input image or batch of images. Must be 3D or 4D. + top_padding: Number of rows of zeros to add on top. + left_padding: Number of columns of zeros to add on the left. + bottom_padding: Number of rows of zeros to add at the bottom. + right_padding: Number of columns of zeros to add on the right. + target_height: Height of output images. + target_width: Width of output images. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Padded image or batch of images. + + Example: + + >>> images = np.random.random((15, 25, 3)) + >>> padded_images = keras.ops.image.pad_images( + ... images, 2, 3, target_height=20, target_width=30 + ... ) + >>> padded_images.shape + (20, 30, 3) + + >>> batch_images = np.random.random((2, 15, 25, 3)) + >>> padded_batch = keras.ops.image.pad_images( + ... batch_images, 2, 3, target_height=20, target_width=30 + ... ) + >>> padded_batch.shape + (2, 20, 30, 3)""" + + if any_symbolic_tensors((images,)): + return PadImages( + top_padding, + left_padding, + bottom_padding, + right_padding, + target_height, + target_width, + data_format, + ).symbolic_call(images) + + return _pad_images( + images, + top_padding, + left_padding, + bottom_padding, + right_padding, + target_height, + target_width, + data_format, + ) + + +def _pad_images( + images, + top_padding, + left_padding, + bottom_padding, + right_padding, + target_height, + target_width, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + images = backend.convert_to_tensor(images) + images_shape = ops.shape(images) + + # Check + if len(images_shape) not in (3, 4): + raise ValueError( + f"Invalid shape for argument `images`: " + "it must have rank 3 or 4. " + f"Received: images.shape={images_shape}" + ) + if [top_padding, bottom_padding, target_height].count(None) != 1: + raise ValueError( + "Must specify exactly two of " + "top_padding, bottom_padding, target_height. " + f"Received: top_padding={top_padding}, " + f"bottom_padding={bottom_padding}, " + f"target_height={target_height}" + ) + if [left_padding, right_padding, target_width].count(None) != 1: + raise ValueError( + "Must specify exactly two of " + "left_padding, right_padding, target_width. " + f"Received: left_padding={left_padding}, " + f"right_padding={right_padding}, " + f"target_width={target_width}" + ) + + is_batch = False if len(images_shape) == 3 else True + if data_format == "channels_last": + height, width = images_shape[-3], images_shape[-2] + else: + height, width = images_shape[-2], images_shape[-1] + + # Infer padding + if top_padding is None: + top_padding = target_height - bottom_padding - height + if bottom_padding is None: + bottom_padding = target_height - top_padding - height + if left_padding is None: + left_padding = target_width - right_padding - width + if right_padding is None: + right_padding = target_width - left_padding - width + + if top_padding < 0: + raise ValueError( + f"top_padding must be >= 0. Received: top_padding={top_padding}" + ) + if left_padding < 0: + raise ValueError( + f"left_padding must be >= 0. Received: left_padding={left_padding}" + ) + if right_padding < 0: + raise ValueError( + "right_padding must be >= 0. " + f"Received: right_padding={right_padding}" + ) + if bottom_padding < 0: + raise ValueError( + "bottom_padding must be >= 0. " + f"Received: bottom_padding={bottom_padding}" + ) + + # Compute pad_width + pad_width = [[top_padding, bottom_padding], [left_padding, right_padding]] + if data_format == "channels_last": + pad_width = pad_width + [[0, 0]] + else: + pad_width = [[0, 0]] + pad_width + if is_batch: + pad_width = [[0, 0]] + pad_width + + padded_images = backend.numpy.pad(images, pad_width) + return padded_images + + +class CropImages(Operation): + def __init__( + self, + top_cropping=None, + left_cropping=None, + bottom_cropping=None, + right_cropping=None, + target_height=None, + target_width=None, + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.top_cropping = top_cropping + self.bottom_cropping = bottom_cropping + self.left_cropping = left_cropping + self.right_cropping = right_cropping + self.target_height = target_height + self.target_width = target_width + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return _crop_images( + images, + self.top_cropping, + self.left_cropping, + self.bottom_cropping, + self.right_cropping, + self.target_height, + self.target_width, + self.data_format, + ) + + def compute_output_spec(self, images): + images_shape = list(images.shape) + + if self.data_format == "channels_last": + height_axis, width_axis = -3, -2 + else: + height_axis, width_axis = -2, -1 + height, width = images_shape[height_axis], images_shape[width_axis] + + if height is None and self.target_height is None: + raise ValueError( + "When the height of the images is unknown, `target_height` " + "must be specified." + f"Received images.shape={images_shape} and " + f"target_height={self.target_height}" + ) + if width is None and self.target_width is None: + raise ValueError( + "When the width of the images is unknown, `target_width` " + "must be specified." + f"Received images.shape={images_shape} and " + f"target_width={self.target_width}" + ) + + target_height = self.target_height + if target_height is None: + target_height = height - self.top_cropping - self.bottom_cropping + target_width = self.target_width + if target_width is None: + target_width = width - self.left_cropping - self.right_cropping + + images_shape[height_axis] = target_height + images_shape[width_axis] = target_width + return KerasTensor(shape=images_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.crop_images") +def crop_images( + images, + top_cropping=None, + left_cropping=None, + bottom_cropping=None, + right_cropping=None, + target_height=None, + target_width=None, + data_format=None, +): + """Crop `images` to a specified `height` and `width`. + + Args: + images: Input image or batch of images. Must be 3D or 4D. + top_cropping: Number of columns to crop from the top. + left_cropping: Number of columns to crop from the left. + bottom_cropping: Number of columns to crop from the bottom. + right_cropping: Number of columns to crop from the right. + target_height: Height of the output images. + target_width: Width of the output images. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Cropped image or batch of images. + + Example: + + >>> images = np.reshape(np.arange(1, 28, dtype="float32"), [3, 3, 3]) + >>> images[:,:,0] # print the first channel of the images + array([[ 1., 4., 7.], + [10., 13., 16.], + [19., 22., 25.]], dtype=float32) + >>> cropped_images = keras.image.crop_images(images, 0, 0, 2, 2) + >>> cropped_images[:,:,0] # print the first channel of the cropped images + array([[ 1., 4.], + [10., 13.]], dtype=float32)""" + + if any_symbolic_tensors((images,)): + return CropImages( + top_cropping, + left_cropping, + bottom_cropping, + right_cropping, + target_height, + target_width, + data_format, + ).symbolic_call(images) + + return _crop_images( + images, + top_cropping, + left_cropping, + bottom_cropping, + right_cropping, + target_height, + target_width, + data_format, + ) + + +def _crop_images( + images, + top_cropping, + left_cropping, + bottom_cropping, + right_cropping, + target_height, + target_width, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + images = backend.convert_to_tensor(images) + images_shape = ops.shape(images) + + # Check + if len(images_shape) not in (3, 4): + raise ValueError( + f"Invalid shape for argument `images`: " + "it must have rank 3 or 4. " + f"Received: images.shape={images_shape}" + ) + if [top_cropping, bottom_cropping, target_height].count(None) != 1: + raise ValueError( + "Must specify exactly two of " + "top_cropping, bottom_cropping, target_height. " + f"Received: top_cropping={top_cropping}, " + f"bottom_cropping={bottom_cropping}, " + f"target_height={target_height}" + ) + if [left_cropping, right_cropping, target_width].count(None) != 1: + raise ValueError( + "Must specify exactly two of " + "left_cropping, right_cropping, target_width. " + f"Received: left_cropping={left_cropping}, " + f"right_cropping={right_cropping}, " + f"target_width={target_width}" + ) + + is_batch = False if len(images_shape) == 3 else True + if data_format == "channels_last": + height, width = images_shape[-3], images_shape[-2] + channels = images_shape[-1] + else: + height, width = images_shape[-2], images_shape[-1] + channels = images_shape[-3] + + # Infer padding + if top_cropping is None: + top_cropping = height - target_height - bottom_cropping + if target_height is None: + target_height = height - bottom_cropping - top_cropping + if left_cropping is None: + left_cropping = width - target_width - right_cropping + if target_width is None: + target_width = width - right_cropping - left_cropping + + if top_cropping < 0: + raise ValueError( + f"top_cropping must be >= 0. Received: top_cropping={top_cropping}" + ) + if target_height < 0: + raise ValueError( + "target_height must be >= 0. " + f"Received: target_height={target_height}" + ) + if left_cropping < 0: + raise ValueError( + "left_cropping must be >= 0. " + f"Received: left_cropping={left_cropping}" + ) + if target_width < 0: + raise ValueError( + f"target_width must be >= 0. Received: target_width={target_width}" + ) + + # Compute start_indices and shape + start_indices = [top_cropping, left_cropping] + shape = [target_height, target_width] + if data_format == "channels_last": + start_indices = start_indices + [0] + shape = shape + [channels] + else: + start_indices = [0] + start_indices + shape = [channels] + shape + if is_batch: + batch_size = images_shape[0] + start_indices = [0] + start_indices + shape = [batch_size] + shape + + cropped_images = ops.slice(images, start_indices, shape) + return cropped_images + + +class PerspectiveTransform(Operation): + def __init__( + self, + interpolation="bilinear", + fill_value=0, + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.interpolation = interpolation + self.fill_value = fill_value + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images, start_points, end_points): + return backend.image.perspective_transform( + images, + start_points, + end_points, + interpolation=self.interpolation, + fill_value=self.fill_value, + data_format=self.data_format, + ) + + def compute_output_spec(self, images, start_points, end_points): + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + if start_points.shape[-2:] != (4, 2) or start_points.ndim not in (2, 3): + raise ValueError( + "Invalid start_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {start_points.shape}" + ) + if end_points.shape[-2:] != (4, 2) or end_points.ndim not in (2, 3): + raise ValueError( + "Invalid end_points shape: expected (4,2) for a single image" + f" or (N,4,2) for a batch. Received shape: {end_points.shape}" + ) + if start_points.shape != end_points.shape: + raise ValueError( + "start_points and end_points must have the same shape." + f" Received start_points.shape={start_points.shape}, " + f"end_points.shape={end_points.shape}" + ) + return KerasTensor(images.shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.perspective_transform") +def perspective_transform( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + """Applies a perspective transformation to the image(s). + + Args: + images: Input image or batch of images. Must be 3D or 4D. + start_points: A tensor of shape `(N, 4, 2)` or `(4, 2)`, + representing the source points in the original image + that define the transformation. + end_points: A tensor of shape `(N, 4, 2)` or `(4, 2)`, + representing the target points in the output image + after transformation. + interpolation: Interpolation method. Available methods are `"nearest"`, + and `"bilinear"`. Defaults to `"bilinear"`. + fill_value: Value used for points outside the boundaries of the input if + extrapolation is needed. Defaults to `0`. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Applied perspective transform image or batch of images. + + Examples: + + >>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images + >>> start_points = np.array( + ... [ + ... [[0, 0], [0, 64], [80, 0], [80, 64]], + ... [[0, 0], [0, 64], [80, 0], [80, 64]], + ... ] + ... ) + >>> end_points = np.array( + ... [ + ... [[3, 5], [7, 64], [76, -10], [84, 61]], + ... [[8, 10], [10, 61], [65, 3], [88, 43]], + ... ] + ... ) + >>> y = keras.ops.image.perspective_transform(x, start_points, end_points) + >>> y.shape + (2, 64, 80, 3) + + >>> x = np.random.random((64, 80, 3)) # single RGB image + >>> start_points = np.array([[0, 0], [0, 64], [80, 0], [80, 64]]) + >>> end_points = np.array([[3, 5], [7, 64], [76, -10], [84, 61]]) + >>> y = keras.ops.image.perspective_transform(x, start_points, end_points) + >>> y.shape + (64, 80, 3) + + >>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images + >>> start_points = np.array( + ... [ + ... [[0, 0], [0, 64], [80, 0], [80, 64]], + ... [[0, 0], [0, 64], [80, 0], [80, 64]], + ... ] + ... ) + >>> end_points = np.array( + ... [ + ... [[3, 5], [7, 64], [76, -10], [84, 61]], + ... [[8, 10], [10, 61], [65, 3], [88, 43]], + ... ] + ... ) + >>> y = keras.ops.image.perspective_transform( + ... x, start_points, end_points, data_format="channels_first" + ... ) + >>> y.shape + (2, 3, 64, 80) + """ + if any_symbolic_tensors((images, start_points, end_points)): + return PerspectiveTransform( + interpolation=interpolation, + fill_value=fill_value, + data_format=data_format, + ).symbolic_call(images, start_points, end_points) + return backend.image.perspective_transform( + images, + start_points, + end_points, + interpolation=interpolation, + fill_value=fill_value, + data_format=data_format, + ) + + +class GaussianBlur(Operation): + def __init__( + self, + kernel_size=(3, 3), + sigma=(1.0, 1.0), + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.kernel_size = kernel_size + self.sigma = sigma + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return backend.image.gaussian_blur( + images, + kernel_size=self.kernel_size, + sigma=self.sigma, + data_format=self.data_format, + ) + + def compute_output_spec(self, images): + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + return KerasTensor(images.shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.gaussian_blur") +def gaussian_blur( + images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None +): + """Applies a Gaussian blur to the image(s). + + Args: + images: Input image or batch of images. Must be 3D or 4D. + kernel_size: A tuple of two integers, specifying the height and width + of the Gaussian kernel. + sigma: A tuple of two floats, specifying the standard deviation of + the Gaussian kernel along height and width. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Blurred image or batch of images. + + Examples: + + >>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images + >>> y = keras.ops.image.gaussian_blur(x) + >>> y.shape + (2, 64, 80, 3) + + >>> x = np.random.random((64, 80, 3)) # single RGB image + >>> y = keras.ops.image.gaussian_blur(x) + >>> y.shape + (64, 80, 3) + + >>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images + >>> y = keras.ops.image.gaussian_blur( + ... x, data_format="channels_first") + >>> y.shape + (2, 3, 64, 80) + """ + if any_symbolic_tensors((images,)): + return GaussianBlur( + kernel_size=kernel_size, + sigma=sigma, + data_format=data_format, + ).symbolic_call(images) + return backend.image.gaussian_blur( + images, + kernel_size=kernel_size, + sigma=sigma, + data_format=data_format, + ) + + +class ElasticTransform(Operation): + def __init__( + self, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.alpha = alpha + self.sigma = sigma + self.interpolation = interpolation + self.fill_mode = fill_mode + self.fill_value = fill_value + self.seed = seed + self.data_format = backend.standardize_data_format(data_format) + + def call(self, images): + return backend.image.elastic_transform( + images, + alpha=self.alpha, + sigma=self.sigma, + interpolation=self.interpolation, + fill_mode=self.fill_mode, + fill_value=self.fill_value, + seed=self.seed, + data_format=self.data_format, + ) + + def compute_output_spec(self, images): + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + return KerasTensor(images.shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.elastic_transform") +def elastic_transform( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + """Applies elastic deformation to the image(s). + + Args: + images: Input image or batch of images. Must be 3D or 4D. + alpha: Scaling factor that controls the intensity of the deformation. + sigma: Standard deviation of the Gaussian filter used for + smoothing the displacement fields. + interpolation: Interpolation method. Available methods are `"nearest"`, + and `"bilinear"`. Defaults to `"bilinear"`. + fill_mode: Points outside the boundaries of the input are filled + according to the given mode. Available methods are `"constant"`, + `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`. + - `"reflect"`: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last + pixel. + - `"constant"`: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond + the edge with the same constant value k specified by + `fill_value`. + - `"wrap"`: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - `"nearest"`: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: Value used for points outside the boundaries of the input if + `fill_mode="constant"`. Defaults to `0`. + data_format: A string specifying the data format of the input tensor. + It can be either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, height, width, channels)`, while `"channels_first"` + corresponds to inputs with shape `(batch, channels, height, width)`. + If not specified, the value will default to + `keras.config.image_data_format`. + + Returns: + Transformed image or batch of images with elastic deformation. + + Examples: + + >>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images + >>> y = keras.ops.image.elastic_transform(x) + >>> y.shape + (2, 64, 80, 3) + + >>> x = np.random.random((64, 80, 3)) # single RGB image + >>> y = keras.ops.image.elastic_transform(x) + >>> y.shape + (64, 80, 3) + + >>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images + >>> y = keras.ops.image.elastic_transform( + ... x, data_format="channels_first") + >>> y.shape + (2, 3, 64, 80) + """ + if any_symbolic_tensors((images,)): + return ElasticTransform( + alpha=alpha, + sigma=sigma, + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + seed=seed, + data_format=data_format, + ).symbolic_call(images) + return backend.image.elastic_transform( + images, + alpha=alpha, + sigma=sigma, + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + seed=seed, + data_format=data_format, + ) + + +class ScaleAndTranslate(Operation): + def __init__(self, spatial_dims, method, antialias=True, *, name=None): + super().__init__(name=name) + self.spatial_dims = spatial_dims + self.method = method + self.antialias = antialias + + def call(self, images, output_shape, scale, translation): + return backend.image.scale_and_translate( + images, + output_shape=output_shape, + scale=scale, + translation=translation, + spatial_dims=self.spatial_dims, + method=self.method, + antialias=self.antialias, + ) + + def compute_output_spec(self, images, output_shape, scale, translation): + return KerasTensor(output_shape, dtype=images.dtype) + + +@keras_export("keras.ops.image.scale_and_translate") +def scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias=True, +): + """Apply a scale and translation to the images. + + Generates a new image of `output_shape` by resampling from the input image + using the sampling method corresponding to method. For 2D images, this + operation transforms a location in the input images, (x, y), to a location + in the output image according to: + + `(x * scale[1] + translation[1], y * scale[0] + translation[0])`. + + (Note the inverse warp is used to generate the sample locations.) Assumes + half-centered pixels, i.e the pixel at integer location row, col has + coordinates y, x = row + 0.5, col + 0.5, and similarly for other input image + dimensions. + + If an output location(pixel) maps to an input sample location that is + outside the input boundaries then the value for the output location will be + set to zero. + + The `method` argument expects one of the following resize methods: + + - `"linear"`, `"bilinear"`, `"trilinear"`, `"triangle"`: Linear + interpolation. If `antialias` is True, uses a triangular filter when + downsampling. + - `"cubic"`, `"bicubic"`, `"tricubic"`: Cubic interpolation, using the Keys + cubic kernel. + - `"lanczos3"`: Lanczos resampling, using a kernel of radius 3. + - `"lanczos5"`: Lanczos resampling, using a kernel of radius 5. + + Args: + images: The input array. + output_shape: The output shape, as a sequence of integers with length + equal to the number of dimensions of image. + scale: A [K] array with the same number of dimensions as `images`, + containing the scale to apply in each dimension. + translation: A [K] array with the same number of dimensions as `images`, + containing the translation to apply in each dimension. + spatial_dims: A length K tuple specifying the spatial dimensions that + the passed `scale` and `translation` should be applied to. + method: A string specifying the resizing method to use. Available + methods are `"linear"`, `"bilinear"`, `"trilinear"`, `"triangle"`, + `"cubic"`, `"bicubic"`, `"tricubic"`, `"lanczos3"` and `"lanczos5"`. + antialias: Whether an antialiasing filter should be applied when + downsampling. Has no effect when upsampling. Defaults to `True`. + + Returns: + The scale and translated images. + + Example: + + >>> images = np.arange(9, dtype="float32").reshape((3, 3)) + >>> scale = np.array([2.0, 2.0]).astype("float32") + >>> translation = -(scale / 2.0 - 0.5) + >>> resized_images = keras.image.scale_and_translate( + ... images, (5, 5), scale, translation, (0, 1), "linear" + ... ) + >>> resized_images + array([[0.0 0.5 1.0 1.5 2.0] + [1.5 2.0 2.5 3.0 3.5] + [3.0 3.5 4.0 4.5 5.0] + [4.5 5.0 5.5 6.0 6.5] + [6.0 6.5 7.0 7.5 8.0]], dtype=float32) + """ + if any_symbolic_tensors((images, scale, translation)): + return ScaleAndTranslate(spatial_dims, method, antialias).symbolic_call( + images, output_shape, scale, translation + ) + return backend.image.scale_and_translate( + images, + output_shape, + scale, + translation, + spatial_dims, + method, + antialias, + ) diff --git a/keras/src/ops/image_test.py b/keras/src/ops/image_test.py new file mode 100644 index 000000000000..fae6129b35b0 --- /dev/null +++ b/keras/src/ops/image_test.py @@ -0,0 +1,2524 @@ +import math + +import jax +import numpy as np +import pytest +import scipy.ndimage +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.backend.common import dtypes +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.ops import image as kimage +from keras.src.ops import numpy as knp +from keras.src.ops import random as krandom +from keras.src.testing.test_utils import named_product + + +class ImageOpsDynamicShapeTest(testing.TestCase): + def setUp(self): + # Defaults to channels_last + self.data_format = backend.image_data_format() + backend.set_image_data_format("channels_last") + return super().setUp() + + def tearDown(self): + backend.set_image_data_format(self.data_format) + return super().tearDown() + + def test_rgb_to_grayscale(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + out = kimage.rgb_to_grayscale(x) + self.assertEqual(out.shape, (None, 20, 20, 1)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + out = kimage.rgb_to_grayscale(x) + self.assertEqual(out.shape, (None, 1, 20, 20)) + + def test_rgb_to_hsv(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + out = kimage.rgb_to_hsv(x) + self.assertEqual(out.shape, (None, 20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + out = kimage.rgb_to_hsv(x) + self.assertEqual(out.shape, (None, 3, 20, 20)) + + def test_hsv_to_rgb(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + out = kimage.hsv_to_rgb(x) + self.assertEqual(out.shape, (None, 20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + out = kimage.hsv_to_rgb(x) + self.assertEqual(out.shape, (None, 3, 20, 20)) + + def test_resize(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + out = kimage.resize(x, size=(15, 15)) + self.assertEqual(out.shape, (None, 15, 15, 3)) + + x = KerasTensor([None, None, 3]) + out = kimage.resize(x, size=(15, 15)) + self.assertEqual(out.shape, (15, 15, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + out = kimage.resize(x, size=(15, 15)) + self.assertEqual(out.shape, (None, 3, 15, 15)) + + x = KerasTensor([3, None, None]) + out = kimage.resize(x, size=(15, 15)) + self.assertEqual(out.shape, (3, 15, 15)) + + def test_affine_transform(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + transform = KerasTensor([None, 8]) + out = kimage.affine_transform(x, transform) + self.assertEqual(out.shape, (None, 20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + transform = KerasTensor([None, 8]) + out = kimage.affine_transform(x, transform) + self.assertEqual(out.shape, (None, 3, 20, 20)) + + def test_extract_patches(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + p_h, p_w = 5, 5 + out = kimage.extract_patches(x, (p_h, p_w)) + self.assertEqual(out.shape, (None, 4, 4, 75)) + out = kimage.extract_patches(x, 5) + self.assertEqual(out.shape, (None, 4, 4, 75)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + p_h, p_w = 5, 5 + out = kimage.extract_patches(x, (p_h, p_w)) + self.assertEqual(out.shape, (None, 75, 4, 4)) + out = kimage.extract_patches(x, 5) + self.assertEqual(out.shape, (None, 75, 4, 4)) + + def test_map_coordinates(self): + input = KerasTensor([20, 20, None]) + coordinates = KerasTensor([3, 15, 15, None]) + out = kimage.map_coordinates(input, coordinates, 0) + self.assertEqual(out.shape, coordinates.shape[1:]) + + def test_pad_images(self): + # Test channels_last + x = KerasTensor([None, 15, 25, 3]) + out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30) + self.assertEqual(out.shape, (None, 20, 30, 3)) + + x = KerasTensor([None, None, 3]) + out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30) + self.assertEqual(out.shape, (20, 30, 3)) + + # Test unknown shape + x = KerasTensor([None, None, 3]) + out = kimage.pad_images(x, 2, 3, 2, 3) + self.assertEqual(out.shape, (None, None, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 15, 25]) + out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30) + self.assertEqual(out.shape, (None, 3, 20, 30)) + + x = KerasTensor([3, None, None]) + out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30) + self.assertEqual(out.shape, (3, 20, 30)) + + def test_crop_images(self): + # Test channels_last + x = KerasTensor([None, 15, 25, 3]) + out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20) + self.assertEqual(out.shape, (None, 10, 20, 3)) + + x = KerasTensor([None, None, 3]) + out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20) + self.assertEqual(out.shape, (10, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 15, 25]) + out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20) + self.assertEqual(out.shape, (None, 3, 10, 20)) + + x = KerasTensor([3, None, None]) + out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20) + self.assertEqual(out.shape, (3, 10, 20)) + + def test_perspective_transform(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + start_points = KerasTensor([None, 4, 2]) + end_points = KerasTensor([None, 4, 2]) + out = kimage.perspective_transform(x, start_points, end_points) + self.assertEqual(out.shape, (None, 20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + start_points = KerasTensor([None, 4, 2]) + end_points = KerasTensor([None, 4, 2]) + out = kimage.perspective_transform(x, start_points, end_points) + self.assertEqual(out.shape, (None, 3, 20, 20)) + + def test_gaussian_blur(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + out = kimage.gaussian_blur(x) + self.assertEqual(out.shape, (None, 20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + out = kimage.gaussian_blur(x) + self.assertEqual(out.shape, (None, 3, 20, 20)) + + def test_elastic_transform(self): + # Test channels_last + x = KerasTensor([None, 20, 20, 3]) + out = kimage.elastic_transform(x) + self.assertEqual(out.shape, (None, 20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([None, 3, 20, 20]) + out = kimage.elastic_transform(x) + self.assertEqual(out.shape, (None, 3, 20, 20)) + + def test_scale_and_translate(self): + images = KerasTensor([None, 20, 20, 3]) + output_shape = (None, 25, 25, 3) + scale = KerasTensor([2]) + translation = KerasTensor([2]) + out = kimage.scale_and_translate( + images, + output_shape=output_shape, + scale=scale, + translation=translation, + spatial_dims=(1, 2), + method="linear", + ) + self.assertEqual(out.shape, output_shape) + + +class ImageOpsStaticShapeTest(testing.TestCase): + def setUp(self): + # Defaults to channels_last + self.data_format = backend.image_data_format() + backend.set_image_data_format("channels_last") + return super().setUp() + + def tearDown(self): + backend.set_image_data_format(self.data_format) + return super().tearDown() + + def test_rgb_to_grayscale(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + out = kimage.rgb_to_grayscale(x) + self.assertEqual(out.shape, (20, 20, 1)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + out = kimage.rgb_to_grayscale(x) + self.assertEqual(out.shape, (1, 20, 20)) + + def test_rgb_to_hsv(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + out = kimage.rgb_to_hsv(x) + self.assertEqual(out.shape, (20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + out = kimage.rgb_to_hsv(x) + self.assertEqual(out.shape, (3, 20, 20)) + + def test_hsv_to_rgb(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + out = kimage.hsv_to_rgb(x) + self.assertEqual(out.shape, (20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + out = kimage.hsv_to_rgb(x) + self.assertEqual(out.shape, (3, 20, 20)) + + def test_resize(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + out = kimage.resize(x, size=(15, 15)) + self.assertEqual(out.shape, (15, 15, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + out = kimage.resize(x, size=(15, 15)) + self.assertEqual(out.shape, (3, 15, 15)) + + def test_affine_transform(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + transform = KerasTensor([8]) + out = kimage.affine_transform(x, transform) + self.assertEqual(out.shape, (20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + transform = KerasTensor([8]) + out = kimage.affine_transform(x, transform) + self.assertEqual(out.shape, (3, 20, 20)) + + def test_extract_patches(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + p_h, p_w = 5, 5 + out = kimage.extract_patches(x, (p_h, p_w)) + self.assertEqual(out.shape, (4, 4, 75)) + out = kimage.extract_patches(x, 5) + self.assertEqual(out.shape, (4, 4, 75)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + p_h, p_w = 5, 5 + out = kimage.extract_patches(x, (p_h, p_w)) + self.assertEqual(out.shape, (75, 4, 4)) + out = kimage.extract_patches(x, 5) + self.assertEqual(out.shape, (75, 4, 4)) + + def test_map_coordinates(self): + input = KerasTensor([20, 20, 3]) + coordinates = KerasTensor([3, 15, 15, 3]) + out = kimage.map_coordinates(input, coordinates, 0) + self.assertEqual(out.shape, coordinates.shape[1:]) + + def test_map_coordinates_uint8(self): + image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8) + coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None] + + if backend.backend() != "tensorflow": + pytest.skip("Skipping test because the backend is not TensorFlow.") + + out = kimage.map_coordinates( + image_uint8, coordinates, order=1, fill_mode="constant" + ) + assert out.shape == coordinates.shape[1:] + + def test_map_coordinates_float32(self): + image_float32 = tf.ones((1, 1, 3), dtype=tf.float32) + coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None] + + if backend.backend() != "tensorflow": + pytest.skip("Skipping test because the backend is not TensorFlow.") + + out = kimage.map_coordinates( + image_float32, coordinates, order=1, fill_mode="constant" + ) + assert out.shape == coordinates.shape[1:] + + def test_map_coordinates_nearest(self): + image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8) + coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None] + + if backend.backend() != "tensorflow": + pytest.skip("Skipping test because the backend is not TensorFlow.") + + out = kimage.map_coordinates( + image_uint8, coordinates, order=1, fill_mode="nearest" + ) + assert out.shape == coordinates.shape[1:] + + def test_map_coordinates_manual_cast(self): + image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8) + coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None] + image_uint8_casted = tf.cast(image_uint8, dtype=tf.float32) + + if backend.backend() != "tensorflow": + pytest.skip("Skipping test because the backend is not TensorFlow.") + + out = tf.cast( + kimage.map_coordinates( + image_uint8_casted, coordinates, order=1, fill_mode="constant" + ), + dtype=tf.uint8, + ) + assert out.shape == coordinates.shape[1:] + + def test_pad_images(self): + # Test channels_last + x = KerasTensor([15, 25, 3]) + out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30) + self.assertEqual(out.shape, (20, 30, 3)) + + x_batch = KerasTensor([2, 15, 25, 3]) + out_batch = kimage.pad_images( + x_batch, 2, 3, target_height=20, target_width=30 + ) + self.assertEqual(out_batch.shape, (2, 20, 30, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 15, 25]) + out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30) + self.assertEqual(out.shape, (3, 20, 30)) + + x_batch = KerasTensor([2, 3, 15, 25]) + out_batch = kimage.pad_images( + x_batch, 2, 3, target_height=20, target_width=30 + ) + self.assertEqual(out_batch.shape, (2, 3, 20, 30)) + + def test_crop_images(self): + # Test channels_last + x = KerasTensor([15, 25, 3]) + out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20) + self.assertEqual(out.shape, (10, 20, 3)) + + x_batch = KerasTensor([2, 15, 25, 3]) + out_batch = kimage.crop_images( + x_batch, 2, 3, target_height=10, target_width=20 + ) + self.assertEqual(out_batch.shape, (2, 10, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 15, 25]) + out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20) + self.assertEqual(out.shape, (3, 10, 20)) + + # Test channels_first and batched + x_batch = KerasTensor([2, 3, 15, 25]) + out_batch = kimage.crop_images( + x_batch, 2, 3, target_height=10, target_width=20 + ) + self.assertEqual(out_batch.shape, (2, 3, 10, 20)) + + def test_perspective_transform(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + start_points = KerasTensor([4, 2]) + end_points = KerasTensor([4, 2]) + out = kimage.perspective_transform(x, start_points, end_points) + self.assertEqual(out.shape, (20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + start_points = KerasTensor([4, 2]) + end_points = KerasTensor([4, 2]) + out = kimage.perspective_transform(x, start_points, end_points) + self.assertEqual(out.shape, (3, 20, 20)) + + def test_gaussian_blur(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + kernel_size = KerasTensor( + [ + 2, + ] + ) + sigma = KerasTensor( + [ + 2, + ] + ) + out = kimage.gaussian_blur(x, kernel_size, sigma) + self.assertEqual(out.shape, (20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + kernel_size = KerasTensor( + [ + 2, + ] + ) + sigma = KerasTensor( + [ + 2, + ] + ) + out = kimage.gaussian_blur(x, kernel_size, sigma) + self.assertEqual(out.shape, (3, 20, 20)) + + def test_elastic_transform(self): + # Test channels_last + x = KerasTensor([20, 20, 3]) + out = kimage.elastic_transform(x) + self.assertEqual(out.shape, (20, 20, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = KerasTensor([3, 20, 20]) + out = kimage.elastic_transform(x) + self.assertEqual(out.shape, (3, 20, 20)) + + def test_scale_and_translate(self): + images = KerasTensor([20, 20, 3]) + output_shape = (25, 25, 3) + scale = KerasTensor([2]) + translation = KerasTensor([2]) + out = kimage.scale_and_translate( + images, + output_shape=output_shape, + scale=scale, + translation=translation, + spatial_dims=(0, 1), + method="linear", + ) + self.assertEqual(out.shape, output_shape) + + +AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order + "nearest": 0, + "bilinear": 1, +} + + +def _compute_affine_transform_coordinates(image, transform): + image = image.copy() + transform = transform.copy() + need_squeeze = False + if len(image.shape) == 3: # unbatched + need_squeeze = True + image = np.expand_dims(image, axis=0) + transform = np.expand_dims(transform, axis=0) + batch_size = image.shape[0] + # get indices + meshgrid = np.meshgrid( + *[np.arange(size) for size in image.shape[1:]], indexing="ij" + ) + indices = np.concatenate( + [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1 + ) + indices = np.tile(indices, (batch_size, 1, 1, 1, 1)) + # swap the values + transform[:, 4], transform[:, 0] = ( + transform[:, 0].copy(), + transform[:, 4].copy(), + ) + transform[:, 5], transform[:, 2] = ( + transform[:, 2].copy(), + transform[:, 5].copy(), + ) + # deal with transform + transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1) + transform = np.reshape(transform, (batch_size, 3, 3)) + offset = np.pad(transform[:, 0:2, 2], pad_width=[[0, 0], [0, 1]]) + transform[:, 0:2, 2] = 0 + # transform the indices + coordinates = np.einsum("Bhwij, Bjk -> Bhwik", indices, transform) + coordinates = np.moveaxis(coordinates, source=-1, destination=1) + coordinates += np.reshape(offset, newshape=(*offset.shape, 1, 1, 1)) + if need_squeeze: + coordinates = np.squeeze(coordinates, axis=0) + return coordinates + + +def _fixed_map_coordinates( + input, coordinates, order, fill_mode="constant", fill_value=0.0 +): + # SciPy's implementation of map_coordinates handles boundaries incorrectly, + # unless mode='reflect'. For order=1, this only affects interpolation + # outside the bounds of the original array. + # https://github.com/scipy/scipy/issues/2640 + padding = [ + ( + max(-np.floor(c.min()).astype(int) + 1, 0), + max(np.ceil(c.max()).astype(int) + 1 - size, 0), + ) + for c, size in zip(coordinates, input.shape) + ] + shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)] + pad_mode = { + "nearest": "edge", + "mirror": "reflect", + "reflect": "symmetric", + }.get(fill_mode, fill_mode) + if fill_mode == "constant": + padded = np.pad( + input, padding, mode=pad_mode, constant_values=fill_value + ) + else: + padded = np.pad(input, padding, mode=pad_mode) + result = scipy.ndimage.map_coordinates( + padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value + ) + return result + + +def _perspective_transform_numpy( + images, + start_points, + end_points, + interpolation="bilinear", + fill_value=0, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + + if len(start_points.shape) == 2: + start_points = np.expand_dims(start_points, axis=0) + if len(end_points.shape) == 2: + end_points = np.expand_dims(end_points, axis=0) + + if data_format == "channels_first": + images = np.transpose(images, (0, 2, 3, 1)) + + batch_size, height, width, channels = images.shape + + transforms = _compute_homography_matrix(start_points, end_points) + + if len(transforms.shape) == 1: + transforms = np.expand_dims(transforms, axis=0) + if transforms.shape[0] == 1 and batch_size > 1: + transforms = np.tile(transforms, (batch_size, 1)) + + x, y = np.meshgrid( + np.arange(width, dtype=np.float32), + np.arange(height, dtype=np.float32), + indexing="xy", + ) + + output = np.empty((batch_size, height, width, channels)) + + for i in range(batch_size): + a0, a1, a2, a3, a4, a5, a6, a7 = transforms[i] + denom = a6 * x + a7 * y + 1.0 + x_in = (a0 * x + a1 * y + a2) / denom + y_in = (a3 * x + a4 * y + a5) / denom + + coords = np.stack([y_in.ravel(), x_in.ravel()], axis=0) + + mapped_channels = [] + for channel in range(channels): + channel_img = images[i, :, :, channel] + + mapped_channel = _fixed_map_coordinates( + channel_img, + coords, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode="constant", + fill_value=fill_value, + ) + mapped_channels.append(mapped_channel.reshape(height, width)) + + output[i] = np.stack(mapped_channels, axis=-1) + + if data_format == "channels_first": + output = np.transpose(output, (0, 3, 1, 2)) + if need_squeeze: + output = np.squeeze(output, axis=0) + + return output + + +def gaussian_blur_np( + images, + kernel_size, + sigma, + data_format=None, +): + def _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype): + def _get_gaussian_kernel1d(size, sigma): + x = np.arange(size, dtype=dtype) - (size - 1) / 2 + kernel1d = np.exp(-0.5 * (x / sigma) ** 2) + return kernel1d / np.sum(kernel1d) + + def _get_gaussian_kernel2d(size, sigma): + kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0]) + kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1]) + return np.outer(kernel1d_y, kernel1d_x) + + kernel = _get_gaussian_kernel2d(kernel_size, sigma) + kernel = kernel[:, :, np.newaxis] + kernel = np.tile(kernel, (1, 1, num_channels)) + return kernel.astype(dtype) + + images = np.asarray(images) + input_dtype = images.dtype + kernel_size = np.asarray(kernel_size) + + if len(images.shape) not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_first": + images = np.transpose(images, (0, 2, 3, 1)) + + num_channels = images.shape[-1] + kernel = _create_gaussian_kernel( + kernel_size, sigma, num_channels, input_dtype + ) + batch_size, height, width, _ = images.shape + padded_images = np.pad( + images, + ( + (0, 0), + (kernel_size[0] // 2, kernel_size[0] // 2), + (kernel_size[1] // 2, kernel_size[1] // 2), + (0, 0), + ), + mode="constant", + ) + + blurred_images = np.zeros_like(images) + kernel_reshaped = kernel.reshape( + (1, kernel.shape[0], kernel.shape[1], num_channels) + ) + + for b in range(batch_size): + image_patch = padded_images[b : b + 1, :, :, :] + + for i in range(height): + for j in range(width): + patch = image_patch[ + :, i : i + kernel_size[0], j : j + kernel_size[1], : + ] + blurred_images[b, i, j, :] = np.sum( + patch * kernel_reshaped, axis=(1, 2) + ) + + if data_format == "channels_first": + blurred_images = np.transpose(blurred_images, (0, 3, 1, 2)) + if need_squeeze: + blurred_images = np.squeeze(blurred_images, axis=0) + + return blurred_images + + +def elastic_transform_np( + images, + alpha=20.0, + sigma=5.0, + interpolation="bilinear", + fill_mode="reflect", + fill_value=0.0, + seed=None, + data_format=None, +): + data_format = backend.standardize_data_format(data_format) + + images = np.asarray(images) + input_dtype = images.dtype + + alpha = np.asarray(alpha, dtype=input_dtype) + sigma = np.asarray(sigma, dtype=input_dtype) + + kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1) + + need_squeeze = False + if len(images.shape) == 3: + images = np.expand_dims(images, axis=0) + need_squeeze = True + + if data_format == "channels_last": + batch_size, height, width, channels = images.shape + channel_axis = -1 + else: + batch_size, channels, height, width = images.shape + channel_axis = 1 + + rng = np.random.default_rng([seed, 0]) + dx = ( + rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype( + input_dtype + ) + * sigma + ) + dy = ( + rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype( + input_dtype + ) + * sigma + ) + + dx = gaussian_blur_np( + np.expand_dims(dx, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + dy = gaussian_blur_np( + np.expand_dims(dy, axis=channel_axis), + kernel_size=kernel_size, + sigma=(sigma, sigma), + data_format=data_format, + ) + + dx = np.squeeze(dx) + dy = np.squeeze(dy) + + x, y = np.meshgrid(np.arange(width), np.arange(height)) + x, y = x[None, :, :], y[None, :, :] + + distorted_x = x + alpha * dx + distorted_y = y + alpha * dy + + transformed_images = np.zeros_like(images) + + if data_format == "channels_last": + for i in range(channels): + transformed_images[..., i] = np.stack( + [ + _fixed_map_coordinates( + images[b, ..., i], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + else: + for i in range(channels): + transformed_images[:, i, :, :] = np.stack( + [ + _fixed_map_coordinates( + images[b, i, ...], + [distorted_y[b], distorted_x[b]], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + fill_value=fill_value, + ) + for b in range(batch_size) + ] + ) + + if need_squeeze: + transformed_images = np.squeeze(transformed_images, axis=0) + transformed_images = transformed_images.astype(input_dtype) + + return transformed_images + + +def _compute_homography_matrix(start_points, end_points): + start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1] + start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1] + start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1] + start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1] + + end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1] + end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1] + end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1] + end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1] + + coefficient_matrix = np.stack( + [ + np.stack( + [ + end_x1, + end_y1, + np.ones_like(end_x1), + np.zeros_like(end_x1), + np.zeros_like(end_x1), + np.zeros_like(end_x1), + -start_x1 * end_x1, + -start_x1 * end_y1, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x1), + np.zeros_like(end_x1), + np.zeros_like(end_x1), + end_x1, + end_y1, + np.ones_like(end_x1), + -start_y1 * end_x1, + -start_y1 * end_y1, + ], + axis=-1, + ), + np.stack( + [ + end_x2, + end_y2, + np.ones_like(end_x2), + np.zeros_like(end_x2), + np.zeros_like(end_x2), + np.zeros_like(end_x2), + -start_x2 * end_x2, + -start_x2 * end_y2, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x2), + np.zeros_like(end_x2), + np.zeros_like(end_x2), + end_x2, + end_y2, + np.ones_like(end_x2), + -start_y2 * end_x2, + -start_y2 * end_y2, + ], + axis=-1, + ), + np.stack( + [ + end_x3, + end_y3, + np.ones_like(end_x3), + np.zeros_like(end_x3), + np.zeros_like(end_x3), + np.zeros_like(end_x3), + -start_x3 * end_x3, + -start_x3 * end_y3, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x3), + np.zeros_like(end_x3), + np.zeros_like(end_x3), + end_x3, + end_y3, + np.ones_like(end_x3), + -start_y3 * end_x3, + -start_y3 * end_y3, + ], + axis=-1, + ), + np.stack( + [ + end_x4, + end_y4, + np.ones_like(end_x4), + np.zeros_like(end_x4), + np.zeros_like(end_x4), + np.zeros_like(end_x4), + -start_x4 * end_x4, + -start_x4 * end_y4, + ], + axis=-1, + ), + np.stack( + [ + np.zeros_like(end_x4), + np.zeros_like(end_x4), + np.zeros_like(end_x4), + end_x4, + end_y4, + np.ones_like(end_x4), + -start_y4 * end_x4, + -start_y4 * end_y4, + ], + axis=-1, + ), + ], + axis=1, + ) + + target_vector = np.stack( + [ + start_x1, + start_y1, + start_x2, + start_y2, + start_x3, + start_y3, + start_x4, + start_y4, + ], + axis=-1, + ) + target_vector = np.expand_dims(target_vector, axis=-1) + + homography_matrix = np.linalg.solve(coefficient_matrix, target_vector) + homography_matrix = np.reshape(homography_matrix, [-1, 8]) + + return homography_matrix + + +class ImageOpsCorrectnessTest(testing.TestCase): + def setUp(self): + # Defaults to channels_last + self.data_format = backend.image_data_format() + backend.set_image_data_format("channels_last") + return super().setUp() + + def tearDown(self): + backend.set_image_data_format(self.data_format) + return super().tearDown() + + def test_rgb_to_grayscale(self): + # Test channels_last + x = np.random.random((50, 50, 3)).astype("float32") * 255 + out = kimage.rgb_to_grayscale(x) + ref_out = tf.image.rgb_to_grayscale(x) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + x = np.random.random((2, 50, 50, 3)).astype("float32") * 255 + out = kimage.rgb_to_grayscale(x) + ref_out = tf.image.rgb_to_grayscale(x) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.random((3, 50, 50)).astype("float32") * 255 + out = kimage.rgb_to_grayscale(x) + ref_out = tf.image.rgb_to_grayscale(np.transpose(x, [1, 2, 0])) + ref_out = tf.transpose(ref_out, [2, 0, 1]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + x = np.random.random((2, 3, 50, 50)).astype("float32") * 255 + out = kimage.rgb_to_grayscale(x) + ref_out = tf.image.rgb_to_grayscale(np.transpose(x, [0, 2, 3, 1])) + ref_out = tf.transpose(ref_out, [0, 3, 1, 2]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + # Test class + out = kimage.RGBToGrayscale()(x) + self.assertAllClose(ref_out, out) + + def test_rgb_to_hsv(self): + # Test channels_last + x = np.random.random((50, 50, 3)).astype("float32") + out = kimage.rgb_to_hsv(x) + ref_out = tf.image.rgb_to_hsv(x) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + x = np.random.random((2, 50, 50, 3)).astype("float32") + out = kimage.rgb_to_hsv(x) + ref_out = tf.image.rgb_to_hsv(x) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.random((3, 50, 50)).astype("float32") + out = kimage.rgb_to_hsv(x) + ref_out = tf.image.rgb_to_hsv(np.transpose(x, [1, 2, 0])) + ref_out = tf.transpose(ref_out, [2, 0, 1]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + x = np.random.random((2, 3, 50, 50)).astype("float32") + out = kimage.rgb_to_hsv(x) + ref_out = tf.image.rgb_to_hsv(np.transpose(x, [0, 2, 3, 1])) + ref_out = tf.transpose(ref_out, [0, 3, 1, 2]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + # Test class + out = kimage.RGBToHSV()(x) + self.assertAllClose(ref_out, out) + + def test_hsv_to_rgb(self): + # Test channels_last + x = np.random.random((50, 50, 3)).astype("float32") + out = kimage.hsv_to_rgb(x) + ref_out = tf.image.hsv_to_rgb(x) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + x = np.random.random((2, 50, 50, 3)).astype("float32") + out = kimage.hsv_to_rgb(x) + ref_out = tf.image.hsv_to_rgb(x) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.random((3, 50, 50)).astype("float32") + out = kimage.hsv_to_rgb(x) + ref_out = tf.image.hsv_to_rgb(np.transpose(x, [1, 2, 0])) + ref_out = tf.transpose(ref_out, [2, 0, 1]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + x = np.random.random((2, 3, 50, 50)).astype("float32") + out = kimage.hsv_to_rgb(x) + ref_out = tf.image.hsv_to_rgb(np.transpose(x, [0, 2, 3, 1])) + ref_out = tf.transpose(ref_out, [0, 3, 1, 2]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out) + + # Test class + out = kimage.HSVToRGB()(x) + self.assertAllClose(ref_out, out) + + @parameterized.named_parameters( + named_product( + interpolation=[ + "bilinear", + "nearest", + "lanczos3", + "lanczos5", + "bicubic", + ], + antialias=[True, False], + ) + ) + def test_resize(self, interpolation, antialias): + if backend.backend() == "torch": + if "lanczos" in interpolation: + self.skipTest( + "Resizing with Lanczos interpolation is " + "not supported by the PyTorch backend. " + f"Received: interpolation={interpolation}." + ) + if interpolation == "bicubic" and antialias is False: + self.skipTest( + "Resizing with Bicubic interpolation in " + "PyTorch backend produces noise. Please " + "turn on anti-aliasing. " + f"Received: interpolation={interpolation}, " + f"antialias={antialias}." + ) + # Test channels_last + x = np.random.random((30, 30, 3)).astype("float32") * 255 + out = kimage.resize( + x, + size=(15, 15), + interpolation=interpolation, + antialias=antialias, + ) + ref_out = tf.image.resize( + x, + size=(15, 15), + method=interpolation, + antialias=antialias, + ) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-4) + + x = np.random.random((2, 30, 30, 3)).astype("float32") * 255 + out = kimage.resize( + x, + size=(15, 15), + interpolation=interpolation, + antialias=antialias, + ) + ref_out = tf.image.resize( + x, + size=(15, 15), + method=interpolation, + antialias=antialias, + ) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-4) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.random((3, 30, 30)).astype("float32") * 255 + out = kimage.resize( + x, + size=(15, 15), + interpolation=interpolation, + antialias=antialias, + ) + ref_out = tf.image.resize( + np.transpose(x, [1, 2, 0]), + size=(15, 15), + method=interpolation, + antialias=antialias, + ) + ref_out = tf.transpose(ref_out, [2, 0, 1]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-4) + + x = np.random.random((2, 3, 30, 30)).astype("float32") * 255 + out = kimage.resize( + x, + size=(15, 15), + interpolation=interpolation, + antialias=antialias, + ) + ref_out = tf.image.resize( + np.transpose(x, [0, 2, 3, 1]), + size=(15, 15), + method=interpolation, + antialias=antialias, + ) + ref_out = tf.transpose(ref_out, [0, 3, 1, 2]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-4) + + # Test class + out = kimage.Resize( + size=(15, 15), + interpolation=interpolation, + antialias=antialias, + )(x) + self.assertAllClose(ref_out, out, atol=1e-4) + + def test_resize_uint8_round(self): + x = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1) + expected = np.array( + # OpenCV as gold standard. + # [ + # [0, 0, 1, 1], + # [64, 64, 64, 65], + # [191, 191, 191, 192], + # [254, 254, 255, 255], + # ] + # + # Resize without `round` - differences in 8 points + # [ + # [0, 0, 0, 1], + # [63, 63, 64, 64], + # [190, 190, 191, 191], + # [254, 254, 254, 255], + # ] + # + # Resize with `round` - differences in 2 points + [ + [0, 0, 1, 1], + [64, 64, 64, 64], + [190, 191, 191, 192], + [254, 254, 255, 255], + ], + dtype="uint8", + ).reshape(1, 4, 4, 1) + out = kimage.resize( + x, + size=(4, 4), + interpolation="bilinear", + antialias=False, + ) + self.assertEqual(tuple(out.shape), tuple(expected.shape)) + self.assertEqual(backend.standardize_dtype(out.dtype), "uint8") + self.assertAllClose(out, expected, atol=1e-4) + + def test_resize_uint8_round_saturate(self): + x = np.array([0, 1, 254, 255], dtype="uint8").reshape(1, 2, 2, 1) + expected = np.array( + # OpenCV as gold standard. Same for `torch` backend. + ( + [ + [0, 0, 0, 0], + [57, 58, 58, 59], + [196, 197, 197, 198], + [255, 255, 255, 255], + ] + if "torch" == backend.backend() + # Resize without `round` and `saturate_cast` - differences in + # 16 points + # [ + # [234, 234, 235, 235], + # [-5, -6, -5, -6], + # [5, 4, 5, 4], + # [-235, -235, -234, -234], + # ] + # + # Resize with `round` and `saturate_cast` - differences in + # 8 points + else [ + [0, 0, 0, 0], + [53, 53, 53, 54], + [201, 202, 202, 202], + [255, 255, 255, 255], + ] + ), + dtype="uint8", + ).reshape(1, 4, 4, 1) + out = kimage.resize( + x, + size=(4, 4), + interpolation="bicubic", + antialias=False, + ) + self.assertEqual(tuple(out.shape), tuple(expected.shape)) + self.assertEqual(backend.standardize_dtype(out.dtype), "uint8") + self.assertAllClose(out, expected, atol=1e-4) + + def test_resize_with_crop(self): + # Test channels_last + x = np.random.random((60, 50, 3)).astype("float32") * 255 + out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True) + self.assertEqual(out.shape, (25, 25, 3)) + + x = np.random.random((2, 50, 60, 3)).astype("float32") * 255 + out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True) + self.assertEqual(out.shape, (2, 25, 25, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.random((3, 60, 50)).astype("float32") * 255 + out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True) + self.assertEqual(out.shape, (3, 25, 25)) + + x = np.random.random((2, 3, 50, 60)).astype("float32") * 255 + out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True) + self.assertEqual(out.shape, (2, 3, 25, 25)) + + @parameterized.named_parameters(named_product(fill_value=[1.0, 2.0])) + def test_resize_with_pad(self, fill_value): + # Test channels_last + x = np.random.random((60, 50, 3)).astype("float32") * 255 + out = kimage.resize( + x, + size=(25, 25), + pad_to_aspect_ratio=True, + fill_value=fill_value, + ) + self.assertEqual(out.shape, (25, 25, 3)) + + x = np.random.random((2, 50, 60, 3)).astype("float32") * 255 + out = kimage.resize( + x, size=(25, 25), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (2, 25, 25, 3)) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.random((3, 60, 50)).astype("float32") * 255 + out = kimage.resize( + x, size=(25, 25), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (3, 25, 25)) + + x = np.random.random((2, 3, 50, 60)).astype("float32") * 255 + out = kimage.resize( + x, size=(25, 25), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (2, 3, 25, 25)) + + x = np.ones((2, 3, 10, 10)) * 128 + out = kimage.resize( + x, size=(4, 4), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (2, 3, 4, 4)) + self.assertAllClose(out[:, 0, :, :], np.ones((2, 4, 4)) * 128) + + x = np.ones((2, 3, 10, 8)) * 128 + out = kimage.resize( + x, size=(4, 4), pad_to_aspect_ratio=True, fill_value=fill_value + ) + self.assertEqual(out.shape, (2, 3, 4, 4)) + self.assertAllClose( + out, + np.concatenate( + [ + np.ones((2, 3, 4, 1)) * 96.25, + np.ones((2, 3, 4, 2)) * 128.0, + np.ones((2, 3, 4, 1)) * 96.25, + ], + axis=3, + ), + atol=1.0, + ) + + @parameterized.named_parameters( + named_product( + interpolation=["bilinear", "nearest"], + fill_mode=["constant", "nearest", "wrap", "mirror", "reflect"], + ) + ) + def test_affine_transform(self, interpolation, fill_mode): + if backend.backend() == "tensorflow" and fill_mode == "mirror": + self.skipTest( + "In tensorflow backend, applying affine_transform with " + "fill_mode=mirror is not supported" + ) + if backend.backend() == "tensorflow" and fill_mode == "wrap": + self.skipTest( + "In tensorflow backend, the numerical results of applying " + "affine_transform with fill_mode=wrap is inconsistent with" + "scipy" + ) + # TODO: `nearest` interpolation in jax and torch causes random index + # shifting, resulting in significant differences in output which leads + # to failure + if backend.backend() in ("jax", "torch") and interpolation == "nearest": + self.skipTest( + f"In {backend.backend()} backend, " + f"interpolation={interpolation} causes index shifting and " + "leads test failure" + ) + + # Test channels_last + np.random.seed(42) + x = np.random.uniform(size=(50, 50, 3)).astype("float32") * 255 + transform = np.random.uniform(size=(6)).astype("float32") + transform = np.pad(transform, (0, 2)) # makes c0, c1 always 0 + out = kimage.affine_transform( + x, transform, interpolation=interpolation, fill_mode=fill_mode + ) + coordinates = _compute_affine_transform_coordinates(x, transform) + ref_out = _fixed_map_coordinates( + x, + coordinates, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + ) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2) + + x = np.random.uniform(size=(2, 50, 50, 3)).astype("float32") * 255 + transform = np.random.uniform(size=(2, 6)).astype("float32") + transform = np.pad(transform, [(0, 0), (0, 2)]) # makes c0, c1 always 0 + out = kimage.affine_transform( + x, + transform, + interpolation=interpolation, + fill_mode=fill_mode, + ) + coordinates = _compute_affine_transform_coordinates(x, transform) + ref_out = np.stack( + [ + _fixed_map_coordinates( + x[i], + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + ) + for i in range(x.shape[0]) + ], + axis=0, + ) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.uniform(size=(3, 50, 50)).astype("float32") * 255 + transform = np.random.uniform(size=(6)).astype("float32") + transform = np.pad(transform, (0, 2)) # makes c0, c1 always 0 + out = kimage.affine_transform( + x, transform, interpolation=interpolation, fill_mode=fill_mode + ) + coordinates = _compute_affine_transform_coordinates( + np.transpose(x, [1, 2, 0]), transform + ) + ref_out = _fixed_map_coordinates( + np.transpose(x, [1, 2, 0]), + coordinates, + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + ) + ref_out = np.transpose(ref_out, [2, 0, 1]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2) + + x = np.random.uniform(size=(2, 3, 50, 50)).astype("float32") * 255 + transform = np.random.uniform(size=(2, 6)).astype("float32") + transform = np.pad(transform, [(0, 0), (0, 2)]) # makes c0, c1 always 0 + out = kimage.affine_transform( + x, + transform, + interpolation=interpolation, + fill_mode=fill_mode, + ) + coordinates = _compute_affine_transform_coordinates( + np.transpose(x, [0, 2, 3, 1]), transform + ) + ref_out = np.stack( + [ + _fixed_map_coordinates( + np.transpose(x[i], [1, 2, 0]), + coordinates[i], + order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation], + fill_mode=fill_mode, + ) + for i in range(x.shape[0]) + ], + axis=0, + ) + ref_out = np.transpose(ref_out, [0, 3, 1, 2]) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2) + + # Test class + out = kimage.AffineTransform( + interpolation=interpolation, fill_mode=fill_mode + )(x, transform) + self.assertAllClose(ref_out, out, atol=1e-2) + + @parameterized.named_parameters( + named_product( + size=[(3, 3), (5, 5)], + strides=[None, (1, 1), (2, 2)], + dilation_rate=[1, 3], + padding=["valid", "same"], + ) + ) + def test_extract_patches(self, size, strides, dilation_rate, padding): + patch_h, patch_w = size[0], size[1] + if strides is None: + strides_h, strides_w = patch_h, patch_w + else: + strides_h, strides_w = strides[0], strides[1] + if ( + backend.backend() == "tensorflow" + and strides_h > 1 + or strides_w > 1 + and dilation_rate > 1 + ): + pytest.skip("dilation_rate>1 with strides>1 not supported with TF") + + # Test channels_last + image = np.random.uniform(size=(1, 20, 20, 3)).astype("float32") + patches_out = kimage.extract_patches( + image, + size=size, + strides=strides, + dilation_rate=dilation_rate, + padding=padding, + ) + patches_ref = tf.image.extract_patches( + image, + sizes=(1, patch_h, patch_w, 1), + strides=(1, strides_h, strides_w, 1), + rates=(1, dilation_rate, dilation_rate, 1), + padding=padding.upper(), + ) + self.assertEqual(tuple(patches_out.shape), tuple(patches_ref.shape)) + self.assertAllClose(patches_ref, patches_out, atol=1e-2) + + # Test channels_first + if backend.backend() == "tensorflow": + # tensorflow doesn't support channels_first in + # `kimage.extract_patches` + return + backend.set_image_data_format("channels_first") + image = np.random.uniform(size=(1, 3, 20, 20)).astype("float32") + patches_out = kimage.extract_patches( + image, + size=size, + strides=strides, + dilation_rate=dilation_rate, + padding=padding, + ) + patches_ref = tf.image.extract_patches( + np.transpose(image, [0, 2, 3, 1]), + sizes=(1, patch_h, patch_w, 1), + strides=(1, strides_h, strides_w, 1), + rates=(1, dilation_rate, dilation_rate, 1), + padding=padding.upper(), + ) + patches_ref = tf.transpose(patches_ref, [0, 3, 1, 2]) + self.assertEqual(tuple(patches_out.shape), tuple(patches_ref.shape)) + self.assertAllClose(patches_ref, patches_out, atol=1e-2) + + # Test class + patches_out = kimage.ExtractPatches( + size=size, + strides=strides, + dilation_rate=dilation_rate, + padding=padding, + )(image) + self.assertAllClose(patches_ref, patches_out, atol=1e-2) + + @parameterized.named_parameters( + named_product( + # (input_shape, coordinates_shape) + shape=[((5,), (7,)), ((3, 4, 5), (2, 3, 4))], + # TODO: scipy.ndimage.map_coordinates does not support float16 + # TODO: torch cpu does not support round & floor for float16 + dtype=["uint8", "int32", "float32"], + order=[0, 1], + fill_mode=["constant", "nearest", "wrap", "mirror", "reflect"], + ) + ) + def test_map_coordinates(self, shape, dtype, order, fill_mode): + input_shape, coordinates_shape = shape + input = np.arange(math.prod(input_shape), dtype=dtype).reshape( + input_shape + ) + coordinates_dtype = "float32" if "int" in dtype else dtype + coordinates = [ + (size - 1) + * np.random.uniform(size=coordinates_shape).astype( + coordinates_dtype + ) + for size in input_shape + ] + output = kimage.map_coordinates(input, coordinates, order, fill_mode) + expected = _fixed_map_coordinates(input, coordinates, order, fill_mode) + self.assertAllClose(output, expected) + + # Test class + output = kimage.MapCoordinates(order, fill_mode)(input, coordinates) + self.assertAllClose(output, expected) + + @parameterized.parameters( + [ + (0, 0, 3, 3, None, None), + (1, 0, 4, 3, None, None), + (0, 1, 3, 4, None, None), + (0, 0, 4, 3, None, None), + (0, 0, 3, 4, None, None), + (0, 0, None, None, 0, 1), + (0, 0, None, None, 1, 0), + (1, 2, None, None, 3, 4), + ] + ) + def test_pad_images( + self, + top_padding, + left_padding, + target_height, + target_width, + bottom_padding, + right_padding, + ): + # Test channels_last + image = np.random.uniform(size=(3, 3, 1)).astype("float32") + _target_height = target_height # For `tf.image.pad_to_bounding_box` + _target_width = target_width # For `tf.image.pad_to_bounding_box` + if _target_height is None: + _target_height = image.shape[0] + top_padding + bottom_padding + if _target_width is None: + _target_width = image.shape[1] + left_padding + right_padding + padded_image = kimage.pad_images( + image, + top_padding, + left_padding, + bottom_padding, + right_padding, + target_height, + target_width, + ) + ref_padded_image = tf.image.pad_to_bounding_box( + image, top_padding, left_padding, _target_height, _target_width + ) + self.assertEqual( + tuple(padded_image.shape), tuple(ref_padded_image.shape) + ) + self.assertAllClose(ref_padded_image, padded_image) + + # Test channels_first + backend.set_image_data_format("channels_first") + image = np.random.uniform(size=(1, 3, 3)).astype("float32") + padded_image = kimage.pad_images( + image, + top_padding, + left_padding, + bottom_padding, + right_padding, + target_height, + target_width, + ) + ref_padded_image = tf.image.pad_to_bounding_box( + np.transpose(image, [1, 2, 0]), + top_padding, + left_padding, + _target_height, + _target_width, + ) + ref_padded_image = tf.transpose(ref_padded_image, [2, 0, 1]) + self.assertEqual( + tuple(padded_image.shape), tuple(ref_padded_image.shape) + ) + self.assertAllClose(ref_padded_image, padded_image) + + # Test class + padded_image = kimage.PadImages( + top_padding, + left_padding, + bottom_padding, + right_padding, + target_height, + target_width, + )(image) + self.assertAllClose(ref_padded_image, padded_image) + + @parameterized.parameters( + [ + (0, 0, 3, 3, None, None), + (1, 0, 4, 3, None, None), + (0, 1, 3, 4, None, None), + (0, 0, 4, 3, None, None), + (0, 0, 3, 4, None, None), + (0, 0, None, None, 0, 1), + (0, 0, None, None, 1, 0), + (1, 2, None, None, 3, 4), + ] + ) + def test_crop_images( + self, + top_cropping, + left_cropping, + target_height, + target_width, + bottom_cropping, + right_cropping, + ): + # Test channels_last + image = np.random.uniform(size=(10, 10, 1)).astype("float32") + _target_height = target_height # For `tf.image.pad_to_bounding_box` + _target_width = target_width # For `tf.image.pad_to_bounding_box` + if _target_height is None: + _target_height = image.shape[0] - top_cropping - bottom_cropping + if _target_width is None: + _target_width = image.shape[1] - left_cropping - right_cropping + cropped_image = kimage.crop_images( + image, + top_cropping, + left_cropping, + bottom_cropping, + right_cropping, + target_height, + target_width, + ) + ref_cropped_image = tf.image.crop_to_bounding_box( + image, top_cropping, left_cropping, _target_height, _target_width + ) + self.assertEqual( + tuple(cropped_image.shape), tuple(ref_cropped_image.shape) + ) + self.assertAllClose(ref_cropped_image, cropped_image) + + # Test channels_first + backend.set_image_data_format("channels_first") + image = np.random.uniform(size=(1, 10, 10)).astype("float32") + cropped_image = kimage.crop_images( + image, + top_cropping, + left_cropping, + bottom_cropping, + right_cropping, + target_height, + target_width, + ) + ref_cropped_image = tf.image.crop_to_bounding_box( + np.transpose(image, [1, 2, 0]), + top_cropping, + left_cropping, + _target_height, + _target_width, + ) + ref_cropped_image = tf.transpose(ref_cropped_image, [2, 0, 1]) + self.assertEqual( + tuple(cropped_image.shape), tuple(ref_cropped_image.shape) + ) + self.assertAllClose(ref_cropped_image, cropped_image) + + # Test class + cropped_image = kimage.CropImages( + top_cropping, + left_cropping, + bottom_cropping, + right_cropping, + target_height, + target_width, + )(image) + self.assertAllClose(ref_cropped_image, cropped_image) + + @parameterized.named_parameters( + named_product( + interpolation=["bilinear", "nearest"], + ) + ) + def test_perspective_transform(self, interpolation): + # Test channels_last + np.random.seed(42) + x = np.random.uniform(size=(50, 50, 3)).astype("float32") + start_points = np.random.uniform(size=(1, 4, 2)).astype("float32") + end_points = np.random.uniform(size=(1, 4, 2)).astype("float32") + + out = kimage.perspective_transform( + x, start_points, end_points, interpolation=interpolation + ) + + ref_out = _perspective_transform_numpy( + x, start_points, end_points, interpolation=interpolation + ) + + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.uniform(size=(3, 50, 50)).astype("float32") + start_points = np.random.uniform(size=(1, 4, 2)).astype("float32") + end_points = np.random.uniform(size=(1, 4, 2)).astype("float32") + + out = kimage.perspective_transform( + x, start_points, end_points, interpolation=interpolation + ) + + ref_out = _perspective_transform_numpy( + x, + start_points, + end_points, + interpolation=interpolation, + data_format="channels_first", + ) + + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2) + + def test_gaussian_blur(self): + # Test channels_last + backend.set_image_data_format("channels_last") + np.random.seed(42) + x = np.random.uniform(size=(50, 50, 3)).astype("float32") + kernel_size = np.array([3, 3]) + sigma = np.random.uniform(size=(2,)).astype("float32") + + out = kimage.gaussian_blur( + x, + kernel_size, + sigma, + data_format="channels_last", + ) + + ref_out = gaussian_blur_np( + x, + kernel_size, + sigma, + data_format="channels_last", + ) + + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.uniform(size=(3, 50, 50)).astype("float32") + kernel_size = np.array([3, 3]) + sigma = np.random.uniform(size=(2,)).astype("float32") + + out = kimage.gaussian_blur( + x, + kernel_size, + sigma, + data_format="channels_first", + ) + + ref_out = gaussian_blur_np( + x, + kernel_size, + sigma, + data_format="channels_first", + ) + + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2) + + def test_elastic_transform(self): + # Test channels_last + backend.set_image_data_format("channels_last") + np.random.seed(42) + x = np.random.uniform(size=(50, 50, 3)).astype("float32") + alpha, sigma, seed = 20.0, 5.0, 42 + + out = kimage.elastic_transform( + x, + alpha=alpha, + sigma=sigma, + seed=seed, + data_format="channels_last", + ) + + ref_out = elastic_transform_np( + x, + alpha=alpha, + sigma=sigma, + seed=seed, + data_format="channels_last", + ) + + out = backend.convert_to_numpy(out) + + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose( + np.mean(ref_out), np.mean(out), atol=1e-2, rtol=1e-2 + ) + self.assertAllClose(np.var(ref_out), np.var(out), atol=1e-2, rtol=1e-2) + + # Test channels_first + backend.set_image_data_format("channels_first") + x = np.random.uniform(size=(3, 50, 50)).astype("float32") + alpha, sigma, seed = 20.0, 5.0, 42 + + ref_out = elastic_transform_np( + x, + alpha=alpha, + sigma=sigma, + seed=seed, + data_format="channels_first", + ) + + out = kimage.elastic_transform( + x, + alpha=alpha, + sigma=sigma, + seed=seed, + data_format="channels_first", + ) + out = backend.convert_to_numpy(out) + + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose( + np.mean(ref_out), np.mean(out), atol=1e-2, rtol=1e-2 + ) + self.assertAllClose(np.var(ref_out), np.var(out), atol=1e-2, rtol=1e-2) + + def test_map_coordinates_constant_padding(self): + input_img = tf.ones((2, 2), dtype=tf.uint8) + # one pixel outside of the input space around the edges + grid = tf.stack( + tf.meshgrid( + tf.range(-1, 3, dtype=tf.float32), + tf.range(-1, 3, dtype=tf.float32), + indexing="ij", + ), + axis=0, + ) + out = backend.convert_to_numpy( + kimage.map_coordinates( + input_img, grid, order=0, fill_mode="constant", fill_value=0 + ) + ) + + # check for ones in the middle and zeros around the edges + self.assertTrue(np.all(out[:1] == 0)) + self.assertTrue(np.all(out[-1:] == 0)) + self.assertTrue(np.all(out[:, :1] == 0)) + self.assertTrue(np.all(out[:, -1:] == 0)) + self.assertTrue(np.all(out[1:3, 1:3] == 1)) + + @parameterized.named_parameters( + named_product( + method=["linear", "cubic", "lanczos3", "lanczos5"], + antialias=[True, False], + ) + ) + def test_scale_and_translate(self, method, antialias): + images = np.random.random((30, 30, 3)).astype("float32") * 255 + scale = np.array([2.0, 2.0]).astype("float32") + translation = -(scale / 2.0 - 0.5) + out = kimage.scale_and_translate( + images, + output_shape=(15, 15, 3), + scale=scale, + translation=translation, + spatial_dims=(0, 1), + method=method, + antialias=antialias, + ) + ref_out = jax.image.scale_and_translate( + images, + shape=(15, 15, 3), + spatial_dims=(0, 1), + scale=scale, + translation=translation, + method=method, + antialias=antialias, + ) + self.assertEqual(tuple(out.shape), tuple(ref_out.shape)) + self.assertAllClose(ref_out, out, atol=1e-4) + + +class ImageOpsDtypeTest(testing.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")] + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + + if backend.backend() == "torch": + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")] + + def setUp(self): + # Defaults to channels_last + self.data_format = backend.image_data_format() + backend.set_image_data_format("channels_last") + return super().setUp() + + def tearDown(self): + backend.set_image_data_format(self.data_format) + return super().tearDown() + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_affine_transform(self, dtype): + images = knp.ones((50, 50, 3), dtype=dtype) + transform = knp.ones((8,), dtype=dtype) + expected_dtype = dtype + + self.assertDType( + kimage.affine_transform(images, transform), expected_dtype + ) + self.assertDType( + kimage.AffineTransform().symbolic_call(images, transform), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_crop_images(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.crop_images(images, 0, 0, 3, 3), expected_dtype) + self.assertDType( + kimage.CropImages(0, 0, 3, 3).symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_elastic_transform(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.elastic_transform(images), expected_dtype) + self.assertDType( + kimage.ElasticTransform().symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_extract_patches(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.extract_patches(images, (3, 3)), expected_dtype) + self.assertDType( + kimage.ExtractPatches((3, 3)).symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_gaussian_blur(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.gaussian_blur(images), expected_dtype) + self.assertDType( + kimage.GaussianBlur().symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_hsv_to_rgb(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.hsv_to_rgb(images), expected_dtype) + self.assertDType( + kimage.HSVToRGB().symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_map_coordinates(self, dtype): + inputs = knp.ones((3, 4, 5), dtype=dtype) + coordinates = knp.stack([knp.ones((2, 3, 4), dtype=dtype)] * 3) + expected_dtype = dtype + + self.assertDType( + kimage.map_coordinates(inputs, coordinates, 0), expected_dtype + ) + self.assertDType( + kimage.MapCoordinates(0).symbolic_call(inputs, coordinates), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_pad_images(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.pad_images(images, 0, 0, 3, 3), expected_dtype) + self.assertDType( + kimage.PadImages(0, 0, 3, 3).symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_perspective_transform(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + start_points = krandom.uniform((1, 4, 2), dtype=dtype) + end_points = krandom.uniform((1, 4, 2), dtype=dtype) + expected_dtype = dtype + + self.assertDType( + kimage.perspective_transform(images, start_points, end_points), + expected_dtype, + ) + self.assertDType( + kimage.PerspectiveTransform().symbolic_call( + images, start_points, end_points + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_resize(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.resize(images, (5, 5)), expected_dtype) + self.assertDType( + kimage.Resize((5, 5)).symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_rgb_to_grayscale(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.rgb_to_grayscale(images), expected_dtype) + self.assertDType( + kimage.RGBToGrayscale().symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_rgb_to_hsv(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + expected_dtype = dtype + + self.assertDType(kimage.rgb_to_hsv(images), expected_dtype) + self.assertDType( + kimage.RGBToHSV().symbolic_call(images), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_scale_and_translate(self, dtype): + images = knp.ones((10, 10, 3), dtype=dtype) + scale = knp.ones((2,), dtype=dtype) + translation = knp.ones((2,), dtype=dtype) + expected_dtype = dtype + + self.assertDType( + kimage.scale_and_translate( + images, + output_shape=(15, 15, 3), + scale=scale, + translation=translation, + spatial_dims=(0, 1), + method="linear", + ), + expected_dtype, + ) + self.assertDType( + kimage.ScaleAndTranslate( + spatial_dims=(0, 1), method="linear" + ).symbolic_call(images, (15, 15, 3), scale, translation), + expected_dtype, + ) + + +class ImageOpsBehaviorTests(testing.TestCase): + def setUp(self): + # Defaults to channels_last + self.data_format = backend.image_data_format() + backend.set_image_data_format("channels_last") + return super().setUp() + + def tearDown(self): + backend.set_image_data_format(self.data_format) + return super().tearDown() + + @parameterized.named_parameters(named_product(rank=[2, 5])) + def test_rgb_to_grayscale_invalid_rank(self, rank): + shape = [3] * rank + invalid_image = np.random.uniform(size=shape) + with self.assertRaisesRegex( + ValueError, + "Invalid images rank: expected rank 3", + ): + kimage.rgb_to_grayscale(invalid_image) + with self.assertRaisesRegex( + ValueError, + "Invalid images rank: expected rank 3", + ): + kimage.RGBToGrayscale()(invalid_image) + invalid_image = KerasTensor(shape=shape) + with self.assertRaisesRegex( + ValueError, + "Invalid images rank: expected rank 3", + ): + kimage.rgb_to_grayscale(invalid_image) + + @parameterized.named_parameters(named_product(rank=[2, 5])) + def test_rgb_to_hsv_invalid_rank(self, rank): + shape = [3] * rank + invalid_image = np.random.uniform(size=shape) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.rgb_to_hsv(invalid_image) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.RGBToHSV()(invalid_image) + invalid_image = KerasTensor(shape=shape) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.rgb_to_hsv(invalid_image) + + def test_rgb_to_hsv_invalid_dtype(self): + invalid_image = np.random.uniform(size=(10, 10, 3)).astype("int32") + with self.assertRaisesRegex( + ValueError, "Invalid images dtype: expected float dtype." + ): + kimage.rgb_to_hsv(invalid_image) + with self.assertRaisesRegex( + ValueError, "Invalid images dtype: expected float dtype." + ): + kimage.RGBToHSV()(invalid_image) + invalid_image = KerasTensor(shape=(10, 10, 3), dtype="int32") + with self.assertRaisesRegex( + ValueError, "Invalid images dtype: expected float dtype." + ): + kimage.rgb_to_hsv(invalid_image) + + @parameterized.named_parameters(named_product(rank=[2, 5])) + def test_hsv_to_rgb_invalid_rank(self, rank): + shape = [3] * rank + invalid_image = np.random.uniform(size=shape) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.hsv_to_rgb(invalid_image) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.HSVToRGB()(invalid_image) + invalid_image = KerasTensor(shape=shape) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.hsv_to_rgb(invalid_image) + + def test_hsv_to_rgb_invalid_dtype(self): + invalid_image = np.random.uniform(size=(10, 10, 3)).astype("int32") + with self.assertRaisesRegex( + ValueError, "Invalid images dtype: expected float dtype." + ): + kimage.hsv_to_rgb(invalid_image) + with self.assertRaisesRegex( + ValueError, "Invalid images dtype: expected float dtype." + ): + kimage.HSVToRGB()(invalid_image) + invalid_image = KerasTensor(shape=(10, 10, 3), dtype="int32") + with self.assertRaisesRegex( + ValueError, "Invalid images dtype: expected float dtype." + ): + kimage.hsv_to_rgb(invalid_image) + + def test_resize_invalid_rank(self): + # Test rank=2 + invalid_image = np.random.uniform(size=(10, 10)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.resize(invalid_image, (5, 5)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.Resize((5, 5))(invalid_image) + + # Test rank=2, symbolic tensor + invalid_image = KerasTensor(shape=(10, 10)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.resize(invalid_image, (5, 5)) + + def test_affine_transform_invalid_images_rank(self): + # Test rank=2 + invalid_image = np.random.uniform(size=(10, 10)) + transform = np.random.uniform(size=(6,)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.affine_transform(invalid_image, transform) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.AffineTransform()(invalid_image, transform) + + # Test rank=5 + invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1)) + transform = np.random.uniform(size=(6,)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.affine_transform(invalid_image, transform) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.AffineTransform()(invalid_image, transform) + + # Test rank=2, symbolic tensor + invalid_image = KerasTensor(shape=(10, 10)) + transform = KerasTensor(shape=(6,)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.affine_transform(invalid_image, transform) + + def test_affine_transform_invalid_transform_rank(self): + # Test rank=3 + images = np.random.uniform(size=(10, 10, 3)) + invalid_transform = np.random.uniform(size=(2, 3, 2)) + with self.assertRaisesRegex( + ValueError, "Invalid transform rank: expected rank 1" + ): + kimage.affine_transform(images, invalid_transform) + with self.assertRaisesRegex( + ValueError, "Invalid transform rank: expected rank 1" + ): + kimage.AffineTransform()(images, invalid_transform) + + # Test rank=0 + invalid_transform = np.random.uniform(size=()) + with self.assertRaisesRegex( + ValueError, "Invalid transform rank: expected rank 1" + ): + kimage.affine_transform(images, invalid_transform) + with self.assertRaisesRegex( + ValueError, "Invalid transform rank: expected rank 1" + ): + kimage.AffineTransform()(images, invalid_transform) + + # Test rank=3, symbolic tensor + images = KerasTensor(shape=(10, 10, 3)) + invalid_transform = KerasTensor(shape=(2, 3, 2)) + with self.assertRaisesRegex( + ValueError, "Invalid transform rank: expected rank 1" + ): + kimage.affine_transform(images, invalid_transform) + + def test_extract_patches_invalid_size(self): + size = (3, 3, 3) # Invalid size, too many dimensions + image = np.random.uniform(size=(2, 20, 20, 3)) + with self.assertRaisesRegex( + TypeError, "Expected an int or a tuple of length 2" + ): + kimage.extract_patches(image, size) + + size = "5" # Invalid size type + with self.assertRaisesRegex( + TypeError, "Expected an int or a tuple of length 2" + ): + kimage.extract_patches(image, size) + + def test_map_coordinates_invalid_coordinates_rank(self): + # Test mismatched dim of coordinates + image = np.random.uniform(size=(10, 10, 3)) + coordinates = np.random.uniform(size=(2, 10, 10)) + with self.assertRaisesRegex( + ValueError, "must be the same as the rank of `inputs`" + ): + kimage.map_coordinates(image, coordinates, 0) + with self.assertRaisesRegex( + ValueError, "must be the same as the rank of `inputs`" + ): + kimage.MapCoordinates(0)(image, coordinates) + + # Test rank=1 + coordinates = np.random.uniform(size=(3,)) + with self.assertRaisesRegex(ValueError, "expected at least rank 2"): + kimage.map_coordinates(image, coordinates, 0) + with self.assertRaisesRegex(ValueError, "expected at least rank 2"): + kimage.MapCoordinates(0)(image, coordinates) + + def test_crop_images_unknown_shape(self): + # Test unknown height and target_height + x = KerasTensor([None, 10, 3]) + with self.assertRaisesRegex( + ValueError, "When the height of the images is unknown" + ): + kimage.crop_images(x, 2, 3, 4, 5) + + # Test unknown width and target_width + x = KerasTensor([10, None, 3]) + with self.assertRaisesRegex( + ValueError, "When the width of the images is unknown" + ): + kimage.crop_images(x, 2, 3, 4, 5) + + def test_perspective_transform_invalid_images_rank(self): + # Test rank=2 + invalid_image = np.random.uniform(size=(10, 10)) + start_points = np.random.uniform(size=(6,)) + end_points = np.random.uniform(size=(6,)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.perspective_transform( + invalid_image, start_points, end_points + ) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.PerspectiveTransform()( + invalid_image, start_points, end_points + ) + + # Test rank=5 + invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1)) + start_points = np.random.uniform(size=(6,)) + end_points = np.random.uniform(size=(6,)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.perspective_transform( + invalid_image, start_points, end_points + ) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.PerspectiveTransform()( + invalid_image, start_points, end_points + ) + + # Test rank=2, symbolic tensor + invalid_image = KerasTensor(shape=(10, 10)) + start_points = KerasTensor(shape=(6,)) + end_points = np.random.uniform(size=(6,)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.perspective_transform( + invalid_image, start_points, end_points + ) + + def test_perspective_transform_invalid_points_rank(self): + # Test rank=3 + images = np.random.uniform(size=(10, 10, 3)) + start_points = np.random.uniform(size=(2, 2, 4, 2)) + end_points = np.random.uniform(size=(2, 2, 4, 2)) + with self.assertRaisesRegex( + ValueError, "Invalid start_points shape: expected" + ): + kimage.perspective_transform(images, start_points, end_points) + with self.assertRaisesRegex( + ValueError, "Invalid start_points shape: expected" + ): + kimage.PerspectiveTransform()(images, start_points, end_points) + + # Test rank=0 + start_points = np.random.uniform(size=()) + end_points = np.random.uniform(size=()) + with self.assertRaisesRegex( + ValueError, "Invalid start_points shape: expected" + ): + kimage.perspective_transform(images, start_points, end_points) + with self.assertRaisesRegex( + ValueError, "Invalid start_points shape: expected" + ): + kimage.PerspectiveTransform()(images, start_points, end_points) + + # Test rank=3, symbolic tensor + images = KerasTensor(shape=(10, 10, 3)) + start_points = KerasTensor(shape=(2, 3, 2)) + with self.assertRaisesRegex( + ValueError, "Invalid start_points shape: expected" + ): + kimage.perspective_transform(images, start_points, end_points) + + def test_gaussian_blur_invalid_images_rank(self): + # Test rank=2 + invalid_image = np.random.uniform(size=(10, 10)) + kernel_size = (3, 3) + sigma = (0.1, 0.1) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.gaussian_blur( + invalid_image, kernel_size=kernel_size, sigma=sigma + ) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.GaussianBlur(kernel_size=kernel_size, sigma=sigma)( + invalid_image + ) + + # Test rank=5 + invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.gaussian_blur( + invalid_image, kernel_size=kernel_size, sigma=sigma + ) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.GaussianBlur(kernel_size=kernel_size, sigma=sigma)( + invalid_image + ) + + # Test rank=2, symbolic tensor + invalid_image = KerasTensor(shape=(10, 10)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.gaussian_blur( + invalid_image, kernel_size=kernel_size, sigma=sigma + ) + + def test_elastic_transform_invalid_images_rank(self): + # Test rank=2 + invalid_image = np.random.uniform(size=(10, 10)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.elastic_transform( + invalid_image, + ) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.ElasticTransform()(invalid_image) + + # Test rank=5 + invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.elastic_transform(invalid_image) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.ElasticTransform()(invalid_image) + + # Test rank=2, symbolic tensor + invalid_image = KerasTensor(shape=(10, 10)) + with self.assertRaisesRegex( + ValueError, "Invalid images rank: expected rank 3" + ): + kimage.elastic_transform(invalid_image) diff --git a/keras/src/ops/linalg.py b/keras/src/ops/linalg.py new file mode 100644 index 000000000000..dee781f49852 --- /dev/null +++ b/keras/src/ops/linalg.py @@ -0,0 +1,827 @@ +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.ops.operation import Operation +from keras.src.ops.operation_utils import reduce_shape + + +class Cholesky(Operation): + def __init__(self, upper=False, *, name=None): + super().__init__(name=name) + self.upper = upper + + def call(self, x): + return _cholesky(x, self.upper) + + def compute_output_spec(self, x): + _assert_2d(x) + _assert_square(x) + return KerasTensor(x.shape, x.dtype) + + +@keras_export(["keras.ops.cholesky", "keras.ops.linalg.cholesky"]) +def cholesky(x, upper=False): + """Computes the Cholesky decomposition of a positive semi-definite matrix. + + Args: + x: Input tensor of shape `(..., M, M)`. + upper (bool): If True, returns the upper-triangular Cholesky factor. + If False (default), returns the lower-triangular Cholesky factor. + + Returns: + A tensor of shape `(..., M, M)` representing the Cholesky factor of `x`. + """ + if any_symbolic_tensors((x,)): + return Cholesky(upper=upper).symbolic_call(x) + return _cholesky(x, upper=upper) + + +def _cholesky(x, upper=False): + x = backend.convert_to_tensor(x) + _assert_2d(x) + _assert_square(x) + try: + return backend.linalg.cholesky(x, upper=upper) + except Exception as e: + raise ValueError(f"Cholesky decomposition failed: {e}") + + +class CholeskyInverse(Operation): + def __init__(self, upper=False, *, name=None): + super().__init__(name=name) + self.upper = upper + + def call(self, x): + return _cholesky_inverse(x, self.upper) + + def compute_output_spec(self, x): + _assert_2d(x) + _assert_square(x) + return KerasTensor(x.shape, x.dtype) + + +@keras_export( + ["keras.ops.cholesky_inverse", "keras.ops.linalg.cholesky_inverse"] +) +def cholesky_inverse(x, upper=False): + """Computes the inverse of a symmetric positive-definite matrix. + + Args: + x: Input tensor of shape `(..., M, M)`. + upper (bool): Determines whether to use the upper- or lower-triangular + factor for the internal computation. Defaults to False. + + Returns: + A tensor of shape `(..., M, M)` representing the inverse of `x`. + + Raises: + ValueError: If `x` is not a symmetric positive-definite matrix. + """ + if any_symbolic_tensors((x,)): + return CholeskyInverse(upper=upper).symbolic_call(x) + return _cholesky_inverse(x, upper=upper) + + +def _cholesky_inverse(x, upper=False): + x = backend.convert_to_tensor(x) + _assert_2d(x) + _assert_square(x) + try: + return backend.linalg.cholesky_inverse(x, upper=upper) + except Exception as e: + raise ValueError(f"Cholesky inverse failed: {e}") + + +class Det(Operation): + def call(self, x): + return _det(x) + + def compute_output_spec(self, x): + _assert_2d(x) + _assert_square(x) + return KerasTensor(x.shape[:-2], x.dtype) + + +@keras_export(["keras.ops.det", "keras.ops.linalg.det"]) +def det(x): + """Computes the determinant of a square tensor. + + Args: + x: Input tensor of shape `(..., M, M)`. + + Returns: + A tensor of shape `(...,)` representing the determinant of `x`. + + """ + if any_symbolic_tensors((x,)): + return Det().symbolic_call(x) + return _det(x) + + +def _det(x): + x = backend.convert_to_tensor(x) + _assert_2d(x) + _assert_square(x) + return backend.linalg.det(x) + + +class Eig(Operation): + def call(self, x): + return _eig(x) + + def compute_output_spec(self, x): + _assert_square(x) + _assert_2d(x) + return ( + KerasTensor(x.shape[:-1], x.dtype), + KerasTensor(x.shape, x.dtype), + ) + + +@keras_export(["keras.ops.eig", "keras.ops.linalg.eig"]) +def eig(x): + """Computes the eigenvalues and eigenvectors of a square matrix. + + Args: + x: Input tensor of shape `(..., M, M)`. + + Returns: + A tuple of two tensors: a tensor of shape `(..., M)` containing + eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors. + """ + if any_symbolic_tensors((x,)): + return Eig().symbolic_call(x) + return _eig(x) + + +def _eig(x): + x = backend.convert_to_tensor(x) + _assert_square(x) + _assert_2d(x) + return backend.linalg.eig(x) + + +class Eigh(Operation): + def call(self, x): + return _eigh(x) + + def compute_output_spec(self, x): + _assert_square(x) + _assert_2d(x) + return ( + KerasTensor(x.shape[:-1], x.dtype), + KerasTensor(x.shape, x.dtype), + ) + + +@keras_export(["keras.ops.eigh", "keras.ops.linalg.eigh"]) +def eigh(x): + """Computes the eigenvalues and eigenvectors of a complex Hermitian. + + Args: + x: Input tensor of shape `(..., M, M)`. + + Returns: + A tuple of two tensors: a tensor of shape `(..., M)` containing + eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors. + + """ + if any_symbolic_tensors((x,)): + return Eigh().symbolic_call(x) + return _eigh(x) + + +def _eigh(x): + x = backend.convert_to_tensor(x) + _assert_square(x) + _assert_2d(x) + return backend.linalg.eigh(x) + + +class Inv(Operation): + def call(self, x): + return _inv(x) + + def compute_output_spec(self, x): + _assert_2d(x) + _assert_square(x) + return KerasTensor(x.shape, x.dtype) + + +@keras_export(["keras.ops.inv", "keras.ops.linalg.inv"]) +def inv(x): + """Computes the inverse of a square tensor. + + Args: + x: Input tensor of shape `(..., M, M)`. + + Returns: + A tensor of shape `(..., M, M)` representing the inverse of `x`. + + """ + if any_symbolic_tensors((x,)): + return Inv().symbolic_call(x) + return _inv(x) + + +def _inv(x): + x = backend.convert_to_tensor(x) + _assert_2d(x) + _assert_square(x) + return backend.linalg.inv(x) + + +class LuFactor(Operation): + def call(self, x): + return _lu_factor(x) + + def compute_output_spec(self, x): + _assert_2d(x) + batch_shape = x.shape[:-2] + m, n = x.shape[-2:] + k = min(m, n) + return ( + KerasTensor(batch_shape + (m, n), x.dtype), + KerasTensor(batch_shape + (k,), x.dtype), + ) + + +@keras_export(["keras.ops.lu_factor", "keras.ops.linalg.lu_factor"]) +def lu_factor(x): + """Computes the lower-upper decomposition of a square matrix. + + Args: + x: A tensor of shape `(..., M, M)`. + + Returns: + A tuple of two tensors: a tensor of shape `(..., M, M)` containing the + lower and upper triangular matrices and a tensor of shape `(..., M)` + containing the pivots. + + """ + if any_symbolic_tensors((x,)): + return LuFactor().symbolic_call(x) + return _lu_factor(x) + + +def _lu_factor(x): + x = backend.convert_to_tensor(x) + _assert_2d(x) + if backend.backend() == "tensorflow": + try: + _assert_square(x) + except ValueError as e: + raise ValueError( + f"LU decomposition failed: {e}. LU decomposition is only " + "supported for square matrices in Tensorflow." + ) + return backend.linalg.lu_factor(x) + + +class Norm(Operation): + def __init__(self, ord=None, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(ord, str): + if ord not in ("fro", "nuc"): + raise ValueError( + "Invalid `ord` argument. " + "Expected one of {'fro', 'nuc'} when using string. " + f"Received: ord={ord}" + ) + if isinstance(axis, int): + axis = [axis] + self.ord = ord + self.axis = axis + self.keepdims = keepdims + + def compute_output_spec(self, x): + output_dtype = backend.standardize_dtype(x.dtype) + if "int" in output_dtype or output_dtype == "bool": + output_dtype = backend.floatx() + if self.axis is None: + axis = tuple(range(len(x.shape))) + else: + axis = self.axis + num_axes = len(axis) + if num_axes == 1 and isinstance(self.ord, str): + raise ValueError( + "Invalid `ord` argument for vector norm. " + f"Received: ord={self.ord}" + ) + elif num_axes == 2 and self.ord not in ( + None, + "fro", + "nuc", + float("inf"), + float("-inf"), + 1, + -1, + 2, + -2, + ): + raise ValueError( + "Invalid `ord` argument for matrix norm. " + f"Received: ord={self.ord}" + ) + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=output_dtype, + ) + + def call(self, x): + x = backend.convert_to_tensor(x) + return backend.linalg.norm( + x, ord=self.ord, axis=self.axis, keepdims=self.keepdims + ) + + +@keras_export(["keras.ops.norm", "keras.ops.linalg.norm"]) +def norm(x, ord=None, axis=None, keepdims=False): + """Matrix or vector norm. + + This function is able to return one of eight different matrix norms, or one + of an infinite number of vector norms (described below), depending on the + value of the `ord` parameter. + + Args: + x: Input tensor. + ord: Order of the norm (see table under Notes). The default is `None`. + axis: If `axis` is an integer, it specifies the axis of `x` along which + to compute the vector norms. If `axis` is a 2-tuple, it specifies + the axes that hold 2-D matrices, and the matrix norms of these + matrices are computed. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Note: + For values of `ord < 1`, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. The following norms can be calculated: + - For matrices: + - `ord=None`: Frobenius norm + - `ord="fro"`: Frobenius norm + - `ord="nuc"`: nuclear norm + - `ord=np.inf`: `max(sum(abs(x), axis=1))` + - `ord=-np.inf`: `min(sum(abs(x), axis=1))` + - `ord=0`: not supported + - `ord=1`: `max(sum(abs(x), axis=0))` + - `ord=-1`: `min(sum(abs(x), axis=0))` + - `ord=2`: 2-norm (largest sing. value) + - `ord=-2`: smallest singular value + - other: not supported + - For vectors: + - `ord=None`: 2-norm + - `ord="fro"`: not supported + - `ord="nuc"`: not supported + - `ord=np.inf`: `max(abs(x))` + - `ord=-np.inf`: `min(abs(x))` + - `ord=0`: `sum(x != 0)` + - `ord=1`: as below + - `ord=-1`: as below + - `ord=2`: as below + - `ord=-2`: as below + - other: `sum(abs(x)**ord)**(1./ord)` + + Returns: + Norm of the matrix or vector(s). + + Example: + + >>> x = keras.ops.reshape(keras.ops.arange(9, dtype="float32") - 4, (3, 3)) + >>> keras.ops.linalg.norm(x) + 7.7459664 + """ + if any_symbolic_tensors((x,)): + return Norm(ord=ord, axis=axis, keepdims=keepdims).symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + + +class Qr(Operation): + def __init__(self, mode="reduced", *, name=None): + super().__init__(name=name) + if mode not in {"reduced", "complete"}: + raise ValueError( + "`mode` argument value not supported. " + "Expected one of {'reduced', 'complete'}. " + f"Received: mode={mode}" + ) + self.mode = mode + + def compute_output_spec(self, x): + if len(x.shape) < 2: + raise ValueError( + "Input should have rank >= 2. Received: " + f"input.shape = {x.shape}" + ) + m = x.shape[-2] + n = x.shape[-1] + if m is None or n is None: + raise ValueError( + "Input should have its last 2 dimensions " + "fully-defined. Received: " + f"input.shape = {x.shape}" + ) + k = min(m, n) + base = tuple(x.shape[:-2]) + if self.mode == "reduced": + return ( + KerasTensor(shape=base + (m, k), dtype=x.dtype), + KerasTensor(shape=base + (k, n), dtype=x.dtype), + ) + # 'complete' mode. + return ( + KerasTensor(shape=base + (m, m), dtype=x.dtype), + KerasTensor(shape=base + (m, n), dtype=x.dtype), + ) + + def call(self, x): + x = backend.convert_to_tensor(x) + return backend.linalg.qr(x, mode=self.mode) + + +@keras_export(["keras.ops.qr", "keras.ops.linalg.qr"]) +def qr(x, mode="reduced"): + """Computes the QR decomposition of a tensor. + + Args: + x: Input tensor of shape `(..., M, N)`. + mode: A string specifying the mode of the QR decomposition. + - 'reduced': Returns the reduced QR decomposition. (default) + - 'complete': Returns the complete QR decomposition. + + Returns: + A tuple containing two tensors. The first tensor of shape `(..., M, K)` + is the orthogonal matrix `q` and the second tensor of shape + `(..., K, N)` is the upper triangular matrix `r`, where `K = min(M, N)`. + + Example: + + >>> x = keras.ops.convert_to_tensor([[1., 2.], [3., 4.], [5., 6.]]) + >>> q, r = qr(x) + >>> print(q) + array([[-0.16903079 0.897085] + [-0.5070925 0.2760267 ] + [-0.8451542 -0.34503305]], shape=(3, 2), dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Qr(mode=mode).symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.linalg.qr(x, mode=mode) + + +class Solve(Operation): + def call(self, a, b): + return _solve(a, b) + + def compute_output_spec(self, a, b): + _assert_2d(a) + _assert_square(a) + _assert_1d(b) + _assert_a_b_compat(a, b) + return KerasTensor(b.shape, b.dtype) + + +@keras_export(["keras.ops.solve", "keras.ops.linalg.solve"]) +def solve(a, b): + """Solves a linear system of equations given by `a x = b`. + + Args: + a: A tensor of shape `(..., M, M)` representing the coefficients matrix. + b: A tensor of shape `(..., M)` or `(..., M, N)` representing the + right-hand side or "dependent variable" matrix. + + Returns: + A tensor of shape `(..., M)` or `(..., M, N)` representing the solution + of the linear system. Returned shape is identical to `b`. + + """ + if any_symbolic_tensors((a, b)): + return Solve().symbolic_call(a, b) + return _solve(a, b) + + +def _solve(a, b): + a = backend.convert_to_tensor(a) + b = backend.convert_to_tensor(b) + _assert_2d(a) + _assert_square(a) + _assert_1d(b) + _assert_a_b_compat(a, b) + return backend.linalg.solve(a, b) + + +class SolveTriangular(Operation): + def __init__(self, lower=False, *, name=None): + super().__init__(name=name) + self.lower = lower + + def call(self, a, b): + return _solve_triangular(a, b, self.lower) + + def compute_output_spec(self, a, b): + _assert_2d(a) + _assert_square(a) + _assert_1d(b) + _assert_a_b_compat(a, b) + return KerasTensor(b.shape, b.dtype) + + +@keras_export( + ["keras.ops.solve_triangular", "keras.ops.linalg.solve_triangular"] +) +def solve_triangular(a, b, lower=False): + """Solves a linear system of equations given by `a x = b`. + + Args: + a: A tensor of shape `(..., M, M)` representing the coefficients matrix. + b: A tensor of shape `(..., M)` or `(..., M, N)` representing the + right-hand side or "dependent variable" matrix. + + Returns: + A tensor of shape `(..., M)` or `(..., M, N)` representing the solution + of the linear system. Returned shape is identical to `b`. + + """ + if any_symbolic_tensors((a, b)): + return SolveTriangular(lower).symbolic_call(a, b) + return _solve_triangular(a, b, lower) + + +def _solve_triangular(a, b, lower=False): + a = backend.convert_to_tensor(a) + b = backend.convert_to_tensor(b) + _assert_2d(a) + _assert_square(a) + _assert_1d(b) + _assert_a_b_compat(a, b) + return backend.linalg.solve_triangular(a, b, lower) + + +class SVD(Operation): + def __init__(self, full_matrices=True, compute_uv=True, *, name=None): + super().__init__(name=name) + self.full_matrices = full_matrices + self.compute_uv = compute_uv + + def call(self, x): + return _svd(x, self.full_matrices, self.compute_uv) + + def compute_output_spec(self, x): + _assert_2d(x) + rows, columns = x.shape[-2:] + batches = x.shape[:-2] + s_shape = batches + (min(rows, columns),) + if self.full_matrices: + u_shape = batches + (rows, rows) + v_shape = batches + (columns, columns) + else: + u_shape = batches + (rows, min(rows, columns)) + v_shape = batches + (min(rows, columns), columns) + + if self.compute_uv: + return ( + KerasTensor(u_shape, x.dtype), + KerasTensor(s_shape, x.dtype), + KerasTensor(v_shape, x.dtype), + ) + return KerasTensor(s_shape, x.dtype) + + +@keras_export(["keras.ops.svd", "keras.ops.linalg.svd"]) +def svd(x, full_matrices=True, compute_uv=True): + """Computes the singular value decomposition of a matrix. + + Args: + x: Input tensor of shape `(..., M, N)`. + + Returns: + A tuple of three tensors: a tensor of shape `(..., M, M)` containing the + left singular vectors, a tensor of shape `(..., M, N)` containing the + singular values and a tensor of shape `(..., N, N)` containing the + right singular vectors. + + """ + if any_symbolic_tensors((x,)): + return SVD(full_matrices, compute_uv).symbolic_call(x) + return _svd(x, full_matrices, compute_uv) + + +def _svd(x, full_matrices=True, compute_uv=True): + x = backend.convert_to_tensor(x) + _assert_2d(x) + return backend.linalg.svd(x, full_matrices, compute_uv) + + +class Lstsq(Operation): + def __init__(self, rcond=None, *, name=None): + super().__init__(name=name) + self.rcond = rcond + + def call(self, a, b): + return backend.linalg.lstsq(a, b, rcond=self.rcond) + + def compute_output_spec(self, a, b): + if len(a.shape) != 2: + raise ValueError( + f"Expected a to have rank 2. Received: a.shape={a.shape}" + ) + if len(b.shape) not in (1, 2): + raise ValueError( + f"Expected b to have rank 1 or 2. Received: b.shape={b.shape}" + ) + m, n = a.shape + if b.shape[0] != m: + raise ValueError( + "Expected b.shape[0] to be equal to " + "a.shape[0]. Received: " + f"a.shape={a.shape}, b.shape={b.shape}" + ) + if len(b.shape) == 2: + k = b.shape[1] + x = KerasTensor((n, k), dtype=a.dtype) + else: + x = KerasTensor((n,), dtype=a.dtype) + return x + + +@keras_export(["keras.ops.lstsq", "keras.ops.linalg.lstsq"]) +def lstsq(a, b, rcond=None): + """Return the least-squares solution to a linear matrix equation. + + Computes the vector x that approximately solves the equation + `a @ x = b`. The equation may be under-, well-, or over-determined + (i.e., the number of linearly independent rows of a can be less than, + equal to, or greater than its number of linearly independent columns). + If a is square and of full rank, then `x` (but for round-off error) + is the exact solution of the equation. Else, `x` minimizes the + L2 norm of `b - a * x`. + + If there are multiple minimizing solutions, + the one with the smallest L2 norm is returned. + + Args: + a: "Coefficient" matrix of shape `(M, N)`. + b: Ordinate or "dependent variable" values, + of shape `(M,)` or `(M, K)`. + If `b` is two-dimensional, the least-squares solution + is calculated for each of the K columns of `b`. + rcond: Cut-off ratio for small singular values of `a`. + For the purposes of rank determination, + singular values are treated as zero if they are + smaller than rcond times the largest + singular value of `a`. + + Returns: + Tensor with shape `(N,)` or `(N, K)` containing + the least-squares solutions. + + **NOTE:** The output differs from `numpy.linalg.lstsq`. + NumPy returns a tuple with four elements, the first of which + being the least-squares solutions and the others + being essentially never used. + Keras only returns the first value. This is done both + to ensure consistency across backends (which cannot be achieved + for the other values) and to simplify the API. + """ + if any_symbolic_tensors((a, b)): + return Lstsq(rcond=rcond).symbolic_call(a, b) + return backend.linalg.lstsq(a, b, rcond=rcond) + + +def _assert_1d(*arrays): + for a in arrays: + if a.ndim < 1: + raise ValueError( + f"Expected input to have rank >= 1. Received scalar input {a}." + ) + + +def _assert_2d(*arrays): + for a in arrays: + if a.ndim < 2: + raise ValueError( + "Expected input to have rank >= 2. " + f"Received input with shape {a.shape}." + ) + + +def _assert_square(*arrays): + for a in arrays: + m, n = a.shape[-2:] + if m != n: + raise ValueError( + "Expected a square matrix. " + f"Received non-square input with shape {a.shape}" + ) + + +def _assert_a_b_compat(a, b): + if a.ndim == b.ndim: + if a.shape[-2] != b.shape[-2]: + raise ValueError( + "Incompatible shapes between `a` and `b`. " + "Expected `a.shape[-2] == b.shape[-2]`. " + f"Received: a.shape={a.shape}, b.shape={b.shape}" + ) + elif a.ndim == b.ndim - 1: + if a.shape[-1] != b.shape[-1]: + raise ValueError( + "Incompatible shapes between `a` and `b`. " + "Expected `a.shape[-1] == b.shape[-1]`. " + f"Received: a.shape={a.shape}, b.shape={b.shape}" + ) + + +class JVP(Operation): + def __init__(self, has_aux=False, *, name=None): + super().__init__(name=name) + self.has_aux = has_aux + + def call(self, fun, primals, tangents): + """Computes the JVP of `fun` at `primals` along `tangents`. + + Args: + fun: A callable that takes tensors (or nested structures) as input + and returns a tensor (or nested structure) as output. + primals: Input tensors (or nested structures) at which the Jacobian + of `fun` is evaluated. + tangents: Tensors (or nested structures) representing the direction + vectors for the JVP. Must have the same structure as + `primals`. + + Returns: + If `has_aux` is False: + A tuple (primals_out, tangents_out) where: + - primals_out: Output of `fun(*primals)` + - tangents_out: JVP of `fun` at `primals` along `tangents` + If `has_aux` is True: + A tuple (primals_out, tangents_out, aux) where: + - aux: Auxiliary data returned by `fun` + """ + return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux) + + def compute_output_spec(self, fun, primals, tangents): + # Infer primal output spec + if self.has_aux: + primals_out_spec, aux_spec = backend.compute_output_spec( + fun, *primals + ) + else: + primals_out_spec = backend.compute_output_spec(fun, *primals) + + # Tangents output should match primals output in structure and shape + tangents_out_spec = tree.map_structure( + lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec + ) + + if self.has_aux: + return primals_out_spec, tangents_out_spec, aux_spec + return primals_out_spec, tangents_out_spec + + +@keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"]) +def jvp(fun, primals, tangents, has_aux=False): + """Computes a (forward-mode) Jacobian-vector product of `fun`. + Args: + fun: Function to be differentiated. Its arguments should be arrays, + scalars, or standard Python containers of arrays or scalars. It + should return an array, scalar, or standard Python container of + arrays or scalars. + primals: The primal values at which the Jacobian of `fun` should be + evaluated. Should be either a tuple or a list of arguments, + and its length should be equal to the number of positional + parameters of `fun`. + tangents: The tangent vector for which the Jacobian-vector product + should be evaluated. Should be either a tuple or a list of + tangents, with the same tree structure and array shapes as + `primals`. + has_aux: Optional, bool. Indicates whether `fun` returns a pair where + the first element is considered the output of the mathematical + function to be differentiated and the second element is + auxiliary data. Default is False. + + Returns: + If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair, + where `primals_out` is `fun(*primals)`, and `tangents_out` is the + Jacobian-vector product of `fun` evaluated at `primals` with + `tangents`. The `tangents_out` value has the same Python tree + structure and shapes as `primals_out`. + + If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`) + tuple where `aux` is the auxiliary data returned by `fun`. + + Example: + >>> from keras import ops + >>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2) + >>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,)) + >>> primals + 0.09983342 + >>> tangents + 0.19900084 + """ + if any_symbolic_tensors((primals, tangents)): + return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents) + return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux) diff --git a/keras/src/ops/linalg_test.py b/keras/src/ops/linalg_test.py new file mode 100644 index 000000000000..0be61d5bb7f9 --- /dev/null +++ b/keras/src/ops/linalg_test.py @@ -0,0 +1,713 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.ops import linalg +from keras.src.testing.test_utils import named_product + + +class LinalgOpsDynamicShapeTest(testing.TestCase): + def test_cholesky(self): + x = KerasTensor([None, 20, 20]) + out = linalg.cholesky(x) + self.assertEqual(out.shape, (None, 20, 20)) + + x = KerasTensor([None, None, 20]) + with self.assertRaises(ValueError): + linalg.cholesky(x) + + x = KerasTensor([None, 20, 15]) + with self.assertRaises(ValueError): + linalg.cholesky(x) + + def test_cholesky_inverse(self): + x = KerasTensor([None, 20, 20]) + out = linalg.cholesky_inverse(x) + self.assertEqual(out.shape, (None, 20, 20)) + + x = KerasTensor([None, None, 20]) + with self.assertRaises(ValueError): + linalg.cholesky_inverse(x) + + x = KerasTensor([None, 20, 15]) + with self.assertRaises(ValueError): + linalg.cholesky_inverse(x) + + def test_det(self): + x = KerasTensor([None, 20, 20]) + out = linalg.det(x) + self.assertEqual(out.shape, (None,)) + + x = KerasTensor([None, None, 20]) + with self.assertRaises(ValueError): + linalg.det(x) + + x = KerasTensor([None, 20, 15]) + with self.assertRaises(ValueError): + linalg.det(x) + + def test_eig(self): + x = KerasTensor([None, 20, 20]) + w, v = linalg.eig(x) + self.assertEqual(w.shape, (None, 20)) + self.assertEqual(v.shape, (None, 20, 20)) + + def test_eigh(self): + x = KerasTensor([None, 20, 20]) + w, v = linalg.eigh(x) + self.assertEqual(w.shape, (None, 20)) + self.assertEqual(v.shape, (None, 20, 20)) + + def test_inv(self): + x = KerasTensor([None, 20, 20]) + out = linalg.inv(x) + self.assertEqual(out.shape, (None, 20, 20)) + + x = KerasTensor([None, None, 20]) + with self.assertRaises(ValueError): + linalg.inv(x) + + x = KerasTensor([None, 20, 15]) + with self.assertRaises(ValueError): + linalg.inv(x) + + def test_lu_factor(self): + if testing.jax_uses_gpu(): + self.skipTest("Skipping test with JAX + GPU due to temporary error") + + x = KerasTensor([None, 4, 3]) + lu, p = linalg.lu_factor(x) + self.assertEqual(lu.shape, (None, 4, 3)) + self.assertEqual(p.shape, (None, 3)) + + x = KerasTensor([None, 2, 3]) + lu, p = linalg.lu_factor(x) + self.assertEqual(lu.shape, (None, 2, 3)) + self.assertEqual(p.shape, (None, 2)) + + def test_norm(self): + x = KerasTensor((None, 3)) + self.assertEqual(linalg.norm(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(linalg.norm(x, axis=1).shape, (None, 3)) + self.assertEqual( + linalg.norm(x, axis=1, keepdims=True).shape, (None, 1, 3) + ) + + def test_qr(self): + x = KerasTensor((None, 4, 3), dtype="float32") + q, r = linalg.qr(x, mode="reduced") + qref, rref = np.linalg.qr(np.ones((2, 4, 3)), mode="reduced") + qref_shape = (None,) + qref.shape[1:] + rref_shape = (None,) + rref.shape[1:] + self.assertEqual(q.shape, qref_shape) + self.assertEqual(r.shape, rref_shape) + + q, r = linalg.qr(x, mode="complete") + qref, rref = np.linalg.qr(np.ones((2, 4, 3)), mode="complete") + qref_shape = (None,) + qref.shape[1:] + rref_shape = (None,) + rref.shape[1:] + self.assertEqual(q.shape, qref_shape) + self.assertEqual(r.shape, rref_shape) + + def test_qr_invalid_mode(self): + # backend agnostic error message + x = np.array([[1, 2], [3, 4]]) + invalid_mode = "invalid_mode" + with self.assertRaisesRegex( + ValueError, "Expected one of {'reduced', 'complete'}." + ): + linalg.qr(x, mode=invalid_mode) + + def test_solve(self): + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, 20, 5]) + out = linalg.solve(a, b) + self.assertEqual(out.shape, (None, 20, 5)) + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, 20]) + out = linalg.solve(a, b) + self.assertEqual(out.shape, (None, 20)) + + a = KerasTensor([None, None, 20]) + b = KerasTensor([None, 20, 5]) + with self.assertRaises(ValueError): + linalg.solve(a, b) + + a = KerasTensor([None, 20, 15]) + b = KerasTensor([None, 20, 5]) + with self.assertRaises(ValueError): + linalg.solve(a, b) + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, None, 5]) + with self.assertRaises(ValueError): + linalg.solve(a, b) + + def test_solve_triangular(self): + if testing.jax_uses_gpu(): + self.skipTest("Skipping test with JAX + GPU due to temporary error") + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, 20, 5]) + out = linalg.solve_triangular(a, b) + self.assertEqual(out.shape, (None, 20, 5)) + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, 20]) + out = linalg.solve_triangular(a, b) + self.assertEqual(out.shape, (None, 20)) + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, 20, 5]) + out = linalg.solve_triangular(a, b, lower=True) + self.assertEqual(out.shape, (None, 20, 5)) + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, 20]) + out = linalg.solve_triangular(a, b, lower=True) + self.assertEqual(out.shape, (None, 20)) + + a = KerasTensor([None, 20, 15]) + b = KerasTensor([None, 20, 5]) + with self.assertRaises(ValueError): + linalg.solve_triangular(a, b) + + a = KerasTensor([None, 20, 20]) + b = KerasTensor([None, None, 5]) + with self.assertRaises(ValueError): + linalg.solve_triangular(a, b) + + def test_svd(self): + x = KerasTensor((None, 3, 2)) + u, s, v = linalg.svd(x) + self.assertEqual(u.shape, (None, 3, 3)) + self.assertEqual(s.shape, (None, 2)) + self.assertEqual(v.shape, (None, 2, 2)) + + u, s, v = linalg.svd(x, full_matrices=False) + self.assertEqual(u.shape, (None, 3, 2)) + self.assertEqual(s.shape, (None, 2)) + self.assertEqual(v.shape, (None, 2, 2)) + + s = linalg.svd(x, compute_uv=False) + self.assertEqual(s.shape, (None, 2)) + + +class LinalgOpsStaticShapeTest(testing.TestCase): + def test_cholesky(self): + x = KerasTensor([4, 3, 3]) + out = linalg.cholesky(x) + self.assertEqual(out.shape, (4, 3, 3)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.cholesky(x) + + def test_cholesky_inverse(self): + x = KerasTensor([4, 3, 3]) + out = linalg.cholesky_inverse(x) + self.assertEqual(out.shape, (4, 3, 3)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.cholesky_inverse(x) + + def test_det(self): + x = KerasTensor([4, 3, 3]) + out = linalg.det(x) + self.assertEqual(out.shape, (4,)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.det(x) + + def test_eig(self): + x = KerasTensor([4, 3, 3]) + w, v = linalg.eig(x) + self.assertEqual(w.shape, (4, 3)) + self.assertEqual(v.shape, (4, 3, 3)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.eig(x) + + def test_eigh(self): + x = KerasTensor([4, 3, 3]) + w, v = linalg.eigh(x) + self.assertEqual(w.shape, (4, 3)) + self.assertEqual(v.shape, (4, 3, 3)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.eigh(x) + + def test_inv(self): + x = KerasTensor([4, 3, 3]) + out = linalg.inv(x) + self.assertEqual(out.shape, (4, 3, 3)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.inv(x) + + def test_lu_factor(self): + if testing.jax_uses_gpu(): + self.skipTest("Skipping test with JAX + GPU due to temporary error") + + x = KerasTensor([10, 4, 3]) + lu, p = linalg.lu_factor(x) + self.assertEqual(lu.shape, (10, 4, 3)) + self.assertEqual(p.shape, (10, 3)) + + x = KerasTensor([10, 2, 3]) + lu, p = linalg.lu_factor(x) + self.assertEqual(lu.shape, (10, 2, 3)) + self.assertEqual(p.shape, (10, 2)) + + def test_norm(self): + x = KerasTensor((10, 3)) + self.assertEqual(linalg.norm(x).shape, ()) + + x = KerasTensor((10, 3, 3)) + self.assertEqual(linalg.norm(x, axis=1).shape, (10, 3)) + self.assertEqual( + linalg.norm(x, axis=1, keepdims=True).shape, (10, 1, 3) + ) + + def test_qr(self): + x = KerasTensor((4, 3), dtype="float32") + q, r = linalg.qr(x, mode="reduced") + qref, rref = np.linalg.qr(np.ones((4, 3)), mode="reduced") + self.assertEqual(q.shape, qref.shape) + self.assertEqual(r.shape, rref.shape) + + q, r = linalg.qr(x, mode="complete") + qref, rref = np.linalg.qr(np.ones((4, 3)), mode="complete") + self.assertEqual(q.shape, qref.shape) + self.assertEqual(r.shape, rref.shape) + + with self.assertRaises(ValueError): + linalg.qr(x, mode="invalid") + + def test_solve(self): + a = KerasTensor([4, 3, 3]) + b = KerasTensor([4, 3, 5]) + out = linalg.solve(a, b) + self.assertEqual(out.shape, (4, 3, 5)) + + a = KerasTensor([4, 3, 3]) + b = KerasTensor([4, 3]) + out = linalg.solve(a, b) + self.assertEqual(out.shape, (4, 3)) + + a = KerasTensor([10, 20, 15]) + b = KerasTensor([10, 20, 5]) + with self.assertRaises(ValueError): + linalg.solve(a, b) + + a = KerasTensor([20, 20]) + b = KerasTensor([]) + with self.assertRaises(ValueError): + linalg.solve(a, b) + + def test_solve_triangular(self): + if testing.jax_uses_gpu(): + self.skipTest("Skipping test with JAX + GPU due to temporary error") + + a = KerasTensor([4, 3, 3]) + b = KerasTensor([4, 3, 5]) + out = linalg.solve_triangular(a, b) + self.assertEqual(out.shape, (4, 3, 5)) + + a = KerasTensor([4, 3, 3]) + b = KerasTensor([4, 3]) + out = linalg.solve_triangular(a, b) + self.assertEqual(out.shape, (4, 3)) + + a = KerasTensor([10, 20, 15]) + b = KerasTensor([10, 20, 5]) + with self.assertRaises(ValueError): + linalg.solve_triangular(a, b) + + def test_svd(self): + x = KerasTensor((10, 3, 2)) + u, s, v = linalg.svd(x) + self.assertEqual(u.shape, (10, 3, 3)) + self.assertEqual(s.shape, (10, 2)) + self.assertEqual(v.shape, (10, 2, 2)) + + u, s, v = linalg.svd(x, full_matrices=False) + self.assertEqual(u.shape, (10, 3, 2)) + self.assertEqual(s.shape, (10, 2)) + self.assertEqual(v.shape, (10, 2, 2)) + + s = linalg.svd(x, compute_uv=False) + self.assertEqual(s.shape, (10, 2)) + + +class LinalgOpsCorrectnessTest(testing.TestCase): + def test_cholesky(self): + x_non_psd = np.random.rand(4, 3, 3).astype("float32") + with self.assertRaises(ValueError): + linalg.cholesky(x_non_psd) + + x = np.random.rand(4, 3, 3).astype("float32") + x_psd = np.matmul(x, x.transpose((0, 2, 1))) + 1e-5 * np.eye( + 3, dtype="float32" + ) + + l_out = linalg.cholesky(x_psd, upper=False) + l_expected = np.linalg.cholesky(x_psd) + self.assertAllClose(l_out, l_expected, atol=1e-4) + + u_out = linalg.cholesky(x_psd, upper=True) + u_expected = l_expected.transpose((0, 2, 1)) + self.assertAllClose(u_out, u_expected, atol=1e-4) + + @parameterized.named_parameters( + {"testcase_name": "lower", "upper": False}, + {"testcase_name": "upper", "upper": True}, + ) + def test_cholesky_inverse(self, upper): + A = np.array( + [ + [4.0, 12.0, -16.0], + [12.0, 37.0, -43.0], + [-16.0, -43.0, 98.0], + ], + dtype="float32", + ) + if upper: + factor = np.linalg.cholesky(A, upper=True) + else: + factor = np.linalg.cholesky(A) + + expected_inverse = np.array( + [ + [49.36111, -13.555555, 2.111111], + [-13.555555, 3.777778, -0.555556], + [2.111111, -0.555556, 0.111111], + ], + dtype="float32", + ) + + output_inverse = linalg.cholesky_inverse(factor, upper=upper) + self.assertAllClose(output_inverse, expected_inverse, atol=1e-5) + + def test_det(self): + x = np.random.rand(4, 3, 3) + out = linalg.det(x) + self.assertAllClose(out, np.linalg.det(x), atol=1e-5) + + with self.assertRaises(ValueError): + x = np.random.rand(4, 3, 4) + linalg.det(x) + + def test_eig(self): + x = np.random.rand(2, 3, 3) + x = x @ x.transpose((0, 2, 1)) + w, v = map(ops.convert_to_numpy, linalg.eig(x)) + x_reconstructed = (v * w[..., None, :]) @ v.transpose((0, 2, 1)) + self.assertAllClose(x_reconstructed, x, atol=1e-4) + + def test_eigh(self): + x = np.random.rand(2, 3, 3) + x = x @ x.transpose((0, 2, 1)) + w, v = map(ops.convert_to_numpy, linalg.eigh(x)) + x_reconstructed = (v * w[..., None, :]) @ v.transpose((0, 2, 1)) + self.assertAllClose(x_reconstructed, x, atol=1e-4) + + def test_inv(self): + x = np.random.rand(4, 3, 3) + x_inv = ops.convert_to_numpy(linalg.inv(x)) + x_reconstructed = x @ x_inv + # high tolerance due to numerical instability + self.assertAllClose( + x_reconstructed, np.repeat(np.eye(3)[None], 4, 0), atol=1e-3 + ) + + def test_lu_factor(self): + if testing.jax_uses_gpu(): + self.skipTest("Skipping test with JAX + GPU due to temporary error") + + def _pivot_matrix(pivots, n): + p_matrix = np.eye(n) + for i, p in enumerate(pivots): + identity = np.eye(n, n) + q = identity[i, :].copy() + identity[i, :] = identity[p, :] + identity[p, :] = q + p_matrix = np.dot(p_matrix, identity) + return p_matrix + + def _reconstruct(lu, pivots, m, n): + lower = np.tril(lu[:, : min(m, n)], -1) + np.eye(m, min(m, n)) + upper = np.triu(lu[: min(m, n)]) + + # pivots are defined differently in tensorflow + # compared to the other backends + if backend.backend() == "tensorflow": + p_matrix = np.eye(m)[pivots] + else: + p_matrix = _pivot_matrix(pivots, m) + out = p_matrix @ lower @ upper + return out + + m, n = 4, 4 + x = np.random.rand(m, n) + lu, pivots = map(ops.convert_to_numpy, linalg.lu_factor(x)) + x_reconstructed = _reconstruct(lu, pivots, m, n) + self.assertAllClose(x_reconstructed, x, atol=1e-5) + + m, n = 4, 3 + x = np.random.rand(m, n) + if backend.backend() == "tensorflow": + with self.assertRaises(ValueError): + linalg.lu_factor(x) + else: + lu, pivots = map(ops.convert_to_numpy, linalg.lu_factor(x)) + x_reconstructed = _reconstruct(lu, pivots, m, n) + self.assertAllClose(x_reconstructed, x, atol=1e-5) + + # batched case + m, n = 3, 4 + x = np.random.rand(2, m, n) + if backend.backend() == "tensorflow": + with self.assertRaises(ValueError): + linalg.lu_factor(x) + else: + lu, pivots = map(ops.convert_to_numpy, linalg.lu_factor(x)) + for i in range(2): + self.assertAllClose( + _reconstruct(lu[i], pivots[i], m, n), x[i], atol=1e-5 + ) + + @parameterized.named_parameters( + named_product( + ndim=[1, 2], + ord=[None, "fro", "nuc", -np.inf, -2, -1, 0, 1, 2, np.inf, 3], + axis=[None, 1, -1, (0, 1)], + keepdims=[False, True], + ) + ) + def test_norm(self, ndim, ord, axis, keepdims): + if ndim == 1: + x = np.random.random((5,)).astype("float32") + else: + x = np.random.random((5, 6)).astype("float32") + + vector_norm = (ndim == 1) or isinstance(axis, int) + + axis_out_of_bounds = ndim == 1 and ( + axis == 1 or isinstance(axis, tuple) + ) + expected_error = None + # when an out of bounds axis triggers an IndexError on torch is complex + if ( + axis_out_of_bounds + and (not isinstance(axis, tuple) or ord is None) + and ord not in ("fro", "nuc") + ): + expected_error = IndexError + elif ( + axis_out_of_bounds + or (vector_norm and isinstance(axis, tuple)) # inv. axis for vector + or (vector_norm and ord in ("fro", "nuc")) # invalid ord for vector + or (not vector_norm and ord in (0, 3)) # invalid ord for matrix + ): + expected_error = RuntimeError + + if expected_error is not None: + # Non-torch backends always throw a ValueError + expected_error = ( + expected_error if backend.backend() == "torch" else ValueError + ) + with self.assertRaises(expected_error): + linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + return + output = linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + expected_result = np.linalg.norm( + x, ord=ord, axis=axis, keepdims=keepdims + ) + self.assertAllClose(output, expected_result, atol=1e-5) + + def test_qr(self): + x = np.random.random((4, 5)) + q, r = linalg.qr(x, mode="reduced") + qref, rref = np.linalg.qr(x, mode="reduced") + self.assertAllClose(qref, q) + self.assertAllClose(rref, r) + + q, r = linalg.qr(x, mode="complete") + qref, rref = np.linalg.qr(x, mode="complete") + self.assertAllClose(qref, q) + self.assertAllClose(rref, r) + + def test_solve(self): + x1 = np.array([[1, 2], [4, 5]], dtype="float32") + x2 = np.array([[2, 4], [8, 10]], dtype="float32") + output = linalg.solve(x1, x2) + expected_result = np.array([[2, 0], [0, 2]], dtype="float32") + self.assertAllClose(output, expected_result) + + def test_solve_triangular(self): + if testing.jax_uses_gpu(): + self.skipTest("Skipping test with JAX + GPU due to temporary error") + + # 2d-case + x1 = np.array([[1, 2], [0, 5]], dtype="float32") + x2 = np.array([2, 10], dtype="float32") + output = linalg.solve_triangular(x1, x2, lower=True) + expected_result = np.array([2, 2], dtype="float32") + self.assertAllClose(output, expected_result) + + output = linalg.solve_triangular(x1, x2, lower=False) + expected_result = np.array([-2, 2], dtype="float32") + self.assertAllClose(output, expected_result) + + # batched case + x1 = np.array([[[1, 2], [0, 5]], [[1, 2], [0, 5]]], dtype="float32") + x2 = np.array([[2, 10], [2, 10]], dtype="float32") + output = linalg.solve_triangular(x1, x2, lower=True) + expected_result = np.array([[2, 2], [2, 2]], dtype="float32") + self.assertAllClose(output, expected_result) + + def test_svd(self): + x = np.random.rand(4, 30, 20).astype("float32") + u, s, vh = linalg.svd(x) + x_reconstructed = (u[..., :, : s.shape[-1]] * s[..., None, :]) @ vh[ + ..., : s.shape[-1], : + ] + # High tolerance due to numerical instability + self.assertAllClose(x_reconstructed, x, atol=1e-3) + + # Test `compute_uv=False` + s_no_uv = linalg.svd(x, compute_uv=False) + self.assertAllClose(s_no_uv, s, atol=1e-5, rtol=1e-5) + + @parameterized.named_parameters( + ("b_rank_1", 1, None), + ("b_rank_2", 2, None), + ("rcond", 1, 1e-3), + ) + def test_lstsq(self, b_rank, rcond): + a = np.random.random((5, 7)).astype("float32") + a_symb = backend.KerasTensor((5, 7)) + if b_rank == 1: + b = np.random.random((5,)).astype("float32") + b_symb = backend.KerasTensor((5,)) + else: + b = np.random.random((5, 4)).astype("float32") + b_symb = backend.KerasTensor((5, 4)) + out = linalg.lstsq(a, b, rcond=rcond) + ref_out = np.linalg.lstsq(a, b, rcond=rcond)[0] + self.assertAllClose(out, ref_out, atol=1e-5) + + out_symb = linalg.lstsq(a_symb, b_symb) + self.assertEqual(out_symb.shape, out.shape) + + +class QrOpTest(testing.TestCase): + def test_qr_init_mode_reduced(self): + qr_op = linalg.Qr(mode="reduced") + self.assertIsNotNone(qr_op) + + def test_qr_init_mode_complete(self): + qr_op = linalg.Qr(mode="complete") + self.assertIsNotNone(qr_op) + + def test_qr_init_invalid_mode(self): + invalid_mode = "invalid_mode" + expected_error = ( + r"`mode` argument value not supported. " + r"Expected one of \{'reduced', 'complete'\}. " + f"Received: mode={invalid_mode}" + ) + with self.assertRaisesRegex(ValueError, expected_error): + linalg.Qr(mode=invalid_mode) + + def test_compute_output_spec_low_rank(self): + qr_op = linalg.Qr(mode="reduced") + low_rank_input = np.random.rand(3) + with self.assertRaisesRegex( + ValueError, r"Input should have rank >= 2. Received: .*" + ): + qr_op.compute_output_spec(low_rank_input) + + def test_compute_output_spec_undefined_dimensions(self): + qr_op = linalg.Qr(mode="reduced") + undefined_dim_input = KerasTensor(shape=(None, 4), dtype="float32") + with self.assertRaisesRegex( + ValueError, + r"Input should have its last 2 dimensions " + r"fully-defined. Received: .*", + ): + qr_op.compute_output_spec(undefined_dim_input) + + def test_qr_call_mode_reduced(self): + qr_op = linalg.Qr(mode="reduced") + test_input = np.random.rand(10, 10) + q, r = qr_op.call(test_input) + self.assertEqual(q.shape, (10, 10)) + self.assertEqual(r.shape, (10, 10)) + + def test_qr_call_mode_complete(self): + qr_op = linalg.Qr(mode="complete") + test_input = np.random.rand(10, 10) + q, r = qr_op.call(test_input) + self.assertEqual(q.shape, (10, 10)) + self.assertEqual(r.shape, (10, 10)) + + def test_jvp(self): + if backend.backend() in ["openvino", "numpy"]: + pytest.skip("Backend does not support jvp operation") + a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2) + primals, tangents = linalg.jvp(backend.numpy.sin, (a1,), (a2,)) + self.assertAllClose(primals, 0.0998, atol=1e-4) + self.assertAllClose(tangents, 0.1990, atol=1e-4) + + def f(x): + return backend.numpy.sin(x), x**2 + + primals_out, tangents_out, aux = linalg.jvp( + f, (a1,), (a2,), has_aux=True + ) + self.assertAllClose(primals_out, 0.0998, atol=1e-4) + self.assertAllClose(tangents_out, 0.1990, atol=1e-4) + self.assertAllClose(aux, 0.01, atol=1e-4) + + def test_jvp_symbolic_has_aux_false(self): + primals = KerasTensor((None, 7)) + tangents = KerasTensor((None, 7)) + + def fun(x): + # simple non-linear transformation + return ops.sin(x) + ops.cos(x) + + primals_out, tangents_out = linalg.jvp(fun, (primals,), (tangents,)) + # output shapes must match input shapes + self.assertEqual(primals_out.shape, primals.shape) + self.assertEqual(tangents_out.shape, tangents.shape) + + """Symbolic JVP test – has_aux=True.""" + + def fun(x): + y = ops.exp(x) + aux = ops.mean(y, axis=-1, keepdims=True) # auxiliary output + return y, aux + + primals_out, tangents_out, aux = linalg.jvp( + fun, (primals,), (tangents,), has_aux=True + ) + # main output shapes + self.assertEqual(primals_out.shape, primals.shape) + self.assertEqual(tangents_out.shape, tangents.shape) + # auxiliary shape: (batch, 1) + self.assertEqual(aux.shape, (None, 1)) diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py new file mode 100644 index 000000000000..e0da72d6f292 --- /dev/null +++ b/keras/src/ops/math.py @@ -0,0 +1,1135 @@ +"""Commonly used math operations not included in NumPy.""" + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.ops.operation import Operation +from keras.src.ops.operation_utils import reduce_shape + + +def _segment_reduce_validation(data, segment_ids): + data_shape = data.shape + segment_ids_shape = segment_ids.shape + if len(segment_ids_shape) > 1: + raise ValueError( + "Argument `segment_ids` should be an 1-D vector, got shape: " + f"{len(segment_ids_shape)}. Consider either flatten input with " + "segment_ids.reshape((-1)) and " + "data.reshape((-1, ) + data.shape[len(segment_ids.shape):]) or " + "vectorize with vmap." + ) + if ( + segment_ids_shape[0] is not None + and data_shape[0] is not None + and segment_ids_shape[0] != data_shape[0] + ): + raise ValueError( + "Argument `segment_ids` and `data` should have same leading " + f"dimension. Got {segment_ids_shape} v.s. " + f"{data_shape}." + ) + + +class SegmentReduction(Operation): + def __init__(self, num_segments=None, sorted=False, *, name=None): + super().__init__(name=name) + self.num_segments = num_segments + self.sorted = sorted + + def compute_output_spec(self, data, _): + output_shape = (self.num_segments,) + tuple(data.shape[1:]) + return KerasTensor(shape=output_shape, dtype=data.dtype) + + +class SegmentSum(SegmentReduction): + def call(self, data, segment_ids): + _segment_reduce_validation(data, segment_ids) + return backend.math.segment_sum( + data, + segment_ids, + num_segments=self.num_segments, + sorted=self.sorted, + ) + + +@keras_export("keras.ops.segment_sum") +def segment_sum(data, segment_ids, num_segments=None, sorted=False): + """Computes the sum of segments in a tensor. + + Args: + data: Input tensor. + segment_ids: A N-D tensor containing segment indices for each + element in `data`. Num dims for segment ids should be strictly + smaller or equal to number of dims in data. + num_segments: An integer representing the total number of + segments. If not specified, it is inferred from the maximum + value in `segment_ids`. + sorted: A boolean indicating whether `segment_ids` is sorted. + Defaults to `False`. + + Returns: + A tensor containing the sum of segments, where each element + represents the sum of the corresponding segment in `data`. + + Example: + + >>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200]) + >>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2]) + >>> num_segments = 3 + >>> keras.ops.segment_sum(data, segment_ids,num_segments) + array([3, 30, 300], dtype=int32) + """ + _segment_reduce_validation(data, segment_ids) + if any_symbolic_tensors((data,)): + return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids) + return backend.math.segment_sum( + data, segment_ids, num_segments=num_segments, sorted=sorted + ) + + +class SegmentMax(SegmentReduction): + def call(self, data, segment_ids): + _segment_reduce_validation(data, segment_ids) + return backend.math.segment_max( + data, + segment_ids, + num_segments=self.num_segments, + sorted=self.sorted, + ) + + +@keras_export("keras.ops.segment_max") +def segment_max(data, segment_ids, num_segments=None, sorted=False): + """Computes the max of segments in a tensor. + + Args: + data: Input tensor. + segment_ids: A N-D tensor containing segment indices for each + element in `data`. data.shape[:len(segment_ids.shape)] should match. + num_segments: An integer representing the total number of + segments. If not specified, it is inferred from the maximum + value in `segment_ids`. + sorted: A boolean indicating whether `segment_ids` is sorted. + Defaults to `False`. + + Returns: + A tensor containing the max of segments, where each element + represents the max of the corresponding segment in `data`. + + Example: + + >>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200]) + >>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2]) + >>> num_segments = 3 + >>> keras.ops.segment_max(data, segment_ids, num_segments) + array([2, 20, 200], dtype=int32) + """ + _segment_reduce_validation(data, segment_ids) + if any_symbolic_tensors((data,)): + return SegmentMax(num_segments, sorted).symbolic_call(data, segment_ids) + return backend.math.segment_max( + data, segment_ids, num_segments=num_segments, sorted=sorted + ) + + +class TopK(Operation): + def __init__(self, k, sorted=True, *, name=None): + super().__init__(name=name) + self.k = k + self.sorted = sorted + + def compute_output_spec(self, x): + output_shape = list(x.shape) + output_shape[-1] = self.k + # Return a tuple (values, indices). + return ( + KerasTensor(shape=output_shape, dtype=x.dtype), + KerasTensor(shape=output_shape, dtype="int32"), + ) + + def call(self, x): + return backend.math.top_k(x, self.k, self.sorted) + + +@keras_export("keras.ops.top_k") +def top_k(x, k, sorted=True): + """Finds the top-k values and their indices in a tensor. + + Args: + x: Input tensor. + k: An integer representing the number of top elements to retrieve. + sorted: A boolean indicating whether to sort the output in + descending order. Defaults to `True`. + + Returns: + A tuple containing two tensors. The first tensor contains the + top-k values, and the second tensor contains the indices of the + top-k values in the input tensor. + + Example: + + >>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3]) + >>> values, indices = top_k(x, k=3) + >>> print(values) + array([9 7 5], shape=(3,), dtype=int32) + >>> print(indices) + array([4 2 0], shape=(3,), dtype=int32) + + """ + if any_symbolic_tensors((x,)): + return TopK(k, sorted).symbolic_call(x) + return backend.math.top_k(x, k, sorted) + + +class InTopK(Operation): + def __init__(self, k, *, name=None): + super().__init__(name=name) + self.k = k + + def compute_output_spec(self, targets, predictions): + return KerasTensor(shape=targets.shape, dtype="bool") + + def call(self, targets, predictions): + return backend.math.in_top_k(targets, predictions, self.k) + + +@keras_export("keras.ops.in_top_k") +def in_top_k(targets, predictions, k): + """Checks if the targets are in the top-k predictions. + + Args: + targets: A tensor of true labels. + predictions: A tensor of predicted labels. + k: An integer representing the number of predictions to consider. + + Returns: + A boolean tensor of the same shape as `targets`, where each element + indicates whether the corresponding target is in the top-k predictions. + + Example: + + >>> targets = keras.ops.convert_to_tensor([2, 5, 3]) + >>> predictions = keras.ops.convert_to_tensor( + ... [[0.1, 0.4, 0.6, 0.9, 0.5], + ... [0.1, 0.7, 0.9, 0.8, 0.3], + ... [0.1, 0.6, 0.9, 0.9, 0.5]]) + >>> in_top_k(targets, predictions, k=3) + array([ True False True], shape=(3,), dtype=bool) + """ + if any_symbolic_tensors((targets, predictions)): + return InTopK(k).symbolic_call(targets, predictions) + return backend.math.in_top_k(targets, predictions, k) + + +class Logsumexp(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + self.axis = axis + self.keepdims = keepdims + + def compute_output_spec(self, x): + output_shape = reduce_shape(x.shape, self.axis, self.keepdims) + return KerasTensor(shape=output_shape) + + def call(self, x): + return backend.math.logsumexp(x, axis=self.axis, keepdims=self.keepdims) + + +@keras_export("keras.ops.logsumexp") +def logsumexp(x, axis=None, keepdims=False): + """Computes the logarithm of sum of exponentials of elements in a tensor. + + Args: + x: Input tensor. + axis: An integer or a tuple of integers specifying the axis/axes + along which to compute the sum. If `None`, the sum is computed + over all elements. Defaults to `None`. + keepdims: A boolean indicating whether to keep the dimensions of + the input tensor when computing the sum. Defaults to `False`. + + Returns: + A tensor containing the logarithm of the sum of exponentials of + elements in `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([1., 2., 3.]) + >>> logsumexp(x) + 3.407606 + """ + if any_symbolic_tensors((x,)): + return Logsumexp(axis, keepdims).symbolic_call(x) + return backend.math.logsumexp(x, axis=axis, keepdims=keepdims) + + +class ExtractSequences(Operation): + def __init__(self, sequence_length, sequence_stride, *, name=None): + super().__init__(name=name) + self.sequence_length = sequence_length + self.sequence_stride = sequence_stride + + def compute_output_spec(self, x): + if len(x.shape) < 1: + raise ValueError( + f"Input should have rank >= 1. " + f"Received: input.shape = {x.shape}" + ) + if x.shape[-1] is not None: + num_sequences = ( + 1 + (x.shape[-1] - self.sequence_length) // self.sequence_stride + ) + else: + num_sequences = None + new_shape = x.shape[:-1] + (num_sequences, self.sequence_length) + return KerasTensor(shape=new_shape, dtype=x.dtype) + + def call(self, x): + return backend.math.extract_sequences( + x, + sequence_length=self.sequence_length, + sequence_stride=self.sequence_stride, + ) + + +@keras_export("keras.ops.extract_sequences") +def extract_sequences(x, sequence_length, sequence_stride): + """Expands the dimension of last axis into sequences of `sequence_length`. + + Slides a window of size `sequence_length` over the last axis of the input + with a stride of `sequence_stride`, replacing the last axis with + `[num_sequences, sequence_length]` sequences. + + If the dimension along the last axis is N, the number of sequences can be + computed by: + + `num_sequences = 1 + (N - sequence_length) // sequence_stride` + + Args: + x: Input tensor. + sequence_length: An integer representing the sequences length. + sequence_stride: An integer representing the sequences hop size. + + Returns: + A tensor of sequences with shape [..., num_sequences, sequence_length]. + + Example: + + >>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) + >>> extract_sequences(x, 3, 2) + array([[1, 2, 3], + [3, 4, 5]]) + """ + if any_symbolic_tensors((x,)): + return ExtractSequences(sequence_length, sequence_stride).symbolic_call( + x + ) + return backend.math.extract_sequences(x, sequence_length, sequence_stride) + + +class FFT(Operation): + def compute_output_spec(self, x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + f"imaginary. Received: x={x}" + ) + + real, imag = x + # Both real and imaginary parts should have the same shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + "imaginary. Both the real and imaginary parts should have the " + f"same shape. Received: x[0].shape = {real.shape}, " + f"x[1].shape = {imag.shape}" + ) + + # We are calculating 1D FFT. Hence, rank >= 1. + if len(real.shape) < 1: + raise ValueError( + f"Input should have rank >= 1. " + f"Received: input.shape = {real.shape}" + ) + + # The axis along which we are calculating FFT should be fully-defined. + m = real.shape[-1] + if m is None: + raise ValueError( + f"Input should have its last dimension fully-defined. " + f"Received: input.shape = {real.shape}" + ) + + return ( + KerasTensor(shape=real.shape, dtype=real.dtype), + KerasTensor(shape=imag.shape, dtype=imag.dtype), + ) + + def call(self, x): + return backend.math.fft(x) + + +@keras_export("keras.ops.fft") +def fft(x): + """Computes the Fast Fourier Transform along last axis of input. + + Args: + x: Tuple of the real and imaginary parts of the input tensor. Both + tensors in the tuple should be of floating type. + + Returns: + A tuple containing two tensors - the real and imaginary parts of the + output tensor. + + Example: + + >>> x = ( + ... keras.ops.convert_to_tensor([1., 2.]), + ... keras.ops.convert_to_tensor([0., 1.]), + ... ) + >>> fft(x) + (array([ 3., -1.], dtype=float32), array([ 1., -1.], dtype=float32)) + """ + if any_symbolic_tensors(x): + return FFT().symbolic_call(x) + return backend.math.fft(x) + + +class FFT2(Operation): + def compute_output_spec(self, x): + axes = (-2, -1) + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + f"imaginary. Received: x={x}" + ) + + real, imag = x + # Both real and imaginary parts should have the same shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + "imaginary. Both the real and imaginary parts should have the " + f"same shape. Received: x[0].shape = {real.shape}, " + f"x[1].shape = {imag.shape}" + ) + # We are calculating 2D FFT. Hence, rank >= 2. + if len(real.shape) < 2: + raise ValueError( + f"Input should have rank >= 2. " + f"Received: input.shape = {real.shape}" + ) + + # The axes along which we are calculating FFT should be fully-defined. + m = real.shape[axes[0]] + n = real.shape[axes[1]] + if m is None or n is None: + raise ValueError( + f"Input should have its {axes} axes fully-defined. " + f"Received: input.shape = {real.shape}" + ) + + return ( + KerasTensor(shape=real.shape, dtype=real.dtype), + KerasTensor(shape=imag.shape, dtype=imag.dtype), + ) + + def call(self, x): + return backend.math.fft2(x) + + +@keras_export("keras.ops.fft2") +def fft2(x): + """Computes the 2D Fast Fourier Transform along the last two axes of input. + + Args: + x: Tuple of the real and imaginary parts of the input tensor. Both + tensors in the tuple should be of floating type. + + Returns: + A tuple containing two tensors - the real and imaginary parts of the + output. + + Example: + + >>> x = ( + ... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]), + ... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]), + ... ) + >>> fft2(x) + (array([[ 6., 0.], + [ 0., -2.]], dtype=float32), array([[ 2., 0.], + [ 0., -2.]], dtype=float32)) + """ + if any_symbolic_tensors(x): + return FFT2().symbolic_call(x) + return backend.math.fft2(x) + + +class IFFT2(Operation): + def compute_output_spec(self, x): + axes = (-2, -1) + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + f"imaginary. Received: x={x}" + ) + + real, imag = x + # Both real and imaginary parts should have the same shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + "imaginary. Both the real and imaginary parts should have the " + f"same shape. Received: x[0].shape = {real.shape}, " + f"x[1].shape = {imag.shape}" + ) + # We are calculating 2D IFFT. Hence, rank >= 2. + if len(real.shape) < 2: + raise ValueError( + f"Input should have rank >= 2. " + f"Received: input.shape = {real.shape}" + ) + + # The axes along which we are calculating IFFT should be fully-defined. + m = real.shape[axes[0]] + n = real.shape[axes[1]] + if m is None or n is None: + raise ValueError( + f"Input should have its {axes} axes fully-defined. " + f"Received: input.shape = {real.shape}" + ) + + return ( + KerasTensor(shape=real.shape, dtype=real.dtype), + KerasTensor(shape=imag.shape, dtype=imag.dtype), + ) + + def call(self, x): + return backend.math.ifft2(x) + + +@keras_export("keras.ops.ifft2") +def ifft2(x): + """Computes the 2D Inverse Fast Fourier Transform along the last two axes of + input. + + Args: + x: Tuple of the real and imaginary parts of the input tensor. Both + tensors in the tuple should be of floating type. + + Returns: + A tuple containing two tensors - the real and imaginary parts of the + output. + + Example: + + >>> x = ( + ... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]), + ... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]), + ... ) + >>> ifft2(x) + (array([[ 6., 0.], + [ 0., -2.]], dtype=float32), array([[ 2., 0.], + [ 0., -2.]], dtype=float32)) + """ + if any_symbolic_tensors(x): + return IFFT2().symbolic_call(x) + return backend.math.ifft2(x) + + +class RFFT(Operation): + def __init__(self, fft_length=None, *, name=None): + super().__init__(name=name) + self.fft_length = fft_length + + def compute_output_spec(self, x): + # We are calculating 1D RFFT. Hence, rank >= 1. + if len(x.shape) < 1: + raise ValueError( + f"Input should have rank >= 1. " + f"Received: input.shape = {x.shape}" + ) + + if self.fft_length is not None: + new_last_dimension = self.fft_length // 2 + 1 + else: + if x.shape[-1] is not None: + new_last_dimension = x.shape[-1] // 2 + 1 + else: + new_last_dimension = None + new_shape = x.shape[:-1] + (new_last_dimension,) + + return ( + KerasTensor(shape=new_shape, dtype=x.dtype), + KerasTensor(shape=new_shape, dtype=x.dtype), + ) + + def call(self, x): + return backend.math.rfft(x, fft_length=self.fft_length) + + +@keras_export("keras.ops.rfft") +def rfft(x, fft_length=None): + """Real-valued Fast Fourier Transform along the last axis of the input. + + Computes the 1D Discrete Fourier Transform of a real-valued signal over the + inner-most dimension of input. + + Since the Discrete Fourier Transform of a real-valued signal is + Hermitian-symmetric, RFFT only returns the `fft_length / 2 + 1` unique + components of the FFT: the zero-frequency term, followed by the + `fft_length / 2` positive-frequency terms. + + Along the axis RFFT is computed on, if `fft_length` is smaller than the + corresponding dimension of the input, the dimension is cropped. If it is + larger, the dimension is padded with zeros. + + Args: + x: Input tensor. + fft_length: An integer representing the number of the fft length. If not + specified, it is inferred from the length of the last axis of `x`. + Defaults to `None`. + + Returns: + A tuple containing two tensors - the real and imaginary parts of the + output. + + Examples: + + >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + >>> rfft(x) + (array([10.0, -2.5, -2.5]), array([0.0, 3.4409548, 0.81229924])) + + >>> rfft(x, 3) + (array([3.0, -1.5]), array([0.0, 0.8660254])) + """ + if any_symbolic_tensors((x,)): + return RFFT(fft_length).symbolic_call(x) + return backend.math.rfft(x, fft_length) + + +class IRFFT(Operation): + def __init__(self, fft_length=None, *, name=None): + super().__init__(name=name) + self.fft_length = fft_length + + def compute_output_spec(self, x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + f"imaginary. Received: x={x}" + ) + real, imag = x + # Both real and imaginary parts should have the same shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + "imaginary. Both the real and imaginary parts should have the " + f"same shape. Received: x[0].shape = {real.shape}, " + f"x[1].shape = {imag.shape}" + ) + # We are calculating 1D IRFFT. Hence, rank >= 1. + if len(real.shape) < 1: + raise ValueError( + f"Input should have rank >= 1. " + f"Received: input.shape = {real.shape}" + ) + + if self.fft_length is not None: + new_last_dimension = self.fft_length + else: + if real.shape[-1] is not None: + new_last_dimension = 2 * (real.shape[-1] - 1) + else: + new_last_dimension = None + new_shape = real.shape[:-1] + (new_last_dimension,) + return KerasTensor(shape=new_shape, dtype=real.dtype) + + def call(self, x): + return backend.math.irfft(x, fft_length=self.fft_length) + + +@keras_export("keras.ops.irfft") +def irfft(x, fft_length=None): + """Inverse real-valued Fast Fourier transform along the last axis. + + Computes the inverse 1D Discrete Fourier Transform of a real-valued signal + over the inner-most dimension of input. + + The inner-most dimension of the input is assumed to be the result of RFFT: + the `fft_length / 2 + 1` unique components of the DFT of a real-valued + signal. If `fft_length` is not provided, it is computed from the size of the + inner-most dimension of the input `(fft_length = 2 * (inner - 1))`. If the + FFT length used to compute is odd, it should be provided since it cannot + be inferred properly. + + Along the axis IRFFT is computed on, if `fft_length / 2 + 1` is smaller than + the corresponding dimension of the input, the dimension is cropped. If it is + larger, the dimension is padded with zeros. + + Args: + x: Tuple of the real and imaginary parts of the input tensor. Both + tensors in the tuple should be of floating type. + fft_length: An integer representing the number of the fft length. If not + specified, it is inferred from the length of the last axis of `x`. + Defaults to `None`. + + Returns: + A tensor containing the inverse real-valued Fast Fourier Transform + along the last axis of `x`. + + Examples: + + >>> real = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + >>> imag = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + >>> irfft((real, imag)) + array([0.66666667, -0.9106836, 0.24401694]) + + >>> irfft(rfft(real, 5), 5) + array([0.0, 1.0, 2.0, 3.0, 4.0]) + """ + if any_symbolic_tensors(x): + return IRFFT(fft_length).symbolic_call(x) + return backend.math.irfft(x, fft_length) + + +class STFT(Operation): + def __init__( + self, + sequence_length, + sequence_stride, + fft_length, + window="hann", + center=True, + *, + name=None, + ): + super().__init__(name=name) + self.sequence_length = sequence_length + self.sequence_stride = sequence_stride + self.fft_length = fft_length + self.window = window + self.center = center + + def compute_output_spec(self, x): + if x.shape[-1] is not None: + padded = 0 if self.center is False else (self.fft_length // 2) * 2 + num_sequences = ( + 1 + + (x.shape[-1] + padded - self.fft_length) + // self.sequence_stride + ) + else: + num_sequences = None + new_shape = x.shape[:-1] + (num_sequences, self.fft_length // 2 + 1) + return ( + KerasTensor(shape=new_shape, dtype=x.dtype), + KerasTensor(shape=new_shape, dtype=x.dtype), + ) + + def call(self, x): + return backend.math.stft( + x, + sequence_length=self.sequence_length, + sequence_stride=self.sequence_stride, + fft_length=self.fft_length, + window=self.window, + center=self.center, + ) + + +@keras_export("keras.ops.stft") +def stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + """Short-Time Fourier Transform along the last axis of the input. + + The STFT computes the Fourier transform of short overlapping windows of the + input. This giving frequency components of the signal as they change over + time. + + Args: + x: Input tensor. + sequence_length: An integer representing the sequence length. + sequence_stride: An integer representing the sequence hop size. + fft_length: An integer representing the size of the FFT to apply. If not + specified, uses the smallest power of 2 enclosing `sequence_length`. + window: A string, a tensor of the window or `None`. If `window` is a + string, available values are `"hann"` and `"hamming"`. If `window` + is a tensor, it will be used directly as the window and its length + must be `sequence_length`. If `window` is `None`, no windowing is + used. Defaults to `"hann"`. + center: Whether to pad `x` on both sides so that the t-th sequence is + centered at time `t * sequence_stride`. Otherwise, the t-th sequence + begins at time `t * sequence_stride`. Defaults to `True`. + + Returns: + A tuple containing two tensors - the real and imaginary parts of the + STFT output. + + Example: + + >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + >>> stft(x, 3, 2, 3) + (array([[0.75, -0.375], + [3.75, -1.875], + [5.25, -2.625]]), array([[0.0, 0.64951905], + [0.0, 0.64951905], + [0.0, -0.64951905]])) + """ + if any_symbolic_tensors((x,)): + return STFT( + sequence_length=sequence_length, + sequence_stride=sequence_stride, + fft_length=fft_length, + window=window, + center=center, + ).symbolic_call(x) + return backend.math.stft( + x, + sequence_length=sequence_length, + sequence_stride=sequence_stride, + fft_length=fft_length, + window=window, + center=center, + ) + + +class ISTFT(Operation): + def __init__( + self, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, + *, + name=None, + ): + super().__init__(name=name) + self.sequence_length = sequence_length + self.sequence_stride = sequence_stride + self.fft_length = fft_length + self.length = length + self.window = window + self.center = center + + def compute_output_spec(self, x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + f"imaginary. Received: x={x}" + ) + real, imag = x + # Both real and imaginary parts should have the same shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + "imaginary. Both the real and imaginary parts should have the " + f"same shape. Received: x[0].shape = {real.shape}, " + f"x[1].shape = {imag.shape}" + ) + if len(real.shape) < 2: + raise ValueError( + f"Input should have rank >= 2. " + f"Received: input.shape = {real.shape}" + ) + if real.shape[-2] is not None: + output_size = ( + real.shape[-2] - 1 + ) * self.sequence_stride + self.fft_length + if self.length is not None: + output_size = self.length + elif self.center: + output_size = output_size - (self.fft_length // 2) * 2 + else: + output_size = None + new_shape = real.shape[:-2] + (output_size,) + return KerasTensor(shape=new_shape, dtype=real.dtype) + + def call(self, x): + return backend.math.istft( + x, + sequence_length=self.sequence_length, + sequence_stride=self.sequence_stride, + fft_length=self.fft_length, + length=self.length, + window=self.window, + center=self.center, + ) + + +@keras_export("keras.ops.istft") +def istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + """Inverse Short-Time Fourier Transform along the last axis of the input. + + To reconstruct an original waveform, the parameters should be the same in + `stft`. + + Args: + x: Tuple of the real and imaginary parts of the input tensor. Both + tensors in the tuple should be of floating type. + sequence_length: An integer representing the sequence length. + sequence_stride: An integer representing the sequence hop size. + fft_length: An integer representing the size of the FFT that produced + `stft`. Should be of type `int32`. + length: An integer representing the output is clipped to exactly length. + If not specified, no padding or clipping take place. Defaults to + `None`. + window: A string, a tensor of the window or `None`. If `window` is a + string, available values are `"hann"` and `"hamming"`. If `window` + is a tensor, it will be used directly as the window and its length + must be `sequence_length`. If `window` is `None`, no windowing is + used. Defaults to `"hann"`. + center: Whether `x` was padded on both sides so that the t-th sequence + is centered at time `t * sequence_stride`. Defaults to `True`. + + Returns: + A tensor containing the inverse Short-Time Fourier Transform along the + last axis of `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0]) + >>> istft(stft(x, 1, 1, 1), 1, 1, 1) + array([0.0, 1.0, 2.0, 3.0, 4.0]) + """ + if any_symbolic_tensors(x): + return ISTFT( + sequence_length=sequence_length, + sequence_stride=sequence_stride, + fft_length=fft_length, + window=window, + center=center, + ).symbolic_call(x) + return backend.math.istft( + x, + sequence_length=sequence_length, + sequence_stride=sequence_stride, + fft_length=fft_length, + length=length, + window=window, + center=center, + ) + + +class Rsqrt(Operation): + def call(self, x): + x = backend.convert_to_tensor(x) + return backend.math.rsqrt(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export("keras.ops.rsqrt") +def rsqrt(x): + """Computes reciprocal of square root of x element-wise. + + Args: + x: input tensor + + Returns: + A tensor with the same dtype as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0]) + >>> keras.ops.rsqrt(x) + array([1.0, 0.31622776, 0.1], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Rsqrt().symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.math.rsqrt(x) + + +class Erf(Operation): + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x): + return backend.math.erf(x) + + +@keras_export("keras.ops.erf") +def erf(x): + """Computes the error function of `x`, element-wise. + + Args: + x: Input tensor. + + Returns: + A tensor with the same dtype as `x`. + + Example: + + >>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0]) + >>> keras.ops.erf(x) + array([-0.99998 , -0.99532, -0.842701, 0., 0.842701], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Erf().symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.math.erf(x) + + +class Erfinv(Operation): + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x): + return backend.math.erfinv(x) + + +@keras_export("keras.ops.erfinv") +def erfinv(x): + """Computes the inverse error function of `x`, element-wise. + + Args: + x: Input tensor. + + Returns: + A tensor with the same dtype as `x`. + + Example: + + >>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3]) + >>> keras.ops.erfinv(x) + array([-0.47694, -0.17914, -0.08886, 0. , 0.27246], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Erfinv().symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.math.erfinv(x) + + +class Logdet(Operation): + def call(self, x): + return backend.math.logdet(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape[:-2], dtype=x.dtype) + + +@keras_export(["keras.ops.logdet"]) +def logdet(x): + """Computes log of the determinant of a hermitian positive definite matrix. + + Args: + x: Input matrix. It must 2D and square. + + Returns: + The natural log of the determinant of matrix. + """ + if any_symbolic_tensors((x,)): + return Logdet().symbolic_call(x) + return backend.math.logdet(x) + + +class ViewAsComplex(Operation): + def call(self, x): + x = backend.convert_to_tensor(x) + if len(x.shape) < 1 or x.shape[-1] != 2: + raise ValueError( + "Input tensor's last dimension must be 2 (real and imaginary)." + ) + return x[..., 0] + 1j * x[..., 1] + + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape[:-1], dtype="complex64") + + +class ViewAsReal(Operation): + def call(self, x): + x = backend.convert_to_tensor(x) + real_part = backend.numpy.real(x) + imag_part = backend.numpy.imag(x) + return backend.numpy.stack((real_part, imag_part), axis=-1) + + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape + (2,), dtype="float32") + + +@keras_export("keras.ops.view_as_complex") +def view_as_complex(x): + """Converts a real tensor with shape `(..., 2)` to a complex tensor, + where the last dimension represents the real and imaginary components + of a complex tensor. + + Args: + x: A real tensor with last dimension of size 2. + + Returns: + A complex tensor with shape `x.shape[:-1]`. + + Example: + + ``` + >>> import numpy as np + >>> from keras import ops + + >>> real_imag = np.array([[1.0, 2.0], [3.0, 4.0]]) + >>> complex_tensor = ops.view_as_complex(real_imag) + >>> complex_tensor + array([1.+2.j, 3.+4.j]) + ``` + """ + if any_symbolic_tensors((x,)): + return ViewAsComplex().symbolic_call(x) + + x = backend.convert_to_tensor(x) + if len(x.shape) < 1 or x.shape[-1] != 2: + raise ValueError( + "Last dimension of input must be size 2 (real and imaginary). " + f"Received shape: {x.shape}" + ) + real_part = x[..., 0] + imag_part = x[..., 1] + + return backend.cast(real_part, dtype="complex64") + 1j * backend.cast( + imag_part, dtype="complex64" + ) + + +@keras_export("keras.ops.view_as_real") +def view_as_real(x): + """Converts a complex tensor to a real tensor with shape `(..., 2)`, + where the last dimension represents the real and imaginary components. + + Args: + x: A complex tensor. + + Returns: + A real tensor where the last dimension contains the + real and imaginary parts. + + Example: + ``` + >>> import numpy as np + >>> from keras import ops + + >>> complex_tensor = np.array([1 + 2j, 3 + 4j]) + >>> real = ops.view_as_real(complex_tensor) + >>> real + array([[1., 2.], + [3., 4.]]) + ``` + """ + if any_symbolic_tensors((x,)): + return ViewAsReal().symbolic_call(x) + + x = backend.convert_to_tensor(x) + real_part = backend.numpy.real(x) + imag_part = backend.numpy.imag(x) + return backend.numpy.stack((real_part, imag_part), axis=-1) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py new file mode 100644 index 000000000000..bd5b17290f27 --- /dev/null +++ b/keras/src/ops/math_test.py @@ -0,0 +1,1558 @@ +import math + +import jax.numpy as jnp +import numpy as np +import pytest +import scipy.signal +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.backend.common import dtypes +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.ops import math as kmath + + +def _stft( + x, sequence_length, sequence_stride, fft_length, window="hann", center=True +): + # pure numpy version of stft that matches librosa's implementation + x = np.array(x) + ori_dtype = x.dtype + + if center: + pad_width = [(0, 0) for _ in range(len(x.shape))] + pad_width[-1] = (fft_length // 2, fft_length // 2) + x = np.pad(x, pad_width, mode="reflect") + + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + + if window is not None: + if isinstance(window, str): + window = scipy.signal.get_window(window, sequence_length) + win = np.array(window, dtype=x.dtype) + win = np.pad(win, [[l_pad, r_pad]]) + else: + win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype) + + x = scipy.signal.stft( + x, + fs=1.0, + window=win, + nperseg=(sequence_length + l_pad + r_pad), + noverlap=(sequence_length + l_pad + r_pad - sequence_stride), + nfft=fft_length, + boundary=None, + padded=False, + )[-1] + + # scale and swap to (..., num_sequences, fft_bins) + x = x / np.sqrt(1.0 / win.sum() ** 2) + x = np.swapaxes(x, -2, -1) + return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype) + + +def _istft( + x, + sequence_length, + sequence_stride, + fft_length, + length=None, + window="hann", + center=True, +): + # pure numpy version of istft that matches librosa's implementation + complex_input = x[0] + 1j * x[1] + x = np.fft.irfft( + complex_input, n=fft_length, axis=-1, norm="backward" + ).astype(x[0].dtype) + + expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1) + + if window is not None: + if isinstance(window, str): + win = np.array( + scipy.signal.get_window(window, sequence_length), dtype=x.dtype + ) + else: + win = np.array(window, dtype=x.dtype) + l_pad = (fft_length - sequence_length) // 2 + r_pad = fft_length - sequence_length - l_pad + win = np.pad(win, [[l_pad, r_pad]]) + + # square and sum + _sequence_length = sequence_length + l_pad + r_pad + denom = np.square(win) + overlaps = -(-_sequence_length // sequence_stride) + denom = np.pad( + denom, [(0, overlaps * sequence_stride - _sequence_length)] + ) + denom = np.reshape(denom, [overlaps, sequence_stride]) + denom = np.sum(denom, 0, keepdims=True) + denom = np.tile(denom, [overlaps, 1]) + denom = np.reshape(denom, [overlaps * sequence_stride]) + win = np.divide(win, denom[:_sequence_length]) + x = np.multiply(x, win) + + # overlap_sequences + def _overlap_sequences(x, sequence_stride): + *batch_shape, num_sequences, sequence_length = x.shape + flat_batchsize = math.prod(batch_shape) + x = np.reshape(x, (flat_batchsize, num_sequences, sequence_length)) + output_size = sequence_stride * (num_sequences - 1) + sequence_length + nstep_per_segment = 1 + (sequence_length - 1) // sequence_stride + padded_segment_len = nstep_per_segment * sequence_stride + x = np.pad( + x, ((0, 0), (0, 0), (0, padded_segment_len - sequence_length)) + ) + x = np.reshape( + x, + (flat_batchsize, num_sequences, nstep_per_segment, sequence_stride), + ) + x = x.transpose((0, 2, 1, 3)) + x = np.pad(x, ((0, 0), (0, 0), (0, num_sequences), (0, 0))) + shrinked = x.shape[2] - 1 + x = np.reshape(x, (flat_batchsize, -1)) + x = x[:, : (nstep_per_segment * shrinked * sequence_stride)] + x = np.reshape( + x, (flat_batchsize, nstep_per_segment, shrinked * sequence_stride) + ) + x = np.sum(x, axis=1)[:, :output_size] + return np.reshape(x, tuple(batch_shape) + (-1,)) + + x = _overlap_sequences(x, sequence_stride) + + start = 0 if center is False else fft_length // 2 + if length is not None: + end = start + length + elif center: + end = -(fft_length // 2) + else: + end = expected_output_len + return x[..., start:end] + + +def _sum_reduce(left, right): + return left + right + + +def _max_reduce(left, right): + return np.max(np.stack([left, right]), axis=0) + + +class MathOpsDynamicShapeTest(testing.TestCase): + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + def test_segment_reduce(self, segment_reduce_op): + # 1D case + data = KerasTensor((None, 4), dtype="float32") + segment_ids = KerasTensor((10,), dtype="int32") + outputs = segment_reduce_op(data, segment_ids) + self.assertEqual(outputs.shape, (None, 4)) + + data = KerasTensor((None, 4), dtype="float32") + segment_ids = KerasTensor((10,), dtype="int32") + outputs = segment_reduce_op(data, segment_ids, num_segments=5) + self.assertEqual(outputs.shape, (5, 4)) + + data = KerasTensor((10,), dtype="float32") + segment_ids = KerasTensor( + (10,), + dtype="int32", + ) + outputs = segment_reduce_op(data, segment_ids) + self.assertEqual(outputs.shape, (None,)) + + def test_top_k(self): + x = KerasTensor((None, 2, 3)) + values, indices = kmath.top_k(x, k=1) + self.assertEqual(values.shape, (None, 2, 1)) + self.assertEqual(indices.shape, (None, 2, 1)) + + def test_in_top_k(self): + targets = KerasTensor((None,)) + predictions = KerasTensor((None, 10)) + self.assertEqual( + kmath.in_top_k(targets, predictions, k=1).shape, (None,) + ) + + def test_logsumexp(self): + x = KerasTensor((None, 2, 3), dtype="float32") + self.assertEqual(kmath.logsumexp(x).shape, ()) + self.assertEqual(kmath.logsumexp(x, axis=1).shape, (None, 3)) + self.assertEqual(kmath.logsumexp(x, axis=(1, 2)).shape, (None,)) + self.assertEqual(kmath.logsumexp(x, keepdims=True).shape, (1, 1, 1)) + + def test_extract_sequences(self): + # Defined dimension + x = KerasTensor((None, 32), dtype="float32") + sequence_length = 3 + sequence_stride = 2 + outputs = kmath.extract_sequences(x, sequence_length, sequence_stride) + num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride + self.assertEqual(outputs.shape, (None, num_sequences, sequence_length)) + + # Undefined dimension + x = KerasTensor((None, None), dtype="float32") + sequence_length = 3 + sequence_stride = 2 + outputs = kmath.extract_sequences(x, sequence_length, sequence_stride) + self.assertEqual(outputs.shape, (None, None, sequence_length)) + + def test_fft(self): + real = KerasTensor((None, 4, 3), dtype="float32") + imag = KerasTensor((None, 4, 3), dtype="float32") + real_output, imag_output = kmath.fft((real, imag)) + ref = np.fft.fft(np.ones((2, 4, 3))) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(real_output.shape, ref_shape) + self.assertEqual(imag_output.shape, ref_shape) + + def test_fft2(self): + real = KerasTensor((None, 4, 3), dtype="float32") + imag = KerasTensor((None, 4, 3), dtype="float32") + real_output, imag_output = kmath.fft2((real, imag)) + ref = np.fft.fft2(np.ones((2, 4, 3))) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(real_output.shape, ref_shape) + self.assertEqual(imag_output.shape, ref_shape) + + def test_ifft2(self): + real = KerasTensor((None, 4, 3), dtype="float32") + imag = KerasTensor((None, 4, 3), dtype="float32") + real_output, imag_output = kmath.ifft2((real, imag)) + ref = np.fft.ifft2(np.ones((2, 4, 3))) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(real_output.shape, ref_shape) + self.assertEqual(imag_output.shape, ref_shape) + + @parameterized.parameters([(None,), (1,), (5,)]) + def test_rfft(self, fft_length): + x = KerasTensor((None, 4, 3), dtype="float32") + real_output, imag_output = kmath.rfft(x, fft_length=fft_length) + ref = np.fft.rfft(np.ones((2, 4, 3)), n=fft_length) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(real_output.shape, ref_shape) + self.assertEqual(imag_output.shape, ref_shape) + + @parameterized.parameters([(None,), (1,), (5,)]) + def test_irfft(self, fft_length): + real = KerasTensor((None, 4, 3), dtype="float32") + imag = KerasTensor((None, 4, 3), dtype="float32") + output = kmath.irfft((real, imag), fft_length=fft_length) + ref = np.fft.irfft(np.ones((2, 4, 3)), n=fft_length) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(output.shape, ref_shape) + + def test_stft(self): + x = KerasTensor((None, 32), dtype="float32") + sequence_length = 10 + sequence_stride = 3 + fft_length = 15 + real_output, imag_output = kmath.stft( + x, sequence_length, sequence_stride, fft_length + ) + real_ref, imag_ref = _stft( + np.ones((2, 32)), sequence_length, sequence_stride, fft_length + ) + real_ref_shape = (None,) + real_ref.shape[1:] + imag_ref_shape = (None,) + imag_ref.shape[1:] + self.assertEqual(real_output.shape, real_ref_shape) + self.assertEqual(imag_output.shape, imag_ref_shape) + + def test_istft(self): + sequence_length = 10 + sequence_stride = 3 + fft_length = 15 + real = KerasTensor((None, 32), dtype="float32") + imag = KerasTensor((None, 32), dtype="float32") + output = kmath.istft( + (real, imag), sequence_length, sequence_stride, fft_length + ) + ref = _istft( + (np.ones((5, 32)), np.ones((5, 32))), + sequence_length, + sequence_stride, + fft_length, + ) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(output.shape, ref_shape) + + def test_rsqrt(self): + x = KerasTensor([None, 3]) + self.assertEqual(kmath.rsqrt(x).shape, (None, 3)) + + def test_logdet(self): + x = KerasTensor((None, 3, 3)) + out = kmath.logdet(x) + self.assertEqual(out.shape, (None,)) + + +class MathOpsStaticShapeTest(testing.TestCase): + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + @pytest.mark.skipif( + backend.backend() == "jax", + reason="JAX does not support `num_segments=None`.", + ) + def test_segment_reduce(self, segment_reduce_op): + # 1D case + data = KerasTensor((10, 4), dtype="float32") + segment_ids = KerasTensor((10,), dtype="int32") + outputs = segment_reduce_op(data, segment_ids) + self.assertEqual(outputs.shape, (None, 4)) + + data = KerasTensor((10,), dtype="float32") + segment_ids = KerasTensor((10,), dtype="int32") + outputs = segment_reduce_op(data, segment_ids) + self.assertEqual(outputs.shape, (None,)) + + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + def test_segment_reduce_explicit_num_segments(self, segment_reduce_op): + # 1D case + data = KerasTensor((10, 4), dtype="float32") + segment_ids = KerasTensor((10,), dtype="int32") + outputs = segment_reduce_op(data, segment_ids, num_segments=5) + self.assertEqual(outputs.shape, (5, 4)) + + data = KerasTensor((6,), dtype="float32") + segment_ids = KerasTensor( + (6,), + dtype="int32", + ) + outputs = segment_reduce_op(data, segment_ids, num_segments=5) + self.assertEqual(outputs.shape, (5,)) + + def test_topk(self): + x = KerasTensor((1, 2, 3)) + values, indices = kmath.top_k(x, k=1) + self.assertEqual(values.shape, (1, 2, 1)) + self.assertEqual(indices.shape, (1, 2, 1)) + + def test_in_top_k(self): + targets = KerasTensor((5,)) + predictions = KerasTensor((5, 10)) + self.assertEqual(kmath.in_top_k(targets, predictions, k=1).shape, (5,)) + + def test_logsumexp(self): + x = KerasTensor((1, 2, 3), dtype="float32") + result = kmath.logsumexp(x) + self.assertEqual(result.shape, ()) + + def test_extract_sequences(self): + x = KerasTensor((10, 16), dtype="float32") + sequence_length = 3 + sequence_stride = 2 + outputs = kmath.extract_sequences(x, sequence_length, sequence_stride) + num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride + self.assertEqual(outputs.shape, (10, num_sequences, sequence_length)) + + def test_fft(self): + real = KerasTensor((2, 4, 3), dtype="float32") + imag = KerasTensor((2, 4, 3), dtype="float32") + real_output, imag_output = kmath.fft((real, imag)) + ref = np.fft.fft(np.ones((2, 4, 3))) + self.assertEqual(real_output.shape, ref.shape) + self.assertEqual(imag_output.shape, ref.shape) + + def test_fft2(self): + real = KerasTensor((2, 4, 3), dtype="float32") + imag = KerasTensor((2, 4, 3), dtype="float32") + real_output, imag_output = kmath.fft2((real, imag)) + ref = np.fft.fft2(np.ones((2, 4, 3))) + self.assertEqual(real_output.shape, ref.shape) + self.assertEqual(imag_output.shape, ref.shape) + + def test_ifft2(self): + real = KerasTensor((2, 4, 3), dtype="float32") + imag = KerasTensor((2, 4, 3), dtype="float32") + real_output, imag_output = kmath.ifft2((real, imag)) + ref = np.fft.ifft2(np.ones((2, 4, 3))) + self.assertEqual(real_output.shape, ref.shape) + self.assertEqual(imag_output.shape, ref.shape) + + def test_rfft(self): + x = KerasTensor((2, 4, 3), dtype="float32") + real_output, imag_output = kmath.rfft(x) + ref = np.fft.rfft(np.ones((2, 4, 3))) + self.assertEqual(real_output.shape, ref.shape) + self.assertEqual(imag_output.shape, ref.shape) + + def test_irfft(self): + real = KerasTensor((2, 4, 3), dtype="float32") + imag = KerasTensor((2, 4, 3), dtype="float32") + output = kmath.irfft((real, imag)) + ref = np.fft.irfft(np.ones((2, 4, 3))) + self.assertEqual(output.shape, ref.shape) + + def test_rsqrt(self): + x = KerasTensor([4, 3], dtype="float32") + self.assertEqual(kmath.rsqrt(x).shape, (4, 3)) + + def test_stft(self): + x = KerasTensor((2, 32), dtype="float32") + sequence_length = 10 + sequence_stride = 3 + fft_length = 15 + real_output, imag_output = kmath.stft( + x, sequence_length, sequence_stride, fft_length + ) + real_ref, imag_ref = _stft( + np.ones((2, 32)), sequence_length, sequence_stride, fft_length + ) + self.assertEqual(real_output.shape, real_ref.shape) + self.assertEqual(imag_output.shape, imag_ref.shape) + + def test_istft(self): + # sequence_stride must <= x[0].shape[-1] + # sequence_stride must >= fft_length / num_sequences + sequence_length = 10 + sequence_stride = 3 + fft_length = 15 + num_sequences = fft_length // sequence_stride + 1 + real = KerasTensor((num_sequences, 32), dtype="float32") + imag = KerasTensor((num_sequences, 32), dtype="float32") + output = kmath.istft( + (real, imag), sequence_length, sequence_stride, fft_length + ) + ref = _istft( + (np.ones((num_sequences, 32)), np.ones((num_sequences, 32))), + sequence_length, + sequence_stride, + fft_length, + ) + self.assertEqual(output.shape, ref.shape) + + def test_logdet(self): + x = KerasTensor((3, 3)) + out = kmath.logdet(x) + self.assertEqual(out.shape, ()) + + x = KerasTensor((2, 4, 3, 3)) + out = kmath.logdet(x) + self.assertEqual(out.shape, (2, 4)) + + +class MathOpsCorrectnessTest(testing.TestCase): + def run_segment_reduce_test( + self, + segment_reduce_op, + element_wise_reduce_method, + num_indices, + indices_high, + data_dims=tuple(), + num_segments=None, + add_neg1_to_indices=False, + sorted_indices=False, + ): + if num_segments is not None and indices_high >= num_segments: + raise ValueError("Indices high cannot be more than num segments") + indices_dims = (num_indices,) + full_data_dims = indices_dims + data_dims + data = np.random.rand(*full_data_dims).astype(np.float32) + segment_ids = np.concatenate( + [ + np.arange(indices_high), + np.random.randint( + low=0, + high=indices_high, + size=(indices_dims[0] - indices_high), + ), + ] + ).astype(np.int32) + if sorted_indices: + segment_ids = np.sort(segment_ids, axis=-1) + if add_neg1_to_indices: + segment_ids[0] = -1 + outputs = segment_reduce_op( + data, segment_ids, num_segments, sorted=sorted_indices + ) + if num_segments is None: + num_segments = np.max(segment_ids).item() + 1 + expected_shape = (num_segments,) + data_dims + if segment_reduce_op == kmath.segment_max: + if backend.backend() == "tensorflow": + empty_fill_value = -np.finfo(np.float32).max + else: + empty_fill_value = -np.inf + expected = np.full(expected_shape, empty_fill_value) + else: + expected = np.zeros(expected_shape) + + for idx in range(num_indices): + segment_id = segment_ids[idx] + if segment_id == -1: + continue + expected[segment_id] = element_wise_reduce_method( + expected[segment_id], data[idx] + ) + self.assertAllClose(outputs, expected) + + @parameterized.product( + ( + dict( + segment_reduce_op=kmath.segment_sum, + element_wise_reduce_method=_sum_reduce, + ), + dict( + segment_reduce_op=kmath.segment_max, + element_wise_reduce_method=_max_reduce, + ), + ), + sorted_indices=(True, False), + ) + @pytest.mark.skipif( + backend.backend() == "jax", + reason="JAX does not support `num_segments=None`.", + ) + def test_segment_reduce( + self, + segment_reduce_op, + element_wise_reduce_method, + sorted_indices, + ): + # Test 1D case. + self.run_segment_reduce_test( + segment_reduce_op, + element_wise_reduce_method, + num_indices=9, + indices_high=3, + sorted_indices=sorted_indices, + ) + + # Test ND data case. + self.run_segment_reduce_test( + segment_reduce_op, + element_wise_reduce_method, + num_indices=9, + indices_high=3, + data_dims=( + 3, + 3, + ), + sorted_indices=sorted_indices, + ) + + @parameterized.product( + ( + dict( + segment_reduce_op=kmath.segment_sum, + element_wise_reduce_method=_sum_reduce, + ), + dict( + segment_reduce_op=kmath.segment_max, + element_wise_reduce_method=_max_reduce, + ), + ), + ( + dict( + contains_neg1_in_indices=True, + sorted_indices=False, + ), + dict( + contains_neg1_in_indices=False, + sorted_indices=False, + ), + dict( + contains_neg1_in_indices=False, + sorted_indices=True, + ), + ), + ) + def test_segment_reduce_explicit_num_segments( + self, + segment_reduce_op, + element_wise_reduce_method, + contains_neg1_in_indices, + sorted_indices, + ): + if backend.backend() == "tensorflow" and sorted_indices: + pytest.skip( + "Num segments and sorted_indices=True doesn't work for " + "tensorflow." + ) + # Test 1D case. + self.run_segment_reduce_test( + segment_reduce_op, + element_wise_reduce_method, + num_indices=9, + indices_high=3, + num_segments=4, + add_neg1_to_indices=contains_neg1_in_indices, + sorted_indices=sorted_indices, + ) + + # Test ND data case. + self.run_segment_reduce_test( + segment_reduce_op, + element_wise_reduce_method, + num_indices=9, + indices_high=3, + data_dims=( + 3, + 3, + ), + num_segments=4, + add_neg1_to_indices=contains_neg1_in_indices, + sorted_indices=sorted_indices, + ) + + def test_top_k(self): + x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32) + values, indices = kmath.top_k(x, k=2) + self.assertAllClose(values, [4, 3]) + self.assertAllClose(indices, [1, 4]) + + x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32) + values, indices = kmath.top_k(x, k=2, sorted=False) + # Any order ok when `sorted=False`. + self.assertEqual(set(backend.convert_to_numpy(values)), set([4, 3])) + self.assertEqual(set(backend.convert_to_numpy(indices)), set([1, 4])) + + x = np.random.rand(5, 5) + outputs = kmath.top_k(x, k=2) + + expected_values = np.zeros((5, 2)) + expected_indices = np.zeros((5, 2), dtype=np.int32) + + for i in range(x.shape[0]): + top_k_indices = np.argsort(x[i])[-2:][::-1] + expected_values[i] = x[i, top_k_indices] + expected_indices[i] = top_k_indices + + self.assertAllClose(outputs[0], expected_values) + self.assertAllClose(outputs[1], expected_indices) + + def test_in_top_k(self): + targets = np.array([1, 0, 2]) + predictions = np.array( + [ + [0.1, 0.9, 0.8, 0.8], + [0.05, 0.95, 0, 1], + [0.1, 0.8, 0.3, 1], + ] + ) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=1), [True, False, False] + ) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=2), [True, False, False] + ) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=3), [True, True, True] + ) + + # Test tie cases. + targets = np.array([1, 0, 2]) + predictions = np.array( + [ + [0.1, 0.9, 0.8, 0.8], + [0.95, 0.95, 0, 0.95], + [0.1, 0.8, 0.8, 0.95], + ] + ) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=1), [True, True, False] + ) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=2), [True, True, True] + ) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=3), [True, True, True] + ) + + # Test `nan` in predictions + # https://github.com/keras-team/keras/issues/19995 + targets = np.array([1, 0]) + predictions = np.array([[0.1, np.nan, 0.5], [0.3, 0.2, 0.5]]) + self.assertAllEqual( + kmath.in_top_k(targets, predictions, k=2), [False, True] + ) + + def test_logsumexp(self): + x = np.random.rand(5, 5) + outputs = kmath.logsumexp(x) + expected = np.log(np.sum(np.exp(x))) + self.assertAllClose(outputs, expected) + + outputs = kmath.logsumexp(x, axis=1) + expected = np.log(np.sum(np.exp(x), axis=1)) + self.assertAllClose(outputs, expected) + + def test_extract_sequences(self): + # Test 1D case. + x = np.random.random((10,)) + sequence_length = 3 + sequence_stride = 2 + output = kmath.extract_sequences(x, sequence_length, sequence_stride) + + num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride + expected = np.zeros(shape=(num_sequences, sequence_length)) + pos = 0 + for i in range(num_sequences): + expected[i] = x[pos : pos + sequence_length] + pos += sequence_stride + self.assertAllClose(output, expected) + + # Test N-D case. + x = np.random.random((4, 8)) + sequence_length = 3 + sequence_stride = 2 + output = kmath.extract_sequences(x, sequence_length, sequence_stride) + + num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride + expected = np.zeros(shape=(4, num_sequences, sequence_length)) + pos = 0 + for i in range(num_sequences): + expected[:, i] = x[:, pos : pos + sequence_length] + pos += sequence_stride + self.assertAllClose(output, expected) + + def test_fft(self): + real = np.random.random((2, 4, 3)) + imag = np.random.random((2, 4, 3)) + complex_arr = real + 1j * imag + + real_output, imag_output = kmath.fft((real, imag)) + ref = np.fft.fft(complex_arr) + real_ref = np.real(ref) + imag_ref = np.imag(ref) + self.assertAllClose(real_ref, real_output) + self.assertAllClose(imag_ref, imag_output) + + def test_fft2(self): + real = np.random.random((2, 4, 3)) + imag = np.random.random((2, 4, 3)) + complex_arr = real + 1j * imag + + real_output, imag_output = kmath.fft2((real, imag)) + ref = np.fft.fft2(complex_arr) + real_ref = np.real(ref) + imag_ref = np.imag(ref) + self.assertAllClose(real_ref, real_output) + self.assertAllClose(imag_ref, imag_output) + + def test_ifft2(self): + real = np.random.random((2, 4, 3)).astype(np.float32) + imag = np.random.random((2, 4, 3)).astype(np.float32) + complex_arr = real + 1j * imag + + real_output, imag_output = kmath.ifft2((real, imag)) + ref = np.fft.ifft2(complex_arr) + real_ref = np.real(ref) + imag_ref = np.imag(ref) + self.assertAllClose(real_ref, real_output) + self.assertAllClose(imag_ref, imag_output) + + @parameterized.parameters([(None,), (3,), (15,)]) + def test_rfft(self, n): + # Test 1D. + x = np.random.random((10,)) + real_output, imag_output = kmath.rfft(x, fft_length=n) + ref = np.fft.rfft(x, n=n) + real_ref = np.real(ref) + imag_ref = np.imag(ref) + self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5) + self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5) + + # Test N-D case. + x = np.random.random((2, 3, 10)) + real_output, imag_output = kmath.rfft(x, fft_length=n) + ref = np.fft.rfft(x, n=n) + real_ref = np.real(ref) + imag_ref = np.imag(ref) + self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5) + self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5) + + @parameterized.parameters([(None,), (3,), (15,)]) + def test_irfft(self, n): + # Test 1D. + real = np.random.random((10,)) + imag = np.random.random((10,)) + complex_arr = real + 1j * imag + output = kmath.irfft((real, imag), fft_length=n) + ref = np.fft.irfft(complex_arr, n=n) + self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5) + + # Test N-D case. + real = np.random.random((2, 3, 10)) + imag = np.random.random((2, 3, 10)) + complex_arr = real + 1j * imag + output = kmath.irfft((real, imag), fft_length=n) + ref = np.fft.irfft(complex_arr, n=n) + self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5) + + @parameterized.parameters( + [ + (32, 8, 32, "hann", True), + (8, 8, 16, "hann", True), + (4, 4, 7, "hann", True), + (32, 8, 32, "hamming", True), + (32, 8, 32, "hann", False), + (32, 8, 32, np.ones((32,)), True), + (32, 8, 32, None, True), + ] + ) + def test_stft( + self, sequence_length, sequence_stride, fft_length, window, center + ): + # Test 1D case. + x = np.random.random((32,)) + real_output, imag_output = kmath.stft( + x, sequence_length, sequence_stride, fft_length, window, center + ) + real_ref, imag_ref = _stft( + x, sequence_length, sequence_stride, fft_length, window, center + ) + self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5) + self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5) + + # Test N-D case. + x = np.random.random((2, 3, 32)) + real_output, imag_output = kmath.stft( + x, sequence_length, sequence_stride, fft_length, window, center + ) + real_ref, imag_ref = _stft( + x, sequence_length, sequence_stride, fft_length, window, center + ) + self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5) + self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5) + + @parameterized.parameters( + [ + (32, 8, 32, "hann", True), + (8, 8, 16, "hann", True), + (4, 4, 7, "hann", True), + (32, 8, 32, "hamming", True), + (8, 4, 8, "hann", False), + (32, 8, 32, np.ones((32,)), True), + (32, 8, 32, None, True), + ] + ) + def test_istft( + self, sequence_length, sequence_stride, fft_length, window, center + ): + # sequence_stride must <= x[0].shape[-1] + # sequence_stride must >= fft_length / num_sequences + # Test 1D case. + x = np.random.random((256,)) + real_x, imag_x = _stft( + x, sequence_length, sequence_stride, fft_length, window, center + ) + output = kmath.istft( + (real_x, imag_x), + sequence_length, + sequence_stride, + fft_length, + window=window, + center=center, + ) + ref = _istft( + (real_x, imag_x), + sequence_length, + sequence_stride, + fft_length, + window=window, + center=center, + ) + if backend.backend() in ("numpy", "jax", "torch"): + # these backends have different implementation for the boundary of + # the output, so we need to truncate 5% before assertAllClose + truncated_len = int(output.shape[-1] * 0.05) + output = output[..., truncated_len:-truncated_len] + ref = ref[..., truncated_len:-truncated_len] + # Nans are handled differently in different backends, so zero them out. + output = np.nan_to_num(backend.convert_to_numpy(output), nan=0.0) + ref = np.nan_to_num(ref, nan=0.0) + self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5) + + # Test N-D case. + x = np.random.random((2, 3, 256)) + real_x, imag_x = _stft( + x, sequence_length, sequence_stride, fft_length, window, center + ) + output = kmath.istft( + (real_x, imag_x), + sequence_length, + sequence_stride, + fft_length, + window=window, + center=center, + ) + ref = _istft( + (real_x, imag_x), + sequence_length, + sequence_stride, + fft_length, + window=window, + center=center, + ) + if backend.backend() in ("numpy", "jax", "torch"): + # these backends have different implementation for the boundary of + # the output, so we need to truncate 5% before assertAllClose + truncated_len = int(output.shape[-1] * 0.05) + output = output[..., truncated_len:-truncated_len] + ref = ref[..., truncated_len:-truncated_len] + # Nans are handled differently in different backends, so zero them out. + output = np.nan_to_num(backend.convert_to_numpy(output), nan=0.0) + ref = np.nan_to_num(ref, nan=0.0) + self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5) + + def test_rsqrt(self): + x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") + self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) + self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x)) + + def test_erf_operation_basic(self): + # Sample values for testing + sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + + # Expected output using numpy's approximation of the error function + expected_output = scipy.special.erf(sample_values) + + # Output from the erf operation in keras_core + output_from_erf_op = kmath.erf(sample_values) + + # Assert that the outputs are close + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4) + + def test_erf_operation_dtype(self): + # Test for float32 and float64 data types + for dtype in ("float32", "float64"): + sample_values = np.array( + [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype + ) + expected_output = scipy.special.erf(sample_values) + output_from_erf_op = kmath.erf(sample_values) + self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4) + + def test_erf_operation_edge_cases(self): + # Test for edge cases + edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64) + expected_output = scipy.special.erf(edge_values) + output_from_edge_erf_op = kmath.erf(edge_values) + self.assertAllClose(expected_output, output_from_edge_erf_op, atol=1e-4) + + def test_erfinv_operation_basic(self): + # Sample values for testing + sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + + # Expected output using numpy's approximation of the error function + expected_output = scipy.special.erfinv(sample_values) + + # Output from the erf operation in keras_core + output_from_erfinv_op = kmath.erfinv(sample_values) + + # Assert that the outputs are close + self.assertAllClose(expected_output, output_from_erfinv_op, atol=1e-4) + + def test_erfinv_operation_dtype(self): + # Test for float32 and float64 data types + for dtype in ("float32", "float64"): + sample_values = np.array( + [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype + ) + expected_output = scipy.special.erfinv(sample_values) + output_from_erfinv_op = kmath.erfinv(sample_values) + self.assertAllClose( + expected_output, output_from_erfinv_op, atol=1e-4 + ) + + def test_erfinv_operation_edge_cases(self): + # Test for edge cases + edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64) + expected_output = scipy.special.erfinv(edge_values) + output_from_edge_erfinv_op = kmath.erfinv(edge_values) + self.assertAllClose( + expected_output, output_from_edge_erfinv_op, atol=1e-4 + ) + + def test_logdet(self): + x = np.array( + [ + [4.42, -1.18, 0.06, 0.74], + [-1.18, 1.77, -0.84, -1.16], + [0.06, -0.84, 5.84, 0.55], + [0.74, -1.16, 0.55, 0.77], + ], + dtype="float32", + ) + out = kmath.logdet(x) + self.assertAllClose(out, -1.1178946, atol=1e-3) + + +class MathDtypeTest(testing.TestCase): + """Test the floating dtype to verify that the behavior matches JAX.""" + + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x + not in ( + "string", + "complex64", + "complex128", + # Remove 64-bit dtypes. + "float64", + "uint64", + "int64", + ) + + dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests + ] + [None] + INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")] + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + + if backend.backend() == "torch": + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")] + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")] + + +class ExtractSequencesOpTest(testing.TestCase): + def test_extract_sequences_init_length_1_stride_1(self): + extract_op = kmath.ExtractSequences( + sequence_length=1, sequence_stride=1 + ) + self.assertIsNotNone(extract_op) + self.assertEqual(extract_op.sequence_length, 1) + self.assertEqual(extract_op.sequence_stride, 1) + + def test_extract_sequences_init_length_5_stride_2(self): + extract_op = kmath.ExtractSequences( + sequence_length=5, sequence_stride=2 + ) + self.assertIsNotNone(extract_op) + self.assertEqual(extract_op.sequence_length, 5) + self.assertEqual(extract_op.sequence_stride, 2) + + def test_compute_output_spec_low_rank(self): + extract_op = kmath.ExtractSequences( + sequence_length=5, sequence_stride=1 + ) + low_rank_input = np.array(42) + error_message = r"Input should have rank >= 1. Received: .*" + with self.assertRaisesRegex(ValueError, error_message): + extract_op.compute_output_spec(low_rank_input) + + def test_extract_sequences_call(self): + sequence_length, sequence_stride = 5, 2 + extract_op = kmath.ExtractSequences(sequence_length, sequence_stride) + test_input = np.random.rand(10, 20) + result = extract_op.call(test_input) + + expected_shape = self.calculate_expected_shape( + test_input.shape, sequence_length, sequence_stride + ) + self.assertEqual(result.shape, expected_shape) + + def calculate_expected_shape( + self, input_shape, sequence_length, sequence_stride + ): + num_sequences = ( + (input_shape[1] - sequence_length) // sequence_stride + ) + 1 + return (input_shape[0], num_sequences, sequence_length) + + +class SegmentSumTest(testing.TestCase): + def test_segment_sum_call(self): + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) + segment_ids = np.array([0, 0, 1], dtype=np.int32) + num_segments = 2 + sorted_segments = False + segment_sum_op = kmath.SegmentSum( + num_segments=num_segments, sorted=sorted_segments + ) + output = segment_sum_op.call(data, segment_ids) + expected_output = np.array([[5, 7, 9], [7, 8, 9]], dtype=np.float32) + self.assertAllClose(output, expected_output) + + +class SegmentMaxTest(testing.TestCase): + def test_segment_max_call(self): + data = np.array([[1, 4, 7], [2, 5, 8], [3, 6, 9]], dtype=np.float32) + segment_ids = np.array([0, 0, 1], dtype=np.int32) + num_segments = 2 + sorted_segments = False + segment_max_op = kmath.SegmentMax( + num_segments=num_segments, sorted=sorted_segments + ) + output = segment_max_op.call(data, segment_ids) + expected_output = np.array([[2, 5, 8], [3, 6, 9]], dtype=np.float32) + self.assertAllClose(output, expected_output) + + +class TopKTest(testing.TestCase): + def test_top_k_call_values(self): + data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32) + k = 2 + sorted_flag = True + top_k_op = kmath.TopK(k=k, sorted=sorted_flag) + values, _ = top_k_op.call(data) + expected_values = np.array([[3, 2], [6, 5]], dtype=np.float32) + self.assertAllClose(values, expected_values) + + def test_top_k_call_indices(self): + data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32) + k = 2 + sorted_flag = True + top_k_op = kmath.TopK(k=k, sorted=sorted_flag) + _, indices = top_k_op.call(data) + expected_indices = np.array([[1, 2], [1, 2]], dtype=np.int32) + self.assertAllClose(indices, expected_indices) + + +class InTopKTest(testing.TestCase): + def test_in_top_k_call(self): + targets = np.array([2, 0, 1], dtype=np.int32) + predictions = np.array( + [[0.1, 0.2, 0.7], [1.0, 0.2, 0.3], [0.2, 0.6, 0.2]], + dtype=np.float32, + ) + k = 2 + in_top_k_op = kmath.InTopK(k=k) + output = in_top_k_op.call(targets, predictions) + expected_output = np.array([True, True, True], dtype=bool) + self.assertAllEqual(output, expected_output) + + +class LogsumexpTest(testing.TestCase): + def test_logsumexp_call(self): + x = np.array([[1, 2], [3, 4]], dtype=np.float32) + axis = 0 + keepdims = True + logsumexp_op = kmath.Logsumexp(axis=axis, keepdims=keepdims) + output = logsumexp_op.call(x) + expected_output = np.log( + np.sum(np.exp(x), axis=axis, keepdims=keepdims) + ) + self.assertAllClose(output, expected_output) + + +class FFTTest(testing.TestCase): + def test_fft_input_not_tuple_or_list(self): + fft_op = kmath.FFT() + with self.assertRaisesRegex( + ValueError, "Input `x` should be a tuple of two tensors" + ): + fft_op.compute_output_spec(np.array([1, 2, 3])) + + def test_fft_input_parts_different_shapes(self): + fft_op = kmath.FFT() + real = np.array([1, 2, 3]) + imag = np.array([1, 2]) + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + fft_op.compute_output_spec((real, imag)) + + def test_fft_input_not_1d(self): + fft_op = kmath.FFT() + real = np.array(1) + imag = np.array(1) + with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"): + fft_op.compute_output_spec((real, imag)) + + def test_fft_last_axis_not_fully_defined(self): + fft_op = kmath.FFT() + real = KerasTensor(shape=(None,), dtype="float32") + imag = KerasTensor(shape=(None,), dtype="float32") + with self.assertRaisesRegex( + ValueError, "Input should have its last dimension fully-defined" + ): + fft_op.compute_output_spec((real, imag)) + + +class FFT2Test(testing.TestCase): + def test_fft2_correct_input(self): + fft2_op = kmath.FFT2() + real_part = np.random.rand(2, 3, 4) + imag_part = np.random.rand(2, 3, 4) + # This should not raise any errors + fft2_op.compute_output_spec((real_part, imag_part)) + + def test_fft2_incorrect_input_type(self): + fft2_op = kmath.FFT2() + incorrect_input = np.array([1, 2, 3]) # Not a tuple or list + with self.assertRaisesRegex( + ValueError, "should be a tuple of two tensors" + ): + fft2_op.compute_output_spec(incorrect_input) + + def test_fft2_mismatched_shapes(self): + fft2_op = kmath.FFT2() + real_part = np.random.rand(2, 3, 4) + imag_part = np.random.rand(2, 3) # Mismatched shape + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + fft2_op.compute_output_spec((real_part, imag_part)) + + def test_fft2_low_rank(self): + fft2_op = kmath.FFT2() + low_rank_input = np.random.rand(3) # Rank of 1 + with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"): + fft2_op.compute_output_spec((low_rank_input, low_rank_input)) + + def test_fft2_undefined_dimensions(self): + fft2_op = kmath.FFT2() + real_part = KerasTensor(shape=(None, None, 3), dtype="float32") + imag_part = KerasTensor(shape=(None, None, 3), dtype="float32") + with self.assertRaisesRegex( + ValueError, "Input should have its .* axes fully-defined" + ): + fft2_op.compute_output_spec((real_part, imag_part)) + + +class RFFTTest(testing.TestCase): + def test_rfft_low_rank_input(self): + rfft_op = kmath.RFFT() + low_rank_input = np.array(5) + with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"): + rfft_op.compute_output_spec(low_rank_input) + + def test_rfft_defined_fft_length(self): + fft_length = 10 + rfft_op = kmath.RFFT(fft_length=fft_length) + input_tensor = np.random.rand(3, 8) + + expected_last_dimension = fft_length // 2 + 1 + expected_shape = input_tensor.shape[:-1] + (expected_last_dimension,) + + output_tensors = rfft_op.compute_output_spec(input_tensor) + for output_tensor in output_tensors: + self.assertEqual(output_tensor.shape, expected_shape) + + def test_rfft_undefined_fft_length_defined_last_dim(self): + rfft_op = kmath.RFFT() + input_tensor = np.random.rand(3, 8) + expected_last_dimension = input_tensor.shape[-1] // 2 + 1 + expected_shape = input_tensor.shape[:-1] + ( + expected_last_dimension, + ) + output_tensors = rfft_op.compute_output_spec(input_tensor) + for output_tensor in output_tensors: + self.assertEqual(output_tensor.shape, expected_shape) + + def test_rfft_undefined_fft_length_undefined_last_dim(self): + rfft_op = kmath.RFFT() + input_tensor = KerasTensor(shape=(None, None), dtype="float32") + expected_shape = input_tensor.shape[:-1] + (None,) + output_tensors = rfft_op.compute_output_spec(input_tensor) + for output_tensor in output_tensors: + self.assertEqual(output_tensor.shape, expected_shape) + + +class ISTFTTest(testing.TestCase): + def test_istft_incorrect_input_type(self): + istft_op = kmath.ISTFT( + sequence_length=5, sequence_stride=2, fft_length=10 + ) + incorrect_input = np.array([1, 2, 3]) + with self.assertRaisesRegex( + ValueError, "should be a tuple of two tensors" + ): + istft_op.compute_output_spec(incorrect_input) + + def test_istft_mismatched_shapes(self): + istft_op = kmath.ISTFT( + sequence_length=5, sequence_stride=2, fft_length=10 + ) + real_part = np.random.rand(2, 3, 4) + imag_part = np.random.rand(2, 3) + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + istft_op.compute_output_spec((real_part, imag_part)) + + def test_istft_low_rank_input(self): + istft_op = kmath.ISTFT( + sequence_length=5, sequence_stride=2, fft_length=10 + ) + low_rank_input = np.random.rand(3) + with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"): + istft_op.compute_output_spec((low_rank_input, low_rank_input)) + + def test_input_not_tuple_or_list_raises_error(self): + irfft_op = kmath.IRFFT() + invalid_input = np.array([1, 2, 3]) + with self.assertRaisesRegex( + ValueError, "Input `x` should be a tuple of two tensors" + ): + irfft_op.compute_output_spec(invalid_input) + + def test_input_tuple_with_less_than_two_elements_raises_error(self): + irfft_op = kmath.IRFFT() + too_short_input = (np.array([1, 2, 3]),) + with self.assertRaisesRegex( + ValueError, "Input `x` should be a tuple of two tensors" + ): + irfft_op.compute_output_spec(too_short_input) + + def test_input_tuple_with_more_than_two_elements_raises_error(self): + irfft_op = kmath.IRFFT() + too_long_input = ( + np.array([1, 2, 3]), + np.array([4, 5, 6]), + np.array([7, 8, 9]), + ) + with self.assertRaisesRegex( + ValueError, "Input `x` should be a tuple of two tensors" + ): + irfft_op.compute_output_spec(too_long_input) + + def test_mismatched_shapes_input_validation(self): + irfft_op = kmath.IRFFT() + + # Create real and imaginary parts with mismatched shapes + real_part = np.array([1, 2, 3]) + imag_part = np.array([[1, 2], [3, 4]]) + + with self.assertRaisesRegex( + ValueError, + "Both the real and imaginary parts should have the same shape", + ): + irfft_op.compute_output_spec((real_part, imag_part)) + + def test_insufficient_rank_input_validation(self): + irfft_op = kmath.IRFFT() + + # Create real and imaginary parts with insufficient rank (0D) + real_part = np.array(1) + imag_part = np.array(1) + + with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"): + irfft_op.compute_output_spec((real_part, imag_part)) + + def test_with_specified_fft_length(self): + fft_length = 10 + irfft_op = kmath.IRFFT(fft_length=fft_length) + + real_part = np.random.rand(4, 8) + imag_part = np.random.rand(4, 8) + + expected_shape = real_part.shape[:-1] + (fft_length,) + output_shape = irfft_op.compute_output_spec( + (real_part, imag_part) + ).shape + + self.assertEqual(output_shape, expected_shape) + + def test_inferred_fft_length_with_defined_last_dimension(self): + irfft_op = kmath.IRFFT() + + real_part = np.random.rand(4, 8) + imag_part = np.random.rand(4, 8) + + inferred_fft_length = 2 * (real_part.shape[-1] - 1) + expected_shape = real_part.shape[:-1] + (inferred_fft_length,) + output_shape = irfft_op.compute_output_spec( + (real_part, imag_part) + ).shape + + self.assertEqual(output_shape, expected_shape) + + def test_undefined_fft_length_and_last_dimension(self): + irfft_op = kmath.IRFFT() + + real_part = KerasTensor(shape=(4, None), dtype="float32") + imag_part = KerasTensor(shape=(4, None), dtype="float32") + + output_spec = irfft_op.compute_output_spec((real_part, imag_part)) + expected_shape = real_part.shape[:-1] + (None,) + + self.assertEqual(output_spec.shape, expected_shape) + + +class TestMathErrors(testing.TestCase): + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + @pytest.mark.skipif( + backend.backend() != "jax", reason="Testing Jax errors only" + ) + def test_segment_reduce_no_num_segments(self, segment_reduce_op): + data = jnp.array([1, 2, 3, 4]) + segment_ids = jnp.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` must be set when using the JAX backend.", + ): + segment_reduce_op(data, segment_ids) + + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Tensorflow error only" + ) + def test_segment_reduce_sort_and_num_segments(self, segment_reduce_op): + data = np.array([1, 2, 3, 4]) + segment_ids = np.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `num_segments` cannot be set when sorted is True when " + "using the tensorflow backend.", + ): + segment_reduce_op(data, segment_ids, num_segments=2, sorted=True) + + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + def test_segment_reduce_multi_dim_segment_ids(self, segment_reduce_op): + data = np.array([1, 2, 3, 4]) + segment_ids = np.array([0, 0, 1, 1]).reshape((2, 2)) + with self.assertRaisesRegex( + ValueError, + "Argument `segment_ids` should be an 1-D vector,", + ): + segment_reduce_op(data, segment_ids) + + @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)]) + def test_segment_reduce_leading_not_match(self, segment_reduce_op): + data = np.array([]) + segment_ids = np.array([0, 0, 1, 1]) + with self.assertRaisesRegex( + ValueError, + "Argument `segment_ids` and `data` should have same leading " + "dimension.", + ): + segment_reduce_op(data, segment_ids) + + output_tensor = segment_reduce_op( + KerasTensor(shape=(None, 4)), KerasTensor(shape=(5,)) + ) + self.assertEqual(output_tensor.shape, (None, 4)) + + output_tensor = segment_reduce_op( + KerasTensor(shape=(5, 4)), KerasTensor(shape=(None,)) + ) + self.assertEqual(output_tensor.shape, (None, 4)) + + output_tensor = segment_reduce_op( + KerasTensor(shape=(None, 4)), KerasTensor(shape=(None,)) + ) + self.assertEqual(output_tensor.shape, (None, 4)) + + def test_stft_invalid_input_type(self): + # backend agnostic error message + x = np.array([1, 2, 3, 4]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + with self.assertRaisesRegex(TypeError, "`float32` or `float64`"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_invalid_fft_length(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 4 + sequence_stride = 1 + fft_length = 2 + with self.assertRaisesRegex(ValueError, "`fft_length` must equal or"): + kmath.stft(x, sequence_length, sequence_stride, fft_length) + + def test_stft_invalid_window(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = "invalid_window" + with self.assertRaisesRegex(ValueError, "If a string is passed to"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_stft_invalid_window_shape(self): + # backend agnostic error message + x = np.array([1.0, 2.0, 3.0, 4.0]) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + window = np.ones((sequence_length + 1)) + with self.assertRaisesRegex(ValueError, "The shape of `window` must"): + kmath.stft( + x, sequence_length, sequence_stride, fft_length, window=window + ) + + def test_istft_invalid_window_shape_2D_inputs(self): + # backend agnostic error message + x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])) + sequence_length = 2 + sequence_stride = 1 + fft_length = 4 + incorrect_window = np.ones((sequence_length + 1,)) + with self.assertRaisesRegex( + ValueError, "The shape of `window` must be equal to" + ): + kmath.istft( + x, + sequence_length, + sequence_stride, + fft_length, + window=incorrect_window, + ) + + +@pytest.mark.skipif( + backend.backend() == "openvino", + reason="Complex dtype is not supported on OpenVINO backend.", +) +class ViewAsComplexRealTest(testing.TestCase): + def test_view_as_complex_basic(self): + real_imag = np.array([[1.0, 2.0], [3.0, 4.0]]) + expected = np.array([1.0 + 2.0j, 3.0 + 4.0j], dtype=np.complex64) + + result = kmath.view_as_complex(real_imag) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) + + def test_view_as_real_basic(self): + complex_tensor = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) + expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + result = kmath.view_as_real(complex_tensor) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) + + def test_view_as_complex_invalid_shape(self): + bad_input = np.array([1.0, 2.0, 3.0]) # Last dimension not size 2 + with self.assertRaisesRegex( + ValueError, "Last dimension of input must be size 2" + ): + kmath.view_as_complex(bad_input) + + def test_view_as_complex_symbolic_input(self): + x = KerasTensor(shape=(None, 2), dtype="float32") + result = kmath.view_as_complex(x) + + self.assertEqual(result.shape, (None,)) + self.assertEqual(standardize_dtype(result.dtype), "complex64") + + def test_view_as_real_symbolic_input(self): + x = KerasTensor(shape=(None,), dtype="complex64") + result = kmath.view_as_real(x) + + self.assertEqual(result.shape, (None, 2)) + self.assertEqual(standardize_dtype(result.dtype), "float32") + + def test_view_as_complex_multi_dimensional(self): + x = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32) + expected = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64) + + result = kmath.view_as_complex(x) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) + + def test_view_as_real_multi_dimensional(self): + x = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64) + expected = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32) + + result = kmath.view_as_real(x) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py new file mode 100644 index 000000000000..23792400ae4e --- /dev/null +++ b/keras/src/ops/nn.py @@ -0,0 +1,3147 @@ +"""Commonly-used neural network operations not included in NumPy.""" + +import warnings + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.backend import standardize_data_format +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_output_shape, +) +from keras.src.ops import operation_utils +from keras.src.ops.operation import Operation +from keras.src.ops.operation_utils import reduce_shape +from keras.src.utils.python_utils import is_continuous_axis + + +class Relu(Operation): + def call(self, x): + return backend.nn.relu(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.relu", "keras.ops.nn.relu"]) +def relu(x): + """Rectified linear unit activation function. + + It is defined as `f(x) = max(0, x)`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x1 = keras.ops.convert_to_tensor([-1.0, 0.0, 1.0, 0.2]) + >>> keras.ops.relu(x1) + array([0.0, 0.0, 1.0, 0.2], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Relu().symbolic_call(x) + return backend.nn.relu(x) + + +class Relu6(Operation): + def call(self, x): + return backend.nn.relu6(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.relu6", "keras.ops.nn.relu6"]) +def relu6(x): + """Rectified linear unit activation function with upper bound of 6. + + It is defined as `f(x) = np.clip(x, 0, 6)`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-3.0, -2.0, 0.1, 0.2, 6.0, 8.0]) + >>> keras.ops.relu6(x) + array([0.0, 0.0, 0.1, 0.2, 6.0, 6.0], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Relu6().symbolic_call(x) + return backend.nn.relu6(x) + + +class Sigmoid(Operation): + def call(self, x): + return backend.nn.sigmoid(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.sigmoid", "keras.ops.nn.sigmoid"]) +def sigmoid(x): + """Sigmoid activation function. + + It is defined as `f(x) = 1 / (1 + exp(-x))`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0]) + >>> keras.ops.sigmoid(x) + array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return Sigmoid().symbolic_call(x) + return backend.nn.sigmoid(x) + + +class SparseSigmoid(Operation): + def call(self, x): + return backend.nn.sparse_sigmoid(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.sparse_sigmoid", "keras.ops.nn.sparse_sigmoid"]) +def sparse_sigmoid(x): + """Sparse sigmoid activation function. + + It is defined as + + `f(x) = 0` for `x <= -1`, + `f(x) = 0.5 * (x + 1)` for `-1 < x < 1`, + `f(x) = 1` for `x >= 1`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0]) + >>> keras.ops.sparse_sigmoid(x) + array([0. , 1. , 0.5, 1. , 1. ], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return SparseSigmoid().symbolic_call(x) + return backend.nn.sparse_sigmoid(x) + + +class Softplus(Operation): + def call(self, x): + return backend.nn.softplus(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.softplus", "keras.ops.nn.softplus"]) +def softplus(x): + """Softplus activation function. + + It is defined as `f(x) = log(exp(x) + 1)`, where `log` is the natural + logarithm and `exp` is the exponential function. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-0.555, 0.0, 0.555]) + >>> keras.ops.softplus(x) + array([0.45366603, 0.6931472, 1.008666], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return Softplus().symbolic_call(x) + return backend.nn.softplus(x) + + +class Softsign(Operation): + def call(self, x): + return backend.nn.softsign(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.softsign", "keras.ops.nn.softsign"]) +def softsign(x): + """Softsign activation function. + + It is defined as `f(x) = x / (abs(x) + 1)`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-0.100, -10.0, 1.0, 0.0, 100.0]) + >>> keras.ops.softsign(x) + Array([-0.09090909, -0.90909094, 0.5, 0.0, 0.990099], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return Softsign().symbolic_call(x) + return backend.nn.softsign(x) + + +class SoftShrink(Operation): + def __init__(self, threshold=0.5, *, name=None): + super().__init__(name=name) + self.threshold = threshold + + def call(self, x): + return backend.nn.soft_shrink(x, self.threshold) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.soft_shrink", "keras.ops.nn.soft_shrink"]) +def soft_shrink(x, threshold=0.5): + """Soft Shrink activation function. + + It is defined as + + `f(x) = x - threshold` if `x > threshold`, + `f(x) = x + threshold` if `x < -threshold`, + `f(x) = 0` otherwise. + + Args: + x: Input tensor. + threshold: Threshold value. Defaults to 0.5. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1.0, 0.0, 1.0]) + >>> x_soft_shrink = keras.ops.soft_shrink(x) + >>> print(x_soft_shrink) + array([-0.5 0. 0.5], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return SoftShrink(threshold).symbolic_call(x) + return backend.nn.soft_shrink(x, threshold) + + +class SparsePlus(Operation): + def call(self, x): + return backend.nn.sparse_plus(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.sparse_plus", "keras.ops.nn.sparse_plus"]) +def sparse_plus(x): + """SparsePlus activation function. + + It is defined as + + `f(x) = 0` for `x <= -1`. + `f(x) = (1/4) * (x + 1)^2` for `-1 < x < 1`. + `f(x) = x` for `x >= 1`. + + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1.0, 0.0, 1.0]) + >>> x_sparse_plus = keras.ops.sparse_plus(x) + >>> print(x_sparse_plus) + Array([0. 0.25 1. ], shape=(3,), dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return SparsePlus().symbolic_call(x) + return backend.nn.sparse_plus(x) + + +class Silu(Operation): + def call(self, x): + return backend.nn.silu(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.silu", + "keras.ops.nn.silu", + "keras.ops.swish", + "keras.ops.nn.swish", + ] +) +def silu(x): + """Sigmoid Linear Unit (SiLU) activation function, also known as Swish. + + The SiLU activation function is computed by the sigmoid function multiplied + by its input. It is defined as `f(x) = x * sigmoid(x)`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0]) + >>> keras.ops.sigmoid(x) + array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32) + >>> keras.ops.silu(x) + array([-0.0148357, 0.7310586, 0.0, 0.7310586, 5.9851646], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return Silu().symbolic_call(x) + return backend.nn.silu(x) + + +class Squareplus(Operation): + def __init__(self, b=4, *, name=None): + super().__init__(name=name) + self.b = b + + def call(self, x): + return backend.nn.squareplus(x, self.b) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.squareplus", "keras.ops.nn.squareplus"]) +def squareplus(x, b=4): + """Squareplus activation function. + + The Squareplus activation function is defined as: + + `f(x) = (x + sqrt(x^2 + b)) / 2` + + Args: + x: Input tensor. + b: Smoothness parameter. Defaults to 4. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1.0, 0.0, 1.0]) + >>> x_squareplus = keras.ops.squareplus(x) + >>> print(x_squareplus) + array([0.6180, 1.0000, 1.6180], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return Squareplus(b).symbolic_call(x) + return backend.nn.squareplus(x, b) + + +class LogSigmoid(Operation): + def call(self, x): + return backend.nn.log_sigmoid(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.log_sigmoid", + "keras.ops.nn.log_sigmoid", + ] +) +def log_sigmoid(x): + """Logarithm of the sigmoid activation function. + + It is defined as `f(x) = log(1 / (1 + exp(-x)))`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-0.541391, 0.0, 0.50, 5.0]) + >>> keras.ops.log_sigmoid(x) + array([-1.0000418, -0.6931472, -0.474077, -0.00671535], dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return LogSigmoid().symbolic_call(x) + return backend.nn.log_sigmoid(x) + + +class LeakyRelu(Operation): + def __init__(self, negative_slope=0.2, *, name=None): + super().__init__(name=name) + self.negative_slope = negative_slope + + def call(self, x): + return backend.nn.leaky_relu(x, self.negative_slope) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.leaky_relu", "keras.ops.nn.leaky_relu"]) +def leaky_relu(x, negative_slope=0.2): + """Leaky version of a Rectified Linear Unit activation function. + + It allows a small gradient when the unit is not active, it is defined as: + + `f(x) = alpha * x for x < 0` or `f(x) = x for x >= 0`. + + Args: + x: Input tensor. + negative_slope: Slope of the activation function at x < 0. + Defaults to `0.2`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_leaky_relu = keras.ops.leaky_relu(x) + >>> print(x_leaky_relu) + array([-0.2, 0. , 1. ], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return LeakyRelu(negative_slope).symbolic_call(x) + return backend.nn.leaky_relu(x, negative_slope=negative_slope) + + +class HardSigmoid(Operation): + def call(self, x): + return backend.nn.hard_sigmoid(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.hard_sigmoid", + "keras.ops.nn.hard_sigmoid", + ] +) +def hard_sigmoid(x): + """Hard sigmoid activation function. + + It is defined as: + + `0 if x < -2.5`, `1 if x > 2.5`, `(0.2 * x) + 0.5 if -2.5 <= x <= 2.5`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_hard_sigmoid = keras.ops.hard_sigmoid(x) + >>> print(x_hard_sigmoid) + array([0.3, 0.5, 0.7], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return HardSigmoid().symbolic_call(x) + return backend.nn.hard_sigmoid(x) + + +class HardSilu(Operation): + def call(self, x): + return backend.nn.hard_silu(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.hard_silu", + "keras.ops.nn.hard_silu", + "keras.ops.hard_swish", + "keras.ops.nn.hard_swish", + ] +) +def hard_silu(x): + """Hard SiLU activation function, also known as Hard Swish. + + It is defined as: + + - `0` if `if x < -3` + - `x` if `x > 3` + - `x * (x + 3) / 6` if `-3 <= x <= 3` + + It's a faster, piecewise linear approximation of the silu activation. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-3.0, -1.0, 0.0, 1.0, 3.0]) + >>> keras.ops.hard_silu(x) + array([-0.0, -0.3333333, 0.0, 0.6666667, 3.0], shape=(5,), dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return HardSilu().symbolic_call(x) + return backend.nn.hard_silu(x) + + +class Elu(Operation): + def __init__(self, alpha=1.0, *, name=None): + super().__init__(name=name) + self.alpha = alpha + + def call(self, x): + return backend.nn.elu(x, alpha=self.alpha) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.elu", "keras.ops.nn.elu"]) +def elu(x, alpha=1.0): + """Exponential Linear Unit activation function. + + It is defined as: + + `f(x) = alpha * (exp(x) - 1.) for x < 0`, `f(x) = x for x >= 0`. + + Args: + x: Input tensor. + alpha: A scalar, slope of positive section. Defaults to `1.0`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_elu = keras.ops.elu(x) + >>> print(x_elu) + array([-0.63212055, 0., 1.], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Elu(alpha).symbolic_call(x) + return backend.nn.elu(x, alpha=alpha) + + +class Selu(Operation): + def call(self, x): + return backend.nn.selu(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.selu", "keras.ops.nn.selu"]) +def selu(x): + """Scaled Exponential Linear Unit (SELU) activation function. + + It is defined as: + + `f(x) = scale * alpha * (exp(x) - 1.) for x < 0`, + `f(x) = scale * x for x >= 0`. + + Args: + x: Input tensor. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_selu = keras.ops.selu(x) + >>> print(x_selu) + array([-1.11133055, 0., 1.05070098], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Selu().symbolic_call(x) + return backend.nn.selu(x) + + +class Gelu(Operation): + def __init__(self, approximate=True, *, name=None): + super().__init__(name=name) + self.approximate = approximate + + def call(self, x): + return backend.nn.gelu(x, self.approximate) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.gelu", "keras.ops.nn.gelu"]) +def gelu(x, approximate=True): + """Gaussian Error Linear Unit (GELU) activation function. + + If `approximate` is `True`, it is defined as: + `f(x) = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))` + + Or if `approximate` is `False`, it is defined as: + `f(x) = x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, + where `P(X) ~ N(0, 1)`. + + Args: + x: Input tensor. + approximate: Approximate version of GELU activation. Defaults to `True`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_gelu = keras.ops.gelu(x) + >>> print(x_gelu) + array([-0.15865525, 0., 0.84134475], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Gelu(approximate).symbolic_call(x) + return backend.nn.gelu(x, approximate) + + +class Celu(Operation): + def __init__(self, alpha=1.0, *, name=None): + super().__init__(name=name) + self.alpha = alpha + + def call(self, x): + return backend.nn.celu(x, self.alpha) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.celu", "keras.ops.nn.celu"]) +def celu(x, alpha=1.0): + """Continuously-differentiable exponential linear unit. + + It is defined as: + + `f(x) = alpha * (exp(x / alpha) - 1) for x < 0`, `f(x) = x for x >= 0`. + + Args: + x: Input tensor. + alpha: the α value for the CELU formulation. Defaults to `1.0`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_celu = keras.ops.celu(x) + >>> print(x_celu) + array([-0.63212056, 0. , 1. ], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Celu(alpha).symbolic_call(x) + return backend.nn.celu(x, alpha) + + +class Glu(Operation): + def __init__(self, axis=-1, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.nn.glu(x, axis=self.axis) + + def compute_output_spec(self, x): + output_shape = list(x.shape) + if output_shape[self.axis] is not None: + if output_shape[self.axis] % 2 != 0: + raise ValueError( + "axis size must be divisible by 2. " + f"Received: x.shape={x.shape} with axis={self.axis}" + ) + output_shape[self.axis] = output_shape[self.axis] // 2 + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.glu", "keras.ops.nn.glu"]) +def glu(x, axis=-1): + """Gated Linear Unit (GLU) activation function. + + It is defined as: + + `f(x) = a * sigmoid(b)` + where `x` is split into `a` and `b` along the given axis. + + Args: + x: Input tensor. + axis: The axis along which to split the input tensor. Defaults to `-1`. + + Returns: + A tensor with the same shape as half of the input. + + Example: + + >>> x = np.array([-1., 0., 1. , 1.]) + >>> x_glu = keras.ops.glu(x) + >>> print(x_glu) + array([-0.73105858, 0. ], shape=(2,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Glu(axis).symbolic_call(x) + return backend.nn.glu(x, axis=axis) + + +class TanhShrink(Operation): + def call(self, x): + return backend.nn.tanh_shrink(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.tanh_shrink", "keras.ops.nn.tanh_shrink"]) +def tanh_shrink(x): + """Applies the tanh shrink function element-wise. + + It is defined as: + + `f(x) = x - tanh(x)`. + + Args: + x: Input tensor. + + Returns: + Output tensor of the same shape as `x`, where each element is + transformed according to the tanh shrink operation. + + Example: + + >>> x = np.array([ -1., 0., 1.]) + >>> x_tanh_shrink = keras.ops.tanh_shrink(x) + >>> print(x_tanh_shrink) + array([-0.23840584 0. 0.23840584], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return TanhShrink().symbolic_call(x) + return backend.nn.tanh_shrink(x) + + +class HardTanh(Operation): + def call(self, x): + return backend.nn.hard_tanh(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.hard_tanh", "keras.ops.nn.hard_tanh"]) +def hard_tanh(x): + """Applies the HardTanh function element-wise. + + It is defined as: + + `f(x) = -1 for x < -1`, `f(x) = x for -1 <= x <= 1`, `f(x) = 1 for x > 1`. + + Args: + x: Input tensor. + + Returns: + Output tensor of same shape as `x` + where values are clamped between -1 and 1. + + Example: + + >>> x = np.array([-2., -1., 0., 1., 2.]) + >>> x_hard_tanh = keras.ops.hard_tanh(x) + >>> print(x_hard_tanh) + array([-1. -1. 0. 1. 1.], shape=(5,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return HardTanh().symbolic_call(x) + return backend.nn.hard_tanh(x) + + +class HardShrink(Operation): + def __init__(self, threshold=0.5, *, name=None): + super().__init__(name=name) + self.threshold = threshold + + def call(self, x): + return backend.nn.hard_shrink(x, self.threshold) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.hard_shrink", "keras.ops.nn.hard_shrink"]) +def hard_shrink(x, threshold=0.5): + """Hard Shrink activation function. + + The Hard Shrink function is a thresholding operation defined as: + + `f(x) = x` if `|x| > threshold`, + `f(x) = 0` otherwise. + + Args: + x: Input tensor. + threshold: Threshold value. Defaults to 0.5. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-0.5, 0., 1.]) + >>> x_hard_shrink = keras.ops.hard_shrink(x) + >>> print(x_hard_shrink) + array([0. 0. 1.], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return HardShrink(threshold).symbolic_call(x) + return backend.nn.hard_shrink(x, threshold) + + +class Threshold(Operation): + def __init__(self, threshold, default_value, *, name=None): + super().__init__(name=name) + self.threshold = threshold + self.default_value = default_value + + def call(self, x): + return backend.nn.threshold(x, self.threshold, self.default_value) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.threshold", "keras.ops.nn.threshold"]) +def threshold(x, threshold, default_value): + """Threshold activation function. + + The function thresholds the input `x` as follows: + `f(x) = x` if `x > threshold`, + `f(x) = default_value` otherwise. + + Args: + x: Input tensor. + threshold: The value that decides when to retain or replace x. + default_value: Value to assign when `x <= threshold`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1.0, 0.0, 1.0, 2.0]) + >>> x_threshold = keras.ops.threshold(x, 1, 0) + >>> print(x_threshold) + array([0., 0., 0., 2.], shape=(4,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Threshold(threshold, default_value).symbolic_call(x) + return backend.nn.threshold(x, threshold, default_value) + + +class Softmax(Operation): + def __init__(self, axis=-1, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.nn.softmax(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.softmax", "keras.ops.nn.softmax"]) +def softmax(x, axis=-1): + """Softmax activation function. + + The elements of the output vector lie within the range `(0, 1)`, and their + total sum is exactly 1 (excluding the floating point rounding error). + + Each vector is processed independently. The `axis` argument specifies the + axis along which the function is applied within the input. + + It is defined as: + `f(x) = exp(x) / sum(exp(x))` + + Args: + x: Input tensor. + axis: Integer, axis along which the softmax is applied. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_softmax = keras.ops.softmax(x) + >>> print(x_softmax) + array([0.09003057, 0.24472847, 0.66524096], shape=(3,), dtype=float64) + + """ + # Don't use `backend.shape` since TensorFlow returns + # symbolic tensors for unknown shape which can trigger + # an error in TensorFlow graph execution. + if isinstance(axis, int) and x.shape[axis] == 1: + warnings.warn( + f"You are using a softmax over axis {axis} " + f"of a tensor of shape {x.shape}. This axis " + "has size 1. The softmax operation will always return " + "the value 1, which is likely not what you intended. " + "Did you mean to use a sigmoid instead?" + ) + if any_symbolic_tensors((x,)): + return Softmax(axis).symbolic_call(x) + if isinstance(axis, tuple): + axis_to_keep = [v for v in range(len(x.shape)) if v not in axis] + + x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis)) + x_reshaped = backend.numpy.reshape( + x_transposed, (*[x.shape[v] for v in axis_to_keep], -1) + ) + + x = backend.nn.softmax(x_reshaped, axis=-1) + + x = backend.numpy.reshape(x, x_transposed.shape) + x = backend.numpy.transpose( + x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis])) + ) + return x + else: + return backend.nn.softmax(x, axis=axis) + + +class LogSoftmax(Operation): + def __init__(self, axis=-1, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.nn.log_softmax(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.log_softmax", + "keras.ops.nn.log_softmax", + ] +) +def log_softmax(x, axis=-1): + """Log-softmax activation function. + + It is defined as: + `f(x) = x - max(x) - log(sum(exp(x - max(x))))` + + Args: + x: Input tensor. + axis: Integer, axis along which the log-softmax is applied. + Defaults to `-1`. + + Returns: + A tensor with the same shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_log_softmax = keras.ops.log_softmax(x) + >>> print(x_log_softmax) + array([-2.40760596, -1.40760596, -0.40760596], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return LogSoftmax(axis).symbolic_call(x) + if isinstance(axis, tuple): + axis_to_keep = [v for v in range(len(x.shape)) if v not in axis] + + x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis)) + x_reshaped = backend.numpy.reshape( + x_transposed, (*[x.shape[v] for v in axis_to_keep], -1) + ) + + x = backend.nn.log_softmax(x_reshaped, axis=-1) + + x = backend.numpy.reshape(x, x_transposed.shape) + x = backend.numpy.transpose( + x, axes=list(backend.numpy.argsort([*axis_to_keep, *axis])) + ) + return x + else: + return backend.nn.log_softmax(x, axis=axis) + + +class Sparsemax(Operation): + def __init__(self, axis=-1, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.nn.sparsemax(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"]) +def sparsemax(x, axis=-1): + """Sparsemax activation function. + + For each batch `i`, and class `j`, + sparsemax activation function is defined as: + + `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).` + + Args: + x: Input tensor. + axis: `int`, axis along which the sparsemax operation is applied. + + Returns: + A tensor, output of sparsemax transformation. Has the same type and + shape as `x`. + + Example: + + >>> x = np.array([-1., 0., 1.]) + >>> x_sparsemax = keras.ops.sparsemax(x) + >>> print(x_sparsemax) + array([0., 0., 1.], shape=(3,), dtype=float64) + + """ + if any_symbolic_tensors((x,)): + return Sparsemax(axis).symbolic_call(x) + return backend.nn.sparsemax(x, axis=axis) + + +class MaxPool(Operation): + def __init__( + self, + pool_size, + strides=None, + padding="valid", + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.pool_size = pool_size + self.strides = strides + self.padding = padding.lower() + self.data_format = data_format + + def call(self, inputs): + return backend.nn.max_pool( + inputs, + self.pool_size, + self.strides, + self.padding, + self.data_format, + ) + + def compute_output_spec(self, inputs): + output_shape = operation_utils.compute_pooling_output_shape( + inputs.shape, + self.pool_size, + self.strides, + self.padding, + self.data_format, + ) + return KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export(["keras.ops.max_pool", "keras.ops.nn.max_pool"]) +def max_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + """Max pooling operation. + + Args: + inputs: Tensor of rank N+2. `inputs` has shape + `(batch_size,) + inputs_spatial_shape + (num_channels,)` if + `data_format="channels_last"`, or + `(batch_size, num_channels) + inputs_spatial_shape` if + `data_format="channels_first"`. Pooling happens over the spatial + dimensions only. + pool_size: int or tuple/list of integers of size + `len(inputs_spatial_shape)`, specifying the size of the pooling + window for each spatial dimension of the input tensor. If + `pool_size` is int, then every spatial dimension shares the same + `pool_size`. + strides: int or tuple/list of integers of size + `len(inputs_spatial_shape)`. The stride of the sliding window for + each spatial dimension of the input tensor. If `strides` is int, + then every spatial dimension shares the same `strides`. + padding: string, either `"valid"` or `"same"`. `"valid"` means no + padding is applied, and `"same"` results in padding evenly to the + left/right or up/down of the input such that output has the + same height/width dimension as the input when `strides=1`. + data_format: A string, either `"channels_last"` or `"channels_first"`. + `data_format` determines the ordering of the dimensions in the + inputs. If `data_format="channels_last"`, `inputs` is of shape + `(batch_size, ..., channels)` while if + `data_format="channels_first"`, `inputs` is of shape + `(batch_size, channels, ...)`. + + Returns: + A tensor of rank N+2, the result of the max pooling operation. + """ + data_format = standardize_data_format(data_format) + padding = padding.lower() + if any_symbolic_tensors((inputs,)): + return MaxPool( + pool_size, + strides, + padding, + data_format, + ).symbolic_call(inputs) + return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format) + + +class AveragePool(Operation): + def __init__( + self, + pool_size, + strides=None, + padding="valid", + data_format=None, + *, + name=None, + ): + super().__init__(name=name) + self.pool_size = pool_size + self.strides = strides + self.padding = padding.lower() + self.data_format = data_format + + def call(self, inputs): + return backend.nn.average_pool( + inputs, + self.pool_size, + self.strides, + self.padding, + self.data_format, + ) + + def compute_output_spec(self, inputs): + output_shape = operation_utils.compute_pooling_output_shape( + inputs.shape, + self.pool_size, + self.strides, + self.padding, + self.data_format, + ) + return KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export( + [ + "keras.ops.average_pool", + "keras.ops.nn.average_pool", + ] +) +def average_pool( + inputs, + pool_size, + strides=None, + padding="valid", + data_format=None, +): + """Average pooling operation. + + Args: + inputs: Tensor of rank N+2. `inputs` has shape + `(batch_size,) + inputs_spatial_shape + (num_channels,)` if + `data_format="channels_last"`, or + `(batch_size, num_channels) + inputs_spatial_shape` if + `data_format="channels_first"`. Pooling happens over the spatial + dimensions only. + pool_size: int or tuple/list of integers of size + `len(inputs_spatial_shape)`, specifying the size of the pooling + window for each spatial dimension of the input tensor. If + `pool_size` is int, then every spatial dimension shares the same + `pool_size`. + strides: int or tuple/list of integers of size + `len(inputs_spatial_shape)`. The stride of the sliding window for + each spatial dimension of the input tensor. If `strides` is int, + then every spatial dimension shares the same `strides`. + padding: string, either `"valid"` or `"same"`. `"valid"` means no + padding is applied, and `"same"` results in padding evenly to the + left/right or up/down of the input such that output has the + same height/width dimension as the input when `strides=1`. + data_format: A string, either `"channels_last"` or `"channels_first"`. + `data_format` determines the ordering of the dimensions in the + inputs. If `data_format="channels_last"`, `inputs` is of shape + `(batch_size, ..., channels)` while if + `data_format="channels_first"`, `inputs` is of shape + `(batch_size, channels, ...)`. + + Returns: + A tensor of rank N+2, the result of the average pooling operation. + """ + data_format = standardize_data_format(data_format) + padding = padding.lower() + if any_symbolic_tensors((inputs,)): + return AveragePool( + pool_size, + strides, + padding, + data_format, + ).symbolic_call(inputs) + return backend.nn.average_pool( + inputs, pool_size, strides, padding, data_format + ) + + +class Conv(Operation): + def __init__( + self, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + *, + name=None, + ): + super().__init__(name=name) + self.strides = strides + self.padding = padding.lower() + self.data_format = data_format + self.dilation_rate = dilation_rate + + def call(self, inputs, kernel): + return backend.nn.conv( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + def compute_output_spec(self, inputs, kernel): + output_shape = operation_utils.compute_conv_output_shape( + inputs.shape, + kernel.shape[-1], + kernel.shape[:-2], + self.strides, + self.padding, + self.data_format, + self.dilation_rate, + ) + return KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export(["keras.ops.conv", "keras.ops.nn.conv"]) +def conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + """General N-D convolution. + + This ops supports 1D, 2D and 3D convolution. + + Args: + inputs: Tensor of rank N+2. `inputs` has shape + `(batch_size,) + inputs_spatial_shape + (num_channels,)` if + `data_format="channels_last"`, or + `(batch_size, num_channels) + inputs_spatial_shape` if + `data_format="channels_first"`. + kernel: Tensor of rank N+2. `kernel` has shape + `(kernel_spatial_shape, num_input_channels, num_output_channels)`. + `num_input_channels` should match the number of channels in + `inputs`. + strides: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the strides of the convolution along each spatial + dimension. If `strides` is int, then every spatial dimension shares + the same `strides`. + padding: string, either `"valid"` or `"same"`. `"valid"` means no + padding is applied, and `"same"` results in padding evenly to the + left/right or up/down of the input such that output has the + same height/width dimension as the input when `strides=1`. + data_format: A string, either `"channels_last"` or `"channels_first"`. + `data_format` determines the ordering of the dimensions in the + inputs. If `data_format="channels_last"`, `inputs` is of shape + `(batch_size, ..., channels)` while if + `data_format="channels_first"`, `inputs` is of shape + `(batch_size, channels, ...)`. + dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the dilation rate to use for dilated convolution. If + `dilation_rate` is int, then every spatial dimension shares + the same `dilation_rate`. + + Returns: + A tensor of rank N+2, the result of the conv operation. + """ + data_format = standardize_data_format(data_format) + padding = padding.lower() + if any_symbolic_tensors((inputs,)): + return Conv(strides, padding, data_format, dilation_rate).symbolic_call( + inputs, kernel + ) + return backend.nn.conv( + inputs, kernel, strides, padding, data_format, dilation_rate + ) + + +class DepthwiseConv(Operation): + def __init__( + self, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + *, + name=None, + ): + super().__init__(name=name) + self.strides = strides + self.padding = padding.lower() + self.data_format = data_format + self.dilation_rate = dilation_rate + + def call(self, inputs, kernel): + return backend.nn.depthwise_conv( + inputs, + kernel, + self.strides, + self.padding, + self.data_format, + self.dilation_rate, + ) + + def compute_output_spec(self, inputs, kernel): + output_shape = operation_utils.compute_conv_output_shape( + inputs.shape, + kernel.shape[-1] * kernel.shape[-2], + kernel.shape[:-2], + self.strides, + self.padding, + self.data_format, + self.dilation_rate, + ) + return KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export( + [ + "keras.ops.depthwise_conv", + "keras.ops.nn.depthwise_conv", + ] +) +def depthwise_conv( + inputs, + kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + """General N-D depthwise convolution. + + This ops supports 1D and 2D depthwise convolution. + + Args: + inputs: Tensor of rank N+2. `inputs` has shape + `(batch_size,) + inputs_spatial_shape + (num_channels,)` if + `data_format="channels_last"`, or + `(batch_size, num_channels) + inputs_spatial_shape` if + `data_format="channels_first"`. + kernel: Tensor of rank N+2. `kernel` has shape + [kernel_spatial_shape, num_input_channels, num_channels_multiplier], + `num_input_channels` should match the number of channels in + `inputs`. + strides: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the strides of the convolution along each spatial + dimension. If `strides` is int, then every spatial dimension shares + the same `strides`. + padding: string, either `"valid"` or `"same"`. `"valid"` means no + padding is applied, and `"same"` results in padding evenly to the + left/right or up/down of the input such that output has the + same height/width dimension as the input when `strides=1`. + data_format: A string, either `"channels_last"` or `"channels_first"`. + `data_format` determines the ordering of the dimensions in the + inputs. If `data_format="channels_last"`, `inputs` is of shape + `(batch_size, ..., channels)` while if + `data_format="channels_first"`, `inputs` is of shape + `(batch_size, channels, ...)`. + dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the dilation rate to use for dilated convolution. If + `dilation_rate` is int, then every spatial dimension shares + the same `dilation_rate`. + + Returns: + A tensor of rank N+2, the result of the depthwise conv operation. + """ + data_format = standardize_data_format(data_format) + padding = padding.lower() + if any_symbolic_tensors((inputs, kernel)): + return DepthwiseConv( + strides, padding, data_format, dilation_rate + ).symbolic_call(inputs, kernel) + return backend.nn.depthwise_conv( + inputs, + kernel, + strides, + padding, + data_format, + dilation_rate, + ) + + +class SeparableConv(Operation): + def __init__( + self, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + *, + name=None, + ): + super().__init__(name=name) + self.strides = strides + self.padding = padding.lower() + self.data_format = data_format + self.dilation_rate = dilation_rate + + def call(self, inputs, depthwise_kernel, pointwise_kernel): + return backend.nn.separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + self.strides, + self.padding, + self.data_format, + self.dilation_rate, + ) + + def compute_output_spec(self, inputs, depthwise_kernel, pointwise_kernel): + output_shape = list( + depthwise_conv( + inputs, + depthwise_kernel, + self.strides, + self.padding, + self.data_format, + self.dilation_rate, + ).shape + ) + if self.data_format == "channels_last": + output_shape[-1] = pointwise_kernel.shape[-1] + else: + output_shape[1] = pointwise_kernel.shape[-1] + return KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export( + [ + "keras.ops.separable_conv", + "keras.ops.nn.separable_conv", + ] +) +def separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, +): + """General N-D separable convolution. + + This ops supports 1D and 2D separable convolution. `separable_conv` is + a depthwise conv followed by a pointwise conv. + + Args: + inputs: Tensor of rank N+2. `inputs` has shape + `(batch_size,) + inputs_spatial_shape + (num_channels,)` if + `data_format="channels_last"`, or + `(batch_size, num_channels) + inputs_spatial_shape` if + `data_format="channels_first"`. + depthwise_kernel: Tensor of rank N+2. `depthwise_kernel` has shape + [kernel_spatial_shape, num_input_channels, num_channels_multiplier], + `num_input_channels` should match the number of channels in + `inputs`. + pointwise_kernel: Tensor of rank N+2. `pointwise_kernel` has shape + `(*ones_like(kernel_spatial_shape), + num_input_channels * num_channels_multiplier, num_output_channels)`. + strides: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the strides of the convolution along each spatial + dimension. If `strides` is int, then every spatial dimension shares + the same `strides`. + padding: string, either `"valid"` or `"same"`. `"valid"` means no + padding is applied, and `"same"` results in padding evenly to the + left/right or up/down of the input such that output has the + same height/width dimension as the input when `strides=1`. + data_format: A string, either `"channels_last"` or `"channels_first"`. + `data_format` determines the ordering of the dimensions in the + inputs. If `data_format="channels_last"`, `inputs` is of shape + `(batch_size, ..., channels)` while if + `data_format="channels_first"`, `inputs` is of shape + `(batch_size, channels, ...)`. + dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the dilation rate to use for dilated convolution. If + `dilation_rate` is int, then every spatial dimension shares + the same `dilation_rate`. + + Returns: + A tensor of rank N+2, the result of the depthwise conv operation. + """ + data_format = standardize_data_format(data_format) + padding = padding.lower() + if any_symbolic_tensors((inputs,)): + return SeparableConv( + strides, + padding, + data_format, + dilation_rate, + ).symbolic_call(inputs, depthwise_kernel, pointwise_kernel) + return backend.nn.separable_conv( + inputs, + depthwise_kernel, + pointwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + + +class ConvTranspose(Operation): + def __init__( + self, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, + *, + name=None, + ): + super().__init__(name=name) + self.strides = strides + self.output_padding = output_padding + self.padding = padding.lower() + self.data_format = data_format + self.dilation_rate = dilation_rate + + def call( + self, + inputs, + kernel, + ): + return backend.nn.conv_transpose( + inputs, + kernel, + self.strides, + self.output_padding, + self.padding, + self.data_format, + self.dilation_rate, + ) + + def compute_output_spec(self, inputs, kernel): + kernel_size = kernel.shape[:-2] + filters = kernel.shape[-2] + output_shape = compute_conv_transpose_output_shape( + inputs.shape, + kernel_size, + filters, + self.strides, + self.padding, + self.output_padding, + self.data_format, + self.dilation_rate, + ) + return KerasTensor(output_shape, dtype=inputs.dtype) + + +@keras_export( + [ + "keras.ops.conv_transpose", + "keras.ops.nn.conv_transpose", + ] +) +def conv_transpose( + inputs, + kernel, + strides=1, + padding="valid", + output_padding=None, + data_format=None, + dilation_rate=1, +): + """General N-D convolution transpose. + + Also known as de-convolution. This ops supports 1D, 2D and 3D convolution. + + Args: + inputs: Tensor of rank N+2. `inputs` has shape + `(batch_size,) + inputs_spatial_shape + (num_channels,)` if + `data_format="channels_last"`, or + `(batch_size, num_channels) + inputs_spatial_shape` if + `data_format="channels_first"`. + kernel: Tensor of rank N+2. `kernel` has shape + [kernel_spatial_shape, num_output_channels, num_input_channels], + `num_input_channels` should match the number of channels in + `inputs`. + strides: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the strides of the convolution along each spatial + dimension. If `strides` is int, then every spatial dimension shares + the same `strides`. + padding: string, either `"valid"` or `"same"`. `"valid"` means no + padding is applied, and `"same"` results in padding evenly to the + left/right or up/down of the input such that output has the + same height/width dimension as the input when `strides=1`. + output_padding: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the amount of padding along the height and width of + the output tensor. Can be a single integer to specify the same + value for all spatial dimensions. The amount of output padding + along a given dimension must be lower than the stride along that + same dimension. If set to `None` (default), the output shape is + inferred. + data_format: A string, either `"channels_last"` or `"channels_first"`. + `data_format` determines the ordering of the dimensions in the + inputs. If `data_format="channels_last"`, `inputs` is of shape + `(batch_size, ..., channels)` while if + `data_format="channels_first"`, `inputs` is of shape + `(batch_size, channels, ...)`. + dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`, + specifying the dilation rate to use for dilated convolution. If + `dilation_rate` is int, then every spatial dimension shares + the same `dilation_rate`. + + Returns: + A tensor of rank N+2, the result of the conv operation. + """ + data_format = standardize_data_format(data_format) + padding = padding.lower() + if any_symbolic_tensors((inputs,)): + return ConvTranspose( + strides, padding, output_padding, data_format, dilation_rate + ).symbolic_call(inputs, kernel) + return backend.nn.conv_transpose( + inputs, + kernel, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + + +class OneHot(Operation): + def __init__( + self, num_classes, axis=-1, dtype=None, sparse=False, *, name=None + ): + super().__init__(name=name) + self.num_classes = num_classes + self.axis = axis + self.dtype = backend.standardize_dtype(dtype) + self.sparse = sparse + + def call(self, x): + return backend.nn.one_hot( + x, + self.num_classes, + axis=self.axis, + dtype=self.dtype, + sparse=self.sparse, + ) + + def compute_output_spec(self, x): + x_shape = list(getattr(x, "shape", [])) + if self.axis == -1: + x_shape.append(self.num_classes) + elif self.axis >= 0 and self.axis < len(x_shape): + x_shape.insert(self.axis, self.num_classes) + else: + raise ValueError( + f"axis must be -1 or between [0, {len(x.shape)}), but " + f"received {self.axis}." + ) + return KerasTensor(x_shape, dtype=self.dtype, sparse=self.sparse) + + +@keras_export(["keras.ops.one_hot", "keras.ops.nn.one_hot"]) +def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False): + """Converts integer tensor `x` into a one-hot tensor. + + The one-hot encoding is a representation where each integer value is + converted into a binary vector with a length equal to `num_classes`, + and the index corresponding to the integer value is marked as 1, while + all other indices are marked as 0. + + Args: + x: Integer tensor to be encoded. The shape can be + arbitrary, but the dtype should be integer. + num_classes: Number of classes for the one-hot encoding. + axis: Axis along which the encoding is performed. + `-1` represents the last axis. Defaults to `-1`. + dtype: (Optional) Data type of the output tensor. If not + provided, it defaults to the default data type of the backend. + sparse: Whether to return a sparse tensor; for backends that support + sparse tensors. + + Returns: + Integer tensor: One-hot encoded tensor with the same shape as `x` + except for the specified `axis` dimension, which will have + a length of `num_classes`. The dtype of the output tensor + is determined by `dtype` or the default data type of the backend. + + Example: + + >>> x = keras.ops.convert_to_tensor([1, 3, 2, 0]) + >>> one_hot(x, num_classes=4) + array([[0. 1. 0. 0.] + [0. 0. 0. 1.] + [0. 0. 1. 0.] + [1. 0. 0. 0.]], shape=(4, 4), dtype=float32) + """ + if any_symbolic_tensors((x,)): + return OneHot( + num_classes, axis=axis, dtype=dtype, sparse=sparse + ).symbolic_call(x) + return backend.nn.one_hot( + x, + num_classes, + axis=axis, + dtype=dtype or backend.floatx(), + sparse=sparse, + ) + + +class BinaryCrossentropy(Operation): + def __init__(self, from_logits=False, *, name=None): + super().__init__(name=name) + self.from_logits = from_logits + + def call(self, target, output): + return backend.nn.binary_crossentropy( + target, output, from_logits=self.from_logits + ) + + def compute_output_spec(self, target, output): + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + return KerasTensor(output.shape, dtype=output.dtype) + + +@keras_export( + [ + "keras.ops.binary_crossentropy", + "keras.ops.nn.binary_crossentropy", + ] +) +def binary_crossentropy(target, output, from_logits=False): + """Computes binary cross-entropy loss between target and output tensor. + + The binary cross-entropy loss is commonly used in binary + classification tasks where each input sample belongs to one + of the two classes. It measures the dissimilarity between the + target and output probabilities or logits. + + Args: + target: The target tensor representing the true binary labels. + Its shape should match the shape of the `output` tensor. + output: The output tensor representing the predicted probabilities + or logits. Its shape should match the shape of the + `target` tensor. + from_logits: (optional) Whether `output` is a tensor of logits or + probabilities. + Set it to `True` if `output` represents logits; otherwise, + set it to `False` if `output` represents probabilities. + Defaults to `False`. + + Returns: + Integer tensor: The computed binary cross-entropy loss between + `target` and `output`. + + Example: + + >>> target = keras.ops.convert_to_tensor([0, 1, 1, 0]) + >>> output = keras.ops.convert_to_tensor([0.1, 0.9, 0.8, 0.2]) + >>> binary_crossentropy(target, output) + array([0.10536054 0.10536054 0.22314355 0.22314355], + shape=(4,), dtype=float32) + """ + if any_symbolic_tensors((target, output)): + return BinaryCrossentropy(from_logits=from_logits).symbolic_call( + target, output + ) + return backend.nn.binary_crossentropy( + target, output, from_logits=from_logits + ) + + +class CategoricalCrossentropy(Operation): + def __init__(self, from_logits=False, axis=-1, *, name=None): + super().__init__(name=name) + self.from_logits = from_logits + self.axis = axis + + def call(self, target, output): + return backend.nn.categorical_crossentropy( + target, output, from_logits=self.from_logits, axis=self.axis + ) + + def compute_output_spec(self, target, output): + if target.shape != output.shape: + raise ValueError( + "Arguments `target` and `output` must have the same shape. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + if len(target.shape) < 1: + raise ValueError( + "Arguments `target` and `output` must be at least rank 1. " + "Received: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + return KerasTensor(output.shape[:-1], dtype=output.dtype) + + +@keras_export( + [ + "keras.ops.categorical_crossentropy", + "keras.ops.nn.categorical_crossentropy", + ] +) +def categorical_crossentropy(target, output, from_logits=False, axis=-1): + """Computes categorical cross-entropy loss between target and output tensor. + + The categorical cross-entropy loss is commonly used in multi-class + classification tasks where each input sample can belong to one of + multiple classes. It measures the dissimilarity + between the target and output probabilities or logits. + + Args: + target: The target tensor representing the true categorical labels. + Its shape should match the shape of the `output` tensor + except for the last dimension. + output: The output tensor representing the predicted probabilities + or logits. Its shape should match the shape of the `target` + tensor except for the last dimension. + from_logits: (optional) Whether `output` is a tensor of logits or + probabilities. + Set it to `True` if `output` represents logits; otherwise, + set it to `False` if `output` represents probabilities. + Defaults to `False`. + axis: (optional) The axis along which the categorical cross-entropy + is computed. + Defaults to `-1`, which corresponds to the last dimension of + the tensors. + + Returns: + Integer tensor: The computed categorical cross-entropy loss between + `target` and `output`. + + Example: + + >>> target = keras.ops.convert_to_tensor( + ... [[1, 0, 0], + ... [0, 1, 0], + ... [0, 0, 1]]) + >>> output = keras.ops.convert_to_tensor( + ... [[0.9, 0.05, 0.05], + ... [0.1, 0.8, 0.1], + ... [0.2, 0.3, 0.5]]) + >>> categorical_crossentropy(target, output) + array([0.10536054 0.22314355 0.6931472 ], shape=(3,), dtype=float32) + """ + if any_symbolic_tensors((target, output)): + return CategoricalCrossentropy( + from_logits=from_logits, axis=axis + ).symbolic_call(target, output) + return backend.nn.categorical_crossentropy( + target, output, from_logits=from_logits, axis=axis + ) + + +class SparseCategoricalCrossentropy(Operation): + def __init__(self, from_logits=False, axis=-1, *, name=None): + super().__init__(name=name) + self.from_logits = from_logits + self.axis = axis + + def call(self, target, output): + return backend.nn.sparse_categorical_crossentropy( + target, output, from_logits=self.from_logits, axis=self.axis + ) + + def compute_output_spec(self, target, output): + if len(output.shape) < 1: + raise ValueError( + "Argument `output` must be at least rank 1. " + "Received: " + f"output.shape={output.shape}" + ) + target_shape = target.shape + if len(target_shape) == len(output.shape) and target_shape[-1] == 1: + target_shape = target_shape[:-1] + if target_shape != output.shape[:-1]: + raise ValueError( + "Arguments `target` and `output` must have the same shape " + "up until the last dimension: " + f"target.shape={target.shape}, output.shape={output.shape}" + ) + return KerasTensor(output.shape[:-1], dtype=output.dtype) + + +@keras_export( + [ + "keras.ops.sparse_categorical_crossentropy", + "keras.ops.nn.sparse_categorical_crossentropy", + ] +) +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): + """Computes sparse categorical cross-entropy loss. + + The sparse categorical cross-entropy loss is similar to categorical + cross-entropy, but it is used when the target tensor contains integer + class labels instead of one-hot encoded vectors. It measures the + dissimilarity between the target and output probabilities or logits. + + Args: + target: The target tensor representing the true class labels as + integers. Its shape should match the shape of the `output` + tensor except for the last dimension. + output: The output tensor representing the predicted probabilities + or logits. + Its shape should match the shape of the `target` tensor except + for the last dimension. + from_logits: (optional) Whether `output` is a tensor of logits + or probabilities. + Set it to `True` if `output` represents logits; otherwise, + set it to `False` if `output` represents probabilities. + Defaults to `False`. + axis: (optional) The axis along which the sparse categorical + cross-entropy is computed. + Defaults to `-1`, which corresponds to the last dimension + of the tensors. + + Returns: + Integer tensor: The computed sparse categorical cross-entropy + loss between `target` and `output`. + + Example: + + >>> target = keras.ops.convert_to_tensor([0, 1, 2], dtype=int32) + >>> output = keras.ops.convert_to_tensor( + ... [[0.9, 0.05, 0.05], + ... [0.1, 0.8, 0.1], + ... [0.2, 0.3, 0.5]]) + >>> sparse_categorical_crossentropy(target, output) + array([0.10536056 0.22314355 0.6931472 ], shape=(3,), dtype=float32) + """ + if any_symbolic_tensors((target, output)): + return SparseCategoricalCrossentropy( + from_logits=from_logits, axis=axis + ).symbolic_call(target, output) + return backend.nn.sparse_categorical_crossentropy( + target, output, from_logits=from_logits, axis=axis + ) + + +class MultiHot(Operation): + def __init__( + self, + num_classes=None, + axis=-1, + dtype=None, + sparse=False, + *, + name=None, + **kwargs, + ): + if num_classes is None and "num_tokens" in kwargs: + num_classes = kwargs.pop("num_tokens") + if num_classes is None: + raise ValueError("Argument `num_classes` must be specified.") + super().__init__(name=name) + self.num_classes = num_classes + self.axis = axis + self.dtype = dtype or backend.floatx() + self.sparse = sparse + + def call(self, inputs): + return backend.nn.multi_hot( + inputs, + num_classes=self.num_classes, + axis=self.axis, + dtype=self.dtype, + ) + + def compute_output_spec(self, inputs): + x_shape = list(getattr(inputs, "shape", [])) + if self.axis == -1: + x_shape.append(self.num_classes) + elif self.axis >= 0 and self.axis < len(x_shape): + x_shape.insert(self.axis, self.num_classes) + else: + raise ValueError( + f"axis must be -1 or between [0, {len(inputs.shape)}), but " + f"received {self.axis}." + ) + + if len(x_shape) == 2: + x_shape = [x_shape[-1]] + else: + x_shape = [x_shape[0]] + x_shape[2:] + + return KerasTensor(x_shape, dtype=inputs.dtype, sparse=self.sparse) + + +@keras_export( + [ + "keras.ops.multi_hot", + "keras.ops.nn.multi_hot", + ] +) +def multi_hot( + inputs, num_classes=None, axis=-1, dtype=None, sparse=False, **kwargs +): + """Encodes integer labels as multi-hot vectors. + + This function encodes integer labels as multi-hot vectors, where each label + is mapped to a binary value in the resulting vector. + + Args: + inputs: Tensor of integer labels to be converted to multi-hot vectors. + num_classes: Integer, the total number of unique classes. + axis: (optional) Axis along which the multi-hot encoding should be + added. Defaults to `-1`, which corresponds to the last dimension. + dtype: (optional) The data type of the resulting tensor. Default + is backend's float type. + sparse: Whether to return a sparse tensor; for backends that support + sparse tensors. + + Returns: + Tensor: The multi-hot encoded tensor. + + Example: + + >>> data = keras.ops.convert_to_tensor([0, 4]) + >>> keras.ops.multi_hot(data, num_classes=5) + array([1.0, 0.0, 0.0, 0.0, 1.0], dtype=float32) + + """ + if num_classes is None and "num_tokens" in kwargs: + num_classes = kwargs.pop("num_tokens") + if num_classes is None: + raise ValueError("Argument `num_classes` must be specified.") + + if any_symbolic_tensors((inputs,)): + return MultiHot(num_classes, axis, dtype, sparse).symbolic_call(inputs) + + return backend.nn.multi_hot(inputs, num_classes, axis, dtype, sparse) + + +class Moments(Operation): + def __init__(self, axes, keepdims=False, synchronized=False, *, name=None): + super().__init__(name=name) + self.axes = axes + self.keepdims = keepdims + self.synchronized = synchronized + + def call(self, x): + return backend.nn.moments( + x, + axes=self.axes, + keepdims=self.keepdims, + synchronized=self.synchronized, + ) + + def compute_output_spec(self, x): + return ( + KerasTensor( + reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims), + dtype=x.dtype, + ), + KerasTensor( + reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims), + dtype=x.dtype, + ), + ) + + +@keras_export( + [ + "keras.ops.moments", + "keras.ops.nn.moments", + ] +) +def moments(x, axes, keepdims=False, synchronized=False): + """Calculates the mean and variance of `x`. + + The mean and variance are calculated by aggregating the contents of `x` + across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and + variance of a vector. + + Args: + x: Input tensor. + axes: A list of axes which to compute mean and variance. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + synchronized: Only applicable with the TensorFlow backend. + If `True`, synchronizes the global batch statistics (mean and + variance) across all devices at each training step in a + distributed training strategy. If `False`, each replica uses its own + local batch statistics. + + Returns: + A tuple containing two tensors - mean and variance. + + Example: + + >>> x = keras.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype="float32") + >>> keras.ops.moments(x, axes=[0]) + (array(21.2, dtype=float32), array(1553.3601, dtype=float32)) + + """ + if any_symbolic_tensors((x,)): + return Moments(axes, keepdims, synchronized=synchronized).symbolic_call( + x + ) + + return backend.nn.moments(x, axes, keepdims, synchronized=synchronized) + + +class BatchNorm(Operation): + def __init__(self, axis, epsilon=1e-3, *, name=None): + super().__init__(name=name) + self.axis = axis + self.epsilon = epsilon + + def call(self, x, mean, variance, offset=None, scale=None): + return backend.nn.batch_normalization( + x, + mean, + variance, + axis=self.axis, + offset=offset, + scale=scale, + epsilon=self.epsilon, + ) + + def _check_shape(self, name, shape, expected_shape): + if shape != expected_shape: + raise ValueError( + f"Arguments `{name}` must be a vector of length " + f"`x.shape[axis]`. Expected: `{expected_shape}`. " + f"Received: `{shape}." + ) + + def compute_output_spec(self, x, mean, variance, offset, scale): + shape = (x.shape[self.axis],) + self._check_shape("mean", tuple(mean.shape), shape) + self._check_shape("variance", tuple(variance.shape), shape) + if offset is not None: + self._check_shape("offset", tuple(offset.shape), shape) + if offset is not scale: + self._check_shape("scale", tuple(scale.shape), shape) + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.batch_normalization", + "keras.ops.nn.batch_normalization", + ] +) +def batch_normalization( + x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3 +): + """Normalizes `x` by `mean` and `variance`. + + This op is typically used by the batch normalization step in a neural + network. It normalizes the input tensor along the given axis. + + Args: + x: Input tensor. + mean: A mean vector of the same length as the `axis` dimension of the + input thensor. + variance: A variance vector of the same length as the `axis` dimension + of the input tensor. + axis: Integer, the axis that should be normalized. + offset: An offset vector of the same length as the `axis` dimension of + the input tensor. If not `None`, `offset` is added to the normalized + tensor. Defaults to `None`. + scale: A scale vector of the same length as the `axis` dimension of the + input tensor. If not `None`, the normalized tensor is multiplied by + `scale`. Defaults to `None`. + epsilon: Small float added to variance to avoid dividing by zero. + Defaults to 1e-3. + + Returns: + The normalized tensor. + + Example: + + >>> x = keras.ops.convert_to_tensor( + ... [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + ... ) + >>> keras.ops.batch_normalization( + ... x, + ... mean=[0.4, 0.5, 0.6], + ... variance=[0.67, 0.67, 0.67], + ... axis=-1 + ... ) + array([[-3.6624e-01, -3.6624e-01, -3.6624e-01], + [-4.6445e-09, 0.0000e+00, -1.8578e-08], + [ 3.6624e-01, 3.6624e-01, 3.6624e-01]]) + + """ + if any_symbolic_tensors((x, mean, variance, offset, scale)): + return BatchNorm(axis, epsilon).symbolic_call( + x, mean, variance, offset, scale + ) + + return backend.nn.batch_normalization( + x, mean, variance, axis, offset, scale, epsilon + ) + + +class CTCLoss(Operation): + def __init__(self, mask_index=0, *, name=None): + super().__init__(name=name) + self.mask_index = mask_index + + def call(self, target, output, target_length, output_length): + return backend.nn.ctc_loss( + target, output, target_length, output_length, self.mask_index + ) + + def _check_shape_first_dim(self, name1, shape1, name2, shape2): + if shape1[0] != shape2[0]: + raise ValueError( + f"Arguments `{name1}` and `{name2}` must have the same " + "first dimension. " + f"Received shapes: `{shape1}` and `{shape2}`." + ) + + def compute_output_spec(self, target, output, target_length, output_length): + self._check_shape_first_dim( + "target", target.shape, "output", output.shape + ) + self._check_shape_first_dim( + "target_length", target_length.shape, "target", target.shape + ) + self._check_shape_first_dim( + "output_length", output_length.shape, "output", output.shape + ) + dtype = backend.result_type(output.dtype, "float32") + return KerasTensor((target.shape[0],), dtype=dtype) + + +@keras_export( + [ + "keras.ops.ctc_loss", + "keras.ops.nn.ctc_loss", + ] +) +def ctc_loss(target, output, target_length, output_length, mask_index=0): + """CTC (Connectionist Temporal Classification) loss. + + Args: + target: A tensor of shape `(batch_size, max_length)` containing + the true labels in integer format. + output: A tensor of shape `(batch_size, max_length, num_classes)` + containing logits (the output of your model). + target_length: A tensor of shape `(batch_size,)` containing the + true label lengths. + output_length: A tensor of shape `(batch_size,)` containing the + output lengths. + mask_index: The index of the mask character in the vocabulary. + Defaults to `0`. + """ + + if any_symbolic_tensors((target, output, target_length, output_length)): + return CTCLoss(mask_index).symbolic_call( + target, output, target_length, output_length + ) + return backend.nn.ctc_loss( + target, output, target_length, output_length, mask_index + ) + + +class CTCDecode(Operation): + def __init__( + self, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, + *, + name=None, + ): + super().__init__(name=name) + self.strategy = strategy + self.beam_width = beam_width + self.top_paths = top_paths + self.merge_repeated = merge_repeated + self.mask_index = mask_index + + def call(self, inputs, sequence_lengths): + return backend.nn.ctc_decode( + inputs, + sequence_lengths, + strategy=self.strategy, + beam_width=self.beam_width, + top_paths=self.top_paths, + merge_repeated=self.merge_repeated, + mask_index=self.mask_index, + ) + + def compute_output_spec(self, inputs, sequence_lengths): + inputs_shape = inputs.shape + if self.strategy == "greedy": + top_paths = 1 + else: + top_paths = self.top_paths + dtype = backend.result_type(inputs.dtype, "float32") + return ( + KerasTensor( + (top_paths, inputs_shape[0], inputs_shape[1]), dtype="int32" + ), + KerasTensor((inputs_shape[0], top_paths), dtype=dtype), + ) + + +@keras_export( + [ + "keras.ops.ctc_decode", + "keras.ops.nn.ctc_decode", + ] +) +def ctc_decode( + inputs, + sequence_lengths, + strategy="greedy", + beam_width=100, + top_paths=1, + merge_repeated=True, + mask_index=0, +): + """Decodes the output of a CTC model. + + Args: + inputs: A tensor of shape `(batch_size, max_length, num_classes)` + containing the logits (the output of the model). + They should *not* be normalized via softmax. + sequence_lengths: A tensor of shape `(batch_size,)` containing the + sequence lengths for the batch. + strategy: A string for the decoding strategy. Supported values are + `"greedy"` and `"beam_search"`. + beam_width: An integer scalar beam width used in beam search. + Defaults to 100. + top_paths: An integer scalar, the number of top paths to return. + Defaults to 1. + merge_repeated: A boolean scalar, whether to merge repeated + labels in the output. Defaults to `True`. + mask_index: An integer scalar, the index of the mask character in + the vocabulary. Defaults to `0`. + + Returns: + A tuple containing: + - The tensor representing the list of decoded sequences. If + `strategy="greedy"`, the shape is `(1, batch_size, max_length)`. If + `strategy="beam_search"`, the shape is + `(top_paths, batch_size, max_length)`. Note that: `-1` indicates the + blank label. + - If `strategy="greedy"`, a tensor of shape `(batch_size, 1)` + representing the negative of the sum of the probability logits for + each sequence. If `strategy="beam_seatch"`, a tensor of shape + `(batch_size, top_paths)` representing the log probability for each + sequence. + """ + + if any_symbolic_tensors((inputs, sequence_lengths)): + return CTCDecode( + strategy=strategy, + beam_width=beam_width, + top_paths=top_paths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ).symbolic_call(inputs, sequence_lengths) + return backend.nn.ctc_decode( + inputs=inputs, + sequence_lengths=sequence_lengths, + strategy=strategy, + beam_width=beam_width, + top_paths=top_paths, + merge_repeated=merge_repeated, + mask_index=mask_index, + ) + + +class Normalize(Operation): + def __init__(self, axis=-1, order=2, epsilon=None, *, name=None): + super().__init__(name=name) + self.axis = axis + self.order = order + self.epsilon = epsilon + + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape) + + def call(self, x): + return _normalize( + x, axis=self.axis, order=self.order, epsilon=self.epsilon + ) + + +@keras_export( + [ + "keras.ops.normalize", + "keras.ops.nn.normalize", + ] +) +def normalize(x, axis=-1, order=2, epsilon=None): + """Normalizes `x` over the specified axis. + + It is defined as: `normalize(x) = x / max(norm(x), epsilon)`. + + Args: + x: Input tensor. + axis: The axis or axes along which to perform normalization. + Default to -1. + order: The exponent value in the norm formulation. + Defaults to 2. + epsilon: A lower bound value for the norm. + Defaults to `backend.epsilon()`. + + Returns: + The normalized array. + + Example: + + >>> x = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]]) + >>> x_norm = keras.ops.math.normalize(x) + >>> print(x_norm) + array([[0.26726124 0.5345225 0.8017837 ] + [0.45584232 0.5698029 0.68376344]], shape=(2, 3), dtype=float32) + + """ + if any_symbolic_tensors((x,)): + return Normalize(axis=axis, order=order, epsilon=epsilon).symbolic_call( + x + ) + return _normalize(x, axis=axis, order=order, epsilon=epsilon) + + +def _normalize(x, axis=-1, order=2, epsilon=None): + if not isinstance(order, int) or not order >= 1: + raise ValueError( + f"Argument `order` must be an int >= 1. Received: order={order}" + ) + x = backend.convert_to_tensor(x) + if len(x.shape) == 0: + x = backend.numpy.expand_dims(x, axis=0) + if epsilon is None: + epsilon = backend.epsilon() + if 2 == order: + # A special case: L2 normalization with `x * rsqrt(...)` + # instead of `x / sqrt(...)` + square_sum = backend.numpy.sum( + backend.numpy.square(x), axis=axis, keepdims=True + ) + inv_norm = backend.math.rsqrt(square_sum) + inv_norm = backend.numpy.minimum(inv_norm, 1.0 / epsilon) + return x * inv_norm + norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True) + denom = backend.numpy.maximum(norm, epsilon) + return backend.numpy.divide(x, denom) + + +class PSNR(Operation): + def __init__( + self, + max_val, + *, + name=None, + ): + super().__init__(name=name) + self.max_val = max_val + + def call(self, x1, x2): + return backend.nn.psnr( + x1=x1, + x2=x2, + max_val=self.max_val, + ) + + def compute_output_spec(self, x1, x2): + if len(x1.shape) != len(x2.shape): + raise ValueError("Inputs must have the same rank") + + return KerasTensor(shape=()) + + +@keras_export( + [ + "keras.ops.psnr", + "keras.ops.nn.psnr", + ] +) +def psnr( + x1, + x2, + max_val, +): + """Peak Signal-to-Noise Ratio (PSNR) function. + + This function computes the Peak Signal-to-Noise Ratio between two signals, + `x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal. + The higher the PSNR, the closer the reconstructed signal is to the original + signal. Note that it can become negative when the signal power is + smaller that the noise power. + + Args: + x1: The first input signal. + x2: The second input signal. Must have the same shape as `x1`. + max_val: The maximum possible value in the signals. + + Returns: + float: The PSNR value between `x1` and `x2`. + + Examples: + + >>> x1 = keras.random.normal((2, 4, 4, 3)) + >>> x2 = keras.random.normal((2, 4, 4, 3)) + >>> max_val = 1.0 + >>> keras.ops.nn.psnr(x1, x2, max_val) + -3.1697404 + """ + if any_symbolic_tensors( + ( + x1, + x2, + ) + ): + return PSNR( + max_val, + ).symbolic_call(x1, x2) + return backend.nn.psnr( + x1, + x2, + max_val, + ) + + +class DotProductAttention(Operation): + def __init__( + self, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, + *, + name=None, + ): + super().__init__(name=name) + self.is_causal = is_causal + self.flash_attention = flash_attention + self.attn_logits_soft_cap = attn_logits_soft_cap + + def call( + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + ): + return backend.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=self.is_causal, + flash_attention=self.flash_attention, + attn_logits_soft_cap=self.attn_logits_soft_cap, + ) + + def compute_output_spec( + self, + query, + key, + value, + bias=None, + mask=None, + scale=None, + ): + dtype = backend.result_type(query.dtype, key.dtype, value.dtype) + return KerasTensor(query.shape, dtype=dtype) + + +@keras_export( + ["keras.ops.dot_product_attention", "keras.ops.nn.dot_product_attention"] +) +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + scale=None, + is_causal=False, + flash_attention=None, + attn_logits_soft_cap=None, +): + """Scaled dot product attention function. + + Computes the attention function on Q (`query`), K (`key`), and V(`value`): + `attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V`. If we define `logits` + as the output of `Q * K` and the `probs` as the output of `softmax`. + + Throughout this function, we utilize the following notation to represent the + shape of array: + - B: batch size + - S: length of the key/value + - T: length of the query + - N: number of attention heads + - H: dimensions of each attention head + - K: number of key/value heads + - G: number of groups, which equals to `N // K` + + Args: + query: The query array with the shape of `(B, T, N, H)`. + key: The key array with the shape of `(B, S, K, H)`. When `K` equals + `N`, multi-headed attention (MHA) is performed. Otherwise, grouped + query attention (GQA) is performed if `N` is a multiple of `K`. and + multi-query attention (MQA) is performed if `K==1` (a special case + of GQA). + value: The value array with the same shape of `key`. + bias: Optional bias array to be added to logits. The shape must be + broadcastable to `(B, N, T, S)`. + mask: Optional mask array used to filter out logits. It is a boolean + mask where `True` indicates the element should take part in + attention. For an additive mask, users should pass it to bias. The + shape must be broadcastable to `(B, N, T, S)`. + scale: Optional scale for the logits. If `None`, the scale will be set + to `1.0 / sqrt(H)`. + is_causal: Whether to apply causal mask. + flash_attention: Whether to use flash attention. If `None`, it will + attempt to use flash attention if the required conditions are met. + Typically, the inputs must be in float16 and bfloat16 dtype and the + input layout requirements may vary depending on the backend. + attn_logits_soft_cap: The value limit for maximum value of the + attention logits before the softmax function is applied. This is + only supported in JAX TPU backend. Defaults to None. + + Returns: + An array of the attention output with the same shape of `query`. + + Example: + + >>> query = keras.random.normal((2, 4, 8, 16)) + >>> key = keras.random.normal((2, 6, 8, 16)) + >>> value = keras.random.normal((2, 6, 8, 16)) + >>> keras.ops.nn.dot_product_attention(query, key, value).shape + (2, 4, 8, 16) + """ + if attn_logits_soft_cap is not None: + if backend.backend() == "jax": + import jax + + if jax.devices()[0].platform != "tpu": + raise ValueError( + "attn_logits_soft_cap is only supported for JAX on TPU. " + "Set attn_logits_soft_cap=None when not using JAX on TPU." + ) + else: + raise ValueError( + "attn_logits_soft_cap is only supported for JAX on TPU. " + "Set attn_logits_soft_cap=None when not using JAX on TPU." + ) + + if any_symbolic_tensors((query, key, value)): + return DotProductAttention( + is_causal=is_causal, + flash_attention=flash_attention, + attn_logits_soft_cap=attn_logits_soft_cap, + ).symbolic_call( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + ) + return backend.nn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + flash_attention=flash_attention, + attn_logits_soft_cap=attn_logits_soft_cap, + ) + + +class RMSNorm(Operation): + def __init__(self, axis=-1, epsilon=None, *, name=None): + super().__init__(name=name) + self.axis = axis + self.epsilon = epsilon + + def compute_output_spec(self, x, scale): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x, scale=None): + return _rms_normalization( + x, scale=scale, axis=self.axis, epsilon=self.epsilon + ) + + +@keras_export( + [ + "keras.ops.rms_normalization", + "keras.ops.nn.rms_normalization", + ] +) +def rms_normalization(x, scale=None, axis=-1, epsilon=None): + """Performs Root Mean Square (RMS) normalization on `x`. + + The Keras operation implements the operation as described in + [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) + by Biao Zhang et al. + + The operation is different from LayerNormalization with RMS scaling. + + It is defined as `rms_normalization(x) = x * rsqrt(mean(square(x))) * scale` + + Args: + x: Input tensor. + scale: Optional scaling factor for the normalization. + axis: The axis or axes along which to perform normalization. Defaults + to `-1`. + epsilon: A lower bound value for the norm. Defaults to + `backend.epsilon()`. + + Returns: + The normalized array. + + Example: + + >>> x = keras.random.normal((1, 10)) + >>> keras.ops.rms_normalization(x) + array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865, + 0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]]) + """ + if any_symbolic_tensors((x, scale)): + return RMSNorm(axis=axis, epsilon=epsilon).symbolic_call(x, scale=scale) + return _rms_normalization(x, scale=scale, axis=axis, epsilon=epsilon) + + +def _rms_normalization(x, scale=None, axis=-1, epsilon=None): + if epsilon is None: + epsilon = backend.epsilon() + original_dtype = backend.standardize_dtype(x.dtype) + # Computes in at least float32 precision for stability in half precision + # training. + compute_dtype = backend.result_type(x.dtype, "float32") + + x = backend.convert_to_tensor(x, dtype=compute_dtype) + if scale is not None: + scale = backend.convert_to_tensor(scale, x.dtype) + + if backend.backend() == "torch" and is_continuous_axis(axis): + import torch.nn.functional as F + + if isinstance(axis, (tuple, list)): + normalized_shape = tuple([x.shape[dim] for dim in axis]) + else: + normalized_shape = (x.shape[axis],) + outputs = F.rms_norm(x, normalized_shape, scale, epsilon) + else: + if len(x.shape) == 0: + x = backend.numpy.expand_dims(x, axis=0) + rrms = backend.math.rsqrt( + backend.numpy.mean( + backend.numpy.square(x), axis=axis, keepdims=True + ) + + epsilon + ) + outputs = backend.numpy.multiply(x, rrms) + if scale is not None: + outputs = backend.numpy.multiply(outputs, scale) + return backend.cast(outputs, original_dtype) + + +class LayerNorm(Operation): + def __init__(self, axis=-1, epsilon=None, rms_scaling=False, *, name=None): + super().__init__(name=name) + self.axis = axis + self.epsilon = epsilon + self.rms_scaling = rms_scaling + + def compute_output_spec(self, x, gamma, beta): + return KerasTensor(shape=x.shape, dtype=x.dtype) + + def call(self, x, gamma=None, beta=None): + return _layer_normalization( + x, + gamma=gamma, + beta=beta, + axis=self.axis, + epsilon=self.epsilon, + rms_scaling=self.rms_scaling, + ) + + +@keras_export( + [ + "keras.ops.layer_normalization", + "keras.ops.nn.layer_normalization", + ] +) +def layer_normalization( + x, gamma=None, beta=None, axis=-1, epsilon=None, **kwargs +): + """Layer normalization layer (Ba et al., 2016). + + Normalize the activations of the previous layer for each given example in a + batch independently, rather than across a batch like Batch Normalization. + i.e. applies a transformation that maintains the mean activation within each + example close to 0 and the activation standard deviation close to 1. + + Args: + x: Input tensor. + gamma: Optional scaling factor for the normalization. + beta: Optional add offset for the normalized tensor. + axis: The axis or axes along which to perform normalization. Default to + `-1`. + epsilon: A lower bound value for the norm. + Defaults to `backend.epsilon()`. + + Returns: + The normalized array. + + Example: + + >>> x = keras.ops.arange(5, dtype="float32") + >>> keras.ops.layer_normalization(x) + array([-1.4142135, -0.70710677, 0.0, 0.7071067, 1.4142135]) + """ + rms_scaling = kwargs.pop("rms_scaling", False) + if rms_scaling: + warnings.warn( + "You passed `rms_scaling=True`, which is deprecated. This argument " + "incorrectly scales the input by the variance, not the root mean " + "square. To correctly use RMS Normalization, please use " + "`keras.ops.rms_normalization` / `keras.ops.nn.rms_normalization` " + "instead." + ) + + if any_symbolic_tensors((x, gamma, beta)): + return LayerNorm( + axis=axis, epsilon=epsilon, rms_scaling=rms_scaling + ).symbolic_call(x, gamma, beta) + return _layer_normalization( + x, + gamma=gamma, + beta=beta, + axis=axis, + epsilon=epsilon, + rms_scaling=rms_scaling, + ) + + +def _layer_normalization( + x, gamma=None, beta=None, axis=-1, epsilon=None, rms_scaling=False +): + if epsilon is None: + epsilon = backend.epsilon() + original_dtype = backend.standardize_dtype(x.dtype) + # Computes in at least float32 precision for stability in half precision + # training. + compute_dtype = backend.result_type(x.dtype, "float32") + + x = backend.convert_to_tensor(x, dtype=compute_dtype) + if gamma is not None: + gamma = backend.convert_to_tensor(gamma, x.dtype) + if beta is not None: + beta = backend.convert_to_tensor(beta, x.dtype) + + # Compute the axes along which to reduce the mean / variance + input_shape = x.shape + ndims = len(input_shape) + + # Broadcasting only necessary for norm when the axis is not just + # the last dimension + broadcast_shape = [1] * ndims + if isinstance(axis, int): + axis = [axis] + for dim in axis: + broadcast_shape[dim] = input_shape[dim] + + def _broadcast(v): + if v is not None and len(v.shape) != ndims and axis != [ndims - 1]: + return backend.numpy.reshape(v, broadcast_shape) + return v + + if rms_scaling: + variance = backend.numpy.var(x, axis=axis, keepdims=True) + inv = backend.math.rsqrt(variance + epsilon) + outputs = outputs = x * inv + if gamma is not None: + outputs = outputs * backend.cast(_broadcast(gamma), x.dtype) + elif backend.config.backend() == "torch" and is_continuous_axis(axis): + # when using torch backend,use kernel to improve performance + import torch.nn.functional as F + + normalized_shape = tuple([input_shape[dim] for dim in axis]) + outputs = F.layer_norm(x, normalized_shape, gamma, beta, epsilon) + else: + # Calculate the mean & variance along self.axis (layer activations). + mean, variance = moments(x, axes=axis, keepdims=True) + gamma, beta = _broadcast(gamma), _broadcast(beta) + inv = backend.math.rsqrt(variance + epsilon) + if gamma is not None: + inv = inv * gamma + + res = -mean * inv + if beta is not None: + res = res + beta + + outputs = x * inv + res + return backend.cast(outputs, original_dtype) + + +class Polar(Operation): + def compute_output_spec(self, abs_, angle): + return KerasTensor(shape=abs_.shape) + + def call(self, abs_, angle): + return _polar(abs_, angle) + + +@keras_export(["keras.ops.polar", "keras.ops.nn.polar"]) +def polar(abs_, angle): + """Constructs a complex tensor whose elements are Cartesian + coordinates corresponding to the polar coordinates + with absolute value `abs` and angle `angle`. + + The operation is numerically equivalent to `torch.polar()`. + It is not equivalent to `scipy.lingalg.polar()` which performs + Singular Value Decomposition. + + Given the magnitude (`abs_`) and angle (`angle`), this function computes the + corresponding complex number in the form of `real + imaginary * 1j`, where: + - `real = abs_ * cos(angle)` + - `imaginary = abs_ * sin(angle)` + + Args: + abs_: The magnitude (absolute value) of the complex number. + angle: The angle (in radians) of the complex number. + + Returns: + A complex number (or array of complex numbers) with the same shape as + `abs_` and `angle`. + + Example: + + >>> abs_ = keras.random.normal((1, 2)) + >>> angle = keras.random.normal((1, 2)) + >>> keras.ops.nn.polar(abs_, angle).shape + (1, 2) + >>> keras.ops.nn.polar(abs_, angle) + Array([[0.63185346-0.59370506j, 0.48960376-0.31677645j]], dtype=complex64) + """ + if any_symbolic_tensors((abs_, angle)): + return Polar().symbolic_call(abs_, angle) + return _polar(abs_, angle) + + +def _polar(abs_, angle): + """Internal implementation of the polar function. + + Args: + abs_: The magnitude (absolute value) of the complex number. + angle: The angle (in radians) of the complex number. + + Returns: + A complex number (or array of complex numbers) with the same shape as + `abs_` and `angle`. + """ + abs_ = backend.convert_to_tensor(abs_) + angle = backend.convert_to_tensor(angle) + + real = abs_ * backend.numpy.cos(angle) + imaginary = abs_ * backend.numpy.sin(angle) + + result = backend.math._get_complex_tensor_from_tuple((real, imaginary)) + + return result + + +class Unfold(Operation): + def __init__( + self, kernel_size, dilation=1, padding=0, stride=1, *, name=None + ): + super().__init__(name=name) + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def compute_output_spec(self, x): + N, C, H, W = x.shape + + def _pair(x): + return (x, x) if isinstance(x, int) else x + + kH, kW = _pair(self.kernel_size) + dH, dW = _pair(self.dilation) + pH, pW = _pair(self.padding) + sH, sW = _pair(self.stride) + + def out_size(L, k, d, p, s): + return (L + 2 * p - d * (k - 1) - 1) // s + 1 + + outH = out_size(H, kH, dH, pH, sH) + outW = out_size(W, kW, dW, pW, sW) + return KerasTensor(shape=(N, C * kH * kW, outH * outW), dtype=x.dtype) + + def call(self, x): + return _unfold( + x, self.kernel_size, self.dilation, self.padding, self.stride + ) + + +@keras_export(["keras.ops.unfold", "keras.ops.nn.unfold"]) +def unfold(x, kernel_size, dilation=1, padding=0, stride=1): + """Extract sliding local blocks from a 4-D input (batched image). + + This operation is known as **im2col** when used with convolution. + It rearranges the image into overlapping or non-overlapping patches + and returns a tensor whose *depth* (last axis) contains the flattened + patches. + + Args: + x: A 4-D tensor of shape `(N, C, H, W)` (**channels-first** format). + kernel_size: int or tuple of two ints, the size of the sliding window + `(kH, kW)`. If a single int is given, it is used for both + dimensions. + dilation: int or tuple of two ints, the spacing between kernel points + (a.k.a. **dilation** or **atrous** convolution). Default: 1. + padding: int or tuple of two ints, the amount of zero-padding to apply + to both spatial dimensions. Default: 0. + stride: int or tuple of two ints, the step size of the sliding window. + Default: 1. + + Returns: + A 3-D tensor of shape `(N, C * kH * kW, L)` where + `L = num_patches_H * num_patches_W` is the total number of patches + extracted. + + Example: + + >>> x = keras.ops.ones((1, 2, 4, 4)) + >>> patches = keras.ops.unfold(x, kernel_size=2, stride=2) + >>> patches.shape + (1, 8, 4) + + """ + input_shape = x.shape + ndims = len(input_shape) + if ndims != 4: + raise ValueError( + f"Input must be a 4D tensor. Received: input.shape={input_shape}" + ) + if any_symbolic_tensors((x,)): + return Unfold(kernel_size, dilation, padding, stride).symbolic_call(x) + return _unfold(x, kernel_size, dilation, padding, stride) + + +def _unfold(x, kernel_size, dilation=1, padding=0, stride=1): + """Internal implementation of unfold.""" + return backend.nn.unfold( + x, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py new file mode 100644 index 000000000000..f4718c495337 --- /dev/null +++ b/keras/src/ops/nn_test.py @@ -0,0 +1,3446 @@ +import math +from itertools import combinations + +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import losses +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.backend.common import dtypes +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.layers.convolutional.conv_test import np_conv1d +from keras.src.layers.convolutional.conv_test import np_conv2d +from keras.src.layers.convolutional.conv_test import np_conv3d +from keras.src.layers.convolutional.conv_transpose_test import ( + np_conv1d_transpose, +) +from keras.src.layers.convolutional.conv_transpose_test import ( + np_conv2d_transpose, +) +from keras.src.layers.convolutional.depthwise_conv_test import ( + np_depthwise_conv2d, +) +from keras.src.layers.pooling.average_pooling_test import np_avgpool1d +from keras.src.layers.pooling.average_pooling_test import np_avgpool2d +from keras.src.layers.pooling.max_pooling_test import np_maxpool1d +from keras.src.layers.pooling.max_pooling_test import np_maxpool2d +from keras.src.ops import nn as knn +from keras.src.ops import numpy as knp +from keras.src.testing.test_utils import named_product + + +def _dot_product_attention( + query, key, value, bias=None, mask=None, scale=None, is_causal=False +): + # A pure and simplified numpy version of `dot_product_attention` + # Ref: jax.nn.dot_product_attention + # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828 + # Not support `query_seq_lengths` and `key_value_seq_lengths` args + + def _apply_masks(logits, mask, is_causal): + def _get_large_negative(dtype): + dtype = backend.standardize_dtype(dtype) + if dtype == "float16": + val = 65500.0 + else: + val = 3.38953e38 + return np.asarray(val * -0.7, dtype=dtype) + + def _get_causal_mask(query_length, key_length): + mask = np.tril(np.ones((query_length, key_length), dtype=np.bool_)) + return mask[None, None, :, :] + + if mask is None and not is_causal: + return logits + combined_mask = np.ones_like(logits, dtype=np.bool_) + if mask is not None: + combined_mask = np.logical_and(combined_mask, mask) + if is_causal: + T, S = logits.shape[2], logits.shape[3] + mask = _get_causal_mask(T, S) + combined_mask = np.logical_and(combined_mask, mask) + padded_logits = np.where( + combined_mask, logits, _get_large_negative(logits.dtype) + ) + return padded_logits + + def softmax(x, axis=None): + exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) + + _, _, _, H = key.shape + scale = (1.0 / np.sqrt(H)) if scale is None else scale + logits = np.einsum("BTNH,BSNH->BNTS", query, key) + logits *= np.array(scale, dtype=logits.dtype) + if bias is not None: + logits = (logits + bias).astype(logits.dtype) + padded_logits = _apply_masks(logits, mask, is_causal) + padded_logits = padded_logits.astype(np.float32) + probs = softmax(padded_logits, axis=-1).astype(key.dtype) + return np.einsum("BNTS,BSNH->BTNH", probs, value) + + +class NNOpsDynamicShapeTest(testing.TestCase): + def test_relu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.relu(x).shape, (None, 2, 3)) + + def test_relu6(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.relu6(x).shape, (None, 2, 3)) + + def test_sigmoid(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.sigmoid(x).shape, (None, 2, 3)) + + def test_sparse_sigmoid(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.sparse_sigmoid(x).shape, (None, 2, 3)) + + def test_softplus(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.softplus(x).shape, (None, 2, 3)) + + def test_softsign(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.softsign(x).shape, (None, 2, 3)) + + def test_silu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.silu(x).shape, (None, 2, 3)) + + def test_log_sigmoid(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.log_sigmoid(x).shape, (None, 2, 3)) + + def test_leaky_relu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.leaky_relu(x).shape, (None, 2, 3)) + + def test_hard_sigmoid(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.hard_sigmoid(x).shape, (None, 2, 3)) + + def test_hard_silu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.hard_silu(x).shape, (None, 2, 3)) + + def test_elu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.elu(x).shape, (None, 2, 3)) + + def test_selu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.selu(x).shape, (None, 2, 3)) + + def test_gelu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.gelu(x).shape, (None, 2, 3)) + + def test_celu(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.celu(x).shape, (None, 2, 3)) + + def test_glu(self): + x = KerasTensor([None, 2, 4]) + self.assertEqual(knn.glu(x).shape, (None, 2, 2)) + + def test_tanh_shrink(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.tanh_shrink(x).shape, (None, 2, 3)) + + def test_hard_tanh(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.hard_tanh(x).shape, (None, 2, 3)) + + def test_hard_shrink(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.hard_shrink(x).shape, (None, 2, 3)) + + def test_threshld(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.threshold(x, 0, 0).shape, (None, 2, 3)) + + def test_squareplus(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.squareplus(x).shape, (None, 2, 3)) + + def test_soft_shrink(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.soft_shrink(x).shape, (None, 2, 3)) + + def test_sparse_plus(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.sparse_plus(x).shape, (None, 2, 3)) + + def test_softmax(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.softmax(x).shape, (None, 2, 3)) + self.assertEqual(knn.softmax(x, axis=1).shape, (None, 2, 3)) + self.assertEqual(knn.softmax(x, axis=-1).shape, (None, 2, 3)) + + def test_softmax_in_graph(self): + class SoftmaxLayer(keras.Layer): + def call(self, x): + return ops.softmax(x, axis=-1) + + class Model(keras.Model): + def __init__(self): + x = keras.Input(shape=(None,)) + y = SoftmaxLayer()(x) + super().__init__(inputs=x, outputs=y) + + # Make sure Keras is able to compile the model graph + model = Model() + x = ops.array([[1.0, 2.0, 3.0, 4.0]]) + model.predict(x) + + def test_log_softmax(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.log_softmax(x).shape, (None, 2, 3)) + self.assertEqual(knn.log_softmax(x, axis=1).shape, (None, 2, 3)) + self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3)) + + def test_sparsemax(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.sparsemax(x).shape, (None, 2, 3)) + + def test_max_pool(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_shape = (None, 8, 3) + else: + input_shape = (None, 3, 8) + x = KerasTensor(input_shape) + self.assertEqual( + knn.max_pool(x, 2, 1).shape, + (None, 7, 3) if data_format == "channels_last" else (None, 3, 7), + ) + self.assertEqual( + knn.max_pool(x, 2, 2, padding="same").shape, + (None, 4, 3) if data_format == "channels_last" else (None, 3, 4), + ) + + if data_format == "channels_last": + input_shape = (None, 8, None, 3) + else: + input_shape = (None, 3, 8, None) + x = KerasTensor(input_shape) + ( + self.assertEqual(knn.max_pool(x, 2, 1).shape, (None, 7, None, 3)) + if data_format == "channels_last" + else (None, 3, 7, None) + ) + self.assertEqual( + knn.max_pool(x, 2, 2, padding="same").shape, + ( + (None, 4, None, 3) + if data_format == "channels_last" + else (None, 3, 4, None) + ), + ) + self.assertEqual( + knn.max_pool(x, (2, 2), (2, 2), padding="same").shape, + ( + (None, 4, None, 3) + if data_format == "channels_last" + else (None, 3, 4, None) + ), + ) + + def test_average_pool(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_shape = (None, 8, 3) + else: + input_shape = (None, 3, 8) + x = KerasTensor(input_shape) + self.assertEqual( + knn.average_pool(x, 2, 1).shape, + (None, 7, 3) if data_format == "channels_last" else (None, 3, 7), + ) + self.assertEqual( + knn.average_pool(x, 2, 2, padding="same").shape, + (None, 4, 3) if data_format == "channels_last" else (None, 3, 4), + ) + + if data_format == "channels_last": + input_shape = (None, 8, None, 3) + else: + input_shape = (None, 3, 8, None) + x = KerasTensor(input_shape) + self.assertEqual( + knn.average_pool(x, 2, 1).shape, + ( + (None, 7, None, 3) + if data_format == "channels_last" + else (None, 3, 7, None) + ), + ) + self.assertEqual( + knn.average_pool(x, 2, 2, padding="same").shape, + ( + (None, 4, None, 3) + if data_format == "channels_last" + else (None, 3, 4, None) + ), + ) + self.assertEqual( + knn.average_pool(x, (2, 2), (2, 2), padding="same").shape, + ( + (None, 4, None, 3) + if data_format == "channels_last" + else (None, 3, 4, None) + ), + ) + + def test_multi_hot(self): + x = KerasTensor([None, 3, 1]) + self.assertEqual(knn.multi_hot(x, 5).shape, (None, 1, 5)) + self.assertEqual(knn.multi_hot(x, 5, 1).shape, (None, 3, 1)) + self.assertEqual(knn.multi_hot(x, 5, 2).shape, (None, 5, 1)) + self.assertSparse(knn.multi_hot(x, 5, sparse=True)) + + @parameterized.named_parameters( + named_product(dtype=["float32", "int32", "bool"], sparse=[False, True]) + ) + def test_multi_hot_dtype(self, dtype, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors") + + x = np.arange(5) + out = knn.multi_hot(x, 5, axis=0, dtype=dtype, sparse=sparse) + self.assertEqual(backend.standardize_dtype(out.dtype), dtype) + self.assertSparse(out, sparse) + + def test_conv(self): + data_format = backend.config.image_data_format() + # Test 1D conv. + if data_format == "channels_last": + input_shape = (None, 20, 3) + else: + input_shape = (None, 3, 20) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([4, 3, 2]) + for padding in ["valid", "VALID"]: + self.assertEqual( + knn.conv(inputs_1d, kernel, 1, padding=padding).shape, + ( + (None, 17, 2) + if data_format == "channels_last" + else (None, 2, 17) + ), + ) + for padding in ["same", "SAME"]: + self.assertEqual( + knn.conv(inputs_1d, kernel, 1, padding=padding).shape, + ( + (None, 20, 2) + if data_format == "channels_last" + else (None, 2, 20) + ), + ) + self.assertEqual( + knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape, + (None, 7, 2) if data_format == "channels_last" else (None, 2, 7), + ) + + # Test 2D conv. + if data_format == "channels_last": + input_shape = (None, 10, None, 3) + else: + input_shape = (None, 3, 10, None) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 2]) + for padding in ["valid", "VALID"]: + self.assertEqual( + knn.conv(inputs_2d, kernel, 1, padding=padding).shape, + ( + (None, 9, None, 2) + if data_format == "channels_last" + else (None, 2, 9, None) + ), + ) + for padding in ["same", "SAME"]: + self.assertEqual( + knn.conv(inputs_2d, kernel, 1, padding=padding).shape, + ( + (None, 10, None, 2) + if data_format == "channels_last" + else (None, 2, 10, None) + ), + ) + self.assertEqual( + knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape, + ( + (None, 4, None, 2) + if data_format == "channels_last" + else (None, 2, 4, None) + ), + ) + + # Test 2D conv - H, W specified + if data_format == "channels_last": + input_shape = (None, 10, 10, 3) + else: + input_shape = (None, 3, 10, 10) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 2]) + for padding in ["valid", "VALID"]: + self.assertEqual( + knn.conv(inputs_2d, kernel, 1, padding=padding).shape, + ( + (None, 9, 9, 2) + if data_format == "channels_last" + else (None, 2, 9, 9) + ), + ) + for padding in ["same", "SAME"]: + self.assertEqual( + knn.conv(inputs_2d, kernel, 1, padding=padding).shape, + ( + (None, 10, 10, 2) + if data_format == "channels_last" + else (None, 2, 10, 10) + ), + ) + self.assertEqual( + knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape, + ( + (None, 4, 9, 2) + if data_format == "channels_last" + else (None, 2, 4, 9) + ), + ) + + # Test 3D conv. + if data_format == "channels_last": + input_shape = (None, 8, None, 8, 3) + else: + input_shape = (None, 3, 8, None, 8) + inputs_3d = KerasTensor(input_shape) + kernel = KerasTensor([3, 3, 3, 3, 2]) + for padding in ["valid", "VALID"]: + self.assertEqual( + knn.conv(inputs_3d, kernel, 1, padding=padding).shape, + ( + (None, 6, None, 6, 2) + if data_format == "channels_last" + else (None, 2, 6, None, 6) + ), + ) + for padding in ["same", "SAME"]: + self.assertEqual( + knn.conv(inputs_3d, kernel, (2, 1, 2), padding=padding).shape, + ( + (None, 4, None, 4, 2) + if data_format == "channels_last" + else (None, 2, 4, None, 4) + ), + ) + self.assertEqual( + knn.conv( + inputs_3d, kernel, 1, padding="valid", dilation_rate=(1, 2, 2) + ).shape, + ( + (None, 6, None, 4, 2) + if data_format == "channels_last" + else (None, 2, 6, None, 4) + ), + ) + + def test_depthwise_conv(self): + data_format = backend.config.image_data_format() + # Test 1D depthwise conv. + if data_format == "channels_last": + input_shape = (None, 20, 3) + else: + input_shape = (None, 3, 20) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([4, 3, 1]) + for padding in ["valid", "VALID"]: + self.assertEqual( + knn.depthwise_conv(inputs_1d, kernel, 1, padding=padding).shape, + ( + (None, 17, 3) + if data_format == "channels_last" + else (None, 3, 17) + ), + ) + for padding in ["same", "SAME"]: + self.assertEqual( + knn.depthwise_conv( + inputs_1d, kernel, (1,), padding=padding + ).shape, + ( + (None, 20, 3) + if data_format == "channels_last" + else (None, 3, 20) + ), + ) + self.assertEqual( + knn.depthwise_conv(inputs_1d, kernel, 2, dilation_rate=2).shape, + (None, 7, 3) if data_format == "channels_last" else (None, 3, 7), + ) + + # Test 2D depthwise conv. + if data_format == "channels_last": + input_shape = (None, 10, 10, 3) + else: + input_shape = (None, 3, 10, 10) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 1]) + for padding in ["valid", "VALID"]: + self.assertEqual( + knn.depthwise_conv(inputs_2d, kernel, 1, padding=padding).shape, + ( + (None, 9, 9, 3) + if data_format == "channels_last" + else (None, 3, 9, 9) + ), + ) + for padding in ["same", "SAME"]: + self.assertEqual( + knn.depthwise_conv( + inputs_2d, kernel, (1, 2), padding=padding + ).shape, + ( + (None, 10, 5, 3) + if data_format == "channels_last" + else (None, 3, 10, 5) + ), + ) + self.assertEqual( + knn.depthwise_conv(inputs_2d, kernel, 2, dilation_rate=2).shape, + ( + (None, 4, 4, 3) + if data_format == "channels_last" + else (None, 3, 4, 4) + ), + ) + self.assertEqual( + knn.depthwise_conv( + inputs_2d, kernel, 2, dilation_rate=(2, 1) + ).shape, + ( + (None, 4, 5, 3) + if data_format == "channels_last" + else (None, 3, 4, 5) + ), + ) + + def test_separable_conv(self): + data_format = backend.config.image_data_format() + # Test 1D separable conv. + if data_format == "channels_last": + input_shape = (None, 20, 3) + else: + input_shape = (None, 3, 20) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([4, 3, 2]) + pointwise_kernel = KerasTensor([1, 6, 5]) + self.assertEqual( + knn.separable_conv( + inputs_1d, kernel, pointwise_kernel, 1, padding="valid" + ).shape, + (None, 17, 5) if data_format == "channels_last" else (None, 5, 17), + ) + self.assertEqual( + knn.separable_conv( + inputs_1d, kernel, pointwise_kernel, 1, padding="same" + ).shape, + (None, 20, 5) if data_format == "channels_last" else (None, 5, 20), + ) + self.assertEqual( + knn.separable_conv( + inputs_1d, kernel, pointwise_kernel, 2, dilation_rate=2 + ).shape, + (None, 7, 5) if data_format == "channels_last" else (None, 5, 7), + ) + + # Test 2D separable conv. + if data_format == "channels_last": + input_shape = (None, 10, 10, 3) + else: + input_shape = (None, 3, 10, 10) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 2]) + pointwise_kernel = KerasTensor([1, 1, 6, 5]) + self.assertEqual( + knn.separable_conv( + inputs_2d, kernel, pointwise_kernel, 1, padding="valid" + ).shape, + ( + (None, 9, 9, 5) + if data_format == "channels_last" + else (None, 5, 9, 9) + ), + ) + self.assertEqual( + knn.separable_conv( + inputs_2d, kernel, pointwise_kernel, (1, 2), padding="same" + ).shape, + ( + (None, 10, 5, 5) + if data_format == "channels_last" + else (None, 5, 10, 5) + ), + ) + self.assertEqual( + knn.separable_conv( + inputs_2d, kernel, pointwise_kernel, 2, dilation_rate=(2, 1) + ).shape, + ( + (None, 4, 5, 5) + if data_format == "channels_last" + else (None, 5, 4, 5) + ), + ) + + def test_conv_transpose(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_shape = (None, 4, 3) + else: + input_shape = (None, 3, 4) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([2, 5, 3]) + self.assertEqual( + knn.conv_transpose(inputs_1d, kernel, 2).shape, + (None, 8, 5) if data_format == "channels_last" else (None, 5, 8), + ) + self.assertEqual( + knn.conv_transpose(inputs_1d, kernel, 2, padding="same").shape, + (None, 8, 5) if data_format == "channels_last" else (None, 5, 8), + ) + self.assertEqual( + knn.conv_transpose( + inputs_1d, kernel, 5, padding="valid", output_padding=4 + ).shape, + (None, 21, 5) if data_format == "channels_last" else (None, 5, 21), + ) + + if data_format == "channels_last": + input_shape = (None, 4, 4, 3) + else: + input_shape = (None, 3, 4, 4) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 5, 3]) + self.assertEqual( + knn.conv_transpose(inputs_2d, kernel, 2).shape, + ( + (None, 8, 8, 5) + if data_format == "channels_last" + else (None, 5, 8, 8) + ), + ) + self.assertEqual( + knn.conv_transpose(inputs_2d, kernel, (2, 2), padding="same").shape, + ( + (None, 8, 8, 5) + if data_format == "channels_last" + else (None, 5, 8, 8) + ), + ) + self.assertEqual( + knn.conv_transpose( + inputs_2d, kernel, (5, 5), padding="valid", output_padding=4 + ).shape, + ( + (None, 21, 21, 5) + if data_format == "channels_last" + else (None, 5, 21, 21) + ), + ) + + def test_one_hot(self): + x = KerasTensor([None, 3, 1]) + self.assertEqual(knn.one_hot(x, 5).shape, (None, 3, 1, 5)) + self.assertEqual(knn.one_hot(x, 5, 1).shape, (None, 5, 3, 1)) + self.assertEqual(knn.one_hot(x, 5, 2).shape, (None, 3, 5, 1)) + self.assertSparse(knn.one_hot(x, 5, sparse=True)) + + @parameterized.named_parameters( + named_product(dtype=["float32", "int32", "bool"], sparse=[False, True]) + ) + def test_one_hot_dtype(self, dtype, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors") + + x = np.arange(5) + out = knn.one_hot(x, 5, axis=0, dtype=dtype, sparse=sparse) + self.assertEqual(backend.standardize_dtype(out.dtype), dtype) + self.assertSparse(out, sparse) + + def test_moments(self): + x = KerasTensor([None, 3, 4]) + self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4)) + self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,)) + self.assertEqual( + knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4) + ) + + self.assertEqual(knn.moments(x, axes=[1])[0].shape, (None, 4)) + self.assertEqual(knn.moments(x, axes=[1, 2])[0].shape, (None,)) + self.assertEqual( + knn.moments(x, axes=[1, 2], keepdims=True)[0].shape, (None, 1, 1) + ) + + def test_batch_normalization(self): + x = KerasTensor([None, 3, 4]) + mean = KerasTensor([4]) + variance = KerasTensor([4]) + self.assertEqual( + knn.batch_normalization(x, mean, variance, axis=-1).shape, + (None, 3, 4), + ) + + x = KerasTensor([None, 3, 4, 5]) + self.assertEqual( + knn.batch_normalization(x, mean, variance, axis=2).shape, + (None, 3, 4, 5), + ) + + mean = KerasTensor([3]) + variance = KerasTensor([3]) + self.assertEqual( + knn.batch_normalization(x, mean, variance, axis=1).shape, + (None, 3, 4, 5), + ) + + # Test wrong offset shape + self.assertRaisesRegex( + ValueError, + "`offset` must be a vector of length", + knn.batch_normalization, + KerasTensor([None, 3, 4, 5]), + KerasTensor([5]), + KerasTensor([5]), + axis=-1, + offset=KerasTensor([3]), + scale=KerasTensor([5]), + ) + + # Test wrong scale shape + self.assertRaisesRegex( + ValueError, + "`scale` must be a vector of length", + knn.batch_normalization, + KerasTensor([None, 3, 4, 5]), + KerasTensor([5]), + KerasTensor([5]), + axis=-1, + offset=KerasTensor([5]), + scale=KerasTensor([3]), + ) + + def test_ctc_decode(self): + # Test strategy="greedy" + inputs = KerasTensor([None, 2, 3]) + sequence_lengths = KerasTensor([None]) + decoded, scores = knn.ctc_decode(inputs, sequence_lengths) + self.assertEqual(decoded.shape, (1, None, 2)) + self.assertEqual(scores.shape, (None, 1)) + + # Test strategy="beam_search" + inputs = KerasTensor([None, 2, 3]) + sequence_lengths = KerasTensor([None]) + decoded, scores = knn.ctc_decode( + inputs, sequence_lengths, strategy="beam_search", top_paths=2 + ) + self.assertEqual(decoded.shape, (2, None, 2)) + self.assertEqual(scores.shape, (None, 2)) + + def test_normalize(self): + x = KerasTensor([None, 2, 3]) + self.assertEqual(knn.normalize(x).shape, (None, 2, 3)) + + def test_psnr(self): + x1 = KerasTensor([None, 2, 3]) + x2 = KerasTensor([None, 5, 6]) + out = knn.psnr(x1, x2, max_val=224) + self.assertEqual(out.shape, ()) + + def test_dot_product_attention(self): + query = KerasTensor([None, None, 8, 16]) + key = KerasTensor([None, None, 6, 16]) + value = KerasTensor([None, None, 6, 16]) + out = knn.dot_product_attention(query, key, value) + self.assertEqual(out.shape, query.shape) + + def test_rms_normalization(self): + x = KerasTensor([None, 8, 16]) + scale = KerasTensor([None, 8, 16]) + out = knn.rms_normalization(x, scale) + self.assertEqual(out.shape, x.shape) + + def test_layer_normalization(self): + x = KerasTensor([None, 8, 16]) + gamma = KerasTensor([None, 16]) + beta = KerasTensor([None, 16]) + out = knn.layer_normalization(x, gamma, beta) + self.assertEqual(out.shape, x.shape) + + +class NNOpsStaticShapeTest(testing.TestCase): + def test_relu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.relu(x).shape, (1, 2, 3)) + + def test_relu6(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.relu6(x).shape, (1, 2, 3)) + + def test_sigmoid(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.sigmoid(x).shape, (1, 2, 3)) + + def test_sparse_sigmoid(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.sparse_sigmoid(x).shape, (1, 2, 3)) + + def test_softplus(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.softplus(x).shape, (1, 2, 3)) + + def test_softsign(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.softsign(x).shape, (1, 2, 3)) + + def test_silu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.silu(x).shape, (1, 2, 3)) + + def test_log_sigmoid(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.log_sigmoid(x).shape, (1, 2, 3)) + + def test_leaky_relu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.leaky_relu(x).shape, (1, 2, 3)) + + def test_hard_sigmoid(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.hard_sigmoid(x).shape, (1, 2, 3)) + + def test_hard_silu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.hard_silu(x).shape, (1, 2, 3)) + + def test_elu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.elu(x).shape, (1, 2, 3)) + + def test_selu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.selu(x).shape, (1, 2, 3)) + + def test_gelu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.gelu(x).shape, (1, 2, 3)) + + def test_celu(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.celu(x).shape, (1, 2, 3)) + + def test_glu(self): + x = KerasTensor([1, 2, 4]) + self.assertEqual(knn.glu(x).shape, (1, 2, 2)) + + def test_tanh_shrink(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.tanh_shrink(x).shape, (1, 2, 3)) + + def test_hard_tanh(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.hard_tanh(x).shape, (1, 2, 3)) + + def test_hard_shrink(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.hard_shrink(x).shape, (1, 2, 3)) + + def test_threshold(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.threshold(x, 0, 0).shape, (1, 2, 3)) + + def test_squareplus(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.squareplus(x).shape, (1, 2, 3)) + + def test_soft_shrink(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.soft_shrink(x).shape, (1, 2, 3)) + + def test_sparse_plus(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.sparse_plus(x).shape, (1, 2, 3)) + + def test_softmax(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.softmax(x).shape, (1, 2, 3)) + self.assertEqual(knn.softmax(x, axis=1).shape, (1, 2, 3)) + self.assertEqual(knn.softmax(x, axis=-1).shape, (1, 2, 3)) + + def test_log_softmax(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.log_softmax(x).shape, (1, 2, 3)) + self.assertEqual(knn.log_softmax(x, axis=1).shape, (1, 2, 3)) + self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3)) + + def test_sparsemax(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.sparsemax(x).shape, (1, 2, 3)) + + def test_max_pool(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_shape = (1, 8, 3) + else: + input_shape = (1, 3, 8) + x = KerasTensor(input_shape) + self.assertEqual( + knn.max_pool(x, 2, 1).shape, + (1, 7, 3) if data_format == "channels_last" else (1, 3, 7), + ) + self.assertEqual( + knn.max_pool(x, 2, 2, padding="same").shape, + (1, 4, 3) if data_format == "channels_last" else (1, 3, 4), + ) + + if data_format == "channels_last": + input_shape = (1, 8, 8, 3) + else: + input_shape = (1, 3, 8, 8) + x = KerasTensor(input_shape) + self.assertEqual( + knn.max_pool(x, 2, 1).shape, + (1, 7, 7, 3) if data_format == "channels_last" else (1, 3, 7, 7), + ) + self.assertEqual( + knn.max_pool(x, 2, 2, padding="same").shape, + (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4), + ) + self.assertEqual( + knn.max_pool(x, (2, 2), (2, 2), padding="same").shape, + (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4), + ) + + def test_average_pool(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_shape = (1, 8, 3) + else: + input_shape = (1, 3, 8) + x = KerasTensor(input_shape) + self.assertEqual( + knn.average_pool(x, 2, 1).shape, + (1, 7, 3) if data_format == "channels_last" else (1, 3, 7), + ) + self.assertEqual( + knn.average_pool(x, 2, 2, padding="same").shape, + (1, 4, 3) if data_format == "channels_last" else (1, 3, 4), + ) + + if data_format == "channels_last": + input_shape = (1, 8, 8, 3) + else: + input_shape = (1, 3, 8, 8) + x = KerasTensor(input_shape) + self.assertEqual( + knn.average_pool(x, 2, 1).shape, + (1, 7, 7, 3) if data_format == "channels_last" else (1, 3, 7, 7), + ) + self.assertEqual( + knn.average_pool(x, 2, 2, padding="same").shape, + (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4), + ) + self.assertEqual( + knn.average_pool(x, (2, 2), (2, 2), padding="same").shape, + (1, 4, 4, 3) if data_format == "channels_last" else (1, 3, 4, 4), + ) + + def test_conv(self): + data_format = backend.config.image_data_format() + # Test 1D conv. + if data_format == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([4, 3, 2]) + self.assertEqual( + knn.conv(inputs_1d, kernel, 1, padding="valid").shape, + (2, 17, 2) if data_format == "channels_last" else (2, 2, 17), + ) + self.assertEqual( + knn.conv(inputs_1d, kernel, 1, padding="same").shape, + (2, 20, 2) if data_format == "channels_last" else (2, 2, 20), + ) + self.assertEqual( + knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape, + (2, 7, 2) if data_format == "channels_last" else (2, 2, 7), + ) + + # Test 2D conv. + if data_format == "channels_last": + input_shape = (2, 10, 10, 3) + else: + input_shape = (2, 3, 10, 10) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 2]) + self.assertEqual( + knn.conv(inputs_2d, kernel, 1, padding="valid").shape, + (2, 9, 9, 2) if data_format == "channels_last" else (2, 2, 9, 9), + ) + self.assertEqual( + knn.conv(inputs_2d, kernel, 1, padding="same").shape, + ( + (2, 10, 10, 2) + if data_format == "channels_last" + else (2, 2, 10, 10) + ), + ) + self.assertEqual( + knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape, + (2, 4, 9, 2) if data_format == "channels_last" else (2, 2, 4, 9), + ) + + # Test 3D conv. + if data_format == "channels_last": + input_shape = (2, 8, 8, 8, 3) + else: + input_shape = (2, 3, 8, 8, 8) + inputs_3d = KerasTensor(input_shape) + kernel = KerasTensor([3, 3, 3, 3, 2]) + self.assertEqual( + knn.conv(inputs_3d, kernel, 1, padding="valid").shape, + ( + (2, 6, 6, 6, 2) + if data_format == "channels_last" + else (2, 2, 6, 6, 6) + ), + ) + self.assertEqual( + knn.conv(inputs_3d, kernel, (2, 1, 2), padding="same").shape, + ( + (2, 4, 8, 4, 2) + if data_format == "channels_last" + else (2, 2, 4, 8, 4) + ), + ) + self.assertEqual( + knn.conv( + inputs_3d, kernel, 1, padding="valid", dilation_rate=(1, 2, 2) + ).shape, + ( + (2, 6, 4, 4, 2) + if data_format == "channels_last" + else (2, 2, 6, 4, 4) + ), + ) + + def test_depthwise_conv(self): + data_format = backend.config.image_data_format() + # Test 1D depthwise conv. + if data_format == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([4, 3, 1]) + self.assertEqual( + knn.depthwise_conv(inputs_1d, kernel, 1, padding="valid").shape, + (2, 17, 3) if data_format == "channels_last" else (2, 3, 17), + ) + self.assertEqual( + knn.depthwise_conv(inputs_1d, kernel, (1,), padding="same").shape, + (2, 20, 3) if data_format == "channels_last" else (2, 3, 20), + ) + self.assertEqual( + knn.depthwise_conv(inputs_1d, kernel, 2, dilation_rate=2).shape, + (2, 7, 3) if data_format == "channels_last" else (2, 3, 7), + ) + + # Test 2D depthwise conv. + if data_format == "channels_last": + input_shape = (2, 10, 10, 3) + else: + input_shape = (2, 3, 10, 10) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 1]) + self.assertEqual( + knn.depthwise_conv(inputs_2d, kernel, 1, padding="valid").shape, + (2, 9, 9, 3) if data_format == "channels_last" else (2, 3, 9, 9), + ) + self.assertEqual( + knn.depthwise_conv(inputs_2d, kernel, (1, 2), padding="same").shape, + (2, 10, 5, 3) if data_format == "channels_last" else (2, 3, 10, 5), + ) + self.assertEqual( + knn.depthwise_conv(inputs_2d, kernel, 2, dilation_rate=2).shape, + (2, 4, 4, 3) if data_format == "channels_last" else (2, 3, 4, 4), + ) + self.assertEqual( + knn.depthwise_conv( + inputs_2d, kernel, 2, dilation_rate=(2, 1) + ).shape, + (2, 4, 5, 3) if data_format == "channels_last" else (2, 3, 4, 5), + ) + + def test_separable_conv(self): + data_format = backend.config.image_data_format() + # Test 1D max pooling. + if data_format == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([4, 3, 2]) + pointwise_kernel = KerasTensor([1, 6, 5]) + self.assertEqual( + knn.separable_conv( + inputs_1d, kernel, pointwise_kernel, 1, padding="valid" + ).shape, + (2, 17, 5) if data_format == "channels_last" else (2, 5, 17), + ) + self.assertEqual( + knn.separable_conv( + inputs_1d, kernel, pointwise_kernel, 1, padding="same" + ).shape, + (2, 20, 5) if data_format == "channels_last" else (2, 5, 20), + ) + self.assertEqual( + knn.separable_conv( + inputs_1d, kernel, pointwise_kernel, 2, dilation_rate=2 + ).shape, + (2, 7, 5) if data_format == "channels_last" else (2, 5, 7), + ) + + # Test 2D separable conv. + if data_format == "channels_last": + input_shape = (2, 10, 10, 3) + else: + input_shape = (2, 3, 10, 10) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 3, 2]) + pointwise_kernel = KerasTensor([1, 1, 6, 5]) + self.assertEqual( + knn.separable_conv( + inputs_2d, kernel, pointwise_kernel, 1, padding="valid" + ).shape, + (2, 9, 9, 5) if data_format == "channels_last" else (2, 5, 9, 9), + ) + self.assertEqual( + knn.separable_conv( + inputs_2d, kernel, pointwise_kernel, (1, 2), padding="same" + ).shape, + (2, 10, 5, 5) if data_format == "channels_last" else (2, 5, 10, 5), + ) + self.assertEqual( + knn.separable_conv( + inputs_2d, kernel, pointwise_kernel, 2, dilation_rate=(2, 1) + ).shape, + (2, 4, 5, 5) if data_format == "channels_last" else (2, 5, 4, 5), + ) + + def test_conv_transpose(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_shape = (2, 4, 3) + else: + input_shape = (2, 3, 4) + inputs_1d = KerasTensor(input_shape) + kernel = KerasTensor([2, 5, 3]) + self.assertEqual( + knn.conv_transpose(inputs_1d, kernel, 2).shape, + (2, 8, 5) if data_format == "channels_last" else (2, 5, 8), + ) + self.assertEqual( + knn.conv_transpose(inputs_1d, kernel, 2, padding="same").shape, + (2, 8, 5) if data_format == "channels_last" else (2, 5, 8), + ) + self.assertEqual( + knn.conv_transpose( + inputs_1d, kernel, 5, padding="valid", output_padding=4 + ).shape, + (2, 21, 5) if data_format == "channels_last" else (2, 5, 21), + ) + + if data_format == "channels_last": + input_shape = (2, 4, 4, 3) + else: + input_shape = (2, 3, 4, 4) + inputs_2d = KerasTensor(input_shape) + kernel = KerasTensor([2, 2, 5, 3]) + self.assertEqual( + knn.conv_transpose(inputs_2d, kernel, 2).shape, + (2, 8, 8, 5) if data_format == "channels_last" else (2, 5, 8, 8), + ) + self.assertEqual( + knn.conv_transpose(inputs_2d, kernel, (2, 2), padding="same").shape, + (2, 8, 8, 5) if data_format == "channels_last" else (2, 5, 8, 8), + ) + self.assertEqual( + knn.conv_transpose( + inputs_2d, kernel, (5, 5), padding="valid", output_padding=4 + ).shape, + ( + (2, 21, 21, 5) + if data_format == "channels_last" + else (2, 5, 21, 21) + ), + ) + + def test_batched_and_unbatched_inputs_multi_hot(self): + x = KerasTensor([2, 3, 1]) + unbatched_input = KerasTensor( + [ + 5, + ] + ) + self.assertEqual(knn.multi_hot(unbatched_input, 5, -1).shape, (5,)) + self.assertEqual(knn.multi_hot(x, 5).shape, (2, 1, 5)) + self.assertEqual(knn.multi_hot(x, 5, 1).shape, (2, 3, 1)) + self.assertEqual(knn.multi_hot(x, 5, 2).shape, (2, 5, 1)) + + def test_one_hot(self): + x = KerasTensor([2, 3, 1]) + self.assertEqual(knn.one_hot(x, 5).shape, (2, 3, 1, 5)) + self.assertEqual(knn.one_hot(x, 5, 1).shape, (2, 5, 3, 1)) + self.assertEqual(knn.one_hot(x, 5, 2).shape, (2, 3, 5, 1)) + self.assertSparse(knn.one_hot(x, 5, sparse=True)) + + def test_binary_crossentropy(self): + x1 = KerasTensor([2, 3, 1]) + x2 = KerasTensor([2, 3, 1]) + self.assertEqual(knn.binary_crossentropy(x1, x2).shape, (2, 3, 1)) + + def test_categorical_crossentropy(self): + x1 = KerasTensor([2, 3, 4]) + x2 = KerasTensor([2, 3, 4]) + self.assertEqual(knn.categorical_crossentropy(x1, x2).shape, (2, 3)) + + def test_sparse_categorical_crossentropy(self): + x1 = KerasTensor([2, 3], dtype="int32") + x2 = KerasTensor([2, 3, 4]) + self.assertEqual( + knn.sparse_categorical_crossentropy(x1, x2).shape, (2, 3) + ) + + def test_moments(self): + x = KerasTensor([2, 3, 4]) + self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4)) + self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,)) + self.assertEqual( + knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4) + ) + + def test_batch_normalization(self): + x = KerasTensor([10, 3, 4]) + mean = KerasTensor([4]) + variance = KerasTensor([4]) + self.assertEqual( + knn.batch_normalization(x, mean, variance, axis=-1).shape, + (10, 3, 4), + ) + + x = KerasTensor([10, 3, 4, 5]) + self.assertEqual( + knn.batch_normalization(x, mean, variance, axis=2).shape, + (10, 3, 4, 5), + ) + + mean = KerasTensor([3]) + variance = KerasTensor([3]) + self.assertEqual( + knn.batch_normalization(x, mean, variance, axis=1).shape, + (10, 3, 4, 5), + ) + + def test_ctc_loss(self): + x = KerasTensor([10, 3, 4]) + y = KerasTensor([10, 3], dtype="int32") + x_lengths = KerasTensor([10], dtype="int32") + y_lengths = KerasTensor([10], dtype="int32") + self.assertEqual(knn.ctc_loss(x, y, x_lengths, y_lengths).shape, (10,)) + + def test_ctc_decode(self): + # Test strategy="greedy" + inputs = KerasTensor([10, 2, 3]) + sequence_lengths = KerasTensor([10]) + decoded, scores = knn.ctc_decode(inputs, sequence_lengths) + self.assertEqual(decoded.shape, (1, 10, 2)) + self.assertEqual(scores.shape, (10, 1)) + + # Test strategy="beam_search" + inputs = KerasTensor([10, 2, 3]) + sequence_lengths = KerasTensor([10]) + decoded, scores = knn.ctc_decode( + inputs, sequence_lengths, strategy="beam_search", top_paths=2 + ) + self.assertEqual(decoded.shape, (2, 10, 2)) + self.assertEqual(scores.shape, (10, 2)) + + def test_normalize(self): + x = KerasTensor([1, 2, 3]) + self.assertEqual(knn.normalize(x).shape, (1, 2, 3)) + + def test_psnr(self): + x1 = KerasTensor([1, 2, 3]) + x2 = KerasTensor([5, 6, 7]) + out = knn.psnr(x1, x2, max_val=224) + self.assertEqual(out.shape, ()) + + def test_dot_product_attention(self): + query = KerasTensor([2, 3, 8, 16]) + key = KerasTensor([2, 4, 6, 16]) + value = KerasTensor([2, 4, 6, 16]) + out = knn.dot_product_attention(query, key, value) + self.assertEqual(out.shape, query.shape) + + def test_rms_normalization(self): + x = KerasTensor([2, 8, 16]) + scale = KerasTensor([2, 8, 16]) + self.assertEqual(knn.rms_normalization(x, scale).shape, x.shape) + + def test_layer_normalization(self): + x = KerasTensor([2, 8, 16]) + gamma = KerasTensor([2, 16]) + beta = KerasTensor([2, 16]) + self.assertEqual(knn.layer_normalization(x, gamma, beta).shape, x.shape) + + def test_polar(self): + abs_ = KerasTensor([1, 2]) + angle = KerasTensor([3, 4]) + out = knn.polar(abs_, angle) + self.assertEqual(out.shape, abs_.shape) + + +class NNOpsCorrectnessTest(testing.TestCase): + def test_relu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3]) + + def test_relu6(self): + x = np.array([-1, 0, 1, 2, 3, 4, 5, 6, 7], dtype=np.float32) + self.assertAllClose(knn.relu6(x), [0, 0, 1, 2, 3, 4, 5, 6, 6]) + + def test_sigmoid(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.sigmoid(x), [0.26894143, 0.5, 0.7310586, 0.880797, 0.95257413] + ) + + def test_sparse_sigmoid(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose(knn.sparse_sigmoid(x), [0.0, 0.5, 1.0, 1.0, 1.0]) + + def test_softplus(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.softplus(x), + [0.31326166, 0.6931472, 1.3132616, 2.126928, 3.0485873], + ) + + def test_softsign(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose(knn.softsign(x), [-0.5, 0, 0.5, 0.6666667, 0.75]) + + def test_silu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.silu(x), + [-0.26894143, 0, 0.7310586, 1.7615942, 2.8577223], + ) + + def test_log_sigmoid(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.log_sigmoid(x), + [-1.3132616, -0.6931472, -0.31326166, -0.126928, -0.04858732], + ) + + def test_leaky_relu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.leaky_relu(x), + [-0.2, 0, 1, 2, 3], + ) + + def test_hard_sigmoid(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.hard_sigmoid(x), + [0.33333334, 0.5, 0.6666667, 0.8333334, 1.0], + ) + + def test_hard_silu(self): + x = np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.hard_silu(x), + [-0.0, -0.333333, -0.333333, 0.0, 0.6666667, 1.6666667, 3.0], + ) + + def test_elu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.elu(x), + [-0.63212055, 0, 1, 2, 3], + ) + self.assertAllClose( + knn.elu(x, alpha=0.5), + [-0.31606027, 0, 1, 2, 3], + ) + + def test_selu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.selu(x), + [-1.1113307, 0.0, 1.050701, 2.101402, 3.152103], + ) + + def test_gelu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.gelu(x), + [-0.15880796, 0.0, 0.841192, 1.9545977, 2.9963627], + ) + + def test_celu(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.celu(x), + [-0.63212055, 0.0, 1.0, 2.0, 3.0], + ) + + def test_glu(self): + x = np.array([-1, 0, 1, 2, 3, 4], dtype=np.float32) + self.assertAllClose( + knn.glu(x), + [-0.8807971, 0.0, 0.98201376], + ) + + def test_tanh_shrink(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.tanh_shrink(x), + [-0.238406, 0.0, 0.238406, 1.035972, 2.004945], + ) + + def test_hard_tanh(self): + x = np.array([-1, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.hard_tanh(x), + [-1.0, 0.0, 1.0, 1.0, 1.0], + ) + + def test_hard_shrink(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.hard_shrink(x), + [0.0, 0.0, 1.0, 2.0, 3.0], + ) + + def test_threshold(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.threshold(x, 0, 0), + [0.0, 0.0, 1.0, 2.0, 3.0], + ) + + def test_squareplus(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.squareplus(x), + [0.780776, 1.0, 1.618034, 2.414214, 3.302776], + ) + + def test_soft_shrink(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.soft_shrink(x), + [0.0, 0.0, 0.5, 1.5, 2.5], + ) + + def test_sparse_plus(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.sparse_plus(x), + [0.0625, 0.25, 1.0, 2.0, 3.0], + ) + + def test_softmax(self): + x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + self.assertAllClose( + knn.softmax(x, axis=None), # Reduce on all axes. + [[0.045015, 0.122364, 0.33262], [0.045015, 0.122364, 0.33262]], + ) + self.assertAllClose( + knn.softmax(x, axis=0), + [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]], + ) + self.assertAllClose( + knn.softmax(x, axis=-1), + [ + [0.09003057, 0.24472848, 0.66524094], + [0.09003057, 0.24472848, 0.66524094], + ], + ) + self.assertAllClose( + knn.softmax(x), # Default axis should be -1. + [ + [0.09003057, 0.24472848, 0.66524094], + [0.09003057, 0.24472848, 0.66524094], + ], + ) + + def test_softmax_correctness_with_axis_tuple(self): + input = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + combination = combinations(range(3), 2) + for axis in list(combination): + result = keras.ops.nn.softmax(input, axis=axis) + normalized_sum_by_axis = np.sum( + ops.convert_to_numpy(result), axis=axis + ) + self.assertAllClose(normalized_sum_by_axis, 1.0) + + def test_log_softmax(self): + x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + self.assertAllClose( + knn.log_softmax(x, axis=None), # Reduce on all axes. + [ + [-3.100753, -2.100753, -1.100753], + [-3.100753, -2.100753, -1.100753], + ], + ) + self.assertAllClose( + knn.log_softmax(x, axis=0), + [ + [-0.693147, -0.693147, -0.693147], + [-0.693147, -0.693147, -0.693147], + ], + ) + self.assertAllClose( + knn.log_softmax(x, axis=-1), + [ + [-2.407606, -1.407606, -0.407606], + [-2.407606, -1.407606, -0.407606], + ], + ) + self.assertAllClose( + knn.log_softmax(x), # Default axis should be -1. + [ + [-2.407606, -1.407606, -0.407606], + [-2.407606, -1.407606, -0.407606], + ], + ) + + def test_log_softmax_correctness_with_axis_tuple(self): + input = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + combination = combinations(range(3), 2) + for axis in list(combination): + result = keras.ops.nn.log_softmax(input, axis=axis) + normalized_sum_by_axis = np.sum( + np.exp(ops.convert_to_numpy(result)), axis=axis + ) + self.assertAllClose(normalized_sum_by_axis, 1.0) + + def test_polar_corectness(self): + abs_ = np.array([1, 2], dtype="float32") + angle = np.array([2, 3], dtype="float32") + out = knn.polar(abs_, angle) + self.assertAllClose( + out, [-0.41614684 + 0.9092974j, -1.979985 + 0.28224j], atol=1e-3 + ) + + def test_sparsemax(self): + x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32) + self.assertAllClose( + knn.sparsemax(x), + [0.0, 0.0, 0.0, 0.0, 1.0], + ) + + def test_max_pool(self): + data_format = backend.config.image_data_format() + # Test 1D max pooling. + if data_format == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + x = np.arange(120, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.max_pool(x, 2, 1, padding="valid"), + np_maxpool1d(x, 2, 1, padding="valid", data_format=data_format), + ) + self.assertAllClose( + knn.max_pool(x, 2, 2, padding="same"), + np_maxpool1d(x, 2, 2, padding="same", data_format=data_format), + ) + + # Test 2D max pooling. + if data_format == "channels_last": + input_shape = (2, 10, 9, 3) + else: + input_shape = (2, 3, 10, 9) + x = np.arange(540, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.max_pool(x, 2, 1, padding="valid"), + np_maxpool2d(x, 2, 1, padding="valid", data_format=data_format), + ) + self.assertAllClose( + knn.max_pool(x, 2, (2, 1), padding="same"), + np_maxpool2d(x, 2, (2, 1), padding="same", data_format=data_format), + ) + + def test_average_pool_valid_padding(self): + data_format = backend.config.image_data_format() + # Test 1D average pooling. + if data_format == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + x = np.arange(120, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.average_pool(x, 2, 1, padding="valid"), + np_avgpool1d(x, 2, 1, padding="valid", data_format=data_format), + ) + + # Test 2D average pooling. + if data_format == "channels_last": + input_shape = (2, 10, 9, 3) + else: + input_shape = (2, 3, 10, 9) + x = np.arange(540, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.average_pool(x, 2, 1, padding="valid"), + np_avgpool2d(x, 2, 1, padding="valid", data_format=data_format), + ) + + def test_average_pool_same_padding(self): + data_format = backend.config.image_data_format() + # Test 1D average pooling. + if data_format == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + x = np.arange(120, dtype=float).reshape(input_shape) + + self.assertAllClose( + knn.average_pool(x, 2, 2, padding="same"), + np_avgpool1d(x, 2, 2, padding="same", data_format=data_format), + ) + + # Test 2D average pooling. + if data_format == "channels_last": + input_shape = (2, 10, 9, 3) + else: + input_shape = (2, 3, 10, 9) + x = np.arange(540, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.average_pool(x, 2, (2, 1), padding="same"), + np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format), + ) + # Test 2D average pooling with different pool size. + if data_format == "channels_last": + input_shape = (2, 10, 9, 3) + else: + input_shape = (2, 3, 10, 9) + x = np.arange(540, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.average_pool(x, (2, 3), (3, 3), padding="same"), + np_avgpool2d( + x, (2, 3), (3, 3), padding="same", data_format=data_format + ), + ) + + @parameterized.product( + strides=(1, 2, 3), + padding=("valid", "same"), + dilation_rate=(1, 2), + ) + def test_conv_1d(self, strides, padding, dilation_rate): + if strides > 1 and dilation_rate > 1: + pytest.skip("Unsupported configuration") + + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 20, 3) + else: + input_shape = (2, 3, 20) + inputs_1d = np.arange(120, dtype=float).reshape(input_shape) + kernel = np.arange(24, dtype=float).reshape([4, 3, 2]) + + outputs = knn.conv( + inputs_1d, + kernel, + strides=strides, + padding=padding, + dilation_rate=dilation_rate, + ) + expected = np_conv1d( + inputs_1d, + kernel, + bias_weights=np.zeros((2,)), + strides=strides, + padding=padding.lower(), + data_format=backend.config.image_data_format(), + dilation_rate=dilation_rate, + groups=1, + ) + self.assertAllClose(outputs, expected) + + @parameterized.product(strides=(1, 2, (1, 2)), padding=("valid", "same")) + def test_conv_2d(self, strides, padding): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 10, 3) + else: + input_shape = (2, 3, 10, 10) + inputs_2d = np.arange(600, dtype=float).reshape(input_shape) + kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) + + outputs = knn.conv(inputs_2d, kernel, strides, padding=padding) + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((2,)), + strides=strides, + padding=padding, + data_format=backend.config.image_data_format(), + dilation_rate=1, + groups=1, + ) + self.assertAllClose(outputs, expected) + + @parameterized.product(strides=(1, 2), dilation_rate=(1, (2, 1))) + def test_conv_2d_group_2(self, strides, dilation_rate): + if ( + backend.backend() == "tensorflow" + and strides == 2 + and dilation_rate == (2, 1) + ): + # This case is not supported by the TF backend. + return + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 10, 4) + else: + input_shape = (2, 4, 10, 10) + inputs_2d = np.ones(input_shape) + kernel = np.ones([2, 2, 2, 6]) + outputs = knn.conv( + inputs_2d, + kernel, + strides, + padding="same", + dilation_rate=dilation_rate, + ) + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=strides, + padding="same", + data_format=backend.config.image_data_format(), + dilation_rate=dilation_rate, + groups=1, + ) + self.assertAllClose(outputs, expected) + + @parameterized.product( + strides=(1, (1, 1, 1), 2), + padding=("valid", "same"), + data_format=("channels_first", "channels_last"), + ) + def test_conv_3d(self, strides, padding, data_format): + if data_format == "channels_last": + input_shape = (2, 8, 8, 8, 3) + else: + input_shape = (2, 3, 8, 8, 8) + inputs_3d = np.arange(3072, dtype=float).reshape(input_shape) + kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2]) + + outputs = knn.conv( + inputs_3d, kernel, strides, padding=padding, data_format=data_format + ) + expected = np_conv3d( + inputs_3d, + kernel, + bias_weights=np.zeros((2,)), + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=1, + groups=1, + ) + self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + + # Test for tracing error on tensorflow backend. + if backend.backend() == "tensorflow": + import tensorflow as tf + + @tf.function + def conv(x): + return knn.conv( + x, kernel, strides, padding=padding, data_format=data_format + ) + + outputs = conv(inputs_3d) + self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) + + @parameterized.product( + strides=(1, (1, 1), (2, 2)), + padding=("valid", "same"), + dilation_rate=(1, (2, 2)), + ) + def test_depthwise_conv_2d(self, strides, padding, dilation_rate): + if ( + backend.backend() == "tensorflow" + and strides == (2, 2) + and dilation_rate == (2, 2) + ): + # This case is not supported by the TF backend. + return + print(strides, padding, dilation_rate) + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 10, 3) + else: + input_shape = (2, 3, 10, 10) + inputs_2d = np.arange(600, dtype=float).reshape(input_shape) + kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) + + outputs = knn.depthwise_conv( + inputs_2d, + kernel, + strides, + padding=padding, + dilation_rate=dilation_rate, + ) + expected = np_depthwise_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=strides, + padding=padding, + data_format=backend.config.image_data_format(), + dilation_rate=dilation_rate, + ) + self.assertAllClose(outputs, expected) + + @parameterized.product( + strides=(1, 2), + padding=("valid", "same"), + dilation_rate=(1, (2, 2)), + ) + def test_separable_conv_2d(self, strides, padding, dilation_rate): + if ( + backend.backend() == "tensorflow" + and strides == 2 + and dilation_rate == (2, 2) + ): + # This case is not supported by the TF backend. + return + # Test 2D conv. + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 10, 10, 3) + else: + input_shape = (2, 3, 10, 10) + inputs_2d = np.arange(600, dtype=float).reshape(input_shape) + depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) + pointwise_kernel = np.arange(72, dtype=float).reshape([1, 1, 6, 12]) + + outputs = knn.separable_conv( + inputs_2d, + depthwise_kernel, + pointwise_kernel, + strides, + padding=padding, + dilation_rate=dilation_rate, + ) + # Depthwise followed by pointwise conv + expected_depthwise = np_depthwise_conv2d( + inputs_2d, + depthwise_kernel, + np.zeros(6), + strides=strides, + padding=padding, + data_format=backend.config.image_data_format(), + dilation_rate=dilation_rate, + ) + expected = np_conv2d( + expected_depthwise, + pointwise_kernel, + np.zeros(6 * 12), + strides=1, + padding=padding, + data_format=backend.config.image_data_format(), + dilation_rate=dilation_rate, + groups=1, + ) + self.assertAllClose(outputs, expected) + + @parameterized.product(padding=("valid", "same")) + def test_conv_transpose_1d(self, padding): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 4, 3) + else: + input_shape = (2, 3, 4) + inputs_1d = np.arange(24, dtype=float).reshape(input_shape) + kernel = np.arange(30, dtype=float).reshape([2, 5, 3]) + outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding=padding) + expected = np_conv1d_transpose( + inputs_1d, + kernel, + bias_weights=np.zeros(5), + strides=2, + output_padding=None, + padding=padding, + data_format=backend.config.image_data_format(), + dilation_rate=1, + ) + self.assertAllClose(outputs, expected) + + @parameterized.product(strides=(2, (2, 2)), padding=("valid", "same")) + def test_conv_transpose_2d(self, strides, padding): + if backend.config.image_data_format() == "channels_last": + input_shape = (2, 4, 4, 3) + else: + input_shape = (2, 3, 4, 4) + inputs_2d = np.arange(96, dtype=float).reshape(input_shape) + kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3]) + + outputs = knn.conv_transpose( + inputs_2d, kernel, strides, padding=padding + ) + expected = np_conv2d_transpose( + inputs_2d, + kernel, + bias_weights=np.zeros(5), + strides=strides, + output_padding=None, + padding=padding, + data_format=backend.config.image_data_format(), + dilation_rate=1, + ) + self.assertAllClose(outputs, expected) + + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + def test_one_hot(self, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors") + # Test 1D one-hot. + indices_1d = np.array([0, 1, 2, 3]) + output_1d = knn.one_hot(indices_1d, 4, sparse=sparse) + self.assertAllClose(output_1d, np.eye(4)[indices_1d]) + self.assertSparse(output_1d, sparse) + output_1d = knn.one_hot(indices_1d, 4, axis=0, sparse=sparse) + self.assertAllClose(output_1d, np.eye(4)[indices_1d]) + self.assertSparse(output_1d, sparse) + + # Test 1D list one-hot. + indices_1d = [0, 1, 2, 3] + output_1d = knn.one_hot(indices_1d, 4, sparse=sparse) + self.assertAllClose(output_1d, np.eye(4)[indices_1d]) + self.assertSparse(output_1d, sparse) + output_1d = knn.one_hot(indices_1d, 4, axis=0, sparse=sparse) + self.assertAllClose(output_1d, np.eye(4)[indices_1d]) + self.assertSparse(output_1d, sparse) + + # Test 2D one-hot. + indices_2d = np.array([[0, 1], [2, 3]]) + output_2d = knn.one_hot(indices_2d, 4, sparse=sparse) + self.assertAllClose(output_2d, np.eye(4)[indices_2d]) + self.assertSparse(output_2d, sparse) + output_2d = knn.one_hot(indices_2d, 4, axis=2, sparse=sparse) + self.assertAllClose(output_2d, np.eye(4)[indices_2d]) + self.assertSparse(output_2d, sparse) + output_2d = knn.one_hot(indices_2d, 4, axis=1, sparse=sparse) + self.assertAllClose( + output_2d, np.transpose(np.eye(4)[indices_2d], (0, 2, 1)) + ) + self.assertSparse(output_2d, sparse) + + # Test 1D one-hot with 1 extra dimension. + indices_1d = np.array([[0], [1], [2], [3]]) + output_1d = knn.one_hot(indices_1d, 4, sparse=sparse) + self.assertAllClose(output_1d, np.eye(4)[indices_1d]) + self.assertSparse(output_1d, sparse) + output_1d = knn.one_hot(indices_1d, 4, axis=0, sparse=sparse) + self.assertAllClose(output_1d, np.eye(4)[indices_1d].swapaxes(1, 2)) + self.assertSparse(output_1d, sparse) + + # Test 1D one-hot with negative inputs + indices_1d = np.array([0, -1, -1, 3]) + output_1d = knn.one_hot(indices_1d, 4, sparse=sparse) + self.assertAllClose( + output_1d, + np.array( + [ + [1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ), + ) + self.assertSparse(output_1d, sparse) + + def test_binary_crossentropy(self): + # Test with from_logits=False + target = np.array([[0.1], [0.9], [0.2], [1.0]]) + output = np.array([[0.1], [0.2], [0.3], [0.4]]) + result = knn.binary_crossentropy(target, output, from_logits=False) + self.assertAllClose( + result, + np.array([[0.32508277], [1.47080801], [0.52613434], [0.91629048]]), + ) + + # Test with from_logits=True + target = np.array([[0.1], [0.9], [0.2], [1.0]]) + output = np.array([[0.1], [0.2], [0.3], [0.4]]) + result = knn.binary_crossentropy(target, output, from_logits=True) + self.assertAllClose( + result, + np.array([[0.73439666], [0.61813887], [0.79435524], [0.51301525]]), + ) + + # Test with output clipping + target = np.array([[0.1], [0.9], [0.2], [1.0]]) + output = np.array([[0.99], [-0.2], [0.9], [-0.4]]) + result = knn.binary_crossentropy(target, output, from_logits=True) + self.assertAllClose( + result, + np.array([[1.206961], [0.778139], [1.061154], [0.913015]]), + ) + + def test_categorical_crossentropy(self): + target = np.array( + [ + [0.33008796, 0.0391289, 0.9503603], + [0.80376694, 0.92363342, 0.19147756], + ] + ) + output = np.array( + [ + [0.23446431, 0.35822914, 0.06683268], + [0.3413979, 0.05420256, 0.81619654], + ] + ) + + # Test from_logits=False + result = knn.categorical_crossentropy( + target, output, from_logits=False, axis=-1 + ) + self.assertAllClose(result, np.array([2.54095299, 3.96374412])) + + # Test axis + result = knn.categorical_crossentropy( + target, output, from_logits=False, axis=0 + ) + self.assertAllClose( + result, np.array([0.71683073, 1.87988172, 2.46810762]) + ) + + # Test from_logits=True + result = knn.categorical_crossentropy( + target, output, from_logits=True, axis=-1 + ) + self.assertAllClose(result, np.array([1.59419954, 2.49880593])) + + # Test with output clipping + output = np.array( + [ + [1.23446431, -0.35822914, 1.06683268], + [0.3413979, -0.05420256, 0.81619654], + ] + ) + result = knn.categorical_crossentropy( + target, output, from_logits=True, axis=-1 + ) + self.assertAllClose(result, np.array([1.16825923, 2.55436813])) + + def test_sparse_categorical_crossentropy(self): + target = np.array([0, 1, 2]) + output = np.array( + [[0.9, 0.05, 0.05], [0.05, 0.89, 0.06], [0.05, 0.01, 0.94]] + ) + result = knn.sparse_categorical_crossentropy(target, output) + self.assertAllClose(result, [0.105361, 0.116534, 0.061875]) + + output = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]]) + result = knn.sparse_categorical_crossentropy( + target, output, from_logits=True + ) + self.assertAllClose(result, [0.001822, 0.000459, 0.169846]) + + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + def test_multi_hot(self, sparse): + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors") + + # Test 1D multi-hot. + indices_1d = np.array([0, 1, 2, 3]) + expected_output_1d = np.array([1, 1, 1, 1]) + output_1d = knn.multi_hot(indices_1d, 4, sparse=sparse) + self.assertAllClose(output_1d, expected_output_1d) + self.assertSparse(output_1d, sparse) + + # Test 2D multi-hot. + indices_2d = np.array([[0, 1], [2, 3]]) + expected_output_2d = np.array([[1, 1, 0, 0], [0, 0, 1, 1]]) + output_2d = knn.multi_hot(indices_2d, 4, sparse=sparse) + self.assertAllClose(output_2d, expected_output_2d) + self.assertSparse(output_2d, sparse) + + # Test 1D multi-hot with negative inputs + indices_1d = np.array([0, -1, -1, 3]) + expected_output_1d = np.array([1, 0, 0, 1]) + output_1d = knn.multi_hot(indices_1d, 4, sparse=sparse) + self.assertAllClose(output_1d, expected_output_1d) + self.assertSparse(output_1d, sparse) + + def test_moments(self): + # Test 1D moments + x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32) + mean, variance = knn.moments(x, axes=[0]) + self.assertAllClose(mean, np.mean(x), atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, np.var(x), atol=1e-5, rtol=1e-5) + + # Test batch statistics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32) + mean, variance = knn.moments(x, axes=[0]) + self.assertAllClose(mean, np.mean(x, axis=0), atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, np.var(x, axis=0), atol=1e-5, rtol=1e-5) + + # Test global statistics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32) + mean, variance = knn.moments(x, axes=[0, 1, 2]) + expected_mean = np.mean(x, axis=(0, 1, 2)) + expected_variance = np.var(x, axis=(0, 1, 2)) + self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5) + + # Test keepdims + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32) + mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True) + expected_mean = np.mean(x, axis=(0, 1, 2), keepdims=True) + expected_variance = np.var(x, axis=(0, 1, 2), keepdims=True) + self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5) + + # Test float16 which causes overflow + x = np.array( + [-741.0, 353.2, 1099.0, -1807.0, 502.8, -83.4, 333.5, -130.9], + dtype=np.float16, + ) + mean, variance = knn.moments(x, axes=[0]) + expected_mean = np.mean(x.astype(np.float32)).astype(np.float16) + # the output variance is clipped to the max value of np.float16 because + # it is overflowed + expected_variance = np.finfo(np.float16).max + self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="synchronized=True only implemented for TF backend", + ) + def test_moments_sync(self): + # Test batch statistics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32) + mean, variance = knn.moments(x, axes=[0], synchronized=True) + self.assertAllClose(mean, np.mean(x, axis=0), atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, np.var(x, axis=0), atol=1e-5, rtol=1e-5) + + # Test global statistics for 4D moments (batch, height, width, channels) + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32) + mean, variance = knn.moments(x, axes=[0, 1, 2], synchronized=True) + expected_mean = np.mean(x, axis=(0, 1, 2)) + expected_variance = np.var(x, axis=(0, 1, 2)) + self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5) + + # Test keepdims + x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32) + mean, variance = knn.moments( + x, axes=[0, 1, 2], keepdims=True, synchronized=True + ) + expected_mean = np.mean(x, axis=(0, 1, 2), keepdims=True) + expected_variance = np.var(x, axis=(0, 1, 2), keepdims=True) + self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5) + self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5) + + @parameterized.product(dtype=["float16", "float32"]) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="synchronized=True only implemented for TF backend", + ) + def test_moments_sync_with_distribution_strategy(self, dtype): + from tensorflow.python.eager import context + + from keras.src.utils.module_utils import tensorflow as tf + + context._reset_context() + + # Config 2 CPUs for testing. + logical_cpus = tf.config.list_logical_devices("CPU") + if len(logical_cpus) == 1: + from tensorflow.python.eager import context + + context._reset_context() + tf.config.set_logical_device_configuration( + tf.config.list_physical_devices("CPU")[0], + [ + tf.config.LogicalDeviceConfiguration(), + tf.config.LogicalDeviceConfiguration(), + ], + ) + + @tf.function() + def test_on_moments(inputs): + return knn.moments( + inputs, axes=-1, keepdims=True, synchronized=True + ) + + # Test output of moments. + inputs = tf.constant([5.0, 9.0, 1.0, 3.0], dtype=dtype) + strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) + with strategy.scope(): + mean, variance = strategy.run(test_on_moments, args=(inputs,)) + self.assertEqual(mean.values[0], 4.5) + self.assertEqual(variance.values[0], 8.75) + self.assertEqual(variance.values[0], 8.75) + + context._reset_context() + + def test_batch_normalization(self): + x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + mean = np.array([0.2, 0.3, 0.4]) + variance = np.array([4.0, 16.0, 64.0]) + output = knn.batch_normalization( + x, + mean, + variance, + axis=-1, + offset=np.array([5.0, 10.0, 15.0]), + scale=np.array([10.0, 20.0, 30.0]), + epsilon=1e-7, + ) + expected_output = np.array([[4.5, 9.5, 14.625], [6.0, 11.0, 15.75]]) + self.assertAllClose(output, expected_output) + + output = knn.batch_normalization( + x, + mean, + variance, + axis=1, + epsilon=1e-7, + ) + expected_output = np.array( + [[-0.05, -0.025, -0.0125], [0.1, 0.05, 0.025]] + ) + self.assertAllClose(output, expected_output) + + output = knn.batch_normalization( + np.random.uniform(size=[2, 3, 3, 5]), + np.random.uniform(size=[5]), + np.random.uniform(size=[5]), + axis=3, + offset=np.random.uniform(size=[5]), + scale=np.random.uniform(size=[5]), + ) + self.assertEqual(tuple(output.shape), (2, 3, 3, 5)) + + def test_ctc_loss(self): + labels = np.array([[1, 2, 1], [1, 2, 2]]) + outputs = np.array( + [ + [[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]], + [[0.4, 0.8, 0.4], [0.2, 0.3, 0.3], [0.4, 0.3, 0.2]], + ] + ) + + label_length = np.array([3, 2]) + output_length = np.array([3, 2]) + + result = knn.ctc_loss(labels, outputs, label_length, output_length) + self.assertAllClose(result, np.array([3.4411672, 1.91680186])) + + def test_ctc_decode(self): + inputs = np.array( + [ + [ + [0.1, 0.4, 0.2, 0.4], + [0.3, -0.3, 0.4, 0.2], + [0.3, 0.2, 0.4, 0.3], + ], + [ + [0.7, 0.4, 0.3, 0.2], + [0.3, 0.3, 0.4, 0.1], + [0.6, -0.1, 0.1, 0.5], + ], + [ + [0.1, 0.4, 0.2, 0.7], + [0.3, 0.3, -0.2, 0.7], + [0.3, 0.2, 0.4, 0.1], + ], + ] + ) + labels = np.array([[1, 2, -1], [2, -1, -1], [3, -1, -1]]) + score_labels = np.array([[-1.2], [-1.7], [-0.7]]) + repeated_labels = np.array([[1, 2, 2], [2, -1, -1], [3, -1, -1]]) + + # Test strategy="greedy" and merge_repeated=True + (decoded,), scores = knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="greedy", + mask_index=0, + ) + self.assertAllClose(decoded, labels) + self.assertAllClose(scores, score_labels) + + # Test strategy="greedy" and merge_repeated=False + (decoded,), scores = knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="greedy", + merge_repeated=False, + mask_index=0, + ) + self.assertAllClose(decoded, repeated_labels) + self.assertAllClose(scores, score_labels) + + if backend.backend() == "torch": + self.skipTest("torch doesn't support 'beam_search' strategy") + + labels = np.array( + [ + [[1, 2, -1], [2, -1, -1], [3, -1, -1]], + [[2, -1, -1], [3, -1, -1], [1, -1, -1]], + ] + ) + score_labels = np.array( + [ + [-2.426537, -2.435596], + [-2.127681, -2.182338], + [-1.063386, -1.363386], + ] + ) + beam_width = 4 + top_paths = 2 + + # Test strategy="beam_search" + decoded, scores = knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="beam_search", + beam_width=beam_width, + top_paths=top_paths, + mask_index=0, + ) + self.assertAllClose(decoded, labels) + self.assertAllClose(scores, score_labels) + + def test_normalize(self): + x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) + self.assertAllClose( + knn.normalize(x, axis=None), + [ + [0.18898225, 0.3779645, 0.56694674], + [0.18898225, 0.3779645, 0.56694674], + ], + ) + self.assertAllClose( + knn.normalize(x, axis=0), + [ + [0.70710677, 0.70710677, 0.70710677], + [0.70710677, 0.70710677, 0.70710677], + ], + ) + self.assertAllClose( + knn.normalize(x, axis=-1), + [ + [0.26726124, 0.53452247, 0.8017837], + [0.26726124, 0.53452247, 0.8017837], + ], + ) + self.assertAllClose( + knn.normalize(x, order=3), + [ + [0.30285344, 0.6057069, 0.9085603], + [0.30285344, 0.6057069, 0.9085603], + ], + ) + + # linalg.norm(x, ...) < epsilon + x = np.array([[1e-6, 1e-8]], dtype=np.float32) + self.assertAllClose( + knn.normalize(x, axis=-1, order=2, epsilon=1e-5), + [[1e-1, 1e-3]], + ) + + def test_psnr(self): + x1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x2 = np.array([[0.2, 0.2, 0.3], [0.4, 0.6, 0.6]]) + max_val = 1.0 + expected_psnr_1 = 20 * np.log10(max_val) - 10 * np.log10( + np.mean(np.square(x1 - x2)) + ) + psnr_1 = knn.psnr(x1, x2, max_val) + self.assertAlmostEqual(psnr_1, expected_psnr_1) + + x3 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x4 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + max_val = 1.0 + expected_psnr_2 = 20 * np.log10(max_val) - 10 * np.log10( + np.mean(np.square(x3 - x4)) + ) + psnr_2 = knn.psnr(x3, x4, max_val) + self.assertAlmostEqual(psnr_2, expected_psnr_2) + + @parameterized.named_parameters( + named_product( + bias=(None, True), + scale=(None, 1.0), + mask_and_is_causal=((None, False), (True, False), (None, True)), + flash_attention=(None, True, False), + ) + ) + def test_dot_product_attention( + self, bias, scale, mask_and_is_causal, flash_attention + ): + mask, is_causal = mask_and_is_causal + query_shape = (2, 3, 4, 8) + key_shape = (2, 3, 4, 8) + bias_shape = (2, 4, 3, 3) + query = np.arange(math.prod(query_shape), dtype=float).reshape( + query_shape + ) + key = np.arange(math.prod(key_shape), dtype=float).reshape(key_shape) + value = np.arange(math.prod(key_shape), dtype=float).reshape(key_shape) + if mask is not None: + mask = np.tril(np.ones((3, 3))).astype("bool") + mask = mask[None, None, ...] + mask = np.tile(mask, (2, 4, 1, 1)) + if bias is not None: + if backend.backend() == "torch": + self.skipTest( + "torch does not support `bias` with `dot_product_attention`" + ) + bias = np.arange(math.prod(bias_shape), dtype=float).reshape( + bias_shape + ) + + if flash_attention: + if backend.backend() in ("tensorflow", "numpy"): + self.skipTest( + "Flash attention is not supported in tensorflow and numpy " + "backends." + ) + elif backend.backend() == "torch": + import torch + + if mask is not None: + self.skipTest( + "Flash attention doesn't support `mask=None` in torch " + "backend." + ) + if not torch.cuda.is_available(): + self.skipTest( + "Flash attention must be run on CUDA in torch backend." + ) + cuda_compute_capability = tuple( + int(x) for x in torch.cuda.get_device_capability() + ) + if cuda_compute_capability < (8, 0): + self.skipTest( + "Flash attention must be run on CUDA compute " + "capability >= 8.0 in torch backend." + ) + elif backend.backend() == "jax": + import jax + from jax._src import xla_bridge + + if "cuda" not in xla_bridge.get_backend().platform_version: + self.skipTest( + "Flash attention must be run on CUDA in jax backend." + ) + d, *_ = jax.local_devices(backend="gpu") + cuda_compute_capability = tuple( + int(x) for x in d.compute_capability.split(".") + ) + if cuda_compute_capability < (8, 0): + self.skipTest( + "Flash attention must be run on CUDA compute " + "capability >= 8.0 in jax backend." + ) + + # Flash attention only supports float16 and bfloat16. We multiply + # 0.1 to avoid overflow. + query = (query * 0.1).astype("float16") + key = (key * 0.1).astype("float16") + value = (value * 0.1).astype("float16") + if bias is not None: + bias = (bias * 0.1).astype("float16") + + outputs = knn.dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + flash_attention=flash_attention, + ) + + expected = _dot_product_attention( + query, + key, + value, + bias=bias, + mask=mask, + scale=scale, + is_causal=is_causal, + ) + self.assertAllClose( + outputs, expected, atol=1e-3 if flash_attention else 1e-6 + ) + + @parameterized.named_parameters(named_product(scale=(1.0, 10.0))) + def test_rms_normalization(self, scale): + x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype="float32") + scale = np.array([scale] * x.shape[-1], dtype="float32") + expected_output = ( + np.array([[0.46291, 0.92582, 1.38873], [0.78954, 0.98693, 1.18431]]) + * scale + ) + + self.assertAllClose( + knn.rms_normalization(x, scale), expected_output, atol=1e-3 + ) + self.assertAllClose(knn.RMSNorm()(x, scale), expected_output, atol=1e-3) + + def test_layer_normalization(self): + x = np.arange(5, dtype="float32") + expected_output = np.array( + [-1.4142135, -0.70710677, 0.0, 0.7071067, 1.4142135] + ) + + self.assertAllClose( + knn.layer_normalization(x), expected_output, atol=1e-3 + ) + self.assertAllClose(knn.LayerNorm()(x), expected_output, atol=1e-3) + + +class NNOpsDtypeTest(testing.TestCase): + """Test the floating dtype to verify that the behavior matches JAX.""" + + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_elu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.elu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.elu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Elu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_gelu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + + # approximate = True + expected_dtype = standardize_dtype(jnn.gelu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.gelu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Gelu().symbolic_call(x).dtype), + expected_dtype, + ) + # approximate = False + expected_dtype = standardize_dtype(jnn.gelu(x_jax, False).dtype) + + self.assertEqual( + standardize_dtype(knn.gelu(x, False).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Gelu(False).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_celu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.celu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.celu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Celu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_tanh_shrink(self, dtype): + import torch + import torch.nn.functional as tnn + + x = knp.ones((1), dtype=dtype) + x_torch = torch.ones(1, dtype=getattr(torch, dtype)) + expected_dtype = standardize_dtype(tnn.tanhshrink(x_torch).dtype) + + self.assertEqual( + standardize_dtype(knn.tanh_shrink(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.TanhShrink().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_hard_tanh(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.hard_tanh(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.hard_tanh(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.HardTanh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_hard_shrink(self, dtype): + import torch + import torch.nn.functional as tnn + + x = knp.ones((1), dtype=dtype) + x_torch = torch.ones(1, dtype=getattr(torch, dtype)) + expected_dtype = standardize_dtype(tnn.hardshrink(x_torch).dtype) + + self.assertEqual( + standardize_dtype(knn.hard_shrink(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.HardShrink().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_threshold(self, dtype): + import torch + import torch.nn.functional as tnn + + x = knp.ones((1), dtype=dtype) + x_torch = torch.ones(1, dtype=getattr(torch, dtype)) + expected_dtype = standardize_dtype(tnn.threshold(x_torch, 0, 0).dtype) + + self.assertEqual( + standardize_dtype(knn.threshold(x, 0, 0).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Threshold(0, 0).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_soft_shrink(self, dtype): + import torch + import torch.nn.functional as tnn + + x = knp.ones((1), dtype=dtype) + x_torch = torch.ones(1, dtype=getattr(torch, dtype)) + expected_dtype = standardize_dtype(tnn.softshrink(x_torch).dtype) + + self.assertEqual( + standardize_dtype(knn.soft_shrink(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.SoftShrink().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_sparse_plus(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.sparse_plus(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.sparse_plus(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.SparsePlus().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_glu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((2), dtype=dtype) + x_jax = jnp.ones((2), dtype=dtype) + expected_dtype = standardize_dtype(jnn.glu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.glu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Glu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_squareplus(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + if dtype == "bfloat16": + self.skipTest("Weirdness with numpy") + + x = knp.ones((2), dtype=dtype) + x_jax = jnp.ones((2), dtype=dtype) + expected_dtype = standardize_dtype(jnn.squareplus(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.squareplus(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Squareplus().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_hard_sigmoid(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.hard_sigmoid(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.hard_sigmoid(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.HardSigmoid().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_hard_silu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.hard_silu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.hard_silu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.HardSilu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_leaky_relu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.leaky_relu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.leaky_relu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.LeakyRelu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_log_sigmoid(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.log_sigmoid(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.log_sigmoid(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.LogSigmoid().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_log_softmax(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((10,), dtype=dtype) + x_jax = jnp.ones((10,), dtype=dtype) + expected_dtype = standardize_dtype(jnn.log_softmax(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.log_softmax(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.LogSoftmax().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_relu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.relu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.relu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Relu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_relu6(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.relu6(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.relu6(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Relu6().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_selu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.selu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.selu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Selu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_sigmoid(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.sigmoid(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.sigmoid(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Sigmoid().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_sparse_sigmoid(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.sparse_sigmoid(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.sparse_sigmoid(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.SparseSigmoid().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_silu(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.silu(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.silu(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Silu().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_softplus(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.softplus(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.softplus(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Softplus().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_softmax(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((10,), dtype=dtype) + x_jax = jnp.ones((10,), dtype=dtype) + expected_dtype = standardize_dtype(jnn.softmax(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.softmax(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Softmax().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_softsign(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.soft_sign(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.softsign(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.Softsign().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_polar(self, dtype): + import jax.nn as jnn + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnn.hard_tanh(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knn.hard_tanh(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knn.HardTanh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_ctc_loss(self, dtype): + labels = knp.array([[1, 2, 1]], dtype="int32") + outputs = knp.array( + [[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]]], dtype=dtype + ) + label_length = knp.array([3]) + output_length = knp.array([3]) + expected_dtype = ( + "float32" if dtype in ("float16", "bfloat16") else dtype + ) + + self.assertEqual( + standardize_dtype( + knn.ctc_loss(labels, outputs, label_length, output_length).dtype + ), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knn.CTCLoss() + .symbolic_call(labels, outputs, label_length, output_length) + .dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_ctc_decode(self, dtype): + inputs = knp.array( + [[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]]], dtype=dtype + ) + sequence_length = knp.array([3]) + expected_dtype = backend.result_type(dtype, "float32") + + # Test strategy="greedy" + decoded, scores = knn.ctc_decode( + inputs, sequence_length, strategy="greedy" + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + decoded, scores = knn.CTCDecode(strategy="greedy").symbolic_call( + inputs, sequence_length + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + + if backend.backend() == "torch": + self.skipTest("torch doesn't support 'beam_search' strategy") + + # Test strategy="beam_search" + decoded, scores = knn.ctc_decode( + inputs, sequence_length, strategy="beam_search" + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + decoded, scores = knn.CTCDecode(strategy="beam_search").symbolic_call( + inputs, sequence_length + ) + self.assertEqual(standardize_dtype(decoded.dtype), "int32") + self.assertEqual(standardize_dtype(scores.dtype), expected_dtype) + + @parameterized.named_parameters( + named_product( + dtypes=list(combinations(FLOAT_DTYPES, 2)) + + [(dtype, dtype) for dtype in FLOAT_DTYPES] + ) + ) + def test_dot_product_attention(self, dtypes): + # TODO: Get expected output from jax if `jax.nn.dot_product_attention` + # is available. + query_dtype, key_value_dtype = dtypes + query = knp.ones((2, 3, 3, 8), dtype=query_dtype) + key = knp.ones((2, 3, 3, 8), dtype=key_value_dtype) + value = knp.ones((2, 3, 3, 8), dtype=key_value_dtype) + expected_dtype = backend.result_type(*dtypes) + + self.assertDType( + knn.dot_product_attention(query, key, value), expected_dtype + ) + self.assertDType( + knn.DotProductAttention().symbolic_call(query, key, value), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=combinations(FLOAT_DTYPES, 2)) + ) + def test_rms_normalization(self, dtypes): + input_dtype, weight_dtype = dtypes + inputs = knp.ones((2, 8), dtype=input_dtype) + scale = backend.Variable(knp.ones((8,), dtype=weight_dtype)) + expected_dtype = input_dtype + + self.assertDType(knn.rms_normalization(inputs, scale), expected_dtype) + self.assertDType( + knn.RMSNorm().symbolic_call(inputs, scale), expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=combinations(FLOAT_DTYPES, 2)) + ) + def test_layer_normalization(self, dtypes): + input_dtype, weight_dtype = dtypes + inputs = knp.ones((2, 8), dtype=input_dtype) + gamma = backend.Variable(knp.ones((8,), dtype=weight_dtype)) + beta = backend.Variable(knp.ones((8,), dtype=weight_dtype)) + expected_dtype = input_dtype + + self.assertDType( + knn.layer_normalization(inputs, gamma, beta), expected_dtype + ) + self.assertDType( + knn.LayerNorm().symbolic_call(inputs, gamma, beta), expected_dtype + ) + + +class NNOpsBehaviorTest(testing.TestCase): + def test_logit_recovery_binary_crossentropy(self): + layer = layers.Dense( + 4, activation="sigmoid", use_bias=False, kernel_initializer="ones" + ) + loss = losses.BinaryCrossentropy() + x = np.array([[1.4, 1.6, 0.8]]) + y = np.array([[0.2, 0.6, 0.1, 0.3]]) + loss_value = loss(y, layer(x)) + self.assertAllClose(loss_value, 2.682124) + + model = models.Sequential([layer]) + model.compile(loss="binary_crossentropy", optimizer="sgd") + out = model.evaluate(x, y) + self.assertAllClose(out, 2.682124) + + def test_softmax_on_axis_with_size_one_warns(self): + x = np.array([[1.0]]) + # Applying softmax on the second axis, which has size 1 + axis = 1 + + # Expected warning message + expected_warning_regex = ( + r"You are using a softmax over axis 1 " + r"of a tensor of shape \(1, 1\)\. This axis " + r"has size 1\. The softmax operation will always return " + r"the value 1, which is likely not what you intended\. " + r"Did you mean to use a sigmoid instead\?" + ) + + with self.assertWarnsRegex(UserWarning, expected_warning_regex): + knn.softmax(x, axis) + + def test_normalize_order_validation(self): + # Test with a non-integer order + with self.assertRaisesRegex( + ValueError, "Argument `order` must be an int >= 1" + ): + knn.normalize(np.array([1, 2, 3]), order="a") + + # Test with a negative integer + with self.assertRaisesRegex( + ValueError, "Argument `order` must be an int >= 1" + ): + knn.normalize(np.array([1, 2, 3]), order=-1) + + # Test with zero + with self.assertRaisesRegex( + ValueError, "Argument `order` must be an int >= 1" + ): + knn.normalize(np.array([1, 2, 3]), order=0) + + # Test with a floating-point number + with self.assertRaisesRegex( + ValueError, "Argument `order` must be an int >= 1" + ): + knn.normalize(np.array([1, 2, 3]), order=2.5) + + def test_check_shape_first_dim_mismatch(self): + name1, shape1 = "labels", (2, 3) + name2, shape2 = "logits", (3, 4, 5) + ctc_loss_instance = knn.CTCLoss(mask_index=-1) + with self.assertRaisesRegex( + ValueError, "must have the same first dimension" + ): + ctc_loss_instance._check_shape_first_dim( + name1, shape1, name2, shape2 + ) + + def test_invalid_strategy_ctc_decode(self): + inputs = np.array( + [ + [ + [0.1, 0.4, 0.2, 0.4], + [0.3, 0.3, 0.4, 0.2], + [0.3, 0.2, 0.4, 0.3], + ] + ] + ) + beam_width = 4 + top_paths = 2 + with self.assertRaisesRegex(ValueError, "Invalid strategy"): + knn.ctc_decode( + inputs, + sequence_lengths=[3, 3, 1], + strategy="invalid", + beam_width=beam_width, + top_paths=top_paths, + ) + + def test_layer_normalization_rms_scaling_warning(self): + x = np.arange(5, dtype="float32") + with self.assertWarnsRegex( + UserWarning, r"You passed `rms_scaling=True`, which is deprecated" + ): + knn.layer_normalization(x, rms_scaling=True) + + def test_unfold(self): + if keras.config.backend() in ["openvino"]: + pytest.skip("Backend does not support unfold operation") + # test 1 kernel_size=2 + x = ops.arange(8, dtype="float32") + x = ops.reshape(x, [1, 1, 2, 4]) + unfold_result = knn.unfold(x, 2) + except_result = ops.convert_to_tensor( + [ + [ + [0.0, 1.0, 2.0], + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [5.0, 6.0, 7.0], + ] + ] + ) + self.assertAllClose(unfold_result, except_result) + + # test 2 kernel_size=[2,4] + x = ops.arange(16, dtype="float32") + x = ops.reshape(x, [1, 1, 4, 4]) + unfold_result = knn.unfold(x, [2, 4]) + except_result = ops.convert_to_tensor( + [ + [ + [0.0, 4.0, 8.0], + [1.0, 5.0, 9.0], + [2.0, 6.0, 10.0], + [3.0, 7.0, 11.0], + [4.0, 8.0, 12.0], + [5.0, 9.0, 13.0], + [6.0, 10.0, 14.0], + [7.0, 11.0, 15.0], + ] + ], + dtype="float32", + ) + self.assertAllClose(unfold_result, except_result) + + # test 3 kernel_size=[3,2],stride=[3,2] + x = ops.arange(12, dtype="float32") + x = ops.reshape(x, [1, 1, 3, 4]) + unfold_result = knn.unfold(x, [3, 2], stride=[3, 2]) + except_result = ops.convert_to_tensor( + [ + [ + [0.0, 2.0], + [1.0, 3.0], + [4.0, 6.0], + [5.0, 7.0], + [8.0, 10.0], + [9.0, 11.0], + ] + ] + ) + self.assertAllClose(unfold_result, except_result) + + # test 4 kernel_size=2,dilation=2,stride=2 + x = ops.arange(16, dtype="float32") + x = ops.reshape(x, [1, 1, 4, 4]) + unfold_result = knn.unfold(x, 2, 2, stride=2) + except_result = ops.convert_to_tensor([0, 2, 8, 10], dtype="float32") + except_result = ops.reshape(except_result, [1, 4, 1]) + self.assertAllClose(unfold_result, except_result) + + # test 5 kernel_size=2,padding=1 + x = ops.arange(4, dtype="float32") + x = ops.reshape(x, [1, 1, 2, 2]) + unfold_result = knn.unfold(x, 1, padding=1) + except_result = ops.convert_to_tensor( + [ + [ + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 2.0, + 3.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ] + ] + ) + self.assertAllClose(unfold_result, except_result) + + # test 6 multi channal and kernel_size=2 + x = ops.arange(8, dtype="float32") + x = ops.reshape(x, [1, 2, 2, 2]) + unfold_result = knn.unfold(x, 2) + except_result = ops.convert_to_tensor( + [[[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]]] + ) + self.assertAllClose(unfold_result, except_result) + + # test 7 multi channal and kernel_size=[2,3] + x = ops.arange(12, dtype="float32") + x = ops.reshape(x, [1, 2, 2, 3]) + unfold_result = knn.unfold(x, [2, 3]) + except_result = ops.convert_to_tensor( + [ + [ + [0.0], + [1.0], + [2.0], + [3.0], + [4.0], + [5.0], + [6.0], + [7.0], + [8.0], + [9.0], + [10.0], + [11.0], + ] + ] + ) + self.assertAllClose(unfold_result, except_result) + + # test 8 multi channal and kernel_size=[2,3],stride=[2,3] + x = ops.arange(12, dtype="float32") + x = ops.reshape(x, [1, 2, 2, 3]) + unfold_result = knn.unfold(x, [2, 3], stride=[2, 3]) + except_result = ops.convert_to_tensor( + [ + [ + [0.0], + [1.0], + [2.0], + [3.0], + [4.0], + [5.0], + [6.0], + [7.0], + [8.0], + [9.0], + [10.0], + [11.0], + ] + ] + ) + self.assertAllClose(unfold_result, except_result) + + # test 9 multi channal and kernel_size=2,dilation=2 + x = ops.arange(32, dtype="float32") + x = ops.reshape(x, [1, 2, 4, 4]) + unfold_result = knn.unfold(x, 2, dilation=2) + except_result = ops.convert_to_tensor( + [ + [ + [0.0, 1.0, 4.0, 5.0], + [2.0, 3.0, 6.0, 7.0], + [8.0, 9.0, 12.0, 13.0], + [10.0, 11.0, 14.0, 15.0], + [16.0, 17.0, 20.0, 21.0], + [18.0, 19.0, 22.0, 23.0], + [24.0, 25.0, 28.0, 29.0], + [26.0, 27.0, 30.0, 31.0], + ] + ] + ) + self.assertAllClose(unfold_result, except_result) + + # test 10 multi channal and kernel_size=2,padding=1 + x = ops.arange(8, dtype="float32") + x = ops.reshape(x, [1, 2, 2, 2]) + unfold_result = knn.unfold(x, 2, padding=1) + except_result = ops.convert_to_tensor( + [ + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 3.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 6.0, 7.0], + [0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 6.0, 7.0, 0.0], + [0.0, 4.0, 5.0, 0.0, 6.0, 7.0, 0.0, 0.0, 0.0], + [4.0, 5.0, 0.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + self.assertAllClose(unfold_result, except_result) diff --git a/keras/src/ops/node.py b/keras/src/ops/node.py new file mode 100644 index 000000000000..7e05de88fcf5 --- /dev/null +++ b/keras/src/ops/node.py @@ -0,0 +1,143 @@ +import collections + +from keras.src import tree +from keras.src.backend import KerasTensor +from keras.src.ops.symbolic_arguments import SymbolicArguments + + +class Node: + """A `Node` describes an operation `__call__()` event. + + A Keras Function is a DAG with `Node` instances as nodes, and + `KerasTensor` instances as edges. Nodes aren't `Operation` instances, + because a single operation could be called multiple times, which would + result in graph cycles. + + A `__call__()` event involves input tensors (and other input arguments), + the operation that was called, and the resulting output tensors. + A `Node` will include all this information. + + Since a single `Operation` could be called multiple times, + the `Node` instances are stored on operations as a list. + Each time an operation is called, a node is added to `op._inbound_nodes`. + Each time the output of an operation is used by another operation, + a node is added to `op._outbound_nodes`. + + Every `KerasTensor` instance has a `KerasHistory` object attached, + which tracks the `Node` that records the `__call__()` event that created + the tensor. By recursively walking through `Node` instances + via the `KerasHistory` metadata of `KerasTensor` instances, once can + retrieve the entire DAG of a Keras Function. + + Args: + operation: The Operation that was called in the `op.__call__()` + event that this node represents. + call_args: The positional arguments the operation was called with. + call_kwargs: The keyword arguments the operation was called with. + outputs: The output tensors of the `op.__call__()` call. + """ + + def __init__( + self, operation, call_args=None, call_kwargs=None, outputs=None + ): + self.operation = operation + self.arguments = SymbolicArguments(*call_args, **call_kwargs) + self.outputs = [] if outputs is None else tree.flatten(outputs) + for x in self.outputs: + if not isinstance(x, KerasTensor): + raise ValueError( + "All operation outputs must be tensors. " + f"Operation {operation} returned a non-tensor. " + f"Non-tensor received: {x}" + ) + + zero_history = any( + not x.record_history for x in self.arguments.keras_tensors + ) + + # If inputs don't have metadata yet, add it. + if not zero_history: + for tensor in self.arguments.keras_tensors: + if not hasattr(tensor, "_keras_history"): + tensor._keras_history = KerasHistory( + operation=None, node_index=0, tensor_index=0 + ) + + # Wire up Node to Operations. + self.operation._inbound_nodes.append(self) + for kt in self.arguments.keras_tensors: + inbound_op = kt._keras_history.operation + if inbound_op is not None: # It's a graph entry point. + inbound_op._outbound_nodes.append(self) + + # Set metadata on outputs. + if not zero_history: + node_index = len(self.operation._inbound_nodes) - 1 + for i, tensor in enumerate(self.outputs): + tensor._keras_history = KerasHistory( + operation=operation, node_index=node_index, tensor_index=i + ) + + # Whether this is a root node. + self.is_input = not self.arguments.keras_tensors + + def __repr__(self): + return f"" + + @property + def input_tensors(self): + return self.arguments.keras_tensors + + @property + def output_tensors(self): + return self.outputs + + @property + def parent_nodes(self): + """The parent `Node`s. + + Returns: + all the `Node`s whose output this node immediately depends on. + """ + node_deps = [] + for kt in self.arguments.keras_tensors: + op = kt._keras_history.operation + node_index = kt._keras_history.node_index + if op is not None: # `None` for `Input` tensors. + node_deps.append(op._inbound_nodes[node_index]) + return node_deps + + +class KerasHistory( + collections.namedtuple( + "KerasHistory", ["operation", "node_index", "tensor_index"] + ) +): + """Tracks the Operation call that created a Tensor. + + During construction of Keras Functions, this metadata is added to + each Tensor produced as the output of an Operation. + This allows Keras to track how each Tensor was produced, and + this information is later retraced by the `Function` class to + reconstruct the Operations graph. + + Attributes: + operation: The Operation instance that produced the Tensor. + node_index: The specific call to the Operation that produced this Tensor. + Operations can be called multiple times in order to share weights. A new + node is created every time an Operation is called. The corresponding + node that represents the call event that produced the Tensor can be + found at `op._inbound_nodes[node_index]`. + tensor_index: The output index for this Tensor. + Always zero if the Operation that produced this Tensor + only has one output. Nested structures of + Tensors are deterministically assigned an index via `nest.flatten`. + """ + + # Added to maintain memory and performance characteristics of `namedtuple` + # while subclassing. + __slots__ = () + + +def is_keras_tensor(obj): + return hasattr(obj, "_keras_history") diff --git a/keras/src/ops/node_test.py b/keras/src/ops/node_test.py new file mode 100644 index 000000000000..7ed8227b3c2f --- /dev/null +++ b/keras/src/ops/node_test.py @@ -0,0 +1,66 @@ +import numpy as np + +from keras.src import Layer +from keras.src import testing +from keras.src.backend import KerasTensor +from keras.src.ops.node import Node + + +class DummyLayer(Layer): + pass + + +class NodeTest(testing.TestCase): + # Testing a simple node and layer combination **a** + def test_simple_case(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + a_layer = DummyLayer() + node = Node(a_layer, outputs=a, call_args=(), call_kwargs={}) + + self.assertEqual(node.is_input, True) + + self.assertEqual(node.output_tensors[0], a) + self.assertEqual(node.output_tensors[0].shape, shape) + + # Testing a simple node connection with args and kwargs **a** --> **b** + def test_single_wired_layers(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + a_layer = DummyLayer() + node1 = Node(a_layer, outputs=a, call_args=(), call_kwargs={}) + + b = KerasTensor(shape=shape) + x = KerasTensor(shape=shape) + kwargs = {"x": x} + args = (a,) + b_layer = DummyLayer() + node2 = Node(b_layer, outputs=b, call_args=args, call_kwargs=kwargs) + + self.assertEqual(node1.is_input, True) + self.assertEqual(node2.is_input, False) + + self.assertEqual(node1.operation, a_layer) + self.assertEqual(node2.operation, b_layer) + + self.assertEqual(node1.output_tensors[0], a) + self.assertEqual(node1.output_tensors[0].shape, shape) + + self.assertEqual(a_layer._inbound_nodes[0], node1) + self.assertEqual(a_layer._outbound_nodes[0], node2) + + self.assertEqual(b_layer._inbound_nodes[0], node2) + self.assertEqual(node2.parent_nodes[0], node1) + + self.assertEqual(node2.input_tensors, [a, x]) + self.assertEqual(node2.arguments.kwargs, kwargs) + self.assertEqual(node2.arguments.args, args) + + # Testing when output tensor is not Keras Tensor + def test_output_tensor_error(self): + a = np.random.rand(2, 3, 4) + a_layer = DummyLayer() + with self.assertRaisesRegex( + ValueError, "operation outputs must be tensors." + ): + Node(a_layer, outputs=a, call_args=(), call_kwargs={}) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py new file mode 100644 index 000000000000..cbc07c9c3e3c --- /dev/null +++ b/keras/src/ops/numpy.py @@ -0,0 +1,7613 @@ +import builtins +import re + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.backend.common import dtypes +from keras.src.backend.common.backend_utils import canonicalize_axis +from keras.src.backend.common.backend_utils import to_tuple_or_list +from keras.src.ops import operation_utils +from keras.src.ops.operation import Operation +from keras.src.ops.operation_utils import broadcast_shapes +from keras.src.ops.operation_utils import reduce_shape + + +class Rot90(Operation): + def __init__(self, k=1, axes=(0, 1), *, name=None): + super().__init__(name=name) + self.k = k + self.axes = axes + + def call(self, array): + return backend.numpy.rot90(array, k=self.k, axes=self.axes) + + def compute_output_spec(self, array): + array_shape = list(array.shape) + if len(array_shape) < 2: + raise ValueError( + "Input array must have at least 2 dimensions. " + f"Received: array.shape={array_shape}" + ) + if len(self.axes) != 2 or self.axes[0] == self.axes[1]: + raise ValueError( + f"Invalid axes: {self.axes}. " + "Axes must be a tuple of two different dimensions." + ) + axis1, axis2 = self.axes + array_shape[axis1], array_shape[axis2] = ( + array_shape[axis2], + array_shape[axis1], + ) + return KerasTensor(shape=array_shape, dtype=array.dtype) + + +@keras_export(["keras.ops.rot90", "keras.ops.numpy.rot90"]) +def rot90(array, k=1, axes=(0, 1)): + """Rotate an array by 90 degrees in the plane specified by axes. + + This function rotates an array counterclockwise + by 90 degrees `k` times in the plane specified by `axes`. + Supports arrays of two or more dimensions. + + Args: + array: Input array to rotate. + k: Number of times the array is rotated by 90 degrees. + axes: A tuple of two integers specifying the + plane of rotation (defaults to `(0, 1)`). + + Returns: + Rotated array. + + Examples: + + >>> import numpy as np + >>> from keras import ops + >>> m = np.array([[1, 2], [3, 4]]) + >>> rotated = ops.rot90(m) + >>> rotated + array([[2, 4], + [1, 3]]) + + >>> m = np.arange(8).reshape((2, 2, 2)) + >>> rotated = ops.rot90(m, k=1, axes=(1, 2)) + >>> rotated + array([[[1, 3], + [0, 2]], + [[5, 7], + [4, 6]]]) + """ + if any_symbolic_tensors((array,)): + return Rot90(k=k, axes=axes).symbolic_call(array) + return backend.numpy.rot90(array, k=k, axes=axes) + + +def shape_equal(shape1, shape2, axis=None, allow_none=True): + """Check if two shapes are equal. + + Args: + shape1: A list or tuple of integers for first shape to be compared. + shape2: A list or tuple of integers for second shape to be compared. + axis: An integer, list, or tuple of integers (optional): + Axes to ignore during comparison. Defaults to `None`. + allow_none (bool, optional): If `True`, allows `None` in a shape + to match any value in the corresponding position of the other shape. + Defaults to `True`. + + Returns: + bool: `True` if shapes are considered equal based on the criteria, + `False` otherwise. + + Examples: + + >>> shape_equal((32, 64, 128), (32, 64, 128)) + True + >>> shape_equal((32, 64, 128), (32, 64, 127)) + False + >>> shape_equal((32, 64, None), (32, 64, 128), allow_none=True) + True + >>> shape_equal((32, 64, None), (32, 64, 128), allow_none=False) + False + >>> shape_equal((32, 64, 128), (32, 63, 128), axis=1) + True + >>> shape_equal((32, 64, 128), (32, 63, 127), axis=(1, 2)) + True + >>> shape_equal((32, 64, 128), (32, 63, 127), axis=[1,2]) + True + >>> shape_equal((32, 64), (32, 64, 128)) + False + """ + if len(shape1) != len(shape2): + return False + + shape1 = list(shape1) + shape2 = list(shape2) + + if axis is not None: + if isinstance(axis, int): + axis = [axis] + for ax in axis: + shape1[ax] = -1 + shape2[ax] = -1 + + if allow_none: + for i in range(len(shape1)): + if shape1[i] is None: + shape1[i] = shape2[i] + if shape2[i] is None: + shape2[i] = shape1[i] + + return shape1 == shape2 + + +class Absolute(Operation): + def call(self, x): + return backend.numpy.absolute(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.absolute", "keras.ops.numpy.absolute"]) +def absolute(x): + """Compute the absolute value element-wise. + + `keras.ops.abs` is a shorthand for this function. + + Args: + x: Input tensor. + + Returns: + An array containing the absolute value of each element in `x`. + + Example: + + >>> x = keras.ops.convert_to_tensor([-1.2, 1.2]) + >>> keras.ops.absolute(x) + array([1.2, 1.2], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Absolute().symbolic_call(x) + return backend.numpy.absolute(x) + + +class Abs(Absolute): + pass + + +@keras_export(["keras.ops.abs", "keras.ops.numpy.abs"]) +def abs(x): + """Shorthand for `keras.ops.absolute`.""" + return absolute(x) + + +class Add(Operation): + def call(self, x1, x2): + return backend.numpy.add(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and x2_sparse + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) + + +@keras_export(["keras.ops.add", "keras.ops.numpy.add"]) +def add(x1, x2): + """Add arguments element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + The tensor containing the element-wise sum of `x1` and `x2`. + + Examples: + >>> x1 = keras.ops.convert_to_tensor([1, 4]) + >>> x2 = keras.ops.convert_to_tensor([5, 6]) + >>> keras.ops.add(x1, x2) + array([6, 10], dtype=int32) + + `keras.ops.add` also broadcasts shapes: + >>> x1 = keras.ops.convert_to_tensor( + ... [[5, 4], + ... [5, 6]] + ... ) + >>> x2 = keras.ops.convert_to_tensor([5, 6]) + >>> keras.ops.add(x1, x2) + array([[10 10] + [10 12]], shape=(2, 2), dtype=int32) + """ + if any_symbolic_tensors((x1, x2)): + return Add().symbolic_call(x1, x2) + return backend.numpy.add(x1, x2) + + +class All(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.all( + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape( + x.shape, + axis=self.axis, + keepdims=self.keepdims, + ), + dtype="bool", + ) + + +@keras_export(["keras.ops.all", "keras.ops.numpy.all"]) +def all(x, axis=None, keepdims=False): + """Test whether all array elements along a given axis evaluate to `True`. + + Args: + x: Input tensor. + axis: An integer or tuple of integers that represent the axis along + which a logical AND reduction is performed. The default + (`axis=None`) is to perform a logical AND over all the dimensions + of the input array. `axis` may be negative, in which case it counts + for the last to the first axis. + keepdims: If `True`, axes which are reduced are left in the result as + dimensions with size one. With this option, the result will + broadcast correctly against the input array. Defaults to `False`. + + Returns: + The tensor containing the logical AND reduction over the `axis`. + + Examples: + >>> x = keras.ops.convert_to_tensor([True, False]) + >>> keras.ops.all(x) + array(False, shape=(), dtype=bool) + + >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]]) + >>> keras.ops.all(x, axis=0) + array([ True False], shape=(2,), dtype=bool) + + `keepdims=True` outputs a tensor with dimensions reduced to one. + >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]]) + >>> keras.ops.all(x, keepdims=True) + array([[False]], shape=(1, 1), dtype=bool) + """ + if any_symbolic_tensors((x,)): + return All(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.all(x, axis=axis, keepdims=keepdims) + + +class Any(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.any( + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape( + x.shape, + axis=self.axis, + keepdims=self.keepdims, + ), + dtype="bool", + ) + + +class Angle(Operation): + def call(self, x): + return backend.numpy.angle(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.angle", "keras.ops.numpy.angle"]) +def angle(x): + """Element-wise angle of a complex tensor. + + Arguments: + x: Input tensor. Can be real or complex. + + Returns: + Output tensor of same shape as x. containing the angle of each element + (in radians). + + Example: + >>> x = keras.ops.convert_to_tensor([[1 + 3j, 2 - 5j], [4 - 3j, 3 + 2j]]) + >>> keras.ops.angle(x) + array([[ 1.2490457, -1.19029 ], + [-0.6435011, 0.5880026]], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Angle().symbolic_call(x) + return backend.numpy.angle(x) + + +@keras_export(["keras.ops.any", "keras.ops.numpy.any"]) +def any(x, axis=None, keepdims=False): + """Test whether any array element along a given axis evaluates to `True`. + + Args: + x: Input tensor. + axis: An integer or tuple of integers that represent the axis along + which a logical OR reduction is performed. The default + (`axis=None`) is to perform a logical OR over all the dimensions + of the input array. `axis` may be negative, in which case it counts + for the last to the first axis. + keepdims: If `True`, axes which are reduced are left in the result as + dimensions with size one. With this option, the result will + broadcast correctly against the input array. Defaults to `False`. + + Returns: + The tensor containing the logical OR reduction over the `axis`. + + Examples: + >>> x = keras.ops.convert_to_tensor([True, False]) + >>> keras.ops.any(x) + array(True, shape=(), dtype=bool) + + >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]]) + >>> keras.ops.any(x, axis=0) + array([ True True], shape=(2,), dtype=bool) + + `keepdims=True` outputs a tensor with dimensions reduced to one. + >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]]) + >>> keras.ops.all(x, keepdims=True) + array([[False]], shape=(1, 1), dtype=bool) + """ + if any_symbolic_tensors((x,)): + return Any(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.any(x, axis=axis, keepdims=keepdims) + + +class Amax(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.amax( + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, + ) + + +@keras_export(["keras.ops.amax", "keras.ops.numpy.amax"]) +def amax(x, axis=None, keepdims=False): + """Returns the maximum of an array or maximum value along an axis. + + Args: + x: Input tensor. + axis: Axis along which to compute the maximum. + By default (`axis=None`), find the maximum value in all the + dimensions of the input array. + keepdims: If `True`, axes which are reduced are left in the result as + dimensions that are broadcast to the size of the original + input tensor. Defaults to `False`. + + Returns: + An array with the maximum value. If `axis=None`, the result is a scalar + value representing the maximum element in the entire array. If `axis` is + given, the result is an array with the maximum values along + the specified axis. + + Examples: + >>> x = keras.ops.convert_to_tensor([[1, 3, 5], [2, 3, 6]]) + >>> keras.ops.amax(x) + array(6, dtype=int32) + + >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [1, 5, 2]]) + >>> keras.ops.amax(x, axis=0) + array([1, 6, 8], dtype=int32) + + >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [1, 5, 2]]) + >>> keras.ops.amax(x, axis=1, keepdims=True) + array([[8], [5]], dtype=int32) + """ + if any_symbolic_tensors((x,)): + return Amax(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.amax(x, axis=axis, keepdims=keepdims) + + +class Amin(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.amin(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, + ) + + +@keras_export(["keras.ops.amin", "keras.ops.numpy.amin"]) +def amin(x, axis=None, keepdims=False): + """Returns the minimum of an array or minimum value along an axis. + + Args: + x: Input tensor. + axis: Axis along which to compute the minimum. + By default (`axis=None`), find the minimum value in all the + dimensions of the input array. + keepdims: If `True`, axes which are reduced are left in the result as + dimensions that are broadcast to the size of the original + input tensor. Defaults to `False`. + + Returns: + An array with the minimum value. If `axis=None`, the result is a scalar + value representing the minimum element in the entire array. If `axis` is + given, the result is an array with the minimum values along + the specified axis. + + Examples: + >>> x = keras.ops.convert_to_tensor([1, 3, 5, 2, 3, 6]) + >>> keras.ops.amin(x) + array(1, dtype=int32) + + >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [7, 5, 3]]) + >>> keras.ops.amin(x, axis=0) + array([1,5,3], dtype=int32) + + >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [7, 5, 3]]) + >>> keras.ops.amin(x, axis=1, keepdims=True) + array([[1],[3]], dtype=int32) + """ + if any_symbolic_tensors((x,)): + return Amin(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.amin(x, axis=axis, keepdims=keepdims) + + +class Append(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x1, x2): + return backend.numpy.append(x1, x2, axis=self.axis) + + def compute_output_spec(self, x1, x2): + x1_shape = x1.shape + x2_shape = x2.shape + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if self.axis is None: + if None in x1_shape or None in x2_shape: + output_shape = [None] + else: + output_shape = [int(np.prod(x1_shape) + np.prod(x2_shape))] + return KerasTensor(output_shape, dtype=dtype) + + if not shape_equal(x1_shape, x2_shape, [self.axis]): + raise ValueError( + "`append` requires inputs to have the same shape except the " + f"`axis={self.axis}`, but received shape {x1_shape} and " + f"{x2_shape}." + ) + + output_shape = list(x1_shape) + output_shape[self.axis] = x1_shape[self.axis] + x2_shape[self.axis] + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.append", "keras.ops.numpy.append"]) +def append( + x1, + x2, + axis=None, +): + """Append tensor `x2` to the end of tensor `x1`. + + Args: + x1: First input tensor. + x2: Second input tensor. + axis: Axis along which tensor `x2` is appended to tensor `x1`. + If `None`, both tensors are flattened before use. + + Returns: + A tensor with the values of `x2` appended to `x1`. + + Examples: + >>> x1 = keras.ops.convert_to_tensor([1, 2, 3]) + >>> x2 = keras.ops.convert_to_tensor([[4, 5, 6], [7, 8, 9]]) + >>> keras.ops.append(x1, x2) + array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32) + + When `axis` is specified, `x1` and `x2` must have compatible shapes. + >>> x1 = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]]) + >>> x2 = keras.ops.convert_to_tensor([[7, 8, 9]]) + >>> keras.ops.append(x1, x2, axis=0) + array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=int32) + >>> x3 = keras.ops.convert_to_tensor([7, 8, 9]) + >>> keras.ops.append(x1, x3, axis=0) + Traceback (most recent call last): + ... + TypeError: Cannot concatenate arrays with different numbers of + dimensions: got (2, 3), (3,). + """ + if any_symbolic_tensors((x1, x2)): + return Append(axis=axis).symbolic_call(x1, x2) + return backend.numpy.append(x1, x2, axis=axis) + + +class Arange(Operation): + def __init__(self, dtype=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, start, stop=None, step=None): + return backend.numpy.arange(start, stop, step=step, dtype=self.dtype) + + def compute_output_spec(self, start, stop=None, step=None): + if stop is None: + start, stop = 0, start + if step is None: + step = 1 + output_shape = [int(np.ceil((stop - start) / step))] + dtype = self.dtype + if dtype is None: + dtypes_to_resolve = [getattr(start, "dtype", type(start))] + if stop is not None: + dtypes_to_resolve.append(getattr(stop, "dtype", type(stop))) + if step is not None: + dtypes_to_resolve.append(getattr(step, "dtype", type(step))) + dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.arange", "keras.ops.numpy.arange"]) +def arange(start, stop=None, step=None, dtype=None): + """Return evenly spaced values within a given interval. + + `arange` can be called with a varying number of positional arguments: + * `arange(stop)`: Values are generated within the half-open interval + `[0, stop)` (in other words, the interval including start but excluding + stop). + * `arange(start, stop)`: Values are generated within the half-open interval + `[start, stop)`. + * `arange(start, stop, step)`: Values are generated within the half-open + interval `[start, stop)`, with spacing between values given by step. + + Args: + start: Integer or real, representing the start of the interval. The + interval includes this value. + stop: Integer or real, representing the end of the interval. The + interval does not include this value, except in some cases where + `step` is not an integer and floating point round-off affects the + length of `out`. Defaults to `None`. + step: Integer or real, represent the spacing between values. For any + output `out`, this is the distance between two adjacent values, + `out[i+1] - out[i]`. The default step size is 1. If `step` is + specified as a position argument, `start` must also be given. + dtype: The type of the output array. If `dtype` is not given, infer the + data type from the other input arguments. + + Returns: + Tensor of evenly spaced values. + For floating point arguments, the length of the result is + `ceil((stop - start)/step)`. Because of floating point overflow, this + rule may result in the last element of out being greater than stop. + + Examples: + >>> keras.ops.arange(3) + array([0, 1, 2], dtype=int32) + + >>> keras.ops.arange(3.0) + array([0., 1., 2.], dtype=float32) + + >>> keras.ops.arange(3, 7) + array([3, 4, 5, 6], dtype=int32) + + >>> keras.ops.arange(3, 7, 2) + array([3, 5], dtype=int32) + """ + if any_symbolic_tensors((start, stop, step)): + return Arange(dtype=dtype).symbolic_call(start, stop, step=step) + return backend.numpy.arange(start, stop, step=step, dtype=dtype) + + +class Arccos(Operation): + def call(self, x): + return backend.numpy.arccos(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.arccos", "keras.ops.numpy.arccos"]) +def arccos(x): + """Trigonometric inverse cosine, element-wise. + + The inverse of `cos` so that, if `y = cos(x)`, then `x = arccos(y)`. + + Args: + x: Input tensor. + + Returns: + Tensor of the angle of the ray intersecting the unit circle at the given + x-coordinate in radians `[0, pi]`. + + Example: + >>> x = keras.ops.convert_to_tensor([1, -1]) + >>> keras.ops.arccos(x) + array([0.0, 3.1415927], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Arccos().symbolic_call(x) + return backend.numpy.arccos(x) + + +class Arccosh(Operation): + def call(self, x): + return backend.numpy.arccosh(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.arccosh", "keras.ops.numpy.arccosh"]) +def arccosh(x): + """Inverse hyperbolic cosine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as x. + + Example: + >>> x = keras.ops.convert_to_tensor([10, 100]) + >>> keras.ops.arccosh(x) + array([2.993223, 5.298292], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Arccosh().symbolic_call(x) + return backend.numpy.arccosh(x) + + +class Arcsin(Operation): + def call(self, x): + return backend.numpy.arcsin(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.arcsin", "keras.ops.numpy.arcsin"]) +def arcsin(x): + """Inverse sine, element-wise. + + Args: + x: Input tensor. + + Returns: + Tensor of the inverse sine of each element in `x`, in radians and in + the closed interval `[-pi/2, pi/2]`. + + Example: + >>> x = keras.ops.convert_to_tensor([1, -1, 0]) + >>> keras.ops.arcsin(x) + array([ 1.5707964, -1.5707964, 0.], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Arcsin().symbolic_call(x) + return backend.numpy.arcsin(x) + + +class Arcsinh(Operation): + def call(self, x): + return backend.numpy.arcsinh(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.arcsinh", "keras.ops.numpy.arcsinh"]) +def arcsinh(x): + """Inverse hyperbolic sine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + + Example: + >>> x = keras.ops.convert_to_tensor([1, -1, 0]) + >>> keras.ops.arcsinh(x) + array([0.88137364, -0.88137364, 0.0], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Arcsinh().symbolic_call(x) + return backend.numpy.arcsinh(x) + + +class Arctan(Operation): + def call(self, x): + return backend.numpy.arctan(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.arctan", "keras.ops.numpy.arctan"]) +def arctan(x): + """Trigonometric inverse tangent, element-wise. + + Args: + x: Input tensor. + + Returns: + Tensor of the inverse tangent of each element in `x`, in the interval + `[-pi/2, pi/2]`. + + Example: + >>> x = keras.ops.convert_to_tensor([0, 1]) + >>> keras.ops.arctan(x) + array([0., 0.7853982], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Arctan().symbolic_call(x) + return backend.numpy.arctan(x) + + +class Arctan2(Operation): + def call(self, x1, x2): + return backend.numpy.arctan2(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + outputs_shape = broadcast_shapes(x1_shape, x2_shape) + x1_dtype = backend.standardize_dtype( + getattr(x1, "dtype", backend.floatx()) + ) + x2_dtype = backend.standardize_dtype( + getattr(x2, "dtype", backend.floatx()) + ) + dtype = dtypes.result_type(x1_dtype, x2_dtype, float) + return KerasTensor(outputs_shape, dtype=dtype) + + +@keras_export(["keras.ops.arctan2", "keras.ops.numpy.arctan2"]) +def arctan2(x1, x2): + """Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. + + The quadrant (i.e., branch) is chosen so that `arctan2(x1, x2)` is the + signed angle in radians between the ray ending at the origin and passing + through the point `(1, 0)`, and the ray ending at the origin and passing + through the point `(x2, x1)`. (Note the role reversal: the "y-coordinate" + is the first function parameter, the "x-coordinate" is the second.) By IEEE + convention, this function is defined for `x2 = +/-0` and for either or both + of `x1` and `x2` `= +/-inf`. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Tensor of angles in radians, in the range `[-pi, pi]`. + + Examples: + Consider four points in different quadrants: + >>> x = keras.ops.convert_to_tensor([-1, +1, +1, -1]) + >>> y = keras.ops.convert_to_tensor([-1, -1, +1, +1]) + >>> keras.ops.arctan2(y, x) * 180 / numpy.pi + array([-135., -45., 45., 135.], dtype=float32) + + Note the order of the parameters. `arctan2` is defined also when x2=0 and + at several other points, obtaining values in the range `[-pi, pi]`: + >>> keras.ops.arctan2( + ... keras.ops.array([1., -1.]), + ... keras.ops.array([0., 0.]), + ... ) + array([ 1.5707964, -1.5707964], dtype=float32) + >>> keras.ops.arctan2( + ... keras.ops.array([0., 0., numpy.inf]), + ... keras.ops.array([+0., -0., numpy.inf]), + ... ) + array([0., 3.1415925, 0.7853982], dtype=float32) + """ + if any_symbolic_tensors((x1, x2)): + return Arctan2().symbolic_call(x1, x2) + return backend.numpy.arctan2(x1, x2) + + +class Arctanh(Operation): + def call(self, x): + return backend.numpy.arctanh(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.arctanh", "keras.ops.numpy.arctanh"]) +def arctanh(x): + """Inverse hyperbolic tangent, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Arctanh().symbolic_call(x) + return backend.numpy.arctanh(x) + + +class Argmax(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.argmax(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + if self.keepdims: + return KerasTensor(x.shape, dtype="int32") + if self.axis is None: + return KerasTensor([], dtype="int32") + return KerasTensor( + reduce_shape(x.shape, axis=[self.axis]), dtype="int32" + ) + + +@keras_export(["keras.ops.argmax", "keras.ops.numpy.argmax"]) +def argmax(x, axis=None, keepdims=False): + """Returns the indices of the maximum values along an axis. + + Args: + x: Input tensor. + axis: By default, the index is into the flattened tensor, otherwise + along the specified axis. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. Defaults to `False`. + + Returns: + Tensor of indices. It has the same shape as `x`, with the dimension + along `axis` removed. + + Example: + >>> x = keras.ops.arange(6).reshape(2, 3) + 10 + >>> x + array([[10, 11, 12], + [13, 14, 15]], dtype=int32) + >>> keras.ops.argmax(x) + array(5, dtype=int32) + >>> keras.ops.argmax(x, axis=0) + array([1, 1, 1], dtype=int32) + >>> keras.ops.argmax(x, axis=1) + array([2, 2], dtype=int32) + """ + if any_symbolic_tensors((x,)): + return Argmax(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.argmax(x, axis=axis, keepdims=keepdims) + + +class Argmin(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.argmin(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + if self.keepdims: + return KerasTensor(x.shape, dtype="int32") + if self.axis is None: + return KerasTensor([], dtype="int32") + return KerasTensor( + reduce_shape(x.shape, axis=[self.axis]), dtype="int32" + ) + + +@keras_export(["keras.ops.argmin", "keras.ops.numpy.argmin"]) +def argmin(x, axis=None, keepdims=False): + """Returns the indices of the minimum values along an axis. + + Args: + x: Input tensor. + axis: By default, the index is into the flattened tensor, otherwise + along the specified axis. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. Defaults to `False`. + + Returns: + Tensor of indices. It has the same shape as `x`, with the dimension + along `axis` removed. + + Example: + >>> x = keras.ops.arange(6).reshape(2, 3) + 10 + >>> x + array([[10, 11, 12], + [13, 14, 15]], dtype=int32) + >>> keras.ops.argmin(x) + array(0, dtype=int32) + >>> keras.ops.argmin(x, axis=0) + array([0, 0, 0], dtype=int32) + >>> keras.ops.argmin(x, axis=1) + array([0, 0], dtype=int32) + """ + if any_symbolic_tensors((x,)): + return Argmin(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.argmin(x, axis=axis, keepdims=keepdims) + + +class Argsort(Operation): + def __init__(self, axis=-1, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.numpy.argsort(x, axis=self.axis) + + def compute_output_spec(self, x): + if self.axis is None: + return KerasTensor([int(np.prod(x.shape))], dtype="int32") + return KerasTensor(x.shape, dtype="int32") + + +@keras_export(["keras.ops.argsort", "keras.ops.numpy.argsort"]) +def argsort(x, axis=-1): + """Returns the indices that would sort a tensor. + + Args: + x: Input tensor. + axis: Axis along which to sort. Defaults to `-1` (the last axis). If + `None`, the flattened tensor is used. + + Returns: + Tensor of indices that sort `x` along the specified `axis`. + + Examples: + One dimensional array: + >>> x = keras.ops.array([3, 1, 2]) + >>> keras.ops.argsort(x) + array([1, 2, 0], dtype=int32) + + Two-dimensional array: + >>> x = keras.ops.array([[0, 3], [3, 2], [4, 5]]) + >>> x + array([[0, 3], + [3, 2], + [4, 5]], dtype=int32) + >>> keras.ops.argsort(x, axis=0) + array([[0, 1], + [1, 0], + [2, 2]], dtype=int32) + >>> keras.ops.argsort(x, axis=1) + array([[0, 1], + [1, 0], + [0, 1]], dtype=int32) + """ + if any_symbolic_tensors((x,)): + return Argsort(axis=axis).symbolic_call(x) + return backend.numpy.argsort(x, axis=axis) + + +class Array(Operation): + def __init__(self, dtype=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.array(x, dtype=self.dtype) + + def compute_output_spec(self, x, dtype=None): + dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.array", "keras.ops.numpy.array"]) +def array(x, dtype=None): + """Create a tensor. + + Args: + x: Input tensor. + dtype: The desired data-type for the tensor. + + Returns: + A tensor. + + Examples: + >>> keras.ops.array([1, 2, 3]) + array([1, 2, 3], dtype=int32) + + >>> keras.ops.array([1, 2, 3], dtype="float32") + array([1., 2., 3.], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Array(dtype=dtype).symbolic_call(x) + return backend.numpy.array(x, dtype=dtype) + + +class Average(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + # np.average() does not support axis as tuple as declared by the + # docstring, it only supports int or None. + self.axis = axis + + def call(self, x, weights=None): + return backend.numpy.average(x, weights=weights, axis=self.axis) + + def compute_output_spec(self, x, weights=None): + dtypes_to_resolve = [getattr(x, "dtype", type(x)), float] + if weights is not None: + shape_match = shape_equal(x.shape, weights.shape, allow_none=True) + if self.axis is not None: + shape_match_on_axis = shape_equal( + [x.shape[self.axis]], weights.shape, allow_none=True + ) + dtypes_to_resolve.append(getattr(weights, "dtype", type(weights))) + dtype = dtypes.result_type(*dtypes_to_resolve) + if self.axis is None: + if weights is None or shape_match: + return KerasTensor([], dtype=dtype) + else: + raise ValueError( + "`weights` must have the same shape as `x` when " + f"`axis=None`, but received `weights.shape={weights.shape}`" + f" and `x.shape={x.shape}`." + ) + + if weights is None or shape_match_on_axis or shape_match: + return KerasTensor( + reduce_shape(x.shape, axis=[self.axis]), dtype=dtype + ) + else: + # `weights` can either be a 1D array of length `x.shape[axis]` or + # of the same shape as `x`. + raise ValueError( + "`weights` must have the same size as `x` at " + f"`axis={self.axis}` but received " + f"`weights.shape={weights.shape}` while x.shape at " + f"`{self.axis}` is `{x.shape[self.axis]}`." + ) + + +@keras_export(["keras.ops.average", "keras.ops.numpy.average"]) +def average(x, axis=None, weights=None): + """Compute the weighted average along the specified axis. + + Args: + x: Input tensor. + axis: Integer along which to average `x`. The default, `axis=None`, + will average over all of the elements of the input tensor. If axis + is negative it counts from the last to the first axis. + weights: Tensor of weights associated with the values in `x`. Each + value in `x` contributes to the average according to its + associated weight. The weights array can either be 1-D (in which + case its length must be the size of a along the given axis) or of + the same shape as `x`. If `weights=None` (default), then all data + in `x` are assumed to have a weight equal to one. + + The 1-D calculation is: `avg = sum(a * weights) / sum(weights)`. + The only constraint on weights is that `sum(weights)` must not be 0. + + Returns: + Return the average along the specified axis. + + Examples: + >>> data = keras.ops.arange(1, 5) + >>> data + array([1, 2, 3, 4], dtype=int32) + >>> keras.ops.average(data) + array(2.5, dtype=float32) + >>> keras.ops.average( + ... keras.ops.arange(1, 11), + ... weights=keras.ops.arange(10, 0, -1) + ... ) + array(4., dtype=float32) + + >>> data = keras.ops.arange(6).reshape((3, 2)) + >>> data + array([[0, 1], + [2, 3], + [4, 5]], dtype=int32) + >>> keras.ops.average( + ... data, + ... axis=1, + ... weights=keras.ops.array([1./4, 3./4]) + ... ) + array([0.75, 2.75, 4.75], dtype=float32) + >>> keras.ops.average( + ... data, + ... weights=keras.ops.array([1./4, 3./4]) + ... ) + Traceback (most recent call last): + ... + ValueError: Axis must be specified when shapes of a and weights differ. + """ + if any_symbolic_tensors((x,)): + return Average(axis=axis).symbolic_call(x, weights=weights) + return backend.numpy.average(x, axis=axis, weights=weights) + + +class Bartlett(Operation): + def call(self, x): + return backend.numpy.bartlett(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=backend.floatx()) + + +@keras_export(["keras.ops.bartlett", "keras.ops.numpy.bartlett"]) +def bartlett(x): + """Bartlett window function. + The Bartlett window is a triangular window that rises then falls linearly. + + Args: + x: Scalar or 1D Tensor. Window length. + + Returns: + A 1D tensor containing the Bartlett window values. + + Example: + >>> x = keras.ops.convert_to_tensor(5) + >>> keras.ops.bartlett(x) + array([0. , 0.5, 1. , 0.5, 0. ], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Bartlett().symbolic_call(x) + return backend.numpy.bartlett(x) + + +class Hamming(Operation): + def call(self, x): + return backend.numpy.hamming(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=backend.floatx()) + + +@keras_export(["keras.ops.hamming", "keras.ops.numpy.hamming"]) +def hamming(x): + """Hamming window function. + + The Hamming window is defined as: + `w[n] = 0.54 - 0.46 * cos(2 * pi * n / (N - 1))` for `0 <= n <= N - 1`. + + Args: + x: Scalar or 1D Tensor. The window length. + + Returns: + A 1D tensor containing the Hamming window values. + + Example: + >>> x = keras.ops.convert_to_tensor(5) + >>> keras.ops.hamming(x) + array([0.08, 0.54, 1. , 0.54, 0.08], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Hamming().symbolic_call(x) + return backend.numpy.hamming(x) + + +class Hanning(Operation): + def call(self, x): + return backend.numpy.hanning(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=backend.floatx()) + + +@keras_export(["keras.ops.hanning", "keras.ops.numpy.hanning"]) +def hanning(x): + """Hanning window function. + + The Hanning window is defined as: + `w[n] = 0.5 - 0.5 * cos(2 * pi * n / (N - 1))` for `0 <= n <= N - 1`. + + Args: + x: Scalar or 1D Tensor. The window length. + + Returns: + A 1D tensor containing the Hanning window values. + + Example: + >>> x = keras.ops.convert_to_tensor(5) + >>> keras.ops.hanning(x) + array([0. , 0.5, 1. , 0.5, 0. ], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Hanning().symbolic_call(x) + return backend.numpy.hanning(x) + + +class Heaviside(Operation): + def call(self, x1, x2): + return backend.numpy.heaviside(x1, x2) + + def compute_output_spec(self, x1, x2): + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = backend.floatx() + elif dtype == "int64": + dtype = "float64" + return KerasTensor(broadcast_shapes(x1.shape, x2.shape), dtype=dtype) + + +@keras_export(["keras.ops.heaviside", "keras.ops.numpy.heaviside"]) +def heaviside(x1, x2): + """Heaviside step function. + + The Heaviside step function is defined as: + `heaviside(x1, x2) = 0 if x1 < 0, 1 if x1 > 0, x2 if x1 == 0` + + Args: + x1: A tensor input. + x2: A scalar or tensor, the value to return when `x1 == 0`. + + Returns: + A tensor with a shape determined by broadcasting `x1` and `x2`. + + Example: + >>> x1 = keras.ops.convert_to_tensor([-2.0, 0.0, 3.0]) + >>> x2 = 0.5 + >>> keras.ops.heaviside(x1, x2) + array([0. , 0.5, 1. ], dtype=float32) + """ + if any_symbolic_tensors((x1, x2)): + return Heaviside().symbolic_call(x1, x2) + return backend.numpy.heaviside(x1, x2) + + +class Kaiser(Operation): + def __init__(self, beta, *, name=None): + super().__init__(name=name) + self.beta = beta + + def call(self, x): + return backend.numpy.kaiser(x, self.beta) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=backend.floatx()) + + +@keras_export(["keras.ops.kaiser", "keras.ops.numpy.kaiser"]) +def kaiser(x, beta): + """Kaiser window function. + + The Kaiser window is defined as: + `w[n] = I0(beta * sqrt(1 - (2n / (N - 1) - 1)^2)) / I0(beta)` + where I0 is the modified zeroth-order Bessel function of the first kind. + + Args: + x: Scalar or 1D Tensor. The window length. + beta: Float. Shape parameter for the Kaiser window. + + Returns: + A 1D tensor containing the Kaiser window values. + + Example: + >>> x = keras.ops.convert_to_tensor(5) + >>> keras.ops.kaiser(x, beta=14.0) + array([7.7268669e-06, 1.6493219e-01, 1.0000000e+00, 1.6493219e-01, + 7.7268669e-06], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Kaiser(beta).symbolic_call(x) + return backend.numpy.kaiser(x, beta) + + +class Bincount(Operation): + def __init__(self, weights=None, minlength=0, sparse=False, *, name=None): + super().__init__(name=name) + self.weights = weights + self.minlength = minlength + self.sparse = sparse + + def call(self, x): + return backend.numpy.bincount( + x, + weights=self.weights, + minlength=self.minlength, + sparse=self.sparse, + ) + + def compute_output_spec(self, x): + dtypes_to_resolve = [x.dtype] + if self.weights is not None: + weights = backend.convert_to_tensor(self.weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + else: + dtype = "int32" + x_sparse = getattr(x, "sparse", False) + return KerasTensor( + list(x.shape[:-1]) + [None], + dtype=dtype, + sparse=x_sparse or self.sparse, + ) + + +@keras_export(["keras.ops.bincount", "keras.ops.numpy.bincount"]) +def bincount(x, weights=None, minlength=0, sparse=False): + """Count the number of occurrences of each value in a tensor of integers. + + Args: + x: Input tensor. + It must be of dimension 1, and it must only contain non-negative + integer(s). + weights: Weight tensor. + It must have the same length as `x`. The default value is `None`. + If specified, `x` is weighted by it, i.e. if `n = x[i]`, + `out[n] += weight[i]` instead of the default behavior `out[n] += 1`. + minlength: An integer. + The default value is 0. If specified, there will be at least + this number of bins in the output tensor. If greater than + `max(x) + 1`, each value of the output at an index higher than + `max(x)` is set to 0. + sparse: Whether to return a sparse tensor; for backends that support + sparse tensors. + + Returns: + 1D tensor where each element gives the number of occurrence(s) of its + index value in x. Its length is the maximum between `max(x) + 1` and + minlength. + + Examples: + >>> x = keras.ops.array([1, 2, 2, 3], dtype="uint8") + >>> keras.ops.bincount(x) + array([0, 1, 2, 1], dtype=int32) + >>> weights = x / 2 + >>> weights + array([0.5, 1., 1., 1.5], dtype=float64) + >>> keras.ops.bincount(x, weights=weights) + array([0., 0.5, 2., 1.5], dtype=float64) + >>> minlength = (keras.ops.max(x).numpy() + 1) + 2 # 6 + >>> keras.ops.bincount(x, minlength=minlength) + array([0, 1, 2, 1, 0, 0], dtype=int32) + """ + if any_symbolic_tensors((x,)): + return Bincount( + weights=weights, minlength=minlength, sparse=sparse + ).symbolic_call(x) + return backend.numpy.bincount( + x, weights=weights, minlength=minlength, sparse=sparse + ) + + +class BitwiseAnd(Operation): + def call(self, x, y): + return backend.numpy.bitwise_and(x, y) + + def compute_output_spec(self, x, y): + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.bitwise_and", "keras.ops.numpy.bitwise_and"]) +def bitwise_and(x, y): + """Compute the bit-wise AND of two arrays element-wise. + + Computes the bit-wise AND of the underlying binary representation of the + integers in the input arrays. This ufunc implements the C/Python operator + `&`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return BitwiseAnd().symbolic_call(x, y) + return backend.numpy.bitwise_and(x, y) + + +class BitwiseInvert(Operation): + def call(self, x): + return backend.numpy.bitwise_invert(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.bitwise_invert", "keras.ops.numpy.bitwise_invert"]) +def bitwise_invert(x): + """Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Computes the bit-wise NOT of the underlying binary representation of the + integers in the input arrays. This ufunc implements the C/Python operator + `~`. + + Args: + x: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x,)): + return BitwiseInvert().symbolic_call(x) + return backend.numpy.bitwise_invert(x) + + +class BitwiseNot(Operation): + def call(self, x): + return backend.numpy.bitwise_not(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.bitwise_not", "keras.ops.numpy.bitwise_not"]) +def bitwise_not(x): + """Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Computes the bit-wise NOT of the underlying binary representation of the + integers in the input arrays. This ufunc implements the C/Python operator + `~`. + + Args: + x: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x,)): + return BitwiseNot().symbolic_call(x) + return backend.numpy.bitwise_not(x) + + +class BitwiseOr(Operation): + def call(self, x, y): + return backend.numpy.bitwise_or(x, y) + + def compute_output_spec(self, x, y): + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.bitwise_or", "keras.ops.numpy.bitwise_or"]) +def bitwise_or(x, y): + """Compute the bit-wise OR of two arrays element-wise. + + Computes the bit-wise OR of the underlying binary representation of the + integers in the input arrays. This ufunc implements the C/Python operator + `|`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return BitwiseOr().symbolic_call(x, y) + return backend.numpy.bitwise_or(x, y) + + +class BitwiseXor(Operation): + def call(self, x, y): + return backend.numpy.bitwise_xor(x, y) + + def compute_output_spec(self, x, y): + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.bitwise_xor", "keras.ops.numpy.bitwise_xor"]) +def bitwise_xor(x, y): + """Compute the bit-wise XOR of two arrays element-wise. + + Computes the bit-wise XOR of the underlying binary representation of the + integers in the input arrays. This ufunc implements the C/Python operator + `^`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return BitwiseXor().symbolic_call(x, y) + return backend.numpy.bitwise_xor(x, y) + + +class BitwiseLeftShift(Operation): + def call(self, x, y): + return backend.numpy.bitwise_left_shift(x, y) + + def compute_output_spec(self, x, y): + if isinstance(y, int): + dtype = x.dtype + else: + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export( + ["keras.ops.bitwise_left_shift", "keras.ops.numpy.bitwise_left_shift"] +) +def bitwise_left_shift(x, y): + """Shift the bits of an integer to the left. + + Bits are shifted to the left by appending `y` 0s at the right of `x`. + Since the internal representation of numbers is in binary format, this + operation is equivalent to multiplying `x` by `2**y`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return BitwiseLeftShift().symbolic_call(x, y) + return backend.numpy.bitwise_left_shift(x, y) + + +class LeftShift(Operation): + def call(self, x, y): + return backend.numpy.left_shift(x, y) + + def compute_output_spec(self, x, y): + if isinstance(y, int): + dtype = x.dtype + else: + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.left_shift", "keras.ops.numpy.left_shift"]) +def left_shift(x, y): + """Shift the bits of an integer to the left. + + Bits are shifted to the left by appending `y` 0s at the right of `x`. + Since the internal representation of numbers is in binary format, this + operation is equivalent to multiplying `x` by `2**y`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return LeftShift().symbolic_call(x, y) + return backend.numpy.left_shift(x, y) + + +class BitwiseRightShift(Operation): + def call(self, x, y): + return backend.numpy.bitwise_right_shift(x, y) + + def compute_output_spec(self, x, y): + if isinstance(y, int): + dtype = x.dtype + else: + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export( + ["keras.ops.bitwise_right_shift", "keras.ops.numpy.bitwise_right_shift"] +) +def bitwise_right_shift(x, y): + """Shift the bits of an integer to the right. + + Bits are shifted to the right `y`. Because the internal representation of + numbers is in binary format, this operation is equivalent to dividing `x` by + `2**y`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return BitwiseRightShift().symbolic_call(x, y) + return backend.numpy.bitwise_right_shift(x, y) + + +class RightShift(Operation): + def call(self, x, y): + return backend.numpy.right_shift(x, y) + + def compute_output_spec(self, x, y): + if isinstance(y, int): + dtype = x.dtype + else: + dtype = dtypes.result_type(x.dtype, y.dtype) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.right_shift", "keras.ops.numpy.right_shift"]) +def right_shift(x, y): + """Shift the bits of an integer to the right. + + Bits are shifted to the right `y`. Because the internal representation of + numbers is in binary format, this operation is equivalent to dividing `x` by + `2**y`. + + Args: + x: Input integer tensor. + y: Input integer tensor. + + Returns: + Result tensor. + """ + if any_symbolic_tensors((x, y)): + return RightShift().symbolic_call(x, y) + return backend.numpy.right_shift(x, y) + + +class Blackman(Operation): + def call(self, x): + return backend.numpy.blackman(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=backend.floatx()) + + +@keras_export(["keras.ops.blackman", "keras.ops.numpy.blackman"]) +def blackman(x): + """Blackman window function. + The Blackman window is a taper formed by using a weighted cosine. + + Args: + x: Scalar or 1D Tensor. Window length. + + Returns: + A 1D tensor containing the Blackman window values. + + Example: + >>> x = keras.ops.convert_to_tensor(5) + >>> keras.ops.blackman(x) + array([-1.3877788e-17, 3.4000000e-01, 1.0000000e+00, 3.4000000e-01, + -1.3877788e-17], dtype=float32) + """ + if any_symbolic_tensors((x,)): + return Blackman().symbolic_call(x) + return backend.numpy.blackman(x) + + +class BroadcastTo(Operation): + def __init__(self, shape, *, name=None): + super().__init__(name=name) + self.shape = shape + + def call(self, x): + return backend.numpy.broadcast_to(x, self.shape) + + def compute_output_spec(self, x): + # Catch broadcasting errors for clear error messages. + broadcast_shapes(x.shape, self.shape) + return KerasTensor(self.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.broadcast_to", + "keras.ops.numpy.broadcast_to", + ] +) +def broadcast_to(x, shape): + """Broadcast a tensor to a new shape. + + Args: + x: The tensor to broadcast. + shape: The shape of the desired tensor. A single integer `i` is + interpreted as `(i,)`. + + Returns: + A tensor with the desired shape. + + Examples: + >>> x = keras.ops.array([1, 2, 3]) + >>> keras.ops.broadcast_to(x, (3, 3)) + array([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) + """ + if any_symbolic_tensors((x,)): + return BroadcastTo(shape=shape).symbolic_call(x) + return backend.numpy.broadcast_to(x, shape) + + +class Cbrt(Operation): + def call(self, x): + return backend.numpy.cbrt(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if dtype in [ + "bool", + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + ]: + dtype = backend.floatx() + elif dtype == "int64": + dtype = "float64" + + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.cbrt", "keras.ops.numpy.cbrt"]) +def cbrt(x): + """Computes the cube root of the input tensor, element-wise. + + This operation returns the real-valued cube root of `x`, handling + negative numbers properly in the real domain. + + Args: + x: Input tensor. + + Returns: + A tensor containing the cube root of each element in `x`. + """ + if any_symbolic_tensors((x,)): + return Cbrt().symbolic_call(x) + return backend.numpy.cbrt(x) + + +class Ceil(Operation): + def call(self, x): + return backend.numpy.ceil(x) + + def compute_output_spec(self, x): + if backend.standardize_dtype(x.dtype) == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.ceil", "keras.ops.numpy.ceil"]) +def ceil(x): + """Return the ceiling of the input, element-wise. + + The ceil of the scalar `x` is the smallest integer `i`, such that + `i >= x`. + + Args: + x: Input tensor. + + Returns: + The ceiling of each element in `x`, with float dtype. + """ + if any_symbolic_tensors((x,)): + return Ceil().symbolic_call(x) + return backend.numpy.ceil(x) + + +class Clip(Operation): + def __init__(self, x_min, x_max, *, name=None): + super().__init__(name=name) + self.x_min = x_min + self.x_max = x_max + + def call(self, x): + return backend.numpy.clip(x, self.x_min, self.x_max) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if dtype == "bool": + dtype = "int32" + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.clip", "keras.ops.numpy.clip"]) +def clip(x, x_min, x_max): + """Clip (limit) the values in a tensor. + + Given an interval, values outside the interval are clipped to the + interval edges. For example, if an interval of `[0, 1]` is specified, + values smaller than 0 become 0, and values larger than 1 become 1. + + Args: + x: Input tensor. + x_min: Minimum value. + x_max: Maximum value. + Returns: + The clipped tensor. + """ + if any_symbolic_tensors((x,)): + return Clip(x_min, x_max).symbolic_call(x) + return backend.numpy.clip(x, x_min, x_max) + + +class Concatenate(Operation): + def __init__(self, axis=0, *, name=None): + super().__init__(name=name) + if axis is None: + raise ValueError("`axis` cannot be None for `concatenate`.") + self.axis = axis + + def call(self, xs): + return backend.numpy.concatenate(xs, axis=self.axis) + + def compute_output_spec(self, xs): + first_shape = xs[0].shape + total_size_on_axis = 0 + all_sparse = True + dtypes_to_resolve = [] + for x in xs: + if not shape_equal( + x.shape, first_shape, axis=[self.axis], allow_none=True + ): + raise ValueError( + "Every value in `xs` must have the same shape except on " + f"the `axis` dim. But found element of shape {x.shape}, " + f"which is different from the first element's " + f"shape {first_shape}." + ) + if total_size_on_axis is None or x.shape[self.axis] is None: + total_size_on_axis = None + else: + total_size_on_axis += x.shape[self.axis] + all_sparse = all_sparse and getattr(x, "sparse", False) + dtypes_to_resolve.append(getattr(x, "dtype", type(x))) + output_shape = list(first_shape) + output_shape[self.axis] = total_size_on_axis + dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=dtype, sparse=all_sparse) + + +@keras_export( + [ + "keras.ops.concatenate", + "keras.ops.numpy.concatenate", + ] +) +def concatenate(xs, axis=0): + """Join a sequence of tensors along an existing axis. + + Args: + xs: The sequence of tensors to concatenate. + axis: The axis along which the tensors will be joined. Defaults to `0`. + + Returns: + The concatenated tensor. + """ + if any_symbolic_tensors(xs): + return Concatenate(axis=axis).symbolic_call(xs) + return backend.numpy.concatenate(xs, axis=axis) + + +class Conjugate(Operation): + def call(self, x): + return backend.numpy.conjugate(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.conjugate", "keras.ops.numpy.conjugate"]) +def conjugate(x): + """Returns the complex conjugate, element-wise. + + The complex conjugate of a complex number is obtained by changing the sign + of its imaginary part. + + `keras.ops.conj` is a shorthand for this function. + + Args: + x: Input tensor. + + Returns: + The complex conjugate of each element in `x`. + """ + if any_symbolic_tensors((x,)): + return Conjugate().symbolic_call(x) + return backend.numpy.conjugate(x) + + +class Conj(Conjugate): + pass + + +@keras_export(["keras.ops.conj", "keras.ops.numpy.conj"]) +def conj(x): + """Shorthand for `keras.ops.conjugate`.""" + return conjugate(x) + + +class Copy(Operation): + def call(self, x): + return backend.numpy.copy(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.copy", "keras.ops.numpy.copy"]) +def copy(x): + """Returns a copy of `x`. + + Args: + x: Input tensor. + + Returns: + A copy of `x`. + """ + if any_symbolic_tensors((x,)): + return Copy().symbolic_call(x) + return backend.numpy.copy(x) + + +class Cos(Operation): + def call(self, x): + return backend.numpy.cos(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.cos", "keras.ops.numpy.cos"]) +def cos(x): + """Cosine, element-wise. + + Args: + x: Input tensor. + + Returns: + The corresponding cosine values. + """ + if any_symbolic_tensors((x,)): + return Cos().symbolic_call(x) + return backend.numpy.cos(x) + + +class Cosh(Operation): + def call(self, x): + return backend.numpy.cosh(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.cosh", "keras.ops.numpy.cosh"]) +def cosh(x): + """Hyperbolic cosine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Cosh().symbolic_call(x) + return backend.numpy.cosh(x) + + +class CountNonzero(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = (axis,) + else: + self.axis = axis + + def call(self, x): + return backend.numpy.count_nonzero(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis), + dtype="int32", + ) + + +@keras_export( + [ + "keras.ops.count_nonzero", + "keras.ops.numpy.count_nonzero", + ] +) +def count_nonzero(x, axis=None): + """Counts the number of non-zero values in `x` along the given `axis`. + + If no axis is specified then all non-zeros in the tensor are counted. + + Args: + x: Input tensor. + axis: Axis or tuple of axes along which to count the number of + non-zeros. Defaults to `None`. + + Returns: + int or tensor of ints. + + Examples: + >>> x = keras.ops.array([[0, 1, 7, 0], [3, 0, 2, 19]]) + >>> keras.ops.count_nonzero(x) + 5 + >>> keras.ops.count_nonzero(x, axis=0) + array([1, 1, 2, 1], dtype=int64) + >>> keras.ops.count_nonzero(x, axis=1) + array([2, 3], dtype=int64) + """ + if any_symbolic_tensors((x,)): + return CountNonzero(axis=axis).symbolic_call(x) + return backend.numpy.count_nonzero(x, axis=axis) + + +class Cross(Operation): + def __init__(self, axisa=-1, axisb=-1, axisc=-1, axis=None, *, name=None): + super().__init__(name=name) + if axis is not None: + self.axisa = axis + self.axisb = axis + self.axisc = axis + else: + self.axisa = axisa + self.axisb = axisb + self.axisc = axisc + + def call(self, x1, x2): + return backend.numpy.cross(x1, x2, self.axisa, self.axisb, self.axisc) + + def compute_output_spec(self, x1, x2): + x1_shape = list(x1.shape) + x2_shape = list(x2.shape) + + x1_value_size = x1_shape[self.axisa] + x2_value_size = x2_shape[self.axisa] + del x1_shape[self.axisa] + del x2_shape[self.axisb] + output_shape = broadcast_shapes(x1_shape, x2_shape) + + if x1_value_size is not None and x1_value_size not in (2, 3): + raise ValueError( + "`x1`'s dim on `axis={axisa}` must be either 2 or 3, but " + f"received: {x1_value_size}" + ) + if x2_value_size is not None and x2_value_size not in (2, 3): + raise ValueError( + "`x2`'s dim on `axis={axisb}` must be either 2 or 3, but " + f"received: {x2_value_size}" + ) + + if x1_value_size == 3 or x2_value_size == 3: + value_size = [3] + else: + value_size = [] + + output_shape = ( + output_shape[: self.axisc] + value_size + output_shape[self.axisc :] + ) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.cross", "keras.ops.numpy.cross"]) +def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + """Returns the cross product of two (arrays of) vectors. + + The cross product of `x1` and `x2` in R^3 is a vector + perpendicular to both `x1` and `x2`. If `x1` and `x2` are arrays of + vectors, the vectors are defined by the last axis of `x1` and `x2` + by default, and these axes can have dimensions 2 or 3. + + Where the dimension of either `x1` or `x2` is 2, the third component of + the input vector is assumed to be zero and the cross product calculated + accordingly. + + In cases where both input vectors have dimension 2, the z-component of + the cross product is returned. + + Args: + x1: Components of the first vector(s). + x2: Components of the second vector(s). + axisa: Axis of `x1` that defines the vector(s). Defaults to `-1`. + axisb: Axis of `x2` that defines the vector(s). Defaults to `-1`. + axisc: Axis of the result containing the cross product vector(s). + Ignored if both input vectors have dimension 2, as the return is + scalar. By default, the last axis. + axis: If defined, the axis of `x1`, `x2` and the result that + defines the vector(s) and cross product(s). Overrides `axisa`, + `axisb` and `axisc`. + + Note: + Torch backend does not support two dimensional vectors, or the + arguments `axisa`, `axisb` and `axisc`. Use `axis` instead. + + Returns: + Vector cross product(s). + """ + if any_symbolic_tensors((x1, x2)): + return Cross( + axisa=axisa, axisb=axisb, axisc=axisc, axis=axis + ).symbolic_call(x1, x2) + return backend.numpy.cross( + x1, + x2, + axisa=axisa, + axisb=axisb, + axisc=axisc, + axis=axis, + ) + + +class Cumprod(Operation): + def __init__(self, axis=None, dtype=None, *, name=None): + super().__init__(name=name) + self.axis = axis + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.cumprod(x, axis=self.axis, dtype=self.dtype) + + def compute_output_spec(self, x): + if self.axis is None: + if None in x.shape: + output_shape = (None,) + else: + output_shape = (int(np.prod(x.shape)),) + else: + output_shape = x.shape + output_dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + if output_dtype == "bool": + output_dtype = "int32" + return KerasTensor(output_shape, output_dtype) + + +@keras_export(["keras.ops.cumprod", "keras.ops.numpy.cumprod"]) +def cumprod(x, axis=None, dtype=None): + """Return the cumulative product of elements along a given axis. + + Args: + x: Input tensor. + axis: Axis along which the cumulative product is computed. + By default the input is flattened. + dtype: dtype of returned tensor. Defaults to x.dtype. + + Returns: + Output tensor. + """ + return Cumprod(axis=axis, dtype=dtype)(x) + + +class Cumsum(Operation): + def __init__(self, axis=None, dtype=None, *, name=None): + super().__init__(name=name) + self.axis = axis + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.cumsum(x, axis=self.axis, dtype=self.dtype) + + def compute_output_spec(self, x): + if self.axis is None: + if None in x.shape: + output_shape = (None,) + else: + output_shape = (int(np.prod(x.shape)),) + else: + output_shape = x.shape + output_dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + if output_dtype == "bool": + output_dtype = "int32" + return KerasTensor(output_shape, output_dtype) + + +@keras_export(["keras.ops.cumsum", "keras.ops.numpy.cumsum"]) +def cumsum(x, axis=None, dtype=None): + """Returns the cumulative sum of elements along a given axis. + + Args: + x: Input tensor. + axis: Axis along which the cumulative sum is computed. + By default the input is flattened. + dtype: dtype of returned tensor. Defaults to x.dtype. + + Returns: + Output tensor. + """ + return Cumsum(axis=axis, dtype=dtype)(x) + + +class Deg2rad(Operation): + def call(self, x): + return backend.numpy.deg2rad(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if dtype in ["int64", "float64"]: + dtype = "float64" + elif dtype not in ["bfloat16", "float16"]: + dtype = backend.floatx() + return KerasTensor(x.shape, dtype) + + +@keras_export(["keras.ops.deg2rad", "keras.ops.numpy.deg2rad"]) +def deg2rad(x): + """Convert angles from degrees to radians. + + The conversion is defined as: + `rad = deg * (π / 180)` + + Args: + x: Input tensor of angles in degrees. + + Returns: + A tensor containing angles converted to radians. + + Examples: + >>> from keras import ops + >>> ops.deg2rad(180.0) + 3.141592653589793 + >>> ops.deg2rad([0.0, 90.0, 180.0]) + array([0., 1.57079633, 3.14159265]) + """ + if any_symbolic_tensors((x,)): + return Deg2rad().symbolic_call(x) + return backend.numpy.deg2rad(x) + + +class Diag(Operation): + def __init__(self, k=0, *, name=None): + super().__init__(name=name) + self.k = k + + def call(self, x): + return backend.numpy.diag(x, k=self.k) + + def compute_output_spec(self, x): + x_shape = x.shape + if len(x_shape) == 1: + if x_shape[0] is None: + output_shape = [None, None] + else: + output_shape = [ + x_shape[0] + int(np.abs(self.k)), + x_shape[0] + int(np.abs(self.k)), + ] + elif len(x_shape) == 2: + if None in x_shape: + output_shape = [None] + else: + shorter_side = np.minimum(x_shape[0], x_shape[1]) + if self.k > 0: + remaining = x_shape[1] - self.k + else: + remaining = x_shape[0] + self.k + output_shape = [ + int(np.maximum(0, np.minimum(remaining, shorter_side))) + ] + else: + raise ValueError( + f"`x` must be 1-D or 2-D, but received shape {x.shape}." + ) + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.diag", "keras.ops.numpy.diag"]) +def diag(x, k=0): + """Extract a diagonal or construct a diagonal array. + + Args: + x: Input tensor. If `x` is 2-D, returns the k-th diagonal of `x`. + If `x` is 1-D, return a 2-D tensor with `x` on the k-th diagonal. + k: The diagonal to consider. Defaults to `0`. Use `k > 0` for diagonals + above the main diagonal, and `k < 0` for diagonals below + the main diagonal. + + Returns: + The extracted diagonal or constructed diagonal tensor. + + Examples: + >>> from keras.src import ops + >>> x = ops.arange(9).reshape((3, 3)) + >>> x + array([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + + >>> ops.diag(x) + array([0, 4, 8]) + >>> ops.diag(x, k=1) + array([1, 5]) + >>> ops.diag(x, k=-1) + array([3, 7]) + + >>> ops.diag(ops.diag(x))) + array([[0, 0, 0], + [0, 4, 0], + [0, 0, 8]]) + """ + if any_symbolic_tensors((x,)): + return Diag(k=k).symbolic_call(x) + return backend.numpy.diag(x, k=k) + + +class Diagflat(Operation): + def __init__(self, k=0, *, name=None): + super().__init__(name=name) + self.k = k + + def call(self, x): + return backend.numpy.diagflat(x, k=self.k) + + def compute_output_spec(self, x): + x_shape = x.shape + + if len(x_shape) == 0: + flat_size = 1 + elif len(x_shape) == 1: + flat_size = x_shape[0] if x_shape[0] is not None else None + else: + flat_size = None + for s in x_shape: + if s is None: + flat_size = None + break + elif flat_size is None: + flat_size = s + else: + flat_size *= s + + if flat_size is None: + output_shape = [None, None] + else: + output_shape = [ + flat_size + int(np.abs(self.k)), + flat_size + int(np.abs(self.k)), + ] + + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.diagflat", "keras.ops.numpy.diagflat"]) +def diagflat(x, k=0): + """Create a two-dimensional array with the flattened input on + the k-th diagonal. + + Args: + x: Input tensor to be flattened and placed on the diagonal. + k: The diagonal to place the flattened input. Defaults to `0`. + Use `k > 0` for diagonals above the main diagonal, + and `k < 0` for diagonals below the main diagonal. + + Returns: + A 2-D tensor with the flattened input on the specified diagonal. + """ + if any_symbolic_tensors((x,)): + return Diagflat(k=k).symbolic_call(x) + return backend.numpy.diagflat(x, k=k) + + +class Diagonal(Operation): + def __init__(self, offset=0, axis1=0, axis2=1, *, name=None): + super().__init__(name=name) + self.offset = offset + self.axis1 = axis1 + self.axis2 = axis2 + + def call(self, x): + return backend.numpy.diagonal( + x, + offset=self.offset, + axis1=self.axis1, + axis2=self.axis2, + ) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + if len(x_shape) < 2: + raise ValueError( + "`diagonal` requires an array of at least two dimensions, but " + f"`x` is of shape {x.shape}." + ) + + shape_2d = [x_shape[self.axis1], x_shape[self.axis2]] + x_shape[self.axis1] = -1 + x_shape[self.axis2] = -1 + output_shape = list(filter((-1).__ne__, x_shape)) + if None in shape_2d: + diag_shape = [None] + else: + shorter_side = np.minimum(shape_2d[0], shape_2d[1]) + if self.offset > 0: + remaining = shape_2d[1] - self.offset + else: + remaining = shape_2d[0] + self.offset + diag_shape = [ + int(np.maximum(0, np.minimum(remaining, shorter_side))) + ] + output_shape = output_shape + diag_shape + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.diagonal", "keras.ops.numpy.diagonal"]) +def diagonal(x, offset=0, axis1=0, axis2=1): + """Return specified diagonals. + + If `x` is 2-D, returns the diagonal of `x` with the given offset, i.e., the + collection of elements of the form `x[i, i+offset]`. + + If `x` has more than two dimensions, the axes specified by `axis1` + and `axis2` are used to determine the 2-D sub-array whose diagonal + is returned. + + The shape of the resulting array can be determined by removing `axis1` + and `axis2` and appending an index to the right equal to the size of + the resulting diagonals. + + Args: + x: Input tensor. + offset: Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to `0`.(main diagonal). + axis1: Axis to be used as the first axis of the 2-D sub-arrays. + Defaults to `0`.(first axis). + axis2: Axis to be used as the second axis of the 2-D sub-arrays. + Defaults to `1` (second axis). + + Returns: + Tensor of diagonals. + + Examples: + >>> from keras.src import ops + >>> x = ops.arange(4).reshape((2, 2)) + >>> x + array([[0, 1], + [2, 3]]) + >>> x.diagonal() + array([0, 3]) + >>> x.diagonal(1) + array([1]) + + >>> x = ops.arange(8).reshape((2, 2, 2)) + >>> x + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> x.diagonal(0, 0, 1) + array([[0, 6], + [1, 7]]) + """ + if any_symbolic_tensors((x,)): + return Diagonal( + offset=offset, + axis1=axis1, + axis2=axis2, + ).symbolic_call(x) + return backend.numpy.diagonal( + x, + offset=offset, + axis1=axis1, + axis2=axis2, + ) + + +class Diff(Operation): + def __init__(self, n=1, axis=-1, *, name=None): + super().__init__(name=name) + self.n = n + self.axis = axis + + def call(self, a): + return backend.numpy.diff(a, n=self.n, axis=self.axis) + + def compute_output_spec(self, a): + shape = list(a.shape) + size = shape[self.axis] + if size is not None: + shape[self.axis] = builtins.max(size - self.n, 0) + return KerasTensor(shape, dtype=a.dtype) + + +@keras_export(["keras.ops.diff", "keras.ops.numpy.diff"]) +def diff(a, n=1, axis=-1): + """Calculate the n-th discrete difference along the given axis. + + The first difference is given by `out[i] = a[i+1] - a[i]` along + the given axis, higher differences are calculated by using `diff` + recursively. + + Args: + a: Input tensor. + n: The number of times values are differenced. Defaults to `1`. + axis: Axis to compute discrete difference(s) along. + Defaults to `-1`.(last axis). + + Returns: + Tensor of diagonals. + + Examples: + >>> from keras.src import ops + >>> x = ops.convert_to_tensor([1, 2, 4, 7, 0]) + >>> ops.diff(x) + array([ 1, 2, 3, -7]) + >>> ops.diff(x, n=2) + array([ 1, 1, -10]) + + >>> x = ops.convert_to_tensor([[1, 3, 6, 10], [0, 5, 6, 8]]) + >>> ops.diff(x) + array([[2, 3, 4], + [5, 1, 2]]) + >>> ops.diff(x, axis=0) + array([[-1, 2, 0, -2]]) + """ + return Diff(n=n, axis=axis)(a) + + +class Digitize(Operation): + def call(self, x, bins): + return backend.numpy.digitize(x, bins) + + def compute_output_spec(self, x, bins): + bins_shape = bins.shape + if len(bins_shape) > 1: + raise ValueError( + f"`bins` must be a 1D array. Received: bins={bins} " + f"with shape bins.shape={bins_shape}" + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype="int32", sparse=sparse) + + +@keras_export(["keras.ops.digitize", "keras.ops.numpy.digitize"]) +def digitize(x, bins): + """Returns the indices of the bins to which each value in `x` belongs. + + Args: + x: Input array to be binned. + bins: Array of bins. It has to be one-dimensional and monotonically + increasing. + + Returns: + Output array of indices, of same shape as `x`. + + Example: + >>> x = np.array([0.0, 1.0, 3.0, 1.6]) + >>> bins = np.array([0.0, 3.0, 4.5, 7.0]) + >>> keras.ops.digitize(x, bins) + array([1, 1, 2, 1]) + """ + if any_symbolic_tensors((x, bins)): + return Digitize().symbolic_call(x, bins) + return backend.numpy.digitize(x, bins) + + +class Dot(Operation): + def call(self, x1, x2): + return backend.numpy.dot(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = list(getattr(x1, "shape", [])) + x2_shape = list(getattr(x2, "shape", [])) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if x1_shape == [] or x2_shape == []: + return multiply(x1, x2) + if len(x1_shape) == 1 and len(x2_shape) == 1: + return KerasTensor([], dtype=dtype) + if len(x2_shape) == 1: + if x1_shape[-1] != x2_shape[0]: + raise ValueError( + "Shape must match on the last axis of `x1` and `x2` when " + "`x1` is N-d array while `x2` is 1-D, but receive shape " + f"`x1.shape={x1.shape}` and x2.shape=`{x2.shape}`." + ) + return KerasTensor(x1_shape[:-1], dtype=dtype) + + if ( + x1_shape[-1] is None + or x2_shape[-2] is None + or x1_shape[-1] == x2_shape[-2] + ): + del x1_shape[-1] + del x2_shape[-2] + return KerasTensor(x1_shape + x2_shape, dtype=dtype) + + raise ValueError( + "Shape must match on the last axis of `x1` and second last " + "axis of `x2` when `x1` is N-d array while `x2` is M-D, but " + f"received `x1.shape={x1.shape}` and x2.shape=`{x2.shape}`." + ) + + +@keras_export(["keras.ops.dot", "keras.ops.numpy.dot"]) +def dot(x1, x2): + """Dot product of two tensors. + + - If both `x1` and `x2` are 1-D tensors, it is inner product of vectors + (without complex conjugation). + - If both `x1` and `x2` are 2-D tensors, it is matrix multiplication. + - If either `x1` or `x2` is 0-D (scalar), it is equivalent to `x1 * x2`. + - If `x1` is an N-D tensor and `x2` is a 1-D tensor, it is a sum product + over the last axis of `x1` and `x2`. + - If `x1` is an N-D tensor and `x2` is an M-D tensor (where `M>=2`), + it is a sum product over the last axis of `x1` and the second-to-last + axis of `x2`: `dot(x1, x2)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])`. + + Args: + x1: First argument. + x2: Second argument. + + Note: + Torch backend does not accept 0-D tensors as arguments. + + Returns: + Dot product of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Dot().symbolic_call(x1, x2) + return backend.numpy.dot(x1, x2) + + +class Einsum(Operation): + def __init__(self, subscripts, *, name=None): + super().__init__(name=name) + self.subscripts = subscripts + + def call(self, *operands, **kwargs): + return backend.numpy.einsum(self.subscripts, *operands, **kwargs) + + def compute_output_spec(self, *operands): + """Compute the output shape of `einsum`. + + The shape computation follows the steps below: + 1. Find all letters in the input specs (left part of "->"), and + break them into two categories: letters appearing more than once + go to `reduced_dims`, otherwise go to `kept_dims`. + 2. Adjust `reduced_dims` and `kept_dims` based on the output spec + (right part of "->"). The rule is if the letter appears in the + output spec, then move it to `kept_dims`, otherwise move it to + `reduced_dims`. + 3. Compute the target output shape. If no output spec is set, then + the target output shape will be "...{kept_dims}", e.g., "...ijk", + else it will be the same as output spec. "..." is a wildcard that + could map shape of arbitrary length. + 4. For each operand in `operands`, map the shape specified in the input + spec to the output target, e.g, if operand is of shape [2,3,4], + input spec is "i..." and output target is "i...jk", then 2 will go + the index 0. For dims not represented by any letter, insert to the + wildcard part. For each letter in output target not appearing in + input spec, the dim will be 1 for broadcasting. After 4, each + operand should have a target shape containing only number and + `None`. + 5. Broadcast all shapes computed from 4, and the result is the output + shape. + + Let's take an example to illustrate the steps above. Let's define: + ```python + x = KerasTensor([None, 3, 4]) + y = KerasTensor(2, 4, 3) + z = knp.einsum("...ij, kji->...k", x, y) + ``` + + 1. `reduced_dims` is {"i", "j"}, `kept_dims` is {"k"}. + 2. `reduced_dims` is still {"i", "j"}, and `kept_dims` is {"k"}. + 3. Output target is "...k". + 4. For `x`, the input spec is "...ij", and the output target is "...k". + "i" and "j" do not appear in the output target, so no replacement + happens, and [None] goes to wildcard. Afterwards, "k" is replaced + by 1, so we get shape [None, 1]. Applying the same logic to `y`, we + get shape [2]. + 5. Broadcast [None, 1] and [2], and we get [None, 2], which is the + output shape. + """ + split_subscripts = self.subscripts.split("->") + if len(split_subscripts) > 2: + raise ValueError( + "At most one '->' is supported in `einsum` subscripts, but " + f"received {self.subscripts}." + ) + if len(split_subscripts) == 2: + subscripts = split_subscripts[0] + output_spec = split_subscripts[1] + else: + subscripts = self.subscripts + output_spec = None + input_specs = subscripts.split(",") + if len(input_specs) != len(operands): + raise ValueError( + f"Number of operands ({len(operands)}) does not match the " + f"number of input specs ({len(input_specs)}) in `einsum`, " + f"received subscripts={self.subscripts}." + ) + reduced_dims = set() + kept_dims = set() + for s in subscripts: + if not s.isalpha(): + continue + if s not in reduced_dims and s not in kept_dims: + kept_dims.add(s) + elif s in kept_dims: + kept_dims.remove(s) + reduced_dims.add(s) + + if output_spec is not None: + # The output spec changes the rule of kept_dims and reduced_dims. + # In short, dims appearing in the output spec will be kept, and + # dims not appearing in the output spec will be reduced. + kept_dims_copy = kept_dims.copy() + reduced_dims_copy = reduced_dims.copy() + for dim in kept_dims: + if dim not in output_spec: + kept_dims_copy.remove(dim) + reduced_dims_copy.add(dim) + for dim in reduced_dims: + if dim in output_spec: + reduced_dims_copy.remove(dim) + kept_dims_copy.add(dim) + kept_dims = kept_dims_copy + reduced_dims = reduced_dims_copy + + reduced_dims = sorted(reduced_dims) + kept_dims = sorted(kept_dims) + + if output_spec is None: + target_broadcast_spec = f"...{''.join(kept_dims)}" + else: + target_broadcast_spec = output_spec + + expanded_operands_shapes = [] + for x, spec in zip(operands, input_specs): + x_shape = getattr(x, "shape", []) + x_shape = [-1 if size is None else size for size in x_shape] + split_spec = spec.split("...") + expanded_shape = target_broadcast_spec + if len(split_spec) == 1: + # In this case, the input spec is just a string of letters, + # e.g., "ijk". + if len(x_shape) != len(split_spec[0]): + raise ValueError( + "Number of dimensions in the subscript does not " + "match the number of dimensions in the operand, " + f"received subscript `{spec}` and operand of shape " + f"{x_shape}." + ) + for size, s in zip(x_shape, split_spec[0]): + # Replace the letter with the right shape. + expanded_shape = expanded_shape.replace(s, f"{str(size)} ") + expanded_shape = expanded_shape.replace("...", "") + else: + # In this case, the input spec has "...", e.g., "i...j", "i...", + # or "...j". + for i in range(len(split_spec[0])): + expanded_shape = expanded_shape.replace( + split_spec[0][i], f"{x_shape[i]} " + ) + for i in range(len(split_spec[1])): + expanded_shape = expanded_shape.replace( + split_spec[1][-i - 1], f"{x_shape[-i - 1]} " + ) + # Shape matched by "..." will be inserted to the position of + # "...". + wildcard_shape_start_index = len(split_spec[0]) + wildcard_shape_end_index = ( + len(x_shape) + if len(split_spec[1]) == 0 + else -len(split_spec[1]) + ) + wildcard_shape = x_shape[ + wildcard_shape_start_index:wildcard_shape_end_index + ] + wildcard_shape_str = ( + f"{' '.join([str(size) for size in wildcard_shape])} " + ) + expanded_shape = expanded_shape.replace( + "...", wildcard_shape_str + ) + # Replace all letters not yet handled with "1" for broadcasting. + expanded_shape = re.sub("[a-z]", "1 ", expanded_shape) + expanded_shape = expanded_shape.split() + expanded_shape = [ + None if size == "-1" else int(size) for size in expanded_shape + ] + expanded_operands_shapes.append(expanded_shape) + + output_shape = expanded_operands_shapes[0] + for shape in expanded_operands_shapes[1:]: + output_shape = broadcast_shapes(output_shape, shape) + dtypes_to_resolve = list( + set( + backend.standardize_dtype(getattr(x, "dtype", type(x))) + for x in operands + ) + ) + if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + dtype = "int32" + else: + dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.einsum", "keras.ops.numpy.einsum"]) +def einsum(subscripts, *operands, **kwargs): + """Evaluates the Einstein summation convention on the operands. + + Args: + subscripts: Specifies the subscripts for summation as comma separated + list of subscript labels. An implicit (classical Einstein + summation) calculation is performed unless the explicit indicator + `->` is included as well as subscript labels of the precise + output form. + operands: The operands to compute the Einstein sum of. + + Returns: + The calculation based on the Einstein summation convention. + + Example: + >>> from keras.src import ops + >>> a = ops.arange(25).reshape(5, 5) + >>> b = ops.arange(5) + >>> c = ops.arange(6).reshape(2, 3) + + Trace of a matrix: + + >>> ops.einsum("ii", a) + 60 + >>> ops.einsum(a, [0, 0]) + 60 + >>> ops.trace(a) + 60 + + Extract the diagonal: + + >>> ops.einsum("ii -> i", a) + array([ 0, 6, 12, 18, 24]) + >>> ops.einsum(a, [0, 0], [0]) + array([ 0, 6, 12, 18, 24]) + >>> ops.diag(a) + array([ 0, 6, 12, 18, 24]) + + Sum over an axis: + + >>> ops.einsum("ij -> i", a) + array([ 10, 35, 60, 85, 110]) + >>> ops.einsum(a, [0, 1], [0]) + array([ 10, 35, 60, 85, 110]) + >>> ops.sum(a, axis=1) + array([ 10, 35, 60, 85, 110]) + + For higher dimensional tensors summing a single axis can be done + with ellipsis: + + >>> ops.einsum("...j -> ...", a) + array([ 10, 35, 60, 85, 110]) + >>> np.einsum(a, [..., 1], [...]) + array([ 10, 35, 60, 85, 110]) + + Compute a matrix transpose or reorder any number of axes: + + >>> ops.einsum("ji", c) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> ops.einsum("ij -> ji", c) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> ops.einsum(c, [1, 0]) + array([[0, 3], + [1, 4], + [2, 5]]) + >>> ops.transpose(c) + array([[0, 3], + [1, 4], + [2, 5]]) + + Matrix vector multiplication: + + >>> ops.einsum("ij, j", a, b) + array([ 30, 80, 130, 180, 230]) + >>> ops.einsum(a, [0, 1], b, [1]) + array([ 30, 80, 130, 180, 230]) + >>> ops.einsum("...j, j", a, b) + array([ 30, 80, 130, 180, 230]) + """ + if any_symbolic_tensors(operands): + return Einsum(subscripts).symbolic_call(*operands, **kwargs) + return backend.numpy.einsum(subscripts, *operands, **kwargs) + + +@keras_export(["keras.ops.empty", "keras.ops.numpy.empty"]) +def empty(shape, dtype=None): + """Return a tensor of given shape and type filled with uninitialized data. + + Args: + shape: Shape of the empty tensor. + dtype: Desired data type of the empty tensor. + + Returns: + The empty tensor. + """ + return backend.numpy.empty(shape, dtype=dtype) + + +class Equal(Operation): + def call(self, x1, x2): + return backend.numpy.equal(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export(["keras.ops.equal", "keras.ops.numpy.equal"]) +def equal(x1, x2): + """Returns `(x1 == x2)` element-wise. + + Args: + x1: Tensor to compare. + x2: Tensor to compare. + + Returns: + Output tensor, element-wise comparison of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Equal().symbolic_call(x1, x2) + return backend.numpy.equal(x1, x2) + + +class Exp(Operation): + def call(self, x): + return backend.numpy.exp(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = backend.floatx() + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.exp", "keras.ops.numpy.exp"]) +def exp(x): + """Calculate the exponential of all elements in the input tensor. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise exponential of `x`. + """ + if any_symbolic_tensors((x,)): + return Exp().symbolic_call(x) + return backend.numpy.exp(x) + + +class Exp2(Operation): + def call(self, x): + return backend.numpy.exp2(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = backend.floatx() + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.exp2", "keras.ops.numpy.exp2"]) +def exp2(x): + """Calculate the base-2 exponential of all elements in the input tensor. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise base-2 exponential of `x`. + """ + if any_symbolic_tensors((x,)): + return Exp2().symbolic_call(x) + return backend.numpy.exp2(x) + + +class ExpandDims(Operation): + def __init__(self, axis, *, name=None): + super().__init__(name=name) + if not isinstance(axis, (int, tuple, list)): + raise ValueError( + "The `axis` argument to `expand_dims` should be an integer, " + f"tuple or list. Received axis={axis}" + ) + self.axis = axis + + def call(self, x): + return backend.numpy.expand_dims(x, self.axis) + + def compute_output_spec(self, x): + output_shape = operation_utils.compute_expand_dims_output_shape( + x.shape, self.axis + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse) + + +@keras_export( + [ + "keras.ops.expand_dims", + "keras.ops.numpy.expand_dims", + ] +) +def expand_dims(x, axis): + """Expand the shape of a tensor. + + Insert a new axis at the `axis` position in the expanded tensor shape. + + Args: + x: Input tensor. + axis: Position in the expanded axes where the new axis + (or axes) is placed. + + Returns: + Output tensor with the number of dimensions increased. + """ + if any_symbolic_tensors((x,)): + return ExpandDims(axis=axis).symbolic_call(x) + return backend.numpy.expand_dims(x, axis) + + +class Expm1(Operation): + def call(self, x): + return backend.numpy.expm1(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(x.dtype) + if "int" in dtype or dtype == "bool": + dtype = backend.floatx() + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.expm1", "keras.ops.numpy.expm1"]) +def expm1(x): + """Calculate `exp(x) - 1` for all elements in the tensor. + + Args: + x: Input values. + + Returns: + Output tensor, element-wise exponential minus one. + """ + if any_symbolic_tensors((x,)): + return Expm1().symbolic_call(x) + return backend.numpy.expm1(x) + + +class Flip(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.numpy.flip(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.flip", "keras.ops.numpy.flip"]) +def flip(x, axis=None): + """Reverse the order of elements in the tensor along the given axis. + + The shape of the tensor is preserved, but the elements are reordered. + + Args: + x: Input tensor. + axis: Axis or axes along which to flip the tensor. The default, + `axis=None`, will flip over all of the axes of the input tensor. + + Returns: + Output tensor with entries of `axis` reversed. + """ + if any_symbolic_tensors((x,)): + return Flip(axis=axis).symbolic_call(x) + return backend.numpy.flip(x, axis=axis) + + +class Floor(Operation): + def call(self, x): + return backend.numpy.floor(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + dtype = ( + backend.floatx() + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.floor", "keras.ops.numpy.floor"]) +def floor(x): + """Return the floor of the input, element-wise. + + The floor of the scalar `x` is the largest integer `i`, such that `i <= x`. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise floor of `x`. + """ + if any_symbolic_tensors((x,)): + return Floor().symbolic_call(x) + return backend.numpy.floor(x) + + +class Full(Operation): + def __init__(self, shape, dtype=None, *, name=None): + super().__init__(name=name) + self.shape = shape + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, fill_value): + return backend.numpy.full(self.shape, fill_value, dtype=self.dtype) + + def compute_output_spec(self, fill_value): + dtype = backend.floatx() if self.dtype is None else self.dtype + return KerasTensor(self.shape, dtype=dtype) + + +@keras_export(["keras.ops.full", "keras.ops.numpy.full"]) +def full(shape, fill_value, dtype=None): + """Return a new tensor of given shape and type, filled with `fill_value`. + + Args: + shape: Shape of the new tensor. + fill_value: Fill value. + dtype: Desired data type of the tensor. + + Returns: + Output tensor. + """ + if any_symbolic_tensors((fill_value,)): + return Full(shape=shape, dtype=dtype).symbolic_call(fill_value) + return backend.numpy.full(shape, fill_value, dtype=dtype) + + +class FullLike(Operation): + def __init__(self, dtype=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x, fill_value): + return backend.numpy.full_like(x, fill_value, dtype=self.dtype) + + def compute_output_spec(self, x, fill_value): + dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.full_like", "keras.ops.numpy.full_like"]) +def full_like(x, fill_value, dtype=None): + """Return a full tensor with the same shape and type as the given tensor. + + Args: + x: Input tensor. + fill_value: Fill value. + dtype: Overrides data type of the result. + + Returns: + Tensor of `fill_value` with the same shape and type as `x`. + """ + if any_symbolic_tensors((x, fill_value)): + return FullLike(dtype=dtype).symbolic_call(x, fill_value) + return backend.numpy.full_like(x, fill_value, dtype=dtype) + + +class Gcd(Operation): + def call(self, x1, x2): + return backend.numpy.gcd(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + + x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1))) + x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2))) + dtype = dtypes.result_type(x1_type, x2_type) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.gcd", "keras.ops.numpy.gcd"]) +def gcd(x1, x2): + """Greatest common divisor of `x1` and `x2`, element-wise. + + Args: + x1: First input tensor (integer type). + x2: Second input tensor (integer type). + + Returns: + Output tensor, element-wise greatest common divisor of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Gcd().symbolic_call(x1, x2) + return backend.numpy.gcd(x1, x2) + + +class GetItem(Operation): + def call(self, x, key): + if isinstance(key, list): + key = tuple(key) + return x[key] + + def compute_output_spec(self, x, key): + remaining_shape = list(x.shape) + new_shape = [] + if isinstance(key, int): + remaining_key = [key] + elif isinstance(key, tuple): + remaining_key = list(key) + elif isinstance(key, list): + remaining_key = key.copy() + else: + raise ValueError( + f"Unsupported key type for array slice. Received: `{key}`" + ) + num_ellipses = remaining_key.count(Ellipsis) + if num_ellipses > 1: + raise ValueError( + f"Slice should only have one ellipsis. Received: `{key}`" + ) + elif num_ellipses == 0: + # Add an implicit final ellipsis. + remaining_key.append(Ellipsis) + # Consume slice key element by element. + while True: + if not remaining_key: + break + subkey = remaining_key.pop(0) + # Check for `newaxis` and `Ellipsis`. + if subkey == Ellipsis: + # Keep as many slices remain in our key, omitting `newaxis`. + needed = len(remaining_key) - remaining_key.count(np.newaxis) + consumed = len(remaining_shape) - needed + new_shape += remaining_shape[:consumed] + remaining_shape = remaining_shape[consumed:] + continue + # All frameworks follow numpy for newaxis. `np.newaxis == None`. + if subkey == np.newaxis: + new_shape.append(1) + continue + # At this point, we need to consume a new axis from the shape. + if not remaining_shape: + raise ValueError( + f"Array has shape {x.shape} but slice " + f"has to many indices. Received: `{key}`" + ) + length = remaining_shape.pop(0) + if isinstance(subkey, int): + if length is not None: + index = subkey if subkey >= 0 else subkey + length + if index < 0 or index >= length: + raise ValueError( + f"Array has shape {x.shape} but out-of-bounds " + f"index {key} was requested." + ) + elif isinstance(subkey, slice): + if length is not None: + # python3 friendly way to compute a slice length. + new_length = len(range(*subkey.indices(length))) + new_shape.append(new_length) + else: + new_shape.append(length) + else: + raise ValueError( + f"Unsupported key type for array slice. Received: `{key}`" + ) + return KerasTensor(tuple(new_shape), dtype=x.dtype) + + +@keras_export(["keras.ops.get_item", "keras.ops.numpy.get_item"]) +def get_item(x, key): + """Return `x[key]`.""" + if any_symbolic_tensors((x,)): + return GetItem().symbolic_call(x, key) + return x[key] + + +class Greater(Operation): + def call(self, x1, x2): + return backend.numpy.greater(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export(["keras.ops.greater", "keras.ops.numpy.greater"]) +def greater(x1, x2): + """Return the truth value of `x1 > x2` element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise comparison of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Greater().symbolic_call(x1, x2) + return backend.numpy.greater(x1, x2) + + +class GreaterEqual(Operation): + def call(self, x1, x2): + return backend.numpy.greater_equal(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export( + [ + "keras.ops.greater_equal", + "keras.ops.numpy.greater_equal", + ] +) +def greater_equal(x1, x2): + """Return the truth value of `x1 >= x2` element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise comparison of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return GreaterEqual().symbolic_call(x1, x2) + return backend.numpy.greater_equal(x1, x2) + + +class Hstack(Operation): + def call(self, xs): + return backend.numpy.hstack(xs) + + def compute_output_spec(self, xs): + first_shape = xs[0].shape + total_size_on_axis = 0 + dtypes_to_resolve = [] + for x in xs: + if not shape_equal(x.shape, first_shape, axis=[1], allow_none=True): + raise ValueError( + "Every value in `xs` must have the same shape except on " + f"the `axis` dim. But found element of shape {x.shape}, " + f"which is different from the first element's " + f"shape {first_shape}." + ) + if total_size_on_axis is None or x.shape[1] is None: + total_size_on_axis = None + else: + total_size_on_axis += x.shape[1] + dtypes_to_resolve.append(getattr(x, "dtype", type(x))) + output_shape = list(first_shape) + output_shape[1] = total_size_on_axis + dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.hstack", "keras.ops.numpy.hstack"]) +def hstack(xs): + """Stack tensors in sequence horizontally (column wise). + + This is equivalent to concatenation along the first axis for 1-D tensors, + and along the second axis for all other tensors. + + Args: + xs: Sequence of tensors. + + Returns: + The tensor formed by stacking the given tensors. + """ + if any_symbolic_tensors((xs,)): + return Hstack().symbolic_call(xs) + return backend.numpy.hstack(xs) + + +class Hypot(Operation): + def call(self, x1, x2): + return backend.numpy.hypot(x1, x2) + + def compute_output_spec(self, x1, x2): + dtype = dtypes.result_type(x1.dtype, x2.dtype) + if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]: + dtype = backend.floatx() + elif dtype == "int64": + dtype = "float64" + return KerasTensor(broadcast_shapes(x1.shape, x2.shape), dtype=dtype) + + +@keras_export(["keras.ops.hypot", "keras.ops.numpy.hypot"]) +def hypot(x1, x2): + """Element-wise hypotenuse of right triangles with legs `x1` and `x2`. + + This is equivalent to computing `sqrt(x1**2 + x2**2)` element-wise, + with shape determined by broadcasting. + + Args: + x1: A tensor, representing the first leg of the right triangle. + x2: A tensor, representing the second leg of the right triangle. + + Returns: + A tensor with a shape determined by broadcasting `x1` and `x2`. + + Example: + >>> x1 = keras.ops.convert_to_tensor([3.0, 4.0, 5.0]) + >>> x2 = keras.ops.convert_to_tensor([4.0, 3.0, 12.0]) + >>> keras.ops.hypot(x1, x2) + array([5., 5., 13.], dtype=float32) + + >>> x1 = keras.ops.convert_to_tensor([[1, 2], [3, 4]]) + >>> x2 = keras.ops.convert_to_tensor([1, 1]) + >>> keras.ops.hypot(x1, x2) + array([[1.41421356 2.23606798], + [3.16227766 4.12310563]], dtype=float32) + """ + if any_symbolic_tensors((x1, x2)): + return Hypot().symbolic_call(x1, x2) + return backend.numpy.hypot(x1, x2) + + +@keras_export(["keras.ops.identity", "keras.ops.numpy.identity"]) +def identity(n, dtype=None): + """Return the identity tensor. + + The identity tensor is a square tensor with ones on the main diagonal and + zeros elsewhere. + + Args: + n: Number of rows (and columns) in the `n x n` output tensor. + dtype: Data type of the output tensor. + + Returns: + The identity tensor. + """ + return backend.numpy.identity(n, dtype=dtype) + + +class Imag(Operation): + def call(self, x): + return backend.numpy.imag(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.imag", "keras.ops.numpy.imag"]) +def imag(x): + """Return the imaginary part of the complex argument. + + Args: + x: Input tensor. + + Returns: + The imaginary component of the complex argument. + """ + if any_symbolic_tensors((x,)): + return Imag().symbolic_call(x) + return backend.numpy.imag(x) + + +class Isclose(Operation): + def __init__(self, equal_nan=False, *, name=None): + super().__init__(name=name) + self.equal_nan = equal_nan + + def call(self, x1, x2, rtol=1e-5, atol=1e-8): + return backend.numpy.isclose(x1, x2, rtol, atol, self.equal_nan) + + def compute_output_spec(self, x1, x2, rtol=1e-5, atol=1e-8): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export(["keras.ops.isclose", "keras.ops.numpy.isclose"]) +def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False): + """Return whether two tensors are element-wise almost equal. + + Args: + x1: First input tensor. + x2: Second input tensor. + rtol: Relative tolerance. + atol: Absolute tolerance. + equal_nan: If `True`, element-wise NaNs are considered equal. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x1, x2)): + return Isclose(equal_nan=equal_nan).symbolic_call(x1, x2, rtol, atol) + return backend.numpy.isclose(x1, x2, rtol, atol, equal_nan) + + +class Isfinite(Operation): + def call(self, x): + return backend.numpy.isfinite(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isfinite", "keras.ops.numpy.isfinite"]) +def isfinite(x): + """Return whether a tensor is finite, element-wise. + + Real values are finite when they are not NaN, not positive infinity, and + not negative infinity. Complex values are finite when both their real + and imaginary parts are finite. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x,)): + return Isfinite().symbolic_call(x) + return backend.numpy.isfinite(x) + + +class IsIn(Operation): + def __init__( + self, + assume_unique=False, + invert=False, + *, + name=None, + ): + super().__init__(name=name) + self.assume_unique = assume_unique + self.invert = invert + + def call(self, x1, x2): + return backend.numpy.isin( + x1, x2, assume_unique=self.assume_unique, invert=self.invert + ) + + def compute_output_spec(self, x1, x2): + return KerasTensor(x1.shape, dtype="bool") + + +@keras_export(["keras.ops.isin", "keras.ops.numpy.isin"]) +def isin(x1, x2, assume_unique=False, invert=False): + """Test whether each element of `x1` is present in `x2`. + + This operation performs element-wise checks to determine if each value + in `x1` is contained within `x2`. The result is a boolean tensor with + the same shape as `x1`, where each entry is `True` if the corresponding + element in `x1` is in `x2`, and `False` otherwise. + + Args: + x1: Input tensor or array-like structure to test. + x2: Values against which each element of `x1` is tested. + Can be a tensor, list, or scalar. + assume_unique: Boolean (default: False). + If True, assumes both `x1` and `x2` contain only unique elements. + This can speed up the computation. If False, duplicates will be + handled correctly but may impact performance. + invert: A boolean (default: False). + If True, inverts the result. Entries will be `True` + where `x1` elements are not in `x2`. + + Returns: + A boolean tensor of the same shape as `x1` indicating element-wise + membership in `x2`. + + Example: + >>> from keras import ops + >>> x1 = ops.array([0, 1, 2, 5]) + >>> x2 = ops.array([0, 2]) + >>> result = ops.isin(x1, x2) + array([ True, False, True, False]) + """ + if any_symbolic_tensors((x1, x2)): + return IsIn(assume_unique=assume_unique, invert=invert).symbolic_call( + x1, x2 + ) + return backend.numpy.isin( + x1, x2, assume_unique=assume_unique, invert=invert + ) + + +class Isinf(Operation): + def call(self, x): + return backend.numpy.isinf(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isinf", "keras.ops.numpy.isinf"]) +def isinf(x): + """Test element-wise for positive or negative infinity. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x,)): + return Isinf().symbolic_call(x) + return backend.numpy.isinf(x) + + +class Isnan(Operation): + def call(self, x): + return backend.numpy.isnan(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isnan", "keras.ops.numpy.isnan"]) +def isnan(x): + """Test element-wise for NaN and return result as a boolean tensor. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x,)): + return Isnan().symbolic_call(x) + return backend.numpy.isnan(x) + + +class Isneginf(Operation): + def call(self, x): + return backend.numpy.isneginf(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isneginf", "keras.ops.numpy.isneginf"]) +def isneginf(x): + """Test element-wise for negative infinity. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x,)): + return Isneginf().symbolic_call(x) + return backend.numpy.isneginf(x) + + +class Isposinf(Operation): + def call(self, x): + return backend.numpy.isposinf(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isposinf", "keras.ops.numpy.isposinf"]) +def isposinf(x): + """Test element-wise for positive infinity. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x,)): + return Isposinf().symbolic_call(x) + return backend.numpy.isposinf(x) + + +class Kron(Operation): + def call(self, x1, x2): + return backend.numpy.kron(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + + def _mul_shape_dim(a, b): + if a is None or b is None: + return None + return a * b + + output_shape = tuple( + _mul_shape_dim(a, b) for a, b in zip(x1_shape, x2_shape) + ) + + x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1))) + x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2))) + dtype = dtypes.result_type(x1_type, x2_type) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.kron", "keras.ops.numpy.kron"]) +def kron(x1, x2): + """Kronecker product of `x1` and `x2`. + + Computes the Kronecker product of two input tensors. If `x1` has shape + `(a0, a1, ..., an)` and `x2` has shape `(b0, b1, ..., bn)`, then the + output will have shape `(a0*b0, a1*b1, ..., an*bn)`. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + A tensor representing the Kronecker product of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Kron().symbolic_call(x1, x2) + return backend.numpy.kron(x1, x2) + + +class Lcm(Operation): + def call(self, x1, x2): + return backend.numpy.lcm(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + + x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1))) + x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2))) + dtype = dtypes.result_type(x1_type, x2_type) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.lcm", "keras.ops.numpy.lcm"]) +def lcm(x1, x2): + """Least common multiple of `x1` and `x2`, element-wise. + + Args: + x1: First input tensor (integer type). + x2: Second input tensor (integer type). + + Returns: + Output tensor, element-wise least common multiple of `x1` and `x2`. + + Example: + >>> x1 = keras.ops.convert_to_tensor([2, 3, 4]) + >>> x2 = keras.ops.convert_to_tensor([5, 6, 7]) + >>> keras.ops.lcm(x1, x2) + array([10, 6, 28], dtype=int32) + """ + if any_symbolic_tensors((x1, x2)): + return Lcm().symbolic_call(x1, x2) + return backend.numpy.lcm(x1, x2) + + +class Less(Operation): + def call(self, x1, x2): + return backend.numpy.less(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export(["keras.ops.less", "keras.ops.numpy.less"]) +def less(x1, x2): + """Return the truth value of `x1 < x2` element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise comparison of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Less().symbolic_call(x1, x2) + return backend.numpy.less(x1, x2) + + +class LessEqual(Operation): + def call(self, x1, x2): + return backend.numpy.less_equal(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export( + [ + "keras.ops.less_equal", + "keras.ops.numpy.less_equal", + ] +) +def less_equal(x1, x2): + """Return the truth value of `x1 <= x2` element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise comparison of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return LessEqual().symbolic_call(x1, x2) + return backend.numpy.less_equal(x1, x2) + + +class Linspace(Operation): + def __init__( + self, + num=50, + endpoint=True, + retstep=False, + dtype=None, + axis=0, + *, + name=None, + ): + super().__init__(name=name) + self.num = num + self.endpoint = endpoint + self.retstep = retstep + self.dtype = dtype + self.axis = axis + + def call(self, start, stop): + return backend.numpy.linspace( + start, + stop, + num=self.num, + endpoint=self.endpoint, + retstep=self.retstep, + dtype=self.dtype, + axis=self.axis, + ) + + def compute_output_spec(self, start, stop): + start_shape = getattr(start, "shape", []) + stop_shape = getattr(stop, "shape", []) + output_shape = broadcast_shapes(start_shape, stop_shape) + if self.axis == -1: + output_shape = output_shape + [self.num] + elif self.axis >= 0: + output_shape = ( + output_shape[: self.axis] + + [self.num] + + output_shape[self.axis :] + ) + else: + output_shape = ( + output_shape[: self.axis + 1] + + [self.num] + + output_shape[self.axis + 1 :] + ) + + dtype = ( + self.dtype + if self.dtype is not None + else backend.standardize_dtype(getattr(start, "dtype", type(start))) + ) + dtype = backend.result_type(dtype, float) + if self.retstep: + return (KerasTensor(output_shape, dtype=dtype), None) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.linspace", "keras.ops.numpy.linspace"]) +def linspace( + start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 +): + """Return evenly spaced numbers over a specified interval. + + Returns `num` evenly spaced samples, calculated over the interval + `[start, stop]`. + + The endpoint of the interval can optionally be excluded. + + Args: + start: The starting value of the sequence. + stop: The end value of the sequence, unless `endpoint` is set to + `False`. In that case, the sequence consists of all but the last + of `num + 1` evenly spaced samples, so that `stop` is excluded. + Note that the step size changes when `endpoint` is `False`. + num: Number of samples to generate. Defaults to `50`. Must be + non-negative. + endpoint: If `True`, `stop` is the last sample. Otherwise, it is + not included. Defaults to `True`. + retstep: If `True`, return `(samples, step)`, where `step` is the + spacing between samples. + dtype: The type of the output tensor. + axis: The axis in the result to store the samples. Relevant only if + start or stop are array-like. Defaults to `0`. + + Note: + Torch backend does not support `axis` argument. + + Returns: + A tensor of evenly spaced numbers. + If `retstep` is `True`, returns `(samples, step)` + """ + if any_symbolic_tensors((start, stop)): + return Linspace(num, endpoint, retstep, dtype, axis)(start, stop) + return backend.numpy.linspace( + start, + stop, + num=num, + endpoint=endpoint, + retstep=retstep, + dtype=dtype, + axis=axis, + ) + + +class Log(Operation): + def call(self, x): + return backend.numpy.log(x) + + def compute_output_spec(self, x): + dtype = ( + backend.floatx() + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.log", "keras.ops.numpy.log"]) +def log(x): + """Natural logarithm, element-wise. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise natural logarithm of `x`. + """ + if any_symbolic_tensors((x,)): + return Log().symbolic_call(x) + return backend.numpy.log(x) + + +class Log10(Operation): + def call(self, x): + return backend.numpy.log10(x) + + def compute_output_spec(self, x): + dtype = ( + backend.floatx() + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.log10", "keras.ops.numpy.log10"]) +def log10(x): + """Return the base 10 logarithm of the input tensor, element-wise. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise base 10 logarithm of `x`. + """ + if any_symbolic_tensors((x,)): + return Log10().symbolic_call(x) + return backend.numpy.log10(x) + + +class Log1p(Operation): + def call(self, x): + return backend.numpy.log1p(x) + + def compute_output_spec(self, x): + dtype = ( + backend.floatx() + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.log1p", "keras.ops.numpy.log1p"]) +def log1p(x): + """Returns the natural logarithm of one plus the `x`, element-wise. + + Calculates `log(1 + x)`. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise natural logarithm of `1 + x`. + """ + if any_symbolic_tensors((x,)): + return Log1p().symbolic_call(x) + return backend.numpy.log1p(x) + + +class Log2(Operation): + def call(self, x): + return backend.numpy.log2(x) + + def compute_output_spec(self, x): + dtype = ( + backend.floatx() + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.log2", "keras.ops.numpy.log2"]) +def log2(x): + """Base-2 logarithm of `x`, element-wise. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise base-2 logarithm of `x`. + """ + if any_symbolic_tensors((x,)): + return Log2().symbolic_call(x) + return backend.numpy.log2(x) + + +class Logaddexp(Operation): + def call(self, x1, x2): + return backend.numpy.logaddexp(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.logaddexp", "keras.ops.numpy.logaddexp"]) +def logaddexp(x1, x2): + """Logarithm of the sum of exponentiations of the inputs. + + Calculates `log(exp(x1) + exp(x2))`. + + Args: + x1: Input tensor. + x2: Input tensor. + + Returns: + Output tensor, element-wise logarithm of the sum of exponentiations + of the inputs. + """ + if any_symbolic_tensors((x1, x2)): + return Logaddexp().symbolic_call(x1, x2) + return backend.numpy.logaddexp(x1, x2) + + +class Logaddexp2(Operation): + def call(self, x1, x2): + return backend.numpy.logaddexp2(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.logaddexp2", "keras.ops.numpy.logaddexp2"]) +def logaddexp2(x1, x2): + """Base-2 logarithm of the sum of exponentiations of the inputs. + + Calculates `log2(2**x1 + 2**x2)`. + + Args: + x1: Input tensor. + x2: Input tensor. + + Returns: + Output tensor, element-wise log base 2 of the sum of 2**x1 and 2**x2. + + Example: + >>> from keras import ops + >>> x1 = ops.array([1, 2, 3]) + >>> x2 = ops.array([1, 2, 3]) + >>> ops.logaddexp2(x1, x2) + array([2., 3., 4.], dtype=float32) + """ + if any_symbolic_tensors((x1, x2)): + return Logaddexp2().symbolic_call(x1, x2) + return backend.numpy.logaddexp2(x1, x2) + + +class LogicalAnd(Operation): + def call(self, x1, x2): + return backend.numpy.logical_and(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export( + [ + "keras.ops.logical_and", + "keras.ops.numpy.logical_and", + ] +) +def logical_and(x1, x2): + """Computes the element-wise logical AND of the given input tensors. + + Zeros are treated as `False` and non-zeros are treated as `True`. + + Args: + x1: Input tensor. + x2: Input tensor. + + Returns: + Output tensor, element-wise logical AND of the inputs. + """ + if any_symbolic_tensors((x1, x2)): + return LogicalAnd().symbolic_call(x1, x2) + return backend.numpy.logical_and(x1, x2) + + +class LogicalNot(Operation): + def call(self, x): + return backend.numpy.logical_not(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export( + [ + "keras.ops.logical_not", + "keras.ops.numpy.logical_not", + ] +) +def logical_not(x): + """Computes the element-wise NOT of the given input tensor. + + Zeros are treated as `False` and non-zeros are treated as `True`. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise logical NOT of the input. + """ + if any_symbolic_tensors((x,)): + return LogicalNot().symbolic_call(x) + return backend.numpy.logical_not(x) + + +class LogicalOr(Operation): + def call(self, x1, x2): + return backend.numpy.logical_or(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export( + [ + "keras.ops.logical_or", + "keras.ops.numpy.logical_or", + ] +) +def logical_or(x1, x2): + """Computes the element-wise logical OR of the given input tensors. + + Zeros are treated as `False` and non-zeros are treated as `True`. + + Args: + x1: Input tensor. + x2: Input tensor. + + Returns: + Output tensor, element-wise logical OR of the inputs. + """ + if any_symbolic_tensors((x1, x2)): + return LogicalOr().symbolic_call(x1, x2) + return backend.numpy.logical_or(x1, x2) + + +class Logspace(Operation): + def __init__( + self, num=50, endpoint=True, base=10, dtype=None, axis=0, *, name=None + ): + super().__init__(name=name) + self.num = num + self.endpoint = endpoint + self.base = base + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + self.axis = axis + + def call(self, start, stop): + return backend.numpy.logspace( + start, + stop, + num=self.num, + endpoint=self.endpoint, + base=self.base, + dtype=self.dtype, + axis=self.axis, + ) + + def compute_output_spec(self, start, stop): + start_shape = getattr(start, "shape", []) + stop_shape = getattr(stop, "shape", []) + output_shape = broadcast_shapes(start_shape, stop_shape) + if self.axis == -1: + output_shape = output_shape + [self.num] + elif self.axis >= 0: + output_shape = ( + output_shape[: self.axis] + + [self.num] + + output_shape[self.axis :] + ) + else: + output_shape = ( + output_shape[: self.axis + 1] + + [self.num] + + output_shape[self.axis + 1 :] + ) + dtype = ( + self.dtype + if self.dtype is not None + else backend.standardize_dtype(getattr(start, "dtype", type(start))) + ) + dtype = backend.result_type(dtype, float) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.logspace", "keras.ops.numpy.logspace"]) +def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): + """Returns numbers spaced evenly on a log scale. + + In linear space, the sequence starts at `base ** start` and ends with + `base ** stop` (see `endpoint` below). + + Args: + start: The starting value of the sequence. + stop: The final value of the sequence, unless `endpoint` is `False`. + In that case, `num + 1` values are spaced over the interval in + log-space, of which all but the last (a sequence of length `num`) + are returned. + num: Number of samples to generate. Defaults to `50`. + endpoint: If `True`, `stop` is the last sample. Otherwise, it is not + included. Defaults to `True`. + base: The base of the log space. Defaults to `10`. + dtype: The type of the output tensor. + axis: The axis in the result to store the samples. Relevant only + if start or stop are array-like. + + Note: + Torch backend does not support `axis` argument. + + Returns: + A tensor of evenly spaced samples on a log scale. + """ + if any_symbolic_tensors((start, stop)): + return Logspace(num, endpoint, base, dtype, axis)(start, stop) + return backend.numpy.logspace( + start, + stop, + num=num, + endpoint=endpoint, + base=base, + dtype=dtype, + axis=axis, + ) + + +class Matmul(Operation): + def call(self, x1, x2): + return backend.numpy.matmul(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = operation_utils.compute_matmul_output_shape( + x1_shape, x2_shape + ) + x1_sparse = getattr(x1, "sparse", True) + x2_sparse = getattr(x2, "sparse", True) + output_sparse = x1_sparse and x2_sparse + x1_dtype = backend.standardize_dtype(getattr(x1, "dtype", type(x1))) + x2_dtype = backend.standardize_dtype(getattr(x2, "dtype", type(x2))) + if x1_dtype == "int8" and x2_dtype == "int8": + dtype = "int32" + else: + dtype = dtypes.result_type(x1_dtype, x2_dtype) + return KerasTensor(output_shape, dtype=dtype, sparse=output_sparse) + + +@keras_export(["keras.ops.matmul", "keras.ops.numpy.matmul"]) +def matmul(x1, x2): + """Matrix product of two tensors. + + - If both tensors are 1-dimensional, the dot product (scalar) is returned. + - If either tensor is N-D, N > 2, it is treated as a stack of matrices + residing in the last two indexes and broadcast accordingly. + - If the first tensor is 1-D, it is promoted to a matrix by prepending + a 1 to its dimensions. After matrix multiplication the prepended + 1 is removed. + - If the second tensor is 1-D, it is promoted to a matrix by appending a 1 + to its dimensions. After matrix multiplication the appended 1 is removed. + + Args: + x1: First tensor. + x2: Second tensor. + + Returns: + Output tensor, matrix product of the inputs. + """ + if any_symbolic_tensors((x1, x2)): + return Matmul().symbolic_call(x1, x2) + return backend.numpy.matmul(x1, x2) + + +class Max(Operation): + def __init__(self, axis=None, keepdims=False, initial=None, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + self.initial = initial + + def call(self, x): + return backend.numpy.max( + x, axis=self.axis, keepdims=self.keepdims, initial=self.initial + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, + ) + + +@keras_export(["keras.ops.max", "keras.ops.numpy.max"]) +def max(x, axis=None, keepdims=False, initial=None): + """Return the maximum of a tensor or maximum along an axis. + + Args: + x: Input tensor. + axis: Axis or axes along which to operate. By default, flattened input + is used. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. Defaults to `False`. + initial: The minimum value of an output element. Defaults to `None`. + + Returns: + Maximum of `x`. + """ + if any_symbolic_tensors((x,)): + return Max(axis=axis, keepdims=keepdims, initial=initial).symbolic_call( + x + ) + return backend.numpy.max(x, axis=axis, keepdims=keepdims, initial=initial) + + +class Maximum(Operation): + def call(self, x1, x2): + return backend.numpy.maximum(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and x2_sparse + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) + + +@keras_export(["keras.ops.maximum", "keras.ops.numpy.maximum"]) +def maximum(x1, x2): + """Element-wise maximum of `x1` and `x2`. + + Args: + x1: First tensor. + x2: Second tensor. + + Returns: + Output tensor, element-wise maximum of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Maximum().symbolic_call(x1, x2) + return backend.numpy.maximum(x1, x2) + + +class Median(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.median(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + output_shape = reduce_shape( + x.shape, axis=self.axis, keepdims=self.keepdims + ) + if backend.standardize_dtype(x.dtype) == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.median", "keras.ops.numpy.median"]) +def median(x, axis=None, keepdims=False): + """Compute the median along the specified axis. + + Args: + x: Input tensor. + axis: Axis or axes along which the medians are computed. Defaults to + `axis=None` which is to compute the median(s) along a flattened + version of the array. + keepdims: If this is set to `True`, the axes which are reduce + are left in the result as dimensions with size one. + + Returns: + The output tensor. + """ + if any_symbolic_tensors((x,)): + return Median(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.median(x, axis=axis, keepdims=keepdims) + + +class Meshgrid(Operation): + def __init__(self, indexing="xy", *, name=None): + super().__init__(name=name) + if indexing not in ("xy", "ij"): + raise ValueError( + "Valid values for `indexing` are 'xy' and 'ij', " + "but received {index}." + ) + self.indexing = indexing + + def call(self, *x): + return backend.numpy.meshgrid(*x, indexing=self.indexing) + + def compute_output_spec(self, *x): + output_shape = [] + for xi in x: + if len(xi.shape) == 0: + size = 1 + else: + if None in xi.shape: + size = None + else: + size = int(np.prod(xi.shape)) + output_shape.append(size) + if self.indexing == "ij": + return [KerasTensor(output_shape) for _ in range(len(x))] + tmp = output_shape[0] + output_shape[0] = output_shape[1] + output_shape[1] = tmp + return [ + KerasTensor(output_shape, dtype=xi.dtype) for _ in range(len(x)) + ] + + +@keras_export(["keras.ops.meshgrid", "keras.ops.numpy.meshgrid"]) +def meshgrid(*x, indexing="xy"): + """Creates grids of coordinates from coordinate vectors. + + Given `N` 1-D tensors `T0, T1, ..., TN-1` as inputs with corresponding + lengths `S0, S1, ..., SN-1`, this creates an `N` N-dimensional tensors + `G0, G1, ..., GN-1` each with shape `(S0, ..., SN-1)` where the output + `Gi` is constructed by expanding `Ti` to the result shape. + + Args: + x: 1-D tensors representing the coordinates of a grid. + indexing: `"xy"` or `"ij"`. "xy" is cartesian; `"ij"` is matrix + indexing of output. Defaults to `"xy"`. + + Returns: + Sequence of N tensors. + + Example: + >>> from keras.src import ops + >>> x = ops.array([1, 2, 3]) + >>> y = ops.array([4, 5, 6]) + + >>> grid_x, grid_y = ops.meshgrid(x, y, indexing="ij") + >>> grid_x + array([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + >>> grid_y + array([[4, 5, 6], + [4, 5, 6], + [4, 5, 6]]) + """ + if any_symbolic_tensors(x): + return Meshgrid(indexing=indexing).symbolic_call(*x) + return backend.numpy.meshgrid(*x, indexing=indexing) + + +class Min(Operation): + def __init__(self, axis=None, keepdims=False, initial=None, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + self.initial = initial + + def call(self, x): + return backend.numpy.min( + x, axis=self.axis, keepdims=self.keepdims, initial=self.initial + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, + ) + + +@keras_export(["keras.ops.min", "keras.ops.numpy.min"]) +def min(x, axis=None, keepdims=False, initial=None): + """Return the minimum of a tensor or minimum along an axis. + + Args: + x: Input tensor. + axis: Axis or axes along which to operate. By default, flattened input + is used. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. Defaults to `False`. + initial: The maximum value of an output element. Defaults to `None`. + + Returns: + Minimum of `x`. + """ + if any_symbolic_tensors((x,)): + return Min(axis=axis, keepdims=keepdims, initial=initial).symbolic_call( + x + ) + return backend.numpy.min(x, axis=axis, keepdims=keepdims, initial=initial) + + +class Minimum(Operation): + def call(self, x1, x2): + return backend.numpy.minimum(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and x2_sparse + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) + + +@keras_export(["keras.ops.minimum", "keras.ops.numpy.minimum"]) +def minimum(x1, x2): + """Element-wise minimum of `x1` and `x2`. + + Args: + x1: First tensor. + x2: Second tensor. + + Returns: + Output tensor, element-wise minimum of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Minimum().symbolic_call(x1, x2) + return backend.numpy.minimum(x1, x2) + + +class Mod(Operation): + def call(self, x1, x2): + return backend.numpy.mod(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if output_dtype == "bool": + output_dtype = "int32" + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.mod", "keras.ops.numpy.mod"]) +def mod(x1, x2): + """Returns the element-wise remainder of division. + + Args: + x1: First tensor. + x2: Second tensor. + + Returns: + Output tensor, element-wise remainder of division. + """ + if any_symbolic_tensors((x1, x2)): + return Mod().symbolic_call(x1, x2) + return backend.numpy.mod(x1, x2) + + +class Moveaxis(Operation): + def __init__(self, source, destination, *, name=None): + super().__init__(name=name) + if isinstance(source, int): + self.source = [source] + else: + self.source = source + if isinstance(destination, int): + self.destination = [destination] + else: + self.destination = destination + + if len(self.source) != len(self.destination): + raise ValueError( + "`source` and `destination` arguments must have the same " + f"number of elements, but received `source={source}` and " + f"`destination={destination}`." + ) + + def call(self, x): + return backend.numpy.moveaxis(x, self.source, self.destination) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + output_shape = [-1 for _ in range(len(x.shape))] + for sc, dst in zip(self.source, self.destination): + output_shape[dst] = x_shape[sc] + x_shape[sc] = -1 + i, j = 0, 0 + while i < len(output_shape): + while i < len(output_shape) and output_shape[i] != -1: + # Find the first dim unset. + i += 1 + while j < len(output_shape) and x_shape[j] == -1: + # Find the first dim not being passed. + j += 1 + if i == len(output_shape): + break + output_shape[i] = x_shape[j] + i += 1 + j += 1 + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.moveaxis", "keras.ops.numpy.moveaxis"]) +def moveaxis(x, source, destination): + """Move axes of a tensor to new positions. + + Other axes remain in their original order. + + Args: + x: Tensor whose axes should be reordered. + source: Original positions of the axes to move. These must be unique. + destination: Destinations positions for each of the original axes. + These must also be unique. + + Returns: + Tensor with moved axes. + """ + if any_symbolic_tensors((x,)): + return Moveaxis(source, destination).symbolic_call(x) + return backend.numpy.moveaxis(x, source=source, destination=destination) + + +class NanToNum(Operation): + def __init__(self, nan=0.0, posinf=None, neginf=None, *, name=None): + super().__init__(name=name) + self.nan = nan + self.posinf = posinf + self.neginf = neginf + + def call(self, x): + return backend.numpy.nan_to_num( + x, nan=self.nan, posinf=self.posinf, neginf=self.neginf + ) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.nan_to_num", + "keras.ops.numpy.nan_to_num", + ] +) +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + """Replace NaN with zero and infinity with large finite numbers. + + Args: + x: Input data. + nan: Optional float or int. Value to replace `NaN` entries with. + posinf: Optional float or int. + Value to replace positive infinity with. + neginf: Optional float or int. + Value to replace negative infinity with. + + Returns: + `x`, with non-finite values replaced. + """ + if any_symbolic_tensors((x,)): + return NanToNum(nan=nan, posinf=posinf, neginf=neginf).symbolic_call(x) + return backend.numpy.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +class Ndim(Operation): + def call(self, x): + return backend.numpy.ndim( + x, + ) + + def compute_output_spec(self, x): + return KerasTensor([len(x.shape)]) + + +@keras_export(["keras.ops.ndim", "keras.ops.numpy.ndim"]) +def ndim(x): + """Return the number of dimensions of a tensor. + + Args: + x: Input tensor. + + Returns: + The number of dimensions in `x`. + """ + if any_symbolic_tensors((x,)): + return Ndim().symbolic_call(x) + return backend.numpy.ndim(x) + + +class Nonzero(Operation): + def call(self, x): + return backend.numpy.nonzero(x) + + def compute_output_spec(self, x): + return tuple( + [KerasTensor((None,), dtype="int32") for _ in range(len(x.shape))] + ) + + +@keras_export(["keras.ops.nonzero", "keras.ops.numpy.nonzero"]) +def nonzero(x): + """Return the indices of the elements that are non-zero. + + Args: + x: Input tensor. + + Returns: + Indices of elements that are non-zero. + """ + if any_symbolic_tensors((x,)): + return Nonzero().symbolic_call(x) + return backend.numpy.nonzero(x) + + +class NotEqual(Operation): + def call(self, x1, x2): + return backend.numpy.not_equal(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export(["keras.ops.not_equal", "keras.ops.numpy.not_equal"]) +def not_equal(x1, x2): + """Return `(x1 != x2)` element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise comparison of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return NotEqual().symbolic_call(x1, x2) + return backend.numpy.not_equal(x1, x2) + + +class OnesLike(Operation): + def __init__(self, dtype=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.ones_like(x, dtype=self.dtype) + + def compute_output_spec(self, x): + dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.ones_like", "keras.ops.numpy.ones_like"]) +def ones_like(x, dtype=None): + """Return a tensor of ones with the same shape and type of `x`. + + Args: + x: Input tensor. + dtype: Overrides the data type of the result. + + Returns: + A tensor of ones with the same shape and type as `x`. + """ + if any_symbolic_tensors((x,)): + return OnesLike(dtype=dtype).symbolic_call(x) + return backend.numpy.ones_like(x, dtype=dtype) + + +class ZerosLike(Operation): + def __init__(self, dtype=None, *, name=None): + super().__init__(name=name) + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.zeros_like(x, dtype=self.dtype) + + def compute_output_spec(self, x, dtype=None): + dtype = ( + backend.standardize_dtype(x.dtype) + if self.dtype is None + else self.dtype + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export( + [ + "keras.ops.zeros_like", + "keras.ops.numpy.zeros_like", + ] +) +def zeros_like(x, dtype=None): + """Return a tensor of zeros with the same shape and type as `x`. + + Args: + x: Input tensor. + dtype: Overrides the data type of the result. + + Returns: + A tensor of zeros with the same shape and type as `x`. + """ + if any_symbolic_tensors((x,)): + return ZerosLike(dtype=dtype).symbolic_call(x) + return backend.numpy.zeros_like(x, dtype=dtype) + + +class Outer(Operation): + def call(self, x1, x2): + return backend.numpy.outer(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", [1]) + x2_shape = getattr(x2, "shape", [1]) + if None in x1_shape: + x1_flatten_shape = None + else: + x1_flatten_shape = int(np.prod(x1_shape)) + if None in x2_shape: + x2_flatten_shape = None + else: + x2_flatten_shape = int(np.prod(x2_shape)) + output_shape = [x1_flatten_shape, x2_flatten_shape] + output_dtype = backend.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.outer", "keras.ops.numpy.outer"]) +def outer(x1, x2): + """Compute the outer product of two vectors. + + Given two vectors `x1` and `x2`, the outer product is: + + ``` + out[i, j] = x1[i] * x2[j] + ``` + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Outer product of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Outer().symbolic_call(x1, x2) + return backend.numpy.outer(x1, x2) + + +class Pad(Operation): + def __init__(self, pad_width, mode="constant", *, name=None): + super().__init__(name=name) + self.pad_width = self._process_pad_width(pad_width) + self.mode = mode + + def _process_pad_width(self, pad_width): + if isinstance(pad_width, int): + return ((pad_width, pad_width),) + if isinstance(pad_width, (tuple, list)) and isinstance( + pad_width[0], int + ): + return (pad_width,) + first_len = len(pad_width[0]) + for i, pw in enumerate(pad_width): + if len(pw) != first_len: + raise ValueError( + "`pad_width` should be a list of tuples of length " + f"1 or 2. Received: pad_width={pad_width}" + ) + if len(pw) == 1: + pad_width[i] = (pw[0], pw[0]) + return pad_width + + def call(self, x, constant_values=None): + if len(self.pad_width) > 1 and len(self.pad_width) != len(x.shape): + raise ValueError( + "`pad_width` must have the same length as `x.shape`. " + f"Received: pad_width={self.pad_width} " + f"(of length {len(self.pad_width)}) and x.shape={x.shape} " + f"(of length {len(x.shape)})" + ) + return backend.numpy.pad( + x, + pad_width=self.pad_width, + mode=self.mode, + constant_values=constant_values, + ) + + def compute_output_spec(self, x, constant_values=None): + output_shape = list(x.shape) + if len(self.pad_width) == 1: + pad_width = [self.pad_width[0] for _ in range(len(output_shape))] + elif len(self.pad_width) == len(output_shape): + pad_width = self.pad_width + else: + raise ValueError( + "`pad_width` must have the same length as `x.shape`. " + f"Received: pad_width={self.pad_width} " + f"(of length {len(self.pad_width)}) and x.shape={x.shape} " + f"(of length {len(x.shape)})" + ) + + for i in range(len(output_shape)): + if output_shape[i] is None: + output_shape[i] = None + else: + output_shape[i] += pad_width[i][0] + pad_width[i][1] + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.pad", "keras.ops.numpy.pad"]) +def pad(x, pad_width, mode="constant", constant_values=None): + """Pad a tensor. + + Args: + x: Tensor to pad. + pad_width: Number of values padded to the edges of each axis. + `((before_1, after_1), ...(before_N, after_N))` unique pad + widths for each axis. + `((before, after),)` yields same before and after pad for + each axis. + `(pad,)` or `int` is a shortcut for `before = after = pad` + width for all axes. + mode: One of `"constant"`, `"edge"`, `"linear_ramp"`, + `"maximum"`, `"mean"`, `"median"`, `"minimum"`, + `"reflect"`, `"symmetric"`, `"wrap"`, `"empty"`, + `"circular"`. Defaults to `"constant"`. + constant_values: value to pad with if `mode == "constant"`. + Defaults to `0`. A `ValueError` is raised if not None and + `mode != "constant"`. + + Note: + Torch backend only supports modes `"constant"`, `"reflect"`, + `"symmetric"` and `"circular"`. + Only Torch backend supports `"circular"` mode. + + Note: + Tensorflow backend only supports modes `"constant"`, `"reflect"` + and `"symmetric"`. + + Returns: + Padded tensor. + """ + return Pad(pad_width, mode=mode)(x, constant_values=constant_values) + + +class Prod(Operation): + def __init__(self, axis=None, keepdims=False, dtype=None, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + self.dtype = None if dtype is None else backend.standardize_dtype(dtype) + + def call(self, x): + return backend.numpy.prod( + x, + axis=self.axis, + keepdims=self.keepdims, + dtype=self.dtype, + ) + + def compute_output_spec(self, x): + if self.dtype is not None: + dtype = self.dtype + else: + dtype = backend.standardize_dtype(x.dtype) + if dtype == "bool": + dtype = "int32" + elif dtype in ("int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + # TODO: torch doesn't support uint32 + if backend.backend() == "torch" and dtype == "uint32": + dtype = "int32" + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=dtype, + ) + + +@keras_export(["keras.ops.prod", "keras.ops.numpy.prod"]) +def prod(x, axis=None, keepdims=False, dtype=None): + """Return the product of tensor elements over a given axis. + + Args: + x: Input tensor. + axis: Axis or axes along which a product is performed. The default, + `axis=None`, will compute the product of all elements + in the input tensor. + keepdims: If this is set to `True`, the axes which are reduce + are left in the result as dimensions with size one. + dtype: Data type of the returned tensor. + + Returns: + Product of elements of `x` over the given axis or axes. + """ + if any_symbolic_tensors((x,)): + return Prod(axis=axis, keepdims=keepdims, dtype=dtype).symbolic_call(x) + return backend.numpy.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + + +class Quantile(Operation): + def __init__( + self, axis=None, method="linear", keepdims=False, *, name=None + ): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.method = method + self.keepdims = keepdims + + def call(self, x, q): + return backend.numpy.quantile( + x, q, axis=self.axis, keepdims=self.keepdims + ) + + def compute_output_spec(self, x, q): + output_shape = reduce_shape( + x.shape, axis=self.axis, keepdims=self.keepdims + ) + if hasattr(q, "shape"): + if len(q.shape) > 0: + output_shape = (q.shape[0],) + output_shape + if backend.standardize_dtype(x.dtype) == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(x.dtype, float) + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.quantile", "keras.ops.numpy.quantile"]) +def quantile(x, q, axis=None, method="linear", keepdims=False): + """Compute the q-th quantile(s) of the data along the specified axis. + + Args: + x: Input tensor. + q: Probability or sequence of probabilities for the quantiles to + compute. Values must be between 0 and 1 inclusive. + axis: Axis or axes along which the quantiles are computed. Defaults to + `axis=None` which is to compute the quantile(s) along a flattened + version of the array. + method: A string specifies the method to use for estimating the + quantile. Available methods are `"linear"`, `"lower"`, `"higher"`, + `"midpoint"`, and `"nearest"`. Defaults to `"linear"`. + If the desired quantile lies between two data points `i < j`: + - `"linear"`: `i + (j - i) * fraction`, where fraction is the + fractional part of the index surrounded by `i` and `j`. + - `"lower"`: `i`. + - `"higher"`: `j`. + - `"midpoint"`: `(i + j) / 2` + - `"nearest"`: `i` or `j`, whichever is nearest. + keepdims: If this is set to `True`, the axes which are reduce + are left in the result as dimensions with size one. + + Returns: + The quantile(s). If `q` is a single probability and `axis=None`, then + the result is a scalar. If multiple probabilities levels are given, + first axis of the result corresponds to the quantiles. The other axes + are the axes that remain after the reduction of `x`. + """ + if any_symbolic_tensors((x, q)): + return Quantile( + axis=axis, method=method, keepdims=keepdims + ).symbolic_call(x, q) + return backend.numpy.quantile( + x, q, axis=axis, method=method, keepdims=keepdims + ) + + +class Ravel(Operation): + def call(self, x): + return backend.numpy.ravel(x) + + def compute_output_spec(self, x): + if None in x.shape: + output_shape = [ + None, + ] + else: + output_shape = [int(np.prod(x.shape))] + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.ravel", "keras.ops.numpy.ravel"]) +def ravel(x): + """Return a contiguous flattened tensor. + + A 1-D tensor, containing the elements of the input, is returned. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + if any_symbolic_tensors((x,)): + return Ravel().symbolic_call(x) + return backend.numpy.ravel(x) + + +class UnravelIndex(Operation): + def __init__(self, shape, *, name=None): + super().__init__(name=name) + self.shape = shape + + def call(self, indices): + return backend.numpy.unravel_index(indices, self.shape) + + def compute_output_spec(self, indices): + if None in self.shape: + output_shapes = [[None] for _ in self.shape] + else: + if isinstance(indices, int): + output_shapes = [[1] for _ in self.shape] + elif hasattr(indices, "shape"): + output_shapes = [list(indices.shape) for _ in self.shape] + else: + try: + indices_shape = np.shape(indices) + output_shapes = [list(indices_shape) for _ in self.shape] + except Exception: + output_shapes = [[None] for _ in self.shape] + + return [ + KerasTensor(shape, dtype=indices.dtype) for shape in output_shapes + ] + + +@keras_export(["keras.ops.unravel_index", "keras.ops.numpy.unravel_index"]) +def unravel_index(indices, shape): + """Convert flat indices to coordinate arrays in a given array shape. + + Args: + indices: An integer or array of integers representing flat indices. + shape: The shape of the array to unravel into. + + Returns: + Tuple of arrays for each dimension with unraveled indices. + + Example: + >>> indices = 5 + >>> shape = (3, 3) + >>> unravel_index(indices, shape) + (1, 2) # 5 is at row 1, column 2 in a 3x3 array + """ + if any_symbolic_tensors((indices,)): + return UnravelIndex(shape).symbolic_call(indices) + + return backend.numpy.unravel_index(indices, shape) + + +class Real(Operation): + def call(self, x): + return backend.numpy.real(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.real", "keras.ops.numpy.real"]) +def real(x): + """Return the real part of the complex argument. + + Args: + x: Input tensor. + + Returns: + The real component of the complex argument. + """ + if any_symbolic_tensors((x,)): + return Real().symbolic_call(x) + return backend.numpy.real(x) + + +class Reciprocal(Operation): + def call(self, x): + return backend.numpy.reciprocal(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape) + + +@keras_export( + [ + "keras.ops.reciprocal", + "keras.ops.numpy.reciprocal", + ] +) +def reciprocal(x): + """Return the reciprocal of the argument, element-wise. + + Calculates `1/x`. + + Args: + x: Input tensor. + + Returns: + Output tensor, element-wise reciprocal of `x`. + """ + if any_symbolic_tensors((x,)): + return Reciprocal().symbolic_call(x) + return backend.numpy.reciprocal(x) + + +class Repeat(Operation): + def __init__(self, repeats, axis=None, *, name=None): + super().__init__(name=name) + self.axis = axis + self.repeats = repeats + + def call(self, x): + return backend.numpy.repeat(x, self.repeats, axis=self.axis) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + repeats = self.repeats + if isinstance(repeats, int): + repeats = [repeats] + repeats_size = len(repeats) + broadcast = repeats_size == 1 + + if self.axis is None: + if None in x_shape: + return KerasTensor([None], dtype=x.dtype) + + x_flatten_size = int(np.prod(x_shape)) + if broadcast: + output_shape = [x_flatten_size * repeats[0]] + elif repeats_size != x_flatten_size: + raise ValueError( + "Size of `repeats` and " + "dimensions of `x` after flattening should be compatible. " + f"Received: {repeats_size} and {x_flatten_size}" + ) + else: + output_shape = [int(np.sum(repeats))] + return KerasTensor(output_shape, dtype=x.dtype) + + size_on_ax = x_shape[self.axis] + if size_on_ax is None: + return KerasTensor(x_shape, dtype=x.dtype) + + output_shape = x_shape + if broadcast: + output_shape[self.axis] = size_on_ax * repeats[0] + elif size_on_ax != repeats_size: + raise ValueError( + "Size of `repeats` and " + f"dimensions of `axis {self.axis} of x` should be compatible. " + f"Received: {repeats_size} and {x_shape}" + ) + else: + output_shape[self.axis] = int(np.sum(repeats)) + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.repeat", "keras.ops.numpy.repeat"]) +def repeat(x, repeats, axis=None): + """Repeat each element of a tensor after themselves. + + Args: + x: Input tensor. + repeats: The number of repetitions for each element. + axis: The axis along which to repeat values. By default, use + the flattened input array, and return a flat output array. + + Returns: + Output tensor. + """ + if any_symbolic_tensors((x,)): + return Repeat(repeats, axis=axis).symbolic_call(x) + return backend.numpy.repeat(x, repeats, axis=axis) + + +class Reshape(Operation): + def __init__(self, newshape, *, name=None): + super().__init__(name=name) + self.newshape = newshape + + def call(self, x): + return backend.numpy.reshape(x, self.newshape) + + def compute_output_spec(self, x): + output_shape = operation_utils.compute_reshape_output_shape( + x.shape, self.newshape, "newshape" + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.reshape", "keras.ops.numpy.reshape"]) +def reshape(x, newshape): + """Gives a new shape to a tensor without changing its data. + + Args: + x: Input tensor. + newshape: The new shape should be compatible with the original shape. + One shape dimension can be -1 in which case the value is + inferred from the length of the array and remaining dimensions. + + Returns: + The reshaped tensor. + """ + if any_symbolic_tensors((x,)): + return Reshape(newshape).symbolic_call(x) + return backend.numpy.reshape(x, newshape) + + +class Roll(Operation): + def __init__(self, shift, axis=None, *, name=None): + super().__init__(name=name) + self.shift = shift + self.axis = axis + + def call(self, x): + return backend.numpy.roll(x, self.shift, self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.roll", "keras.ops.numpy.roll"]) +def roll(x, shift, axis=None): + """Roll tensor elements along a given axis. + + Elements that roll beyond the last position are re-introduced at the first. + + Args: + x: Input tensor. + shift: The number of places by which elements are shifted. + axis: The axis along which elements are shifted. By default, the + array is flattened before shifting, after which the original + shape is restored. + + Returns: + Output tensor. + """ + if any_symbolic_tensors((x,)): + return Roll(shift, axis=axis).symbolic_call(x) + return backend.numpy.roll(x, shift, axis=axis) + + +class Round(Operation): + def __init__(self, decimals=0, *, name=None): + super().__init__(name=name) + self.decimals = decimals + + def call(self, x): + return backend.numpy.round(x, self.decimals) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.round", "keras.ops.numpy.round"]) +def round(x, decimals=0): + """Evenly round to the given number of decimals. + + Args: + x: Input tensor. + decimals: Number of decimal places to round to. Defaults to `0`. + + Returns: + Output tensor. + """ + if any_symbolic_tensors((x,)): + return Round(decimals).symbolic_call(x) + return backend.numpy.round(x, decimals) + + +class SearchSorted(Operation): + def __init__(self, side="left", *, name=None): + super().__init__(name=name) + self.side = side + + def call(self, sorted_sequence, values): + sorted_sequence = backend.convert_to_tensor(sorted_sequence) + values = backend.convert_to_tensor(values) + return backend.numpy.searchsorted( + sorted_sequence, values, side=self.side + ) + + def compute_output_spec(self, sorted_sequence, values): + if len(sorted_sequence.shape) != 1: + raise ValueError( + "searchsorted only supports 1-D sorted sequences. Use" + "keras.ops.vectorized_map to extend to N-D sequences." + ) + sequence_len = sorted_sequence.shape[0] + out_type = ( + "int32" + if sequence_len is not None + and sequence_len <= np.iinfo(np.int32).max + else "int64" + ) + return KerasTensor(values.shape, dtype=out_type) + + +@keras_export(["keras.ops.searchsorted", "keras.ops.numpy.searchsorted"]) +def searchsorted(sorted_sequence, values, side="left"): + """Perform a binary search, returning indices for insertion of `values` + into `sorted_sequence` that maintain the sorting order. + + Args: + sorted_sequence: 1-D input tensor, sorted along the innermost + dimension. + values: N-D tensor of query insertion values. + side: 'left' or 'right', specifying the direction in which to insert + for the equality case (tie-breaker). + + Returns: + Tensor of insertion indices of same shape as `values`. + """ + if any_symbolic_tensors((sorted_sequence, values)): + return SearchSorted(side=side).symbolic_call(sorted_sequence, values) + + sorted_sequence = backend.convert_to_tensor(sorted_sequence) + values = backend.convert_to_tensor(values) + return backend.numpy.searchsorted(sorted_sequence, values, side=side) + + +class Sign(Operation): + def call(self, x): + return backend.numpy.sign(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.sign", "keras.ops.numpy.sign"]) +def sign(x): + """Returns a tensor with the signs of the elements of `x`. + + Args: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Sign().symbolic_call(x) + return backend.numpy.sign(x) + + +class Signbit(Operation): + def call(self, x): + return backend.numpy.signbit(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype="bool", sparse=sparse) + + +@keras_export(["keras.ops.signbit", "keras.ops.numpy.signbit"]) +def signbit(x): + """Return the sign bit of the elements of `x`. + + The output boolean tensor contains `True` where the sign of `x` is negative, + and `False` otherwise. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Signbit().symbolic_call(x) + return backend.numpy.signbit(x) + + +class Sin(Operation): + def call(self, x): + return backend.numpy.sin(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.sin", "keras.ops.numpy.sin"]) +def sin(x): + """Trigonometric sine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Sin().symbolic_call(x) + return backend.numpy.sin(x) + + +class Sinh(Operation): + def call(self, x): + return backend.numpy.sinh(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.sinh", "keras.ops.numpy.sinh"]) +def sinh(x): + """Hyperbolic sine, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Sinh().symbolic_call(x) + return backend.numpy.sinh(x) + + +class Size(Operation): + def call(self, x): + return backend.numpy.size(x) + + def compute_output_spec(self, x): + return KerasTensor([], dtype="int32") + + +@keras_export(["keras.ops.size", "keras.ops.numpy.size"]) +def size(x): + """Return the number of elements in a tensor. + + Args: + x: Input tensor. + + Returns: + Number of elements in `x`. + """ + if any_symbolic_tensors((x,)): + return Size().symbolic_call(x) + return backend.numpy.size(x) + + +class Sort(Operation): + def __init__(self, axis=-1, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.numpy.sort(x, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, x.dtype) + + +@keras_export(["keras.ops.sort", "keras.ops.numpy.sort"]) +def sort(x, axis=-1): + """Sorts the elements of `x` along a given axis in ascending order. + + Args: + x: Input tensor. + axis: Axis along which to sort. If `None`, the tensor is flattened + before sorting. Defaults to `-1`; the last axis. + + Returns: + Sorted tensor. + """ + if any_symbolic_tensors((x,)): + return Sort(axis=axis).symbolic_call(x) + return backend.numpy.sort(x, axis=axis) + + +class Split(Operation): + def __init__(self, indices_or_sections, axis=0, *, name=None): + super().__init__(name=name) + if not isinstance(indices_or_sections, int): + indices_or_sections = tuple(indices_or_sections) + self.indices_or_sections = indices_or_sections + self.axis = axis + + def call(self, x): + return backend.numpy.split(x, self.indices_or_sections, axis=self.axis) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + x_size_on_axis = x_shape[self.axis] + if isinstance(self.indices_or_sections, int): + if x_size_on_axis is None: + x_shape[self.axis] = None + return [ + KerasTensor(x_shape, dtype=x.dtype) + for _ in range(self.indices_or_sections) + ] + if np.mod(x_size_on_axis, self.indices_or_sections) != 0: + raise ValueError( + "`x` size on given `axis` must be dividible by " + "`indices_or_sections` when `indices_or_sections` is an " + f"int. But received {x_size_on_axis} and " + f"{self.indices_or_sections}." + ) + size = x_size_on_axis // self.indices_or_sections + x_shape[self.axis] = size + return [ + KerasTensor(x_shape, dtype=x.dtype) + for _ in range(self.indices_or_sections) + ] + + indices_or_sections = (0, *self.indices_or_sections, x_size_on_axis) + output_size = np.diff(indices_or_sections) + outputs = [] + for i in range(len(output_size)): + output_shape = list(x_shape) + output_shape[self.axis] = int(output_size[i]) + outputs.append(KerasTensor(output_shape, dtype=x.dtype)) + return outputs + + +@keras_export(["keras.ops.split", "keras.ops.numpy.split"]) +def split(x, indices_or_sections, axis=0): + """Split a tensor into chunks. + + Args: + x: Input tensor. + indices_or_sections: If an integer, N, the tensor will be split into N + equal sections along `axis`. If a 1-D array of sorted integers, + the entries indicate indices at which the tensor will be split + along `axis`. + axis: Axis along which to split. Defaults to `0`. + + Note: + A split does not have to result in equal division when using + Torch backend. + + Returns: + A list of tensors. + """ + if any_symbolic_tensors((x,)): + return Split(indices_or_sections, axis=axis).symbolic_call(x) + return backend.numpy.split(x, indices_or_sections, axis=axis) + + +class Stack(Operation): + def __init__(self, axis=0, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.numpy.stack(x, axis=self.axis) + + def compute_output_spec(self, x): + first_shape = x[0].shape + dtypes_to_resolve = [] + for a in x: + if not shape_equal(a.shape, first_shape, axis=[], allow_none=True): + raise ValueError( + "Every value in `x` must have the same shape. But found " + f"element of shape {a.shape}, which is different from the " + f"first element's shape {first_shape}." + ) + dtypes_to_resolve.append(getattr(a, "dtype", type(a))) + + size_on_axis = len(x) + output_shape = list(first_shape) + if self.axis == -1: + output_shape = output_shape + [size_on_axis] + elif self.axis >= 0: + output_shape.insert(self.axis, size_on_axis) + else: + output_shape.insert(self.axis + 1, size_on_axis) + output_dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.stack", "keras.ops.numpy.stack"]) +def stack(x, axis=0): + """Join a sequence of tensors along a new axis. + + The `axis` parameter specifies the index of the new axis in the + dimensions of the result. + + Args: + x: A sequence of tensors. + axis: Axis along which to stack. Defaults to `0`. + + Returns: + The stacked tensor. + """ + if any_symbolic_tensors((x,)): + return Stack(axis=axis).symbolic_call(x) + return backend.numpy.stack(x, axis=axis) + + +class Std(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.std(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + output_dtype = backend.standardize_dtype(x.dtype) + if "int" in output_dtype or output_dtype == "bool": + output_dtype = backend.floatx() + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=output_dtype, + ) + + +@keras_export(["keras.ops.std", "keras.ops.numpy.std"]) +def std(x, axis=None, keepdims=False): + """Compute the standard deviation along the specified axis. + + Args: + x: Input tensor. + axis: Axis along which to compute standard deviation. + Default is to compute the standard deviation of the + flattened tensor. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Returns: + Output tensor containing the standard deviation values. + """ + if any_symbolic_tensors((x,)): + return Std(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.std(x, axis=axis, keepdims=keepdims) + + +class Swapaxes(Operation): + def __init__(self, axis1, axis2, *, name=None): + super().__init__(name=name) + + self.axis1 = axis1 + self.axis2 = axis2 + + def call(self, x): + return backend.numpy.swapaxes(x, self.axis1, self.axis2) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + tmp = x_shape[self.axis1] + x_shape[self.axis1] = x_shape[self.axis2] + x_shape[self.axis2] = tmp + return KerasTensor(x_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.swapaxes", "keras.ops.numpy.swapaxes"]) +def swapaxes(x, axis1, axis2): + """Interchange two axes of a tensor. + + Args: + x: Input tensor. + axis1: First axis. + axis2: Second axis. + + Returns: + A tensor with the axes swapped. + """ + if any_symbolic_tensors((x,)): + return Swapaxes(axis1, axis2).symbolic_call(x) + return backend.numpy.swapaxes(x, axis1=axis1, axis2=axis2) + + +class Take(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x, indices): + return backend.numpy.take(x, indices, axis=self.axis) + + def compute_output_spec(self, x, indices): + x_shape = list(x.shape) + if isinstance(indices, KerasTensor): + indices_shape = list(indices.shape) + ragged = indices.ragged + else: + indices_shape = list(getattr(np.array(indices), "shape", [])) + ragged = False + if self.axis is None: + return KerasTensor(indices_shape, dtype=x.dtype) + + # make sure axis is non-negative + axis = len(x_shape) + self.axis if self.axis < 0 else self.axis + output_shape = x_shape[:axis] + indices_shape + x_shape[axis + 1 :] + return KerasTensor(output_shape, dtype=x.dtype, ragged=ragged) + + +@keras_export(["keras.ops.take", "keras.ops.numpy.take"]) +def take(x, indices, axis=None): + """Take elements from a tensor along an axis. + + Args: + x: Source tensor. + indices: The indices of the values to extract. + axis: The axis over which to select values. By default, the + flattened input tensor is used. + + Returns: + The corresponding tensor of values. + """ + if any_symbolic_tensors((x, indices)): + return Take(axis=axis).symbolic_call(x, indices) + return backend.numpy.take(x, indices, axis=axis) + + +class TakeAlongAxis(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x, indices): + return backend.numpy.take_along_axis(x, indices, axis=self.axis) + + def compute_output_spec(self, x, indices): + output_shape = operation_utils.compute_take_along_axis_output_shape( + x.shape, indices.shape, self.axis + ) + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export( + [ + "keras.ops.take_along_axis", + "keras.ops.numpy.take_along_axis", + ] +) +def take_along_axis(x, indices, axis=None): + """Select values from `x` at the 1-D `indices` along the given axis. + + Args: + x: Source tensor. + indices: The indices of the values to extract. + axis: The axis over which to select values. By default, the flattened + input tensor is used. + + Returns: + The corresponding tensor of values. + """ + if any_symbolic_tensors((x, indices)): + return TakeAlongAxis(axis=axis).symbolic_call(x, indices) + return backend.numpy.take_along_axis(x, indices, axis=axis) + + +class Tan(Operation): + def call(self, x): + return backend.numpy.tan(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.tan", "keras.ops.numpy.tan"]) +def tan(x): + """Compute tangent, element-wise. + + Args: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Tan().symbolic_call(x) + return backend.numpy.tan(x) + + +class Tanh(Operation): + def call(self, x): + return backend.numpy.tanh(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = backend.floatx() + else: + dtype = dtypes.result_type(dtype, float) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.tanh", "keras.ops.numpy.tanh"]) +def tanh(x): + """Hyperbolic tangent, element-wise. + + Arguments: + x: Input tensor. + + Returns: + Output tensor of same shape as `x`. + """ + if any_symbolic_tensors((x,)): + return Tanh().symbolic_call(x) + return backend.numpy.tanh(x) + + +class Tensordot(Operation): + def __init__(self, axes=2, *, name=None): + super().__init__(name=name) + self.axes = axes + + def call(self, x1, x2): + return backend.numpy.tensordot(x1, x2, axes=self.axes) + + def compute_output_spec(self, x1, x2): + x1_shape = list(getattr(x1, "shape", [])) + x2_shape = list(getattr(x2, "shape", [])) + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if not isinstance(self.axes, int): + x1_select_shape = [x1_shape[ax] for ax in self.axes[0]] + x2_select_shape = [x2_shape[ax] for ax in self.axes[1]] + if not shape_equal( + x1_select_shape, x2_select_shape, allow_none=True + ): + raise ValueError( + "Shape mismatch on `x1[axes[0]]` and `x2[axes[1]]`, " + f"received {x1_select_shape} and {x2_select_shape}." + ) + + for ax in self.axes[0]: + x1_shape[ax] = -1 + for ax in self.axes[1]: + x2_shape[ax] = -1 + + x1_shape = list(filter((-1).__ne__, x1_shape)) + x2_shape = list(filter((-1).__ne__, x2_shape)) + + output_shape = x1_shape + x2_shape + return KerasTensor(output_shape, dtype=dtype) + + if self.axes <= 0: + output_shape = x1_shape + x2_shape + else: + output_shape = x1_shape[: -self.axes] + x2_shape[self.axes :] + + return KerasTensor(output_shape, dtype=dtype) + + +@keras_export(["keras.ops.tensordot", "keras.ops.numpy.tensordot"]) +def tensordot(x1, x2, axes=2): + """Compute the tensor dot product along specified axes. + + Args: + x1: First tensor. + x2: Second tensor. + axes: - If an integer, N, sum over the last N axes of `x1` and the + first N axes of `x2` in order. The sizes of the corresponding + axes must match. + - Or, a list of axes to be summed over, first sequence applying + to `x1`, second to `x2`. Both sequences must be of the + same length. + + Returns: + The tensor dot product of the inputs. + """ + if any_symbolic_tensors((x1, x2)): + return Tensordot(axes=axes).symbolic_call(x1, x2) + return backend.numpy.tensordot(x1, x2, axes=axes) + + +class Tile(Operation): + def __init__(self, repeats, *, name=None): + super().__init__(name=name) + self.repeats = repeats + + def call(self, x): + return backend.numpy.tile(x, self.repeats) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + repeats = self.repeats + if isinstance(repeats, int): + repeats = [repeats] + if len(x_shape) > len(repeats): + repeats = [1] * (len(x_shape) - len(repeats)) + repeats + else: + x_shape = [1] * (len(repeats) - len(x_shape)) + x_shape + + output_shape = [] + for x_size, repeat in zip(x_shape, repeats): + if x_size is None: + output_shape.append(None) + else: + output_shape.append(x_size * repeat) + return KerasTensor(output_shape, dtype=x.dtype) + + +@keras_export(["keras.ops.tile", "keras.ops.numpy.tile"]) +def tile(x, repeats): + """Repeat `x` the number of times given by `repeats`. + + If `repeats` has length `d`, the result will have dimension of + `max(d, x.ndim)`. + + If `x.ndim < d`, `x` is promoted to be d-dimensional by prepending + new axes. + + If `x.ndim > d`, `repeats` is promoted to `x.ndim` by prepending 1's to it. + + Args: + x: Input tensor. + repeats: The number of repetitions of `x` along each axis. + + Returns: + The tiled output tensor. + """ + if any_symbolic_tensors((x,)): + return Tile( + repeats, + ).symbolic_call(x) + return backend.numpy.tile(x, repeats) + + +class Trace(Operation): + def __init__(self, offset=0, axis1=0, axis2=1, *, name=None): + super().__init__(name=name) + self.offset = offset + self.axis1 = axis1 + self.axis2 = axis2 + + def call(self, x): + return backend.numpy.trace( + x, offset=self.offset, axis1=self.axis1, axis2=self.axis2 + ) + + def compute_output_spec(self, x): + x_shape = list(x.shape) + x_shape[self.axis1] = -1 + x_shape[self.axis2] = -1 + output_shape = list(filter((-1).__ne__, x_shape)) + output_dtype = backend.standardize_dtype(x.dtype) + if output_dtype not in ("int64", "uint32", "uint64"): + output_dtype = dtypes.result_type(output_dtype, "int32") + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.trace", "keras.ops.numpy.trace"]) +def trace(x, offset=0, axis1=0, axis2=1): + """Return the sum along diagonals of the tensor. + + If `x` is 2-D, the sum along its diagonal with the given offset is + returned, i.e., the sum of elements `x[i, i+offset]` for all `i`. + + If a has more than two dimensions, then the axes specified by `axis1` + and `axis2` are used to determine the 2-D sub-arrays whose traces are + returned. + + The shape of the resulting tensor is the same as that of `x` with `axis1` + and `axis2` removed. + + Args: + x: Input tensor. + offset: Offset of the diagonal from the main diagonal. Can be + both positive and negative. Defaults to `0`. + axis1: Axis to be used as the first axis of the 2-D sub-arrays. + Defaults to `0`.(first axis). + axis2: Axis to be used as the second axis of the 2-D sub-arrays. + Defaults to `1` (second axis). + + Returns: + If `x` is 2-D, the sum of the diagonal is returned. If `x` has + larger dimensions, then a tensor of sums along diagonals is + returned. + """ + if any_symbolic_tensors((x,)): + return Trace(offset, axis1, axis2).symbolic_call(x) + return backend.numpy.trace(x, offset=offset, axis1=axis1, axis2=axis2) + + +@keras_export(["keras.ops.tri", "keras.ops.numpy.tri"]) +def tri(N, M=None, k=0, dtype=None): + """Return a tensor with ones at and below a diagonal and zeros elsewhere. + + Args: + N: Number of rows in the tensor. + M: Number of columns in the tensor. + k: The sub-diagonal at and below which the array is filled. + `k = 0` is the main diagonal, while `k < 0` is below it, and + `k > 0` is above. The default is 0. + dtype: Data type of the returned tensor. The default is "float32". + + Returns: + Tensor with its lower triangle filled with ones and zeros elsewhere. + `T[i, j] == 1` for `j <= i + k`, 0 otherwise. + """ + return backend.numpy.tri(N, M=M, k=k, dtype=dtype) + + +class Tril(Operation): + def __init__(self, k=0, *, name=None): + super().__init__(name=name) + self.k = k + + def call(self, x): + return backend.numpy.tril(x, k=self.k) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.tril", "keras.ops.numpy.tril"]) +def tril(x, k=0): + """Return lower triangle of a tensor. + + For tensors with `ndim` exceeding 2, `tril` will apply to the + final two axes. + + Args: + x: Input tensor. + k: Diagonal above which to zero elements. Defaults to `0`. the + main diagonal. `k < 0` is below it, and `k > 0` is above it. + + Returns: + Lower triangle of `x`, of same shape and data type as `x`. + """ + if any_symbolic_tensors((x,)): + return Tril(k=k).symbolic_call(x) + return backend.numpy.tril(x, k=k) + + +class Triu(Operation): + def __init__(self, k=0, *, name=None): + super().__init__(name=name) + self.k = k + + def call(self, x): + return backend.numpy.triu(x, k=self.k) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.triu", "keras.ops.numpy.triu"]) +def triu(x, k=0): + """Return upper triangle of a tensor. + + For tensors with `ndim` exceeding 2, `triu` will apply to the + final two axes. + + Args: + x: Input tensor. + k: Diagonal below which to zero elements. Defaults to `0`. the + main diagonal. `k < 0` is below it, and `k > 0` is above it. + + Returns: + Upper triangle of `x`, of same shape and data type as `x`. + """ + if any_symbolic_tensors((x,)): + return Triu(k=k).symbolic_call(x) + return backend.numpy.triu(x, k=k) + + +class Trunc(Operation): + def call(self, x): + return backend.numpy.trunc(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +@keras_export(["keras.ops.trunc", "keras.ops.numpy.trunc"]) +def trunc(x): + """Return the truncated value of the input, element-wise. + + The truncated value of the scalar `x` is the nearest integer `i` which is + closer to zero than `x` is. In short, the fractional part of the signed + number `x` is discarded. + + Args: + x: Input tensor. + + Returns: + The truncated value of each element in `x`. + + Example: + >>> x = ops.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) + >>> ops.trunc(x) + array([-1.0, -1.0, -0.0, 0.0, 1.0, 1.0, 2.0]) + """ + if any_symbolic_tensors((x,)): + return Trunc().symbolic_call(x) + return backend.numpy.trunc(x) + + +class Vdot(Operation): + def call(self, x1, x2): + return backend.numpy.vdot(x1, x2) + + def compute_output_spec(self, x1, x2): + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor([], dtype=dtype) + + +@keras_export(["keras.ops.vdot", "keras.ops.numpy.vdot"]) +def vdot(x1, x2): + """Return the dot product of two vectors. + + If the first argument is complex, the complex conjugate of the first + argument is used for the calculation of the dot product. + + Multidimensional tensors are flattened before the dot product is taken. + + Args: + x1: First input tensor. If complex, its complex conjugate is taken + before calculation of the dot product. + x2: Second input tensor. + + Returns: + Output tensor. + """ + if any_symbolic_tensors((x1, x2)): + return Vdot().symbolic_call(x1, x2) + return backend.numpy.vdot(x1, x2) + + +class Inner(Operation): + def call(self, x1, x2): + return backend.numpy.inner(x1, x2) + + def compute_output_spec(self, x1, x2): + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor([], dtype=dtype) + + +@keras_export(["keras.ops.inner", "keras.ops.numpy.inner"]) +def inner(x1, x2): + """Return the inner product of two tensors. + + Ordinary inner product of vectors for 1-D tensors + (without complex conjugation), in higher dimensions + a sum product over the last axes. + + Multidimensional arrays are treated as vectors by flattening + all but their last axes. The resulting dot product is performed + over their last axes. + + Args: + x1: First input tensor. + x2: Second input tensor. The last dimension of `x1` and `x2` + must match. + + Returns: + Output tensor. The shape of the output is determined by + broadcasting the shapes of `x1` and `x2` after removing + their last axes. + """ + if any_symbolic_tensors((x1, x2)): + return Inner().symbolic_call(x1, x2) + return backend.numpy.inner(x1, x2) + + +@keras_export(["keras.ops.vectorize", "keras.ops.numpy.vectorize"]) +def vectorize(pyfunc, *, excluded=None, signature=None): + """Turn a function into a vectorized function. + + Example: + + ```python + def myfunc(a, b): + return a + b + + vfunc = keras.ops.vectorize(myfunc) + y = vfunc([1, 2, 3, 4], 2) # Returns Tensor([3, 4, 5, 6]) + ``` + + Args: + pyfunc: Callable of a single tensor argument. + excluded: Optional set of integers representing + positional arguments for which the function + will not be vectorized. + These will be passed directly to `pyfunc` unmodified. + signature: Optional generalized universal function signature, + e.g., `"(m,n),(n)->(m)"` for vectorized + matrix-vector multiplication. If provided, + `pyfunc` will be called with (and expected to return) + arrays with shapes given by the size of corresponding + core dimensions. By default, `pyfunc` is assumed + to take scalars tensors as input and output. + + Returns: + A new function that applies `pyfunc` to every element + of its input along axis 0 (the batch axis). + """ + if not callable(pyfunc): + raise ValueError( + "Expected argument `pyfunc` to be a callable. " + f"Received: pyfunc={pyfunc}" + ) + return backend.numpy.vectorize( + pyfunc, excluded=excluded, signature=signature + ) + + +class Vstack(Operation): + def call(self, xs): + return backend.numpy.vstack(xs) + + def compute_output_spec(self, xs): + first_shape = xs[0].shape + total_size_on_axis = 0 + dtypes_to_resolve = [] + for x in xs: + if not shape_equal(x.shape, first_shape, axis=[0], allow_none=True): + raise ValueError( + "Every value in `xs` must have the same shape except on " + f"the `axis` dim. But found element of shape {x.shape}, " + f"which is different from the first element's " + f"shape {first_shape}." + ) + if total_size_on_axis is None or x.shape[0] is None: + total_size_on_axis = None + else: + total_size_on_axis += x.shape[0] + dtypes_to_resolve.append(getattr(x, "dtype", type(x))) + output_shape = list(first_shape) + output_shape[0] = total_size_on_axis + output_dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, output_dtype) + + +@keras_export(["keras.ops.vstack", "keras.ops.numpy.vstack"]) +def vstack(xs): + """Stack tensors in sequence vertically (row wise). + + Args: + xs: Sequence of tensors. + + Returns: + Tensor formed by stacking the given tensors. + """ + if any_symbolic_tensors((xs,)): + return Vstack().symbolic_call(xs) + return backend.numpy.vstack(xs) + + +class Where(Operation): + def call(self, condition, x1=None, x2=None): + return backend.numpy.where(condition, x1, x2) + + def compute_output_spec(self, condition, x1, x2): + condition_shape = getattr(condition, "shape", []) + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(condition_shape, x1_shape) + output_shape = broadcast_shapes(output_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1) if x1 is not None else "int"), + getattr(x2, "dtype", type(x2) if x2 is not None else "int"), + ) + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.where", "keras.ops.numpy.where"]) +def where(condition, x1=None, x2=None): + """Return elements chosen from `x1` or `x2` depending on `condition`. + + Args: + condition: Where `True`, yield `x1`, otherwise yield `x2`. + x1: Values from which to choose when `condition` is `True`. + x2: Values from which to choose when `condition` is `False`. + + Returns: + A tensor with elements from `x1` where `condition` is `True`, and + elements from `x2` where `condition` is `False`. + """ + if (x1 is None and x2 is not None) or (x1 is not None and x2 is None): + raise ValueError( + "`x1` and `x2` either both should be `None`" + " or both should have non-None value." + ) + if any_symbolic_tensors((condition, x1, x2)): + return Where().symbolic_call(condition, x1, x2) + return backend.numpy.where(condition, x1, x2) + + +class Subtract(Operation): + def call(self, x1, x2): + return backend.numpy.subtract(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and x2_sparse + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor(output_shape, dtype=dtype, sparse=output_sparse) + + +@keras_export(["keras.ops.subtract", "keras.ops.numpy.subtract"]) +def subtract(x1, x2): + """Subtract arguments element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise difference of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Subtract().symbolic_call(x1, x2) + return backend.numpy.subtract(x1, x2) + + +class Multiply(Operation): + def call(self, x1, x2): + return backend.numpy.multiply(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + x1_sparse = getattr(x1, "sparse", True) + x2_sparse = getattr(x2, "sparse", True) + output_sparse = x1_sparse or x2_sparse + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor(output_shape, dtype=dtype, sparse=output_sparse) + + +@keras_export(["keras.ops.multiply", "keras.ops.numpy.multiply"]) +def multiply(x1, x2): + """Multiply arguments element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, element-wise product of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Multiply().symbolic_call(x1, x2) + return backend.numpy.multiply(x1, x2) + + +class Divide(Operation): + def call(self, x1, x2): + return backend.numpy.divide(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and not x2_sparse + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) + + +@keras_export(["keras.ops.divide", "keras.ops.numpy.divide"]) +def divide(x1, x2): + """Divide arguments element-wise. + + `keras.ops.true_divide` is an alias for this function. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output tensor, the quotient `x1/x2`, element-wise. + """ + if any_symbolic_tensors((x1, x2)): + return Divide().symbolic_call(x1, x2) + return backend.numpy.divide(x1, x2) + + +class DivideNoNan(Operation): + def call(self, x1, x2): + return backend.numpy.divide_no_nan(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and not x2_sparse + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) + + +@keras_export(["keras.ops.divide_no_nan", "keras.ops.numpy.divide_no_nan"]) +def divide_no_nan(x1, x2): + """Safe element-wise division which returns 0 where the denominator is 0. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + The quotient `x1/x2`, element-wise, with zero where x2 is zero. + """ + if any_symbolic_tensors((x1, x2)): + return DivideNoNan().symbolic_call(x1, x2) + return backend.numpy.divide_no_nan(x1, x2) + + +class TrueDivide(Operation): + def call(self, x1, x2): + return backend.numpy.true_divide(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + float, + ) + x1_sparse = getattr(x1, "sparse", False) + x2_sparse = getattr(x2, "sparse", False) + output_sparse = x1_sparse and not x2_sparse + return KerasTensor( + output_shape, dtype=output_dtype, sparse=output_sparse + ) + + +@keras_export( + [ + "keras.ops.true_divide", + "keras.ops.numpy.true_divide", + ] +) +def true_divide(x1, x2): + """Alias for `keras.ops.divide`.""" + if any_symbolic_tensors((x1, x2)): + return TrueDivide().symbolic_call(x1, x2) + return backend.numpy.true_divide(x1, x2) + + +class Power(Operation): + def call(self, x1, x2): + return backend.numpy.power(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), getattr(x2, "dtype", type(x2)) + ) + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.power", "keras.ops.numpy.power"]) +def power(x1, x2): + """First tensor elements raised to powers from second tensor, element-wise. + + Args: + x1: The bases. + x2: The exponents. + + Returns: + Output tensor, the bases in `x1` raised to the exponents in `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Power().symbolic_call(x1, x2) + return backend.numpy.power(x1, x2) + + +class Negative(Operation): + def call(self, x): + return backend.numpy.negative(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.negative", "keras.ops.numpy.negative"]) +def negative(x): + """Numerical negative, element-wise. + + Args: + x: Input tensor. + + Returns: + Output tensor, `y = -x`. + """ + if any_symbolic_tensors((x,)): + return Negative().symbolic_call(x) + return backend.numpy.negative(x) + + +class Square(Operation): + def call(self, x): + return backend.numpy.square(x) + + def compute_output_spec(self, x): + sparse = getattr(x, "sparse", False) + dtype = backend.standardize_dtype(x.dtype) + if dtype == "bool": + dtype = "int32" + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.square", "keras.ops.numpy.square"]) +def square(x): + """Return the element-wise square of the input. + + Args: + x: Input tensor. + + Returns: + Output tensor, the square of `x`. + """ + if any_symbolic_tensors((x,)): + return Square().symbolic_call(x) + return backend.numpy.square(x) + + +class Sqrt(Operation): + def call(self, x): + x = backend.convert_to_tensor(x) + return backend.numpy.sqrt(x) + + def compute_output_spec(self, x): + dtype = ( + backend.floatx() + if backend.standardize_dtype(x.dtype) == "int64" + else dtypes.result_type(x.dtype, float) + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(x.shape, dtype=dtype, sparse=sparse) + + +@keras_export(["keras.ops.sqrt", "keras.ops.numpy.sqrt"]) +def sqrt(x): + """Return the non-negative square root of a tensor, element-wise. + + Args: + x: Input tensor. + + Returns: + Output tensor, the non-negative square root of `x`. + """ + if any_symbolic_tensors((x,)): + return Sqrt().symbolic_call(x) + x = backend.convert_to_tensor(x) + return backend.numpy.sqrt(x) + + +class Squeeze(Operation): + def __init__(self, axis=None, *, name=None): + super().__init__(name=name) + self.axis = axis + + def call(self, x): + return backend.numpy.squeeze(x, axis=self.axis) + + def compute_output_spec(self, x): + input_shape = list(x.shape) + sparse = getattr(x, "sparse", False) + axis = to_tuple_or_list(self.axis) + if axis is None: + output_shape = list(filter((1).__ne__, input_shape)) + return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse) + else: + for a in axis: + if input_shape[a] != 1: + raise ValueError( + f"Cannot squeeze axis {a}, because the dimension " + "is not 1." + ) + axis = [canonicalize_axis(a, len(input_shape)) for a in axis] + for a in sorted(axis, reverse=True): + del input_shape[a] + return KerasTensor(input_shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.squeeze", "keras.ops.numpy.squeeze"]) +def squeeze(x, axis=None): + """Remove axes of length one from `x`. + + Args: + x: Input tensor. + axis: Select a subset of the entries of length one in the shape. + + Returns: + The input tensor with all or a subset of the dimensions of + length 1 removed. + """ + if any_symbolic_tensors((x,)): + return Squeeze(axis=axis).symbolic_call(x) + return backend.numpy.squeeze(x, axis=axis) + + +class Transpose(Operation): + def __init__(self, axes=None, *, name=None): + super().__init__(name=name) + self.axes = axes + + def call(self, x): + return backend.numpy.transpose(x, axes=self.axes) + + def compute_output_spec(self, x): + output_shape = operation_utils.compute_transpose_output_shape( + x.shape, self.axes + ) + sparse = getattr(x, "sparse", False) + return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse) + + +@keras_export(["keras.ops.transpose", "keras.ops.numpy.transpose"]) +def transpose(x, axes=None): + """Returns a tensor with `axes` transposed. + + Args: + x: Input tensor. + axes: Sequence of integers. Permutation of the dimensions of `x`. + By default, the order of the axes are reversed. + + Returns: + `x` with its axes permuted. + """ + if any_symbolic_tensors((x,)): + return Transpose(axes=axes).symbolic_call(x) + return backend.numpy.transpose(x, axes=axes) + + +class Mean(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.mean(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + ori_dtype = backend.standardize_dtype(x.dtype) + compute_dtype = dtypes.result_type(x.dtype, "float32") + if "int" in ori_dtype or ori_dtype == "bool": + result_dtype = compute_dtype + else: + result_dtype = ori_dtype + sparse = getattr(x, "sparse", False) + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=result_dtype, + sparse=sparse, + ) + + +@keras_export(["keras.ops.mean", "keras.ops.numpy.mean"]) +def mean(x, axis=None, keepdims=False): + """Compute the arithmetic mean along the specified axes. + + Args: + x: Input tensor. + axis: Axis or axes along which the means are computed. The default + is to compute the mean of the flattened tensor. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Returns: + Output tensor containing the mean values. + """ + if any_symbolic_tensors((x,)): + return Mean(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.mean(x, axis=axis, keepdims=keepdims) + + +class Var(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.var(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + output_dtype = backend.result_type(getattr(x, "dtype", type(x)), float) + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=output_dtype, + ) + + +@keras_export(["keras.ops.var", "keras.ops.numpy.var"]) +def var(x, axis=None, keepdims=False): + """Compute the variance along the specified axes. + + Args: + x: Input tensor. + axis: Axis or axes along which the variance is computed. The default + is to compute the variance of the flattened tensor. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Returns: + Output tensor containing the variance. + """ + if any_symbolic_tensors((x,)): + return Var(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.var(x, axis=axis, keepdims=keepdims) + + +class Sum(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.sum(x, axis=self.axis, keepdims=self.keepdims) + + def compute_output_spec(self, x): + dtype = dtypes.result_type(getattr(x, "dtype", backend.floatx())) + # follow jax's rule + if dtype in ("bool", "int8", "int16"): + dtype = "int32" + elif dtype in ("uint8", "uint16"): + dtype = "uint32" + # TODO: torch doesn't support uint32 + if backend.backend() == "torch" and dtype == "uint32": + dtype = "int32" + sparse = getattr(x, "sparse", False) + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=dtype, + sparse=sparse, + ) + + +@keras_export(["keras.ops.sum", "keras.ops.numpy.sum"]) +def sum(x, axis=None, keepdims=False): + """Sum of a tensor over the given axes. + + Args: + x: Input tensor. + axis: Axis or axes along which the sum is computed. The default is to + compute the sum of the flattened tensor. + keepdims: If this is set to `True`, the axes which are reduced are left + in the result as dimensions with size one. + + Returns: + Output tensor containing the sum. + """ + if any_symbolic_tensors((x,)): + return Sum(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.numpy.sum(x, axis=axis, keepdims=keepdims) + + +@keras_export(["keras.ops.zeros", "keras.ops.numpy.zeros"]) +def zeros(shape, dtype=None): + """Return a new tensor of given shape and type, filled with zeros. + + Args: + shape: Shape of the new tensor. + dtype: Desired data type of the tensor. + + Returns: + Tensor of zeros with the given shape and dtype. + """ + return backend.numpy.zeros(shape, dtype=dtype) + + +@keras_export(["keras.ops.ones", "keras.ops.numpy.ones"]) +def ones(shape, dtype=None): + """Return a new tensor of given shape and type, filled with ones. + + Args: + shape: Shape of the new tensor. + dtype: Desired data type of the tensor. + + Returns: + Tensor of ones with the given shape and dtype. + """ + return backend.numpy.ones(shape, dtype=dtype) + + +@keras_export(["keras.ops.eye", "keras.ops.numpy.eye"]) +def eye(N, M=None, k=0, dtype=None): + """Return a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + N: Number of rows in the output. + M: Number of columns in the output. If `None`, defaults to `N`. + k: Index of the diagonal: 0 (the default) refers to the main + diagonal, a positive value refers to an upper diagonal, + and a negative value to a lower diagonal. + dtype: Data type of the returned tensor. + + Returns: + Tensor with ones on the k-th diagonal and zeros elsewhere. + """ + return backend.numpy.eye(N, M=M, k=k, dtype=dtype) + + +class FloorDivide(Operation): + def call(self, x1, x2): + return backend.numpy.floor_divide(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.floor_divide", "keras.ops.numpy.floor_divide"]) +def floor_divide(x1, x2): + """Returns the largest integer smaller or equal to the division of inputs. + + Args: + x1: Numerator. + x2: Denominator. + + Returns: + Output tensor, `y = floor(x1/x2)` + """ + if any_symbolic_tensors((x1, x2)): + return FloorDivide().symbolic_call(x1, x2) + return backend.numpy.floor_divide(x1, x2) + + +class LogicalXor(Operation): + def call(self, x1, x2): + return backend.numpy.logical_xor(x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + output_shape = broadcast_shapes(x1_shape, x2_shape) + return KerasTensor(output_shape, dtype="bool") + + +@keras_export(["keras.ops.logical_xor", "keras.ops.numpy.logical_xor"]) +def logical_xor(x1, x2): + """Compute the truth value of `x1 XOR x2`, element-wise. + + Args: + x1: First input tensor. + x2: Second input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x1, x2)): + return LogicalXor().symbolic_call(x1, x2) + return backend.numpy.logical_xor(x1, x2) + + +class Corrcoef(Operation): + def call(self, x): + return backend.numpy.corrcoef(x) + + def compute_output_spec(self, x): + dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx())) + if dtype == "int64": + dtype = "float64" + else: + dtype = dtypes.result_type(dtype, float) + return KerasTensor(x.shape, dtype=dtype) + + +@keras_export(["keras.ops.corrcoef", "keras.ops.numpy.corrcoef"]) +def corrcoef(x): + """Compute the Pearson correlation coefficient matrix. + + Args: + x: A 2D tensor of shape `(N, D)`, where N is the number of variables + and D is the number of observations. + + Returns: + A tensor of shape `(N, N)` representing the correlation matrix. + """ + if any_symbolic_tensors((x,)): + return Corrcoef().symbolic_call(x) + return backend.numpy.corrcoef(x) + + +class Correlate(Operation): + def __init__(self, mode="valid", *, name=None): + super().__init__(name=name) + self.mode = mode + + def call(self, x1, x2): + return backend.numpy.correlate(x1, x2, mode=self.mode) + + def compute_output_spec(self, x1, x2): + x1_shape = getattr(x1, "shape", []) + x2_shape = getattr(x2, "shape", []) + if len(x1_shape) != 1: + raise ValueError( + "`x1` must be a 1-dimensional tensor, but received" + + f"shape {x1_shape}" + ) + if len(x2_shape) != 1: + raise ValueError( + "`x2` must be a 1-dimensional tensor, but received" + + f"shape {x2_shape}" + ) + x1_len, x2_len = x1_shape[0], x2_shape[0] + output_shape = ( + np.maximum(x1_len, x2_len) - np.minimum(x1_len, x2_len) + 1, + ) + if self.mode == "same": + output_shape = (np.maximum(x1_len, x2_len),) + elif self.mode == "full": + output_shape = (x1_len + x2_len - 1,) + if self.mode not in ("valid", "same", "full"): + raise ValueError( + "`mode` must be either `valid`, `same`, or `full`, but" + f"received: {self.mode}" + ) + output_dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) + if output_dtype == "int64": + output_dtype = "float64" + elif output_dtype not in ["bfloat16", "float16", "float64"]: + output_dtype = "float32" + return KerasTensor(output_shape, dtype=output_dtype) + + +@keras_export(["keras.ops.correlate", "keras.ops.numpy.correlate"]) +def correlate(x1, x2, mode="valid"): + """Compute the cross-correlation of two 1-dimensional tensors. + + Args: + x1: First 1-dimensional input tensor of length M. + x2: Second 1-dimensional input tensor of length N. + mode: Either `valid`, `same` or `full`. + By default the mode is set to `valid`, which returns + an output of length max(M, N) - min(M, N) + 1. + `same` returns an output of length max(M, N). + `full` mode returns the convolution at each point of + overlap, with an output length of N+M-1 + + Returns: + Output tensor, cross-correlation of `x1` and `x2`. + """ + if any_symbolic_tensors((x1, x2)): + return Correlate(mode=mode).symbolic_call(x1, x2) + return backend.numpy.correlate(x1, x2, mode=mode) + + +class Select(Operation): + def call(self, condlist, choicelist, default=0): + return backend.numpy.select(condlist, choicelist, default) + + def compute_output_spec(self, condlist, choicelist, default=0): + first_element = choicelist[0] + return KerasTensor(first_element.shape, dtype=first_element.dtype) + + +@keras_export(["keras.ops.select", "keras.ops.numpy.select"]) +def select(condlist, choicelist, default=0): + """Return elements from `choicelist`, based on conditions in `condlist`. + + Args: + condlist: List of boolean tensors. + The list of conditions which determine from which array + in choicelist the output elements are taken. + When multiple conditions are satisfied, + the first one encountered in condlist is used. + choicelist: List of tensors. + The list of tensors from which the output elements are taken. + This list has to be of the same length as `condlist`. + defaults: Optional scalar value. + The element inserted in the output + when all conditions evaluate to `False`. + + Returns: + Tensor where the output at position `m` is the `m`-th element + of the tensor in `choicelist` where the `m`-th element of the + corresponding tensor in `condlist` is `True`. + + Example: + + ```python + from keras import ops + + x = ops.arange(6) + condlist = [x<3, x>3] + choicelist = [x, x**2] + ops.select(condlist, choicelist, 42) + # Returns: tensor([0, 1, 2, 42, 16, 25]) + ``` + """ + if not isinstance(condlist, (list, tuple)) or not isinstance( + choicelist, (list, tuple) + ): + raise ValueError( + "condlist and choicelist must be lists. Received: " + f"type(condlist) = {type(condlist)}, " + f"type(choicelist) = {type(choicelist)}" + ) + condlist = list(condlist) + choicelist = list(choicelist) + if not condlist or not choicelist: + raise ValueError( + "condlist and choicelist must not be empty. Received: " + f"condlist = {condlist}, " + f"choicelist = {choicelist}" + ) + if any_symbolic_tensors(condlist + choicelist + [default]): + return Select().symbolic_call(condlist, choicelist, default) + return backend.numpy.select(condlist, choicelist, default) + + +class Slogdet(Operation): + def call(self, x): + return backend.numpy.slogdet(x) + + def compute_output_spec(self, x): + sign = KerasTensor((), dtype=x.dtype) + logabsdet = KerasTensor(x.shape[:-2], dtype=x.dtype) + return (sign, logabsdet) + + +@keras_export(["keras.ops.slogdet", "keras.ops.numpy.slogdet"]) +def slogdet(x): + """Compute the sign and natural logarithm of the determinant of a matrix. + + Args: + x: Input matrix. It must 2D and square. + + Returns: + A tuple `(sign, logabsdet)`. `sign` is a number representing + the sign of the determinant. For a real matrix, this is 1, 0, or -1. + For a complex matrix, this is a complex number with absolute value 1 + (i.e., it is on the unit circle), or else 0. + `logabsdet` is the natural log of the absolute value of the determinant. + """ + if any_symbolic_tensors((x,)): + return Slogdet().symbolic_call(x) + return backend.numpy.slogdet(x) + + +class Argpartition(Operation): + def __init__(self, kth, axis=-1, *, name=None): + super().__init__(name=name) + if not isinstance(kth, int): + raise ValueError(f"kth must be an integer. Received:kth = {kth}") + self.kth = kth + self.axis = axis + + def call(self, x): + return backend.numpy.argpartition(x, kth=self.kth, axis=self.axis) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="int32") + + +@keras_export(["keras.ops.argpartition", "keras.ops.numpy.argpartition"]) +def argpartition(x, kth, axis=-1): + """Performs an indirect partition along the given axis. + + It returns an array + of indices of the same shape as `x` that index data along the given axis + in partitioned order. + + Args: + a: Array to sort. + kth: Element index to partition by. + The k-th element will be in its final sorted position and all + smaller elements will be moved before it and all larger elements + behind it. The order of all elements in the partitions is undefined. + If provided with a sequence of k-th it will partition all of them + into their sorted position at once. + axis: Axis along which to sort. The default is -1 (the last axis). + If `None`, the flattened array is used. + + Returns: + Array of indices that partition `x` along the specified `axis`. + """ + if any_symbolic_tensors((x,)): + return Argpartition(kth, axis).symbolic_call(x) + return backend.numpy.argpartition(x, kth, axis) + + +class Histogram(Operation): + def __init__(self, bins=10, range=None, *, name=None): + super().__init__(name=name) + + if not isinstance(bins, int): + raise TypeError("bins must be of type `int`") + if bins < 0: + raise ValueError("`bins` should be a non-negative integer") + + if range: + if len(range) < 2 or not isinstance(range, tuple): + raise ValueError("range must be a tuple of two elements") + + if range[1] < range[0]: + raise ValueError( + "The second element of range must be greater than the first" + ) + + self.bins = bins + self.range = range + + def call(self, x): + x = backend.convert_to_tensor(x) + if len(x.shape) > 1: + raise ValueError("Input tensor must be 1-dimensional") + return backend.math.histogram(x, bins=self.bins, range=self.range) + + def compute_output_spec(self, x): + return ( + KerasTensor(shape=(self.bins,), dtype=x.dtype), + KerasTensor(shape=(self.bins + 1,), dtype=x.dtype), + ) + + +@keras_export(["keras.ops.histogram", "keras.ops.numpy.histogram"]) +def histogram(x, bins=10, range=None): + """Computes a histogram of the data tensor `x`. + + Args: + x: Input tensor. + bins: An integer representing the number of histogram bins. + Defaults to 10. + range: A tuple representing the lower and upper range of the bins. + If not specified, it will use the min and max of `x`. + + Returns: + A tuple containing: + - A tensor representing the counts of elements in each bin. + - A tensor representing the bin edges. + + Example: + >>> input_tensor = np.random.rand(8) + >>> keras.ops.histogram(input_tensor) + (array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32), + array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262, + 0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101, + 0.85892869])) + """ + if not isinstance(bins, int): + raise TypeError( + f"Argument `bins` must be of type `int`. Received: bins={bins}" + ) + if bins < 0: + raise ValueError( + "Argument `bins` should be a non-negative integer. " + f"Received: bins={bins}" + ) + + if range: + if len(range) < 2 or not isinstance(range, tuple): + raise ValueError( + "Argument `range` must be a tuple of two elements. " + f"Received: range={range}" + ) + + if range[1] < range[0]: + raise ValueError( + "The second element of `range` must be greater than the first. " + f"Received: range={range}" + ) + + if any_symbolic_tensors((x,)): + return Histogram(bins=bins, range=range).symbolic_call(x) + + x = backend.convert_to_tensor(x) + if len(x.shape) > 1: + raise ValueError( + "Input tensor must be 1-dimensional. " + f"Received: input.shape={x.shape}" + ) + return backend.numpy.histogram(x, bins=bins, range=range) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py new file mode 100644 index 000000000000..998d18bd4b73 --- /dev/null +++ b/keras/src/ops/numpy_test.py @@ -0,0 +1,9499 @@ +import contextlib +import functools +import itertools +import math +import warnings + +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common import dtypes +from keras.src.backend.common import is_int_dtype +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.ops import numpy as knp +from keras.src.testing.test_utils import named_product + + +@contextlib.contextmanager +def jax_disable_x64_context(): + try: + # JAX v0.8.0 and newer + from jax import enable_x64 + except ImportError: + # JAX v0.7.2 and older + from jax.experimental import enable_x64 + with enable_x64(False): + yield + + +class NumPyTestRot90(testing.TestCase): + def test_basic_rotation(self): + array = np.array([[1, 2, 3], [4, 5, 6]]) + rotated = knp.rot90(array) + expected = np.rot90(array) + self.assertAllClose(rotated, expected) + + @parameterized.named_parameters( + ("k_0", 0, [[1, 2], [3, 4]]), + ("k_1", 1, [[2, 4], [1, 3]]), + ("k_2", 2, [[4, 3], [2, 1]]), + ("k_neg1", -1, [[3, 1], [4, 2]]), + ("k_5", 5, [[2, 4], [1, 3]]), # k=5 ≡ k=1 (mod 4) + ("k_6", 6, [[4, 3], [2, 1]]), # k=6 ≡ k=2 (mod 4) + ) + def test_k_parameter_variations(self, k, expected): + array = np.array([[1, 2], [3, 4]]) + rotated = knp.rot90(array, k=k) + expected = np.array(expected) + self.assertAllClose(rotated, expected) + + @parameterized.named_parameters( + ("axes_0_1", (0, 1)), ("axes_1_2", (1, 2)), ("axes_0_2", (0, 2)) + ) + def test_3d_operations(self, axes): + array_3d = np.arange(12).reshape(3, 2, 2) + rotated = knp.rot90(array_3d, axes=axes) + expected = np.rot90(array_3d, axes=axes) + self.assertAllClose(rotated, expected) + + @parameterized.named_parameters( + ("single_image", np.random.random((4, 4, 3))), + ("batch_images", np.random.random((2, 4, 4, 3))), + ) + def test_image_processing(self, array): + np.random.seed(0) + rotated = knp.rot90(array, axes=(0, 1)) + expected = np.rot90(array, axes=(0, 1)) + self.assertAllClose(rotated, expected) + + @parameterized.named_parameters( + ("single_row", [[1, 2, 3]]), + ("single_column", [[1], [2], [3]]), + ("negative_values", [[-1, 0], [1, -2]]), + ) + def test_edge_conditions(self, array): + numpy_array = np.array(array) + rotated = knp.rot90(numpy_array) + expected = np.rot90(numpy_array) + self.assertAllClose(rotated, expected) + + @parameterized.named_parameters( + ("1D_array", np.array([1, 2, 3]), None), + ("duplicate_axes", np.array([[1, 2], [3, 4]]), (0, 0)), + ) + def test_error_conditions(self, array, axes): + if axes is None: + with self.assertRaises(ValueError): + knp.rot90(array) + else: + with self.assertRaises(ValueError): + knp.rot90(array, axes=axes) + + +class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase): + def test_add(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.add(x, y).shape, (2, 3)) + + def test_heaviside(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.heaviside(x, y).shape, (None, 3)) + + def test_hypot(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.hypot(x, y).shape, (None, 3)) + + def test_subtract(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.subtract(x, y).shape, (2, 3)) + + def test_multiply(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.multiply(x, y).shape, (2, 3)) + + def test_matmul(self): + x = KerasTensor((None, 3, 4)) + y = KerasTensor((3, None, 4, 5)) + self.assertEqual(knp.matmul(x, y).shape, (3, None, 3, 5)) + + def test_power(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.power(x, y).shape, (2, 3)) + + def test_divide(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.divide(x, y).shape, (2, 3)) + + def test_divide_no_nan(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.divide_no_nan(x, y).shape, (2, 3)) + + def test_true_divide(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.true_divide(x, y).shape, (2, 3)) + + def test_append(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.append(x, y).shape, (None,)) + + def test_arctan2(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.arctan2(x, y).shape, (2, 3)) + + def test_bitwise_and(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.bitwise_and(x, y).shape, (None, 3)) + + def test_bitwise_or(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.bitwise_or(x, y).shape, (None, 3)) + + def test_bitwise_xor(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.bitwise_xor(x, y).shape, (None, 3)) + + def test_bitwise_left_shift(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.bitwise_left_shift(x, y).shape, (None, 3)) + + # left_shift is same as bitwise_left_shift + + def test_bitwise_right_shift(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.bitwise_right_shift(x, y).shape, (None, 3)) + + # right_shift is same as bitwise_right_shift + + def test_cross(self): + x1 = KerasTensor((2, 3, 3)) + x2 = KerasTensor((1, 3, 2)) + y = KerasTensor((None, 1, 2)) + self.assertEqual(knp.cross(x1, y).shape, (2, 3, 3)) + self.assertEqual(knp.cross(x2, y).shape, (None, 3)) + + def test_einsum(self): + x = KerasTensor((None, 3)) + y = KerasTensor((3, 4)) + self.assertEqual(knp.einsum("ij,jk->ik", x, y).shape, (None, 4)) + self.assertEqual(knp.einsum("ij,jk->ikj", x, y).shape, (None, 4, 3)) + self.assertEqual(knp.einsum("ii", x).shape, ()) + self.assertEqual(knp.einsum(",ij", 5, x).shape, (None, 3)) + + x = KerasTensor((None, 3, 4)) + y = KerasTensor((None, 4, 5)) + z = KerasTensor((1, 1, 1, 9)) + self.assertEqual(knp.einsum("ijk,jkl->li", x, y).shape, (5, None)) + self.assertEqual(knp.einsum("ijk,jkl->lij", x, y).shape, (5, None, 3)) + self.assertEqual( + knp.einsum("...,...j->...j", x, y).shape, (None, 3, 4, 5) + ) + self.assertEqual( + knp.einsum("i...,...j->i...j", x, y).shape, (None, 3, 4, 5) + ) + self.assertEqual(knp.einsum("i...,...j", x, y).shape, (3, 4, None, 5)) + self.assertEqual( + knp.einsum("i...,...j,...k", x, y, z).shape, (1, 3, 4, None, 5, 9) + ) + self.assertEqual( + knp.einsum("mij,ijk,...", x, y, z).shape, (1, 1, 1, 9, 5, None) + ) + + with self.assertRaises(ValueError): + x = KerasTensor((None, 3)) + y = KerasTensor((3, 4)) + knp.einsum("ijk,jk->ik", x, y) + + def test_full_like(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.full_like(x, KerasTensor((1, 3))).shape, (None, 3)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.full_like(x, 2).shape, (None, 3, 3)) + + def test_gcd(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.gcd(x, y).shape, (2, 3)) + + def test_greater(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.greater(x, y).shape, (2, 3)) + + def test_greater_equal(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.greater_equal(x, y).shape, (2, 3)) + + def test_isclose(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.isclose(x, y).shape, (2, 3)) + + def test_isin(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.isin(x, y).shape, (None, 3)) + + def test_kron(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.kron(x, y).shape, (None, None)) + + def test_lcm(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.lcm(x, y).shape, (2, 3)) + + def test_less(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.less(x, y).shape, (2, 3)) + + def test_less_equal(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.less_equal(x, y).shape, (2, 3)) + + def test_linspace(self): + start = KerasTensor((None, 3, 4)) + stop = KerasTensor((2, 3, 4)) + self.assertEqual( + knp.linspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4) + ) + + start = KerasTensor((None, 3)) + stop = 2 + self.assertEqual( + knp.linspace(start, stop, 10, axis=1).shape, (None, 10, 3) + ) + + def test_logical_and(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.logical_and(x, y).shape, (2, 3)) + + def test_logical_or(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.logical_or(x, y).shape, (2, 3)) + + def test_logspace(self): + start = KerasTensor((None, 3, 4)) + stop = KerasTensor((2, 3, 4)) + self.assertEqual( + knp.logspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4) + ) + + start = KerasTensor((None, 3)) + stop = 2 + self.assertEqual( + knp.logspace(start, stop, 10, axis=1).shape, (None, 10, 3) + ) + + def test_maximum(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.maximum(x, y).shape, (2, 3)) + + def test_minimum(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.minimum(x, y).shape, (2, 3)) + + def test_mod(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.mod(x, y).shape, (2, 3)) + + def test_not_equal(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.not_equal(x, y).shape, (2, 3)) + + def test_outer(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.outer(x, y).shape, (None, None)) + + def test_quantile(self): + x = KerasTensor((None, 3)) + + # q as scalar + q = KerasTensor(()) + self.assertEqual(knp.quantile(x, q).shape, ()) + + # q as 1D tensor + q = KerasTensor((2,)) + self.assertEqual(knp.quantile(x, q).shape, (2,)) + self.assertEqual(knp.quantile(x, q, axis=1).shape, (2, None)) + self.assertEqual( + knp.quantile(x, q, axis=1, keepdims=True).shape, + (2, None, 1), + ) + + def test_searchsorted(self): + a = KerasTensor((None,)) + v = KerasTensor((2, 3)) + + output = knp.searchsorted(a, v) + self.assertEqual(output.shape, v.shape) + self.assertEqual(output.dtype, "int64") + + def test_take(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.take(x, 1).shape, ()) + self.assertEqual(knp.take(x, [1, 2]).shape, (2,)) + self.assertEqual( + knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2) + ) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.take(x, 1, axis=1).shape, (None, 3)) + self.assertEqual(knp.take(x, [1, 2]).shape, (2,)) + self.assertEqual( + knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2, 3) + ) + + # test with negative axis + self.assertEqual(knp.take(x, 1, axis=-2).shape, (None, 3)) + + # test with multi-dimensional indices + x = KerasTensor((None, 3, None, 5)) + indices = KerasTensor((6, 7)) + self.assertEqual(knp.take(x, indices, axis=2).shape, (None, 3, 6, 7, 5)) + + def test_take_along_axis(self): + x = KerasTensor((None, 3)) + indices = KerasTensor((1, 3)) + self.assertEqual(knp.take_along_axis(x, indices, axis=0).shape, (1, 3)) + self.assertEqual( + knp.take_along_axis(x, indices, axis=1).shape, (None, 3) + ) + + x = KerasTensor((None, 3, 3)) + indices = KerasTensor((1, 3, None)) + self.assertEqual( + knp.take_along_axis(x, indices, axis=1).shape, (None, 3, 3) + ) + + def test_tensordot(self): + x = KerasTensor((None, 3, 4)) + y = KerasTensor((3, 4)) + self.assertEqual(knp.tensordot(x, y, axes=1).shape, (None, 3, 4)) + self.assertEqual(knp.tensordot(x, y, axes=[[0, 1], [1, 0]]).shape, (4,)) + + def test_vdot(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.vdot(x, y).shape, ()) + + x = KerasTensor((None, 3, 3)) + y = KerasTensor((None, 3, 3)) + self.assertEqual(knp.vdot(x, y).shape, ()) + + def test_inner(self): + x = KerasTensor((None,)) + y = KerasTensor((3,)) + self.assertEqual(knp.inner(x, y).shape, ()) + + def test_where(self): + condition = KerasTensor((2, None, 1)) + x = KerasTensor((None, 1)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.where(condition, x, y).shape, (2, None, 3)) + self.assertEqual(knp.where(condition).shape, (2, None, 1)) + + def test_floor_divide(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.floor_divide(x, y).shape, (2, 3)) + + def test_xor(self): + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.logical_xor(x, y).shape, (2, 3)) + + def test_shape_equal_basic_equality(self): + x = KerasTensor((3, 4)).shape + y = KerasTensor((3, 4)).shape + self.assertTrue(knp.shape_equal(x, y)) + y = KerasTensor((3, 5)).shape + self.assertFalse(knp.shape_equal(x, y)) + + def test_shape_equal_allow_none(self): + x = KerasTensor((3, 4, None)).shape + y = KerasTensor((3, 4, 5)).shape + self.assertTrue(knp.shape_equal(x, y, allow_none=True)) + self.assertFalse(knp.shape_equal(x, y, allow_none=False)) + + def test_shape_equal_different_shape_lengths(self): + x = KerasTensor((3, 4)).shape + y = KerasTensor((3, 4, 5)).shape + self.assertFalse(knp.shape_equal(x, y)) + + def test_shape_equal_ignore_axes(self): + x = KerasTensor((3, 4, 5)).shape + y = KerasTensor((3, 6, 5)).shape + self.assertTrue(knp.shape_equal(x, y, axis=1)) + y = KerasTensor((3, 6, 7)).shape + self.assertTrue(knp.shape_equal(x, y, axis=(1, 2))) + self.assertFalse(knp.shape_equal(x, y, axis=1)) + + def test_shape_equal_only_none(self): + x = KerasTensor((None, None)).shape + y = KerasTensor((5, 6)).shape + self.assertTrue(knp.shape_equal(x, y, allow_none=True)) + + def test_shape_equal_axis_as_list(self): + x = KerasTensor((3, 4, 5)).shape + y = KerasTensor((3, 6, 5)).shape + self.assertTrue(knp.shape_equal(x, y, axis=[1])) + + def test_shape_non_equal_with_negative_axis(self): + x = KerasTensor((3, 4, 5)).shape + y = KerasTensor((3, 4, 6)).shape + self.assertFalse(knp.shape_equal(x, y, axis=-2)) + + def test_shape_equal_with_negative_axis(self): + x = KerasTensor((3, 4, 5)).shape + y = KerasTensor((3, 4, 5)).shape + self.assertTrue(knp.shape_equal(x, y, axis=-1)) + + def test_shape_equal_zeros(self): + x = KerasTensor((0, 4)).shape + y = KerasTensor((0, 4)).shape + self.assertTrue(knp.shape_equal(x, y)) + y = KerasTensor((0, 5)).shape + self.assertFalse(knp.shape_equal(x, y)) + + def test_broadcast_shapes_conversion_to_list(self): + shape1 = KerasTensor((1, 2)).shape + shape2 = KerasTensor((3, 1)).shape + expected_output = [3, 2] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + def test_broadcast_shapes_shape1_longer_than_shape2(self): + shape1 = KerasTensor((5, 3, 2)).shape + shape2 = KerasTensor((1, 3)).shape + with self.assertRaisesRegex(ValueError, "Cannot broadcast shape"): + knp.broadcast_shapes(shape1, shape2) + + def test_broadcast_shapes_shape2_longer_than_shape1(self): + shape1 = KerasTensor((5, 3)).shape + shape2 = KerasTensor((2, 5, 3)).shape + expected_output = [2, 5, 3] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + def test_broadcast_shapes_broadcasting_shape1_is_1(self): + shape1 = KerasTensor((1, 3)).shape + shape2 = KerasTensor((5, 1)).shape + expected_output = [5, 3] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + def test_broadcast_shapes_broadcasting_shape1_is_none(self): + shape1 = KerasTensor((None, 3)).shape + shape2 = KerasTensor((5, 1)).shape + expected_output = [5, 3] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + shape1 = KerasTensor((None, 3)).shape + shape2 = KerasTensor((5, 3)).shape + expected_output = [5, 3] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + def test_broadcast_shapes_broadcasting_shape2_conditions(self): + shape1 = KerasTensor((5, 3, 2)).shape + shape2 = KerasTensor((1, 3, 2)).shape + expected_output = [5, 3, 2] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + shape1 = KerasTensor((5, 3, 2)).shape + shape2 = KerasTensor((1, None, 2)).shape + expected_output = [5, 3, 2] + self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output) + + +class NumpyTwoInputOpsStaticShapeTest(testing.TestCase): + def test_add(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.add(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.add(x, y) + + def test_heaviside(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.heaviside(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + y = KerasTensor((3,)) + self.assertEqual(knp.heaviside(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + y = KerasTensor((1, 3)) + self.assertEqual(knp.heaviside(x, y).shape, (2, 3)) + + def test_hypot(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.hypot(x, y).shape, (2, 3)) + + def test_subtract(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.subtract(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.subtract(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.subtract(x, y) + + def test_multiply(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.multiply(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.multiply(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.multiply(x, y) + + def test_matmul(self): + x = KerasTensor((2, 3)) + y = KerasTensor((3, 2)) + self.assertEqual(knp.matmul(x, y).shape, (2, 2)) + + with self.assertRaises(ValueError): + x = KerasTensor((3, 4)) + y = KerasTensor((2, 3, 4)) + knp.matmul(x, y) + + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_matmul_sparse(self): + x = KerasTensor((2, 3), sparse=True) + y = KerasTensor((3, 2)) + result = knp.matmul(x, y) + self.assertEqual(result.shape, (2, 2)) + + x = KerasTensor((2, 3)) + y = KerasTensor((3, 2), sparse=True) + result = knp.matmul(x, y) + self.assertEqual(result.shape, (2, 2)) + + x = KerasTensor((2, 3), sparse=True) + y = KerasTensor((3, 2), sparse=True) + result = knp.matmul(x, y) + self.assertEqual(result.shape, (2, 2)) + self.assertTrue(result.sparse) + + def test_power(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.power(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.power(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.power(x, y) + + def test_divide(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.divide(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.divide(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.divide(x, y) + + def test_divide_no_nan(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.divide_no_nan(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.divide_no_nan(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.divide_no_nan(x, y) + + def test_true_divide(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.true_divide(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.true_divide(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.true_divide(x, y) + + def test_append(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.append(x, y).shape, (12,)) + + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.append(x, y, axis=0).shape, (4, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.append(x, y, axis=2) + + def test_arctan2(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.arctan2(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.arctan2(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.arctan2(x, y) + + def test_bitwise_and(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.bitwise_and(x, y).shape, (2, 3)) + + def test_bitwise_or(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.bitwise_or(x, y).shape, (2, 3)) + + def test_bitwise_xor(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.bitwise_xor(x, y).shape, (2, 3)) + + def test_bitwise_left_shift(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.bitwise_left_shift(x, y).shape, (2, 3)) + + # left_shift is same as bitwise_left_shift + + def test_bitwise_right_shift(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.bitwise_right_shift(x, y).shape, (2, 3)) + + # right_shift is same as bitwise_right_shift + + def test_cross(self): + x1 = KerasTensor((2, 3, 3)) + x2 = KerasTensor((1, 3, 2)) + y1 = KerasTensor((2, 3, 3)) + y2 = KerasTensor((2, 3, 2)) + self.assertEqual(knp.cross(x1, y1).shape, (2, 3, 3)) + self.assertEqual(knp.cross(x2, y2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.cross(x, y) + + with self.assertRaises(ValueError): + x = KerasTensor((4, 3, 3)) + y = KerasTensor((2, 3, 3)) + knp.cross(x, y) + + def test_einsum(self): + x = KerasTensor((2, 3)) + y = KerasTensor((3, 4)) + self.assertEqual(knp.einsum("ij,jk->ik", x, y).shape, (2, 4)) + self.assertEqual(knp.einsum("ij,jk->ikj", x, y).shape, (2, 4, 3)) + self.assertEqual(knp.einsum("ii", x).shape, ()) + self.assertEqual(knp.einsum(",ij", 5, x).shape, (2, 3)) + + x = KerasTensor((2, 3, 4)) + y = KerasTensor((3, 4, 5)) + z = KerasTensor((1, 1, 1, 9)) + self.assertEqual(knp.einsum("ijk,jkl->li", x, y).shape, (5, 2)) + self.assertEqual(knp.einsum("ijk,jkl->lij", x, y).shape, (5, 2, 3)) + self.assertEqual(knp.einsum("...,...j->...j", x, y).shape, (2, 3, 4, 5)) + self.assertEqual( + knp.einsum("i...,...j->i...j", x, y).shape, (2, 3, 4, 5) + ) + self.assertEqual(knp.einsum("i...,...j", x, y).shape, (3, 4, 2, 5)) + self.assertEqual(knp.einsum("i...,...j", x, y).shape, (3, 4, 2, 5)) + self.assertEqual( + knp.einsum("i...,...j,...k", x, y, z).shape, (1, 3, 4, 2, 5, 9) + ) + self.assertEqual( + knp.einsum("mij,ijk,...", x, y, z).shape, (1, 1, 1, 9, 5, 2) + ) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((3, 4)) + knp.einsum("ijk,jk->ik", x, y) + + def test_full_like(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.full_like(x, 2).shape, (2, 3)) + + def test_gcd(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.gcd(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.gcd(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.gcd(x, y) + + def test_greater(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.greater(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.greater(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.greater(x, y) + + def test_greater_equal(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.greater_equal(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.greater_equal(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.greater_equal(x, y) + + def test_isclose(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.isclose(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.isclose(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.isclose(x, y) + + def test_isin(self): + x = KerasTensor((2, 3)) + y = KerasTensor((3, 3)) + self.assertEqual(knp.isin(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.isin(x, 2).shape, (2, 3)) + + def test_kron(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.kron(x, y).shape, (4, 9)) + + def test_lcm(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.lcm(x, y).shape, (2, 3)) + + def test_less(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.less(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.less(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.less(x, y) + + def test_less_equal(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.less_equal(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.less_equal(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.less_equal(x, y) + + def test_linspace(self): + start = KerasTensor((2, 3, 4)) + stop = KerasTensor((2, 3, 4)) + self.assertEqual(knp.linspace(start, stop, 10).shape, (10, 2, 3, 4)) + + with self.assertRaises(ValueError): + start = KerasTensor((2, 3)) + stop = KerasTensor((2, 3, 4)) + knp.linspace(start, stop) + + def test_logical_and(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.logical_and(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.logical_and(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.logical_and(x, y) + + def test_logical_or(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.logical_or(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.logical_or(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.logical_or(x, y) + + def test_logspace(self): + start = KerasTensor((2, 3, 4)) + stop = KerasTensor((2, 3, 4)) + self.assertEqual(knp.logspace(start, stop, 10).shape, (10, 2, 3, 4)) + + with self.assertRaises(ValueError): + start = KerasTensor((2, 3)) + stop = KerasTensor((2, 3, 4)) + knp.logspace(start, stop) + + def test_maximum(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.maximum(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.maximum(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.maximum(x, y) + + def test_minimum(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.minimum(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.minimum(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.minimum(x, y) + + def test_mod(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.mod(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.mod(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.mod(x, y) + + def test_not_equal(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.not_equal(x, y).shape, (2, 3)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.not_equal(x, 2).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.not_equal(x, y) + + def test_outer(self): + x = KerasTensor((3,)) + y = KerasTensor((4,)) + self.assertEqual(knp.outer(x, y).shape, (3, 4)) + + x = KerasTensor((2, 3)) + y = KerasTensor((4, 5)) + self.assertEqual(knp.outer(x, y).shape, (6, 20)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.outer(x, 2).shape, (6, 1)) + + def test_quantile(self): + x = KerasTensor((3, 3)) + + # q as scalar + q = KerasTensor(()) + self.assertEqual(knp.quantile(x, q).shape, ()) + + # q as 1D tensor + q = KerasTensor((2,)) + self.assertEqual(knp.quantile(x, q).shape, (2,)) + self.assertEqual(knp.quantile(x, q, axis=1).shape, (2, 3)) + self.assertEqual( + knp.quantile(x, q, axis=1, keepdims=True).shape, + (2, 3, 1), + ) + + def test_searchsorted(self): + a = KerasTensor((3,)) + v = KerasTensor((2, 3)) + + self.assertEqual(knp.searchsorted(a, v).shape, v.shape) + + def test_take(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.take(x, 1).shape, ()) + self.assertEqual(knp.take(x, [1, 2]).shape, (2,)) + self.assertEqual(knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (2, 2, 2)) + + # test with multi-dimensional indices + x = KerasTensor((2, 3, 4, 5)) + indices = KerasTensor((6, 7)) + self.assertEqual(knp.take(x, indices, axis=2).shape, (2, 3, 6, 7, 5)) + + def test_take_along_axis(self): + x = KerasTensor((2, 3)) + indices = KerasTensor((1, 3)) + self.assertEqual(knp.take_along_axis(x, indices, axis=0).shape, (1, 3)) + self.assertEqual(knp.take_along_axis(x, indices, axis=1).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + indices = KerasTensor((1, 4)) + knp.take_along_axis(x, indices, axis=0) + + def test_tensordot(self): + x = KerasTensor((2, 3, 3)) + y = KerasTensor((3, 3, 4)) + self.assertEqual(knp.tensordot(x, y, axes=1).shape, (2, 3, 3, 4)) + self.assertEqual(knp.tensordot(x, y, axes=2).shape, (2, 4)) + self.assertEqual( + knp.tensordot(x, y, axes=[[1, 2], [0, 1]]).shape, (2, 4) + ) + + def test_vdot(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.vdot(x, y).shape, ()) + + def test_inner(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.inner(x, y).shape, ()) + + def test_where(self): + condition = KerasTensor((2, 3)) + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.where(condition, x, y).shape, (2, 3)) + self.assertAllEqual(knp.where(condition).shape, (2, 3)) + + def test_floor_divide(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.floor_divide(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.floor_divide(x, y) + + def test_xor(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.logical_xor(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + knp.logical_xor(x, y) + + def test_digitize(self): + x = KerasTensor((2, 3)) + bins = KerasTensor((3,)) + self.assertEqual(knp.digitize(x, bins).shape, (2, 3)) + self.assertTrue(knp.digitize(x, bins).dtype == "int32") + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + bins = KerasTensor((2, 3, 4)) + knp.digitize(x, bins) + + def test_correlate_mode_valid(self): + x = KerasTensor((3,)) + y = KerasTensor((3,)) + self.assertEqual(knp.correlate(x, y).shape, (1,)) + self.assertTrue(knp.correlate(x, y).dtype == "float32") + + with self.assertRaises(ValueError): + x = KerasTensor((3,)) + y = KerasTensor((3, 4)) + knp.correlate(x, y) + + def test_correlate_mode_same(self): + x = KerasTensor((3,)) + y = KerasTensor((3,)) + self.assertEqual(knp.correlate(x, y, mode="same").shape, (3,)) + self.assertTrue(knp.correlate(x, y, mode="same").dtype == "float32") + + with self.assertRaises(ValueError): + x = KerasTensor((3,)) + y = KerasTensor((3, 4)) + knp.correlate(x, y, mode="same") + + def test_correlate_mode_full(self): + x = KerasTensor((3,)) + y = KerasTensor((3,)) + self.assertEqual(knp.correlate(x, y, mode="full").shape, (5,)) + self.assertTrue(knp.correlate(x, y, mode="full").dtype == "float32") + + with self.assertRaises(ValueError): + x = KerasTensor((3)) + y = KerasTensor((3, 4)) + knp.correlate(x, y, mode="full") + + +class NumpyOneInputOpsDynamicShapeTest(testing.TestCase): + def test_mean(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.mean(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.mean(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.mean(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_all(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.all(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.all(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.all(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_any(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.any(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.any(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.any(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_var(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.var(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.var(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.var(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_sum(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.sum(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.sum(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.sum(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_amax(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.amax(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.amax(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.amax(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_amin(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.amin(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.amin(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.amin(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_square(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.square(x).shape, (None, 3)) + + def test_negative(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.negative(x).shape, (None, 3)) + + def test_abs(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.abs(x).shape, (None, 3)) + + def test_absolute(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.absolute(x).shape, (None, 3)) + + def test_squeeze(self): + x = KerasTensor((None, 1)) + self.assertEqual(knp.squeeze(x).shape, (None,)) + self.assertEqual(knp.squeeze(x, axis=1).shape, (None,)) + + with self.assertRaises(ValueError): + x = KerasTensor((None, 1)) + knp.squeeze(x, axis=0) + + # Multiple axes + x = KerasTensor((None, 1, 1, 1)) + self.assertEqual(knp.squeeze(x, (1, 2)).shape, (None, 1)) + self.assertEqual(knp.squeeze(x, (-1, -2)).shape, (None, 1)) + self.assertEqual(knp.squeeze(x, (1, 2, 3)).shape, (None,)) + self.assertEqual(knp.squeeze(x, (-1, 1)).shape, (None, 1)) + + def test_transpose(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.transpose(x).shape, (3, None)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.transpose(x, (2, 0, 1)).shape, (3, None, 3)) + + def test_arccos(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.arccos(x).shape, (None, 3)) + + def test_arccosh(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.arccosh(x).shape, (None, 3)) + + def test_arcsin(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.arcsin(x).shape, (None, 3)) + + def test_arcsinh(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.arcsinh(x).shape, (None, 3)) + + def test_arctan(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.arctan(x).shape, (None, 3)) + + def test_arctanh(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.arctanh(x).shape, (None, 3)) + + def test_argmax(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.argmax(x).shape, ()) + self.assertEqual(knp.argmax(x, keepdims=True).shape, (None, 3)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.argmax(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.argmax(x, keepdims=True).shape, (None, 3, 3)) + + @pytest.mark.skipif( + keras.config.backend() == "openvino", + reason="OpenVINO doesn't support this change", + ) + def test_argmax_negative_zero(self): + input_data = np.array( + [-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32 + ) + self.assertEqual(knp.argmax(input_data), 2) + + @pytest.mark.skipif( + keras.config.backend() == "openvino" + or keras.config.backend() == "tensorflow", + reason=""" + OpenVINO and TensorFlow don't support this + change, TensorFlow behavior for this case is under + evaluation and may change within this PR + """, + ) + def test_argmin_negative_zero(self): + input_data = np.array( + [ + 0.0, + 1.1754943508222875e-38, + -1.401298464324817e-45, + 0.0, + 459367.0, + ], + dtype=np.float32, + ) + self.assertEqual(knp.argmin(input_data), 2) + + def test_argmin(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.argmin(x).shape, ()) + self.assertEqual(knp.argmin(x, keepdims=True).shape, (None, 3)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.argmin(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.argmin(x, keepdims=True).shape, (None, 3, 3)) + + def test_argsort(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.argsort(x).shape, (None, 3)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.argsort(x, axis=1).shape, (None, 3, 3)) + + def test_array(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.array(x).shape, (None, 3)) + + def test_average(self): + x = KerasTensor((None, 3)) + weights = KerasTensor((None, 3)) + self.assertEqual(knp.average(x, weights=weights).shape, ()) + + x = KerasTensor((None, 3)) + weights = KerasTensor((3,)) + self.assertEqual(knp.average(x, axis=1, weights=weights).shape, (None,)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.average(x, axis=1).shape, (None, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((None, 3, 3)) + weights = KerasTensor((None, 4)) + knp.average(x, weights=weights) + + def test_bartlett(self): + x = np.random.randint(1, 100 + 1) + self.assertEqual(knp.bartlett(x).shape[0], x) + + def test_blackman(self): + x = np.random.randint(1, 100 + 1) + self.assertEqual(knp.blackman(x).shape[0], x) + + def test_hamming(self): + x = np.random.randint(1, 100 + 1) + self.assertEqual(knp.hamming(x).shape[0], x) + + def test_hanning(self): + x = np.random.randint(1, 100 + 1) + self.assertEqual(knp.hanning(x).shape[0], x) + + def test_kaiser(self): + x = np.random.randint(1, 100 + 1) + beta = float(np.random.randint(10, 20 + 1)) + self.assertEqual(knp.kaiser(x, beta).shape[0], x) + + def test_bitwise_invert(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.bitwise_invert(x).shape, (None, 3)) + + # bitwise_not is same as bitwise_invert + + def test_broadcast_to(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.broadcast_to(x, (2, 3, 3)).shape, (2, 3, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((3, 3)) + knp.broadcast_to(x, (2, 2, 3)) + + def test_cbrt(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.cbrt(x).shape, (None, 3)) + + def test_ceil(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.ceil(x).shape, (None, 3)) + + def test_clip(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.clip(x, 1, 2).shape, (None, 3)) + + def test_concatenate(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual( + knp.concatenate( + [x, y], + ).shape, + (None, 3), + ) + self.assertEqual(knp.concatenate([x, y], axis=1).shape, (None, 6)) + + with self.assertRaises(ValueError): + self.assertEqual(knp.concatenate([x, y], axis=None).shape, (None,)) + + with self.assertRaises(ValueError): + x = KerasTensor((None, 3, 5)) + y = KerasTensor((None, 4, 6)) + knp.concatenate([x, y], axis=1) + + def test_concatenate_sparse(self): + x = KerasTensor((2, 3), sparse=True) + y = KerasTensor((2, 3)) + result = knp.concatenate([x, y], axis=1) + self.assertEqual(result.shape, (2, 6)) + self.assertFalse(result.sparse) + + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3), sparse=True) + result = knp.concatenate([x, y], axis=1) + self.assertEqual(result.shape, (2, 6)) + self.assertFalse(result.sparse) + + x = KerasTensor((2, 3), sparse=True) + y = KerasTensor((2, 3), sparse=True) + result = knp.concatenate([x, y], axis=1) + self.assertEqual(result.shape, (2, 6)) + self.assertTrue(result.sparse) + + def test_conjugate(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.conjugate(x).shape, (None, 3)) + + def test_conj(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.conj(x).shape, (None, 3)) + + def test_copy(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.copy(x).shape, (None, 3)) + + def test_corrcoef(self): + x = KerasTensor((3, None)) + self.assertEqual(knp.corrcoef(x).shape, (3, None)) + + def test_cos(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.cos(x).shape, (None, 3)) + + def test_cosh(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.cosh(x).shape, (None, 3)) + + def test_count_nonzero(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.count_nonzero(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.count_nonzero(x, axis=1).shape, (None, 3)) + + def test_cumprod(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.cumprod(x).shape, (None,)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.cumprod(x, axis=1).shape, (None, 3, 3)) + + def test_cumsum(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.cumsum(x).shape, (None,)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.cumsum(x, axis=1).shape, (None, 3, 3)) + + def test_deg2rad(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.deg2rad(x).shape, (None, 3)) + + def test_diag(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.diag(x).shape, (None,)) + self.assertEqual(knp.diag(x, k=3).shape, (None,)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3, 4)) + knp.diag(x) + + def test_diagflat(self): + x = KerasTensor((3,)) + self.assertEqual(knp.diagflat(x).shape, (3, 3)) + self.assertEqual(knp.diagflat(x, k=1).shape, (4, 4)) + self.assertEqual(knp.diagflat(x, k=-1).shape, (4, 4)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.diagflat(x).shape, (6, 6)) + self.assertEqual(knp.diagflat(x, k=2).shape, (8, 8)) + + x = KerasTensor((None, 3)) + self.assertEqual(knp.diagflat(x).shape, (None, None)) + + def test_diagonal(self): + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.diagonal(x).shape, (3, None)) + + def test_diff(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.diff(x).shape, (None, 2)) + self.assertEqual(knp.diff(x, n=2).shape, (None, 1)) + self.assertEqual(knp.diff(x, n=3).shape, (None, 0)) + self.assertEqual(knp.diff(x, n=4).shape, (None, 0)) + + self.assertEqual(knp.diff(x, axis=0).shape, (None, 3)) + self.assertEqual(knp.diff(x, n=2, axis=0).shape, (None, 3)) + + def test_dot(self): + x = KerasTensor((None, 3)) + y = KerasTensor((3, 2)) + z = KerasTensor((None, None, 2)) + self.assertEqual(knp.dot(x, y).shape, (None, 2)) + self.assertEqual(knp.dot(x, 2).shape, (None, 3)) + self.assertEqual(knp.dot(x, z).shape, (None, None, 2)) + + x = KerasTensor((None,)) + y = KerasTensor((5,)) + self.assertEqual(knp.dot(x, y).shape, ()) + + def test_exp(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.exp(x).shape, (None, 3)) + + def test_exp2(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.exp2(x).shape, (None, 3)) + + def test_expand_dims(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.expand_dims(x, -1).shape, (None, 3, 1)) + self.assertEqual(knp.expand_dims(x, 0).shape, (1, None, 3)) + self.assertEqual(knp.expand_dims(x, 1).shape, (None, 1, 3)) + self.assertEqual(knp.expand_dims(x, -2).shape, (None, 1, 3)) + + # Multiple axes + self.assertEqual(knp.expand_dims(x, (1, 2)).shape, (None, 1, 1, 3)) + self.assertEqual(knp.expand_dims(x, (-1, -2)).shape, (None, 3, 1, 1)) + self.assertEqual(knp.expand_dims(x, (-1, 1)).shape, (None, 1, 3, 1)) + + def test_expm1(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.expm1(x).shape, (None, 3)) + + def test_flip(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.flip(x).shape, (None, 3)) + + def test_floor(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.floor(x).shape, (None, 3)) + + def test_get_item(self): + x = KerasTensor((None, 5, 16)) + # Simple slice. + sliced = knp.get_item(x, 5) + self.assertEqual(sliced.shape, (5, 16)) + # Ellipsis slice. + sliced = knp.get_item(x, np.s_[..., -1]) + self.assertEqual(sliced.shape, (None, 5)) + # `newaxis` slice. + sliced = knp.get_item(x, np.s_[:, np.newaxis, ...]) + self.assertEqual(sliced.shape, (None, 1, 5, 16)) + # Strided slice. + sliced = knp.get_item(x, np.s_[:5, 3:, 3:12:2]) + self.assertEqual(sliced.shape, (None, 2, 5)) + # Error states. + with self.assertRaises(ValueError): + sliced = knp.get_item(x, np.s_[:, 17, :]) + with self.assertRaises(ValueError): + sliced = knp.get_item(x, np.s_[..., 5, ...]) + with self.assertRaises(ValueError): + sliced = knp.get_item(x, np.s_[:, :, :, :]) + + def test_hstack(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.hstack([x, y]).shape, (None, 6)) + + x = KerasTensor((None, 3)) + y = KerasTensor((None, None)) + self.assertEqual(knp.hstack([x, y]).shape, (None, None)) + + def test_imag(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.imag(x).shape, (None, 3)) + + def test_isfinite(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isfinite(x).shape, (None, 3)) + + def test_isinf(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isinf(x).shape, (None, 3)) + + def test_isnan(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isnan(x).shape, (None, 3)) + + def test_isneginf(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isneginf(x).shape, (None, 3)) + + def test_isposinf(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isposinf(x).shape, (None, 3)) + + def test_log(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.log(x).shape, (None, 3)) + + def test_log10(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.log10(x).shape, (None, 3)) + + def test_log1p(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.log1p(x).shape, (None, 3)) + + def test_log2(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.log2(x).shape, (None, 3)) + + def test_logaddexp(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.logaddexp(x, x).shape, (None, 3)) + + def test_logaddexp2(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.logaddexp2(x, x).shape, (None, 3)) + + def test_logical_not(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.logical_not(x).shape, (None, 3)) + + def test_max(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.max(x).shape, ()) + + def test_median(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.median(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.median(x, axis=1).shape, (None, 3)) + self.assertEqual( + knp.median(x, axis=1, keepdims=True).shape, (None, 1, 3) + ) + + def test_meshgrid(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.meshgrid(x, y)[0].shape, (None, None)) + self.assertEqual(knp.meshgrid(x, y)[1].shape, (None, None)) + + with self.assertRaises(ValueError): + knp.meshgrid(x, y, indexing="kk") + + def test_moveaxis(self): + x = KerasTensor((None, 3, 4, 5)) + self.assertEqual(knp.moveaxis(x, 0, -1).shape, (3, 4, 5, None)) + self.assertEqual(knp.moveaxis(x, -1, 0).shape, (5, None, 3, 4)) + self.assertEqual( + knp.moveaxis(x, [0, 1], [-1, -2]).shape, (4, 5, 3, None) + ) + self.assertEqual(knp.moveaxis(x, [0, 1], [1, 0]).shape, (3, None, 4, 5)) + self.assertEqual( + knp.moveaxis(x, [0, 1], [-2, -1]).shape, (4, 5, None, 3) + ) + + def test_ndim(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.ndim(x).shape, (2,)) + + def test_nonzero(self): + x = KerasTensor((None, 5, 6)) + result = knp.nonzero(x) + self.assertLen(result, 3) + self.assertEqual(result[0].shape, (None,)) + self.assertEqual(result[1].shape, (None,)) + self.assertEqual(result[2].shape, (None,)) + + def test_ones_like(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.ones_like(x).shape, (None, 3)) + self.assertEqual(knp.ones_like(x).dtype, x.dtype) + + def test_zeros_like(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.zeros_like(x).shape, (None, 3)) + self.assertEqual(knp.zeros_like(x).dtype, x.dtype) + + def test_pad(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.pad(x, 1).shape, (None, 5)) + self.assertEqual(knp.pad(x, (1, 2)).shape, (None, 6)) + self.assertEqual(knp.pad(x, ((1, 2), (3, 4))).shape, (None, 10)) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.pad(x, 1).shape, (None, 5, 5)) + self.assertEqual(knp.pad(x, (1, 2)).shape, (None, 6, 6)) + self.assertEqual( + knp.pad(x, ((1, 2), (3, 4), (5, 6))).shape, (None, 10, 14) + ) + + def test_prod(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.prod(x).shape, ()) + self.assertEqual(knp.prod(x, axis=0).shape, (3,)) + self.assertEqual(knp.prod(x, axis=1, keepdims=True).shape, (None, 1)) + + def test_ravel(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.ravel(x).shape, (None,)) + + def test_unravel_index(self): + x = KerasTensor((None,)) + indices = knp.unravel_index(x, (2, 3)) + self.assertEqual(len(indices), 2) + self.assertEqual(indices[0].shape, (None,)) + self.assertEqual(indices[1].shape, (None,)) + + x = KerasTensor((None, 4)) + indices = knp.unravel_index(x, (3, 4)) + self.assertEqual(len(indices), 2) + self.assertEqual(indices[0].shape, (None, 4)) + self.assertEqual(indices[1].shape, (None, 4)) + + x = KerasTensor((None, 3, 2)) + indices = knp.unravel_index(x, (5, 6, 4)) + self.assertEqual(len(indices), 3) + self.assertEqual(indices[0].shape, (None, 3, 2)) + self.assertEqual(indices[1].shape, (None, 3, 2)) + self.assertEqual(indices[2].shape, (None, 3, 2)) + + def test_real(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.real(x).shape, (None, 3)) + + def test_reciprocal(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.reciprocal(x).shape, (None, 3)) + + def test_repeat(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.repeat(x, 2).shape, (None,)) + self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9)) + self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (None, 3)) + self.assertEqual(knp.repeat(x, 2, axis=0).shape, (None, 3)) + + def test_reshape(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2)) + self.assertEqual(knp.reshape(x, (3, -1)).shape, (3, None)) + + def test_roll(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.roll(x, 1).shape, (None, 3)) + self.assertEqual(knp.roll(x, 1, axis=1).shape, (None, 3)) + self.assertEqual(knp.roll(x, 1, axis=0).shape, (None, 3)) + + def test_round(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.round(x).shape, (None, 3)) + + def test_sign(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.sign(x).shape, (None, 3)) + + def test_signbit(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.signbit(x).shape, (None, 3)) + + def test_sin(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.sin(x).shape, (None, 3)) + + def test_sinh(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.sinh(x).shape, (None, 3)) + + def test_size(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.size(x).shape, ()) + + def test_sort(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.sort(x).shape, (None, 3)) + self.assertEqual(knp.sort(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.sort(x, axis=0).shape, (None, 3)) + + def test_split(self): + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.split(x, 2)[0].shape, (None, 3, 3)) + self.assertEqual(knp.split(x, 3, axis=1)[0].shape, (None, 1, 3)) + self.assertEqual(len(knp.split(x, [1, 3], axis=1)), 3) + self.assertEqual(knp.split(x, [1, 3], axis=1)[0].shape, (None, 1, 3)) + self.assertEqual(knp.split(x, [1, 3], axis=1)[1].shape, (None, 2, 3)) + self.assertEqual(knp.split(x, [1, 3], axis=1)[2].shape, (None, 0, 3)) + + def test_sqrt(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.sqrt(x).shape, (None, 3)) + + def test_stack(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.stack([x, y]).shape, (2, None, 3)) + self.assertEqual(knp.stack([x, y], axis=-1).shape, (None, 3, 2)) + + def test_std(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.std(x).shape, ()) + + x = KerasTensor((None, 3, 3)) + self.assertEqual(knp.std(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.std(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_swapaxes(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.swapaxes(x, 0, 1).shape, (3, None)) + + def test_tan(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.tan(x).shape, (None, 3)) + + def test_tanh(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.tanh(x).shape, (None, 3)) + + def test_tile(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.tile(x, 2).shape, (None, 6)) + self.assertEqual(knp.tile(x, [2]).shape, (None, 6)) + self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6)) + self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6)) + + def test_trace(self): + x = KerasTensor((None, 3, None, 5)) + self.assertEqual(knp.trace(x).shape, (None, 5)) + self.assertEqual(knp.trace(x, axis1=2, axis2=3).shape, (None, 3)) + + def test_tril(self): + x = KerasTensor((None, 3, None, 5)) + self.assertEqual(knp.tril(x).shape, (None, 3, None, 5)) + self.assertEqual(knp.tril(x, k=1).shape, (None, 3, None, 5)) + self.assertEqual(knp.tril(x, k=-1).shape, (None, 3, None, 5)) + + def test_triu(self): + x = KerasTensor((None, 3, None, 5)) + self.assertEqual(knp.triu(x).shape, (None, 3, None, 5)) + self.assertEqual(knp.triu(x, k=1).shape, (None, 3, None, 5)) + self.assertEqual(knp.triu(x, k=-1).shape, (None, 3, None, 5)) + + def test_trunc(self): + x = KerasTensor((None, 3, None, 5)) + self.assertEqual(knp.trunc(x).shape, (None, 3, None, 5)) + + def test_vstack(self): + x = KerasTensor((None, 3)) + y = KerasTensor((None, 3)) + self.assertEqual(knp.vstack([x, y]).shape, (None, 3)) + + x = KerasTensor((None, 3)) + y = KerasTensor((None, None)) + self.assertEqual(knp.vstack([x, y]).shape, (None, 3)) + + def test_argpartition(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.argpartition(x, 3).shape, (None, 3)) + self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (None, 3)) + + with self.assertRaises(ValueError): + knp.argpartition(x, (1, 3)) + + def test_angle(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.angle(x).shape, (None, 3)) + + +class NumpyOneInputOpsStaticShapeTest(testing.TestCase): + def test_mean(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.mean(x).shape, ()) + + def test_all(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.all(x).shape, ()) + + def test_any(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.any(x).shape, ()) + + def test_var(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.var(x).shape, ()) + + def test_sum(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.sum(x).shape, ()) + + def test_amax(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.amax(x).shape, ()) + + def test_amin(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.amin(x).shape, ()) + + def test_square(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.square(x).shape, (2, 3)) + + def test_negative(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.negative(x).shape, (2, 3)) + + def test_abs(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.abs(x).shape, (2, 3)) + + def test_absolute(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.absolute(x).shape, (2, 3)) + + def test_squeeze(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.squeeze(x).shape, (2, 3)) + + x = KerasTensor((2, 1, 3)) + self.assertEqual(knp.squeeze(x).shape, (2, 3)) + self.assertEqual(knp.squeeze(x, axis=1).shape, (2, 3)) + self.assertEqual(knp.squeeze(x, axis=-2).shape, (2, 3)) + + with self.assertRaises(ValueError): + knp.squeeze(x, axis=0) + + # Multiple axes + x = KerasTensor((2, 1, 1, 1)) + self.assertEqual(knp.squeeze(x, (1, 2)).shape, (2, 1)) + self.assertEqual(knp.squeeze(x, (-1, -2)).shape, (2, 1)) + self.assertEqual(knp.squeeze(x, (1, 2, 3)).shape, (2,)) + self.assertEqual(knp.squeeze(x, (-1, 1)).shape, (2, 1)) + + def test_transpose(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.transpose(x).shape, (3, 2)) + + def test_arccos(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.arccos(x).shape, (2, 3)) + + def test_arccosh(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.arccosh(x).shape, (2, 3)) + + def test_arcsin(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.arcsin(x).shape, (2, 3)) + + def test_arcsinh(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.arcsinh(x).shape, (2, 3)) + + def test_arctan(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.arctan(x).shape, (2, 3)) + + def test_arctanh(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.arctanh(x).shape, (2, 3)) + + def test_argmax(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.argmax(x).shape, ()) + self.assertEqual(knp.argmax(x, keepdims=True).shape, (2, 3)) + + def test_argmin(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.argmin(x).shape, ()) + self.assertEqual(knp.argmin(x, keepdims=True).shape, (2, 3)) + + def test_argsort(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.argsort(x).shape, (2, 3)) + self.assertEqual(knp.argsort(x, axis=None).shape, (6,)) + + def test_array(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.array(x).shape, (2, 3)) + + def test_average(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.average(x).shape, ()) + + def test_bitwise_invert(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.bitwise_invert(x).shape, (2, 3)) + + # bitwise_not is same as bitwise_invert + + def test_broadcast_to(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.broadcast_to(x, (2, 2, 3)).shape, (2, 2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((3, 3)) + knp.broadcast_to(x, (2, 2, 3)) + + def test_cbrt(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.cbrt(x).shape, (2, 3)) + + def test_ceil(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.ceil(x).shape, (2, 3)) + + def test_clip(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.clip(x, 1, 2).shape, (2, 3)) + + def test_concatenate(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.concatenate([x, y]).shape, (4, 3)) + self.assertEqual(knp.concatenate([x, y], axis=1).shape, (2, 6)) + + with self.assertRaises(ValueError): + self.assertEqual(knp.concatenate([x, y], axis=None).shape, (None,)) + + def test_conjugate(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.conjugate(x).shape, (2, 3)) + + def test_conj(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.conj(x).shape, (2, 3)) + + def test_copy(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.copy(x).shape, (2, 3)) + + def test_cos(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.cos(x).shape, (2, 3)) + + def test_cosh(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.cosh(x).shape, (2, 3)) + + def test_count_nonzero(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.count_nonzero(x).shape, ()) + + def test_cumprod(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.cumprod(x).shape, (6,)) + + def test_cumsum(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.cumsum(x).shape, (6,)) + + def test_deg2rad(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.deg2rad(x).shape, (2, 3)) + + def test_diag(self): + x = KerasTensor((3,)) + self.assertEqual(knp.diag(x).shape, (3, 3)) + self.assertEqual(knp.diag(x, k=3).shape, (6, 6)) + self.assertEqual(knp.diag(x, k=-2).shape, (5, 5)) + + x = KerasTensor((3, 5)) + self.assertEqual(knp.diag(x).shape, (3,)) + self.assertEqual(knp.diag(x, k=3).shape, (2,)) + self.assertEqual(knp.diag(x, k=-2).shape, (1,)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3, 4)) + knp.diag(x) + + def test_diagflat(self): + x = KerasTensor((3,)) + self.assertEqual(knp.diagflat(x).shape, (3, 3)) + self.assertEqual(knp.diagflat(x, k=1).shape, (4, 4)) + self.assertEqual(knp.diagflat(x, k=-1).shape, (4, 4)) + + x = KerasTensor((2, 3)) + self.assertEqual(knp.diagflat(x).shape, (6, 6)) + self.assertEqual(knp.diagflat(x, k=1).shape, (7, 7)) + self.assertEqual(knp.diagflat(x, k=-1).shape, (7, 7)) + + x = KerasTensor((None, 3)) + self.assertEqual(knp.diagflat(x).shape, (None, None)) + + x = KerasTensor(()) + self.assertEqual(knp.diagflat(x).shape, (1, 1)) + + def test_diagonal(self): + x = KerasTensor((3, 3)) + self.assertEqual(knp.diagonal(x).shape, (3,)) + self.assertEqual(knp.diagonal(x, offset=1).shape, (2,)) + + x = KerasTensor((3, 5, 5)) + self.assertEqual(knp.diagonal(x).shape, (5, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor((3,)) + knp.diagonal(x) + + def test_diff(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.diff(x).shape, (2, 2)) + self.assertEqual(knp.diff(x, n=2).shape, (2, 1)) + self.assertEqual(knp.diff(x, n=3).shape, (2, 0)) + self.assertEqual(knp.diff(x, n=4).shape, (2, 0)) + + self.assertEqual(knp.diff(x, axis=0).shape, (1, 3)) + self.assertEqual(knp.diff(x, n=2, axis=0).shape, (0, 3)) + self.assertEqual(knp.diff(x, n=3, axis=0).shape, (0, 3)) + + def test_dot(self): + x = KerasTensor((2, 3)) + y = KerasTensor((3, 2)) + z = KerasTensor((4, 3, 2)) + self.assertEqual(knp.dot(x, y).shape, (2, 2)) + self.assertEqual(knp.dot(x, 2).shape, (2, 3)) + self.assertEqual(knp.dot(x, z).shape, (2, 4, 2)) + + x = KerasTensor((5,)) + y = KerasTensor((5,)) + self.assertEqual(knp.dot(x, y).shape, ()) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + knp.dot(x, y) + + def test_exp(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.exp(x).shape, (2, 3)) + + def test_exp2(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.exp2(x).shape, (2, 3)) + + def test_expand_dims(self): + x = KerasTensor((2, 3, 4)) + self.assertEqual(knp.expand_dims(x, 0).shape, (1, 2, 3, 4)) + self.assertEqual(knp.expand_dims(x, 1).shape, (2, 1, 3, 4)) + self.assertEqual(knp.expand_dims(x, -2).shape, (2, 3, 1, 4)) + + # Multiple axes + self.assertEqual(knp.expand_dims(x, (1, 2)).shape, (2, 1, 1, 3, 4)) + self.assertEqual(knp.expand_dims(x, (-1, -2)).shape, (2, 3, 4, 1, 1)) + self.assertEqual(knp.expand_dims(x, (-1, 1)).shape, (2, 1, 3, 4, 1)) + + def test_expm1(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.expm1(x).shape, (2, 3)) + + def test_flip(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.flip(x).shape, (2, 3)) + + def test_floor(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.floor(x).shape, (2, 3)) + + def test_get_item(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.get_item(x, 1).shape, (3,)) + + x = KerasTensor((5, 3, 2)) + self.assertEqual(knp.get_item(x, 3).shape, (3, 2)) + + x = KerasTensor( + [ + 2, + ] + ) + self.assertEqual(knp.get_item(x, 0).shape, ()) + + def test_hstack(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.hstack([x, y]).shape, (2, 6)) + + def test_imag(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.imag(x).shape, (2, 3)) + + def test_isfinite(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isfinite(x).shape, (2, 3)) + + def test_isinf(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isinf(x).shape, (2, 3)) + + def test_isnan(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isnan(x).shape, (2, 3)) + + def test_isneginf(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isneginf(x).shape, (2, 3)) + + def test_isposinf(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isposinf(x).shape, (2, 3)) + + def test_log(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.log(x).shape, (2, 3)) + + def test_log10(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.log10(x).shape, (2, 3)) + + def test_log1p(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.log1p(x).shape, (2, 3)) + + def test_log2(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.log2(x).shape, (2, 3)) + + def test_logaddexp(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.logaddexp(x, x).shape, (2, 3)) + + def test_logaddexp2(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.logaddexp2(x, x).shape, (2, 3)) + + def test_logical_not(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.logical_not(x).shape, (2, 3)) + + def test_max(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.max(x).shape, ()) + + def test_median(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.median(x).shape, ()) + + x = KerasTensor((2, 3, 3)) + self.assertEqual(knp.median(x, axis=1).shape, (2, 3)) + self.assertEqual(knp.median(x, axis=1, keepdims=True).shape, (2, 1, 3)) + + def test_meshgrid(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3, 4)) + z = KerasTensor((2, 3, 4, 5)) + self.assertEqual(knp.meshgrid(x, y)[0].shape, (24, 6)) + self.assertEqual(knp.meshgrid(x, y)[1].shape, (24, 6)) + self.assertEqual(knp.meshgrid(x, y, indexing="ij")[0].shape, (6, 24)) + self.assertEqual( + knp.meshgrid(x, y, z, indexing="ij")[0].shape, (6, 24, 120) + ) + with self.assertRaises(ValueError): + knp.meshgrid(x, y, indexing="kk") + + def test_moveaxis(self): + x = KerasTensor((2, 3, 4, 5)) + self.assertEqual(knp.moveaxis(x, 0, -1).shape, (3, 4, 5, 2)) + self.assertEqual(knp.moveaxis(x, -1, 0).shape, (5, 2, 3, 4)) + self.assertEqual(knp.moveaxis(x, [0, 1], [-1, -2]).shape, (4, 5, 3, 2)) + self.assertEqual(knp.moveaxis(x, [0, 1], [1, 0]).shape, (3, 2, 4, 5)) + self.assertEqual(knp.moveaxis(x, [0, 1], [-2, -1]).shape, (4, 5, 2, 3)) + + def test_ndim(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.ndim(x).shape, (2,)) + + def test_ones_like(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.ones_like(x).shape, (2, 3)) + self.assertEqual(knp.ones_like(x).dtype, x.dtype) + + def test_zeros_like(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.zeros_like(x).shape, (2, 3)) + self.assertEqual(knp.zeros_like(x).dtype, x.dtype) + + def test_pad(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.pad(x, 1).shape, (4, 5)) + self.assertEqual(knp.pad(x, (1, 2)).shape, (5, 6)) + self.assertEqual(knp.pad(x, ((1, 2), (3, 4))).shape, (5, 10)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + knp.pad(x, ((1, 2), (3, 4), (5, 6))) + + def test_prod(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.prod(x).shape, ()) + self.assertEqual(knp.prod(x, axis=0).shape, (3,)) + self.assertEqual(knp.prod(x, axis=1).shape, (2,)) + + def test_ravel(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.ravel(x).shape, (6,)) + + def test_unravel_index(self): + x = KerasTensor((6,)) + indices = knp.unravel_index(x, (2, 3)) + self.assertEqual(len(indices), 2) + self.assertEqual(indices[0].shape, (6,)) + self.assertEqual(indices[1].shape, (6,)) + + x = KerasTensor((2, 3)) + indices = knp.unravel_index(x, (3, 4)) + self.assertEqual(len(indices), 2) + self.assertEqual(indices[0].shape, (2, 3)) + self.assertEqual(indices[1].shape, (2, 3)) + + def test_real(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.real(x).shape, (2, 3)) + + def test_reciprocal(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.reciprocal(x).shape, (2, 3)) + + def test_repeat(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.repeat(x, 2).shape, (12,)) + self.assertEqual(knp.repeat(x, [2]).shape, (12,)) + self.assertEqual(knp.repeat(x, 3, axis=1).shape, (2, 9)) + self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3)) + + with self.assertRaises(ValueError): + knp.repeat(x, [1, 1]) + with self.assertRaises(ValueError): + knp.repeat(x, [1, 1, 1], axis=0) + + def test_reshape(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2)) + self.assertEqual(knp.reshape(x, (3, -1)).shape, (3, 2)) + self.assertEqual(knp.reshape(x, (6,)).shape, (6,)) + self.assertEqual(knp.reshape(x, (-1,)).shape, (6,)) + + def test_roll(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.roll(x, 1).shape, (2, 3)) + self.assertEqual(knp.roll(x, 1, axis=1).shape, (2, 3)) + self.assertEqual(knp.roll(x, 1, axis=0).shape, (2, 3)) + + def test_round(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.round(x).shape, (2, 3)) + + def test_sign(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.sign(x).shape, (2, 3)) + + def test_signbit(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.signbit(x).shape, (2, 3)) + + def test_sin(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.sin(x).shape, (2, 3)) + + def test_sinh(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.sinh(x).shape, (2, 3)) + + def test_size(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.size(x).shape, ()) + + def test_sort(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.sort(x).shape, (2, 3)) + self.assertEqual(knp.sort(x, axis=1).shape, (2, 3)) + self.assertEqual(knp.sort(x, axis=0).shape, (2, 3)) + + def test_split(self): + x = KerasTensor((2, 3)) + self.assertEqual(len(knp.split(x, 2)), 2) + self.assertEqual(knp.split(x, 2)[0].shape, (1, 3)) + self.assertEqual(knp.split(x, 3, axis=1)[0].shape, (2, 1)) + self.assertEqual(len(knp.split(x, [1, 3], axis=1)), 3) + self.assertEqual(knp.split(x, [1, 3], axis=1)[0].shape, (2, 1)) + self.assertEqual(knp.split(x, [1, 3], axis=1)[1].shape, (2, 2)) + self.assertEqual(knp.split(x, [1, 3], axis=1)[2].shape, (2, 0)) + + with self.assertRaises(ValueError): + knp.split(x, 2, axis=1) + + def test_sqrt(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.sqrt(x).shape, (2, 3)) + + def test_stack(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.stack([x, y]).shape, (2, 2, 3)) + self.assertEqual(knp.stack([x, y], axis=-1).shape, (2, 3, 2)) + + with self.assertRaises(ValueError): + x = KerasTensor((2, 3)) + y = KerasTensor((3, 3)) + knp.stack([x, y]) + + def test_std(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.std(x).shape, ()) + + def test_swapaxes(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.swapaxes(x, 0, 1).shape, (3, 2)) + + def test_tan(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.tan(x).shape, (2, 3)) + + def test_tanh(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.tanh(x).shape, (2, 3)) + + def test_tile(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.tile(x, 2).shape, (2, 6)) + self.assertEqual(knp.tile(x, [2]).shape, (2, 6)) + self.assertEqual(knp.tile(x, [1, 2]).shape, (2, 6)) + self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, 2, 6)) + + def test_trace(self): + x = KerasTensor((2, 3, 4, 5)) + self.assertEqual(knp.trace(x).shape, (4, 5)) + self.assertEqual(knp.trace(x, axis1=2, axis2=3).shape, (2, 3)) + + def test_tril(self): + x = KerasTensor((2, 3, 4, 5)) + self.assertEqual(knp.tril(x).shape, (2, 3, 4, 5)) + self.assertEqual(knp.tril(x, k=1).shape, (2, 3, 4, 5)) + self.assertEqual(knp.tril(x, k=-1).shape, (2, 3, 4, 5)) + + def test_triu(self): + x = KerasTensor((2, 3, 4, 5)) + self.assertEqual(knp.triu(x).shape, (2, 3, 4, 5)) + self.assertEqual(knp.triu(x, k=1).shape, (2, 3, 4, 5)) + self.assertEqual(knp.triu(x, k=-1).shape, (2, 3, 4, 5)) + + def test_trunc(self): + x = KerasTensor((2, 3, 4, 5)) + self.assertEqual(knp.trunc(x).shape, (2, 3, 4, 5)) + + def test_vstack(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.vstack([x, y]).shape, (4, 3)) + + def test_argpartition(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.argpartition(x, 3).shape, (2, 3)) + self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (2, 3)) + + with self.assertRaises(ValueError): + knp.argpartition(x, (1, 3)) + + def test_angle(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.angle(x).shape, (2, 3)) + + +class NumpyTwoInputOpsCorrectnessTest(testing.TestCase): + def test_add(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.add(x, y), np.add(x, y)) + self.assertAllClose(knp.add(x, z), np.add(x, z)) + + self.assertAllClose(knp.Add()(x, y), np.add(x, y)) + self.assertAllClose(knp.Add()(x, z), np.add(x, z)) + + def test_heaviside(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + self.assertAllClose(knp.heaviside(x, y), np.heaviside(x, y)) + self.assertAllClose(knp.Heaviside()(x, y), np.heaviside(x, y)) + + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array(4) + self.assertAllClose(knp.heaviside(x, y), np.heaviside(x, y)) + self.assertAllClose(knp.Heaviside()(x, y), np.heaviside(x, y)) + + def test_hypot(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + self.assertAllClose(knp.hypot(x, y), np.hypot(x, y)) + self.assertAllClose(knp.Hypot()(x, y), np.hypot(x, y)) + + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array(4) + self.assertAllClose(knp.hypot(x, y), np.hypot(x, y)) + self.assertAllClose(knp.Hypot()(x, y), np.hypot(x, y)) + + def test_subtract(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.subtract(x, y), np.subtract(x, y)) + self.assertAllClose(knp.subtract(x, z), np.subtract(x, z)) + + self.assertAllClose(knp.Subtract()(x, y), np.subtract(x, y)) + self.assertAllClose(knp.Subtract()(x, z), np.subtract(x, z)) + + def test_multiply(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.multiply(x, y), np.multiply(x, y)) + self.assertAllClose(knp.multiply(x, z), np.multiply(x, z)) + + self.assertAllClose(knp.Multiply()(x, y), np.multiply(x, y)) + self.assertAllClose(knp.Multiply()(x, z), np.multiply(x, z)) + + def test_matmul(self): + x = np.ones([2, 3, 4, 5]) + y = np.ones([2, 3, 5, 6]) + z = np.ones([5, 6]) + p = np.ones([4]) + self.assertAllClose(knp.matmul(x, y), np.matmul(x, y)) + self.assertAllClose(knp.matmul(x, z), np.matmul(x, z)) + self.assertAllClose(knp.matmul(p, x), np.matmul(p, x)) + + self.assertAllClose(knp.Matmul()(x, y), np.matmul(x, y)) + self.assertAllClose(knp.Matmul()(x, z), np.matmul(x, z)) + self.assertAllClose(knp.Matmul()(p, x), np.matmul(p, x)) + + @parameterized.named_parameters( + named_product( + ( + { + "testcase_name": "rank2", + "x_shape": (5, 3), + "y_shape": (3, 4), + }, + { + "testcase_name": "rank3", + "x_shape": (2, 5, 3), + "y_shape": (2, 3, 4), + }, + { + "testcase_name": "rank4", + "x_shape": (2, 2, 5, 3), + "y_shape": (2, 2, 3, 4), + }, + ), + dtype=["float16", "float32", "float64", "int32"], + x_sparse=[False, True], + y_sparse=[False, True], + ) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_matmul_sparse(self, dtype, x_shape, y_shape, x_sparse, y_sparse): + if backend.backend() == "tensorflow": + import tensorflow as tf + + if x_sparse and y_sparse and dtype in ("float16", "int32"): + pytest.skip( + f"Sparse sparse matmul unsupported for {dtype}" + " with TensorFlow backend" + ) + + dense_to_sparse = tf.sparse.from_dense + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + dense_to_sparse = functools.partial( + jax_sparse.BCOO.fromdense, n_batch=len(x_shape) - 2 + ) + + rng = np.random.default_rng(0) + + x = x_np = (4 * rng.standard_normal(x_shape)).astype(dtype) + if x_sparse: + x_np = np.multiply(x_np, rng.random(x_shape) < 0.7) + x = dense_to_sparse(x_np) + + y = y_np = (4 * rng.standard_normal(y_shape)).astype(dtype) + if y_sparse: + y_np = np.multiply(y_np, rng.random(y_shape) < 0.7) + y = dense_to_sparse(y_np) + + atol = 0.1 if dtype == "float16" else 1e-4 + self.assertAllClose(knp.matmul(x, y), np.matmul(x_np, y_np), atol=atol) + self.assertSparse(knp.matmul(x, y), x_sparse and y_sparse) + + def test_power(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.power(x, y), np.power(x, y)) + self.assertAllClose(knp.power(x, z), np.power(x, z)) + + self.assertAllClose(knp.Power()(x, y), np.power(x, y)) + self.assertAllClose(knp.Power()(x, z), np.power(x, z)) + + def test_divide(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.divide(x, y), np.divide(x, y)) + self.assertAllClose(knp.divide(x, z), np.divide(x, z)) + + self.assertAllClose(knp.Divide()(x, y), np.divide(x, y)) + self.assertAllClose(knp.Divide()(x, z), np.divide(x, z)) + + def test_divide_no_nan(self): + x = np.array( + [[2, 1, 0], [np.inf, -np.inf, np.nan], [np.inf, -np.inf, np.nan]] + ) + y = np.array([[2, 0, 0], [0, 0, 0], [3, 2, 1]]) + expected_result = np.array( + [[1, 0, 0], [0, 0, 0], [np.inf, -np.inf, np.nan]] + ) + self.assertAllClose(knp.divide_no_nan(x, y), expected_result) + self.assertAllClose(knp.DivideNoNan()(x, y), expected_result) + + def test_true_divide(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.true_divide(x, y), np.true_divide(x, y)) + self.assertAllClose(knp.true_divide(x, z), np.true_divide(x, z)) + + self.assertAllClose(knp.TrueDivide()(x, y), np.true_divide(x, y)) + self.assertAllClose(knp.TrueDivide()(x, z), np.true_divide(x, z)) + + def test_append(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [3, 2, 1]]]) + self.assertAllClose(knp.append(x, y), np.append(x, y)) + self.assertAllClose(knp.append(x, y, axis=1), np.append(x, y, axis=1)) + self.assertAllClose(knp.append(x, z), np.append(x, z)) + + self.assertAllClose(knp.Append()(x, y), np.append(x, y)) + self.assertAllClose(knp.Append(axis=1)(x, y), np.append(x, y, axis=1)) + self.assertAllClose(knp.Append()(x, z), np.append(x, z)) + + def test_arctan2(self): + x = np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]) + y = np.array([[4.0, 5.0, 6.0], [3.0, 2.0, 1.0]]) + self.assertAllClose(knp.arctan2(x, y), np.arctan2(x, y)) + + self.assertAllClose(knp.Arctan2()(x, y), np.arctan2(x, y)) + + a = np.array([0.0, 0.0, 0.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0]) + b = np.array([0.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 0.0, 0.0]) + + self.assertAllClose(knp.arctan2(a, b), np.arctan2(a, b)) + self.assertAllClose(knp.Arctan2()(a, b), np.arctan2(a, b)) + + m = np.array([[3, 4], [7, 8]], dtype=np.int8) + n = np.array([[1, 2], [3, 4]], dtype=float) + + self.assertAllClose(knp.arctan2(m, n), np.arctan2(m, n)) + self.assertAllClose(knp.Arctan2()(m, n), np.arctan2(m, n)) + + def test_bitwise_and(self): + x = np.array([2, 5, 255]) + y = np.array([3, 14, 16]) + self.assertAllClose(knp.bitwise_and(x, y), np.bitwise_and(x, y)) + self.assertAllClose(knp.BitwiseAnd()(x, y), np.bitwise_and(x, y)) + + def test_bitwise_or(self): + x = np.array([2, 5, 255]) + y = np.array([3, 14, 16]) + self.assertAllClose(knp.bitwise_or(x, y), np.bitwise_or(x, y)) + self.assertAllClose(knp.BitwiseOr()(x, y), np.bitwise_or(x, y)) + + def test_bitwise_xor(self): + x = np.array([2, 5, 255]) + y = np.array([3, 14, 16]) + self.assertAllClose(knp.bitwise_xor(x, y), np.bitwise_xor(x, y)) + self.assertAllClose(knp.BitwiseXor()(x, y), np.bitwise_xor(x, y)) + + def test_bitwise_left_shift(self): + x = np.array([50, 60, 70]) + y = np.array([1, 2, 3]) + self.assertAllClose(knp.bitwise_left_shift(x, y), np.left_shift(x, y)) + self.assertAllClose(knp.BitwiseLeftShift()(x, y), np.left_shift(x, y)) + + # left_shift is same as bitwise_left_shift + + def test_bitwise_right_shift(self): + x = np.array([5, 6, 7]) + y = np.array([1, 2, 3]) + self.assertAllClose(knp.bitwise_right_shift(x, y), np.right_shift(x, y)) + self.assertAllClose(knp.BitwiseRightShift()(x, y), np.right_shift(x, y)) + + # right_shift is same as bitwise_right_shift + + def test_cross(self): + x1 = np.ones([2, 1, 4, 3]) + x2 = np.ones([2, 1, 4, 2]) + y1 = np.ones([2, 1, 4, 3]) + y2 = np.ones([1, 5, 4, 3]) + y3 = np.ones([1, 5, 4, 2]) + self.assertAllClose(knp.cross(x1, y1), np.cross(x1, y1)) + self.assertAllClose(knp.cross(x1, y2), np.cross(x1, y2)) + if backend.backend() != "torch": + # API divergence between `torch.cross` and `np.cross` + # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3 + self.assertAllClose(knp.cross(x1, y3), np.cross(x1, y3)) + self.assertAllClose(knp.cross(x2, y3), np.cross(x2, y3)) + + self.assertAllClose(knp.Cross()(x1, y1), np.cross(x1, y1)) + self.assertAllClose(knp.Cross()(x1, y2), np.cross(x1, y2)) + if backend.backend() != "torch": + # API divergence between `torch.cross` and `np.cross` + # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3 + self.assertAllClose(knp.Cross()(x1, y3), np.cross(x1, y3)) + self.assertAllClose(knp.Cross()(x2, y3), np.cross(x2, y3)) + + # Test axis is not None + self.assertAllClose( + knp.cross(x1, y1, axis=-1), np.cross(x1, y1, axis=-1) + ) + self.assertAllClose( + knp.Cross(axis=-1)(x1, y1), np.cross(x1, y1, axis=-1) + ) + + def test_einsum(self): + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(24).reshape([2, 4, 3]).astype("float32") + self.assertAllClose( + knp.einsum("ijk,lkj->il", x, y), + np.einsum("ijk,lkj->il", x, y), + ) + self.assertAllClose( + knp.einsum("ijk,ikj->i", x, y), + np.einsum("ijk,ikj->i", x, y), + ) + self.assertAllClose( + knp.einsum("i...,j...k->...ijk", x, y), + np.einsum("i..., j...k->...ijk", x, y), + ) + self.assertAllClose(knp.einsum(",ijk", 5, y), np.einsum(",ijk", 5, y)) + + self.assertAllClose( + knp.Einsum("ijk,lkj->il")(x, y), + np.einsum("ijk,lkj->il", x, y), + ) + self.assertAllClose( + knp.Einsum("ijk,ikj->i")(x, y), + np.einsum("ijk,ikj->i", x, y), + ) + self.assertAllClose( + knp.Einsum("i...,j...k->...ijk")(x, y), + np.einsum("i...,j...k->...ijk", x, y), + ) + self.assertAllClose(knp.Einsum(",ijk")(5, y), np.einsum(",ijk", 5, y)) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason=f"{backend.backend()} doesn't implement custom ops for einsum.", + ) + def test_einsum_custom_ops_for_tensorflow(self): + subscripts = "a,b->ab" + x = np.arange(2).reshape([2]).astype("float32") + y = np.arange(3).reshape([3]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "ab,b->a" + x = np.arange(6).reshape([2, 3]).astype("float32") + y = np.arange(3).reshape([3]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "ab,bc->ac" + x = np.arange(6).reshape([2, 3]).astype("float32") + y = np.arange(12).reshape([3, 4]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "ab,cb->ac" + x = np.arange(6).reshape([2, 3]).astype("float32") + y = np.arange(12).reshape([4, 3]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abc,cd->abd" + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(20).reshape([4, 5]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abc,cde->abde" + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(120).reshape([4, 5, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abc,dc->abd" + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(20).reshape([5, 4]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abc,dce->abde" + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(120).reshape([5, 4, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abc,dec->abde" + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(120).reshape([5, 6, 4]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,abde->abce" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(180).reshape([2, 3, 5, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,abed->abce" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(180).reshape([2, 3, 6, 5]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,acbe->adbe" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(144).reshape([2, 4, 3, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,adbe->acbe" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(180).reshape([2, 5, 3, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,aecd->acbe" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(240).reshape([2, 6, 4, 5]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,aecd->aceb" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(240).reshape([2, 6, 4, 5]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,cde->abe" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(120).reshape([4, 5, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,ced->abe" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(120).reshape([4, 6, 5]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcd,ecd->abe" + x = np.arange(120).reshape([2, 3, 4, 5]).astype("float32") + y = np.arange(120).reshape([6, 4, 5]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcde,aebf->adbcf" + x = np.arange(720).reshape([2, 3, 4, 5, 6]).astype("float32") + y = np.arange(252).reshape([2, 6, 3, 7]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + subscripts = "abcde,afce->acdbf" + x = np.arange(720).reshape([2, 3, 4, 5, 6]).astype("float32") + y = np.arange(336).reshape([2, 7, 4, 6]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + + def test_full_like(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.full_like(x, 2), np.full_like(x, 2)) + self.assertAllClose( + knp.full_like(x, 2, dtype="float32"), + np.full_like(x, 2, dtype="float32"), + ) + self.assertAllClose( + knp.full_like(x, np.ones([2, 3])), + np.full_like(x, np.ones([2, 3])), + ) + + self.assertAllClose(knp.FullLike()(x, 2), np.full_like(x, 2)) + self.assertAllClose( + knp.FullLike(dtype="float32")(x, 2), + np.full_like(x, 2, dtype="float32"), + ) + self.assertAllClose( + knp.FullLike()(x, np.ones([2, 3])), + np.full_like(x, np.ones([2, 3])), + ) + + def test_gcd(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.gcd(x, y), np.gcd(x, y)) + self.assertAllClose(knp.gcd(x, 2), np.gcd(x, 2)) + self.assertAllClose(knp.gcd(2, x), np.gcd(2, x)) + + self.assertAllClose(knp.Gcd()(x, y), np.gcd(x, y)) + self.assertAllClose(knp.Gcd()(x, 2), np.gcd(x, 2)) + self.assertAllClose(knp.Gcd()(2, x), np.gcd(2, x)) + + def test_greater(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.greater(x, y), np.greater(x, y)) + self.assertAllClose(knp.greater(x, 2), np.greater(x, 2)) + self.assertAllClose(knp.greater(2, x), np.greater(2, x)) + + self.assertAllClose(knp.Greater()(x, y), np.greater(x, y)) + self.assertAllClose(knp.Greater()(x, 2), np.greater(x, 2)) + self.assertAllClose(knp.Greater()(2, x), np.greater(2, x)) + + def test_greater_equal(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose( + knp.greater_equal(x, y), + np.greater_equal(x, y), + ) + self.assertAllClose( + knp.greater_equal(x, 2), + np.greater_equal(x, 2), + ) + self.assertAllClose( + knp.greater_equal(2, x), + np.greater_equal(2, x), + ) + + self.assertAllClose( + knp.GreaterEqual()(x, y), + np.greater_equal(x, y), + ) + self.assertAllClose( + knp.GreaterEqual()(x, 2), + np.greater_equal(x, 2), + ) + self.assertAllClose( + knp.GreaterEqual()(2, x), + np.greater_equal(2, x), + ) + + def test_isclose(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.isclose(x, y), np.isclose(x, y)) + self.assertAllClose(knp.isclose(x, 2), np.isclose(x, 2)) + self.assertAllClose(knp.isclose(2, x), np.isclose(2, x)) + + self.assertAllClose(knp.Isclose()(x, y), np.isclose(x, y)) + self.assertAllClose(knp.Isclose()(x, 2), np.isclose(x, 2)) + self.assertAllClose(knp.Isclose()(2, x), np.isclose(2, x)) + + def test_isin(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.isin(x, y), np.isin(x, y)) + self.assertAllClose(knp.isin(x, 2), np.isin(x, 2)) + self.assertAllClose(knp.isin(2, x), np.isin(2, x)) + + self.assertAllClose( + knp.isin(x, y, assume_unique=True), + np.isin(x, y, assume_unique=True), + ) + self.assertAllClose( + knp.isin(x, 2, assume_unique=True), + np.isin(x, 2, assume_unique=True), + ) + self.assertAllClose( + knp.isin(2, x, assume_unique=True), + np.isin(2, x, assume_unique=True), + ) + + self.assertAllClose( + knp.isin(x, y, invert=True), np.isin(x, y, invert=True) + ) + self.assertAllClose( + knp.isin(x, 2, invert=True), np.isin(x, 2, invert=True) + ) + self.assertAllClose( + knp.isin(2, x, invert=True), np.isin(2, x, invert=True) + ) + + self.assertAllClose( + knp.isin(x, y, assume_unique=True, invert=True), + np.isin(x, y, assume_unique=True, invert=True), + ) + self.assertAllClose( + knp.isin(x, 2, assume_unique=True, invert=True), + np.isin(x, 2, assume_unique=True, invert=True), + ) + self.assertAllClose( + knp.isin(2, x, assume_unique=True, invert=True), + np.isin(2, x, assume_unique=True, invert=True), + ) + + self.assertAllClose(knp.IsIn()(x, y), np.isin(x, y)) + self.assertAllClose(knp.IsIn()(x, 2), np.isin(x, 2)) + self.assertAllClose(knp.IsIn()(2, x), np.isin(2, x)) + + self.assertAllClose( + knp.IsIn(assume_unique=True)(x, y), + np.isin(x, y, assume_unique=True), + ) + self.assertAllClose( + knp.IsIn(invert=True)(x, y), + np.isin(x, y, invert=True), + ) + self.assertAllClose( + knp.IsIn(assume_unique=True, invert=True)(x, y), + np.isin(x, y, assume_unique=True, invert=True), + ) + + def test_kron(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.kron(x, y), np.kron(x, y)) + self.assertAllClose(knp.Kron()(x, y), np.kron(x, y)) + + def test_lcm(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.lcm(x, y), np.lcm(x, y)) + self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y)) + + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array(4) + self.assertAllClose(knp.lcm(x, y), np.lcm(x, y)) + self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y)) + + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([4]) + self.assertAllClose(knp.lcm(x, y), np.lcm(x, y)) + self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y)) + + def test_less(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.less(x, y), np.less(x, y)) + self.assertAllClose(knp.less(x, 2), np.less(x, 2)) + self.assertAllClose(knp.less(2, x), np.less(2, x)) + + self.assertAllClose(knp.Less()(x, y), np.less(x, y)) + self.assertAllClose(knp.Less()(x, 2), np.less(x, 2)) + self.assertAllClose(knp.Less()(2, x), np.less(2, x)) + + def test_less_equal(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + self.assertAllClose(knp.less_equal(x, y), np.less_equal(x, y)) + self.assertAllClose(knp.less_equal(x, 2), np.less_equal(x, 2)) + self.assertAllClose(knp.less_equal(2, x), np.less_equal(2, x)) + + self.assertAllClose(knp.LessEqual()(x, y), np.less_equal(x, y)) + self.assertAllClose(knp.LessEqual()(x, 2), np.less_equal(x, 2)) + self.assertAllClose(knp.LessEqual()(2, x), np.less_equal(2, x)) + + def test_linspace(self): + self.assertAllClose(knp.linspace(0, 10, 5), np.linspace(0, 10, 5)) + self.assertAllClose( + knp.linspace(0, 10, 5, endpoint=False), + np.linspace(0, 10, 5, endpoint=False), + ) + self.assertAllClose(knp.Linspace(num=5)(0, 10), np.linspace(0, 10, 5)) + self.assertAllClose( + knp.Linspace(num=5, endpoint=False)(0, 10), + np.linspace(0, 10, 5, endpoint=False), + ) + self.assertAllClose( + knp.Linspace(num=0, endpoint=False)(0, 10), + np.linspace(0, 10, 0, endpoint=False), + ) + + start = np.zeros([2, 3, 4]) + stop = np.ones([2, 3, 4]) + self.assertAllClose( + knp.linspace(start, stop, 5, retstep=True)[0], + np.linspace(start, stop, 5, retstep=True)[0], + ) + self.assertAllClose( + knp.linspace(start, stop, 5, endpoint=False, retstep=True)[0], + np.linspace(start, stop, 5, endpoint=False, retstep=True)[0], + ) + self.assertAllClose( + knp.linspace( + start, stop, 5, endpoint=False, retstep=True, dtype="int32" + )[0], + np.linspace( + start, stop, 5, endpoint=False, retstep=True, dtype="int32" + )[0], + ) + + self.assertAllClose( + knp.Linspace(5, retstep=True)(start, stop)[0], + np.linspace(start, stop, 5, retstep=True)[0], + ) + self.assertAllClose( + knp.Linspace(5, endpoint=False, retstep=True)(start, stop)[0], + np.linspace(start, stop, 5, endpoint=False, retstep=True)[0], + ) + self.assertAllClose( + knp.Linspace(5, endpoint=False, retstep=True, dtype="int32")( + start, stop + )[0], + np.linspace( + start, stop, 5, endpoint=False, retstep=True, dtype="int32" + )[0], + ) + + # Test `num` as a tensor + # https://github.com/keras-team/keras/issues/19772 + self.assertAllClose( + knp.linspace(0, 10, backend.convert_to_tensor(5)), + np.linspace(0, 10, 5), + ) + self.assertAllClose( + knp.linspace(0, 10, backend.convert_to_tensor(5), endpoint=False), + np.linspace(0, 10, 5, endpoint=False), + ) + + def test_logical_and(self): + x = np.array([[True, False], [True, True]]) + y = np.array([[False, False], [True, False]]) + self.assertAllClose(knp.logical_and(x, y), np.logical_and(x, y)) + self.assertAllClose(knp.logical_and(x, True), np.logical_and(x, True)) + self.assertAllClose(knp.logical_and(True, x), np.logical_and(True, x)) + + self.assertAllClose(knp.LogicalAnd()(x, y), np.logical_and(x, y)) + self.assertAllClose(knp.LogicalAnd()(x, True), np.logical_and(x, True)) + self.assertAllClose(knp.LogicalAnd()(True, x), np.logical_and(True, x)) + + def test_logical_or(self): + x = np.array([[True, False], [True, True]]) + y = np.array([[False, False], [True, False]]) + self.assertAllClose(knp.logical_or(x, y), np.logical_or(x, y)) + self.assertAllClose(knp.logical_or(x, True), np.logical_or(x, True)) + self.assertAllClose(knp.logical_or(True, x), np.logical_or(True, x)) + + self.assertAllClose(knp.LogicalOr()(x, y), np.logical_or(x, y)) + self.assertAllClose(knp.LogicalOr()(x, True), np.logical_or(x, True)) + self.assertAllClose(knp.LogicalOr()(True, x), np.logical_or(True, x)) + + def test_logspace(self): + self.assertAllClose(knp.logspace(0, 10, 5), np.logspace(0, 10, 5)) + self.assertAllClose( + knp.logspace(0, 10, 5, endpoint=False), + np.logspace(0, 10, 5, endpoint=False), + ) + self.assertAllClose(knp.Logspace(num=5)(0, 10), np.logspace(0, 10, 5)) + self.assertAllClose( + knp.Logspace(num=5, endpoint=False)(0, 10), + np.logspace(0, 10, 5, endpoint=False), + ) + + start = np.zeros([2, 3, 4]) + stop = np.ones([2, 3, 4]) + + self.assertAllClose( + knp.logspace(start, stop, 5, base=10), + np.logspace(start, stop, 5, base=10), + ) + self.assertAllClose( + knp.logspace(start, stop, 5, endpoint=False, base=10), + np.logspace(start, stop, 5, endpoint=False, base=10), + ) + + self.assertAllClose( + knp.Logspace(5, base=10)(start, stop), + np.logspace(start, stop, 5, base=10), + ) + self.assertAllClose( + knp.Logspace(5, endpoint=False, base=10)(start, stop), + np.logspace(start, stop, 5, endpoint=False, base=10), + ) + + def test_maximum(self): + x = np.array([[1, 2], [3, 4]]) + y = np.array([[5, 6], [7, 8]]) + self.assertAllClose(knp.maximum(x, y), np.maximum(x, y)) + self.assertAllClose(knp.maximum(x, 1), np.maximum(x, 1)) + self.assertAllClose(knp.maximum(1, x), np.maximum(1, x)) + + self.assertAllClose(knp.Maximum()(x, y), np.maximum(x, y)) + self.assertAllClose(knp.Maximum()(x, 1), np.maximum(x, 1)) + self.assertAllClose(knp.Maximum()(1, x), np.maximum(1, x)) + + def test_minimum(self): + x = np.array([[1, 2], [3, 4]]) + y = np.array([[5, 6], [7, 8]]) + self.assertAllClose(knp.minimum(x, y), np.minimum(x, y)) + self.assertAllClose(knp.minimum(x, 1), np.minimum(x, 1)) + self.assertAllClose(knp.minimum(1, x), np.minimum(1, x)) + + self.assertAllClose(knp.Minimum()(x, y), np.minimum(x, y)) + self.assertAllClose(knp.Minimum()(x, 1), np.minimum(x, 1)) + self.assertAllClose(knp.Minimum()(1, x), np.minimum(1, x)) + + def test_mod(self): + x = np.array([[1, 2], [3, 4]]) + y = np.array([[5, 6], [7, 8]]) + self.assertAllClose(knp.mod(x, y), np.mod(x, y)) + self.assertAllClose(knp.mod(x, 1), np.mod(x, 1)) + self.assertAllClose(knp.mod(1, x), np.mod(1, x)) + + self.assertAllClose(knp.Mod()(x, y), np.mod(x, y)) + self.assertAllClose(knp.Mod()(x, 1), np.mod(x, 1)) + self.assertAllClose(knp.Mod()(1, x), np.mod(1, x)) + + def test_not_equal(self): + x = np.array([[1, 2], [3, 4]]) + y = np.array([[5, 6], [7, 8]]) + self.assertAllClose(knp.not_equal(x, y), np.not_equal(x, y)) + self.assertAllClose(knp.not_equal(x, 1), np.not_equal(x, 1)) + self.assertAllClose(knp.not_equal(1, x), np.not_equal(1, x)) + + self.assertAllClose(knp.NotEqual()(x, y), np.not_equal(x, y)) + self.assertAllClose(knp.NotEqual()(x, 1), np.not_equal(x, 1)) + self.assertAllClose(knp.NotEqual()(1, x), np.not_equal(1, x)) + + def test_outer(self): + x = np.array([1, 2, 3]) + y = np.array([4, 5, 6]) + self.assertAllClose(knp.outer(x, y), np.outer(x, y)) + self.assertAllClose(knp.Outer()(x, y), np.outer(x, y)) + + x = np.ones([2, 3, 4]) + y = np.ones([2, 3, 4, 5, 6]) + self.assertAllClose(knp.outer(x, y), np.outer(x, y)) + self.assertAllClose(knp.Outer()(x, y), np.outer(x, y)) + + def test_quantile(self): + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + + # q as scalar + q = np.array(0.5, dtype="float32") + self.assertAllClose(knp.quantile(x, q), np.quantile(x, q)) + self.assertAllClose( + knp.quantile(x, q, keepdims=True), np.quantile(x, q, keepdims=True) + ) + + # q as 1D tensor + q = np.array([0.5, 1.0], dtype="float32") + self.assertAllClose(knp.quantile(x, q), np.quantile(x, q)) + self.assertAllClose( + knp.quantile(x, q, keepdims=True), np.quantile(x, q, keepdims=True) + ) + self.assertAllClose( + knp.quantile(x, q, axis=1), np.quantile(x, q, axis=1) + ) + self.assertAllClose( + knp.quantile(x, q, axis=1, keepdims=True), + np.quantile(x, q, axis=1, keepdims=True), + ) + + # multiple axes + self.assertAllClose( + knp.quantile(x, q, axis=(1, 2)), np.quantile(x, q, axis=(1, 2)) + ) + + # test all supported methods + q = np.array([0.501, 1.0], dtype="float32") + for method in ["linear", "lower", "higher", "midpoint", "nearest"]: + self.assertAllClose( + knp.quantile(x, q, method=method), + np.quantile(x, q, method=method), + ) + self.assertAllClose( + knp.quantile(x, q, axis=1, method=method), + np.quantile(x, q, axis=1, method=method), + ) + + def test_take(self): + x = np.arange(24).reshape([1, 2, 3, 4]) + indices = np.array([0, 1]) + self.assertAllClose(knp.take(x, indices), np.take(x, indices)) + self.assertAllClose(knp.take(x, 0), np.take(x, 0)) + self.assertAllClose(knp.take(x, 0, axis=1), np.take(x, 0, axis=1)) + + self.assertAllClose(knp.Take()(x, indices), np.take(x, indices)) + self.assertAllClose(knp.Take()(x, 0), np.take(x, 0)) + self.assertAllClose(knp.Take(axis=1)(x, 0), np.take(x, 0, axis=1)) + + # Test with multi-dimensional indices + rng = np.random.default_rng(0) + x = rng.standard_normal((2, 3, 4, 5)) + indices = rng.integers(0, 4, (6, 7)) + self.assertAllClose( + knp.take(x, indices, axis=2), np.take(x, indices, axis=2) + ) + + # Test with negative axis + self.assertAllClose( + knp.take(x, indices, axis=-2), np.take(x, indices, axis=-2) + ) + + # Test with axis=None & x.ndim=2 + x = np.array(([1, 2], [3, 4])) + indices = np.array([2, 3]) + self.assertAllClose( + knp.take(x, indices, axis=None), np.take(x, indices, axis=None) + ) + + # Test with negative indices + x = rng.standard_normal((2, 3, 4, 5)) + indices = rng.integers(-3, 0, (6, 7)) + self.assertAllClose( + knp.take(x, indices, axis=2), np.take(x, indices, axis=2) + ) + + @parameterized.named_parameters( + named_product( + [ + {"testcase_name": "axis_none", "axis": None}, + {"testcase_name": "axis_0", "axis": 0}, + {"testcase_name": "axis_1", "axis": 1}, + {"testcase_name": "axis_minus1", "axis": -1}, + ], + dtype=[ + "float16", + "float32", + "float64", + "uint8", + "int8", + "int16", + "int32", + ], + ) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_take_sparse(self, dtype, axis): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal((3, 4, 5))).astype(dtype) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + indices = tf.SparseTensor([[0, 0], [1, 2]], [-1, 2], (2, 3)) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + indices = jax_sparse.BCOO(([-1, 2], [[0, 0], [1, 2]]), shape=(2, 3)) + + self.assertAllClose( + knp.take(x, indices, axis=axis), + np.take(x, backend.convert_to_numpy(indices), axis=axis), + ) + + @parameterized.named_parameters( + named_product( + [ + {"testcase_name": "axis_none", "axis": None}, + {"testcase_name": "axis_0", "axis": 0}, + {"testcase_name": "axis_1", "axis": 1}, + {"testcase_name": "axis_minus1", "axis": -1}, + ], + dtype=[ + "float16", + "float32", + "float64", + "uint8", + "int8", + "int16", + "int32", + ], + ) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_RAGGED_TENSORS, + reason="Backend does not support ragged tensors.", + ) + def test_take_ragged(self, dtype, axis): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal((3, 4, 5))).astype(dtype) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + indices = tf.ragged.constant([[2], [0, -1, 1]]) + mask = backend.convert_to_numpy(tf.ones_like(indices)) + + if axis == 0: + mask = np.expand_dims(mask, (2, 3)) + elif axis == 1: + mask = np.expand_dims(mask, (2,)) + + self.assertAllClose( + knp.take(x, indices, axis=axis), + np.take(x, backend.convert_to_numpy(indices), axis=axis) + * mask.astype(dtype), + ) + + def test_take_along_axis(self): + x = np.arange(24).reshape([1, 2, 3, 4]) + indices = np.ones([1, 4, 1, 1], dtype=np.int32) + self.assertAllClose( + knp.take_along_axis(x, indices, axis=1), + np.take_along_axis(x, indices, axis=1), + ) + self.assertAllClose( + knp.TakeAlongAxis(axis=1)(x, indices), + np.take_along_axis(x, indices, axis=1), + ) + + x = np.arange(12).reshape([1, 1, 3, 4]) + indices = np.ones([1, 4, 1, 1], dtype=np.int32) + self.assertAllClose( + knp.take_along_axis(x, indices, axis=2), + np.take_along_axis(x, indices, axis=2), + ) + self.assertAllClose( + knp.TakeAlongAxis(axis=2)(x, indices), + np.take_along_axis(x, indices, axis=2), + ) + + # Test with axis=None + x = np.arange(12).reshape([1, 1, 3, 4]) + indices = np.array([1, 2, 3], dtype=np.int32) + self.assertAllClose( + knp.take_along_axis(x, indices, axis=None), + np.take_along_axis(x, indices, axis=None), + ) + self.assertAllClose( + knp.TakeAlongAxis(axis=None)(x, indices), + np.take_along_axis(x, indices, axis=None), + ) + + # Test with negative indices + x = np.arange(12).reshape([1, 1, 3, 4]) + indices = np.full([1, 4, 1, 1], -1, dtype=np.int32) + self.assertAllClose( + knp.take_along_axis(x, indices, axis=2), + np.take_along_axis(x, indices, axis=2), + ) + self.assertAllClose( + knp.TakeAlongAxis(axis=2)(x, indices), + np.take_along_axis(x, indices, axis=2), + ) + + def test_tensordot(self): + x = np.arange(24).reshape([1, 2, 3, 4]).astype("float32") + y = np.arange(24).reshape([3, 4, 1, 2]).astype("float32") + self.assertAllClose( + knp.tensordot(x, y, axes=2), np.tensordot(x, y, axes=2) + ) + self.assertAllClose( + knp.tensordot(x, y, axes=([0, 1], [2, 3])), + np.tensordot(x, y, axes=([0, 1], [2, 3])), + ) + self.assertAllClose( + knp.Tensordot(axes=2)(x, y), + np.tensordot(x, y, axes=2), + ) + self.assertAllClose( + knp.Tensordot(axes=([0, 1], [2, 3]))(x, y), + np.tensordot(x, y, axes=([0, 1], [2, 3])), + ) + self.assertAllClose( + knp.Tensordot(axes=[0, 2])(x, y), + np.tensordot(x, y, axes=[0, 2]), + ) + + def test_vdot(self): + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + self.assertAllClose(knp.vdot(x, y), np.vdot(x, y)) + self.assertAllClose(knp.Vdot()(x, y), np.vdot(x, y)) + + def test_inner(self): + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + self.assertAllClose(knp.inner(x, y), np.inner(x, y)) + self.assertAllClose(knp.Inner()(x, y), np.inner(x, y)) + + def test_where(self): + x = np.array([1, 2, 3]) + y = np.array([4, 5, 6]) + self.assertAllClose(knp.where(x > 1, x, y), np.where(x > 1, x, y)) + self.assertAllClose(knp.Where()(x > 1, x, y), np.where(x > 1, x, y)) + self.assertAllClose(knp.where(x > 1), np.where(x > 1)) + self.assertAllClose(knp.Where()(x > 1), np.where(x > 1)) + + with self.assertRaisesRegex( + ValueError, "`x1` and `x2` either both should be `None`" + ): + knp.where(x > 1, x, None) + + def test_digitize(self): + x = np.array([0.0, 1.0, 3.0, 1.6]) + bins = np.array([0.0, 3.0, 4.5, 7.0]) + self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins)) + self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins)) + self.assertTrue( + standardize_dtype(knp.digitize(x, bins).dtype) == "int32" + ) + self.assertTrue( + standardize_dtype(knp.Digitize()(x, bins).dtype) == "int32" + ) + + x = np.array([0.2, 6.4, 3.0, 1.6]) + bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0]) + self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins)) + self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins)) + self.assertTrue( + standardize_dtype(knp.digitize(x, bins).dtype) == "int32" + ) + self.assertTrue( + standardize_dtype(knp.Digitize()(x, bins).dtype) == "int32" + ) + + x = np.array([1, 4, 10, 15]) + bins = np.array([4, 10, 14, 15]) + self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins)) + self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins)) + self.assertTrue( + standardize_dtype(knp.digitize(x, bins).dtype) == "int32" + ) + self.assertTrue( + standardize_dtype(knp.Digitize()(x, bins).dtype) == "int32" + ) + + +class NumpyOneInputOpsCorrectnessTest(testing.TestCase): + def test_mean(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.mean(x), np.mean(x)) + self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=())) + self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1)) + self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,))) + self.assertAllClose( + knp.mean(x, axis=1, keepdims=True), + np.mean(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Mean()(x), np.mean(x)) + self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1)) + self.assertAllClose( + knp.Mean(axis=1, keepdims=True)(x), + np.mean(x, axis=1, keepdims=True), + ) + + # test overflow + x = np.array([65504, 65504, 65504], dtype="float16") + self.assertAllClose(knp.mean(x), np.mean(x)) + + def test_all(self): + x = np.array([[True, False, True], [True, True, True]]) + self.assertAllClose(knp.all(x), np.all(x)) + self.assertAllClose(knp.all(x, axis=()), np.all(x, axis=())) + self.assertAllClose(knp.all(x, axis=1), np.all(x, axis=1)) + self.assertAllClose(knp.all(x, axis=(1,)), np.all(x, axis=(1,))) + self.assertAllClose( + knp.all(x, axis=1, keepdims=True), + np.all(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.All()(x), np.all(x)) + self.assertAllClose(knp.All(axis=1)(x), np.all(x, axis=1)) + self.assertAllClose( + knp.All(axis=1, keepdims=True)(x), + np.all(x, axis=1, keepdims=True), + ) + + def test_any(self): + x = np.array([[True, False, True], [True, True, True]]) + self.assertAllClose(knp.any(x), np.any(x)) + self.assertAllClose(knp.any(x, axis=()), np.any(x, axis=())) + self.assertAllClose(knp.any(x, axis=1), np.any(x, axis=1)) + self.assertAllClose(knp.any(x, axis=(1,)), np.any(x, axis=(1,))) + self.assertAllClose( + knp.any(x, axis=1, keepdims=True), + np.any(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Any()(x), np.any(x)) + self.assertAllClose(knp.Any(axis=1)(x), np.any(x, axis=1)) + self.assertAllClose( + knp.Any(axis=1, keepdims=True)(x), + np.any(x, axis=1, keepdims=True), + ) + + def test_var(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.var(x), np.var(x)) + self.assertAllClose(knp.var(x, axis=()), np.var(x, axis=())) + self.assertAllClose(knp.var(x, axis=1), np.var(x, axis=1)) + self.assertAllClose(knp.var(x, axis=(1,)), np.var(x, axis=(1,))) + self.assertAllClose( + knp.var(x, axis=1, keepdims=True), + np.var(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Var()(x), np.var(x)) + self.assertAllClose(knp.Var(axis=1)(x), np.var(x, axis=1)) + self.assertAllClose( + knp.Var(axis=1, keepdims=True)(x), + np.var(x, axis=1, keepdims=True), + ) + + def test_sum(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.sum(x), np.sum(x)) + self.assertAllClose(knp.sum(x, axis=()), np.sum(x, axis=())) + self.assertAllClose(knp.sum(x, axis=1), np.sum(x, axis=1)) + self.assertAllClose(knp.sum(x, axis=(1,)), np.sum(x, axis=(1,))) + self.assertAllClose( + knp.sum(x, axis=1, keepdims=True), + np.sum(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Sum()(x), np.sum(x)) + self.assertAllClose(knp.Sum(axis=1)(x), np.sum(x, axis=1)) + self.assertAllClose( + knp.Sum(axis=1, keepdims=True)(x), + np.sum(x, axis=1, keepdims=True), + ) + + def test_amax(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.amax(x), np.amax(x)) + self.assertAllClose(knp.amax(x, axis=()), np.amax(x, axis=())) + self.assertAllClose(knp.amax(x, axis=1), np.amax(x, axis=1)) + self.assertAllClose(knp.amax(x, axis=(1,)), np.amax(x, axis=(1,))) + self.assertAllClose( + knp.amax(x, axis=1, keepdims=True), + np.amax(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Amax()(x), np.amax(x)) + self.assertAllClose(knp.Amax(axis=1)(x), np.amax(x, axis=1)) + self.assertAllClose( + knp.Amax(axis=1, keepdims=True)(x), + np.amax(x, axis=1, keepdims=True), + ) + + def test_amin(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.amin(x), np.amin(x)) + self.assertAllClose(knp.amin(x, axis=()), np.amin(x, axis=())) + self.assertAllClose(knp.amin(x, axis=1), np.amin(x, axis=1)) + self.assertAllClose(knp.amin(x, axis=(1,)), np.amin(x, axis=(1,))) + self.assertAllClose( + knp.amin(x, axis=1, keepdims=True), + np.amin(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Amin()(x), np.amin(x)) + self.assertAllClose(knp.Amin(axis=1)(x), np.amin(x, axis=1)) + self.assertAllClose( + knp.Amin(axis=1, keepdims=True)(x), + np.amin(x, axis=1, keepdims=True), + ) + + def test_square(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.square(x), np.square(x)) + + self.assertAllClose(knp.Square()(x), np.square(x)) + + def test_negative(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.negative(x), np.negative(x)) + + self.assertAllClose(knp.Negative()(x), np.negative(x)) + + def test_abs(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.abs(x), np.abs(x)) + + self.assertAllClose(knp.Abs()(x), np.abs(x)) + + def test_absolute(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.absolute(x), np.absolute(x)) + + self.assertAllClose(knp.Absolute()(x), np.absolute(x)) + + def test_squeeze(self): + x = np.ones([1, 3, 1, 5]) + self.assertAllClose(knp.squeeze(x), np.squeeze(x)) + self.assertAllClose(knp.squeeze(x, axis=0), np.squeeze(x, axis=0)) + + self.assertAllClose(knp.Squeeze()(x), np.squeeze(x)) + self.assertAllClose(knp.Squeeze(axis=0)(x), np.squeeze(x, axis=0)) + + # Multiple axes + x = np.ones([2, 1, 1, 1]) + self.assertAllClose(knp.squeeze(x, (1, 2)), np.squeeze(x, (1, 2))) + self.assertAllClose(knp.squeeze(x, (-1, -2)), np.squeeze(x, (-1, -2))) + self.assertAllClose(knp.squeeze(x, (1, 2, 3)), np.squeeze(x, (1, 2, 3))) + self.assertAllClose(knp.squeeze(x, (-1, 1)), np.squeeze(x, (-1, 1))) + + self.assertAllClose(knp.Squeeze((1, 2))(x), np.squeeze(x, (1, 2))) + self.assertAllClose(knp.Squeeze((-1, -2))(x), np.squeeze(x, (-1, -2))) + self.assertAllClose(knp.Squeeze((1, 2, 3))(x), np.squeeze(x, (1, 2, 3))) + self.assertAllClose(knp.Squeeze((-1, 1))(x), np.squeeze(x, (-1, 1))) + + def test_transpose(self): + x = np.ones([1, 2, 3, 4, 5]) + self.assertAllClose(knp.transpose(x), np.transpose(x)) + self.assertAllClose( + knp.transpose(x, axes=(1, 0, 3, 2, 4)), + np.transpose(x, axes=(1, 0, 3, 2, 4)), + ) + + self.assertAllClose(knp.Transpose()(x), np.transpose(x)) + self.assertAllClose( + knp.Transpose(axes=(1, 0, 3, 2, 4))(x), + np.transpose(x, axes=(1, 0, 3, 2, 4)), + ) + + def test_arccos(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arccos(x), np.arccos(x)) + + self.assertAllClose(knp.Arccos()(x), np.arccos(x)) + + def test_arccosh(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arccosh(x), np.arccosh(x)) + + self.assertAllClose(knp.Arccosh()(x), np.arccosh(x)) + + def test_arcsin(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arcsin(x), np.arcsin(x)) + + self.assertAllClose(knp.Arcsin()(x), np.arcsin(x)) + + def test_arcsinh(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arcsinh(x), np.arcsinh(x)) + + self.assertAllClose(knp.Arcsinh()(x), np.arcsinh(x)) + + def test_arctan(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arctan(x), np.arctan(x)) + + self.assertAllClose(knp.Arctan()(x), np.arctan(x)) + + def test_arctanh(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.arctanh(x), np.arctanh(x)) + + self.assertAllClose(knp.Arctanh()(x), np.arctanh(x)) + + def test_argmax(self): + x = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]]) + self.assertAllClose(knp.argmax(x), np.argmax(x)) + self.assertAllClose(knp.argmax(x, axis=1), np.argmax(x, axis=1)) + self.assertAllClose( + knp.argmax(x, axis=1, keepdims=True), + np.argmax(x, axis=1, keepdims=True), + ) + self.assertAllClose( + knp.argmax(x, keepdims=True), np.argmax(x, keepdims=True) + ) + + self.assertAllClose(knp.Argmax()(x), np.argmax(x)) + self.assertAllClose(knp.Argmax(axis=1)(x), np.argmax(x, axis=1)) + + self.assertAllClose(knp.Argmax()(x), np.argmax(x)) + self.assertAllClose( + knp.Argmax(keepdims=True)(x), np.argmax(x, keepdims=True) + ) + + def test_argmin(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.argmin(x), np.argmin(x)) + self.assertAllClose(knp.argmin(x, axis=1), np.argmin(x, axis=1)) + self.assertAllClose( + knp.argmin(x, keepdims=True), np.argmin(x, keepdims=True) + ) + + self.assertAllClose(knp.Argmin()(x), np.argmin(x)) + self.assertAllClose(knp.Argmin(axis=1)(x), np.argmin(x, axis=1)) + self.assertAllClose( + knp.Argmin(keepdims=True)(x), np.argmin(x, keepdims=True) + ) + + def test_argsort(self): + x = np.array([[1, 2, 3], [4, 5, 6]]) + self.assertAllClose(knp.argsort(x), np.argsort(x)) + self.assertAllClose(knp.argsort(x, axis=1), np.argsort(x, axis=1)) + self.assertAllClose(knp.argsort(x, axis=None), np.argsort(x, axis=None)) + + self.assertAllClose(knp.Argsort()(x), np.argsort(x)) + self.assertAllClose(knp.Argsort(axis=1)(x), np.argsort(x, axis=1)) + self.assertAllClose(knp.Argsort(axis=None)(x), np.argsort(x, axis=None)) + + x = np.array(1) # rank == 0 + self.assertAllClose(knp.argsort(x), np.argsort(x)) + self.assertAllClose(knp.Argsort()(x), np.argsort(x)) + + def test_array(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.array(x), np.array(x)) + self.assertAllClose(knp.Array()(x), np.array(x)) + self.assertTrue(backend.is_tensor(knp.array(x))) + self.assertTrue(backend.is_tensor(knp.Array()(x))) + + # Check dtype conversion. + x = [[1, 0, 1], [1, 1, 0]] + output = knp.array(x, dtype="int32") + self.assertEqual(standardize_dtype(output.dtype), "int32") + x = [[1, 0, 1], [1, 1, 0]] + output = knp.array(x, dtype="float32") + self.assertEqual(standardize_dtype(output.dtype), "float32") + x = [[1, 0, 1], [1, 1, 0]] + output = knp.array(x, dtype="bool") + self.assertEqual(standardize_dtype(output.dtype), "bool") + + def test_average(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + weights = np.ones([2, 3]) + weights_1d = np.ones([3]) + self.assertAllClose(knp.average(x), np.average(x)) + self.assertAllClose(knp.average(x, axis=()), np.average(x, axis=())) + self.assertAllClose(knp.average(x, axis=1), np.average(x, axis=1)) + self.assertAllClose(knp.average(x, axis=(1,)), np.average(x, axis=(1,))) + self.assertAllClose( + knp.average(x, axis=1, weights=weights), + np.average(x, axis=1, weights=weights), + ) + self.assertAllClose( + knp.average(x, axis=1, weights=weights_1d), + np.average(x, axis=1, weights=weights_1d), + ) + + self.assertAllClose(knp.Average()(x), np.average(x)) + self.assertAllClose(knp.Average(axis=1)(x), np.average(x, axis=1)) + self.assertAllClose( + knp.Average(axis=1)(x, weights=weights), + np.average(x, axis=1, weights=weights), + ) + self.assertAllClose( + knp.Average(axis=1)(x, weights=weights_1d), + np.average(x, axis=1, weights=weights_1d), + ) + + def test_bartlett(self): + x = np.random.randint(1, 100 + 1) + self.assertAllClose(knp.bartlett(x), np.bartlett(x)) + + self.assertAllClose(knp.Bartlett()(x), np.bartlett(x)) + + def test_blackman(self): + x = np.random.randint(1, 100 + 1) + self.assertAllClose(knp.blackman(x), np.blackman(x)) + + self.assertAllClose(knp.Blackman()(x), np.blackman(x)) + + def test_hamming(self): + x = np.random.randint(1, 100 + 1) + self.assertAllClose(knp.hamming(x), np.hamming(x)) + + self.assertAllClose(knp.Hamming()(x), np.hamming(x)) + + def test_hanning(self): + x = np.random.randint(1, 100 + 1) + self.assertAllClose(knp.hanning(x), np.hanning(x)) + + self.assertAllClose(knp.Hanning()(x), np.hanning(x)) + + def test_kaiser(self): + x = np.random.randint(1, 100 + 1) + beta = float(np.random.randint(10, 20 + 1)) + self.assertAllClose(knp.kaiser(x, beta), np.kaiser(x, beta)) + + self.assertAllClose(knp.Kaiser(beta)(x), np.kaiser(x, beta)) + + @parameterized.named_parameters( + named_product(sparse_input=(False, True), sparse_arg=(False, True)) + ) + def test_bincount(self, sparse_input, sparse_arg): + if (sparse_input or sparse_arg) and not backend.SUPPORTS_SPARSE_TENSORS: + pytest.skip("Backend does not support sparse tensors") + if testing.tensorflow_uses_gpu(): + self.skipTest("bincount does not work in tensorflow gpu") + + x = x_np = np.array([1, 1, 2, 3, 2, 4, 4, 6]) + weights = weights_np = np.array([0, 0, 3, 2, 1, 1, 4, 2]) + if sparse_input: + indices = np.array([[1], [3], [5], [7], [9], [11], [13], [15]]) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + x = tf.SparseTensor(indices, x, (16,)) + weights = tf.SparseTensor(indices, weights, (16,)) + elif backend.backend() == "jax": + from jax.experimental import sparse as jax_sparse + + x = jax_sparse.BCOO((x, indices), shape=(16,)) + weights = jax_sparse.BCOO((weights, indices), shape=(16,)) + + minlength = 3 + output = knp.bincount( + x, weights=weights, minlength=minlength, sparse=sparse_arg + ) + self.assertAllClose( + output, np.bincount(x_np, weights=weights_np, minlength=minlength) + ) + self.assertSparse(output, sparse_input or sparse_arg) + output = knp.Bincount( + weights=weights, minlength=minlength, sparse=sparse_arg + )(x) + self.assertAllClose( + output, np.bincount(x_np, weights=weights_np, minlength=minlength) + ) + self.assertSparse(output, sparse_input or sparse_arg) + + x = knp.expand_dims(x, 0) + weights = knp.expand_dims(weights, 0) + + expected_output = np.array([[0, 0, 4, 2, 5, 0, 2]]) + output = knp.bincount( + x, weights=weights, minlength=minlength, sparse=sparse_arg + ) + self.assertAllClose(output, expected_output) + self.assertSparse(output, sparse_input or sparse_arg) + output = knp.Bincount( + weights=weights, minlength=minlength, sparse=sparse_arg + )(x) + self.assertAllClose(output, expected_output) + self.assertSparse(output, sparse_input or sparse_arg) + + # test with weights=None + expected_output = np.array([[0, 2, 2, 1, 2, 0, 1]]) + output = knp.Bincount( + weights=None, minlength=minlength, sparse=sparse_arg + )(x) + self.assertAllClose(output, expected_output) + self.assertSparse(output, sparse_input or sparse_arg) + + def test_bitwise_invert(self): + x = np.array([2, 5, 255]) + self.assertAllClose(knp.bitwise_invert(x), np.bitwise_not(x)) + self.assertAllClose(knp.BitwiseInvert()(x), np.bitwise_not(x)) + + # bitwise_not is same as bitwise_invert + + def test_broadcast_to(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose( + knp.broadcast_to(x, [2, 2, 3]), + np.broadcast_to(x, [2, 2, 3]), + ) + + self.assertAllClose( + knp.BroadcastTo([2, 2, 3])(x), + np.broadcast_to(x, [2, 2, 3]), + ) + + def test_cbrt(self): + x = np.array([[-8, -1, 0], [1, 8, 27]], dtype="float32") + ref_y = np.sign(x) * np.abs(x) ** (1.0 / 3.0) + y = knp.cbrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + + y = knp.Cbrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + + def test_ceil(self): + x = np.array([[1.2, 2.1, -2.5], [2.4, -11.9, -5.5]]) + self.assertAllClose(knp.ceil(x), np.ceil(x)) + self.assertAllClose(knp.Ceil()(x), np.ceil(x)) + + def test_clip(self): + x = np.array([[1.2, 2.1, 0.5], [2.4, 11.9, 0.5]]) + self.assertAllClose(knp.clip(x, 1, 2), np.clip(x, 1, 2)) + self.assertAllClose(knp.clip(x, 1, 2), np.clip(x, 1, 2)) + + self.assertAllClose(knp.Clip(0, 1)(x), np.clip(x, 0, 1)) + self.assertAllClose(knp.Clip(0, 1)(x), np.clip(x, 0, 1)) + + def test_concatenate(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + z = np.array([[7, 8, 9], [9, 8, 7]]) + self.assertAllClose( + knp.concatenate([x, y], axis=0), + np.concatenate([x, y], axis=0), + ) + self.assertAllClose( + knp.concatenate([x, y, z], axis=0), + np.concatenate([x, y, z], axis=0), + ) + self.assertAllClose( + knp.concatenate([x, y], axis=1), + np.concatenate([x, y], axis=1), + ) + + self.assertAllClose( + knp.Concatenate(axis=0)([x, y]), + np.concatenate([x, y], axis=0), + ) + self.assertAllClose( + knp.Concatenate(axis=0)([x, y, z]), + np.concatenate([x, y, z], axis=0), + ) + self.assertAllClose( + knp.Concatenate(axis=1)([x, y]), + np.concatenate([x, y], axis=1), + ) + + @parameterized.named_parameters( + [ + {"testcase_name": "axis_0", "axis": 0}, + {"testcase_name": "axis_1", "axis": 1}, + ] + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_concatenate_sparse(self, axis): + if backend.backend() == "tensorflow": + import tensorflow as tf + + x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3)) + y = tf.SparseTensor([[0, 0], [1, 1]], [4.0, 5.0], (2, 3)) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 3)) + y = jax_sparse.BCOO(([4.0, 5.0], [[0, 0], [1, 1]]), shape=(2, 3)) + + x_np = backend.convert_to_numpy(x) + y_np = backend.convert_to_numpy(y) + z = np.random.rand(2, 3).astype("float32") + + self.assertAllClose( + knp.concatenate([x, z], axis=axis), + np.concatenate([x_np, z], axis=axis), + ) + self.assertAllClose( + knp.concatenate([z, x], axis=axis), + np.concatenate([z, x_np], axis=axis), + ) + self.assertAllClose( + knp.concatenate([x, y], axis=axis), + np.concatenate([x_np, y_np], axis=axis), + ) + + self.assertAllClose( + knp.Concatenate(axis=axis)([x, z]), + np.concatenate([x_np, z], axis=axis), + ) + self.assertAllClose( + knp.Concatenate(axis=axis)([z, x]), + np.concatenate([z, x_np], axis=axis), + ) + self.assertAllClose( + knp.Concatenate(axis=axis)([x, y]), + np.concatenate([x_np, y_np], axis=axis), + ) + + self.assertSparse(knp.concatenate([x, y], axis=axis)) + self.assertSparse(knp.Concatenate(axis=axis)([x, y])) + + def test_conjugate(self): + x = np.array([[1 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]]) + self.assertAllClose(knp.conjugate(x), np.conjugate(x)) + self.assertAllClose(knp.Conjugate()(x), np.conjugate(x)) + + def test_conj(self): + x = np.array([[1 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]]) + self.assertAllClose(knp.conj(x), np.conj(x)) + self.assertAllClose(knp.Conj()(x), np.conj(x)) + + def test_copy(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.copy(x), np.copy(x)) + self.assertAllClose(knp.Copy()(x), np.copy(x)) + + def test_corrcoef(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.corrcoef(x), np.corrcoef(x)) + self.assertAllClose(knp.Corrcoef()(x), np.corrcoef(x)) + + def test_cos(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.cos(x), np.cos(x)) + self.assertAllClose(knp.Cos()(x), np.cos(x)) + + def test_cosh(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.cosh(x), np.cosh(x)) + self.assertAllClose(knp.Cosh()(x), np.cosh(x)) + + def test_count_nonzero(self): + x = np.array([[0, 2, 3], [3, 2, 0]]) + self.assertAllClose(knp.count_nonzero(x), np.count_nonzero(x)) + self.assertAllClose( + knp.count_nonzero(x, axis=()), np.count_nonzero(x, axis=()) + ) + self.assertAllClose( + knp.count_nonzero(x, axis=1), + np.count_nonzero(x, axis=1), + ) + self.assertAllClose( + knp.count_nonzero(x, axis=(1,)), + np.count_nonzero(x, axis=(1,)), + ) + + self.assertAllClose( + knp.CountNonzero()(x), + np.count_nonzero(x), + ) + self.assertAllClose( + knp.CountNonzero(axis=1)(x), + np.count_nonzero(x, axis=1), + ) + + @parameterized.product( + axis=[None, 0, 1, -1], + dtype=[None, "int32", "float32"], + ) + def test_cumprod(self, axis, dtype): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose( + knp.cumprod(x, axis=axis, dtype=dtype), + np.cumprod(x, axis=axis, dtype=dtype or x.dtype), + ) + self.assertAllClose( + knp.Cumprod(axis=axis, dtype=dtype)(x), + np.cumprod(x, axis=axis, dtype=dtype or x.dtype), + ) + + @parameterized.product( + axis=[None, 0, 1, -1], + dtype=[None, "int32", "float32"], + ) + def test_cumsum(self, axis, dtype): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose( + knp.cumsum(x, axis=axis, dtype=dtype), + np.cumsum(x, axis=axis, dtype=dtype or x.dtype), + ) + self.assertAllClose( + knp.Cumsum(axis=axis, dtype=dtype)(x), + np.cumsum(x, axis=axis, dtype=dtype or x.dtype), + ) + + def test_deg2rad(self): + x = np.random.uniform(-360, 360, size=(3, 3)) + self.assertAllClose(knp.deg2rad(x), np.deg2rad(x)) + self.assertAllClose(knp.Deg2rad()(x), np.deg2rad(x)) + + def test_diag(self): + x = np.array([1, 2, 3]) + self.assertAllClose(knp.diag(x), np.diag(x)) + self.assertAllClose(knp.diag(x, k=1), np.diag(x, k=1)) + self.assertAllClose(knp.diag(x, k=-1), np.diag(x, k=-1)) + + self.assertAllClose(knp.Diag()(x), np.diag(x)) + self.assertAllClose(knp.Diag(k=1)(x), np.diag(x, k=1)) + self.assertAllClose(knp.Diag(k=-1)(x), np.diag(x, k=-1)) + + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.diag(x), np.diag(x)) + self.assertAllClose(knp.diag(x, k=1), np.diag(x, k=1)) + self.assertAllClose(knp.diag(x, k=-1), np.diag(x, k=-1)) + + self.assertAllClose(knp.Diag()(x), np.diag(x)) + self.assertAllClose(knp.Diag(k=1)(x), np.diag(x, k=1)) + self.assertAllClose(knp.Diag(k=-1)(x), np.diag(x, k=-1)) + + def test_diagflat(self): + x = np.array([1, 2, 3]) + self.assertAllClose(knp.diagflat(x), np.diagflat(x)) + self.assertAllClose(knp.diagflat(x, k=1), np.diagflat(x, k=1)) + self.assertAllClose(knp.diagflat(x, k=-1), np.diagflat(x, k=-1)) + + x = np.array([[1, 2], [3, 4]]) + self.assertAllClose(knp.diagflat(x), np.diagflat(x)) + self.assertAllClose(knp.diagflat(x, k=1), np.diagflat(x, k=1)) + self.assertAllClose(knp.diagflat(x, k=-1), np.diagflat(x, k=-1)) + + x = np.array([1, 2, 3, 4]) + self.assertAllClose(knp.diagflat(x), np.diagflat(x)) + self.assertAllClose(knp.diagflat(x, k=2), np.diagflat(x, k=2)) + self.assertAllClose(knp.diagflat(x, k=-2), np.diagflat(x, k=-2)) + + x_float = np.array([1.1, 2.2, 3.3]) + self.assertAllClose(knp.diagflat(x_float), np.diagflat(x_float)) + + x_complex = np.array([1 + 1j, 2 + 2j, 3 + 3j]) + self.assertAllClose(knp.diagflat(x_complex), np.diagflat(x_complex)) + + x = np.array([1, 2, 3]) + self.assertAllClose(knp.Diagflat()(x), np.diagflat(x)) + self.assertAllClose(knp.Diagflat(k=1)(x), np.diagflat(x, k=1)) + self.assertAllClose(knp.Diagflat(k=-1)(x), np.diagflat(x, k=-1)) + + def test_diagonal(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.diagonal(x), np.diagonal(x)) + self.assertAllClose( + knp.diagonal(x, offset=1), + np.diagonal(x, offset=1), + ) + self.assertAllClose( + knp.diagonal(x, offset=-1), np.diagonal(x, offset=-1) + ) + + self.assertAllClose(knp.Diagonal()(x), np.diagonal(x)) + self.assertAllClose(knp.Diagonal(offset=1)(x), np.diagonal(x, offset=1)) + self.assertAllClose( + knp.Diagonal(offset=-1)(x), np.diagonal(x, offset=-1) + ) + + x = np.ones([2, 3, 4, 5]) + self.assertAllClose(knp.diagonal(x), np.diagonal(x)) + self.assertAllClose( + knp.diagonal(x, offset=1, axis1=2, axis2=3), + np.diagonal(x, offset=1, axis1=2, axis2=3), + ) + self.assertAllClose( + knp.diagonal(x, offset=-1, axis1=2, axis2=3), + np.diagonal(x, offset=-1, axis1=2, axis2=3), + ) + + def test_diff(self): + x = np.array([1, 2, 4, 7, 0]) + self.assertAllClose(knp.diff(x), np.diff(x)) + self.assertAllClose(knp.diff(x, n=2), np.diff(x, n=2)) + self.assertAllClose(knp.diff(x, n=3), np.diff(x, n=3)) + + x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]]) + self.assertAllClose(knp.diff(x), np.diff(x)) + self.assertAllClose(knp.diff(x, axis=0), np.diff(x, axis=0)) + self.assertAllClose(knp.diff(x, n=2, axis=0), np.diff(x, n=2, axis=0)) + self.assertAllClose(knp.diff(x, n=2, axis=1), np.diff(x, n=2, axis=1)) + + # Test n=0 + x = np.array([1, 2, 4, 7, 0]) + self.assertAllClose(knp.diff(x, n=0), np.diff(x, n=0)) + + def test_dot(self): + x = np.arange(24).reshape([2, 3, 4]).astype("float32") + y = np.arange(12).reshape([4, 3]).astype("float32") + z = np.arange(4).astype("float32") + self.assertAllClose(knp.dot(x, y), np.dot(x, y)) + self.assertAllClose(knp.dot(x, z), np.dot(x, z)) + self.assertAllClose(knp.dot(x, 2), np.dot(x, 2)) + + self.assertAllClose(knp.Dot()(x, y), np.dot(x, y)) + self.assertAllClose(knp.Dot()(x, z), np.dot(x, z)) + self.assertAllClose(knp.Dot()(x, 2), np.dot(x, 2)) + + def test_exp(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.exp(x), np.exp(x)) + self.assertAllClose(knp.Exp()(x), np.exp(x)) + + def test_exp2(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.exp2(x), np.exp2(x)) + self.assertAllClose(knp.Exp2()(x), np.exp2(x)) + + def test_expand_dims(self): + x = np.ones([2, 3, 4]) + self.assertAllClose(knp.expand_dims(x, 0), np.expand_dims(x, 0)) + self.assertAllClose(knp.expand_dims(x, 1), np.expand_dims(x, 1)) + self.assertAllClose(knp.expand_dims(x, -2), np.expand_dims(x, -2)) + + self.assertAllClose(knp.ExpandDims(0)(x), np.expand_dims(x, 0)) + self.assertAllClose(knp.ExpandDims(1)(x), np.expand_dims(x, 1)) + self.assertAllClose(knp.ExpandDims(-2)(x), np.expand_dims(x, -2)) + + # Multiple axes + self.assertAllClose( + knp.expand_dims(x, (1, 2)), np.expand_dims(x, (1, 2)) + ) + self.assertAllClose( + knp.expand_dims(x, (-1, -2)), np.expand_dims(x, (-1, -2)) + ) + self.assertAllClose( + knp.expand_dims(x, (-1, 1)), np.expand_dims(x, (-1, 1)) + ) + + self.assertAllClose( + knp.ExpandDims((1, 2))(x), np.expand_dims(x, (1, 2)) + ) + self.assertAllClose( + knp.ExpandDims((-1, -2))(x), np.expand_dims(x, (-1, -2)) + ) + self.assertAllClose( + knp.ExpandDims((-1, 1))(x), np.expand_dims(x, (-1, 1)) + ) + + def test_expm1(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.expm1(x), np.expm1(x)) + self.assertAllClose(knp.Expm1()(x), np.expm1(x)) + + def test_flip(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.flip(x), np.flip(x)) + self.assertAllClose(knp.flip(x, 0), np.flip(x, 0)) + self.assertAllClose(knp.flip(x, 1), np.flip(x, 1)) + + self.assertAllClose(knp.Flip()(x), np.flip(x)) + self.assertAllClose(knp.Flip(0)(x), np.flip(x, 0)) + self.assertAllClose(knp.Flip(1)(x), np.flip(x, 1)) + + def test_floor(self): + x = np.array([[1.1, 2.2, -3.3], [3.3, 2.2, -1.1]]) + self.assertAllClose(knp.floor(x), np.floor(x)) + self.assertAllClose(knp.Floor()(x), np.floor(x)) + + def test_hstack(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + self.assertAllClose(knp.hstack([x, y]), np.hstack([x, y])) + self.assertAllClose(knp.Hstack()([x, y]), np.hstack([x, y])) + + x = np.ones([2, 3, 4]) + y = np.ones([2, 5, 4]) + self.assertAllClose(knp.hstack([x, y]), np.hstack([x, y])) + self.assertAllClose(knp.Hstack()([x, y]), np.hstack([x, y])) + + def test_imag(self): + x = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [3 + 3j, 2 + 2j, 1 + 1j]]) + self.assertAllClose(knp.imag(x), np.imag(x)) + self.assertAllClose(knp.Imag()(x), np.imag(x)) + + def test_isfinite(self): + x = np.array([[1, 2, np.inf], [np.nan, np.nan, np.nan]]) + self.assertAllClose(knp.isfinite(x), np.isfinite(x)) + self.assertAllClose(knp.Isfinite()(x), np.isfinite(x)) + + def test_isinf(self): + x = np.array([[1, 2, np.inf], [np.nan, np.nan, np.nan]]) + self.assertAllClose(knp.isinf(x), np.isinf(x)) + self.assertAllClose(knp.Isinf()(x), np.isinf(x)) + + def test_isnan(self): + x = np.array([[1, 2, np.inf], [np.nan, np.nan, np.nan]]) + self.assertAllClose(knp.isnan(x), np.isnan(x)) + self.assertAllClose(knp.Isnan()(x), np.isnan(x)) + + def test_isneginf(self): + x = np.array( + [[1, 2, np.inf, -np.inf], [np.nan, np.nan, np.nan, np.nan]] + ) + self.assertAllClose(knp.isneginf(x), np.isneginf(x)) + self.assertAllClose(knp.Isneginf()(x), np.isneginf(x)) + + def test_isposinf(self): + x = np.array( + [[1, 2, np.inf, -np.inf], [np.nan, np.nan, np.nan, np.nan]] + ) + self.assertAllClose(knp.isposinf(x), np.isposinf(x)) + self.assertAllClose(knp.Isposinf()(x), np.isposinf(x)) + + def test_log(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.log(x), np.log(x)) + self.assertAllClose(knp.Log()(x), np.log(x)) + + def test_log10(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.log10(x), np.log10(x)) + self.assertAllClose(knp.Log10()(x), np.log10(x)) + + def test_log1p(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.log1p(x), np.log1p(x)) + self.assertAllClose(knp.Log1p()(x), np.log1p(x)) + + def test_log2(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.log2(x), np.log2(x)) + self.assertAllClose(knp.Log2()(x), np.log2(x)) + + def test_logaddexp(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.logaddexp(x, y), np.logaddexp(x, y)) + self.assertAllClose(knp.Logaddexp()(x, y), np.logaddexp(x, y)) + + def test_logaddexp2(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.logaddexp2(x, y), np.logaddexp2(x, y)) + self.assertAllClose(knp.Logaddexp2()(x, y), np.logaddexp2(x, y)) + + def test_logical_not(self): + x = np.array([[True, False], [False, True]]) + self.assertAllClose(knp.logical_not(x), np.logical_not(x)) + self.assertAllClose(knp.LogicalNot()(x), np.logical_not(x)) + + def test_max(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.max(x), np.max(x)) + self.assertAllClose(knp.Max()(x), np.max(x)) + + self.assertAllClose(knp.max(x, 0), np.max(x, 0)) + self.assertAllClose(knp.Max(0)(x), np.max(x, 0)) + + self.assertAllClose(knp.max(x, 1), np.max(x, 1)) + self.assertAllClose(knp.Max(1)(x), np.max(x, 1)) + + # test max with initial + self.assertAllClose(knp.max(x, initial=4), 4) + + # test empty tensor + x = np.array([[]]) + self.assertAllClose(knp.max(x, initial=1), np.max(x, initial=1)) + self.assertAllClose( + knp.max(x, initial=1, keepdims=True), + np.max(x, initial=1, keepdims=True), + ) + + def test_min(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.min(x), np.min(x)) + self.assertAllClose(knp.Min()(x), np.min(x)) + + self.assertAllClose(knp.min(x, axis=(0, 1)), np.min(x, (0, 1))) + self.assertAllClose(knp.Min((0, 1))(x), np.min(x, (0, 1))) + + self.assertAllClose(knp.min(x, axis=()), np.min(x, axis=())) + self.assertAllClose(knp.Min(())(x), np.min(x, axis=())) + + self.assertAllClose(knp.min(x, 0), np.min(x, 0)) + self.assertAllClose(knp.Min(0)(x), np.min(x, 0)) + + self.assertAllClose(knp.min(x, 1), np.min(x, 1)) + self.assertAllClose(knp.Min(1)(x), np.min(x, 1)) + + # test min with initial + self.assertAllClose(knp.min(x, initial=0), 0) + + # test empty tensor + x = np.array([[]]) + self.assertAllClose(knp.min(x, initial=1), np.min(x, initial=1)) + self.assertAllClose( + knp.min(x, initial=1, keepdims=True), + np.min(x, initial=1, keepdims=True), + ) + + def test_median(self): + x = np.array([[1, 2, 3], [3, 2, 1]]).astype("float32") + self.assertAllClose(knp.median(x), np.median(x)) + self.assertAllClose( + knp.median(x, keepdims=True), np.median(x, keepdims=True) + ) + self.assertAllClose(knp.median(x, axis=1), np.median(x, axis=1)) + self.assertAllClose(knp.median(x, axis=(1,)), np.median(x, axis=(1,))) + self.assertAllClose( + knp.median(x, axis=1, keepdims=True), + np.median(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Median()(x), np.median(x)) + self.assertAllClose(knp.Median(axis=1)(x), np.median(x, axis=1)) + self.assertAllClose( + knp.Median(axis=1, keepdims=True)(x), + np.median(x, axis=1, keepdims=True), + ) + + def test_meshgrid(self): + x = np.array([1, 2, 3]) + y = np.array([4, 5, 6]) + z = np.array([7, 8, 9]) + self.assertAllClose(knp.meshgrid(x, y), np.meshgrid(x, y)) + self.assertAllClose(knp.meshgrid(x, z), np.meshgrid(x, z)) + self.assertAllClose( + knp.meshgrid(x, y, z, indexing="ij"), + np.meshgrid(x, y, z, indexing="ij"), + ) + self.assertAllClose(knp.Meshgrid()(x, y), np.meshgrid(x, y)) + self.assertAllClose(knp.Meshgrid()(x, z), np.meshgrid(x, z)) + self.assertAllClose( + knp.Meshgrid(indexing="ij")(x, y, z), + np.meshgrid(x, y, z, indexing="ij"), + ) + + if backend.backend() == "tensorflow": + # Arguments to `jax.numpy.meshgrid` must be 1D now. + x = np.ones([1, 2, 3]) + y = np.ones([4, 5, 6, 6]) + z = np.ones([7, 8]) + self.assertAllClose(knp.meshgrid(x, y), np.meshgrid(x, y)) + self.assertAllClose(knp.meshgrid(x, z), np.meshgrid(x, z)) + self.assertAllClose( + knp.meshgrid(x, y, z, indexing="ij"), + np.meshgrid(x, y, z, indexing="ij"), + ) + self.assertAllClose(knp.Meshgrid()(x, y), np.meshgrid(x, y)) + self.assertAllClose(knp.Meshgrid()(x, z), np.meshgrid(x, z)) + self.assertAllClose( + knp.Meshgrid(indexing="ij")(x, y, z), + np.meshgrid(x, y, z, indexing="ij"), + ) + + def test_moveaxis(self): + x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) + self.assertAllClose(knp.moveaxis(x, 0, -1), np.moveaxis(x, 0, -1)) + self.assertAllClose(knp.moveaxis(x, -1, 0), np.moveaxis(x, -1, 0)) + self.assertAllClose( + knp.moveaxis(x, (0, 1), (1, 0)), + np.moveaxis(x, (0, 1), (1, 0)), + ) + self.assertAllClose( + knp.moveaxis(x, [0, 1, 2], [2, 0, 1]), + np.moveaxis(x, [0, 1, 2], [2, 0, 1]), + ) + self.assertAllClose(knp.Moveaxis(-1, 0)(x), np.moveaxis(x, -1, 0)) + self.assertAllClose( + knp.Moveaxis((0, 1), (1, 0))(x), + np.moveaxis(x, (0, 1), (1, 0)), + ) + + self.assertAllClose( + knp.Moveaxis([0, 1, 2], [2, 0, 1])(x), + np.moveaxis(x, [0, 1, 2], [2, 0, 1]), + ) + + def test_ndim(self): + x = np.array([1, 2, 3]) + self.assertEqual(knp.ndim(x), np.ndim(x)) + self.assertEqual(knp.Ndim()(x), np.ndim(x)) + + def test_nonzero(self): + x = np.array([[0, 0, 3], [3, 0, 0]]) + self.assertAllClose(knp.nonzero(x), np.nonzero(x)) + self.assertAllClose(knp.Nonzero()(x), np.nonzero(x)) + + def test_ones_like(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.ones_like(x), np.ones_like(x)) + self.assertAllClose(knp.OnesLike()(x), np.ones_like(x)) + + @parameterized.named_parameters( + named_product( + dtype=[ + "float16", + "float32", + "float64", + "uint8", + "int8", + "int16", + "int32", + ], + mode=["constant", "reflect", "symmetric"], + constant_values=[None, 0, 2], + ) + ) + def test_pad(self, dtype, mode, constant_values): + # 2D + x = np.ones([2, 3], dtype=dtype) + pad_width = ((1, 1), (1, 1)) + + if mode != "constant": + if constant_values is not None: + with self.assertRaisesRegex( + ValueError, + "Argument `constant_values` can only be " + "provided when `mode == 'constant'`", + ): + knp.pad( + x, pad_width, mode=mode, constant_values=constant_values + ) + return + # constant_values is None + kwargs = {} + else: + # mode is constant + kwargs = {"constant_values": constant_values or 0} + + self.assertAllClose( + knp.pad(x, pad_width, mode=mode, constant_values=constant_values), + np.pad(x, pad_width, mode=mode, **kwargs), + ) + self.assertAllClose( + knp.Pad(pad_width, mode=mode)(x, constant_values=constant_values), + np.pad(x, pad_width, mode=mode, **kwargs), + ) + + # 5D (pad last 3D) + x = np.ones([2, 3, 4, 5, 6], dtype=dtype) + pad_width = ((0, 0), (0, 0), (2, 3), (1, 1), (1, 1)) + self.assertAllClose( + knp.pad(x, pad_width, mode=mode, constant_values=constant_values), + np.pad(x, pad_width, mode=mode, **kwargs), + ) + self.assertAllClose( + knp.Pad(pad_width, mode=mode)(x, constant_values=constant_values), + np.pad(x, pad_width, mode=mode, **kwargs), + ) + + # 5D (pad arbitrary dimensions) + if backend.backend() == "torch" and mode != "constant": + self.skipTest( + "reflect and symmetric padding for arbitrary dimensions " + "are not supported by torch" + ) + x = np.ones([2, 3, 4, 5, 6], dtype=dtype) + pad_width = ((1, 1), (2, 1), (3, 2), (4, 3), (5, 4)) + self.assertAllClose( + knp.pad(x, pad_width, mode=mode, constant_values=constant_values), + np.pad(x, pad_width, mode=mode, **kwargs), + ) + self.assertAllClose( + knp.Pad(pad_width, mode=mode)(x, constant_values=constant_values), + np.pad(x, pad_width, mode=mode, **kwargs), + ) + + def test_prod(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.prod(x), np.prod(x)) + self.assertAllClose(knp.prod(x, axis=()), np.prod(x, axis=())) + self.assertAllClose(knp.prod(x, axis=1), np.prod(x, axis=1)) + self.assertAllClose(knp.prod(x, axis=(1,)), np.prod(x, axis=(1,))) + self.assertAllClose( + knp.prod(x, axis=1, keepdims=True), + np.prod(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Prod()(x), np.prod(x)) + self.assertAllClose(knp.Prod(axis=1)(x), np.prod(x, axis=1)) + self.assertAllClose( + knp.Prod(axis=1, keepdims=True)(x), + np.prod(x, axis=1, keepdims=True), + ) + + def test_ravel(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.ravel(x), np.ravel(x)) + self.assertAllClose(knp.Ravel()(x), np.ravel(x)) + + def test_unravel_index(self): + x = np.array([0, 1, 2, 3]) + shape = (2, 2) + self.assertAllClose( + knp.unravel_index(x, shape), np.unravel_index(x, shape) + ) + + x = np.array([[0, 1], [2, 3]]) + shape = (2, 2) + self.assertAllClose( + knp.unravel_index(x, shape), np.unravel_index(x, shape) + ) + + def test_real(self): + x = np.array([[1, 2, 3 - 3j], [3, 2, 1 + 5j]]) + self.assertAllClose(knp.real(x), np.real(x)) + self.assertAllClose(knp.Real()(x), np.real(x)) + + def test_reciprocal(self): + x = np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]) + self.assertAllClose(knp.reciprocal(x), np.reciprocal(x)) + self.assertAllClose(knp.Reciprocal()(x), np.reciprocal(x)) + + def test_repeat(self): + x = np.array([[1, 2], [3, 4]]) + self.assertAllClose(knp.repeat(x, 2), np.repeat(x, 2)) + self.assertAllClose( + knp.Repeat(np.array([2]))(x), + np.repeat(x, np.array([2])), + ) + self.assertAllClose(knp.repeat(x, 3, axis=1), np.repeat(x, 3, axis=1)) + self.assertAllClose( + knp.repeat(x, np.array([1, 2]), axis=-1), + np.repeat(x, np.array([1, 2]), axis=-1), + ) + self.assertAllClose(knp.Repeat(2)(x), np.repeat(x, 2)) + self.assertAllClose(knp.Repeat(3, axis=1)(x), np.repeat(x, 3, axis=1)) + self.assertAllClose( + knp.Repeat(np.array([1, 2]), axis=0)(x), + np.repeat(x, np.array([1, 2]), axis=0), + ) + + def test_reshape(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.reshape(x, [3, 2]), np.reshape(x, [3, 2])) + self.assertAllClose(knp.Reshape([3, 2])(x), np.reshape(x, [3, 2])) + self.assertAllClose(knp.Reshape(-1)(x), np.reshape(x, -1)) + + def test_roll(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.roll(x, 1), np.roll(x, 1)) + self.assertAllClose(knp.roll(x, 1, axis=1), np.roll(x, 1, axis=1)) + self.assertAllClose(knp.roll(x, -1, axis=0), np.roll(x, -1, axis=0)) + self.assertAllClose(knp.Roll(1)(x), np.roll(x, 1)) + self.assertAllClose(knp.Roll(1, axis=1)(x), np.roll(x, 1, axis=1)) + self.assertAllClose(knp.Roll(-1, axis=0)(x), np.roll(x, -1, axis=0)) + + def test_round(self): + x = np.array([[1.1, 2.5, 3.9], [3.2, 2.3, 1.8]]) + self.assertAllClose(knp.round(x), np.round(x)) + self.assertAllClose(knp.Round()(x), np.round(x)) + + # Test with decimal=1 + self.assertAllClose(knp.round(x, decimals=1), np.round(x, decimals=1)) + self.assertAllClose(knp.Round(decimals=1)(x), np.round(x, decimals=1)) + + # Test with integers + x = np.array([[1, 2, 3], [3, 2, 1]], dtype="int32") + self.assertAllClose(knp.round(x, decimals=1), np.round(x, decimals=1)) + self.assertAllClose(knp.Round(decimals=1)(x), np.round(x, decimals=1)) + + # Test with integers and decimal < 0 + x = np.array([[123, 234, 345], [345, 234, 123]], dtype="int32") + self.assertAllClose(knp.round(x, decimals=-1), np.round(x, decimals=-1)) + self.assertAllClose(knp.Round(decimals=-1)(x), np.round(x, decimals=-1)) + + def test_searchsorted(self): + a = np.array([1, 2, 2, 3, 4, 5, 5]) + v = np.array([4, 3, 5, 1, 2]) + expected = np.searchsorted(a, v).astype("int32") + self.assertAllEqual(knp.searchsorted(a, v), expected) + self.assertAllEqual(knp.SearchSorted()(a, v), expected) + + def test_sign(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.sign(x), np.sign(x)) + self.assertAllClose(knp.Sign()(x), np.sign(x)) + + def test_signbit(self): + x = np.array([[0.0, -0.0, -1.1e-45], [1.1e-38, 2, -1]]) + self.assertAllClose(knp.signbit(x), np.signbit(x)) + self.assertAllClose(knp.Signbit()(x), np.signbit(x)) + + def test_sin(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.sin(x), np.sin(x)) + self.assertAllClose(knp.Sin()(x), np.sin(x)) + + def test_sinh(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.sinh(x), np.sinh(x)) + self.assertAllClose(knp.Sinh()(x), np.sinh(x)) + + def test_size(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.size(x), np.size(x)) + self.assertAllClose(knp.Size()(x), np.size(x)) + + def test_sort(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.sort(x), np.sort(x)) + self.assertAllClose(knp.Sort()(x), np.sort(x)) + self.assertAllClose(knp.sort(x, axis=0), np.sort(x, axis=0)) + self.assertAllClose(knp.Sort(axis=0)(x), np.sort(x, axis=0)) + + def test_split(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertIsInstance(knp.split(x, 2), list) + self.assertAllClose(knp.split(x, 2), np.split(x, 2)) + self.assertAllClose(knp.Split(2)(x), np.split(x, 2)) + self.assertAllClose( + knp.split(x, [1, 2], axis=1), + np.split(x, [1, 2], axis=1), + ) + self.assertAllClose( + knp.Split([1, 2], axis=1)(x), + np.split(x, [1, 2], axis=1), + ) + + # test invalid indices_or_sections + with self.assertRaises(Exception): + knp.split(x, 3) + + # test zero dimension + x = np.ones(shape=(0,)) + self.assertEqual(len(knp.split(x, 2)), 2) + self.assertEqual(len(knp.Split(2)(x)), 2) + + # test indices_or_sections as tensor + x = knp.array([[1, 2, 3], [3, 2, 1]]) + indices_or_sections = knp.array([1, 2]) + x_np = np.array([[1, 2, 3], [3, 2, 1]]) + indices_or_sections_np = np.array([1, 2]) + self.assertAllClose( + knp.split(x, indices_or_sections, axis=1), + np.split(x_np, indices_or_sections_np, axis=1), + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only test tensorflow backend", + ) + def test_split_with_jit_in_tf(self): + import tensorflow as tf + + x = knp.array([[1, 2, 3], [3, 2, 1]]) + indices = knp.array([1, 2]) + x_np = np.array([[1, 2, 3], [3, 2, 1]]) + indices_np = np.array([1, 2]) + + @tf.function(jit_compile=True) + def fn(x, indices, axis): + return knp.split(x, indices, axis=axis) + + self.assertAllClose( + fn(x, indices, axis=1), + np.split(x_np, indices_np, axis=1), + ) + + def test_sqrt(self): + x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") + ref_y = np.sqrt(x) + y = knp.sqrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + y = knp.Sqrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + + def test_sqrt_int32(self): + x = np.array([[1, 4, 9], [16, 25, 36]], dtype="int32") + ref_y = np.sqrt(x) + y = knp.sqrt(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + y = knp.Sqrt()(x) + self.assertEqual(standardize_dtype(y.dtype), "float32") + self.assertAllClose(y, ref_y) + + def test_stack(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + self.assertAllClose(knp.stack([x, y]), np.stack([x, y])) + self.assertAllClose(knp.stack([x, y], axis=1), np.stack([x, y], axis=1)) + self.assertAllClose(knp.Stack()([x, y]), np.stack([x, y])) + self.assertAllClose(knp.Stack(axis=1)([x, y]), np.stack([x, y], axis=1)) + + def test_std(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.std(x), np.std(x)) + self.assertAllClose(knp.std(x, axis=1), np.std(x, axis=1)) + self.assertAllClose( + knp.std(x, axis=1, keepdims=True), + np.std(x, axis=1, keepdims=True), + ) + + self.assertAllClose(knp.Std()(x), np.std(x)) + self.assertAllClose(knp.Std(axis=1)(x), np.std(x, axis=1)) + self.assertAllClose( + knp.Std(axis=1, keepdims=True)(x), + np.std(x, axis=1, keepdims=True), + ) + + def test_swapaxes(self): + x = np.arange(24).reshape([1, 2, 3, 4]) + self.assertAllClose( + knp.swapaxes(x, 0, 1), + np.swapaxes(x, 0, 1), + ) + self.assertAllClose( + knp.Swapaxes(0, 1)(x), + np.swapaxes(x, 0, 1), + ) + + def test_tan(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.tan(x), np.tan(x)) + self.assertAllClose(knp.Tan()(x), np.tan(x)) + + def test_tanh(self): + x = np.array([[1, -2, 3], [-3, 2, -1]]) + self.assertAllClose(knp.tanh(x), np.tanh(x)) + self.assertAllClose(knp.Tanh()(x), np.tanh(x)) + + def test_tile(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose(knp.tile(x, 2), np.tile(x, 2)) + self.assertAllClose(knp.tile(x, [2, 3]), np.tile(x, [2, 3])) + self.assertAllClose(knp.Tile([2, 3])(x), np.tile(x, [2, 3])) + + # If repeats.ndim > x.ndim + self.assertAllClose(knp.tile(x, [2, 3, 4]), np.tile(x, [2, 3, 4])) + self.assertAllClose(knp.Tile([2, 3, 4])(x), np.tile(x, [2, 3, 4])) + + # If repeats.ndim < x.ndim + self.assertAllClose(knp.tile(x, [2]), np.tile(x, [2])) + self.assertAllClose(knp.Tile([2])(x), np.tile(x, [2])) + + def test_trace(self): + x = np.arange(24).reshape([1, 2, 3, 4]) + self.assertAllClose(knp.trace(x), np.trace(x)) + self.assertAllClose( + knp.trace(x, axis1=2, axis2=3), + np.trace(x, axis1=2, axis2=3), + ) + self.assertAllClose( + knp.Trace(axis1=2, axis2=3)(x), + np.trace(x, axis1=2, axis2=3), + ) + + def test_tril(self): + x = np.arange(24).reshape([1, 2, 3, 4]) + self.assertAllClose(knp.tril(x), np.tril(x)) + self.assertAllClose(knp.tril(x, -1), np.tril(x, -1)) + self.assertAllClose(knp.Tril(-1)(x), np.tril(x, -1)) + + x = np.ones([5, 5]) + self.assertAllClose(knp.tril(x), np.tril(x)) + self.assertAllClose(knp.tril(x, -1), np.tril(x, -1)) + self.assertAllClose(knp.Tril(-1)(x), np.tril(x, -1)) + + def test_tril_in_layer(self): + # https://github.com/keras-team/keras/issues/18890 + x = keras.Input((None, 3)) + y1 = keras.layers.Lambda( + lambda x: keras.ops.tril( + keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1])) + ), + output_shape=(None, None, 3), + )(x) + y2 = keras.layers.Lambda( + lambda x: keras.ops.tril( + keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1])), + k=-1, + ), + output_shape=(None, None, 3), + )(x) + model = keras.Model(x, [y1, y2]) + + result = model(np.ones((1, 2, 3), "float32")) + self.assertAllClose( + result, [np.tril(np.ones((2, 2))), np.tril(np.ones((2, 2)), k=-1)] + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only test tensorflow backend", + ) + def test_tril_with_jit_in_tf(self): + import tensorflow as tf + + x = knp.reshape(knp.arange(24), [1, 2, 3, 4]) + k = knp.array(0) + x_np = np.reshape(np.arange(24), [1, 2, 3, 4]) + k_np = np.array(0) + + @tf.function(jit_compile=True) + def fn(x, k): + return knp.tril(x, k=k) + + self.assertAllClose(fn(x, k), np.tril(x_np, k_np)) + + def test_triu(self): + x = np.arange(24).reshape([1, 2, 3, 4]) + self.assertAllClose(knp.triu(x), np.triu(x)) + self.assertAllClose(knp.triu(x, -1), np.triu(x, -1)) + self.assertAllClose(knp.Triu(-1)(x), np.triu(x, -1)) + + x = np.ones([5, 5]) + self.assertAllClose(knp.triu(x), np.triu(x)) + self.assertAllClose(knp.triu(x, -1), np.triu(x, -1)) + self.assertAllClose(knp.Triu(-1)(x), np.triu(x, -1)) + + def test_triu_in_layer(self): + # https://github.com/keras-team/keras/issues/18890 + x = keras.Input((None, 3)) + y1 = keras.layers.Lambda( + lambda x: keras.ops.triu( + keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1])) + ), + output_shape=(None, None, 3), + )(x) + y2 = keras.layers.Lambda( + lambda x: keras.ops.triu( + keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1])), + k=-1, + ), + output_shape=(None, None, 3), + )(x) + model = keras.Model(x, [y1, y2]) + + result = model(np.ones((1, 2, 3), "float32")) + self.assertAllClose( + result, [np.triu(np.ones((2, 2))), np.triu(np.ones((2, 2)), k=-1)] + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Only test tensorflow backend", + ) + def test_triu_with_jit_in_tf(self): + import tensorflow as tf + + x = knp.reshape(knp.arange(24), [1, 2, 3, 4]) + k = knp.array(0) + x_np = np.reshape(np.arange(24), [1, 2, 3, 4]) + k_np = np.array(0) + + @tf.function(jit_compile=True) + def fn(x, k): + return knp.triu(x, k=k) + + self.assertAllClose(fn(x, k), np.triu(x_np, k_np)) + + def test_trunc(self): + x = np.array([-1.7, -2.5, -0.2, 0.2, 1.5, 1.7, 2.0]) + self.assertAllClose(knp.trunc(x), np.trunc(x)) + self.assertAllClose(knp.Trunc()(x), np.trunc(x)) + + x = np.array([-1, -2, -0, 0, 1, 1, 2], dtype="int32") + self.assertAllClose(knp.trunc(x), np.trunc(x)) + self.assertAllClose(knp.Trunc()(x), np.trunc(x)) + + def test_vstack(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [6, 5, 4]]) + self.assertAllClose(knp.vstack([x, y]), np.vstack([x, y])) + self.assertAllClose(knp.Vstack()([x, y]), np.vstack([x, y])) + + def test_floor_divide(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + self.assertAllClose(knp.floor_divide(x, y), np.floor_divide(x, y)) + self.assertAllClose(knp.floor_divide(x, z), np.floor_divide(x, z)) + + self.assertAllClose(knp.FloorDivide()(x, y), np.floor_divide(x, y)) + self.assertAllClose(knp.FloorDivide()(x, z), np.floor_divide(x, z)) + + def test_xor(self): + x = np.array([[True, False], [True, True]]) + y = np.array([[False, False], [True, False]]) + self.assertAllClose(knp.logical_xor(x, y), np.logical_xor(x, y)) + self.assertAllClose(knp.logical_xor(x, True), np.logical_xor(x, True)) + self.assertAllClose(knp.logical_xor(True, x), np.logical_xor(True, x)) + + self.assertAllClose(knp.LogicalXor()(x, y), np.logical_xor(x, y)) + self.assertAllClose(knp.LogicalXor()(x, True), np.logical_xor(x, True)) + self.assertAllClose(knp.LogicalXor()(True, x), np.logical_xor(True, x)) + + def test_correlate(self): + x = np.array([1, 2, 3]) + y = np.array([0, 1, 0.5]) + self.assertAllClose(knp.correlate(x, y), np.correlate(x, y)) + self.assertAllClose( + knp.correlate(x, y, mode="same"), np.correlate(x, y, mode="same") + ) + self.assertAllClose( + knp.correlate(x, y, mode="full"), np.correlate(x, y, mode="full") + ) + + self.assertAllClose(knp.Correlate()(x, y), np.correlate(x, y)) + self.assertAllClose( + knp.Correlate(mode="same")(x, y), np.correlate(x, y, mode="same") + ) + self.assertAllClose( + knp.Correlate(mode="full")(x, y), np.correlate(x, y, mode="full") + ) + + def test_correlate_different_size(self): + x = np.array([1, 2, 3, 4, 5, 6]) + y = np.array([0, 1, 0.5]) + self.assertAllClose(knp.correlate(x, y), np.correlate(x, y)) + self.assertAllClose( + knp.correlate(x, y, mode="same"), np.correlate(x, y, mode="same") + ) + self.assertAllClose( + knp.correlate(x, y, mode="full"), np.correlate(x, y, mode="full") + ) + + self.assertAllClose(knp.Correlate()(x, y), np.correlate(x, y)) + self.assertAllClose( + knp.Correlate(mode="same")(x, y), np.correlate(x, y, mode="same") + ) + self.assertAllClose( + knp.Correlate(mode="full")(x, y), np.correlate(x, y, mode="full") + ) + + def test_select(self): + x = np.arange(6) + condlist = [x < 3, x > 3] + choicelist = [x, x**2] + y = knp.select(condlist, choicelist, 42) + self.assertAllClose(y, [0, 1, 2, 42, 16, 25]) + + # Test with tuples + condlist = (x < 3, x > 3) + choicelist = (x, x**2) + y = knp.select(condlist, choicelist, 42) + self.assertAllClose(y, [0, 1, 2, 42, 16, 25]) + + # Test with symbolic tensors + x = backend.KerasTensor((6,)) + condlist = [x < 3, x > 3] + choicelist = [x, x**2] + y = knp.select(condlist, choicelist, 42) + self.assertEqual(y.shape, (6,)) + + def test_slogdet(self): + x = np.ones((4, 4)) * 2.0 + out = knp.slogdet(x) + self.assertAllClose(out[0], 0) + self.assertAllClose(out[0], 0) + + x = backend.KerasTensor((3, 3)) + out = knp.slogdet(x) + self.assertEqual(out[0].shape, ()) + self.assertEqual(out[1].shape, ()) + + x = backend.KerasTensor((2, 4, 3, 3)) + out = knp.slogdet(x) + self.assertEqual(out[0].shape, ()) + self.assertEqual(out[1].shape, (2, 4)) + + def test_nan_to_num(self): + x = knp.array([1.0, np.nan, np.inf, -np.inf]) + self.assertAllClose( + knp.nan_to_num(x), [1.0, 0.0, 3.402823e38, -3.402823e38] + ) + self.assertAllClose( + knp.NanToNum()(x), [1.0, 0.0, 3.402823e38, -3.402823e38] + ) + self.assertAllClose( + knp.nan_to_num(x, nan=2, posinf=3, neginf=4), [1.0, 2.0, 3.0, 4.0] + ) + self.assertAllClose( + knp.NanToNum(nan=2, posinf=3, neginf=4)(x), [1.0, 2.0, 3.0, 4.0] + ) + + x = backend.KerasTensor((3, 4)) + self.assertEqual( + knp.NanToNum(nan=2, posinf=3, neginf=4)(x).shape, (3, 4) + ) + + def test_vectorize(self): + # Basic functionality + def myfunc(a, b): + return a + b + + vfunc = np.vectorize(myfunc) + y = vfunc([1, 2, 3, 4], 2) + self.assertAllClose(y, [3, 4, 5, 6]) + + # Test signature arg + vfunc = knp.vectorize(knp.trace, signature="(d,d)->()") + out = vfunc(np.eye(4)) + self.assertAllClose( + out, np.vectorize(np.trace, signature="(d,d)->()")(np.eye(4)) + ) + + vfunc = knp.vectorize(knp.diag, signature="(d,d)->(d)") + out = vfunc(np.eye(4)) + self.assertAllClose( + out, np.vectorize(np.diag, signature="(d,d)->(d)")(np.eye(4)) + ) + + def test_argpartition(self): + x = np.array([3, 4, 2, 1]) + self.assertAllClose(knp.argpartition(x, 2), np.argpartition(x, 2)) + self.assertAllClose(knp.Argpartition(2)(x), np.argpartition(x, 2)) + + x = np.array([[3, 4, 2], [1, 3, 4]]) + self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1)) + self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1)) + + x = np.array([[[3, 4], [2, 3]], [[1, 2], [0, 1]]]) + self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1)) + self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1)) + + def test_angle(self): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + self.assertAllClose(knp.angle(x), np.angle(x)) + + self.assertAllClose(knp.Angle()(x), np.angle(x)) + + +class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): + def test_ones(self): + self.assertAllClose(knp.ones([2, 3]), np.ones([2, 3])) + + def test_zeros(self): + self.assertAllClose(knp.zeros([2, 3]), np.zeros([2, 3])) + + def test_eye(self): + self.assertAllClose(knp.eye(3), np.eye(3)) + self.assertAllClose(knp.eye(3, 4), np.eye(3, 4)) + self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1)) + + # Test k >= N + self.assertAllClose(knp.eye(3, k=3), np.eye(3, k=3)) + + # Test k > 0 and N >= M + self.assertAllClose(knp.eye(3, k=1), np.eye(3, k=1)) + + # Test k > 0 and N < M and N + k > M + self.assertAllClose(knp.eye(3, 4, k=2), np.eye(3, 4, k=2)) + + # Test k < 0 and M >= N + self.assertAllClose(knp.eye(3, k=-1), np.eye(3, k=-1)) + + # Test k < 0 and M < N and M - k > N + self.assertAllClose(knp.eye(4, 3, k=-2), np.eye(4, 3, k=-2)) + + def test_arange(self): + self.assertAllClose(knp.arange(3), np.arange(3)) + self.assertAllClose(knp.arange(3, 7), np.arange(3, 7)) + self.assertAllClose(knp.arange(3, 7, 2), np.arange(3, 7, 2)) + + self.assertAllClose(knp.Arange()(3), np.arange(3)) + self.assertAllClose(knp.Arange()(3, 7), np.arange(3, 7)) + self.assertAllClose(knp.Arange()(3, 7, 2), np.arange(3, 7, 2)) + + self.assertEqual(standardize_dtype(knp.arange(3).dtype), "int32") + with warnings.catch_warnings(record=True) as record: + knp.arange(3, dtype="int") + self.assertEqual(len(record), 0) + + def test_full(self): + self.assertAllClose(knp.full([2, 3], 0), np.full([2, 3], 0)) + self.assertAllClose(knp.full([2, 3], 0.1), np.full([2, 3], 0.1)) + self.assertAllClose( + knp.full([2, 3], np.array([1, 4, 5])), + np.full([2, 3], np.array([1, 4, 5])), + ) + + self.assertAllClose(knp.Full([2, 3])(0), np.full([2, 3], 0)) + self.assertAllClose(knp.Full([2, 3])(0.1), np.full([2, 3], 0.1)) + self.assertAllClose( + knp.Full([2, 3])(np.array([1, 4, 5])), + np.full([2, 3], np.array([1, 4, 5])), + ) + + def test_identity(self): + self.assertAllClose(knp.identity(3), np.identity(3)) + + def test_tri(self): + self.assertAllClose(knp.tri(3), np.tri(3)) + self.assertAllClose(knp.tri(3, 4), np.tri(3, 4)) + self.assertAllClose(knp.tri(3, 4, 1), np.tri(3, 4, 1)) + + # Test k < 0 + self.assertAllClose(knp.tri(3, k=-1), np.tri(3, k=-1)) + + # Test -k-1 > N + self.assertAllClose(knp.tri(3, k=-5), np.tri(3, k=-5)) + + # Test k > M + self.assertAllClose(knp.tri(3, k=4), np.tri(3, k=4)) + + +def create_sparse_tensor(x, indices_from=None, start=0, delta=2): + if indices_from is not None: + indices = indices_from.indices + else: + size = math.prod(x.shape) + flat_indices = np.arange(start, size, delta) + indices = np.stack(np.where(np.ones_like(x)), axis=1)[flat_indices] + + if backend.backend() == "tensorflow": + import tensorflow as tf + + return tf.SparseTensor(indices, tf.gather_nd(x, indices), x.shape) + elif backend.backend() == "jax": + import jax + import jax.experimental.sparse as jax_sparse + + values = x[tuple(jax.numpy.moveaxis(indices, -1, 0))] + return jax_sparse.BCOO((values, indices), shape=x.shape) + + +def create_indexed_slices(x, indices_from=None, start=0, delta=2): + indices = np.arange(start, x.shape[0], delta) + + if backend.backend() == "tensorflow": + import tensorflow as tf + + if indices_from is not None: + indices = indices_from.indices + return tf.IndexedSlices(tf.gather(x, indices), indices, x.shape) + elif backend.backend() == "jax": + import jax + import jax.experimental.sparse as jax_sparse + + if indices_from is not None: + indices = indices_from.indices + else: + indices = jax.numpy.expand_dims(indices, axis=1) + values = jax.numpy.take(x, jax.numpy.squeeze(indices, axis=1), axis=0) + return jax_sparse.BCOO((values, indices), shape=x.shape) + + +def get_sparseness_combinations(dense_to_sparse_fn): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + scalar = backend.convert_to_tensor(2) + x_sp = dense_to_sparse_fn(x) + y_sp = dense_to_sparse_fn(y, indices_from=x_sp) + x_sp_sup = dense_to_sparse_fn(x, start=0, delta=1) + y_sp_dis = dense_to_sparse_fn(y, start=1) + y_sp_sup = dense_to_sparse_fn(y, start=0, delta=1) + x = backend.convert_to_tensor(x) + y = backend.convert_to_tensor(y) + return [ + {"testcase_name": "sparse_dense", "x": x_sp, "y": y}, + {"testcase_name": "dense_sparse", "x": x, "y": y_sp}, + {"testcase_name": "sparse_scalar", "x": x_sp, "y": scalar}, + {"testcase_name": "scalar_sparse", "x": scalar, "y": y_sp}, + {"testcase_name": "sparse_sparse_same", "x": x_sp, "y": y_sp}, + {"testcase_name": "sparse_sparse_disjoint", "x": x_sp, "y": y_sp_dis}, + {"testcase_name": "sparse_sparse_superset", "x": x_sp, "y": y_sp_sup}, + {"testcase_name": "sparse_sparse_subset", "x": x_sp_sup, "y": y_sp}, + ] + + +def sparseness(x): + if isinstance(x, KerasTensor): + return "sparse" if x.sparse else "dense" + elif x.__class__.__name__ == "BCOO": + if x.n_dense > 0: + return "slices" + else: + return "sparse" + elif x.__class__.__name__ == "SparseTensor": + return "sparse" + elif x.__class__.__name__ == "IndexedSlices": + return "slices" + elif not hasattr(x, "shape") or not x.shape: + return "scalar" + else: + return "dense" + + +def union_sparseness(x1, x2): + x1_sparseness = sparseness(x1) + x2_sparseness = sparseness(x2) + if any(s in ("scalar", "dense") for s in (x1_sparseness, x2_sparseness)): + return "dense" + if x1_sparseness != x2_sparseness: + raise ValueError(f"Illegal combination of operands: {x1} {x2}") + return x1_sparseness + + +def intersection_sparseness(x1, x2): + x1_sparseness = sparseness(x1) + x2_sparseness = sparseness(x2) + if x1_sparseness == "scalar": + return x2_sparseness + if x2_sparseness in ("scalar", "dense"): + return x1_sparseness + if x1_sparseness == "dense": + return x2_sparseness + if x1_sparseness != x2_sparseness: + raise ValueError(f"Illegal combination of operands: {x1} {x2}") + return x1_sparseness + + +def division_sparseness(x1, x2): + x1_sparseness = sparseness(x1) + x2_sparseness = sparseness(x2) + if x2_sparseness in ("sparse", "slices"): + return "dense" + return "dense" if x1_sparseness == "scalar" else x1_sparseness + + +def snake_to_pascal_case(name): + return "".join(w.capitalize() for w in name.split("_")) + + +@pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", +) +class SparseTest(testing.TestCase): + DTYPES = ["int32", "float32"] + DENSIFYING_UNARY_OPS = [ + "arccos", + "arccosh", + "cos", + "cosh", + "exp", + "isfinite", + "log", + "log10", + "log2", + "reciprocal", + ] + DENSIFYING_UNARY_OPS_TESTS = [ + { + "testcase_name": op, + "op_function": getattr(knp, op), + "op_class": getattr(knp, op.capitalize()), + "np_op": getattr(np, op), + } + for op in DENSIFYING_UNARY_OPS + ] + ELEMENTWISE_UNARY_OPS = [ + "abs", + "absolute", + "arcsin", + "arcsinh", + "arctan", + "arctanh", + "ceil", + "conj", + "conjugate", + "copy", + "expm1", + "floor", + "imag", + "log1p", + "negative", + "real", + "round", + "sign", + "sin", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + ] + ELEMENTWISE_UNARY_OPS_TESTS = [ + { + "testcase_name": op, + "op_function": getattr(knp, op), + "op_class": getattr(knp, snake_to_pascal_case(op)), + "np_op": getattr(np, op), + } + for op in ELEMENTWISE_UNARY_OPS + ] + OTHER_UNARY_OPS_ARGS = [ + ("digitize", "", {}, {"bins": np.array([0.1, 0.2, 1.0])}, (4, 2, 3)), + ("mean", "none", {"axis": None}, {}, (4, 2, 3)), + ("mean", "none_k", {"axis": None, "keepdims": True}, {}, (4, 2, 3)), + ("mean", "empty", {"axis": ()}, {}, (4, 2, 3)), + ("mean", "empty_k", {"axis": (), "keepdims": True}, {}, (4, 2, 3)), + ("mean", "0", {"axis": 0}, {}, (4, 2, 3)), + ("mean", "0_k", {"axis": 0, "keepdims": True}, {}, (4, 2, 3)), + ("mean", "1", {"axis": 1}, {}, (4, 2, 3)), + ("mean", "1_k", {"axis": 1, "keepdims": True}, {}, (4, 2, 3)), + ("mean", "01", {"axis": (0, 1)}, {}, (4, 2, 3)), + ("mean", "01_k", {"axis": (0, 1), "keepdims": True}, {}, (4, 2, 3)), + ("mean", "02", {"axis": (1, 2)}, {}, (4, 2, 3)), + ("mean", "02_k", {"axis": (1, 2), "keepdims": True}, {}, (4, 2, 3)), + ("mean", "all", {"axis": (0, 1, 2)}, {}, (4, 2, 3)), + ("mean", "all_k", {"axis": (0, 1, 2), "keepdims": True}, {}, (4, 2, 3)), + ("sum", "none", {"axis": None}, {}, (4, 2, 3)), + ("sum", "none_k", {"axis": None, "keepdims": True}, {}, (4, 2, 3)), + ("sum", "empty", {"axis": ()}, {}, (4, 2, 3)), + ("sum", "empty_k", {"axis": (), "keepdims": True}, {}, (4, 2, 3)), + ("sum", "0", {"axis": 0}, {}, (4, 2, 3)), + ("sum", "0_k", {"axis": 0, "keepdims": True}, {}, (4, 2, 3)), + ("sum", "1", {"axis": 1}, {}, (4, 2, 3)), + ("sum", "1_k", {"axis": 1, "keepdims": True}, {}, (4, 2, 3)), + ("sum", "01", {"axis": (0, 1)}, {}, (4, 2, 3)), + ("sum", "01_k", {"axis": (0, 1), "keepdims": True}, {}, (4, 2, 3)), + ("sum", "02", {"axis": (1, 2)}, {}, (4, 2, 3)), + ("sum", "02_k", {"axis": (1, 2), "keepdims": True}, {}, (4, 2, 3)), + ("sum", "all", {"axis": (0, 1, 2)}, {}, (4, 2, 3)), + ("sum", "all_k", {"axis": (0, 1, 2), "keepdims": True}, {}, (4, 2, 3)), + ("expand_dims", "zero", {"axis": 0}, {}, (2, 3)), + ("expand_dims", "one", {"axis": 1}, {}, (2, 3)), + ("expand_dims", "minus_two", {"axis": -2}, {}, (2, 3)), + ("reshape", "basic", {"newshape": (4, 3, 2)}, {}, (4, 2, 3)), + ("reshape", "minus_one", {"newshape": (4, 3, -1)}, {}, (4, 2, 3)), + ("reshape", "fewer_dims", {"newshape": (4, 6)}, {}, (4, 2, 3)), + ("squeeze", "no_axis_no_op", {}, {}, (2, 3)), + ("squeeze", "one", {"axis": 1}, {}, (2, 1, 3)), + ("squeeze", "minus_two", {"axis": -2}, {}, (2, 1, 3)), + ("squeeze", "no_axis", {}, {}, (2, 1, 3)), + ("transpose", "no_axes", {}, {}, (1, 2, 3, 4)), + ("transpose", "axes", {"axes": (0, 3, 2, 1)}, {}, (1, 2, 3, 4)), + ] + OTHER_UNARY_OPS_TESTS = [ + { + "testcase_name": "_".join([op, testcase_name]), + "op_function": getattr(knp, op), + "op_class": getattr(knp, snake_to_pascal_case(op)), + "np_op": getattr(np, op), + "init_kwargs": init_kwargs, + "op_kwargs": op_kwargs, + "input_shape": input_shape, + } + for op, testcase_name, init_kwargs, op_kwargs, input_shape in ( + OTHER_UNARY_OPS_ARGS + ) + ] + + BINARY_OPS = [ + ("add", union_sparseness), + ("subtract", union_sparseness), + ("maximum", union_sparseness), + ("minimum", union_sparseness), + ("multiply", intersection_sparseness), + ("divide", division_sparseness), + ("true_divide", division_sparseness), + ] + BINARY_OPS_TESTS = [ + { + "testcase_name": op, + "op_function": getattr(knp, op), + "op_class": getattr(knp, snake_to_pascal_case(op)), + "np_op": getattr(np, op), + "op_sparseness": op_sparseness, + } + for op, op_sparseness in BINARY_OPS + ] + + def assertSameSparseness(self, x, y): + self.assertEqual(sparseness(x), sparseness(y)) + + def assertSparseness(self, x, expected_sparseness): + self.assertEqual(sparseness(x), expected_sparseness) + + @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS) + def test_elementwise_unary_symbolic_static_shape( + self, op_function, op_class, np_op + ): + x = KerasTensor([2, 3], sparse=True) + self.assertEqual(op_function(x).shape, (2, 3)) + self.assertTrue(op_function(x).sparse) + self.assertEqual(op_class()(x).shape, (2, 3)) + self.assertTrue(op_class()(x).sparse) + + @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS) + def test_elementwise_unary_symbolic_dynamic_shape( + self, op_function, op_class, np_op + ): + x = KerasTensor([None, 3], sparse=True) + self.assertEqual(op_function(x).shape, (None, 3)) + self.assertTrue(op_function(x).sparse) + self.assertEqual(op_class()(x).shape, (None, 3)) + self.assertTrue(op_class()(x).sparse) + + @parameterized.named_parameters(OTHER_UNARY_OPS_TESTS) + def test_other_unary_symbolic_static_shape( + self, op_function, op_class, np_op, init_kwargs, op_kwargs, input_shape + ): + expected_shape = op_function( + KerasTensor(input_shape), **init_kwargs, **op_kwargs + ).shape + x = KerasTensor(input_shape, sparse=True) + self.assertEqual( + op_function(x, **init_kwargs, **op_kwargs).shape, expected_shape + ) + self.assertTrue(op_function(x, **init_kwargs, **op_kwargs).sparse) + self.assertEqual( + op_class(**init_kwargs)(x, **op_kwargs).shape, expected_shape + ) + self.assertTrue(op_class(**init_kwargs)(x, **op_kwargs).sparse) + + @parameterized.named_parameters(OTHER_UNARY_OPS_TESTS) + def test_other_unary_symbolic_dynamic_shape( + self, op_function, op_class, np_op, init_kwargs, op_kwargs, input_shape + ): + input_shape = (None,) + input_shape[1:] + expected_shape = op_function( + KerasTensor(input_shape), **init_kwargs, **op_kwargs + ).shape + x = KerasTensor(input_shape, sparse=True) + self.assertEqual( + op_function(x, **init_kwargs, **op_kwargs).shape, expected_shape + ) + self.assertTrue(op_function(x, **init_kwargs, **op_kwargs).sparse) + self.assertEqual( + op_class(**init_kwargs)(x, **op_kwargs).shape, expected_shape + ) + self.assertTrue(op_class(**init_kwargs)(x, **op_kwargs).sparse) + + @parameterized.named_parameters(DENSIFYING_UNARY_OPS_TESTS) + def test_densifying_unary_sparse_correctness( + self, op_function, op_class, np_op + ): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + x = create_sparse_tensor(x) + x_np = backend.convert_to_numpy(x) + + self.assertAllClose(op_function(x), np_op(x_np)) + self.assertAllClose(op_class()(x), np_op(x_np)) + + @parameterized.named_parameters(DENSIFYING_UNARY_OPS_TESTS) + def test_densifying_unary_indexed_slices_correctness( + self, op_function, op_class, np_op + ): + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + x = create_indexed_slices(x) + x_np = backend.convert_to_numpy(x) + + self.assertAllClose(op_function(x), np_op(x_np)) + self.assertAllClose(op_class()(x), np_op(x_np)) + + @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS) + def test_elementwise_unary_sparse_correctness( + self, op_function, op_class, np_op + ): + if op_function.__name__ in ("conj", "conjugate", "imag", "real"): + x = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [3 + 3j, 2 + 2j, 1 + 1j]]) + else: + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + x = create_sparse_tensor(x) + x_np = backend.convert_to_numpy(x) + + self.assertAllClose(op_function(x), np_op(x_np)) + self.assertSameSparseness(op_function(x), x) + self.assertAllClose(op_class()(x), np_op(x_np)) + self.assertSameSparseness(op_class()(x), x) + + @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS) + def test_elementwise_unary_indexed_slices_correctness( + self, op_function, op_class, np_op + ): + if op_function.__name__ in ("conj", "conjugate", "imag", "real"): + x = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [3 + 3j, 2 + 2j, 1 + 1j]]) + else: + x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]]) + x = create_indexed_slices(x) + x_np = backend.convert_to_numpy(x) + + self.assertAllClose(op_function(x), np_op(x_np)) + self.assertSameSparseness(op_function(x), x) + self.assertAllClose(op_class()(x), np_op(x_np)) + self.assertSameSparseness(op_class()(x), x) + + @parameterized.named_parameters(OTHER_UNARY_OPS_TESTS) + def test_other_unary_sparse_correctness( + self, op_function, op_class, np_op, init_kwargs, op_kwargs, input_shape + ): + x = np.random.random(input_shape) + if op_function is knp.mean: + x = create_indexed_slices(x) + else: + x = create_sparse_tensor(x) + x_np = backend.convert_to_numpy(x) + + self.assertAllClose( + op_function(x, **init_kwargs, **op_kwargs), + np_op(x_np, **init_kwargs, **op_kwargs), + ) + self.assertAllClose( + op_class(**init_kwargs)(x, **op_kwargs), + np_op(x_np, **init_kwargs, **op_kwargs), + ) + # Reduction operations have complex and backend dependent rules about + # when the result is sparse and it is dense. + if op_function is not knp.mean: + self.assertSameSparseness( + op_function(x, **init_kwargs, **op_kwargs), x + ) + self.assertSameSparseness( + op_class(**init_kwargs)(x, **op_kwargs), x + ) + + @parameterized.named_parameters( + named_product( + BINARY_OPS_TESTS, x_sparse=[True, False], y_sparse=[True, False] + ) + ) + def test_binary_symbolic_static_shape( + self, x_sparse, y_sparse, op_function, op_class, np_op, op_sparseness + ): + x = KerasTensor([2, 3], sparse=x_sparse) + y = KerasTensor([2, 3], sparse=y_sparse) + self.assertEqual(op_function(x, y).shape, (2, 3)) + self.assertSparseness(op_function(x, y), op_sparseness(x, y)) + self.assertEqual(op_class()(x, y).shape, (2, 3)) + self.assertSparseness(op_class()(x, y), op_sparseness(x, y)) + + @parameterized.named_parameters( + named_product( + BINARY_OPS_TESTS, x_sparse=[True, False], y_sparse=[True, False] + ) + ) + def test_binary_symbolic_dynamic_shape( + self, x_sparse, y_sparse, op_function, op_class, np_op, op_sparseness + ): + x = KerasTensor([None, 3], sparse=x_sparse) + y = KerasTensor([2, None], sparse=y_sparse) + self.assertEqual(op_function(x, y).shape, (2, 3)) + self.assertSparseness(op_function(x, y), op_sparseness(x, y)) + self.assertEqual(op_class()(x, y).shape, (2, 3)) + self.assertSparseness(op_class()(x, y), op_sparseness(x, y)) + + @parameterized.named_parameters( + named_product( + BINARY_OPS_TESTS, + get_sparseness_combinations(create_sparse_tensor), + dtype=DTYPES, + ) + ) + def test_binary_correctness_sparse_tensor( + self, x, y, op_function, op_class, np_op, op_sparseness, dtype + ): + x = backend.cast(x, dtype) + y = backend.cast(y, dtype) + expected_result = np_op( + backend.convert_to_numpy(x), backend.convert_to_numpy(y) + ) + + self.assertAllClose(op_function(x, y), expected_result) + self.assertSparseness(op_function(x, y), op_sparseness(x, y)) + self.assertAllClose(op_class()(x, y), expected_result) + self.assertSparseness(op_class()(x, y), op_sparseness(x, y)) + + @parameterized.named_parameters( + named_product( + BINARY_OPS_TESTS, + get_sparseness_combinations(create_indexed_slices), + dtype=DTYPES, + ) + ) + def test_binary_correctness_indexed_slices( + self, x, y, op_function, op_class, np_op, op_sparseness, dtype + ): + x = backend.cast(x, dtype) + y = backend.cast(y, dtype) + expected_result = np_op( + backend.convert_to_numpy(x), backend.convert_to_numpy(y) + ) + + self.assertAllClose(op_function(x, y), expected_result) + self.assertSparseness(op_function(x, y), op_sparseness(x, y)) + self.assertAllClose(op_class()(x, y), expected_result) + self.assertSparseness(op_class()(x, y), op_sparseness(x, y)) + + @parameterized.named_parameters( + named_product( + sparse_type=["sparse_tensor", "indexed_slices"], + dtype=["int32", "float32"], + ) + ) + def test_divide_with_zeros_nans(self, sparse_type, dtype): + x = backend.convert_to_tensor([[0, 2, 3], [3, 2, 1]], dtype=dtype) + if sparse_type == "indexed_slices": + x = create_indexed_slices(x, start=0, delta=2) + else: + x = create_sparse_tensor(x, start=0, delta=2) + if dtype.startswith("int"): + y = [[0, 0, 3], [0, 0, 1]] + else: + y = [[np.nan, np.nan, 3], [0, 0, 1]] + y = backend.convert_to_tensor(y, dtype=dtype) + expected_result = np.divide( + backend.convert_to_numpy(x), backend.convert_to_numpy(y) + ) + + self.assertAllClose(knp.divide(x, y), expected_result) + self.assertAllClose(knp.Divide()(x, y), expected_result) + + +class NumpyDtypeTest(testing.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + ALL_DTYPES = [ + x + for x in dtypes.ALLOWED_DTYPES + if x + not in ( + "string", + "complex64", + "complex128", + # Remove 64-bit dtypes. + "float64", + "uint64", + "int64", + ) + + dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests + ] + [None] + INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")] + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + + if backend.backend() == "torch": + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")] + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")] + elif backend.backend() == "tensorflow": + # TODO(hongyu): Re-enable uint32 tests once we determine how to handle + # dtypes.result_type(uint32, int*) -> int64 promotion. + # Since TF variables require int64 to be placed on the GPU, we + # exclusively enable the int64 dtype for TF. However, JAX does not + # natively support int64, which prevents us from comparing the dtypes. + ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint32",)] + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint32",)] + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_add(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.add(x1, x2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Add().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_add_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.add doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.add(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.add(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Add().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.add(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.add(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Add().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_bartlett(self, dtype): + x = knp.ones((), dtype=dtype) + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.bartlett(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Bartlett().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_blackman(self, dtype): + x = knp.ones((), dtype=dtype) + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.blackman(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Blackman().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_hamming(self, dtype): + x = knp.ones((), dtype=dtype) + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.hamming(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Hamming().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_hanning(self, dtype): + x = knp.ones((), dtype=dtype) + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.hanning(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Hanning().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_kaiser(self, dtype): + x = knp.ones((), dtype=dtype) + beta = knp.ones((), dtype=dtype) + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.kaiser(x, beta).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Kaiser(beta).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=INT_DTYPES)) + def test_bincount(self, dtype): + import jax.numpy as jnp + + if backend.backend() == "tensorflow": + import tensorflow as tf + + if tf.test.is_gpu_available(): + self.skipTest("bincount does not work in tensorflow gpu") + + x = np.array([1, 1, 2, 3, 2, 4, 4, 5], dtype=dtype) + weights = np.array([0, 0, 3, 2, 1, 1, 4, 2], dtype=dtype) + minlength = 3 + self.assertEqual( + standardize_dtype( + knp.bincount(x, weights=weights, minlength=minlength).dtype + ), + standardize_dtype( + jnp.bincount(x, weights=weights, minlength=minlength).dtype + ), + ) + self.assertEqual( + knp.Bincount(weights=weights, minlength=minlength) + .symbolic_call(x) + .dtype, + standardize_dtype( + jnp.bincount(x, weights=weights, minlength=minlength).dtype + ), + ) + + # test float32 weights + weights = np.array([0, 0, 3, 2, 1, 1, 4, 2], dtype="float32") + self.assertEqual( + standardize_dtype(knp.bincount(x, weights=weights).dtype), + standardize_dtype(jnp.bincount(x, weights=weights).dtype), + ) + self.assertEqual( + knp.Bincount(weights=weights).symbolic_call(x).dtype, + standardize_dtype(jnp.bincount(x, weights=weights).dtype), + ) + + # test float16 weights + weights = np.array([0, 0, 3, 2, 1, 1, 4, 2], dtype="float16") + self.assertEqual( + standardize_dtype(knp.bincount(x, weights=weights).dtype), + standardize_dtype(jnp.bincount(x, weights=weights).dtype), + ) + self.assertEqual( + knp.Bincount(weights=weights).symbolic_call(x).dtype, + standardize_dtype(jnp.bincount(x, weights=weights).dtype), + ) + + # test weights=None + self.assertEqual( + standardize_dtype(knp.bincount(x).dtype), + standardize_dtype(jnp.bincount(x).dtype), + ) + self.assertEqual( + knp.Bincount().symbolic_call(x).dtype, + standardize_dtype(jnp.bincount(x).dtype), + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_subtract(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + if dtype1 == "bool" and dtype2 == "bool": + self.skipTest("subtract does not support bool") + + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.subtract(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.subtract(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Subtract().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_subtract_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.subtract doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.subtract(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.subtract(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Subtract().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.subtract(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.subtract(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Subtract().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product( + dtypes=list(itertools.combinations(ALL_DTYPES, 2)) + + [("int8", "int8")] + ) + ) + def test_matmul(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + # The shape of the matrix needs to meet the requirements of + # torch._int_mm to test hardware-accelerated matmul + x1 = knp.ones((17, 16), dtype=dtype1) + x2 = knp.ones((16, 8), dtype=dtype2) + x1_jax = jnp.ones((17, 16), dtype=dtype1) + x2_jax = jnp.ones((16, 8), dtype=dtype2) + if dtype1 == "int8" and dtype2 == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + expected_dtype = standardize_dtype( + jnp.matmul( + x1_jax, x2_jax, preferred_element_type=preferred_element_type + ).dtype + ) + + self.assertEqual( + standardize_dtype(knp.matmul(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Matmul().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_multiply(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.multiply(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.multiply(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Multiply().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_multiply_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.multiply doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.multiply(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.multiply(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Multiply().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.multiply(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.multiply(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Multiply().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_mean(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.mean(x_jax).dtype) + if dtype == "int64": + expected_dtype = "float32" + + self.assertEqual(standardize_dtype(knp.mean(x).dtype), expected_dtype) + self.assertEqual(knp.Mean().symbolic_call(x).dtype, expected_dtype) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_max(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.max(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.max(x).dtype), expected_dtype) + self.assertEqual(knp.Max().symbolic_call(x).dtype, expected_dtype) + + # Test with initial + initial = 1 + expected_dtype = standardize_dtype( + jnp.max(x_jax, initial=initial).dtype + ) + self.assertEqual( + standardize_dtype(knp.max(x, initial=initial).dtype), expected_dtype + ) + self.assertEqual( + knp.Max(initial=initial).symbolic_call(x).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_ones(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.ones([2, 3], dtype=dtype).dtype) + + self.assertEqual( + standardize_dtype(knp.ones([2, 3], dtype=dtype).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_zeros(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.zeros([2, 3], dtype=dtype).dtype) + + self.assertEqual( + standardize_dtype(knp.zeros([2, 3], dtype=dtype).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_absolute(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.absolute(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.absolute(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Absolute().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_all(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.all(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.all(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.All().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_amax(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.amax(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.amax(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Amax().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_amin(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.amin(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.amin(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Amin().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_any(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.any(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.any(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Any().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_append(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.append(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.append(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Append().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_argmax(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + value = [[True, False, True], [False, True, False]] + else: + value = [[1, 2, 3], [3, 2, 1]] + x = knp.array(value, dtype=dtype) + x_jax = jnp.array(value, dtype=dtype) + expected_dtype = standardize_dtype(jnp.argmax(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.argmax(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Argmax().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_argmin(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + value = [[True, False, True], [False, True, False]] + else: + value = [[1, 2, 3], [3, 2, 1]] + x = knp.array(value, dtype=dtype) + x_jax = jnp.array(value, dtype=dtype) + expected_dtype = standardize_dtype(jnp.argmin(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.argmin(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Argmin().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_argpartition(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + self.skipTest("argpartition doesn't support bool dtype") + + x = knp.array([1, 2, 3], dtype=dtype) + x_jax = jnp.array([1, 2, 3], dtype=dtype) + expected_dtype = standardize_dtype(jnp.argpartition(x_jax, 1).dtype) + + self.assertEqual( + standardize_dtype(knp.argpartition(x, 1).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Argpartition(1).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_argsort(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + value = [[True, False, True], [False, True, False]] + else: + value = [[1, 2, 3], [4, 5, 6]] + x = knp.array(value, dtype=dtype) + x_jax = jnp.array(value, dtype=dtype) + expected_dtype = standardize_dtype(jnp.argsort(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.argsort(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Argsort().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.parameters( + (10, None, None, None), # stop + (2, 10, None, None), # start, stop + (10, None, 2, None), # stop, step + (0, 10, 2, None), # start, stop, step + (0, 10, 0.5, None), + (10.0, None, 1, None), + (0, 10.0, 1, None), + (0.0, 10, 1, None), + (10, None, 1, "float32"), + (10, None, 1, "int32"), + (10, None, 1, "int16"), + (10, None, 1, "float16"), + ) + def test_arange(self, start, stop, step, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype( + jnp.arange(start, stop, step, dtype).dtype + ) + + self.assertEqual( + standardize_dtype(knp.arange(start, stop, step, dtype).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Arange(dtype).symbolic_call(start, stop, step).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_arccos(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.arccos(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.arccos(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Arccos().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_arccosh(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.arccosh(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.arccosh(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Arccosh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_arcsin(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.arcsin(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.arcsin(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Arcsin().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_arcsinh(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.arcsinh(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.arcsinh(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Arcsinh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_arctan(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.arctan(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.arctan(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Arctan().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_arctan2(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.arctan2(x1_jax, x2_jax).dtype) + if dtype1 is not None and "float" not in dtype1: + if dtype2 is not None and "float" not in dtype2: + if "int64" in (dtype1, dtype2) or "uint32" in (dtype1, dtype2): + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.arctan2(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Arctan2().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_arctanh(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.arctanh(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.arctanh(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Arctanh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.parameters( + (bool(0), "bool"), + (int(0), "int32"), + (float(0), backend.floatx()), + ([False, True, False], "bool"), + ([1, 2, 3], "int32"), + ([1.0, 2.0, 3.0], backend.floatx()), + ([1, 2.0, 3], backend.floatx()), + ([[False], [True], [False]], "bool"), + ([[1], [2], [3]], "int32"), + ([[1], [2.0], [3]], backend.floatx()), + *[ + (np.array(0, dtype=dtype), dtype) + for dtype in ALL_DTYPES + if dtype is not None + ], + ) + def test_array(self, x, expected_dtype): + # We have to disable x64 for jax backend since jnp.array doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit. + if backend.backend() == "jax": + jax_disable_x64 = jax_disable_x64_context() + expected_dtype = expected_dtype.replace("64", "32") + else: + jax_disable_x64 = contextlib.nullcontext() + + with jax_disable_x64: + self.assertEqual( + standardize_dtype(knp.array(x).dtype), expected_dtype + ) + # TODO: support the assertion of knp.Array + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_average(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.average(x1_jax, weights=x2_jax).dtype + ) + if dtype1 is not None and "float" not in dtype1: + if dtype2 is not None and "float" not in dtype2: + if "int64" in (dtype1, dtype2) or "uint32" in (dtype1, dtype2): + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.average(x1, weights=x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Average().symbolic_call(x1, weights=x2).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) + ) + def test_bitwise_and(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.bitwise_and(x1_jax, x2_jax).dtype + ) + + self.assertDType(knp.bitwise_and(x1, x2), expected_dtype) + self.assertDType(knp.BitwiseAnd().symbolic_call(x1, x2), expected_dtype) + + @parameterized.named_parameters(named_product(dtype=INT_DTYPES)) + def test_bitwise_invert(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.invert(x_jax).dtype) + + self.assertDType(knp.bitwise_invert(x), expected_dtype) + self.assertDType(knp.BitwiseInvert().symbolic_call(x), expected_dtype) + + # bitwise_not is same as bitwise_invert + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) + ) + def test_bitwise_or(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.bitwise_or(x1_jax, x2_jax).dtype) + + self.assertDType(knp.bitwise_or(x1, x2), expected_dtype) + self.assertDType(knp.BitwiseOr().symbolic_call(x1, x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) + ) + def test_bitwise_xor(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.bitwise_xor(x1_jax, x2_jax).dtype + ) + + self.assertDType(knp.bitwise_xor(x1, x2), expected_dtype) + self.assertDType(knp.BitwiseXor().symbolic_call(x1, x2), expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.product(INT_DTYPES, INT_DTYPES + [None])) + ) + def test_bitwise_left_shift(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) if dtype2 else 1 + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) if dtype2 else 1 + expected_dtype = standardize_dtype(jnp.left_shift(x1_jax, x2_jax).dtype) + + self.assertDType(knp.bitwise_left_shift(x1, x2), expected_dtype) + self.assertDType( + knp.BitwiseLeftShift().symbolic_call(x1, x2), expected_dtype + ) + + # left_shift is same as bitwise_left_shift + + @parameterized.named_parameters( + named_product(dtypes=itertools.product(INT_DTYPES, INT_DTYPES + [None])) + ) + def test_bitwise_right_shift(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) if dtype2 else 1 + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) if dtype2 else 1 + expected_dtype = standardize_dtype( + jnp.right_shift(x1_jax, x2_jax).dtype + ) + + self.assertDType(knp.bitwise_right_shift(x1, x2), expected_dtype) + self.assertDType( + knp.BitwiseRightShift().symbolic_call(x1, x2), expected_dtype + ) + + # right_shift is same as bitwise_right_shift + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_broadcast_to(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3,), dtype=dtype) + x_jax = jnp.ones((3,), dtype=dtype) + expected_dtype = standardize_dtype( + jnp.broadcast_to(x_jax, (3, 3)).dtype + ) + + self.assertEqual( + standardize_dtype(knp.broadcast_to(x, (3, 3)).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.BroadcastTo((3, 3)).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_cbrt(self, dtype): + import jax.numpy as jnp + + x1 = knp.ones((1,), dtype=dtype) + x1_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.cbrt(x1_jax).dtype) + + self.assertEqual(standardize_dtype(knp.cbrt(x1).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Cbrt().symbolic_call(x1).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_ceil(self, dtype): + import jax.numpy as jnp + + if dtype is None: + dtype = backend.floatx() + if dtype == "bool": + value = [[True, False, True], [True, False, True]] + elif "int" in dtype: + value = [[1, 2, 2], [2, 11, 5]] + else: + value = [[1.2, 2.1, 2.5], [2.4, 11.9, 5.5]] + x = knp.array(value, dtype=dtype) + x_jax = jnp.array(value, dtype=dtype) + expected_dtype = standardize_dtype(jnp.ceil(x_jax).dtype) + # Here, we follow Numpy's rule, not JAX's; ints are promoted to floats. + if dtype == "bool" or is_int_dtype(dtype): + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.ceil(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Ceil().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_clip(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.clip(x_jax, 1, 2).dtype) + if dtype == "bool": + expected_dtype = "int32" + + self.assertEqual( + standardize_dtype(knp.clip(x, 1, 2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Clip(1, 2).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_concatenate(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.concatenate([x1_jax, x2_jax]).dtype + ) + + self.assertEqual( + standardize_dtype(knp.concatenate([x1, x2]).dtype), expected_dtype + ) + self.assertEqual( + knp.Concatenate().symbolic_call([x1, x2]).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_cos(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.cos(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.cos(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Cos().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_cosh(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.cosh(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.cosh(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Cosh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_copy(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.copy(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.copy(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Copy().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_corrcoef(self, dtype): + import jax.numpy as jnp + + x = knp.ones((2, 4), dtype=dtype) + x_jax = jnp.ones((2, 4), dtype=dtype) + expected_dtype = standardize_dtype(jnp.corrcoef(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.corrcoef(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Corrcoef().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_correlate(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((3,), dtype=dtype1) + x2 = knp.ones((3,), dtype=dtype2) + x1_jax = jnp.ones((3,), dtype=dtype1) + x2_jax = jnp.ones((3,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.correlate(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.correlate(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Correlate().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_count_nonzero(self, dtype): + x = knp.ones((1,), dtype=dtype) + expected_dtype = "int32" + + self.assertEqual( + standardize_dtype(knp.count_nonzero(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.CountNonzero().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_cross(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1, 3), dtype=dtype1) + x2 = knp.ones((1, 1, 3), dtype=dtype2) + x1_jax = jnp.ones((1, 1, 3), dtype=dtype1) + x2_jax = jnp.ones((1, 1, 3), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.cross(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.cross(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Cross().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_cumprod(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.cumprod(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.cumprod(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Cumprod().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_cumsum(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.cumsum(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.cumsum(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Cumsum().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_deg2rad(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.deg2rad(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.deg2rad(x).dtype), expected_dtype + ) + + self.assertEqual( + standardize_dtype(knp.Deg2rad().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diag(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diag(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.diag(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Diag().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diagflat(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diagflat(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.diagflat(x).dtype), expected_dtype + ) + + self.assertEqual( + standardize_dtype(knp.Diagflat().symbolic_call(x).dtype), + expected_dtype, + ) + + x_2d = knp.ones((1, 1), dtype=dtype) + x_jax_2d = jnp.ones((1, 1), dtype=dtype) + expected_dtype_2d = standardize_dtype(jnp.diagflat(x_jax_2d).dtype) + + self.assertEqual( + standardize_dtype(knp.diagflat(x_2d).dtype), expected_dtype_2d + ) + self.assertEqual( + standardize_dtype(knp.Diagflat().symbolic_call(x_2d).dtype), + expected_dtype_2d, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diagonal(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diagonal(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.diagonal(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Diagonal().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diff(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diff(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.diff(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Diff().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_digitize(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + bins = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + x_bins = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.digitize(x_jax, x_bins).dtype) + + self.assertEqual( + standardize_dtype(knp.digitize(x, bins).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Digitize().symbolic_call(x, bins).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_divide(self, dtypes): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.divide doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.divide(x1_jax, x2_jax).dtype) + if "float64" in (dtype1, dtype2): + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.divide(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Divide().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_divide_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.divide doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.divide(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.divide(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Divide().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.divide(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.divide(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Divide().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_dot(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((2, 3, 4), dtype=dtype1) + x2 = knp.ones((4, 3), dtype=dtype2) + x1_jax = jnp.ones((2, 3, 4), dtype=dtype1) + x2_jax = jnp.ones((4, 3), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.dot(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.dot(x1, x2).dtype), expected_dtype + ) + self.assertEqual(knp.Dot().symbolic_call(x1, x2).dtype, expected_dtype) + + @parameterized.named_parameters( + named_product( + dtypes=list(itertools.combinations(ALL_DTYPES, 2)) + + [("int8", "int8")] + ) + ) + def test_einsum(self, dtypes): + import jax.numpy as jnp + + def get_input_shapes(subscripts): + x1_labels = subscripts.split(",")[0] + x2_labels = subscripts.split("->")[0][len(x1_labels) + 1 :] + x1_shape = [1] * len(x1_labels) + x2_shape = [1] * len(x2_labels) + return x1_shape, x2_shape + + dtype1, dtype2 = dtypes + subscripts = "ijk,lkj->il" + x1_shape, x2_shape = get_input_shapes(subscripts) + x1 = knp.ones(x1_shape, dtype=dtype1) + x2 = knp.ones(x2_shape, dtype=dtype2) + x1_jax = jnp.ones(x1_shape, dtype=dtype1) + x2_jax = jnp.ones(x2_shape, dtype=dtype2) + if dtype1 == "int8" and dtype2 == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + expected_dtype = standardize_dtype( + jnp.einsum( + subscripts, + x1_jax, + x2_jax, + preferred_element_type=preferred_element_type, + ).dtype + ) + + self.assertEqual( + standardize_dtype(knp.einsum(subscripts, x1, x2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Einsum(subscripts).symbolic_call(x1, x2).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product( + dtypes=list(itertools.combinations(ALL_DTYPES, 2)) + + [("int8", "int8")] + ) + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason=f"{backend.backend()} doesn't implement custom ops for einsum.", + ) + def test_einsum_custom_ops_for_tensorflow(self, dtypes): + import jax.numpy as jnp + + def get_input_shapes(subscripts): + x1_labels = subscripts.split(",")[0] + x2_labels = subscripts.split("->")[0][len(x1_labels) + 1 :] + x1_shape = [1] * len(x1_labels) + x2_shape = [1] * len(x2_labels) + return x1_shape, x2_shape + + dtype1, dtype2 = dtypes + for subscripts in [ + "a,b->ab", + "ab,b->a", + "ab,bc->ac", + "ab,cb->ac", + "abc,cd->abd", + "abc,cde->abde", + "abc,dc->abd", + "abc,dce->abde", + "abc,dec->abde", + "abcd,abde->abce", + "abcd,abed->abce", + "abcd,acbe->adbe", + "abcd,adbe->acbe", + "abcd,aecd->acbe", + "abcd,aecd->aceb", + "abcd,cde->abe", + "abcd,ced->abe", + "abcd,ecd->abe", + "abcde,aebf->adbcf", + "abcde,afce->acdbf", + ]: + x1_shape, x2_shape = get_input_shapes(subscripts) + x1 = knp.ones(x1_shape, dtype=dtype1) + x2 = knp.ones(x2_shape, dtype=dtype2) + x1_jax = jnp.ones(x1_shape, dtype=dtype1) + x2_jax = jnp.ones(x2_shape, dtype=dtype2) + if dtype1 == "int8" and dtype2 == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + expected_dtype = standardize_dtype( + jnp.einsum( + subscripts, + x1_jax, + x2_jax, + preferred_element_type=preferred_element_type, + ).dtype + ) + + self.assertEqual( + standardize_dtype(knp.einsum(subscripts, x1, x2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Einsum(subscripts).symbolic_call(x1, x2).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_empty(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.empty([2, 3], dtype=dtype).dtype) + + self.assertEqual( + standardize_dtype(knp.empty([2, 3], dtype=dtype).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_equal(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.equal(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.equal(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Equal().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_exp(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.exp(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.exp(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Exp().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_exp2(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.exp2(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.exp2(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Exp2().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_expand_dims(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.expand_dims(x_jax, -1).dtype) + + self.assertEqual( + standardize_dtype(knp.expand_dims(x, -1).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.ExpandDims(-1).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_expm1(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.expm1(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.expm1(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Expm1().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_eye(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.eye(3, dtype=dtype).dtype) + if dtype is None: + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.eye(3, dtype=dtype).dtype), + expected_dtype, + ) + + expected_dtype = standardize_dtype(jnp.eye(3, 4, 1, dtype=dtype).dtype) + if dtype is None: + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.eye(3, 4, k=1, dtype=dtype).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_flip(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.flip(x_jax, -1).dtype) + + self.assertEqual( + standardize_dtype(knp.flip(x, -1).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Flip(-1).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_floor(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.floor(x_jax).dtype) + # Here, we follow Numpy's rule, not JAX's; ints are promoted to floats. + if dtype == "bool" or is_int_dtype(dtype): + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.floor(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Floor().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_floor_divide(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.floor_divide(x1_jax, x2_jax).dtype + ) + + self.assertEqual( + standardize_dtype(knp.floor_divide(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.FloorDivide().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_floor_divide_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.floor_divide doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.floor_divide(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.floor_divide(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.FloorDivide().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype( + jnp.floor_divide(x_jax, 1.0).dtype + ) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.floor_divide(x, 1.0).dtype), + expected_dtype, + ) + self.assertEqual( + knp.FloorDivide().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_full(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.full((), 0, dtype=dtype).dtype) + if dtype is None: + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.full((), 0, dtype=dtype).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Full((), dtype=dtype).symbolic_call(0).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_full_like(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.full_like(x_jax, 0).dtype) + + self.assertEqual( + standardize_dtype(knp.full_like(x, 0).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.FullLike().symbolic_call(x, 0).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) + ) + def test_gcd(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.gcd(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.gcd(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Gcd().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_greater(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.greater(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.greater(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Greater().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_greater_equal(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.greater_equal(x1_jax, x2_jax).dtype + ) + + self.assertEqual( + standardize_dtype(knp.greater_equal(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.GreaterEqual().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_heaviside(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1), dtype=dtype1) + x2 = knp.ones((1, 1), dtype=dtype2) + x1_jax = jnp.ones((1, 1), dtype=dtype1) + x2_jax = jnp.ones((1, 1), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.heaviside(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.heaviside(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Heaviside().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_hstack(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1), dtype=dtype1) + x2 = knp.ones((1, 1), dtype=dtype2) + x1_jax = jnp.ones((1, 1), dtype=dtype1) + x2_jax = jnp.ones((1, 1), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.hstack([x1_jax, x2_jax]).dtype) + + self.assertEqual( + standardize_dtype(knp.hstack([x1, x2]).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Hstack().symbolic_call([x1, x2]).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_hypot(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1), dtype=dtype1) + x2 = knp.ones((1, 1), dtype=dtype2) + x1_jax = jnp.ones((1, 1), dtype=dtype1) + x2_jax = jnp.ones((1, 1), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.hypot(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.hypot(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Hypot().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_identity(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.identity(3, dtype=dtype).dtype) + if dtype is None: + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.identity(3, dtype=dtype).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_isclose(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.isclose(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.isclose(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Isclose().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isfinite(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isfinite(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.isfinite(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Isfinite().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_isin(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.isin(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.isin(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.IsIn().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isinf(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isinf(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.isinf(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Isinf().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isnan(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isnan(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.isnan(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Isnan().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isneginf(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isneginf(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.isneginf(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Isneginf().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isposinf(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isposinf(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.isposinf(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Isposinf().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) + ) + def test_kron(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.kron(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.kron(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Kron().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) + ) + def test_lcm(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.lcm(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.lcm(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Lcm().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_less(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.less(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.less(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Less().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_less_equal(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.less_equal(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.less_equal(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.LessEqual().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product( + start_and_stop=[ + [0, 10], + [0.5, 10.5], + [np.array([0, 1], "int32"), np.array([10, 20], "int32")], + [np.array([0, 1], "float32"), np.array([10, 20], "float32")], + ], + num=[0, 1, 5], + dtype=FLOAT_DTYPES + [None], + ) + ) + def test_linspace(self, start_and_stop, num, dtype): + import jax.numpy as jnp + + start, stop = start_and_stop + expected_dtype = standardize_dtype( + jnp.linspace(start, stop, num, dtype=dtype).dtype + ) + + self.assertEqual( + standardize_dtype( + knp.linspace(start, stop, num, dtype=dtype).dtype + ), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Linspace(num, dtype=dtype).symbolic_call(start, stop).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_log(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3, 3), dtype=dtype) + x_jax = jnp.ones((3, 3), dtype=dtype) + expected_dtype = standardize_dtype(jnp.log(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.log(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Log().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_log10(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3, 3), dtype=dtype) + x_jax = jnp.ones((3, 3), dtype=dtype) + expected_dtype = standardize_dtype(jnp.log10(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.log10(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Log10().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_log1p(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3, 3), dtype=dtype) + x_jax = jnp.ones((3, 3), dtype=dtype) + expected_dtype = standardize_dtype(jnp.log1p(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.log1p(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Log1p().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_log2(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3, 3), dtype=dtype) + x_jax = jnp.ones((3, 3), dtype=dtype) + expected_dtype = standardize_dtype(jnp.log2(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.log2(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Log2().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_logaddexp(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((3, 3), dtype=dtype1) + x2 = knp.ones((3, 3), dtype=dtype2) + x1_jax = jnp.ones((3, 3), dtype=dtype1) + x2_jax = jnp.ones((3, 3), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.logaddexp(x1_jax, x2_jax).dtype) + # jnp.logaddexp will promote "int64" and "uint32" to "float64" + # force the promotion to `backend.floatx()` + if dtype1 is not None and "float" not in dtype1: + if dtype2 is not None and "float" not in dtype2: + if "int64" in (dtype1, dtype2) or "uint32" in (dtype1, dtype2): + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.logaddexp(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Logaddexp().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_logaddexp2(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((3, 3), dtype=dtype1) + x2 = knp.ones((3, 3), dtype=dtype2) + x1_jax = jnp.ones((3, 3), dtype=dtype1) + x2_jax = jnp.ones((3, 3), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.logaddexp2(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.logaddexp2(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Logaddexp2().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product( + start_and_stop=[ + [0, 10], + [0.5, 10.5], + [np.array([0, 1], "int32"), np.array([10, 20], "int32")], + [np.array([0, 1], "float32"), np.array([10, 20], "float32")], + ], + num=[0, 1, 5], + dtype=FLOAT_DTYPES + [None], + ) + ) + def test_logspace(self, start_and_stop, num, dtype): + import jax.numpy as jnp + + start, stop = start_and_stop + expected_dtype = standardize_dtype( + jnp.logspace(start, stop, num, dtype=dtype).dtype + ) + + self.assertEqual( + standardize_dtype( + knp.logspace(start, stop, num, dtype=dtype).dtype + ), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Logspace(num, dtype=dtype).symbolic_call(start, stop).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_logical_and(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.logical_and(x1_jax, x2_jax).dtype + ) + + self.assertEqual( + standardize_dtype(knp.logical_and(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.LogicalAnd().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_logical_not(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.logical_not(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.logical_not(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.LogicalNot().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_logical_or(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.logical_or(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.logical_or(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.LogicalOr().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_logical_xor(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.logical_xor(x1_jax, x2_jax).dtype + ) + + self.assertEqual( + standardize_dtype(knp.logical_xor(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.LogicalXor().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_maximum(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.maximum(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.maximum(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Maximum().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_maximum_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.maximum doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. + with jax_disable_x64_context(): + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.maximum(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.maximum(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Maximum().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.maximum(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.maximum(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Maximum().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_median(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3, 3), dtype=dtype) + x_jax = jnp.ones((3, 3), dtype=dtype) + expected_dtype = standardize_dtype(jnp.median(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.median(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Median().symbolic_call(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.median(x, axis=1).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Median(axis=1).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_meshgrid(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + self.skipTest("meshgrid doesn't support bool dtype") + elif dtype is None: + dtype = backend.floatx() + x = knp.array([1, 2, 3], dtype=dtype) + y = knp.array([4, 5, 6], dtype=dtype) + x_jax = jnp.array([1, 2, 3], dtype=dtype) + y_jax = jnp.array([4, 5, 6], dtype=dtype) + expected_dtype = standardize_dtype(jnp.meshgrid(x_jax, y_jax)[0].dtype) + + self.assertEqual( + standardize_dtype(knp.meshgrid(x, y)[0].dtype), expected_dtype + ) + self.assertEqual( + knp.Meshgrid().symbolic_call(x, y)[0].dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_min(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.min(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.min(x).dtype), expected_dtype) + self.assertEqual(knp.Min().symbolic_call(x).dtype, expected_dtype) + + # Test with initial + initial = 0 + expected_dtype = standardize_dtype( + jnp.min(x_jax, initial=initial).dtype + ) + self.assertEqual( + standardize_dtype(knp.min(x, initial=initial).dtype), expected_dtype + ) + self.assertEqual( + knp.Min(initial=initial).symbolic_call(x).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_minimum(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.minimum(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.minimum(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Minimum().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_minimum_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.minimum doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. + with jax_disable_x64_context(): + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.minimum(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.minimum(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Minimum().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.minimum(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.minimum(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Minimum().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_mod(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.mod(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.mod(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Mod().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_moveaxis(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.moveaxis(x_jax, -2, -1).dtype) + + self.assertEqual( + standardize_dtype(knp.moveaxis(x, -2, -1).dtype), expected_dtype + ) + self.assertEqual( + knp.Moveaxis(-2, -1).symbolic_call(x).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_nan_to_num(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.nan_to_num(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.nan_to_num(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.NanToNum().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_nonzero(self, dtype): + import jax.numpy as jnp + + x = knp.zeros((1,), dtype=dtype) + x_jax = jnp.zeros((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.nonzero(x_jax)[0].dtype) + + self.assertEqual( + standardize_dtype(knp.nonzero(x)[0].dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Nonzero().symbolic_call(x)[0].dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_not_equal(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((), dtype=dtype1) + x2 = knp.ones((), dtype=dtype2) + x1_jax = jnp.ones((), dtype=dtype1) + x2_jax = jnp.ones((), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.not_equal(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.not_equal(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.NotEqual().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_ones_like(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.ones_like(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.ones_like(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.OnesLike().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_outer(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 2), dtype=dtype1) + x2 = knp.ones((3, 4), dtype=dtype2) + x1_jax = jnp.ones((1, 2), dtype=dtype1) + x2_jax = jnp.ones((3, 4), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.outer(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.outer(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Outer().symbolic_call(x1, x2).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_pad(self, dtype): + import jax.numpy as jnp + + x = knp.ones((2, 2, 2, 2), dtype=dtype) + x_jax = jnp.ones((2, 2, 2, 2), dtype=dtype) + pad_width = ((0, 0), (1, 1), (1, 1), (1, 1)) + + for mode in ("constant", "symmetric", "reflect"): + expected_dtype = standardize_dtype( + jnp.pad(x_jax, pad_width, mode).dtype + ) + + self.assertEqual( + standardize_dtype(knp.pad(x, pad_width, mode).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Pad(pad_width, mode).symbolic_call(x).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_power(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x = knp.ones((1,), dtype=dtype1) + power = knp.ones((1,), dtype2) + x_jax = jnp.ones((1,), dtype=dtype1) + power_jax = jnp.ones((1,), dtype2) + expected_dtype = standardize_dtype(jnp.power(x_jax, power_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.power(x, power).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Power().symbolic_call(x, power).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_power_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.power doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + + # python int + expected_dtype = standardize_dtype(jnp.power(x_jax, 1).dtype) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.power(x, 1).dtype), expected_dtype + ) + self.assertEqual( + knp.Power().symbolic_call(x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype(jnp.power(x_jax, 1.0).dtype) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.power(x, 1.0).dtype), expected_dtype + ) + self.assertEqual( + knp.Power().symbolic_call(x, 1.0).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_prod(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.prod(x_jax).dtype) + # TODO: torch doesn't support uint32 + if backend.backend() == "torch" and expected_dtype == "uint32": + expected_dtype = "int32" + + self.assertEqual( + standardize_dtype(knp.prod(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Prod().symbolic_call(x).dtype), expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_quantile(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3,), dtype=dtype) + x_jax = jnp.ones((3,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.quantile(x_jax, 0.5).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.quantile(x, 0.5).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Quantile().symbolic_call(x, 0.5).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_searchsorted(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + self.skipTest("searchsorted doesn't support bool dtype") + + a = knp.ones((3,), dtype=dtype) + v = knp.ones((3,), dtype=dtype) + + a_jax = jnp.ones((3,), dtype=dtype) + v_jax = jnp.ones((3,), dtype=dtype) + + expected_dtype = standardize_dtype(jnp.searchsorted(a_jax, v_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.searchsorted(a, v).dtype), expected_dtype + ) + + self.assertEqual( + standardize_dtype(knp.SearchSorted().symbolic_call(a, v).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_ravel(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.ravel(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.ravel(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Ravel().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=INT_DTYPES)) + def test_unravel_index(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3,), dtype=dtype) + x_jax = jnp.ones((3,), dtype=dtype) + + indices = knp.array([2, 0], dtype=dtype) + indices_jax = jnp.array([2, 0], dtype=dtype) + + unravel_result_knp = knp.unravel_index(indices, x.shape) + unravel_result_jax = jnp.unravel_index(indices_jax, x_jax.shape) + + expected_dtype_knp = standardize_dtype(unravel_result_knp[0].dtype) + expected_dtype_jax = standardize_dtype(unravel_result_jax[0].dtype) + + self.assertEqual(expected_dtype_knp, expected_dtype_jax) + + unravel_result_knp_symbolic = knp.UnravelIndex(x.shape).symbolic_call( + indices + ) + expected_dtype_symbolic = standardize_dtype( + unravel_result_knp_symbolic[0].dtype + ) + + self.assertEqual(expected_dtype_symbolic, expected_dtype_jax) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_repeat(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.repeat(x_jax, 2).dtype) + + self.assertEqual( + standardize_dtype(knp.repeat(x, 2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Repeat(2).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_reshape(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.reshape(x_jax, [1]).dtype) + + self.assertEqual( + standardize_dtype(knp.reshape(x, [1]).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Reshape([1]).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_roll(self, dtype): + import jax.numpy as jnp + + x = knp.ones((5,), dtype=dtype) + x_jax = jnp.ones((5,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.roll(x_jax, 2).dtype) + + self.assertEqual( + standardize_dtype(knp.roll(x, 2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Roll(2).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_round(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + self.skipTest("round doesn't support bool dtype") + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.round(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.round(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Round().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_sign(self, dtype): + import jax.numpy as jnp + + if dtype == "bool": + self.skipTest("sign doesn't support bool dtype") + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.sign(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.sign(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Sign().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_signbit(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.signbit(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.signbit(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Signbit().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_sin(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.sin(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.sin(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Sin().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_sinh(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.sinh(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.sinh(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Sinh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_sort(self, dtype): + import jax.numpy as jnp + + x = knp.ones((2,), dtype=dtype) + x_jax = jnp.ones((2,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.sort(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.sort(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Sort().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_split(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 2), dtype=dtype) + x_jax = jnp.ones((1, 2), dtype=dtype) + expected_dtype = standardize_dtype(jnp.split(x_jax, 2, -1)[0].dtype) + + self.assertEqual( + standardize_dtype(knp.split(x, 2, -1)[0].dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Split(2, -1).symbolic_call(x)[0].dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_sqrt(self, dtype): + import jax.numpy as jnp + + x1 = knp.ones((1,), dtype=dtype) + x1_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.sqrt(x1_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.sqrt(x1).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Sqrt().symbolic_call(x1).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_square(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.square(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.square(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Square().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_squeeze(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.squeeze(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.squeeze(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Squeeze().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_stack(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.stack([x1_jax, x2_jax]).dtype) + + self.assertEqual( + standardize_dtype(knp.stack([x1, x2]).dtype), expected_dtype + ) + self.assertEqual( + knp.Stack().symbolic_call([x1, x2]).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_std(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.std(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.std(x).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Std().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_sum(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.sum(x_jax).dtype) + + # TODO: torch doesn't support uint32 + if backend.backend() == "torch" and expected_dtype == "uint32": + expected_dtype = "int32" + + self.assertEqual(standardize_dtype(knp.sum(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Sum().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_swapaxes(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.swapaxes(x_jax, -1, -2).dtype) + + self.assertEqual( + standardize_dtype(knp.swapaxes(x, -1, -2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Swapaxes(-1, -2).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_take(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.take(x_jax, 0).dtype) + + self.assertEqual( + standardize_dtype(knp.take(x, 0).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype(knp.Take().symbolic_call(x, 0).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtype=ALL_DTYPES, indices_dtype=INT_DTYPES) + ) + def test_take_along_axis(self, dtype, indices_dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + indices = knp.zeros((1,), dtype=indices_dtype) + x_jax = jnp.ones((1,), dtype=dtype) + indices_jax = jnp.zeros((1,), dtype=indices_dtype) + expected_dtype = standardize_dtype( + jnp.take_along_axis(x_jax, indices_jax, 0).dtype + ) + + self.assertEqual( + standardize_dtype(knp.take_along_axis(x, indices, 0).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.TakeAlongAxis(0).symbolic_call(x, indices).dtype + ), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_tan(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.tan(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.tan(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Tan().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_tanh(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.tanh(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.tanh(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Tanh().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_tensordot(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1), dtype=dtype1) + x2 = knp.ones((1, 1), dtype=dtype2) + x1_jax = jnp.ones((1, 1), dtype=dtype1) + x2_jax = jnp.ones((1, 1), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.tensordot(x1_jax, x2_jax, 2).dtype + ) + + self.assertEqual( + standardize_dtype(knp.tensordot(x1, x2, 2).dtype), expected_dtype + ) + self.assertEqual( + knp.Tensordot(2).symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_tile(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.tile(x_jax, [1]).dtype) + + self.assertEqual( + standardize_dtype(knp.tile(x, [1]).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Tile([1]).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_trace(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.trace doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + x = knp.ones((1, 1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.trace(x_jax).dtype) + # jnp.trace is buggy with bool. We set the expected_dtype to int32 + # for bool inputs + if dtype == "bool": + expected_dtype = "int32" + elif dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + # TODO: Remove the condition of uint8 and uint16 once we have + # jax>=0.4.27 for both CPU & GPU environments. + # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to + # int32 otherwise. + elif dtype in ("uint8", "uint16"): + expected_dtype = "int32" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertDType(knp.trace(x), expected_dtype) + self.assertDType(knp.Trace().symbolic_call(x), expected_dtype) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_transpose(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.transpose(x_jax, [1, 0]).dtype) + + self.assertEqual( + standardize_dtype(knp.transpose(x, [1, 0]).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Transpose([1, 0]).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_tri(self, dtype): + import jax.numpy as jnp + + expected_dtype = standardize_dtype(jnp.tri(3, dtype=dtype).dtype) + + self.assertEqual( + standardize_dtype(knp.tri(3, dtype=dtype).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_tril(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.tril(x_jax, 0).dtype) + + self.assertEqual( + standardize_dtype(knp.tril(x, 0).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Tril(0).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_triu(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.triu(x_jax, 0).dtype) + + self.assertEqual( + standardize_dtype(knp.triu(x, 0).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Triu(0).symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_true_divide(self, dtypes): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.true_divide doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.true_divide(x1_jax, x2_jax).dtype + ) + if "float64" in (dtype1, dtype2): + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.true_divide(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.TrueDivide().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_trunc(self, dtype): + x = knp.ones((1, 1), dtype=dtype) + # TODO: jax <= 0.30.0 doesn't preserve the original dtype. + expected_dtype = dtype or backend.floatx() + + self.assertEqual(standardize_dtype(knp.trunc(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Trunc().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_var(self, dtype): + import jax.numpy as jnp + + x = knp.ones((2,), dtype=dtype) + x_jax = jnp.ones((2,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.var(x_jax).dtype) + if dtype == "int64": + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.var(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Var().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_vdot(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.vdot(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.vdot(x1, x2).dtype), expected_dtype + ) + self.assertEqual(knp.Vdot().symbolic_call(x1, x2).dtype, expected_dtype) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_inner(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.inner(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.inner(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Inner().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_vstack(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.vstack([x1_jax, x2_jax]).dtype) + + self.assertEqual( + standardize_dtype(knp.vstack([x1, x2]).dtype), expected_dtype + ) + self.assertEqual( + knp.Vstack().symbolic_call([x1, x2]).dtype, expected_dtype + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_where(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + condition = knp.ones((10,), dtype="bool") + x1 = knp.ones((10,), dtype=dtype1) + x2 = knp.ones((10,), dtype=dtype2) + condition_jax = jnp.ones((10,), dtype="bool") + x1_jax = jnp.ones((10,), dtype=dtype1) + x2_jax = jnp.ones((10,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.where(condition_jax, x1_jax, x2_jax).dtype + ) + + self.assertEqual( + standardize_dtype(knp.where(condition, x1, x2).dtype), + expected_dtype, + ) + self.assertEqual( + knp.Where().symbolic_call(condition, x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_where_python_types(self, dtype): + import jax.numpy as jnp + + # We have to disable x64 for jax since jnp.power doesn't respect + # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast + # the expected dtype from 64 bit to 32 bit when using jax backend. + with jax_disable_x64_context(): + condition = knp.ones((10,), dtype="bool") + x = knp.ones((10,), dtype=dtype) + condition_jax = jnp.ones((10,), dtype="bool") + x_jax = jnp.ones((10,), dtype=dtype) + + # python int + expected_dtype = standardize_dtype( + jnp.where(condition_jax, x_jax, 1).dtype + ) + if dtype == "float64": + expected_dtype = "float64" + elif dtype == "int64": + expected_dtype = "int64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.where(condition, x, 1).dtype), + expected_dtype, + ) + self.assertEqual( + knp.Where().symbolic_call(condition, x, 1).dtype, expected_dtype + ) + + # python float + expected_dtype = standardize_dtype( + jnp.where(condition_jax, x_jax, 1.0).dtype + ) + if dtype == "float64": + expected_dtype = "float64" + if backend.backend() == "jax": + expected_dtype = expected_dtype.replace("64", "32") + + self.assertEqual( + standardize_dtype(knp.where(condition, x, 1.0).dtype), + expected_dtype, + ) + self.assertEqual( + knp.Where().symbolic_call(condition, x, 1.0).dtype, + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_zeros_like(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.ones_like(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.zeros_like(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.ZerosLike().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_angle(self, dtype): + import jax.numpy as jnp + + if dtype == "bfloat16": + self.skipTest("Weirdness with numpy") + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.angle(x_jax).dtype) + if dtype == "bool" or is_int_dtype(dtype): + expected_dtype = backend.floatx() + + self.assertEqual(standardize_dtype(knp.angle(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Angle().symbolic_call(x).dtype), + expected_dtype, + ) + + +@pytest.mark.skipif( + testing.torch_uses_gpu(), + reason="histogram op not implemented for torch on gpu", +) +class HistogramTest(testing.TestCase): + def test_histogram_default_args(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + # Expected output + expected_counts, expected_edges = np.histogram(input_tensor) + + counts, edges = hist_op(input_tensor) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_custom_bins(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + bins = 5 + + # Expected output + expected_counts, expected_edges = np.histogram(input_tensor, bins=bins) + + counts, edges = hist_op(input_tensor, bins=bins) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_custom_range(self): + hist_op = knp.histogram + input_tensor = np.random.rand(10) + range_specified = (2, 8) + + # Expected output + expected_counts, expected_edges = np.histogram( + input_tensor, range=range_specified + ) + + counts, edges = hist_op(input_tensor, range=range_specified) + + self.assertEqual(counts.shape, expected_counts.shape) + self.assertAllClose(counts, expected_counts) + self.assertEqual(edges.shape, expected_edges.shape) + self.assertAllClose(edges, expected_edges) + + def test_histogram_symbolic_input(self): + hist_op = knp.histogram + input_tensor = KerasTensor(shape=(None,), dtype="float32") + + counts, edges = hist_op(input_tensor) + + self.assertEqual(counts.shape, (10,)) + self.assertEqual(edges.shape, (11,)) + + def test_histogram_non_integer_bins_raises_error(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + with self.assertRaisesRegex( + ValueError, "Argument `bins` should be a non-negative integer" + ): + hist_op(input_tensor, bins=-5) + + def test_histogram_range_validation(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + with self.assertRaisesRegex( + ValueError, "Argument `range` must be a tuple of two elements" + ): + hist_op(input_tensor, range=(1,)) + + with self.assertRaisesRegex( + ValueError, + "The second element of `range` must be greater than the first", + ): + hist_op(input_tensor, range=(5, 1)) + + def test_histogram_large_values(self): + hist_op = knp.histogram + input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10]) + + counts, edges = hist_op(input_tensor, bins=5) + + expected_counts, expected_edges = np.histogram(input_tensor, bins=5) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + def test_histogram_float_input(self): + hist_op = knp.histogram + input_tensor = np.random.rand(8) + + counts, edges = hist_op(input_tensor, bins=5) + + expected_counts, expected_edges = np.histogram(input_tensor, bins=5) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + def test_histogram_high_dimensional_input(self): + hist_op = knp.histogram + input_tensor = np.random.rand(3, 4, 5) + + with self.assertRaisesRegex( + ValueError, "Input tensor must be 1-dimensional" + ): + hist_op(input_tensor) + + def test_histogram_values_on_edges(self): + hist_op = knp.histogram + input_tensor = np.array([0.0, 2.0, 4.0, 8.0, 10.0]) + bins = 5 + + expected_counts, expected_edges = np.histogram(input_tensor, bins=bins) + counts, edges = hist_op(input_tensor, bins=bins) + + self.assertAllClose(counts, expected_counts) + self.assertAllClose(edges, expected_edges) + + # TODO: Fix predict for NumPy. + @parameterized.named_parameters( + ("jit_compile_false", False), + ("jit_compile_true", True), + ) + @pytest.mark.skipif( + backend.backend() == "numpy", + reason=( + "`predict` errors out with 'autodetected range of [nan, nan] is " + "not finite' on the NumPy backend. To be fixed." + ), + ) + def test_histogram_predict(self, jit_compile): + class HistogramLayer(keras.layers.Layer): + def call(self, x): + shape = ops.shape(x) + + # Flatten, because the op does not work with >1-dim inputs. + x = ops.reshape(x, (shape[0] * shape[1],)) + return knp.histogram(x, bins=5) + + inputs = keras.Input(shape=(8,)) + counts, edges = HistogramLayer()(inputs) + model = keras.Model(inputs, (counts, edges)) + model.compile(jit_compile=jit_compile) + + model.predict(np.random.randn(1, 8)) diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py new file mode 100644 index 000000000000..570ce5e27c9a --- /dev/null +++ b/keras/src/ops/operation.py @@ -0,0 +1,402 @@ +import inspect +import textwrap + +from keras.src import backend +from keras.src import dtype_policies +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend.common.keras_tensor import any_symbolic_tensors +from keras.src.backend.config import is_nnx_enabled +from keras.src.ops.node import Node +from keras.src.saving.keras_saveable import KerasSaveable +from keras.src.utils import python_utils +from keras.src.utils import traceback_utils +from keras.src.utils.naming import auto_name + + +@keras_export("keras.Operation") +class Operation(KerasSaveable): + def __init__(self, name=None): + if name is None: + name = auto_name(self.__class__.__name__) + if not isinstance(name, str) or "/" in name: + raise ValueError( + "Argument `name` must be a string and " + f"cannot contain character `/`. " + f"Received: name={name} (of type {type(name)})" + ) + self.name = name + self._inbound_nodes = [] + self._outbound_nodes = [] + + @traceback_utils.filter_traceback + def __call__(self, *args, **kwargs): + if traceback_utils.is_traceback_filtering_enabled(): + # Wrap self.call to provide helpful info in case of exception + if any_symbolic_tensors(args, kwargs): + call_fn = self.symbolic_call + else: + if getattr(self, "_remat_mode", None) is not None: + if getattr(self, "quantization_mode", None) is not None: + call_fn = self.rematerialized_call( + self.quantized_call, + *args, + **kwargs, + ) + else: + call_fn = self.rematerialized_call( + self.call, *args, **kwargs + ) + else: + if getattr(self, "quantization_mode", None) is not None: + call_fn = self.quantized_call + else: + call_fn = self.call + call_fn = traceback_utils.inject_argument_info_in_traceback( + call_fn, + object_name=(f"{self.__class__.__name__}.call()"), + ) + return call_fn(*args, **kwargs) + + # Plain flow. + if any_symbolic_tensors(args, kwargs): + return self.symbolic_call(*args, **kwargs) + elif getattr(self, "_remat_mode", None) is not None: + if getattr(self, "quantization_mode", None) is not None: + return self.rematerialized_call( + self.quantized_call, *args, **kwargs + )(*args, **kwargs) + else: + return self.rematerialized_call(self.call, *args, **kwargs)( + *args, **kwargs + ) + else: + if getattr(self, "quantization_mode", None) is not None: + return self.quantized_call(*args, **kwargs) + else: + return self.call(*args, **kwargs) + + def symbolic_call(self, *args, **kwargs): + # Perform shape/dtype inference. + outputs = self.compute_output_spec(*args, **kwargs) + # Record a new node in the operations graph. + # The Node wires itself to inbound and outbound ops. The + # Node constructor updates this op's self._inbound_nodes, + # sets _keras_history on the outputs, and adds itself to the + # `_outbound_nodes` of the ops that produced the inputs to this + # call. + Node( + operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs + ) + return outputs + + def call(self, *args, **kwargs): + raise NotImplementedError + + def quantized_call(self, *args, **kwargs): + raise NotImplementedError + + def compute_output_spec(self, *args, **kwargs): + try: + return backend.compute_output_spec(self.call, *args, **kwargs) + except Exception as e: + new_e = e.__class__( + "Could not automatically infer the output shape / dtype of " + f"'{self.name}' (of type {self.__class__.__name__}). " + f"Either the `{self.__class__.__name__}.call()` method " + f"is incorrect, or you need to implement the " + f"`{self.__class__.__name__}.compute_output_spec() / " + "compute_output_shape()` method. " + f"Error encountered:\n\n{e}" + ) + raise new_e.with_traceback(e.__traceback__) from None + + def __new__(cls, *args, **kwargs): + """We override __new__ to saving serializable constructor arguments. + + These arguments are used to auto-generate an object serialization + config, which enables user-created subclasses to be serializable + out of the box in most cases without forcing the user + to manually implement `get_config()`. + """ + instance = super(Operation, cls).__new__(cls) + if backend.backend() == "jax" and is_nnx_enabled(): + from flax import nnx + + try: + vars(instance)["_pytree__state"] = nnx.pytreelib.PytreeState() + except AttributeError: + vars(instance)["_object__state"] = nnx.object.ObjectState() + + # Generate a config to be returned by default by `get_config()`. + auto_config = True + + signature = inspect.signature(cls.__init__) + argspec = inspect.getfullargspec(cls.__init__) + + try: + bound_parameters = signature.bind(None, *args, **kwargs) + except TypeError: + # Raised by signature.bind when the supplied args and kwargs + # do not match the signature. + auto_config = False + + if auto_config and any( + [ + param.kind == inspect.Parameter.POSITIONAL_ONLY + for name, param in signature.parameters.items() + if name != argspec.args[0] + ] + ): + # cls.__init__ takes positional only arguments, which + # cannot be restored via cls(**config) + auto_config = False + # Create variable to show appropriate warning in get_config. + instance._auto_config_error_args = True + + if auto_config: + # Include default values in the config. + bound_parameters.apply_defaults() + # Extract all arguments as a dictionary. + kwargs = bound_parameters.arguments + # Expand variable kwargs argument. + kwargs |= kwargs.pop(argspec.varkw, {}) + # Remove first positional argument, self. + kwargs.pop(argspec.args[0]) + # Remove argument "name", as it is provided by get_config. + kwargs.pop("name", None) + if argspec.varargs is not None: + # Varargs cannot be meaningfully converted to a dictionary. + varargs = kwargs.pop(argspec.varargs) + if len(varargs) > 0: + auto_config = False + # Store variable to show appropriate warning in get_config. + instance._auto_config_error_args = True + + # For safety, we only rely on auto-configs for a small set of + # serializable types. + supported_types = (str, int, float, bool, type(None)) + try: + flat_arg_values = tree.flatten(kwargs) + for value in flat_arg_values: + if not isinstance(value, supported_types): + auto_config = False + break + except TypeError: + auto_config = False + try: + instance._lock = False + if auto_config: + from keras.src.saving import serialization_lib + + instance._auto_config = serialization_lib.SerializableDict( + **kwargs + ) + else: + instance._auto_config = None + instance._lock = True + except RecursionError: + # Setting an instance attribute in __new__ has the potential + # to trigger an infinite recursion if a subclass overrides + # setattr in an unsafe way. + pass + return instance + + @python_utils.default + def get_config(self): + """Returns the config of the object. + + An object config is a Python dictionary (serializable) + containing the information needed to re-instantiate it. + """ + config = { + "name": self.name, + } + + if not python_utils.is_default(self.get_config): + # In this case the subclass implements get_config() + return config + + # In this case the subclass doesn't implement get_config(): + # Let's see if we can autogenerate it. + if getattr(self, "_auto_config", None) is not None: + config.update(self._auto_config.config) + init_params = inspect.signature(self.__init__).parameters + init_has_name = "name" in init_params + init_has_kwargs = ( + "kwargs" in init_params + and init_params["kwargs"].kind == inspect.Parameter.VAR_KEYWORD + ) + if not init_has_name and not init_has_kwargs: + # We can't pass `name` back to `__init__`, remove it. + config.pop("name", None) + return config + else: + example_str = """ + class CustomLayer(keras.layers.Layer): + def __init__(self, arg1, arg2, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def get_config(self): + config = super().get_config() + config.update({ + "arg1": self.arg1, + "arg2": self.arg2, + }) + return config + """ + if getattr(self, "_auto_config_error_args", False): + raise NotImplementedError( + textwrap.dedent( + f""" + Object {self.__class__.__name__} was created by passing + positional only or variadic positional arguments (e.g., + `*args`) to `__init__()`, which is not supported by the + automatic config generation. Please remove all positional + only and variadic arguments from `__init__()` + or override `get_config()` and `from_config()` to make + the object serializatble. + + Example: + + {example_str}""" + ) + ) + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Object {self.__class__.__name__} was created by passing + non-serializable argument values in `__init__()`, + and therefore the object must override `get_config()` in + order to be serializable. Please implement `get_config()`. + + Example: + + {example_str}""" + ) + ) + + @classmethod + def from_config(cls, config): + """Creates an operation from its config. + + This method is the reverse of `get_config`, capable of instantiating the + same operation from the config dictionary. + + Note: If you override this method, you might receive a serialized dtype + config, which is a `dict`. You can deserialize it as follows: + + ```python + if "dtype" in config and isinstance(config["dtype"], dict): + policy = dtype_policies.deserialize(config["dtype"]) + ``` + + Args: + config: A Python dictionary, typically the output of `get_config`. + + Returns: + An operation instance. + """ + # Explicitly deserialize dtype config if needed. This enables users to + # directly interact with the instance of `DTypePolicy`. + if "dtype" in config and isinstance(config["dtype"], dict): + config = config.copy() + policy = dtype_policies.deserialize(config["dtype"]) + if ( + not isinstance(policy, dtype_policies.DTypePolicyMap) + and policy.quantization_mode is None + ): + # For backward compatibility, we use a str (`name`) for + # `DTypePolicy` + policy = policy.name + config["dtype"] = policy + try: + return cls(**config) + except Exception as e: + raise TypeError( + f"Error when deserializing class '{cls.__name__}' using " + f"config={config}.\n\nException encountered: {e}" + ) + + def __repr__(self): + return f"" + + @property + def input(self): + """Retrieves the input tensor(s) of a symbolic operation. + + Only returns the tensor(s) corresponding to the *first time* + the operation was called. + + Returns: + Input tensor or list of input tensors. + """ + return self._get_node_attribute_at_index(0, "input_tensors", "input") + + @property + def output(self): + """Retrieves the output tensor(s) of a layer. + + Only returns the tensor(s) corresponding to the *first time* + the operation was called. + + Returns: + Output tensor or list of output tensors. + """ + return self._get_node_attribute_at_index(0, "output_tensors", "output") + + def _get_node_attribute_at_index(self, node_index, attr, attr_name): + """Private utility to retrieves an attribute (e.g. inputs) from a node. + + This is used to implement the properties: + - output + - input + + Args: + node_index: Integer index of the node from which + to retrieve the attribute. + attr: Exact node attribute name. + attr_name: Human-readable attribute name, for error messages. + + Returns: + The operation's attribute `attr` at the node of index `node_index`. + """ + if not self._inbound_nodes: + raise AttributeError( + f"The layer {self.name} has never been called " + f"and thus has no defined {attr_name}." + ) + if not len(self._inbound_nodes) > node_index: + raise ValueError( + f"Asked to get {attr_name} at node " + f"{node_index}, but the operation has only " + f"{len(self._inbound_nodes)} inbound nodes." + ) + values = getattr(self._inbound_nodes[node_index], attr) + if isinstance(values, list) and len(values) == 1: + return values[0] + else: + return values + + def _obj_type(self): + return "Operation" + + # Hooks for backend layer classes + def _post_build(self): + """Can be overridden for per backend post build actions.""" + pass + + def _setattr_hook(self, name, value): + """Can be overridden for per backend post build actions.""" + return name, value + + def _post_track_variable(self, variable): + """Can be overridden for per backend post track actions.""" + pass + + def _post_untrack_variable(self, variable): + """Can be overridden for per backend post untrack actions.""" + pass diff --git a/keras/src/ops/operation_test.py b/keras/src/ops/operation_test.py new file mode 100644 index 000000000000..0a039edad841 --- /dev/null +++ b/keras/src/ops/operation_test.py @@ -0,0 +1,333 @@ +import numpy as np + +from conftest import skip_if_backend +from keras.src import backend +from keras.src import testing +from keras.src.backend.common import keras_tensor +from keras.src.ops import numpy as knp +from keras.src.ops import operation + + +class OpWithMultipleInputs(operation.Operation): + def call(self, x, y, z=None): + # `z` has to be put first due to the order of operations issue with + # torch backend. + return 3 * z + x + 2 * y + + def compute_output_spec(self, x, y, z=None): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + +class OpWithMultipleOutputs(operation.Operation): + def call(self, x): + return (x, x + 1) + + def compute_output_spec(self, x): + return ( + keras_tensor.KerasTensor(x.shape, x.dtype), + keras_tensor.KerasTensor(x.shape, x.dtype), + ) + + +class OpWithCustomConstructor(operation.Operation): + def __init__(self, alpha, *, beta=1.0, name=None): + super().__init__(name=name) + self.alpha = alpha + self.beta = beta + + def call(self, x): + return self.alpha * x + self.beta + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + +class OpWithCustomConstructorNoName(operation.Operation): + def __init__(self, alpha, beta=1.0): + super().__init__() + self.alpha = alpha + self.beta = beta + + def call(self, x): + return self.alpha * x + self.beta + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + +class OpWithKwargsInConstructor(operation.Operation): + def __init__(self, alpha, beta=1.0, **kwargs): + super().__init__(**kwargs) + self.alpha = alpha + self.beta = beta + + def call(self, x): + return self.alpha * x + self.beta + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + +class OpWithArgsInConstructor(operation.Operation): + def __init__(self, alpha, *args, name=None): + super().__init__(name=name) + self.alpha = alpha + + def call(self, x): + return self.alpha * x + self.beta + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + +class OpWithCustomConstructorGetConfig(operation.Operation): + def __init__(self, alpha, *, name=None): + super().__init__(name=name) + self.alpha = alpha + + def call(self, x): + return self.alpha * x + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + def get_config(self): + return {**super().get_config(), "alpha": self.alpha} + + +class OpWithKwargsInConstructorGetConfig(operation.Operation): + def __init__(self, alpha, **kwargs): + super().__init__(**kwargs) + self.alpha = alpha + + def call(self, x): + return self.alpha * x + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + def get_config(self): + return {**super().get_config(), "alpha": self.alpha} + + +class OperationTest(testing.TestCase): + def test_symbolic_call(self): + x = keras_tensor.KerasTensor(shape=(2, 3), name="x") + y = keras_tensor.KerasTensor(shape=(2, 3), name="y") + z = keras_tensor.KerasTensor(shape=(2, 3), name="z") + + # Positional arguments + op = OpWithMultipleInputs(name="test_op") + self.assertEqual(op.name, "test_op") + out = op(x, y, z) + self.assertIsInstance(out, keras_tensor.KerasTensor) + self.assertEqual(out.shape, (2, 3)) + self.assertEqual(len(op._inbound_nodes), 1) + self.assertEqual(op.input, [x, y, z]) + self.assertEqual(op.output, out) + + # Keyword arguments + op = OpWithMultipleInputs(name="test_op") + out = op(x=x, y=y, z=z) + self.assertIsInstance(out, keras_tensor.KerasTensor) + self.assertEqual(out.shape, (2, 3)) + self.assertEqual(len(op._inbound_nodes), 1) + self.assertEqual(op.input, [x, y, z]) + self.assertEqual(op.output, out) + + # Mix + op = OpWithMultipleInputs(name="test_op") + out = op(x, y=y, z=z) + self.assertIsInstance(out, keras_tensor.KerasTensor) + self.assertEqual(out.shape, (2, 3)) + self.assertEqual(len(op._inbound_nodes), 1) + self.assertEqual(op.input, [x, y, z]) + self.assertEqual(op.output, out) + + # Test op reuse + prev_out = out + out = op(x, y=y, z=z) + self.assertIsInstance(out, keras_tensor.KerasTensor) + self.assertEqual(out.shape, (2, 3)) + self.assertEqual(len(op._inbound_nodes), 2) + self.assertEqual(op.output, prev_out) + + # Test multiple outputs + op = OpWithMultipleOutputs() + out = op(x) + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 2) + self.assertIsInstance(out[0], keras_tensor.KerasTensor) + self.assertIsInstance(out[1], keras_tensor.KerasTensor) + self.assertEqual(out[0].shape, (2, 3)) + self.assertEqual(out[1].shape, (2, 3)) + self.assertEqual(len(op._inbound_nodes), 1) + self.assertEqual(op.output, list(out)) + + def test_eager_call(self): + x = knp.ones((2, 3)) + y = knp.ones((2, 3)) + z = knp.ones((2, 3)) + op = OpWithMultipleInputs(name="test_op") + self.assertEqual(op.name, "test_op") + + # Positional arguments + out = op(x, y, z) + self.assertTrue(backend.is_tensor(out)) + self.assertAllClose(out, 6 * np.ones((2, 3))) + + # Keyword arguments + out = op(x=x, y=y, z=z) + self.assertTrue(backend.is_tensor(out)) + self.assertAllClose(out, 6 * np.ones((2, 3))) + + # Mixed arguments + out = op(x, y=y, z=z) + self.assertTrue(backend.is_tensor(out)) + self.assertAllClose(out, 6 * np.ones((2, 3))) + + # Test multiple outputs + op = OpWithMultipleOutputs() + out = op(x) + self.assertEqual(len(out), 2) + self.assertTrue(backend.is_tensor(out[0])) + self.assertTrue(backend.is_tensor(out[1])) + self.assertAllClose(out[0], np.ones((2, 3))) + self.assertAllClose(out[1], np.ones((2, 3)) + 1) + + def test_serialization_with_default_init_and_get_config(self): + # Explicit name passed in constructor is serialized and deserialized. + op = OpWithMultipleInputs(name="test_op") + config = op.get_config() + self.assertEqual(config, {"name": "test_op"}) + revived = OpWithMultipleInputs.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + # Auto generated name is serialized and deserialized. + op = OpWithMultipleInputs() + config = op.get_config() + self.assertEqual(config, {"name": op.name}) + revived = OpWithMultipleInputs.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + def test_serialization_custom_constructor_with_name_auto_config(self): + # Explicit name passed in constructor is serialized and deserialized. + op = OpWithCustomConstructor(alpha=0.2, name="test_op") + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "beta": 1.0, "name": "test_op"}) + revived = OpWithCustomConstructor.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + # Auto generated name is serialized and deserialized. + op = OpWithCustomConstructor(alpha=0.2, beta=0.0) + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "beta": 0.0, "name": op.name}) + revived = OpWithCustomConstructor.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + def test_serialization_custom_constructor_with_no_name_auto_config(self): + # Auto generated name is not serialized. + op = OpWithCustomConstructorNoName(alpha=0.2) + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "beta": 1.0}) + revived = OpWithCustomConstructorNoName.from_config(config) + self.assertEqual(revived.get_config(), config) + + def test_serialization_custom_constructor_with_kwargs_auto_config(self): + # Explicit name passed in constructor is serialized and deserialized. + op = OpWithKwargsInConstructor(alpha=0.2, name="test_op") + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "beta": 1.0, "name": "test_op"}) + revived = OpWithKwargsInConstructor.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + # Auto generated name is serialized and deserialized. + op = OpWithKwargsInConstructor(alpha=0.2, beta=0.0) + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "beta": 0.0, "name": op.name}) + revived = OpWithKwargsInConstructor.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + def test_failing_serialization_non_serializable_auto_config( + self, + ): + class NonSerializable: + pass + + # Custom class cannot be automatically serialized. + op = OpWithCustomConstructor(alpha=NonSerializable(), name="test_op") + with self.assertRaises(NotImplementedError): + _ = op.get_config() + + def test_failing_serialization_custom_constructor_with_args_auto_config( + self, + ): + # Custom constructor with variadic args cannot be automatically + # serialized. + op = OpWithArgsInConstructor(0.2, "a", "b", "c", name="test_op") + with self.assertRaises(NotImplementedError): + _ = op.get_config() + + def test_serialization_custom_constructor_custom_get_config(self): + # Explicit name passed in constructor is serialized and deserialized. + op = OpWithCustomConstructorGetConfig(alpha=0.2, name="test_op") + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "name": "test_op"}) + revived = OpWithCustomConstructorGetConfig.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + # Auto generated name is serialized and deserialized. + op = OpWithCustomConstructorGetConfig(alpha=0.2) + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "name": op.name}) + revived = OpWithCustomConstructorGetConfig.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + def test_serialization_custom_constructor_with_kwargs_custom_get_config( + self, + ): + # Explicit name passed in constructor is serialized and deserialized. + op = OpWithKwargsInConstructorGetConfig(alpha=0.2, name="test_op") + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "name": "test_op"}) + revived = OpWithKwargsInConstructorGetConfig.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + # Auto generated name is serialized and deserialized. + op = OpWithKwargsInConstructorGetConfig(alpha=0.2) + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "name": op.name}) + revived = OpWithKwargsInConstructorGetConfig.from_config(config) + self.assertEqual(revived.get_config(), config) + self.assertEqual(revived.name, op.name) + + @skip_if_backend( + "openvino", "Can not constant fold eltwise node by CPU plugin" + ) + def test_input_conversion(self): + x = np.ones((2,)) + y = np.ones((2,)) + z = knp.ones((2,)) # mix + if backend.backend() == "torch": + z = z.cpu() + op = OpWithMultipleInputs() + out = op(x, y, z) + self.assertTrue(backend.is_tensor(out)) + self.assertAllClose(out, 6 * np.ones((2,))) + + def test_valid_naming(self): + OpWithMultipleOutputs(name="test_op") + + with self.assertRaisesRegex( + ValueError, "must be a string and cannot contain character `/`." + ): + OpWithMultipleOutputs(name="test/op") diff --git a/keras/src/ops/operation_utils.py b/keras/src/ops/operation_utils.py new file mode 100644 index 000000000000..b1ac2621de0a --- /dev/null +++ b/keras/src/ops/operation_utils.py @@ -0,0 +1,425 @@ +import math + +import numpy as np + +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend.common.backend_utils import canonicalize_axis +from keras.src.backend.common.backend_utils import to_tuple_or_list + + +def broadcast_shapes(shape1, shape2): + """Broadcast input shapes to a unified shape. + + Convert to list for mutability. + + Args: + shape1: A tuple or list of integers. + shape2: A tuple or list of integers. + + Returns: + output_shape (list of integers or `None`): The broadcasted shape. + + Example: + >>> broadcast_shapes((5, 3), (1, 3)) + [5, 3] + """ + shape1 = list(shape1) + shape2 = list(shape2) + origin_shape1 = shape1 + origin_shape2 = shape2 + + if len(shape1) > len(shape2): + shape2 = [1] * (len(shape1) - len(shape2)) + shape2 + if len(shape1) < len(shape2): + shape1 = [1] * (len(shape2) - len(shape1)) + shape1 + output_shape = list(shape1) + for i in range(len(shape1)): + if shape1[i] == 1: + output_shape[i] = shape2[i] + elif shape1[i] is None: + output_shape[i] = None if shape2[i] == 1 else shape2[i] + else: + if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]: + output_shape[i] = shape1[i] + else: + raise ValueError( + "Cannot broadcast shape, the failure dim has value " + f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. " + f"Input shapes are: {origin_shape1} and {origin_shape2}." + ) + + return output_shape + + +def compute_expand_dims_output_shape(input_shape, axis): + """Compute the output shape for the `expand_dims` operation. + + Args: + input_shape: Input shape. + axis: int or sequence of ints for the axis to expand. + + Returns: + Tuple of ints: The output shape after the `expand_dims` operation. + """ + input_shape = list(input_shape) + if axis is None: + axis = len(input_shape) + axis = to_tuple_or_list(axis) + out_ndim = len(axis) + len(input_shape) + axis = [canonicalize_axis(a, out_ndim) for a in axis] + shape_iter = iter(input_shape) + new_shape = [ + 1 if ax in axis else next(shape_iter) for ax in range(out_ndim) + ] + return tuple(new_shape) + + +def compute_pooling_output_shape( + input_shape, + pool_size, + strides, + padding="valid", + data_format="channels_last", +): + """Computes the output shape of pooling operations. + + Args: + input_shape: Input shape. Must be a tuple of integers. + pool_size: Size of the pooling operation. Must be a tuple of integers. + strides: Stride of the pooling operation. Must be a tuple of integers. + Defaults to `pool_size`. + padding: Padding method. Available methods are `"valid"` or `"same"`. + Defaults to `"valid"`. + data_format: String, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, weight)`. Defaults to `"channels_last"`. + + Returns: + Tuple of ints: The output shape of the pooling operation. + + Examples: + + # Basic usage with square pooling on a single image + >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2)) + (1, 2, 2, 1) + + # Strided pooling on a single image with strides different from pool_size + >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2), strides=(1, 1)) + (1, 3, 3, 1) + + # Pooling on a batch of images + >>> compute_pooling_output_shape((32, 4, 4, 3), (2, 2)) + (32, 2, 2, 3) + """ + strides = pool_size if strides is None else strides + input_shape_origin = list(input_shape) + input_shape = np.array(input_shape) + if data_format == "channels_last": + spatial_shape = input_shape[1:-1] + else: + spatial_shape = input_shape[2:] + none_dims = [] + for i in range(len(spatial_shape)): + if spatial_shape[i] is None: + # Set `None` shape to a manual value so that we can run numpy + # computation on `spatial_shape`. + spatial_shape[i] = -1 + none_dims.append(i) + pool_size = np.array(pool_size) + if padding == "valid": + output_spatial_shape = ( + np.floor((spatial_shape - pool_size) / strides) + 1 + ) + for i in range(len(output_spatial_shape)): + if i not in none_dims and output_spatial_shape[i] < 0: + raise ValueError( + "Computed output size would be negative. Received: " + f"`inputs.shape={input_shape}` and `pool_size={pool_size}`." + ) + elif padding == "same": + output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 + else: + raise ValueError( + "Argument `padding` must be either 'valid' or 'same'. Received: " + f"padding={padding}" + ) + output_spatial_shape = [int(i) for i in output_spatial_shape] + for i in none_dims: + output_spatial_shape[i] = None + output_spatial_shape = tuple(output_spatial_shape) + if data_format == "channels_last": + output_shape = ( + (input_shape_origin[0],) + + output_spatial_shape + + (input_shape_origin[-1],) + ) + else: + output_shape = ( + input_shape_origin[0], + input_shape_origin[1], + ) + output_spatial_shape + return output_shape + + +def compute_conv_output_shape( + input_shape, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, +): + """Compute the output shape of conv ops.""" + if data_format == "channels_last": + spatial_shape = input_shape[1:-1] + kernel_shape = kernel_size + (input_shape[-1], filters) + else: + spatial_shape = input_shape[2:] + kernel_shape = kernel_size + (input_shape[1], filters) + if len(kernel_shape) != len(input_shape): + raise ValueError( + "Kernel shape must have the same length as input, but received " + f"kernel of shape {kernel_shape} and " + f"input of shape {input_shape}." + ) + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) * len(spatial_shape) + if isinstance(strides, int): + strides = (strides,) * len(spatial_shape) + if len(dilation_rate) != len(spatial_shape): + raise ValueError( + "Dilation must be None, scalar or tuple/list of length of " + "inputs' spatial shape, but received " + f"`dilation_rate={dilation_rate}` and " + f"input of shape {input_shape}." + ) + none_dims = [] + spatial_shape = np.array(spatial_shape) + for i in range(len(spatial_shape)): + if spatial_shape[i] is None: + # Set `None` shape to a manual value so that we can run numpy + # computation on `spatial_shape`. + spatial_shape[i] = -1 + none_dims.append(i) + + kernel_spatial_shape = np.array(kernel_shape[:-2]) + dilation_rate = np.array(dilation_rate) + if padding == "valid": + output_spatial_shape = ( + np.floor( + (spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1) + / strides + ) + + 1 + ) + for i in range(len(output_spatial_shape)): + if i not in none_dims and output_spatial_shape[i] < 0: + raise ValueError( + "Computed output size would be negative. Received " + f"`inputs shape={input_shape}`, " + f"`kernel shape={kernel_shape}`, " + f"`dilation_rate={dilation_rate}`." + ) + elif padding == "same" or padding == "causal": + output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1 + else: + raise ValueError( + "`padding` must be either `'valid'` or `'same'`. Received " + f"{padding}." + ) + output_spatial_shape = [int(i) for i in output_spatial_shape] + for i in none_dims: + output_spatial_shape[i] = None + output_spatial_shape = tuple(output_spatial_shape) + if data_format == "channels_last": + output_shape = ( + (input_shape[0],) + output_spatial_shape + (kernel_shape[-1],) + ) + else: + output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape + return output_shape + + +def compute_matmul_output_shape(shape1, shape2): + """Compute the output shape of a `matmul` operation. + + Args: + shape1: Shape of the left operand. + shape2: Shape of the right operand. + + Returns: + Tuple of ints: The output shape for the `matmul` operation. + """ + if len(shape1) == 1: + shape1 = (1, shape1[0]) + if len(shape2) == 1: + shape2 = (shape2[0], 1) + if ( + shape1[-1] is not None + and shape2[-2] is not None + and shape1[-1] != shape2[-2] + ): + raise ValueError( + "Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be " + f"equal, but received `x1.shape={shape1}` and " + f"`x2.shape={shape2}`." + ) + + leading_shape = broadcast_shapes(shape1[:-2], shape2[:-2]) + last_2_dims_shape = [shape1[-2], shape2[-1]] + output_shape = leading_shape + last_2_dims_shape + if len(shape1) == 1: + del output_shape[-2] + if len(shape2) == 1: + del output_shape[-1] + return tuple(output_shape) + + +def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name): + """Converts `-1` in `newshape` to either an actual dimension or `None`. + + This utility does not special case the 0th dimension (batch size). + """ + unknown_dim_count = newshape.count(-1) + if unknown_dim_count > 1: + raise ValueError( + "There must be at most one unknown dimension (-1) in " + f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}." + ) + + # If there is a None in input_shape, we can't infer what the -1 is + if None in input_shape: + return tuple(dim if dim != -1 else None for dim in newshape) + + input_size = math.prod(input_shape) + # If the `newshape` is fully defined, return it + if unknown_dim_count == 0: + if input_size != math.prod(newshape): + raise ValueError( + "The total size of the tensor must be unchanged. Received: " + f"input_shape={input_shape}, {newshape_arg_name}={newshape}" + ) + return newshape + + # We have one -1 in `newshape`, compute the actual value + known_output_size = 1 + unknown_dim_index = None + for index, dim in enumerate(newshape): + if dim == -1: + unknown_dim_index = index + else: + known_output_size *= dim + + if known_output_size == 0 or input_size % known_output_size != 0: + raise ValueError( + "The total size of the tensor must be unchanged, however, the " + "input size cannot by divided by the specified dimensions in " + f"{newshape_arg_name}. Received: input_shape={input_shape}, " + f"{newshape_arg_name}={newshape}" + ) + + output_shape = list(newshape) + output_shape[unknown_dim_index] = input_size // known_output_size + return tuple(output_shape) + + +def compute_transpose_output_shape(input_shape, axes): + """Compute the output shape for the `transpose` operation. + + Args: + input_shape: Input shape. + axes: Permutation of the dimensions for the `transpose` operation. + + Returns: + Tuple of ints: The output shape after the `transpose` operation. + """ + input_shape = list(input_shape) + if axes is None: + return tuple(input_shape[::-1]) + + if len(axes) != len(input_shape): + raise ValueError( + "axis must be a list of the same length as the input shape, " + f"expected {len(input_shape)}, but received {len(axes)}." + ) + return tuple(input_shape[ax] for ax in axes) + + +def compute_take_along_axis_output_shape(input_shape, indices_shape, axis): + input_shape = list(input_shape) + indices_shape = list(indices_shape) + if axis is None: + input_shape = ( + [None] if None in input_shape else [int(np.prod(input_shape))] + ) + + if len(input_shape) != len(indices_shape): + raise ValueError( + "`x` and `indices` must have the same number of dimensions, " + f"but receive shape {input_shape} and {indices_shape}." + ) + + input_shape[axis] = indices_shape[axis] + output_shape = broadcast_shapes(input_shape, indices_shape) + return output_shape + + +def reduce_shape(shape, axis=None, keepdims=False): + shape = list(shape) + if axis is None: + if keepdims: + return tuple([1 for _ in shape]) + else: + return tuple([]) + elif isinstance(axis, int): + axis = (axis,) + + axis = tuple(canonicalize_axis(a, len(shape)) for a in axis) + + if keepdims: + for ax in axis: + shape[ax] = 1 + return tuple(shape) + else: + for ax in sorted(axis, reverse=True): + del shape[ax] + return tuple(shape) + + +@keras_export("keras.utils.get_source_inputs") +def get_source_inputs(tensor): + """Returns the list of input tensors necessary to compute `tensor`. + + Output will always be a list of tensors + (potentially with 1 element). + + Args: + tensor: The tensor to start from. + + Returns: + List of input tensors. + """ + if not hasattr(tensor, "_keras_history"): + return tensor + + operation, node_index, _ = tensor._keras_history + if not operation or not operation._inbound_nodes: + return [tensor] + else: + node = operation._inbound_nodes[node_index] + if node.is_input: + # Reached input node, stop recursion. + return tree.flatten(node.output_tensors) + else: + source_tensors = [] + for tensor in node.input_tensors: + previous_sources = get_source_inputs(tensor) + # Avoid input redundancy. + for x in previous_sources: + if all(x is not t for t in source_tensors): + source_tensors.append(x) + return source_tensors diff --git a/keras/src/ops/operation_utils_test.py b/keras/src/ops/operation_utils_test.py new file mode 100644 index 000000000000..2ac2e5b0fa30 --- /dev/null +++ b/keras/src/ops/operation_utils_test.py @@ -0,0 +1,210 @@ +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.layers.core import input_layer +from keras.src.ops import operation_utils + + +class OperationUtilsTest(testing.TestCase): + def test_get_source_inputs(self): + x1 = backend.KerasTensor(shape=(2,)) + x2 = backend.KerasTensor(shape=(2,)) + x = x1 + x2 + x += 2 + x = ops.square(x) + self.assertEqual(operation_utils.get_source_inputs(x), [x1, x2]) + + def test_get_source_inputs_return_input_tensor(self): + inputs = input_layer.Input(shape=(10,)) + self.assertIs(operation_utils.get_source_inputs(inputs)[0], inputs) + + def test_compute_expand_dims_output_shape(self): + input_shape = (2, 3, 4) + axis = -1 + output_shape = operation_utils.compute_expand_dims_output_shape( + input_shape, axis + ) + expected_output_shape = (2, 3, 4, 1) + self.assertEqual(output_shape, expected_output_shape) + + input_shape = (2, 3, 4) + axis = (1, -1) + output_shape = operation_utils.compute_expand_dims_output_shape( + input_shape, axis + ) + expected_output_shape = (2, 1, 3, 4, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_pooling_output_shape(self): + input_shape = (1, 4, 4, 1) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, pool_size, strides + ) + expected_output_shape = (1, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_pooling_output_shape_with_none(self): + input_shape = (None, 4, 4, 1) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, pool_size, strides + ) + expected_output_shape = (None, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_pooling_output_shape_valid_padding(self): + input_shape = (1, 4, 4, 1) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, pool_size, strides, padding="valid" + ) + self.assertEqual(output_shape, (1, 2, 2, 1)) + + def test_compute_pooling_output_shape_channels_last(self): + input_shape = (1, 4, 4, 3) + pool_size = (2, 2) + strides = (2, 2) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, + pool_size, + strides, + padding="valid", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 2, 2, 3)) + + def test_compute_pooling_output_shape_same_padding_stride1(self): + input_shape = (1, 4, 4, 3) + pool_size = (2, 2) + strides = (1, 1) + output_shape = operation_utils.compute_pooling_output_shape( + input_shape, + pool_size, + strides, + padding="same", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 4, 4, 3)) + + def test_compute_conv_output_shape(self): + input_shape = (1, 4, 4, 1) + filters = 1 + kernel_size = (3, 3) + strides = (1, 1) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, filters, kernel_size, strides + ) + expected_output_shape = (1, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_conv_output_shape_with_none(self): + input_shape = (None, 4, 4, 1) + kernel_size = (3, 3) + filters = 1 + strides = (1, 1) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, filters, kernel_size, strides + ) + expected_output_shape = (None, 2, 2, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_compute_conv_output_shape_valid_padding(self): + input_shape = (1, 4, 4, 1) + kernel_size = (3, 3) + filters = 1 + strides = (2, 2) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, filters, kernel_size, strides, padding="valid" + ) + self.assertEqual(output_shape, (1, 1, 1, 1)) + + def test_compute_conv_output_shape_channels_last(self): + input_shape = (1, 4, 4, 3) + kernel_size = (3, 3) + filters = 3 + strides = (2, 2) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, + filters, + kernel_size, + strides, + padding="valid", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 1, 1, 3)) + + def test_compute_conv_output_shape_same_padding_stride1(self): + input_shape = (1, 4, 4, 3) + kernel_size = (3, 3) + filters = 3 + strides = (1, 1) + output_shape = operation_utils.compute_conv_output_shape( + input_shape, + filters, + kernel_size, + strides, + padding="same", + data_format="channels_last", + ) + self.assertEqual(output_shape, (1, 4, 4, 3)) + + def test_compute_reshape_output_shape(self): + input_shape = (1, 4, 4, 1) + target_shape = (16, 1) + output_shape = operation_utils.compute_reshape_output_shape( + input_shape, newshape=target_shape, newshape_arg_name="New shape" + ) + self.assertEqual(output_shape, target_shape) + + def test_reduce_shape_no_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + output_shape = operation_utils.reduce_shape(input_shape) + expected_output_shape = () + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_no_axes_with_keepdims(self): + input_shape = (1, 4, 4, 1) + output_shape = operation_utils.reduce_shape(input_shape, keepdims=True) + expected_output_shape = (1, 1, 1, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_single_axis_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [1] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 4, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_single_axis_with_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [1] + output_shape = operation_utils.reduce_shape( + input_shape, axes, keepdims=True + ) + expected_output_shape = (1, 1, 4, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_multiple_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [1, 2] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_out_of_order_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [2, 1] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 1) + self.assertEqual(output_shape, expected_output_shape) + + def test_reduce_shape_negative_axes_no_keepdims(self): + input_shape = (1, 4, 4, 1) + axes = [-2, -3] + output_shape = operation_utils.reduce_shape(input_shape, axes) + expected_output_shape = (1, 1) + self.assertEqual(output_shape, expected_output_shape) diff --git a/keras/src/ops/ops_test.py b/keras/src/ops/ops_test.py new file mode 100644 index 000000000000..724dd573400b --- /dev/null +++ b/keras/src/ops/ops_test.py @@ -0,0 +1,278 @@ +import inspect + +from absl.testing import parameterized + +try: + from keras.api import ops as api_ops_root +except ImportError: + from keras import ops as api_ops_root + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.ops.operation import Operation +from keras.src.testing.test_utils import named_product +from keras.src.utils.naming import to_snake_case + +OPS_MODULES = ("core", "image", "linalg", "math", "nn", "numpy") + +SELF_PARAMETER = inspect.Parameter( + "self", inspect.Parameter.POSITIONAL_OR_KEYWORD +) +NAME_PARAMETER = inspect.Parameter( + "name", inspect.Parameter.KEYWORD_ONLY, default=None +) + +# Parameters with these names are known to always be static (non-tensors). +STATIC_PARAMETER_NAMES = frozenset( + {"axis", "axes", "dtype", "shape", "newshape", "sparse", "ragged"} +) + + +def op_functions_and_classes(ops_module): + """Enumerate pairs of op function and op classes in a module. + + Will return for instance `(ExpandDims, expand_dims)`, `(Sum, sum)`, ... + + Args: + ops_module: the module to explore. + + Returns: + iterable returning tuples with function and class pairs. + """ + # Go through all symbols. + for op_class_name in dir(ops_module): + op_class = getattr(ops_module, op_class_name) + # Find the ones that are classes that extend `Operation`. + if isinstance(op_class, type) and Operation in op_class.__mro__: + # Infer what the corresponding op function name should be. + op_function_name = to_snake_case(op_class_name) + # With some exceptions. + op_function_name = { + "batch_norm": "batch_normalization", + "rms_norm": "rms_normalization", + "search_sorted": "searchsorted", + }.get(op_function_name, op_function_name) + # Check if that function exist. Some classes are abstract super + # classes for multiple operations and should be ignored. + op_function = getattr(ops_module, op_function_name, None) + if op_function is not None: + # We have a pair, return it. + yield op_function, op_class + + +class OperationTest(testing.TestCase): + @parameterized.named_parameters(named_product(module_name=OPS_MODULES)) + def test_class_function_consistency(self, module_name): + ops_module = getattr(ops, module_name) + if module_name in ("core", "math"): + # `core` and `math` are not exported as their own module. + api_ops_module = None + else: + api_ops_module = getattr(api_ops_root, module_name) + + for op_function, op_class in op_functions_and_classes(ops_module): + name = op_function.__name__ + + # ==== Check exports ==== + # - op should be exported as e.g. `keras.ops.numpy.sum` + # - op should also be exported as e.g. `keras.ops.sum` + + if module_name != "image": + # `image` ops are not exported at the top-level. + self.assertIsNotNone( + getattr(api_ops_root, name, None), + f"Not exported as `keras.ops.{name}`", + ) + if api_ops_module is not None: + # `core` and `math` are not exported as their own module. + self.assertIsNotNone( + getattr(api_ops_module, name, None), + f"Not exported as `keras.ops.{module_name}.{name}`", + ) + + # ==== Check handling of name in __init__ ==== + # - op class `__init__` should have a `name` parameter at the end, + # which should be keyword only and with a default value of `None` + # - op class `__init__` should call `super().__init__(name=name)` + + if op_class.__init__ is Operation.__init__: + # `name` is not keyword only in `Operation`, use this instead. + class_init_signature = inspect.Signature( + [SELF_PARAMETER, NAME_PARAMETER] + ) + else: + class_init_signature = inspect.signature(op_class.__init__) + + # Check call to super. + self.assertContainsSubsequence( + inspect.getsource(op_class.__init__), + "super().__init__(name=name)", + f"`{op_class.__name__}.__init__` is not calling " + "`super().__init__(name=name)`", + ) + + static_parameters = list(class_init_signature.parameters.values()) + # Remove `self`. + static_parameters = static_parameters[1:] + name_index = -1 + if static_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + # When there is a `**kwargs`, `name` appears before. + name_index = -2 + # Verify `name` parameter is as expected. + self.assertEqual( + static_parameters[name_index], + NAME_PARAMETER, + f"The last parameter of `{op_class.__name__}.__init__` " + "should be `name`, should be a keyword only, and should " + "have a default value of `None`", + ) + # Remove `name`, it's not part of the op signature. + static_parameters.pop(name_index) + + # ==== Check static parameters ==== + # Static parameters are declared in the class' `__init__`. + # Dynamic parameters are declared in the class' `call` method. + # - they should all appear in the op signature with the same name + # - they should have the same default value + # - they should appear in the same order and usually with the + # dynamic parameters first, and the static parameters last. + + dynamic_parameters = list( + inspect.signature(op_class.call).parameters.values() + )[1:] # Remove self + + op_signature = inspect.signature(op_function) + + for p in dynamic_parameters + static_parameters: + # Check the same name appears in the op signature + self.assertIn( + p.name, + op_signature.parameters, + f"Op function `{name}` is missing a parameter that is in " + f"op class `{op_class.__name__}`", + ) + # Check default values are the same + self.assertEqual( + p.default, + op_signature.parameters[p.name].default, + f"Default mismatch for parameter `{p.name}` between op " + f"function `{name}` and op class `{op_class.__name__}`", + ) + + dynamic_parameter_names = [p.name for p in dynamic_parameters] + static_parameter_names = [p.name for p in static_parameters] + + # Check for obvious mistakes in parameters that were made dynamic + # but should be static. + for p in dynamic_parameters: + self.assertNotIn( + p.name, + STATIC_PARAMETER_NAMES, + f"`{p.name}` should not be a dynamic parameter in op class " + f"`{op_class.__name__}` based on its name.", + ) + self.assertNotIsInstance( + p.default, + (bool, str), + f"`{p.name}` should not be a dynamic parameter in op class " + f"`{op_class.__name__}` based on default `{p.default}`.", + ) + + # Check order of parameters. + if name in ( + "fori_loop", + "vectorized_map", + "while_loop", + "batch_normalization", + "dot_product_attention", + "average", + "einsum", + "full", + "pad", + ): + # Loose case: + # order of of parameters is preserved but they are interspersed. + op_dynamic_parameter_names = [ + name + for name in op_signature.parameters.keys() + if name in dynamic_parameter_names + ] + self.assertEqual( + op_dynamic_parameter_names, + dynamic_parameter_names, + "Inconsistent dynamic parameter order for op " + f"function `{name}` and op class `{op_class.__name__}`", + ) + op_static_parameter_names = [ + name + for name in op_signature.parameters.keys() + if name in static_parameter_names + ] + self.assertEqual( + op_static_parameter_names, + static_parameter_names, + "Inconsistent static parameter order for op " + f"function `{name}` and op class `{op_class.__name__}`", + ) + else: + # Strict case: + # dynamic parameters first and static parameters at the end. + self.assertEqual( + list(op_signature.parameters.keys()), + dynamic_parameter_names + static_parameter_names, + "Inconsistent static parameter position for op " + f"function `{name}` and op class `{op_class.__name__}`", + ) + + # ==== Check compute_output_spec is implement ==== + # - op class should override Operation's `compute_output_spec` + self.assertTrue( + hasattr(op_class, "compute_output_spec") + and op_class.compute_output_spec + is not Operation.compute_output_spec, + f"Op class `{op_class.__name__}` should override " + "`compute_output_spec`", + ) + + @parameterized.named_parameters(named_product(module_name=OPS_MODULES)) + def test_backend_consistency(self, module_name): + ops_module = getattr(ops, module_name) + backend_ops_module = getattr(backend, module_name) + + for op_function, _ in op_functions_and_classes(ops_module): + name = op_function.__name__ + + if hasattr(ops_module, f"_{name}"): + # For an op function `foo`, if there is a function named `_foo`, + # that means we have a backend independent implementation. + continue + if name in ("view_as_complex", "view_as_real", "get_item"): + # These ops have an inlined backend independent implementation. + continue + + # ==== Check backend implementation ==== + # - op should have an implementation in every backend + # - op implementation should have the same signature (same + # parameters, same order, same defaults) + + backend_op_function = getattr(backend_ops_module, name, None) + + if backend.backend() == "openvino" and backend_op_function is None: + # Openvino is still missing a number of ops. + continue + + self.assertIsNotNone(backend_op_function, f"Missing op `{name}`") + + if name == "multi_hot": + # multi_hot has code to massage the input parameters before + # calling the backend implementation, so the signature is + # different on purpose. + continue + + # Signature should match in every way. + self.assertEqual( + inspect.signature(backend_op_function), + inspect.signature(op_function), + f"Signature mismatch for `{name}`", + ) diff --git a/keras/src/ops/symbolic_arguments.py b/keras/src/ops/symbolic_arguments.py new file mode 100644 index 000000000000..c71e04e7b145 --- /dev/null +++ b/keras/src/ops/symbolic_arguments.py @@ -0,0 +1,46 @@ +from keras.src import tree +from keras.src.backend import KerasTensor + + +class SymbolicArguments: + def __init__(self, *args, **kwargs): + self.args = tree.map_structure(lambda x: x, args) + self.kwargs = tree.map_structure(lambda x: x, kwargs) + self._flat_arguments = tree.flatten((self.args, self.kwargs)) + + # Used to avoid expensive `tree` operations in the most common case. + if ( + not self.kwargs + and len(self.args) == 1 + and isinstance(self.args[0], KerasTensor) + ): + self._single_positional_tensor = self.args[0] + else: + self._single_positional_tensor = None + + self.keras_tensors = [] + for arg in self._flat_arguments: + if isinstance(arg, KerasTensor): + self.keras_tensors.append(arg) + + def convert(self, conversion_fn): + args = tree.map_structure(conversion_fn, self.args) + kwargs = tree.map_structure(conversion_fn, self.kwargs) + return args, kwargs + + def fill_in(self, tensor_dict): + """Maps KerasTensors to computed values using `tensor_dict`. + + `tensor_dict` maps `KerasTensor` instances to their current values. + """ + if self._single_positional_tensor is not None: + # Performance optimization for most common case. + # Approx. 70x faster. + return (tensor_dict[id(self._single_positional_tensor)],), {} + + def switch_fn(x): + if isinstance(x, KerasTensor): + return tensor_dict.get(id(x), None) + return x + + return self.convert(switch_fn) diff --git a/keras/src/ops/symbolic_arguments_test.py b/keras/src/ops/symbolic_arguments_test.py new file mode 100644 index 000000000000..034e56779c2c --- /dev/null +++ b/keras/src/ops/symbolic_arguments_test.py @@ -0,0 +1,121 @@ +from keras.src import testing +from keras.src import tree +from keras.src.backend import KerasTensor +from keras.src.ops.symbolic_arguments import SymbolicArguments + + +class SymbolicArgumentsTest(testing.TestCase): + # Testing multiple args and empty kwargs + def test_args(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + b = KerasTensor(shape=shape) + args = SymbolicArguments( + ( + a, + b, + ), + {}, + ) + + self.assertEqual(args.keras_tensors, [a, b]) + self.assertEqual(args._flat_arguments, [a, b]) + self.assertEqual(args._single_positional_tensor, None) + + # Testing single arg and single position tensor + def test_args_single_arg(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + args = SymbolicArguments((a)) + + self.assertEqual(args.keras_tensors, [a]) + self.assertEqual(args._flat_arguments, [a]) + self.assertEqual(len(args.kwargs), 0) + self.assertEqual(isinstance(args.args[0], KerasTensor), True) + self.assertEqual(args._single_positional_tensor, a) + + # Testing kwargs + def test_kwargs(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + b = KerasTensor(shape=shape) + c = KerasTensor(shape=shape) + args = SymbolicArguments( + ( + a, + b, + ), + {1: c}, + ) + + self.assertEqual(args.keras_tensors, [a, b, c]) + self.assertEqual(args._flat_arguments, [a, b, c]) + self.assertEqual(args._single_positional_tensor, None) + + # Testing conversion function with args and kwargs + def test_conversion_fn(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + b = KerasTensor(shape=shape) + c = KerasTensor(shape=shape) + sym_args = SymbolicArguments( + ( + a, + b, + ), + {1: c}, + ) + + (value, _) = sym_args.convert(lambda x: x**2) + args1 = value[0][0] + + self.assertIsInstance(args1, KerasTensor) + + mapped_value = tree.map_structure(lambda x: x**2, a) + self.assertEqual(mapped_value.shape, args1.shape) + self.assertEqual(mapped_value.dtype, args1.dtype) + + # Testing fill in function with single args only + def test_fill_in_single_arg(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + + tensor_dict = {id(a): 3} + sym_args = SymbolicArguments((a)) + + # Call the method to be tested + result, _ = sym_args.fill_in(tensor_dict) + + self.assertEqual(result, (3,)) + + # Testing fill in function with multiple args + def test_fill_in_multiple_arg(self): + shape = (2, 3, 4) + a = KerasTensor(shape=shape) + b = KerasTensor(shape=shape) + + tensor_dict = {id(b): 2} + sym_args = SymbolicArguments((a, b)) + + # Call the method to be tested + result, _ = sym_args.fill_in(tensor_dict) + self.assertEqual(result, ((None, 2),)) + + # Testing fill in function for args and kwargs + def test_fill_in(self): + shape1 = (2, 3, 4) + shape2 = (3, 2, 4) + a = KerasTensor(shape=shape1) + b = KerasTensor(shape=shape2) + c = KerasTensor(shape=shape2) + dictionary = {id(a): 3, id(c): 2} + sym_args = SymbolicArguments( + ( + a, + b, + ), + {"1": c}, + ) + + (values, _) = sym_args.fill_in(dictionary) + self.assertEqual(values, ((3, None), {"1": 2})) diff --git a/keras/src/optimizers/__init__.py b/keras/src/optimizers/__init__.py new file mode 100644 index 000000000000..4db5319793ea --- /dev/null +++ b/keras/src/optimizers/__init__.py @@ -0,0 +1,122 @@ +from keras.src.api_export import keras_export +from keras.src.optimizers.adadelta import Adadelta +from keras.src.optimizers.adafactor import Adafactor +from keras.src.optimizers.adagrad import Adagrad +from keras.src.optimizers.adam import Adam +from keras.src.optimizers.adamax import Adamax +from keras.src.optimizers.adamw import AdamW +from keras.src.optimizers.ftrl import Ftrl +from keras.src.optimizers.lion import Lion +from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer +from keras.src.optimizers.muon import Muon +from keras.src.optimizers.nadam import Nadam +from keras.src.optimizers.optimizer import Optimizer +from keras.src.optimizers.rmsprop import RMSprop +from keras.src.optimizers.sgd import SGD +from keras.src.saving import serialization_lib + +ALL_OBJECTS = { + Optimizer, + Adam, + SGD, + RMSprop, + Adadelta, + AdamW, + Adagrad, + Adamax, + Adafactor, + Nadam, + Ftrl, + Lion, + LossScaleOptimizer, +} +ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS} + + +@keras_export("keras.optimizers.serialize") +def serialize(optimizer): + """Returns the optimizer configuration as a Python dict. + + Args: + optimizer: An `Optimizer` instance to serialize. + + Returns: + Python dict which contains the configuration of the optimizer. + """ + return serialization_lib.serialize_keras_object(optimizer) + + +@keras_export("keras.optimizers.deserialize") +def deserialize(config, custom_objects=None): + """Returns a Keras optimizer object via its configuration. + + Args: + config: Optimizer configuration dictionary. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras Optimizer instance. + """ + # Make deserialization case-insensitive for built-in optimizers. + if config["class_name"].lower() in ALL_OBJECTS_DICT: + config["class_name"] = config["class_name"].lower() + + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.optimizers.get") +def get(identifier): + """Retrieves a Keras Optimizer instance. + + Args: + identifier: Optimizer identifier, one of: + - String: name of an optimizer + - Dictionary: configuration dictionary. + - Keras Optimizer instance (it will be returned unchanged). + + Returns: + A Keras Optimizer instance. + """ + if identifier is None: + return None + elif isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + config = {"class_name": identifier, "config": {}} + obj = deserialize(config) + else: + obj = identifier + + if isinstance(obj, Optimizer): + return obj + raise ValueError(f"Could not interpret optimizer identifier: {identifier}") + + +# We will add this temporarily so that tensorflow packages that depend on +# estimators will continue to import (there are a large number). Note that +# Keras 3 will not work with the estimators API. +@keras_export( + [ + "keras.optimizers.legacy.Adagrad", + "keras.optimizers.legacy.Adam", + "keras.optimizers.legacy.Ftrl", + "keras.optimizers.legacy.RMSprop", + "keras.optimizers.legacy.SGD", + "keras.optimizers.legacy.Optimizer", + ] +) +class LegacyOptimizerWarning: + def __init__(self, *args, **kwargs): + raise ImportError( + "`keras.optimizers.legacy` is not supported in Keras 3. When using " + "`tf.keras`, to continue using a `tf.keras.optimizers.legacy` " + "optimizer, you can install the `tf_keras` package (Keras 2) and " + "set the environment variable `TF_USE_LEGACY_KERAS=True` to " + "configure TensorFlow to use `tf_keras` when accessing `tf.keras`." + ) diff --git a/keras/src/optimizers/adadelta.py b/keras/src/optimizers/adadelta.py new file mode 100644 index 000000000000..7e5a450ecbfa --- /dev/null +++ b/keras/src/optimizers/adadelta.py @@ -0,0 +1,135 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Adadelta"]) +class Adadelta(optimizer.Optimizer): + """Optimizer that implements the Adadelta algorithm. + + Adadelta optimization is a stochastic gradient descent method that is based + on adaptive learning rate per dimension to address two drawbacks: + + - The continual decay of learning rates throughout training. + - The need for a manually selected global learning rate. + + Adadelta is a more robust extension of Adagrad that adapts learning rates + based on a moving window of gradient updates, instead of accumulating all + past gradients. This way, Adadelta continues learning even when many updates + have been done. Compared to Adagrad, in the original version of Adadelta you + don't have to set an initial learning rate. In this version, the initial + learning rate can be set, as in most other Keras optimizers. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. Note that `Adadelta` + tends to benefit from higher initial learning rate values compared + to other optimizers. To match the exact form in the original paper, + use 1.0. + rho: A floating point value. The decay rate. Defaults to `0.95`. + epsilon: Small floating point value for maintaining numerical stability. + {{base_optimizer_keyword_args}} + + Reference: + + - [Zeiler, 2012](http://arxiv.org/abs/1212.5701) + """ + + def __init__( + self, + learning_rate=0.001, + rho=0.95, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="adadelta", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + name=name, + **kwargs, + ) + self.rho = rho + self.epsilon = epsilon + + def build(self, var_list): + if self.built: + return + super().build(var_list) + self._accumulated_grads, self._accumulated_delta_vars = ( + self.add_optimizer_variables( + var_list, ["accumulated_grad", "accumulated_delta_var"] + ) + ) + + def update_step(self, grad, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + grad = ops.cast(grad, variable.dtype) + + rho = self.rho + accumulated_grad = self._accumulated_grads[ + self._get_variable_index(variable) + ] + accumulated_delta_var = self._accumulated_delta_vars[ + self._get_variable_index(variable) + ] + + def rms(x): + return ops.sqrt(ops.add(x, self.epsilon)) + + self.assign( + accumulated_grad, + ops.add( + rho * accumulated_grad, ops.multiply(1 - rho, ops.square(grad)) + ), + ) + delta_var = ops.negative( + ops.divide( + ops.multiply(rms(accumulated_delta_var), grad), + rms(accumulated_grad), + ) + ) + self.assign( + accumulated_delta_var, + ops.add( + ops.multiply(rho, accumulated_delta_var), + ops.multiply(1 - rho, ops.square(delta_var)), + ), + ) + self.assign_add(variable, ops.multiply(lr, delta_var)) + + def get_config(self): + config = super().get_config() + + config.update( + { + "rho": self.rho, + "epsilon": self.epsilon, + } + ) + return config + + +Adadelta.__doc__ = Adadelta.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/adadelta_test.py b/keras/src/optimizers/adadelta_test.py new file mode 100644 index 000000000000..9da72612fc87 --- /dev/null +++ b/keras/src/optimizers/adadelta_test.py @@ -0,0 +1,75 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.adadelta import Adadelta + + +class AdadeltaTest(testing.TestCase): + def test_config(self): + optimizer = Adadelta( + learning_rate=0.5, + rho=0.9, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Adadelta(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.9993, 1.9993, 2.9993, 3.9993], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adadelta(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adadelta(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adadelta(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adadelta(learning_rate=1.0, rho=0.8, epsilon=1e-6) + + x = backend.Variable(np.ones([10])) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + golden = np.tile( + [[0.9978], [0.9947], [0.9915], [0.9882], [0.9849]], (1, 10) + ) + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adadelta(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adadelta(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/adafactor.py b/keras/src/optimizers/adafactor.py new file mode 100644 index 000000000000..6c406043353e --- /dev/null +++ b/keras/src/optimizers/adafactor.py @@ -0,0 +1,233 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Adafactor"]) +class Adafactor(optimizer.Optimizer): + """Optimizer that implements the Adafactor algorithm. + + Adafactor is commonly used in NLP tasks, and has the advantage + of taking less memory because it only saves partial information of previous + gradients. + + The default argument setup is based on the original paper (see reference). + When gradients are of dimension > 2, Adafactor optimizer will delete the + last 2 dimensions separately in its accumulator variables. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`. + epsilon_1: float, defaults to 1e-30. A small offset to keep denominator + away from 0. + epsilon_2: float, defaults to 1e-3. A small offset to avoid learning + rate becoming too small by time. + clip_threshold: float, defaults to 1.0. Clipping threshold. This is a + part of Adafactor algorithm, independent from `clipnorm`, + `clipvalue`, and `global_clipnorm`. + relative_step: bool, defaults to `True`. If `learning_rate` is a + constant and `relative_step=True`, learning rate will be adjusted + based on current iterations. This is a default learning rate decay + in Adafactor. + {{base_optimizer_keyword_args}} + + Reference: + + - [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235). + + """ + + def __init__( + self, + learning_rate=0.001, + beta_2_decay=-0.8, + epsilon_1=1e-30, + epsilon_2=1e-3, + clip_threshold=1.0, + relative_step=True, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="adafactor", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.beta_2_decay = beta_2_decay + self.epsilon_1 = epsilon_1 + self.epsilon_2 = epsilon_2 + self.clip_threshold = clip_threshold + self.relative_step = relative_step + + def build(self, var_list): + """Initialize optimizer variables. + + Adam optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), + + Args: + var_list: list of model variables to build Adam variables on. + """ + if self.built: + return + super().build(var_list) + self._r = [] + self._c = [] + self._v = [] + for var in var_list: + if len(var.shape) < 2: + # Don't factor if variable is of dimension < 2, but we still + # need to create dummy variables as placeholder. + self._r.append( + backend.Variable(0, name=var.name, trainable=False) + ) + self._c.append( + backend.Variable(0, name=var.name, trainable=False) + ) + elif self._overwrite_variable_with_gradient(var): + self._r.append(None) + self._c.append(None) + else: + # Always factor the last 2 dimensions. + r_shape = var.shape[:-1] + c_shape = var.shape[:-2] + (var.shape[-1],) + self._r.append( + self.add_variable( + shape=r_shape, + dtype=var.dtype, + name=var.name, + ) + ) + self._c.append( + self.add_variable( + shape=c_shape, + dtype=var.dtype, + name=var.name, + ) + ) + + if self._overwrite_variable_with_gradient(var): + self._v.append(None) + else: + self._v.append( + self.add_variable_from_reference( + reference_variable=var, name="velocity" + ) + ) + + def _rms(self, x): + return ops.sqrt(ops.mean(ops.square(x))) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + epsilon_2 = ops.cast(self.epsilon_2, variable.dtype) + one = ops.cast(1.0, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + if not callable(self._learning_rate) and self.relative_step: + lr = ops.minimum(lr, 1 / ops.sqrt(local_step)) + + r = self._r[self._get_variable_index(variable)] + c = self._c[self._get_variable_index(variable)] + v = self._v[self._get_variable_index(variable)] + + rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step)) + alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t + regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1) + beta_2_t = ops.subtract(1, ops.power(local_step, self.beta_2_decay)) + + if len(variable.shape) >= 2: + # `r` deletes the last dimension of gradient, so it is of shape + # `gradient.shape[:-1]`. + self.assign( + r, + ops.add( + ops.multiply(beta_2_t, r), + ops.multiply( + ops.subtract(1, beta_2_t), + ops.mean(regulated_grad_square, axis=-1), + ), + ), + ) + # `c` deletes the second last dimension of gradient, so it is of + # shape `gradient.shape[:-2] + gradient.shape[-1]`. + self.assign( + c, + ops.add( + ops.multiply(beta_2_t, c), + ops.multiply( + ops.subtract(1, beta_2_t), + ops.mean(regulated_grad_square, axis=-2), + ), + ), + ) + self.assign( + v, + ops.multiply( + ops.expand_dims( + ops.divide(r, ops.mean(r, axis=-1, keepdims=True)), + axis=-1, + ), + ops.expand_dims(c, -2), + ), + ) + else: + self.assign( + v, + ops.add( + ops.multiply(beta_2_t, v), + ops.multiply( + ops.subtract(1, beta_2_t), regulated_grad_square + ), + ), + ) + + u_t = ops.divide(gradient, ops.sqrt(v)) + u_t_hat = ops.divide( + u_t, + ops.maximum(one, ops.divide(self._rms(u_t), self.clip_threshold)), + ) + self.assign_sub(variable, ops.multiply(alpha_t, u_t_hat)) + + def get_config(self): + config = super().get_config() + + config.update( + { + "beta_2_decay": self.beta_2_decay, + "epsilon_1": self.epsilon_1, + "epsilon_2": self.epsilon_2, + "clip_threshold": self.clip_threshold, + "relative_step": self.relative_step, + } + ) + return config + + +Adafactor.__doc__ = Adafactor.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/adafactor_test.py b/keras/src/optimizers/adafactor_test.py new file mode 100644 index 000000000000..b928400c34e4 --- /dev/null +++ b/keras/src/optimizers/adafactor_test.py @@ -0,0 +1,102 @@ +# flake8: noqa + + +import numpy as np + +from keras.src import backend +from keras.src import testing +from keras.src.optimizers.adafactor import Adafactor + + +class AdafactorTest(testing.TestCase): + def test_config(self): + optimizer = Adafactor( + learning_rate=0.5, + beta_2_decay=-0.65, + epsilon_1=1e-15, + epsilon_2=1e-4, + clip_threshold=0.9, + relative_step=False, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step_1d(self): + optimizer = Adafactor(learning_rate=0.5) + grads = np.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [-0.3693, 0.6307, 1.6307, 2.6307], rtol=1e-4, atol=1e-4 + ) + + def test_single_step_2d(self): + optimizer = Adafactor(learning_rate=0.5) + grads = np.array([[1.0, 6.0], [7.0, 2.0]]) + vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [[0.7007, -0.0081], [1.2492, 3.4407]], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + np.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adafactor(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adafactor(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adafactor(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adafactor( + learning_rate=0.5, + beta_2_decay=-0.65, + epsilon_1=1e-15, + epsilon_2=1e-4, + clip_threshold=0.9, + relative_step=False, + ) + + x = backend.Variable(np.ones([10])) + grads = np.arange(0.1, 1.1, 0.1) + first_grads = np.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55], + [0.3031, 0.3026, 0.3025, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024], + [0.1671, 0.1665, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663], + [0.0923, 0.0916, 0.0915, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914], + [0.0554, 0.0548, 0.0546, 0.0546, 0.0546, 0.0546, 0.0546, 0.0545, 0.0545, 0.0545]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adafactor(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adafactor(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/adagrad.py b/keras/src/optimizers/adagrad.py new file mode 100644 index 000000000000..1323bc1027ea --- /dev/null +++ b/keras/src/optimizers/adagrad.py @@ -0,0 +1,108 @@ +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Adagrad"]) +class Adagrad(optimizer.Optimizer): + """Optimizer that implements the Adagrad algorithm. + + Adagrad is an optimizer with parameter-specific learning rates, + which are adapted relative to how frequently a parameter gets + updated during training. The more updates a parameter receives, + the smaller the updates. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. Note that `Adagrad` + tends to benefit from higher initial learning rate values compared + to other optimizers. To match the exact form in the original paper, + use `1.0`. + initial_accumulator_value: Floating point value. Starting value for the + accumulators (per-parameter momentum values). Must be non-negative. + epsilon: Small floating point value for maintaining numerical stability. + {{base_optimizer_keyword_args}} + + Reference: + + - [Duchi et al., 2011]( + http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf). + """ + + def __init__( + self, + learning_rate=0.001, + initial_accumulator_value=0.1, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="adagrad", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + name=name, + **kwargs, + ) + self.initial_accumulator_value = initial_accumulator_value + self.epsilon = epsilon + + def build(self, var_list): + if self.built: + return + super().build(var_list) + initializer = initializers.Constant(self.initial_accumulator_value) + self._accumulators = self.add_optimizer_variables( + var_list, "accumulator", initializer=initializer + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + + accumulator = self._accumulators[self._get_variable_index(variable)] + + self.assign_add(accumulator, ops.square(gradient)) + self.assign_sub( + variable, + ops.divide( + ops.multiply(lr, gradient), + ops.sqrt(ops.add(accumulator, self.epsilon)), + ), + ) + + def get_config(self): + config = super().get_config() + + config.update( + { + "initial_accumulator_value": self.initial_accumulator_value, + "epsilon": self.epsilon, + } + ) + return config + + +Adagrad.__doc__ = Adagrad.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/adagrad_test.py b/keras/src/optimizers/adagrad_test.py new file mode 100644 index 000000000000..43d2bcbd7afa --- /dev/null +++ b/keras/src/optimizers/adagrad_test.py @@ -0,0 +1,86 @@ +# flake8: noqa + + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.adagrad import Adagrad + + +class AdagradTest(testing.TestCase): + def test_config(self): + optimizer = Adagrad( + learning_rate=0.5, + initial_accumulator_value=0.2, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Adagrad(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.5233, 1.5007, 2.5005, 3.5061], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adagrad(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adagrad(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adagrad(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adagrad( + learning_rate=0.2, initial_accumulator_value=0.3, epsilon=1e-6 + ) + + x = backend.Variable(np.ones([10])) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963], + [0.9604, 0.9278, 0.9003, 0.8784, 0.8615, 0.8487, 0.8388, 0.8313, 0.8255, 0.8209], + [0.9251, 0.8629, 0.8137, 0.7768, 0.7497, 0.7298, 0.7151, 0.704, 0.6956, 0.6891], + [0.8903, 0.8012, 0.7342, 0.6862, 0.6521, 0.6277, 0.6099, 0.5967, 0.5867, 0.579], + [0.856, 0.7422, 0.6604, 0.6037, 0.5644, 0.5367, 0.5168, 0.5021, 0.491, 0.4825]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adagrad(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adagrad(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/adam.py b/keras/src/optimizers/adam.py new file mode 100644 index 000000000000..2c3970e97aa4 --- /dev/null +++ b/keras/src/optimizers/adam.py @@ -0,0 +1,154 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Adam"]) +class Adam(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + Adam optimization is a stochastic gradient descent method that is based on + adaptive estimation of first-order and second-order moments. + + According to + [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), + the method is "*computationally + efficient, has little memory requirement, invariant to diagonal rescaling of + gradients, and is well suited for problems that are large in terms of + data/parameters*". + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. Defaults to + `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 2nd moment estimates. Defaults to + `0.999`. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults + to `1e-7`. + amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm + from the paper "On the Convergence of Adam and beyond". Defaults + to `False`. + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="adam", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.amsgrad = amsgrad + + def build(self, var_list): + """Initialize optimizer variables. + + Adam optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), + + Args: + var_list: list of model variables to build Adam variables on. + """ + if self.built: + return + super().build(var_list) + self._momentums, self._velocities = self.add_optimizer_variables( + var_list, ["momentum", "velocity"] + ) + + if self.amsgrad: + self._velocity_hats = self.add_optimizer_variables( + var_list, "velocity_hat" + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + beta_1_power = ops.power( + ops.cast(self.beta_1, variable.dtype), local_step + ) + beta_2_power = ops.power( + ops.cast(self.beta_2, variable.dtype), local_step + ) + + m = self._momentums[self._get_variable_index(variable)] + v = self._velocities[self._get_variable_index(variable)] + + alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power) + + self.assign_add( + m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1) + ) + self.assign_add( + v, + ops.multiply( + ops.subtract(ops.square(gradient), v), 1 - self.beta_2 + ), + ) + if self.amsgrad: + v_hat = self._velocity_hats[self._get_variable_index(variable)] + self.assign(v_hat, ops.maximum(v_hat, v)) + v = v_hat + self.assign_sub( + variable, + ops.divide( + ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon) + ), + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + "amsgrad": self.amsgrad, + } + ) + return config + + +Adam.__doc__ = Adam.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/adam_test.py b/keras/src/optimizers/adam_test.py new file mode 100644 index 000000000000..4cc029ad9d30 --- /dev/null +++ b/keras/src/optimizers/adam_test.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest + +import keras +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.adam import Adam + + +class AdamTest(testing.TestCase): + def test_config(self): + optimizer = Adam( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + amsgrad=True, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Adam(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adam(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adam(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adam(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adam(amsgrad=True) + + x = backend.Variable(np.ones([10], dtype="float32")) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + golden = np.tile( + [[0.999], [0.9982], [0.9974], [0.9965], [0.9955]], (1, 10) + ) + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adam(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adam(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + @pytest.mark.requires_trainable_backend + def test_ema(self): + # TODO: test correctness + model = keras.Sequential([keras.layers.Dense(10)]) + model.compile(optimizer=Adam(use_ema=True), loss="mse") + x = keras.ops.zeros((1, 5)) + y = keras.ops.zeros((1, 10)) + model.fit(x, y) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The IndexedSlices test can only run with TF backend.", + ) + def test_clipnorm_indexed_slices(self): + # https://github.com/keras-team/keras/issues/18985 + model = keras.Sequential( + [ + keras.layers.Embedding(10, 4), + keras.layers.Flatten(), + keras.layers.Dense(2), + ] + ) + model.compile(optimizer=Adam(clipnorm=100), loss="mse") + x = keras.ops.ones((8, 5)) + y = keras.ops.zeros((8, 2)) + model.fit(x, y, verbose=0) diff --git a/keras/src/optimizers/adamax.py b/keras/src/optimizers/adamax.py new file mode 100644 index 000000000000..661fe1cb5310 --- /dev/null +++ b/keras/src/optimizers/adamax.py @@ -0,0 +1,146 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Adamax"]) +class Adamax(optimizer.Optimizer): + """Optimizer that implements the Adamax algorithm. + + Adamax, a variant of Adam based on the infinity norm, is a first-order + gradient-based optimization method. Due to its capability of adjusting the + learning rate based on data characteristics, it is suited to learn + time-variant process, e.g., speech data with dynamically changed noise + conditions. Default parameters follow those provided in the paper (see + references below). + + Initialization: + + ```python + m = 0 # Initialize initial 1st moment vector + u = 0 # Initialize the exponentially weighted infinity norm + t = 0 # Initialize timestep + ``` + + The update rule for parameter `w` with gradient `g` is described at the end + of section 7.1 of the paper (see the reference section): + + ```python + t += 1 + m = beta1 * m + (1 - beta) * g + u = max(beta2 * u, abs(g)) + current_lr = learning_rate / (1 - beta1 ** t) + w = w - current_lr * m / (u + epsilon) + ``` + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential decay + rate for the exponentially weighted infinity norm. + epsilon: A small constant for numerical stability. + {{base_optimizer_keyword_args}} + + Reference: + + - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="adamax", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + + def build(self, var_list): + """Initialize optimizer variables. + + Adamax optimizer has 2 types of variables: momentums (denoted as m), + exponentially weighted infinity norm (denoted as u). + + Args: + var_list: list of model variables to build Adamax variables on. + """ + if self.built: + return + super().build(var_list) + self._m, self._u = self.add_optimizer_variables( + var_list, ["momentum", "norm"] + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + beta_1_power = ops.power( + ops.cast(self.beta_1, variable.dtype), local_step + ) + + m = self._m[self._get_variable_index(variable)] + u = self._u[self._get_variable_index(variable)] + + self.assign_add( + m, ops.multiply(ops.subtract(gradient, m), (1 - self.beta_1)) + ) + self.assign( + u, ops.maximum(ops.multiply(self.beta_2, u), ops.abs(gradient)) + ) + self.assign_sub( + variable, + ops.divide( + ops.multiply(lr, m), + ops.multiply((1 - beta_1_power), ops.add(u, self.epsilon)), + ), + ) + + def get_config(self): + config = super().get_config() + + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + } + ) + return config + + +Adamax.__doc__ = Adamax.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/adamax_test.py b/keras/src/optimizers/adamax_test.py new file mode 100644 index 000000000000..50ca00383698 --- /dev/null +++ b/keras/src/optimizers/adamax_test.py @@ -0,0 +1,85 @@ +# flake8: noqa + + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.adamax import Adamax + + +class AdamaxTest(testing.TestCase): + def test_config(self): + optimizer = Adamax( + learning_rate=0.5, + beta_1=0.8, + beta_2=0.95, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Adamax(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Adamax(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Adamax(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Adamax(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Adamax( + learning_rate=0.2, beta_1=0.85, beta_2=0.95, epsilon=1e-6 + ) + + x = backend.Variable(np.ones([10], dtype="float32")) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8], + [0.6827, 0.6873, 0.6888, 0.6896, 0.6901, 0.6904, 0.6906, 0.6908, 0.6909, 0.691], + [0.5333, 0.5407, 0.5431, 0.5444, 0.5451, 0.5456, 0.546, 0.5462, 0.5464, 0.5466], + [0.368, 0.3773, 0.3804, 0.382, 0.3829, 0.3835, 0.384, 0.3843, 0.3846, 0.3848], + [0.1933, 0.204, 0.2076, 0.2094, 0.2105, 0.2112, 0.2117, 0.2121, 0.2124, 0.2126]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Adamax(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Adamax(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/adamw.py b/keras/src/optimizers/adamw.py new file mode 100644 index 000000000000..9db4a30094ab --- /dev/null +++ b/keras/src/optimizers/adamw.py @@ -0,0 +1,100 @@ +from keras.src.api_export import keras_export +from keras.src.optimizers import adam +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.AdamW"]) +class AdamW(adam.Adam): + """Optimizer that implements the AdamW algorithm. + + AdamW optimization is a stochastic gradient descent method that is based on + adaptive estimation of first-order and second-order moments with an added + method to decay weights per the techniques discussed in the paper, + 'Decoupled Weight Decay Regularization' by + [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101). + + According to + [Kingma et al., 2014](http://arxiv.org/abs/1412.6980), + the underlying Adam method is "*computationally + efficient, has little memory requirement, invariant to diagonal rescaling of + gradients, and is well suited for problems that are large in terms of + data/parameters*". + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. + Defaults to `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 2nd moment estimates. + Defaults to `0.999`. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just + before Section 2.1), not the epsilon in Algorithm 1 of the paper. + Defaults to 1e-7. + amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm + from the paper "On the Convergence of Adam and beyond". + Defaults to `False`. + {{base_optimizer_keyword_args}} + + References: + + - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101) + - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam` + - [Reddi et al., 2018]( + https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`. + """ + + def __init__( + self, + learning_rate=0.001, + weight_decay=0.004, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="adamw", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + + if self.weight_decay is None: + raise ValueError( + "Argument `weight_decay` must be a float. Received: " + "weight_decay=None" + ) + + +AdamW.__doc__ = AdamW.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/adamw_test.py b/keras/src/optimizers/adamw_test.py new file mode 100644 index 000000000000..e2d620c7c3e7 --- /dev/null +++ b/keras/src/optimizers/adamw_test.py @@ -0,0 +1,95 @@ +# flake8: noqa + + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.adamw import AdamW + + +class AdamWTest(testing.TestCase): + def test_config(self): + optimizer = AdamW( + learning_rate=0.5, + weight_decay=0.008, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + amsgrad=True, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = AdamW(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.4980, 1.4960, 2.494, 3.492], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = AdamW(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = AdamW(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = AdamW(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_weight_decay_is_none(self): + with self.assertRaisesRegex( + ValueError, + "Argument `weight_decay` must be a float. " + "Received: weight_decay=None", + ): + AdamW(learning_rate=1.0, weight_decay=None) + + def test_correctness_with_golden(self): + optimizer = AdamW(learning_rate=1.0, weight_decay=0.5, epsilon=2) + + x = backend.Variable(np.ones([10], dtype="float32")) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998], + [0.2486, 0.2475, 0.2463, 0.2451, 0.244, 0.2428, 0.2417, 0.2405, 0.2394, 0.2382], + [0.1223, 0.1198, 0.1174, 0.1149, 0.1124, 0.11, 0.1075, 0.1051, 0.1027, 0.1003], + [0.0586, 0.0549, 0.0512, 0.0475, 0.0439, 0.0402, 0.0366, 0.033, 0.0294, 0.0258], + [0.0263, 0.0215, 0.0167, 0.012, 0.0073, 0.0026, -0.0021, -0.0067, -0.0113, -0.0159]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = AdamW(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = AdamW(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py new file mode 100644 index 000000000000..4cae1d0b4f7d --- /dev/null +++ b/keras/src/optimizers/base_optimizer.py @@ -0,0 +1,1205 @@ +import re +import warnings + +from keras.src import backend +from keras.src import initializers +from keras.src import ops +from keras.src.optimizers.schedules import learning_rate_schedule +from keras.src.saving import serialization_lib +from keras.src.saving.keras_saveable import KerasSaveable +from keras.src.utils import tracking +from keras.src.utils.naming import auto_name + + +class BaseOptimizer(KerasSaveable): + """Abstract optimizer base class. + + If you intend to create your own optimization algorithm, please inherit from + this class and override the following methods: + + - `build`: Create your optimizer-related variables, such as momentum + variables in the SGD optimizer. + - `update_step`: Implement your optimizer's variable updating logic. + - `get_config`: serialization of the optimizer. + + Example: + + ```python + class SGD(Optimizer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.momentum = 0.9 + + def build(self, variables): + super().build(variables) + self.momentums = [] + for variable in variables: + self.momentums.append( + self.add_variable_from_reference( + reference_variable=variable, name="momentum" + ) + ) + + def update_step(self, gradient, variable, learning_rate): + learning_rate = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + m = self.momentums[self._get_variable_index(variable)] + self.assign( + m, + ops.subtract( + ops.multiply(m, ops.cast(self.momentum, variable.dtype)), + ops.multiply(gradient, learning_rate), + ), + ) + self.assign_add(variable, m) + + def get_config(self): + config = super().get_config() + config.update( + { + "momentum": self.momentum, + "nesterov": self.nesterov, + } + ) + return config + ``` + """ + + def __init__( + self, + learning_rate, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name=None, + **kwargs, + ): + self._lock = False + + if kwargs.pop("decay", None) is not None: + warnings.warn( + "Argument `decay` is no longer supported and will be ignored." + ) + if kwargs: + raise ValueError(f"Argument(s) not recognized: {kwargs}") + + if name is None: + name = auto_name(self.__class__.__name__) + self.name = name + self.weight_decay = weight_decay + self.clipnorm = clipnorm + self.global_clipnorm = global_clipnorm + self.clipvalue = clipvalue + self.use_ema = use_ema + self.loss_scale_factor = loss_scale_factor + self.gradient_accumulation_steps = gradient_accumulation_steps + + if gradient_accumulation_steps: + if not gradient_accumulation_steps >= 2: + raise ValueError( + "`gradient_accumulation_steps` must be an integer >= 2. " + "Received: gradient_accumulation_steps=" + f"{gradient_accumulation_steps}" + ) + + if use_ema: + # Verify the arguments related to EMA. + if ema_momentum > 1 or ema_momentum < 0: + raise ValueError( + "`ema_momentum` must be in the range [0, 1]. " + f"Received: ema_momentum={ema_momentum}" + ) + if ema_overwrite_frequency and ( + not isinstance(ema_overwrite_frequency, int) + or ema_overwrite_frequency < 1 + ): + raise ValueError( + "`ema_overwrite_frequency` must be an integer >= 1 or " + "None. Received: ema_overwrite_frequency=" + f"{ema_overwrite_frequency}" + ) + self.ema_momentum = ema_momentum + self.ema_overwrite_frequency = ema_overwrite_frequency + + clip_args_sum = sum( + a is not None for a in [clipnorm, clipvalue, global_clipnorm] + ) + if clip_args_sum > 1: + raise ValueError( + "Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can " + f"be set. Received: clipnorm={clipnorm}, " + f"clipvalue={clipvalue}, global_clipnorm={global_clipnorm}" + ) + self.built = False + + # Set up variable tracking. + self._variables = [] + self._trainable_variables = [] + self._tracker = tracking.Tracker( + { + "variables": ( + lambda x: isinstance(x, backend.Variable), + self._variables, + ), + } + ) + self._trainable_variables_indices = {} + + # Create iteration variable + # Note: dtype="int" will resolve to int32 in JAX + # (since int64 is disallowed in JAX) and to int64 in TF. + with backend.name_scope(self.name, caller=self): + iterations = backend.Variable( + 0, + name="iteration", + dtype="int", + trainable=False, + aggregation="only_first_replica", + ) + self._track_variable(iterations) + self._iterations = iterations + + # Create learning rate (schedule or variable) + if isinstance( + learning_rate, learning_rate_schedule.LearningRateSchedule + ): + self._learning_rate = learning_rate + elif callable(learning_rate): + self._learning_rate = learning_rate + else: + if not isinstance(learning_rate, float): + raise ValueError( + "Argument `learning_rate` should be float, or an instance " + "of LearningRateSchedule, or a callable " + "(that takes in the current iteration value " + "and returns the corresponding learning rate value). " + f"Received instead: learning_rate={learning_rate}" + ) + with backend.name_scope(self.name, caller=self): + learning_rate = backend.Variable( + learning_rate, + name="learning_rate", + dtype=backend.floatx(), + trainable=False, + aggregation="only_first_replica", + ) + self._track_variable(learning_rate) + self._learning_rate = learning_rate + + @property + def iterations(self): + if self.gradient_accumulation_steps: + return ops.floor_divide( + self._iterations, self.gradient_accumulation_steps + ) + + return self._iterations + + def _track_variable(self, variable): + self._tracker.add_to_store("variables", variable) + + def _overwrite_variable_with_gradient(self, variable): + return getattr(variable, "overwrite_with_gradient", False) + + @tracking.no_automatic_dependency_tracking + def build(self, variables): + if self.use_ema: + self._model_variables_moving_average = self.add_optimizer_variables( + variables, "average" + ) + if self.gradient_accumulation_steps: + self._accumulated_gradients = [] + for i, variable in enumerate(variables): + self._trainable_variables_indices[self._var_key(variable)] = i + if self.gradient_accumulation_steps: + self._accumulated_gradients.append( + self.add_variable_from_reference( + variable, + name="gradient_accumulator", + ) + ) + self._trainable_variables = variables[:] + self.built = True + + def _var_key(self, variable): + # Helper function to get a stable ID and the variable instance mapping. + return id(variable) + + @property + def variables(self): + return self._variables[:] + + def _get_variable_index(self, variable): + return self._trainable_variables_indices[self._var_key(variable)] + + def add_variable( + self, + shape, + initializer="zeros", + dtype=None, + aggregation="none", + layout=None, + name=None, + ): + """Add a variable to the optimizer. + + Args: + shape: Shape tuple for the variable. Must be fully-defined + (no `None` entries). + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). Defaults to `"zeros"`. + dtype: Dtype of the variable to create, e.g. `"float32"`. If + unspecified, defaults to the `keras.backend.floatx()`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"`. Annotates the variable with + the type of multi-replica aggregation to be used for this + variable when writing custom data parallel training loops. + Defaults to `"none"`. + layout: Optional tensor layout. Defaults to `None`. + name: String name of the variable. Useful for debugging purposes. + + Returns: + An optimizer variable, in the format of `keras.Variable`. + """ + self._check_super_called() + initializer = initializers.get(initializer) + with backend.name_scope(self.name, caller=self): + variable = backend.Variable( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=False, + aggregation=aggregation, + layout=layout, + name=name, + ) + self._track_variable(variable) + return variable + + def add_variable_from_reference( + self, reference_variable, name=None, initializer="zeros" + ): + """Add an optimizer variable from the model variable. + + Create an optimizer variable based on the information of model variable. + For example, in SGD optimizer momemtum, for each model variable, a + corresponding momemtum variable is created of the same shape and dtype. + + Args: + reference_variable: `keras.Variable`. The corresponding model + variable to the optimizer variable to be created. + name: Optional string. The name prefix of the optimizer variable to + be created. If not provided, it will be set to `"var"`. The + variable name will follow the pattern + `{variable_name}_{reference_variable.name}`, + e.g., `momemtum/dense_1`. Defaults to `None`. + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). If unspecified, defaults to + `"zeros"`. + + Returns: + An optimizer variable, in the format of `keras.Variable`. + """ + name = name or "var" + if hasattr(reference_variable, "path"): + name = f"{reference_variable.path.replace('/', '_')}_{name}" + else: + sanitised_ref_name = ( + str(reference_variable.name).replace("/", "_").replace(":", "_") + ) + name = f"{sanitised_ref_name}_{name}" + return self.add_variable( + shape=reference_variable.shape, + initializer=initializer, + dtype=reference_variable.dtype, + name=name, + layout=getattr(reference_variable, "_layout", None), + ) + + def add_optimizer_variables( + self, trainable_variables, name, initializer="zeros" + ): + """Add optimizer variables from the list of trainable model variables. + + Create an optimizer variable based on the information of the supplied + model variables. For example, in SGD optimizer momemtum, for each model + variable, a corresponding momemtum variable is created of the same shape + and dtype. + + Note that trainable variables with `v.overwrite_with_gradient == True` + will insert `None`, into the output list, since the optimizer variable + will not be used anyways, and could be wasteful. + + Args: + trainable_variables: `keras.Variable`, the corresponding model + variable to the optimizer variable to be created. + name: The name prefix(es) of the optimizer variable(s) to be + created. Can be a single string or list of strings. If a + list of strings, will create an optimizer variable for each + prefix. The variable name will follow the pattern + `{variable_name}_{trainable_variable.name}`, e.g., + `momemtum/dense_1`. + initializer: Initializer object(s) to use to populate the initial + variable value(s), or string name of a built-in initializer + (e.g. `"random_normal"`). If unspecified, defaults to + `"zeros"`. + + Returns: + A list of optimizer variables, in the format of `keras.Variable`s. + If multiple names are provide, returns a tuple of lists. + """ + name_list = name + initializer_list = initializer + if isinstance(name, str): + # Single name/initializer. + name_list = [name] + initializer_list = [initializer] + else: + # Multiple names/initializers. + # If there is only one initializer, use it for all names. + if isinstance(initializer, str) or isinstance( + initializer, initializers.Initializer + ): + initializer_list = [initializer] * len(name_list) + + if len(name_list) != len(initializer_list): + raise ValueError( + f"The number of provided names must match the number of " + f"provided initializers. Received name='{name}', " + f"initializer='{initializer}'" + ) + + # Build up lists of optimizer variables. + optimizer_variables = tuple([] for _ in name_list) + for variable in trainable_variables: + # Interleaves adding variables for backward-compatibility. + if not self._overwrite_variable_with_gradient(variable): + for i, (var_name, var_init) in enumerate( + zip(name_list, initializer_list) + ): + optimizer_variables[i].append( + self.add_variable_from_reference( + variable, + name=var_name, + initializer=var_init, + ) + ) + else: + for i in range(len(name_list)): + optimizer_variables[i].append(None) + + # If single input name, return the single list. + if isinstance(name, str): + return optimizer_variables[0] + + return optimizer_variables + + def _check_variables_are_known(self, variables): + for v in variables: + if self._var_key(v) not in self._trainable_variables_indices: + raise ValueError( + f"Unknown variable: {v}. This optimizer can only " + "be called for the variables it was originally built with. " + "When working with a new set of variables, you should " + "recreate a new optimizer instance." + ) + + def assign(self, variable, value): + """Assign a value to a variable. + + This should be used in optimizers instead of `variable.assign(value)` to + support backend specific optimizations. + Note that the variable can be a model variable or an optimizer variable; + it can be a backend native variable or a Keras variable. + + Args: + variable: The variable to update. + value: The value to add to the variable. + """ + variable.assign(value) + + def assign_add(self, variable, value): + """Add a value to a variable. + + This should be used in optimizers instead of + `variable.assign_add(value)` to support backend specific optimizations. + Note that the variable can be a model variable or an optimizer variable; + it can be a backend native variable or a Keras variable. + + Args: + variable: The variable to update. + value: The value to add to the variable. + """ + variable.assign_add(value) + + def assign_sub(self, variable, value): + """Subtract a value from a variable. + + This should be used in optimizers instead of + `variable.assign_sub(value)` to support backend specific optimizations. + Note that the variable can be a model variable or an optimizer variable; + it can be a backend native variable or a Keras variable. + + Args: + variable: The variable to update. + value: The value to add to the variable. + """ + variable.assign_sub(value) + + def update_step(self, gradient, variable, learning_rate): + raise NotImplementedError + + def apply_gradients(self, grads_and_vars): + grads, trainable_variables = zip(*grads_and_vars) + self.apply(grads, trainable_variables) + # Return iterations for compat with tf.keras. + return self._iterations + + def apply(self, grads, trainable_variables=None): + """Update traininable variables according to provided gradient values. + + `grads` should be a list of gradient tensors + with 1:1 mapping to the list of variables the optimizer was built with. + + `trainable_variables` can be provided + on the first call to build the optimizer. + """ + if len(grads) == 0: + # It is possible that the grad is empty. In this case, + # `apply_gradients` is a no-op. + return + + if trainable_variables is None: + if not self.built: + raise ValueError( + "When passing `grads` without `variables`, the optimizer " + "must already be built on a list of variables. " + "Call `optimizer.build(trainable_variables)` first. " + ) + if len(grads) != len(self._trainable_variables_indices): + raise ValueError( + "When passing `grads` as a list of gradient tensors, the " + f"gradients must match `optimizer.variables` one-to-on. " + f"Received a list of {len(grads)} gradients, but the " + f"optimizer is tracking {len(self._trainable_variables)} " + "trainable variables." + ) + trainable_variables = self._trainable_variables + else: + trainable_variables = list(trainable_variables) + # Optionally build optimizer. + if not self.built: + with backend.name_scope(self.name, caller=self): + self.build(trainable_variables) + self.built = True + self._check_variables_are_known(trainable_variables) + + with backend.name_scope(self.name, caller=self): + # Filter empty gradients. + grads, trainable_variables = self._filter_empty_gradients( + grads, trainable_variables + ) + + # Overwrite targeted variables directly with their gradients if + # their `overwrite_with_gradient` is set. + grads, trainable_variables = ( + self._overwrite_variables_directly_with_gradients( + grads, trainable_variables + ) + ) + + if len(list(grads)) > 0: + # Unscale gradients. + scale = self.loss_scale_factor + if scale is not None: + grads = [g if g is None else g / scale for g in grads] + + # Apply gradient updates. + self._backend_apply_gradients(grads, trainable_variables) + # Apply variable constraints after applying gradients. + for variable in trainable_variables: + if variable.constraint is not None: + variable.assign(variable.constraint(variable)) + + # Update iteration counter. + self._iterations.assign_add(1) + + def _backend_apply_gradients(self, grads, trainable_variables): + """Apply method that can be overridden by different backends. + + JAX overrides it in order to deal with statelessness in gradient + accumulation and EMA handling. + + The below implementation is intended to be generally backend-agnostic, + but may not work with all backends. + + This method does 4 things: + - Call the optimizer's update_step() to update trainable variables + and optimizer variables. + - Update EMA variables, if EMA is configured. + - Update gradient accumulators, if gradient accumulation is configured. + - Update the iteration counter. + """ + if self.gradient_accumulation_steps: + is_update_step = ( + self._iterations + 1 + ) % self.gradient_accumulation_steps == 0 + # `trainable_variables` might have been filtered in previous + # processing steps, so we need to ensure the correct mapping between + # `self._accumulated_gradients` and `trainable_variables` + acc_grads = [ + self._accumulated_gradients[self._get_variable_index(v)] + for v in trainable_variables + ] + + def _update_step_fn(grads, trainable_variables): + # Run update step with accumulated grads + reset accumulators + steps = self.gradient_accumulation_steps + grads = [ + (g + acc_g) / steps for g, acc_g in zip(grads, acc_grads) + ] + + # Apply clipping and weight decay. + grads = self._clip_gradients(grads) + self._apply_weight_decay(trainable_variables) + + self._backend_update_step( + grads, trainable_variables, self.learning_rate + ) + self._backend_reset_gradient_accumulators() + + ops.cond( + is_update_step, + lambda: _update_step_fn(grads, trainable_variables), + lambda: self._backend_increment_gradient_accumulators( + grads, acc_grads + ), + ) + else: + # Apply clipping and weight decay. + grads = self._clip_gradients(grads) + self._apply_weight_decay(trainable_variables) + + # Run update step. + self._backend_update_step( + grads, trainable_variables, self.learning_rate + ) + + if self.use_ema: + self._update_model_variables_moving_average( + self._trainable_variables + ) + if self.ema_overwrite_frequency: + # Only when self.ema_overwrite_frequency is not None, we + # overwrite the model variables. + should_overwrite_model_vars = ( + self.iterations + 1 + ) % self.ema_overwrite_frequency == 0 + ops.cond( + should_overwrite_model_vars, + lambda: self._overwrite_model_variables_with_average_value( + self._trainable_variables + ), + lambda: None, + ) + + def _backend_update_step(self, grads, trainable_variables, learning_rate): + """Collective update_step that can be overridden by the backend. + + It is overridden by torch for performance reasons, and + by TF to support tf.distribute. + """ + for grad, var in zip(grads, trainable_variables): + self.update_step(grad, var, learning_rate) + + def _backend_reset_gradient_accumulators(self): + for g_acc in self._accumulated_gradients: + if g_acc is not None: + g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype)) + + def _backend_increment_gradient_accumulators(self, grads, acc_grads): + new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)] + for n_g_acc, g_acc in zip(new_g_accs, acc_grads): + g_acc.assign(n_g_acc) + + def stateless_apply(self, optimizer_variables, grads, trainable_variables): + """Stateless version of `apply` that returns modified variables. + + Args: + optimizer_variables: list of tensors containing the current values + for the optimizer variables. These are native tensors and not + `keras.Variable`s. + grads: list of gradients to apply. + trainable_variables: list of tensors containing the current values + for the model variables. These are native tensors and not + `keras.Variable`s. + + Returns: A tuple containing two list of tensors, the updated + `trainable_variables` and the updated `optimizer_variables`. + """ + self._check_super_called() + + if not self.built: + raise ValueError( + f"To call `stateless_apply`, {self.__class__.__name__} " + "must be built (i.e. its variables must have been created). " + "You can build it via `optimizer.build(trainable_variables)`." + ) + if len(optimizer_variables) != len(self.variables): + raise ValueError( + "Argument `optimizer_variables` must be a list of tensors " + f"corresponding 1:1 to {self.__class__.__name__}().variables. " + f"Received list with length {len(optimizer_variables)}, but " + f"expected {len(self.variables)} variables." + ) + if len(trainable_variables) != len(self._trainable_variables): + raise ValueError( + "Argument `optimizer_variables` must be a list of tensors " + "corresponding 1:1 to the trainable variables list that " + "the optimizer was built with. Received " + f"len(trainable_variables) == {len(trainable_variables)} " + "whereas the optimizer was built with " + f"{len(self._trainable_variables)} variables." + ) + + # Gather variable mapping + mapping = list( + zip(self._trainable_variables, trainable_variables) + ) + list(zip(self.variables, optimizer_variables)) + + # Call in stateless scope + with backend.StatelessScope(state_mapping=mapping) as scope: + self.apply(grads) + + # Gather updated variables + trainable_variables = [] + for v in self._trainable_variables: + new_v = scope.get_current_value(v) + if new_v is not None: + trainable_variables.append(new_v) + else: + trainable_variables.append(v) + optimizer_variables = [] + for v in self.variables: + new_v = scope.get_current_value(v) + if new_v is not None: + optimizer_variables.append(new_v) + else: + optimizer_variables.append(v) + return trainable_variables, optimizer_variables + + def scale_loss(self, loss): + """Scale the loss before computing gradients. + + Scales the loss before gradients are computed in a `train_step`. This + is primarily useful during mixed precision training to prevent numeric + underflow. + """ + if self.loss_scale_factor is not None: + return loss * self.loss_scale_factor + return loss + + @property + def learning_rate(self): + return self._get_current_learning_rate() + + @learning_rate.setter + def learning_rate(self, learning_rate): + if isinstance(self._learning_rate, backend.Variable): + prev_lr_var = self._learning_rate + else: + prev_lr_var = None + if isinstance( + learning_rate, learning_rate_schedule.LearningRateSchedule + ): + self._learning_rate = learning_rate + elif callable(learning_rate): + self._learning_rate = learning_rate + else: + if isinstance( + self._learning_rate, learning_rate_schedule.LearningRateSchedule + ): + raise TypeError( + "This optimizer was created with a `LearningRateSchedule`" + " object as its `learning_rate` constructor argument, " + "hence its learning rate is not settable. If you need the" + " learning rate to be settable, you should instantiate " + "the optimizer with a float `learning_rate` argument." + ) + self._learning_rate.assign(learning_rate) + if prev_lr_var is not None and not isinstance( + self._learning_rate, backend.Variable + ): + # Untrack learning rate variable + self._untrack_variable(prev_lr_var) + + def set_weights(self, weights): + """Set the weights of the optimizer.""" + if not self.built: + raise ValueError( + "You are calling `set_weights()` on an optimizer that has not " + "yet been built. Please call " + "`optimizer.build(trainable_variables)` to create the " + "optimizer weights before calling `set_weights()`." + ) + for variable, weight in zip(self._variables, weights): + if variable.shape != weight.shape: + raise ValueError( + f"Optimizer variable {self._var_key(variable)} has shape " + f"{str(variable.shape)} not compatible with provided " + f"weight shape {str(weight.shape)}." + ) + variable.assign(weight) + + def save_own_variables(self, store): + """Get the state of this optimizer object.""" + for i, variable in enumerate(self.variables): + store[str(i)] = variable.numpy() + + def load_own_variables(self, store): + """Set the state of this optimizer object.""" + if len(store.keys()) != len(self.variables): + msg = ( + f"Skipping variable loading for optimizer '{self.name}', " + f"because it has {len(self.variables)} variables whereas " + f"the saved optimizer has {len(store.keys())} variables. " + ) + if len(self.variables) == 0: + msg += ( + "This is likely because the optimizer has not been " + "called/built yet." + ) + warnings.warn(msg, stacklevel=2) + return + for i, variable in enumerate(self.variables): + variable.assign(store[str(i)]) + + def _get_current_learning_rate(self): + if isinstance( + self._learning_rate, learning_rate_schedule.LearningRateSchedule + ): + return self._learning_rate(self._iterations) + elif isinstance(self._learning_rate, backend.Variable): + return self._learning_rate + elif callable(self._learning_rate): + return self._learning_rate() + return self._learning_rate + + def _overwrite_variables_directly_with_gradients(self, grads, vars): + """Overwrite the variables directly by their gradients. + + This method is designed for a special case where we want to overwrite + the variable directly with its computed gradient. For example, in float8 + training, new `scale` and `amax_history` are computed as gradients, and + we want to overwrite them directly instead of following the typical + procedure such as gradient descent with a learning rate, gradient + clipping and weight decaying. + + After the update, the processed pairs will be filtered out. + """ + # Shortcut for `tf.Variable` because it doesn't have a + # `overwrite_with_gradient` attr. + if not any(self._overwrite_variable_with_gradient(v) for v in vars): + return grads, vars + + # Shallow copies + filtered_grads = list(grads) + filtered_vars = list(vars) + + # Iterate from right to left for safe popping + for i in range(len(filtered_grads) - 1, -1, -1): + g, v = filtered_grads[i], filtered_vars[i] + if self._overwrite_variable_with_gradient(v): + if self.gradient_accumulation_steps: + # Utilize a stateless manner for JAX compatibility + steps = self.gradient_accumulation_steps + is_update_step = (self._iterations + 1) % steps == 0 + acc_g = self._accumulated_gradients[ + self._get_variable_index(v) + ] + # `ops.maximum` is utilized for gradient accumulation for + # `overwrite_with_gradient=True` variables + new_g_acc = ops.cond( + is_update_step, + lambda: ops.zeros(g.shape, dtype=g.dtype), + lambda: ops.maximum(g, acc_g), + ) + new_g = ops.cond( + is_update_step, + lambda: ops.maximum(g, acc_g), + lambda: g, + ) + new_v = ops.cond( + is_update_step, lambda: new_g, lambda: v.value + ) + v.assign(new_v) + acc_g.assign(new_g_acc) + else: + v.assign(g) + filtered_grads.pop(i) + filtered_vars.pop(i) + return filtered_grads, filtered_vars + + def _filter_empty_gradients(self, grads, vars): + filtered_grads = list(grads) + filtered_vars = list(vars) + missing_grad_vars = [] + + # Iterate from right to left for safe popping + for i in range(len(filtered_grads) - 1, -1, -1): + if filtered_grads[i] is None: + filtered_grads.pop(i) + v = filtered_vars.pop(i) + try: + missing_grad_vars.append(v.path) + except AttributeError: + # `tf.Variable` doesn't have `path` attr. + missing_grad_vars.append(v.name) + + if not filtered_grads: + raise ValueError("No gradients provided for any variable.") + if missing_grad_vars: + warnings.warn( + "Gradients do not exist for variables " + f"{list(reversed(missing_grad_vars))} when minimizing the loss." + " If using `model.compile()`, did you forget to provide a " + "`loss` argument?" + ) + return filtered_grads, filtered_vars + + def _clip_gradients(self, grads): + if self.clipnorm and self.clipnorm > 0: + return [ + self._clip_by_norm(g) if g is not None else g for g in grads + ] + elif self.global_clipnorm and self.global_clipnorm > 0: + return clip_by_global_norm(grads, self.global_clipnorm) + elif self.clipvalue and self.clipvalue > 0: + v = self.clipvalue + return [ops.clip(g, -v, v) if g is not None else g for g in grads] + else: + return grads + + def exclude_from_weight_decay(self, var_list=None, var_names=None): + """Exclude variables from weight decay. + + This method must be called before the optimizer's `build` method is + called. You can set specific variables to exclude out, or set a list of + strings as the anchor words, if any of which appear in a variable's + name, then the variable is excluded. + + Args: + var_list: A list of `Variable`s to exclude from weight decay. + var_names: A list of strings. If any string in `var_names` appear + in the model variable's name, then this model variable is + excluded from weight decay. For example, `var_names=['bias']` + excludes all bias variables from weight decay. + """ + if hasattr(self, "_built") and self._built: + raise ValueError( + "`exclude_from_weight_decay()` can only be configured before " + "the optimizer is built." + ) + + # Use a `set` for the ids of `var_list` to speed up the searching + if var_list: + self._exclude_from_weight_decay = set( + self._var_key(variable) for variable in var_list + ) + else: + self._exclude_from_weight_decay = set() + + # Precompile the pattern for `var_names` to speed up the searching + if var_names and len(var_names) > 0: + self._exclude_from_weight_decay_pattern = re.compile( + "|".join(set(var_names)) + ) + else: + self._exclude_from_weight_decay_pattern = None + + # Reset cache + self._exclude_from_weight_decay_cache = dict() + + def _use_weight_decay(self, variable): + variable_id = self._var_key(variable) + + # Immediately return the value if `variable_id` hits the cache + if not hasattr(self, "_exclude_from_weight_decay_cache"): + self._exclude_from_weight_decay_cache = dict() + if variable_id in self._exclude_from_weight_decay_cache: + return self._exclude_from_weight_decay_cache[variable_id] + + # Determine whether the variable should apply weight decay or not + exclude_from_weight_decay = getattr( + self, "_exclude_from_weight_decay", set() + ) + exclude_from_weight_decay_pattern = getattr( + self, "_exclude_from_weight_decay_pattern", None + ) + if variable_id in exclude_from_weight_decay: + self._exclude_from_weight_decay_cache[variable_id] = False + return False + if exclude_from_weight_decay_pattern is not None: + if ( + re.search(exclude_from_weight_decay_pattern, variable.name) + is not None + ): + self._exclude_from_weight_decay_cache[variable_id] = False + return False + self._exclude_from_weight_decay_cache[variable_id] = True + return True + + def _apply_weight_decay(self, variables): + if self.weight_decay is None: + return + for variable in variables: + if self._use_weight_decay(variable): + lr = ops.cast(self.learning_rate, variable.dtype) + wd = ops.cast(self.weight_decay, variable.dtype) + variable.assign(variable - variable * wd * lr) + + def _check_super_called(self): + if not hasattr(self, "_lock"): + raise RuntimeError( + f"In optimizer '{self.__class__.__name__}', you forgot to call " + "`super().__init__()` as the first statement " + "in the `__init__()` method. " + "Go add it!" + ) + + def _update_model_variables_moving_average(self, trainable_variables): + """Update the stored moving average using the latest value.""" + if self.use_ema: + for var, average in zip( + trainable_variables, self._model_variables_moving_average + ): + if average is not None: + not_first_step = ops.not_equal(self.iterations, 0) + momentum = ops.multiply( + ops.cast(not_first_step, var.dtype), self.ema_momentum + ) + average.assign( + ops.add( + ops.multiply(momentum, average), + ops.multiply(ops.subtract(1, momentum), var), + ) + ) + + def _overwrite_model_variables_with_average_value( + self, trainable_variables + ): + """Overwrite model variables with its moving average.""" + if len(trainable_variables) != len( + self._model_variables_moving_average + ): + raise ValueError( + f"The length of model variables ({len(trainable_variables)}) " + "to override does not match the length of model variables " + "stored in the optimizer " + f"({len(self._model_variables_moving_average)}). Please " + "check if the optimizer was called on your model." + ) + for var, average_var in zip( + trainable_variables, self._model_variables_moving_average + ): + if average_var is not None: + var.assign(average_var) + + def finalize_variable_values(self, var_list): + """Set the final value of model's trainable variables. + + Sometimes there are some extra steps before ending the variable updates, + such as overriding the model variables with its average value. + + Args: + var_list: list of model variables. + """ + if self.use_ema: + # If the optimizer uses EMA, then when finalizing, we replace the + # model variable value with its moving average stored inside + # optimizer. + self._overwrite_model_variables_with_average_value(var_list) + + def _obj_type(self): + return "Optimizer" + + def get_config(self): + """Returns the config of the optimizer. + + An optimizer config is a Python dictionary (serializable) + containing the configuration of an optimizer. + The same optimizer can be reinstantiated later + (without any saved state) from this configuration. + + Subclass optimizer should override this method to include other + hyperparameters. + + Returns: + Python dictionary. + """ + + if isinstance( + self._learning_rate, learning_rate_schedule.LearningRateSchedule + ): + learning_rate = learning_rate_schedule.serialize( + self._learning_rate + ) + elif isinstance(self._learning_rate, backend.Variable): + learning_rate = float(self._learning_rate.numpy()) + elif ops.is_tensor(self._learning_rate): + learning_rate = float(self._learning_rate) + elif callable(self._learning_rate): + learning_rate = serialization_lib.serialize_keras_object( + self._learning_rate + ) + else: + learning_rate = 0.5 + + config = { + "name": self.name, + "learning_rate": learning_rate, + "weight_decay": self.weight_decay, + "clipnorm": self.clipnorm, + "global_clipnorm": self.global_clipnorm, + "clipvalue": self.clipvalue, + "use_ema": self.use_ema, + "ema_momentum": self.ema_momentum, + "ema_overwrite_frequency": self.ema_overwrite_frequency, + "loss_scale_factor": self.loss_scale_factor, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + } + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + """Creates an optimizer from its config. + + This method is the reverse of `get_config`, capable of instantiating the + same optimizer from the config dictionary. + + Args: + config: A Python dictionary, typically the output of get_config. + custom_objects: A Python dictionary mapping names to additional + user-defined Python objects needed to recreate this optimizer. + + Returns: + An optimizer instance. + """ + if "learning_rate" in config: + if isinstance(config["learning_rate"], dict): + config["learning_rate"] = ( + serialization_lib.deserialize_keras_object( + config["learning_rate"], custom_objects=custom_objects + ) + ) + return cls(**config) + + def __setattr__(self, name, value): + # Prevent users from attaching state to the + # layer before `super()` is called -- since that + # state would silently not be tracked. + if name != "_lock": + self._check_super_called() + # Track Variables. + if hasattr(self, "_tracker"): + value = self._tracker.track(value) + return super().__setattr__(name, value) + + def _clip_by_norm(self, values, axes=None): + # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm + l2sum = ops.sum(ops.square(values), axes, keepdims=True) + pred = l2sum > 0 + # Two-tap tf.where trick to bypass NaN gradients + l2sum_safe = ops.where(pred, l2sum, ops.ones_like(l2sum)) + l2norm = ops.where(pred, ops.sqrt(l2sum_safe), l2sum) + intermediate = ops.multiply(values, self.clipnorm) + values_clip = ops.convert_to_tensor(intermediate) / ops.maximum( + l2norm, self.clipnorm + ) + return values_clip + + def _untrack_variable(self, variable): + previous_lock_state = self._tracker.locked + self._tracker.unlock() + self._tracker.untrack(variable) + if previous_lock_state is True: + self._tracker.lock() + + +base_optimizer_keyword_args = """name: String. The name to use + for momentum accumulator weights created by + the optimizer. + weight_decay: Float. If set, weight decay is applied. + clipnorm: Float. If set, the gradient of each weight is individually + clipped so that its norm is no higher than this value. + clipvalue: Float. If set, the gradient of each weight is clipped to be + no higher than this value. + global_clipnorm: Float. If set, the gradient of all weights is clipped + so that their global norm is no higher than this value. + use_ema: Boolean, defaults to `False`. + If `True`, exponential moving average + (EMA) is applied. EMA consists of computing an exponential moving + average of the weights of the model (as the weight values change + after each training batch), and periodically overwriting the + weights with their moving average. + ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`. + This is the momentum to use when computing + the EMA of the model's weights: + `new_average = ema_momentum * old_average + (1 - ema_momentum) * + current_variable_value`. + ema_overwrite_frequency: Int or None, defaults to None. Only used if + `use_ema=True`. Every `ema_overwrite_frequency` steps of iterations, + we overwrite the model variable by its moving average. + If None, the optimizer + does not overwrite model variables in the middle of training, + and you need to explicitly overwrite the variables + at the end of training by calling + `optimizer.finalize_variable_values()` (which updates the model + variables in-place). When using the built-in `fit()` training loop, + this happens automatically after the last epoch, + and you don't need to do anything. + loss_scale_factor: Float or `None`. If a float, the scale factor will + be multiplied the loss before computing gradients, and the inverse + of the scale factor will be multiplied by the gradients before + updating variables. Useful for preventing underflow during + mixed precision training. Alternately, + `keras.optimizers.LossScaleOptimizer` will + automatically set a loss scale factor. + gradient_accumulation_steps: Int or `None`. If an int, model & optimizer + variables will not be updated at every step; instead they will be + updated every `gradient_accumulation_steps` steps, using the average + value of the gradients since the last update. This is known as + "gradient accumulation". This can be useful + when your batch size is very small, in order to reduce gradient + noise at each update step. EMA frequency will look at "accumulated" + iterations value (optimizer steps // gradient_accumulation_steps). + Learning rate schedules will look at "real" iterations value + (optimizer steps). +""" + + +def global_norm(value_list): + """Computes the global norm of multiple tensors.""" + squared_norms = [ + ops.sum(ops.square(v)) for v in value_list if v is not None + ] + squared_norm = ops.sum(ops.stack(squared_norms)) + return ops.sqrt(squared_norm) + + +def clip_by_global_norm(value_list, clip_norm): + use_norm = global_norm(value_list) + # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm + scale_for_finite = clip_norm * ops.minimum(1.0 / use_norm, 1.0 / clip_norm) + # If use_norm is any finite number, this is a no-op. For inf/-inf/NaN, + # this will make scale NaN. + scale = scale_for_finite + (use_norm - use_norm) + return [v * scale if v is not None else v for v in value_list] diff --git a/keras/src/optimizers/ftrl.py b/keras/src/optimizers/ftrl.py new file mode 100644 index 000000000000..6bef848a905b --- /dev/null +++ b/keras/src/optimizers/ftrl.py @@ -0,0 +1,239 @@ +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Ftrl"]) +class Ftrl(optimizer.Optimizer): + r"""Optimizer that implements the FTRL algorithm. + + "Follow The Regularized Leader" (FTRL) is an optimization algorithm + developed at Google for click-through rate prediction in the early 2010s. It + is most suitable for shallow models with large and sparse feature spaces. + The algorithm is described by + [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). + The Keras version has support for both online L2 regularization + (the L2 regularization described in the paper + above) and shrinkage-type L2 regularization + (which is the addition of an L2 penalty to the loss function). + + Initialization: + + ```python + n = 0 + sigma = 0 + z = 0 + ``` + + Update rule for one variable `w`: + + ```python + prev_n = n + n = n + g ** 2 + sigma = (n ** -lr_power - prev_n ** -lr_power) / lr + z = z + g - sigma * w + if abs(z) < lambda_1: + w = 0 + else: + w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2) + ``` + + Notation: + + - `lr` is the learning rate + - `g` is the gradient for the variable + - `lambda_1` is the L1 regularization strength + - `lambda_2` is the L2 regularization strength + - `lr_power` is the power to scale n. + + Check the documentation for the `l2_shrinkage_regularization_strength` + parameter for more details when shrinkage is enabled, in which case gradient + is replaced with a gradient with shrinkage. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + learning_rate_power: A float value, must be less or equal to zero. + Controls how the learning rate decreases during training. Use zero + for a fixed learning rate. + initial_accumulator_value: The starting value for accumulators. Only + zero or positive values are allowed. + l1_regularization_strength: A float value, must be greater than or equal + to zero. Defaults to `0.0`. + l2_regularization_strength: A float value, must be greater than or equal + to zero. Defaults to `0.0`. + l2_shrinkage_regularization_strength: A float value, must be greater + than or equal to zero. This differs from L2 above in that the L2 + above is a stabilization penalty, whereas this L2 shrinkage is a + magnitude penalty. When input is sparse shrinkage will only happen + on the active weights. + beta: A float value, representing the beta value from the paper. + Defaults to `0.0`. + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + learning_rate=0.001, + learning_rate_power=-0.5, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0, + l2_shrinkage_regularization_strength=0.0, + beta=0.0, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="ftrl", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + + if initial_accumulator_value < 0.0: + raise ValueError( + "`initial_accumulator_value` needs to be positive or zero. " + "Received: initial_accumulator_value=" + f"{initial_accumulator_value}." + ) + if learning_rate_power > 0.0: + raise ValueError( + "`learning_rate_power` needs to be negative or zero. Received: " + f"learning_rate_power={learning_rate_power}." + ) + if l1_regularization_strength < 0.0: + raise ValueError( + "`l1_regularization_strength` needs to be positive or zero. " + "Received: l1_regularization_strength=" + f"{l1_regularization_strength}." + ) + if l2_regularization_strength < 0.0: + raise ValueError( + "`l2_regularization_strength` needs to be positive or zero. " + "Received: l2_regularization_strength=" + f"{l2_regularization_strength}." + ) + if l2_shrinkage_regularization_strength < 0.0: + raise ValueError( + "`l2_shrinkage_regularization_strength` needs to be positive " + "or zero. Received: l2_shrinkage_regularization_strength" + f"={l2_shrinkage_regularization_strength}." + ) + + self.learning_rate_power = learning_rate_power + self.initial_accumulator_value = initial_accumulator_value + self.l1_regularization_strength = l1_regularization_strength + self.l2_regularization_strength = l2_regularization_strength + self.l2_shrinkage_regularization_strength = ( + l2_shrinkage_regularization_strength + ) + self.beta = beta + + def build(self, var_list): + """Initialize optimizer variables. + + Args: + var_list: list of model variables to build Ftrl variables on. + """ + if self.built: + return + super().build(var_list) + accumulator_initializer = initializers.Constant( + self.initial_accumulator_value, + ) + self._accumulators, self._linears = self.add_optimizer_variables( + var_list, + ["accumulator", "linear"], + initializer=[accumulator_initializer, "zeros"], + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + + accum = self._accumulators[self._get_variable_index(variable)] + linear = self._linears[self._get_variable_index(variable)] + + lr_power = self.learning_rate_power + l2_reg = self.l2_regularization_strength + l2_reg = l2_reg + self.beta / (2.0 * lr) + + grad_to_use = ops.add( + gradient, + ops.multiply( + 2 * self.l2_shrinkage_regularization_strength, variable + ), + ) + new_accum = ops.add(accum, ops.square(gradient)) + self.assign_add( + linear, + ops.subtract( + grad_to_use, + ops.multiply( + ops.divide( + ops.subtract( + ops.power(new_accum, -lr_power), + ops.power(accum, -lr_power), + ), + lr, + ), + variable, + ), + ), + ) + quadratic = ops.add( + ops.divide(ops.power(new_accum, (-lr_power)), lr), 2 * l2_reg + ) + linear_clipped = ops.clip( + linear, + -self.l1_regularization_strength, + self.l1_regularization_strength, + ) + self.assign( + variable, + ops.divide(ops.subtract(linear_clipped, linear), quadratic), + ) + self.assign(accum, new_accum) + + def get_config(self): + config = super().get_config() + + config.update( + { + "learning_rate_power": self.learning_rate_power, + "initial_accumulator_value": self.initial_accumulator_value, + "l1_regularization_strength": self.l1_regularization_strength, + "l2_regularization_strength": self.l2_regularization_strength, + "l2_shrinkage_regularization_strength": self.l2_shrinkage_regularization_strength, # noqa: E501 + "beta": self.beta, + } + ) + return config + + +Ftrl.__doc__ = Ftrl.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/ftrl_test.py b/keras/src/optimizers/ftrl_test.py new file mode 100644 index 000000000000..379ecd97d82f --- /dev/null +++ b/keras/src/optimizers/ftrl_test.py @@ -0,0 +1,114 @@ +# flake8: noqa + + +import numpy as np +from unittest import mock + +from keras.src import backend +from keras.src import testing +from keras.src.optimizers.ftrl import Ftrl + + +class FtrlTest(testing.TestCase): + def test_config(self): + optimizer = Ftrl( + learning_rate=0.05, + learning_rate_power=-0.2, + initial_accumulator_value=0.4, + l1_regularization_strength=0.05, + l2_regularization_strength=0.15, + l2_shrinkage_regularization_strength=0.01, + beta=0.3, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Ftrl(learning_rate=0.5) + grads = np.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.2218, 1.3954, 2.3651, 2.8814], rtol=1e-4, atol=1e-4 + ) + + def test_correctness_with_golden(self): + optimizer = Ftrl( + learning_rate=0.05, + learning_rate_power=-0.2, + initial_accumulator_value=0.4, + l1_regularization_strength=0.05, + l2_regularization_strength=0.15, + l2_shrinkage_regularization_strength=0.01, + beta=0.3, + ) + + x = backend.Variable(np.ones([10])) + grads = np.arange(0.1, 1.1, 0.1) + first_grads = np.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [-0.0034, -0.0077, -0.0118, -0.0157, -0.0194, -0.023, -0.0263, -0.0294, -0.0325, -0.0354], + [-0.0078, -0.0162, -0.0242, -0.0317, -0.0387, -0.0454, -0.0516, -0.0575, -0.0631, -0.0685], + [-0.0121, -0.0246, -0.0363, -0.0472, -0.0573, -0.0668, -0.0757, -0.0842, -0.0922, -0.0999], + [-0.0164, -0.0328, -0.0481, -0.0623, -0.0753, -0.0875, -0.099, -0.1098, -0.1201, -0.1299]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Ftrl(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Ftrl(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + def test_invalid_initial_accumulator_value(self): + invalid_value = -0.1 + with self.assertRaisesRegex( + ValueError, + f"^`initial_accumulator_value` needs to be positive or zero. Received: initial_accumulator_value={invalid_value}.$", + ): + Ftrl(initial_accumulator_value=invalid_value) + + def test_invalid_learning_rate_power(self): + invalid_value = 0.1 + with self.assertRaisesRegex( + ValueError, + f"^`learning_rate_power` needs to be negative or zero. Received: learning_rate_power={invalid_value}.$", + ): + Ftrl(learning_rate_power=invalid_value) + + def test_invalid_l1_regularization_strength(self): + invalid_value = -0.1 + with self.assertRaisesRegex( + ValueError, + f"^`l1_regularization_strength` needs to be positive or zero. Received: l1_regularization_strength={invalid_value}.$", + ): + Ftrl(l1_regularization_strength=invalid_value) + + def test_invalid_l2_regularization_strength(self): + invalid_value = -0.1 + with self.assertRaisesRegex( + ValueError, + f"^`l2_regularization_strength` needs to be positive or zero. Received: l2_regularization_strength={invalid_value}.$", + ): + Ftrl(l2_regularization_strength=invalid_value) + + def test_invalid_l2_shrinkage_regularization_strength(self): + invalid_value = -0.1 + with self.assertRaisesRegex( + ValueError, + f"^`l2_shrinkage_regularization_strength` needs to be positive or zero. Received: l2_shrinkage_regularization_strength={invalid_value}.$", + ): + Ftrl(l2_shrinkage_regularization_strength=invalid_value) diff --git a/keras/src/optimizers/lamb.py b/keras/src/optimizers/lamb.py new file mode 100644 index 000000000000..5a4e1f3958d5 --- /dev/null +++ b/keras/src/optimizers/lamb.py @@ -0,0 +1,148 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export("keras.optimizers.Lamb") +class Lamb(optimizer.Optimizer): + """Optimizer that implements the Lamb algorithm. + + Lamb is a stochastic gradient descent method that + uses layer-wise adaptive moments to adjusts the + learning rate for each parameter based on the ratio of the + norm of the weight to the norm of the gradient + This helps to stabilize the training process and improves convergence + especially for large batch sizes. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. Defaults to + `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 2nd moment estimates. Defaults to + `0.999`. + epsilon: A small constant for numerical stability. + Defaults to `1e-7`. + {{base_optimizer_keyword_args}} + + References: + - [Yang et al.](https://arxiv.org/pdf/1904.00962) + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="lamb", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + + def build(self, var_list): + """Initialize optimizer variables. + + Lamb optimizer has 2 types of variables: momentums and velocities + + Args: + var_list: list of model variables to build Lamb variables on. + """ + if self.built: + return + super().build(var_list) + self._momentums, self._velocities = self.add_optimizer_variables( + var_list, ["momentum", "velocity"] + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + + beta_1_power = ops.power( + ops.cast(self.beta_1, variable.dtype), local_step + ) + beta_2_power = ops.power( + ops.cast(self.beta_2, variable.dtype), local_step + ) + + m = self._momentums[self._get_variable_index(variable)] + v = self._velocities[self._get_variable_index(variable)] + + self.assign_add( + m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1) + ) + + self.assign_add( + v, + ops.multiply( + ops.subtract(ops.square(gradient), v), 1 - self.beta_2 + ), + ) + + m_t_hat = ops.divide(m, (1.0 - beta_1_power)) + v_sqrt = ops.add( + ops.sqrt(ops.divide(v, (1.0 - beta_2_power))), self.epsilon + ) + + update = ops.divide(m_t_hat, v_sqrt) + w_norm = ops.sqrt(ops.sum(ops.power(variable, 2))) + g_norm = ops.sqrt(ops.sum(ops.power(update, 2))) + + # ratio = w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1 + ratio = ops.where( + ops.greater(w_norm, 0), + ops.where(ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), + 1.0, + ) + + self.assign_sub(variable, ratio * lr * update) + + def get_config(self): + config = super().get_config() + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + } + ) + return config + + +Lamb.__doc__ = Lamb.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/lamb_test.py b/keras/src/optimizers/lamb_test.py new file mode 100644 index 000000000000..682c2aeadbbb --- /dev/null +++ b/keras/src/optimizers/lamb_test.py @@ -0,0 +1,76 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.lamb import Lamb + + +class LambTest(testing.TestCase): + def test_config(self): + optimizer = Lamb( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Lamb(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [-0.3693, 0.6306, 1.6306, 2.6306], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Lamb(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Lamb(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Lamb(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Lamb() + + x = backend.Variable(np.ones([10], dtype="float32")) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + golden = np.tile( + [[0.999], [0.9982], [0.9974], [0.9965], [0.9955]], (1, 10) + ) + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Lamb(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Lamb(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/lion.py b/keras/src/optimizers/lion.py new file mode 100644 index 000000000000..5c798eb71355 --- /dev/null +++ b/keras/src/optimizers/lion.py @@ -0,0 +1,136 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Lion"]) +class Lion(optimizer.Optimizer): + """Optimizer that implements the Lion algorithm. + + The Lion optimizer is a stochastic-gradient-descent method that uses the + sign operator to control the magnitude of the update, unlike other adaptive + optimizers such as Adam that rely on second-order moments. This makes + Lion more memory-efficient as it only keeps track of the momentum. According + to the authors (see reference), its performance gain over Adam grows with + the batch size. Because the update of Lion is produced through the sign + operation, resulting in a larger norm, a suitable learning rate for Lion is + typically 3-10x smaller than that for AdamW. The weight decay for Lion + should in turn be 3-10x larger than that for AdamW to maintain a + similar strength (lr * wd). + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + rate to combine the current gradient and the 1st moment estimate. + Defaults to `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimate. Defaults to + `0.99`. + {{base_optimizer_keyword_args}} + + References: + + - [Chen et al., 2023](http://arxiv.org/abs/2302.06675) + - [Authors' implementation]( + http://github.com/google/automl/tree/master/lion) + + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.99, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="lion", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + if beta_1 <= 0 or beta_1 > 1: + raise ValueError( + "Argument `beta_1` must be in the [0, 1] range. Otherwise, the " + f"optimizer degenerates to SignSGD. Received: beta_1={beta_1}." + ) + + def build(self, var_list): + """Initialize optimizer variables. + + Lion optimizer has one variable `momentums`. + + Args: + var_list: list of model variables to build Lion variables on. + """ + if self.built: + return + super().build(var_list) + self._momentums = self.add_optimizer_variables(var_list, "momentum") + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + beta_1 = ops.cast(self.beta_1, variable.dtype) + beta_2 = ops.cast(self.beta_2, variable.dtype) + m = self._momentums[self._get_variable_index(variable)] + + self.assign_sub( + variable, + ops.multiply( + lr, + ops.sign( + ops.add( + ops.multiply(m, beta_1), + ops.multiply(gradient, (1.0 - beta_1)), + ) + ), + ), + ) + self.assign( + m, + ops.add( + ops.multiply(m, beta_2), ops.multiply(gradient, (1.0 - beta_2)) + ), + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + } + ) + return config + + +Lion.__doc__ = Lion.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/lion_test.py b/keras/src/optimizers/lion_test.py new file mode 100644 index 000000000000..49ffb0124fd8 --- /dev/null +++ b/keras/src/optimizers/lion_test.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest + +import keras +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.lion import Lion + + +class LionTest(testing.TestCase): + def test_invalid_beta_1(self): + with self.assertRaisesRegex( + ValueError, + "Argument `beta_1` must be in the \\[0, 1\\] range. Otherwise, the " + "optimizer degenerates to SignSGD. Received: beta_1=-0.1.", + ): + Lion(beta_1=-0.1) + with self.assertRaisesRegex( + ValueError, + "Argument `beta_1` must be in the \\[0, 1\\] range. Otherwise, the " + "optimizer degenerates to SignSGD. Received: beta_1=0.0.", + ): + Lion(beta_1=0.0) + with self.assertRaisesRegex( + ValueError, + "Argument `beta_1` must be in the \\[0, 1\\] range. Otherwise, the " + "optimizer degenerates to SignSGD. Received: beta_1=1.1.", + ): + Lion(beta_1=1.1) + + def test_config(self): + optimizer = Lion( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = Lion(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Lion(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Lion(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Lion(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Lion() + + x = backend.Variable(np.ones([10])) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + golden = np.tile( + [[0.999], [0.998], [0.997], [0.996], [0.995]], + (1, 10), + ) + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Lion(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Lion(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + @pytest.mark.requires_trainable_backend + def test_ema(self): + # TODO: test correctness + model = keras.Sequential([keras.layers.Dense(10)]) + model.compile(optimizer=Lion(use_ema=True), loss="mse") + x = keras.ops.zeros((1, 5)) + y = keras.ops.zeros((1, 10)) + model.fit(x, y) diff --git a/keras/src/optimizers/loss_scale_optimizer.py b/keras/src/optimizers/loss_scale_optimizer.py new file mode 100644 index 000000000000..d0f1cb062d85 --- /dev/null +++ b/keras/src/optimizers/loss_scale_optimizer.py @@ -0,0 +1,345 @@ +from keras.src import backend +from keras.src import initializers +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer +from keras.src.saving import serialization_lib +from keras.src.utils import tracking + + +@keras_export( + [ + "keras.optimizers.LossScaleOptimizer", + "keras.mixed_precision.LossScaleOptimizer", + ] +) +class LossScaleOptimizer(optimizer.Optimizer): + """An optimizer that dynamically scales the loss to prevent underflow. + + Loss scaling is a technique to prevent numeric underflow in intermediate + gradients when float16 is used. To prevent underflow, the loss is multiplied + (or "scaled") by a certain factor called the "loss scale", which causes + intermediate gradients to be scaled by the loss scale as well. The final + gradients are divided (or "unscaled") by the loss scale to bring them back + to their original value. + + `LossScaleOptimizer` wraps another optimizer and applies dynamic loss + scaling to it. This loss scale is dynamically updated over time as follows: + - On any train step, if a nonfinite gradient is encountered, the loss scale + is halved, and the train step is skipped. + - If `dynamic_growth_steps` have occurred since the last time the loss scale + was updated, and no nonfinite gradients have occurred, the loss scale + is doubled. + + Args: + inner_optimizer: The `keras.optimizers.Optimizer` instance to wrap. + initial_scale: Float. The initial loss scale. This scale will be updated + during training. It is recommended for this to be a very high + number, because a loss scale that is too high gets lowered far more + quickly than a loss scale that is too low gets raised. + dynamic_growth_steps: Int. How often to update the scale upwards. After + every `dynamic_growth_steps` steps with finite gradients, the + loss scale is doubled. + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + inner_optimizer, + initial_scale=2.0**15, + dynamic_growth_steps=2000, + name=None, + **kwargs, + ): + if not kwargs.pop("dynamic", True): + raise ValueError( + "LossScaleOptimizer no longer supports `dynamic=False`. " + "Instead, simply set `loss_scale_factor` directly on the " + "`inner_optimizer`." + ) + + # Backwards compatibility code for deserialization. + # LossScaleOptimizer used to return all these parameters in `get_config` + # from `super.get_config` even though they are all non-functional. We + # no longer let user set them, but we have to allow the default values + # to be passed during deserialization to support older models. + base_optimizer_defaults = { + "weight_decay": None, + "clipnorm": None, + "global_clipnorm": None, + "clipvalue": None, + "use_ema": False, + "ema_momentum": 0.99, + "ema_overwrite_frequency": None, + "loss_scale_factor": None, + "gradient_accumulation_steps": None, + } + for arg_name, default_value in base_optimizer_defaults.items(): + if arg_name not in kwargs: + continue + arg_value = kwargs.pop(arg_name) + if ( + default_value is None and arg_value is not None + ) or arg_value != default_value: + raise ValueError( + f"LossScaleOptimizer does not support `{arg_name}`. " + f"Instead, set `{arg_name}` on the `inner_optimizer`." + ) + + if kwargs: + raise ValueError( + "LossScaleOptimizer does not support arguments: " + f"`{'`, `'.join(kwargs.keys())}`." + ) + + super().__init__(learning_rate=0.0, name=name) + self.inner_optimizer = inner_optimizer + self.initial_scale = initial_scale + self.dynamic_growth_steps = dynamic_growth_steps + # Disable the inner optimizer's loss scaling, otherwise + # gradients will be scaled twice. + self.inner_optimizer.loss_scale_factor = None + + @tracking.no_automatic_dependency_tracking + def build(self, var_list): + self.step_counter = self.add_variable( + shape=(), + dtype="int", + initializer=initializers.Zeros(), + aggregation="none", + name="step_counter", + ) + self.dynamic_scale = self.add_variable( + shape=(), + dtype="float32", + initializer=initializers.Constant(self.initial_scale), + aggregation="none", + name="dynamic_scale", + ) + self.inner_optimizer.build(var_list) + super().build(var_list) + + @property + def variables(self): + return self._variables + self.inner_optimizer.variables + + def stateless_apply(self, optimizer_variables, grads, trainable_variables): + if not self.built: + raise ValueError( + f"To call `stateless_apply`, {self.__class__.__name__} " + "must be built (i.e. its variables must have been created). " + "You can build it via `optimizer.build(trainable_variables)`." + ) + finite = self.check_finite(grads) + return ops.cond( + finite, + lambda: self._stateless_handle_finite_grads( + optimizer_variables, grads, trainable_variables + ), + lambda: self._stateless_handle_non_finite_grads( + optimizer_variables, trainable_variables + ), + ) + + def _stateless_handle_finite_grads( + self, optimizer_variables, grads, trainable_variables + ): + def upscale(): + mapping = list(zip(self.variables, optimizer_variables)) + with backend.StatelessScope(state_mapping=mapping) as scope: + self.step_counter.assign(0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0)) + return [scope.get_current_value(v) for v in self._variables] + + def increment(): + mapping = list(zip(self.variables, optimizer_variables)) + with backend.StatelessScope(state_mapping=mapping) as scope: + self.step_counter.assign_add(1) + return [scope.get_current_value(v) for v in self._variables] + + mapping = list(zip(self.variables, optimizer_variables)) + with backend.StatelessScope(state_mapping=mapping): + # Potentially upscale loss and reset counter. + own_variables = ops.cond( + ops.equal(self.step_counter, self.dynamic_growth_steps - 1), + upscale, + increment, + ) + + # Unscale gradients. + scale = self.dynamic_scale + unscaled_grads = [ + g + if g is None or self._overwrite_variable_with_gradient(v) + else ops.divide(g, scale) + for g, v in zip(grads, self._trainable_variables) + ] + ( + new_trainable_variables, + new_inner_variables, + ) = self.inner_optimizer.stateless_apply( + self.inner_optimizer.variables, + unscaled_grads, + trainable_variables, + ) + + new_optimizer_variables = own_variables + new_inner_variables + return new_trainable_variables, new_optimizer_variables + + def _stateless_handle_non_finite_grads( + self, optimizer_variables, trainable_variables + ): + mapping = list(zip(self.variables, optimizer_variables)) + with backend.StatelessScope(state_mapping=mapping) as scope: + self.step_counter.assign(0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5)) + new_optimizer_variables = [] + for v in self.variables: + new_optimizer_variables.append(scope.get_current_value(v)) + return trainable_variables, new_optimizer_variables + + def apply(self, grads, trainable_variables=None): + # Optionally build optimizer. + if not self.built: + with backend.name_scope(self.name, caller=self): + self.build(trainable_variables) + self.built = True + + if backend.backend() == "tensorflow": + self._tf_apply(grads, trainable_variables) + else: + self._common_apply(grads, trainable_variables) + + def _stateful_handle_finite_grads(self, grads, trainable_variables): + scale = self.dynamic_scale + # Unscale gradients. + tvs = trainable_variables or self._trainable_variables + unscaled_grads = [ + g + if g is None or self._overwrite_variable_with_gradient(v) + else ops.divide(g, scale) + for g, v in zip(grads, tvs) + ] + self.inner_optimizer.apply( + unscaled_grads, trainable_variables=trainable_variables + ) + + def upscale(): + self.step_counter.assign(0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0)) + + def increment(): + self.step_counter.assign_add(1) + + # Potentially upscale loss and reset counter. + ops.cond( + ops.equal(self.step_counter, self.dynamic_growth_steps - 1), + upscale, + increment, + ) + + def _stateful_handle_non_finite_grads(self): + # If any inf or nan in grads, downscale loss and reset counter. + self.step_counter.assign(0) + self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5)) + + def _common_apply(self, grads, trainable_variables=None): + finite = self.check_finite(grads) + ops.cond( + finite, + lambda: self._stateful_handle_finite_grads( + grads, trainable_variables + ), + self._stateful_handle_non_finite_grads, + ) + + def _tf_apply(self, grads, trainable_variables=None): + """Tensorflow specific logic for apply, which handles distribution.""" + from keras.src.utils.module_utils import tensorflow as tf + + if tf.distribute.in_cross_replica_context(): + raise ValueError("apply() must be called in a replica context.") + + if tf.__internal__.distribute.strategy_supports_no_merge_call(): + self._common_apply(grads, trainable_variables=trainable_variables) + else: + + def _handle_cross_replica(distribution, grads, trainable_variables): + finite_per_replica = ( + distribution.extended.call_for_each_replica( + self.check_finite, args=(grads,) + ) + ) + # Each replica computed the same `finite` value, since + # `grads` is all-reduced across replicas. Arbitrarily take + # `finite` from the first replica. + finite = distribution.experimental_local_results( + finite_per_replica + )[0] + + def apply_fn(): + distribution.extended.call_for_each_replica( + self._stateful_handle_finite_grads, + args=(grads, trainable_variables), + ) + + # Note: We must call this cond() in a cross-replica context. + # DistributionStrategy does not support having a cond in a + # replica context with a branch that calls `merge_call`, and + # self._optimizer.apply_gradients calls `merge_call`. + ops.cond( + finite, apply_fn, self._stateful_handle_non_finite_grads + ) + + tf.distribute.get_replica_context().merge_call( + _handle_cross_replica, args=(grads, trainable_variables) + ) + + def check_finite(self, grads): + tensor_grads = [g for g in grads if g is not None] + finite_grads = [ops.all(ops.isfinite(g)) for g in tensor_grads] + return ops.all(ops.convert_to_tensor(finite_grads)) + + @property + def learning_rate(self): + return self.inner_optimizer.learning_rate + + @learning_rate.setter + def learning_rate(self, learning_rate): + self.inner_optimizer.learning_rate = learning_rate + + @property + def iterations(self): + return self.inner_optimizer.iterations + + def scale_loss(self, loss): + scale = self.dynamic_scale if self.built else self.initial_scale + return ops.multiply(loss, scale) + + def finalize_variable_values(self, var_list): + self.inner_optimizer.finalize_variable_values(var_list) + + def get_config(self): + # Do not use super().get_config() as only "name" is supported. + inner_optimizer_config = serialization_lib.serialize_keras_object( + self.inner_optimizer + ) + return { + "name": self.name, + "inner_optimizer": inner_optimizer_config, + "initial_scale": self.initial_scale, + "dynamic_growth_steps": self.dynamic_growth_steps, + } + + @classmethod + def from_config(cls, config, custom_objects=None): + inner_optimizer = serialization_lib.deserialize_keras_object( + config.pop("inner_optimizer"), + custom_objects=custom_objects, + ) + return cls(inner_optimizer, **config) + + +LossScaleOptimizer.__doc__ = LossScaleOptimizer.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/loss_scale_optimizer_test.py b/keras/src/optimizers/loss_scale_optimizer_test.py new file mode 100644 index 000000000000..d707ad765f33 --- /dev/null +++ b/keras/src/optimizers/loss_scale_optimizer_test.py @@ -0,0 +1,282 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer +from keras.src.optimizers.sgd import SGD + + +class LossScaleOptimizerTest(testing.TestCase): + def _skip_test_for_stateless(self, stateless): + if not stateless and backend.backend() == "jax": + self.skipTest( + "LossScaleOptimizer must use stateless_apply with JAX." + ) + if stateless and backend.backend() == "tensorflow": + self.skipTest( + "stateless_apply is not supported with the TF backend." + ) + + def test_config(self): + inner_optimizer = SGD( + learning_rate=0.5, + momentum=0.06, + nesterov=True, + weight_decay=0.004, + ) + optimizer = LossScaleOptimizer(inner_optimizer) + self.run_class_serialization_test(optimizer) + + def test_apply_with_no_vars(self): + self._skip_test_for_stateless(False) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer) + grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + optimizer.apply(grads) + self.assertAllClose( + vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 + ) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_finite_step(self, stateless): + self._skip_test_for_stateless(stateless) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer) + grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + if stateless: + optimizer.build(vars) + vars, _ = optimizer.stateless_apply( + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], + ) + else: + optimizer.apply(grads, vars) + self.assertAllClose( + vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 + ) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_finite_step_with_inner_loss_scale(self, stateless): + self._skip_test_for_stateless(stateless) + + # Ensure that the inner loss scale does not interfere with the update. + inner_optimizer = SGD(learning_rate=0.5, loss_scale_factor=100) + optimizer = LossScaleOptimizer(inner_optimizer) + grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + if stateless: + optimizer.build(vars) + vars, _ = optimizer.stateless_apply( + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], + ) + else: + optimizer.apply(grads, vars) + self.assertAllClose( + vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 + ) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_infinite_step(self, stateless): + self._skip_test_for_stateless(stateless) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer) + grads = [ops.array([np.inf, np.inf, np.inf, np.inf])] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + if stateless: + optimizer.build(vars) + vars, _ = optimizer.stateless_apply( + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], + ) + else: + optimizer.apply(grads, vars) + self.assertAllClose(vars, [[1.0, 2.0, 3.0, 4.0]], rtol=1e-4, atol=1e-4) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_finite_step_with_overwrite(self, stateless): + self._skip_test_for_stateless(stateless) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer) + grads = [ops.array([1.0, 6.0, 7.0, 2.0])] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + vars[0].overwrite_with_gradient = True + + if stateless: + optimizer.build(vars) + vars, _ = optimizer.stateless_apply( + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], + ) + else: + optimizer.apply(grads, vars) + self.assertAllClose(vars, grads) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_downscaling(self, stateless): + self._skip_test_for_stateless(stateless) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer, initial_scale=400.0) + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + opt_var_values = [v.value for v in optimizer.variables] + grads = [ops.array([np.inf, np.inf, np.inf, np.inf])] + for _ in range(4): + if stateless: + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): + ref_v.assign(v) + else: + optimizer.apply(grads, vars) + self.assertAllClose(optimizer.scale_loss(1.0), 25.0) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_upscaling(self, stateless): + self._skip_test_for_stateless(stateless) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer( + inner_optimizer, + initial_scale=2.0, + dynamic_growth_steps=2, + ) + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + opt_var_values = [v.value for v in optimizer.variables] + grads = [ops.array([1.0, 6.0, 7.0, 2.0])] + for _ in range(8): + if stateless: + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): + ref_v.assign(v) + else: + optimizer.apply(grads, vars) + self.assertAllClose(optimizer.scale_loss(1.0), 32.0) + + @parameterized.named_parameters(("stateless", True), ("stateful", False)) + def test_iterations_update(self, stateless): + self._skip_test_for_stateless(stateless) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer) + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + opt_var_values = [v.value for v in optimizer.variables] + grads = [ops.array([1.0, 6.0, 7.0, 2.0])] + + self.assertEqual(optimizer.iterations.value, 0) + + for i in range(3): + if stateless: + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): + ref_v.assign(v) + else: + optimizer.apply(grads, vars) + self.assertEqual(optimizer.iterations.value, i + 1) + + def test_serialization(self): + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer( + inner_optimizer, + initial_scale=3.0, + dynamic_growth_steps=2, + name="test_opt", + ) + config = optimizer.get_config() + self.assertLen(config, 4) + self.assertEqual(config["name"], "test_opt") + self.assertEqual(config["initial_scale"], 3.0) + self.assertEqual(config["dynamic_growth_steps"], 2) + self.assertIn("inner_optimizer", config) + LossScaleOptimizer.from_config(config) + + def test_init_dynamic_arg(self): + inner_optimizer = SGD(learning_rate=0.5) + + # dynamic=True is supported + LossScaleOptimizer(inner_optimizer, dynamic=True) + + # dynamic=False is not supported + with self.assertRaisesRegex(ValueError, "set `loss_scale_factor`"): + LossScaleOptimizer(inner_optimizer, dynamic=False) + + def test_init_unsupported_arg(self): + inner_optimizer = SGD(learning_rate=0.5) + with self.assertRaisesRegex(ValueError, "arguments: `foo`, `bar`"): + LossScaleOptimizer(inner_optimizer, foo=True, bar=3) + + @parameterized.named_parameters( + ("weight_decay", "weight_decay", 0.5), + ("clipnorm", "clipnorm", 0.5), + ("global_clipnorm", "global_clipnorm", 0.5), + ("clipvalue", "clipvalue", 0.5), + ("use_ema", "use_ema", True), + ("ema_momentum", "ema_momentum", 0.5), + ("ema_overwrite_frequency", "ema_overwrite_frequency", 2), + ("loss_scale_factor", "loss_scale_factor", 0.5), + ("gradient_accumulation_steps", "gradient_accumulation_steps", 2), + ) + def test_init_base_optimizer_unsupported_args(self, arg_name, arg_value): + inner_optimizer = SGD(learning_rate=0.5) + with self.assertRaisesRegex(ValueError, "on the `inner_optimizer`"): + LossScaleOptimizer(inner_optimizer, **{arg_name: arg_value}) + + def test_deserialization_backwards_compatibility(self): + # Test deserializing with a config that has all the unsupported + # arguments from the base optimizer (which are no longer serialized) + config = { + "name": "loss_scale_optimizer", + "weight_decay": None, + "clipnorm": None, + "global_clipnorm": None, + "clipvalue": None, + "use_ema": False, + "ema_momentum": 0.99, + "ema_overwrite_frequency": None, + "loss_scale_factor": None, + "gradient_accumulation_steps": None, + "inner_optimizer": { + "module": "keras.optimizers", + "class_name": "SGD", + "config": { + "name": "SGD", + "learning_rate": 0.5, + "weight_decay": None, + "clipnorm": None, + "global_clipnorm": None, + "clipvalue": None, + "use_ema": False, + "ema_momentum": 0.99, + "ema_overwrite_frequency": None, + "loss_scale_factor": None, + "gradient_accumulation_steps": None, + "momentum": 0.0, + "nesterov": False, + }, + "registered_name": None, + }, + "initial_scale": 2.0, + "dynamic_growth_steps": 2, + } + LossScaleOptimizer.from_config(config) diff --git a/keras/src/optimizers/muon.py b/keras/src/optimizers/muon.py new file mode 100644 index 000000000000..88d0dde3ee92 --- /dev/null +++ b/keras/src/optimizers/muon.py @@ -0,0 +1,289 @@ +import re + +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Muon"]) +class Muon(optimizer.Optimizer): + """Optimizer that implements the Muon algorithm. + + Note that this optimizer should not be used in the following layers: + + 1. Embedding layer + 2. Final output fully connected layer + 3. Any {0,1}-D variables + + These should all be optimized using AdamW. + + The Muon optimizer can use both the Muon update step or the + AdamW update step based on the following: + + - For any variable that isn't 2D, 3D or 4D, the AdamW step + will be used. This is not configurable. + - If the argument `exclude_embeddings` (defaults to `True`) is set + to `True`, the AdamW step will be used. + - For any variablewith a name that matches an expression + listed in the argument `exclude_layers` (a list), the + AdamW step will be used. + - Any other variable uses the Muon step. + + Typically, you only need to pass the name of your densely-connected + output layer to `exclude_layers`, e.g. + `exclude_layers=["output_dense"]`. + + References: + - [Original implementation](https://github.com/KellerJordan/Muon) + - [Liu et al, 2025](https://arxiv.org/abs/2502.16982) + + Args: + learning_rate: A float, + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + adam_beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. + The exponential decay rate for the 1st moment estimates. Defaults to + `0.9`. + adam_beta_2: A float value or a constant float tensor, ora callable + that takes no arguments and returns the actual value to use. + The exponential decay rate for the 2nd moment estimates. Defaults to + `0.999`. + epsilon: A small constant for numerical stability. This is + "epsilon hat" in the Kingma and Ba paper + (in the formula just before Section 2.1), + not the epsilon in Algorithm 1 of the paper. + It be used at Adamw.Defaults to `1e-7`. + exclude_layers: List of strings, keywords of layer names to exclude. + All layers with keywords in their path will use adamw. + exclude_embeddings: Boolean value + If True, embedding layers will use adamw. + muon_a: Float, parameter a of the muon algorithm. + It is recommended to use the default value + muon_b: Float, parameter b of the muon algorithm. + It is recommended to use the default value + muon_c: Float, parameter c of the muon algorithm. + It is recommended to use the default value + adam_lr_ratio: Float, the ratio of the learning rate when + using Adam to the main learning rate. + it is recommended to set it to 0.1 + momentum: Float, momentum used by internal SGD. + ns_steps: Integer, number of Newton-Schulz iterations to run. + nesterov: Boolean, whether to use Nesterov-style momentum + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + learning_rate=0.001, + adam_beta_1=0.9, + adam_beta_2=0.999, + epsilon=1e-7, + weight_decay=0.1, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="muon", + exclude_layers=None, + exclude_embeddings=True, + muon_a=3.4445, + muon_b=-4.7750, + muon_c=2.0315, + adam_lr_ratio=0.1, + momentum=0.95, + ns_steps=6, + nesterov=True, + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.adam_beta_1 = adam_beta_1 + self.adam_beta_2 = adam_beta_2 + self.epsilon = epsilon + self.muon_a = muon_a + self.muon_b = muon_b + self.muon_c = muon_c + self.adam_lr_ratio = adam_lr_ratio + self.momentum = momentum + self.ns_steps = ns_steps + self.nesterov = nesterov + self.exclude_embeddings = exclude_embeddings + self.exclude_layers = exclude_layers or [] + + def _should_use_adamw(self, variable): + # To use it with 4D convolutional filters, + # it works well to just flatten their last 3 dimensions. + # any {0,1}-D parameters should all be optimized by adam + if not 1 < len(variable.shape) < 4: + return True + if self.exclude_embeddings and "embedding" in variable.path.lower(): + return True + for keyword in self.exclude_layers: + if re.search(keyword, variable.path): + return True + return False + + def build(self, var_list): + """Initialize optimizer variables. + + Adam optimizer has 3 types of variables: momentums, velocities and + velocity_hat (only set when amsgrad is applied), + + Args: + var_list: list of model variables to build Adam variables on. + """ + if self.built: + return + super().build(var_list) + self.adam_momentums = {} + self.adam_velocities = {} + + self.muon_momentums = {} + self.muon_velocities = {} + + for var in var_list: + if not self._overwrite_variable_with_gradient(var): + self.adam_momentums[var.path] = ( + self.add_variable_from_reference( + reference_variable=var, name="momentum" + ) + ) + if self._should_use_adamw(var): + self.adam_velocities[var.path] = ( + self.add_variable_from_reference( + reference_variable=var, name="velocity" + ) + ) + + def update_step(self, gradient, variable, learning_rate): + if self._should_use_adamw(variable): + # It should be noted that lr is one-tenth when using adamw. + self._adamw_update_step( + gradient, variable, learning_rate * self.adam_lr_ratio + ) + else: + self._muon_update_step(gradient, variable, learning_rate) + + def _muon_update_step(self, gradient, variable, lr): + m = self.adam_momentums[variable.path] + self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) + shape = variable.shape + if self.nesterov: + g = ops.add(gradient, self.momentum * m) + else: + g = m + + self.assign_sub( + variable, + lr + * self.zeropower_via_newtonschulz5(g, self.ns_steps) + * max(1, shape[0] / shape[1]) ** 0.5, + ) + + def _adamw_update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + local_step = ops.cast(self.iterations + 1, variable.dtype) + adam_beta_1_power = ops.power( + ops.cast(self.adam_beta_1, variable.dtype), local_step + ) + adam_beta_2_power = ops.power( + ops.cast(self.adam_beta_2, variable.dtype), local_step + ) + + m = self.adam_momentums[variable.path] + v = self.adam_velocities[variable.path] + + alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power) + + self.assign_add( + m, ops.multiply(ops.subtract(gradient, m), 1 - self.adam_beta_1) + ) + self.assign_add( + v, + ops.multiply( + ops.subtract(ops.square(gradient), v), 1 - self.adam_beta_2 + ), + ) + self.assign_sub( + variable, + ops.divide( + ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon) + ), + ) + + def transpose_last_axis(self, X): + shape = ops.shape(X) + temp_order = list(range(len(shape))) + temp_order[-2] = temp_order[-1] + temp_order[-1] = len(shape) - 2 + X = ops.transpose(X, temp_order) + return X + + def zeropower_via_newtonschulz5(self, x, steps: int): + """We apply the Newton-Schulz iteration to compute matrix G. + + We select a quintic iteration that maximizes the slope at zero. This + approach helps minimize steps, even if the iteration doesn't fully + converge across the interval. The result isn't exactly UV^T (from the + SVD of G), but rather an approximation like US'V^T. Despite this + approximation, model performance remains unaffected compared to using + the exact UV^T from the SVD. + """ + shape = ops.shape(x) + assert len(shape) >= 2 + + a, b, c = self.muon_a, self.muon_b, self.muon_c + if shape[-2] > shape[-1]: + x = self.transpose_last_axis(x) + + # Ensure spectral norm is at most 1 + x = x / (ops.norm(x, axis=(-2, -1), keepdims=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + temp_a = x @ self.transpose_last_axis(x) + temp_b = b * temp_a + c * temp_a @ temp_a + x = a * x + temp_b @ x + + if shape[-2] > shape[-1]: + x = self.transpose_last_axis(x) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "adam_beta_1": self.adam_beta_1, + "adam_beta_2": self.adam_beta_2, + "epsilon": self.epsilon, + "exclude_layers": self.exclude_layers, + "muon_a": self.muon_a, + "muon_b": self.muon_b, + "muon_c": self.muon_c, + "adam_lr_ratio": self.adam_lr_ratio, + "momentum": self.momentum, + "ns_steps": self.ns_steps, + "nesterov": self.nesterov, + "exclude_embeddings": self.exclude_embeddings, + } + ) + return config diff --git a/keras/src/optimizers/muon_test.py b/keras/src/optimizers/muon_test.py new file mode 100644 index 000000000000..f22423c34aae --- /dev/null +++ b/keras/src/optimizers/muon_test.py @@ -0,0 +1,83 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.layers import Dense +from keras.src.layers import Embedding +from keras.src.optimizers.muon import Muon + + +class MuonTest(testing.TestCase): + def test_config(self): + optimizer = Muon( + learning_rate=0.5, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_Newton_Schulz(self): + optimizer = Muon() + tensor_input = ops.array([[0.2499, 0.9105], [0.2655, 0.8824]]) + except_output = ops.array([[-0.4422, 0.6457], [0.7285, 0.2968]]) + output = optimizer.zeropower_via_newtonschulz5(tensor_input, 5) + self.assertAllClose(output, except_output, rtol=1e-3, atol=1e-3) + + def test_adamw_single_step(self): + optimizer = Muon() + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0], name="test_vars") + optimizer.build([vars]) + optimizer._adamw_update_step(grads, vars, 0.5) + self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4) + + def test_should_use_adamw(self): + vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + optimizer = Muon(exclude_layers=["var"]) + self.assertAllClose( + True, + optimizer._should_use_adamw(vars), + ) + embeding = Embedding(2, 2) + embeding.build() + self.assertAllClose( + True, + optimizer._should_use_adamw(embeding.weights[0]), + ) + vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + optimizer = Muon() + self.assertAllClose( + False, + optimizer._should_use_adamw(vars), + ) + dense = Dense(2) + dense.build([None, 2]) + self.assertAllClose( + False, + optimizer._should_use_adamw(dense.weights[0]), + ) + + def test_muon_single_step(self): + optimizer = Muon( + learning_rate=0.5, + weight_decay=0, + ) + grads = ops.array([[1.0, 6.0], [7.0, 2.0]]) + vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + optimizer.build([vars]) + optimizer._muon_update_step(grads, vars, 0.5) + self.assertAllClose( + vars, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2 + ) + + def test_clip_norm(self): + optimizer = Muon(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Muon(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/nadam.py b/keras/src/optimizers/nadam.py new file mode 100644 index 000000000000..4b0fddb83b19 --- /dev/null +++ b/keras/src/optimizers/nadam.py @@ -0,0 +1,163 @@ +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.Nadam"]) +class Nadam(optimizer.Optimizer): + """Optimizer that implements the Nadam algorithm. + + Much like Adam is essentially RMSprop with momentum, Nadam is Adam with + Nesterov momentum. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + beta_1: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 1st moment estimates. + Defaults to `0.9`. + beta_2: A float value or a constant float tensor, or a callable + that takes no arguments and returns the actual value to use. The + exponential decay rate for the 2nd moment estimates. Defaults to + `0.999`. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + Defaults to `1e-7`. + {{base_optimizer_keyword_args}} + + Reference: + + - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). + + """ + + def __init__( + self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="nadam", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + + def build(self, var_list): + """Initialize optimizer variables. + + Nadam optimizer has 2 types of variables: momentums and velocities. + + Args: + var_list: list of model variables to build Nadam variables on. + """ + if self.built: + return + if var_list: + dtype = var_list[0].dtype + else: + dtype = backend.floatx() + super().build(var_list) + self._momentums, self._velocities = self.add_optimizer_variables( + var_list, ["momentum", "velocity"] + ) + self._u_product = backend.Variable(1.0, dtype=dtype) + + def _backend_update_step(self, grads, trainable_variables, learning_rate): + dtype = self._u_product.dtype + self.assign( + self._u_product, + self._u_product + * self.beta_1 + * ( + 1.0 + - 0.5 * ops.power(0.96, ops.cast(self.iterations + 1, dtype)) + ), + ) + super()._backend_update_step(grads, trainable_variables, learning_rate) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + var_dtype = variable.dtype + lr = ops.cast(learning_rate, var_dtype) + gradient = ops.cast(gradient, var_dtype) + + local_step = ops.cast(self.iterations + 1, var_dtype) + next_step = ops.cast(self.iterations + 2, var_dtype) + decay = ops.cast(0.96, var_dtype) + beta_1 = ops.cast(self.beta_1, var_dtype) + beta_2 = ops.cast(self.beta_2, var_dtype) + u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step))) + u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step))) + u_product_t = ops.cast(self._u_product, var_dtype) + + u_product_t_1 = u_product_t * u_t_1 + beta_2_power = ops.power(beta_2, local_step) + + m = self._momentums[self._get_variable_index(variable)] + v = self._velocities[self._get_variable_index(variable)] + + self.assign_add( + m, ops.multiply(ops.subtract(gradient, m), (1 - beta_1)) + ) + self.assign_add( + v, ops.multiply(ops.subtract(ops.square(gradient), v), (1 - beta_2)) + ) + m_hat = ops.add( + ops.divide(ops.multiply(u_t_1, m), 1 - u_product_t_1), + ops.divide(ops.multiply(1 - u_t, gradient), 1 - u_product_t), + ) + v_hat = ops.divide(v, (1 - beta_2_power)) + + self.assign_sub( + variable, + ops.divide( + ops.multiply(m_hat, lr), ops.add(ops.sqrt(v_hat), self.epsilon) + ), + ) + + def get_config(self): + config = super().get_config() + + config.update( + { + "beta_1": self.beta_1, + "beta_2": self.beta_2, + "epsilon": self.epsilon, + } + ) + return config + + +Nadam.__doc__ = Nadam.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/nadam_test.py b/keras/src/optimizers/nadam_test.py new file mode 100644 index 000000000000..b6d5f67c2ae3 --- /dev/null +++ b/keras/src/optimizers/nadam_test.py @@ -0,0 +1,95 @@ +# flake8: noqa + + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.nadam import Nadam + + +class NadamTest(testing.TestCase): + def test_config(self): + optimizer = Nadam( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + ) + self.run_class_serialization_test(optimizer) + + def test_build_with_empty_var_list(self): + optimizer = Nadam() + optimizer.build([]) + self.assertEqual(optimizer._u_product.dtype, backend.floatx()) + + def test_single_step(self): + optimizer = Nadam(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [0.4686, 1.4686, 2.4686, 3.4686], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = Nadam(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = Nadam(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = Nadam(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = Nadam( + learning_rate=0.5, + beta_1=0.5, + beta_2=0.67, + epsilon=1e-5, + ) + + x = backend.Variable(np.ones([10], dtype="float32")) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281], + [-0.1738, -0.1731, -0.1726, -0.1723, -0.1721, -0.172, -0.1719, -0.1718, -0.1718, -0.1717], + [-0.7115, -0.7103, -0.7096, -0.7092, -0.709, -0.7088, -0.7086, -0.7085, -0.7085, -0.7084], + [-1.2335, -1.2322, -1.2313, -1.2309, -1.2306, -1.2304, -1.2302, -1.2301, -1.23, -1.2299], + [-1.7492, -1.7478, -1.7469, -1.7464, -1.7461, -1.7459, -1.7457, -1.7456, -1.7455, -1.7454]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = Nadam(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = Nadam(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/optimizer.py b/keras/src/optimizers/optimizer.py new file mode 100644 index 000000000000..c285b814ba74 --- /dev/null +++ b/keras/src/optimizers/optimizer.py @@ -0,0 +1,27 @@ +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.optimizers import base_optimizer + +if backend.backend() == "tensorflow": + from keras.src.backend.tensorflow.optimizer import ( + TFOptimizer as BackendOptimizer, + ) +elif backend.backend() == "torch": + from keras.src.backend.torch.optimizers import ( + TorchOptimizer as BackendOptimizer, + ) +elif backend.backend() == "jax": + from keras.src.backend.jax.optimizer import JaxOptimizer as BackendOptimizer +else: + + class BackendOptimizer(base_optimizer.BaseOptimizer): + pass + + +@keras_export(["keras.Optimizer", "keras.optimizers.Optimizer"]) +class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer): + pass + + +Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__ +base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args diff --git a/keras/src/optimizers/optimizer_sparse_test.py b/keras/src/optimizers/optimizer_sparse_test.py new file mode 100644 index 000000000000..1d1f73ebaa45 --- /dev/null +++ b/keras/src/optimizers/optimizer_sparse_test.py @@ -0,0 +1,310 @@ +from unittest import mock + +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import optimizers +from keras.src import testing + + +class ScatterUpdateOptimizer(optimizers.Optimizer): + def __init__(self): + super().__init__(learning_rate=0.001) + + def build(self, variables): + if self.built: + return + super().build(variables) + self.momentums = [ + self.add_variable_from_reference(v, name="momentum") + for v in variables + ] + + def update_step(self, grad, variable, learning_rate): + momentum = self.momentums[self._get_variable_index(variable)] + self.assign(momentum, ops.cast(grad, momentum.dtype)) + self.assign(variable, ops.cast(grad, variable.dtype)) + + +TEST_CASES = [ + { + "testcase_name": "adadelta", + "optimizer_class": optimizers.Adadelta, + "expect_model_sparse_variable_updates": True, + }, + { + "testcase_name": "adafactor", + "optimizer_class": optimizers.Adafactor, + "init_kwargs": {"clip_threshold": 0.5}, + "expect_model_sparse_variable_updates": True, + }, + { + "testcase_name": "adagrad", + "optimizer_class": optimizers.Adagrad, + "expect_model_sparse_variable_updates": True, + "expect_optimizer_sparse_variable_updates": True, + }, + { + "testcase_name": "adam", + "optimizer_class": optimizers.Adam, + }, + { + "testcase_name": "adam_amsgrad", + "optimizer_class": optimizers.Adam, + "init_kwargs": {"amsgrad": True}, + }, + { + "testcase_name": "adamax", + "optimizer_class": optimizers.Adamax, + }, + { + "testcase_name": "adamw", + "optimizer_class": optimizers.AdamW, + }, + { + "testcase_name": "adamw_amsgrad", + "optimizer_class": optimizers.AdamW, + "init_kwargs": {"amsgrad": True}, + }, + { + "testcase_name": "ftrl", + "optimizer_class": optimizers.Ftrl, + }, + { + "testcase_name": "lion", + "optimizer_class": optimizers.Lion, + }, + { + "testcase_name": "loss_scale_optimizer_sgd", + "optimizer_class": lambda: optimizers.LossScaleOptimizer( + optimizers.SGD(learning_rate=0.5) + ), + "expect_model_sparse_variable_updates": True, + }, + { + "testcase_name": "nadam", + "optimizer_class": optimizers.Nadam, + }, + { + "testcase_name": "rmsprop", + "optimizer_class": optimizers.RMSprop, + "expect_model_sparse_variable_updates": True, + }, + { + "testcase_name": "rmsprop_momentum", + "optimizer_class": optimizers.RMSprop, + "init_kwargs": {"momentum": 0.05}, + }, + { + "testcase_name": "rmsprop_momentum_centered", + "optimizer_class": optimizers.RMSprop, + "init_kwargs": {"momentum": 0.05, "centered": True}, + }, + { + "testcase_name": "sgd", + "optimizer_class": optimizers.SGD, + "expect_model_sparse_variable_updates": True, + }, + { + "testcase_name": "sgd_momentum", + "optimizer_class": optimizers.SGD, + "init_kwargs": {"momentum": 0.05}, + }, + { + "testcase_name": "sgd_momentum_nesterov", + "optimizer_class": optimizers.SGD, + "init_kwargs": {"momentum": 0.05, "nesterov": True}, + }, + { + "testcase_name": "scatter_update", + "optimizer_class": ScatterUpdateOptimizer, + "expect_model_sparse_variable_updates": True, + "expect_optimizer_sparse_variable_updates": True, + }, +] + + +@pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", +) +class OptimizerSparseTest(testing.TestCase): + @parameterized.named_parameters(TEST_CASES) + def test_sparse_gradients( + self, + optimizer_class, + init_kwargs={}, + expect_model_sparse_variable_updates=False, + expect_optimizer_sparse_variable_updates=False, + ): + # This test verifies that: + # - Optimizers use Keras ops everywhere instead of native operators + # (e.g. `ops.add()` instead of `+`) where sparse gradients are handled + # - The used ops handle sparse gradients + # - Optimizers use `self.assign/assign_add/assign_sub` instead of + # calling the method on the variable directly. Otherwise, the sparse + # updates are densified before being applied. + # - For some optimizers, a sparse gradient actually results in a sparse + # variable update as per `expect_model_sparse_variable_updates` and + # `expect_optimizer_sparse_variable_updates` + + model_variable = backend.Variable(initializer="ones", shape=(5, 10)) + optimizer = optimizer_class(**init_kwargs) + + # Mocking "tensorflow.Variable" won't work as it gets substituted with + # the resource variable class. + + if backend.backend() == "tensorflow": + import tensorflow as tf + + grad = tf.IndexedSlices(0.5 * ops.ones((3, 10)), (0, 2, 4), (5, 10)) + sparse_class = tf.IndexedSlices + variable_class = model_variable._value.__class__ + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + grad = jax_sparse.BCOO( + (0.5 * ops.ones((3, 10)), ((0,), (2,), (4,))), shape=(5, 10) + ) + sparse_class = jax_sparse.JAXSparse + variable_class = model_variable.__class__ + else: + self.fail(f"Sparse is unsupported with backend {backend.backend()}") + + optimizer_to_patch = ( + optimizer.inner_optimizer + if isinstance(optimizer, optimizers.LossScaleOptimizer) + else optimizer + ) + + model_sparse_variable_updates = False + optimizer_sparse_variable_updates = False + + def mock_optimizer_assign(variable, value): + nonlocal model_sparse_variable_updates + nonlocal optimizer_sparse_variable_updates + if isinstance(variable, backend.Variable): + variable = variable._value + if isinstance(value, sparse_class): + if variable is model_variable._value: + model_sparse_variable_updates = True + elif any(variable is v._value for v in optimizer.variables): + optimizer_sparse_variable_updates = True + + def mock_variable_assign(variable, value): + # Make an exception for scalar variables + if len(variable.shape): + pytest.fail( + "Optimizer is calling `assign`, `assign_add` or " + "`assign_sub` directly on a variable. Use " + "`self.assign/assign_add/assign_sub(variable, value)` " + "instead to support sparse updates." + ) + + # patch "_apply_weight_decay" to exclude this special case. + # patch the optimizer "assign" methods to detect sparse updates. + # patch the tf.Variable "assign" methods to detect direct assign calls. + with ( + mock.patch.object( + optimizer_to_patch, "_apply_weight_decay", autospec=True + ), + mock.patch.object( + optimizer_to_patch, "assign", autospec=True + ) as optimizer_assign, + mock.patch.object( + optimizer_to_patch, "assign_add", autospec=True + ) as optimizer_assign_add, + mock.patch.object( + optimizer_to_patch, "assign_sub", autospec=True + ) as optimizer_assign_sub, + mock.patch.object( + variable_class, "assign", autospec=True + ) as variable_assign, + mock.patch.object( + variable_class, "assign_add", autospec=True + ) as variable_assign_add, + mock.patch.object( + variable_class, "assign_sub", autospec=True + ) as variable_assign_sub, + ): + optimizer_assign.side_effect = mock_optimizer_assign + optimizer_assign_add.side_effect = mock_optimizer_assign + optimizer_assign_sub.side_effect = mock_optimizer_assign + variable_assign.side_effect = mock_variable_assign + variable_assign_add.side_effect = mock_variable_assign + variable_assign_sub.side_effect = mock_variable_assign + + optimizer.apply([grad], [model_variable]) + + self.assertEqual( + model_sparse_variable_updates, expect_model_sparse_variable_updates + ) + self.assertEqual( + optimizer_sparse_variable_updates, + expect_optimizer_sparse_variable_updates, + ) + + @parameterized.named_parameters(TEST_CASES) + def test_sparse_correctness( + self, optimizer_class, init_kwargs={}, **kwargs + ): + # This test verifies that applying a sparse gradient gives the same + # numerical results as the same dense gradient. + + optimizer_sparse = optimizer_class(**init_kwargs) + optimizer_dense = optimizer_class(**init_kwargs) + var_sparse = backend.Variable(initializer="ones", shape=(5, 3, 2)) + var_dense = backend.Variable(initializer="ones", shape=(5, 3, 2)) + stateless = backend.backend() == "jax" + if stateless: + optimizer_sparse.build([var_sparse]) + optimizer_dense.build([var_dense]) + + optimizer_sparse_vars = optimizer_sparse.variables + optimizer_dense_vars = optimizer_dense.variables + var_sparse_values = [var_sparse.value] + var_dense_values = [var_dense.value] + + for i in range(5): + if backend.backend() == "tensorflow": + import tensorflow as tf + + grad_sparse = tf.IndexedSlices( + values=ops.ones((3, 3, 2)) * (10.0 - i), + indices=(0, 2, 4), + dense_shape=(5, 3, 2), + ) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + grad_sparse = jax_sparse.BCOO( + (ops.ones((3, 3, 2)) * (10.0 - i), ((0,), (2,), (4,))), + shape=(5, 3, 2), + ) + else: + self.fail( + f"Sparse is unsupported with backend {backend.backend()}" + ) + + grad_dense = ops.convert_to_tensor(grad_sparse, sparse=False) + if stateless: + ( + var_sparse_values, + optimizer_sparse_vars, + ) = optimizer_sparse.stateless_apply( + optimizer_sparse_vars, [grad_sparse], var_sparse_values + ) + ( + var_dense_values, + optimizer_dense_vars, + ) = optimizer_dense.stateless_apply( + optimizer_dense_vars, [grad_dense], var_dense_values + ) + self.assertAllClose(var_sparse_values[0], var_dense_values[0]) + + else: + optimizer_sparse.apply([grad_sparse], [var_sparse]) + optimizer_dense.apply([grad_dense], [var_dense]) + self.assertAllClose(var_sparse.value, var_dense.value) diff --git a/keras/src/optimizers/optimizer_test.py b/keras/src/optimizers/optimizer_test.py new file mode 100644 index 000000000000..7d661df9a3c0 --- /dev/null +++ b/keras/src/optimizers/optimizer_test.py @@ -0,0 +1,425 @@ +import os +import pickle + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import constraints +from keras.src import layers +from keras.src import models +from keras.src import optimizers +from keras.src import testing + + +class OptimizerTest(testing.TestCase): + def test_iterations_counter(self): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = optimizers.Adam(learning_rate=1.0) + self.assertAllClose(optimizer.iterations, 0) + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(optimizer.iterations, 1) + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(optimizer.iterations, 2) + + def test_empty_gradients(self): + # Test no valid gradient + v = backend.Variable([[3.0, 4.0], [5.0, 6.0]]) + grads = None + optimizer = optimizers.SGD(learning_rate=1.0) + with self.assertRaisesRegex( + ValueError, "No gradients provided for any variable." + ): + optimizer.apply_gradients([(grads, v)]) + + # Test filtering of empty gradients + v2 = backend.Variable([[3.0, 4.0], [5.0, 6.0]]) + grads2 = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = optimizers.SGD(learning_rate=1.0) + with self.assertWarns(Warning): + optimizer.apply_gradients([(grads, v), (grads2, v2)]) + self.assertAllClose(v, [[3.0, 4.0], [5.0, 6.0]]) + self.assertAllClose(v2, [[2.0, 3.0], [4.0, 5.0]]) + + def test_clip_args(self): + optimizer = optimizers.SGD(learning_rate=1.0, clipnorm=0.1) + self.assertEqual(optimizer.clipnorm, 0.1) + optimizer = optimizers.SGD(learning_rate=1.0, clipvalue=0.1) + self.assertEqual(optimizer.clipvalue, 0.1) + optimizer = optimizers.SGD(learning_rate=1.0, global_clipnorm=0.1) + self.assertEqual(optimizer.global_clipnorm, 0.1) + + # Test invalid arguments + with self.assertRaisesRegex( + ValueError, + "Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can", + ): + optimizers.SGD( + learning_rate=1.0, + clipnorm=0.1, + clipvalue=0.1, + ) + with self.assertRaisesRegex( + ValueError, + "Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can", + ): + optimizers.SGD( + learning_rate=1.0, + clipnorm=0.1, + global_clipnorm=0.1, + ) + + def test_clip_norm(self): + optimizer = optimizers.SGD(clipnorm=1) + grad = backend.convert_to_tensor([100.0, 100.0]) + clipped_grad = optimizer._clip_gradients([grad]) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = optimizers.SGD(clipvalue=1) + grad = backend.convert_to_tensor([100.0, 100.0]) + clipped_grad = optimizer._clip_gradients([grad]) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + def test_global_clip_norm(self): + optimizer = optimizers.SGD(global_clipnorm=1) + grad = np.array([50.0, 100.0], dtype="float32") + global_norm = np.linalg.norm(grad) + clipped_grad = optimizer._clip_gradients( + [backend.convert_to_tensor(grad)] + ) + self.assertAllClose(clipped_grad[0], grad / global_norm) + + def test_ema(self): + v = backend.Variable([[3.0, 4.0], [5.0, 6.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = optimizers.SGD( + learning_rate=1.0, + use_ema=True, + ema_momentum=0.9, + ema_overwrite_frequency=3, + ) + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[2.0, 3.0], [4.0, 5.0]]) + self.assertAllClose( + optimizer._model_variables_moving_average[0], + [[2.0, 3.0], [4.0, 5.0]], # initialized after first step + ) + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._model_variables_moving_average[0], + [[1.9, 2.9], [3.9, 4.9]], + ) + optimizer.apply_gradients([(grads, v)]) + # Variables were overwritten with EMA + self.assertAllClose(v, [[1.71, 2.71], [3.71, 4.71]]) + self.assertAllClose( + optimizer._model_variables_moving_average[0], + [[1.71, 2.71], [3.71, 4.71]], + ) + + @pytest.mark.requires_trainable_backend + def test_ema_with_model_fit(self): + x_train = np.ones((1, 1)).astype("float32") + y_train = np.zeros((1, 1)).astype("float32") + optimizer = optimizers.SGD( + learning_rate=0.1, use_ema=True, ema_momentum=0.9 + ) + model = models.Sequential( + [layers.Dense(2, kernel_initializer="ones", use_bias=False)] + ) + model.compile(loss="mse", optimizer=optimizer, run_eagerly=True) + model.fit(x_train, y_train, batch_size=1, epochs=2) + self.assertAllClose( + optimizer._model_variables_moving_average[0].numpy(), + [[0.891, 0.891]], + atol=1e-5, + ) + self.assertAllClose( + model.trainable_variables[0].numpy(), + [[0.891, 0.891]], + atol=1e-5, + ) + + def test_constraints_are_applied(self): + v = backend.Variable(np.random.random((2, 2)) - 1.0) + v.constraint = constraints.NonNeg() + optimizer = optimizers.SGD(learning_rate=0.0001) + grad = backend.numpy.zeros((2, 2)) + optimizer.apply_gradients([(grad, v)]) + self.assertAlmostEqual(np.min(v), 0.0) + + def test_get_method(self): + obj = optimizers.get("sgd") + self.assertIsInstance(obj, optimizers.SGD) + obj = optimizers.get("adamw") + self.assertIsInstance(obj, optimizers.AdamW) + + obj = optimizers.get(None) + self.assertEqual(obj, None) + + with self.assertRaises(ValueError): + optimizers.get("typo") + + def test_static_loss_scaling(self): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]]) * 1024.0 + optimizer = optimizers.SGD(learning_rate=1.0, loss_scale_factor=1024.0) + optimizer.apply_gradients([(grads, v)]) + self.assertEqual(optimizer.scale_loss(1.0), 1024.0) + self.assertAllClose(v, [[0.0, 0.0], [0.0, 0.0]]) + + def test_set_weights(self): + x = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + optimizer_1 = optimizers.Adam() + grads = backend.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]]) + optimizer_1.apply_gradients(zip([grads], [x])) + optimizer_2 = optimizers.Adam() + with self.assertRaisesRegex(ValueError, "You are calling*"): + optimizer_2.set_weights(optimizer_1.variables) + optimizer_2.build([x]) + optimizer_2.set_weights(optimizer_1.variables) + for i in range(len(optimizer_1.variables)): + self.assertAllClose( + optimizer_1.variables[i], + optimizer_2.variables[i], + ) + + def test_gradient_accumulation(self): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = optimizers.SGD( + learning_rate=1.0, gradient_accumulation_steps=3 + ) + self.assertEqual(optimizer.gradient_accumulation_steps, 3) + + # Iteration 1 + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + ) + self.assertAllClose(optimizer._iterations, 1) + self.assertAllClose(optimizer.iterations, 0) + + # Iteration 2 + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[2.0, 2.0], [2.0, 2.0]] + ) + self.assertAllClose(optimizer._iterations, 2) + self.assertAllClose(optimizer.iterations, 0) + + # Iteration 3 + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[0.0, 1.0], [2.0, 3.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]] + ) + self.assertAllClose(optimizer._iterations, 3) + self.assertAllClose(optimizer.iterations, 1) + + # Iteration 4 + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[0.0, 1.0], [2.0, 3.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + ) + self.assertAllClose(optimizer._iterations, 4) + self.assertAllClose(optimizer.iterations, 1) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="Requires TF") + def test_tf_checkpointing(self): + import tensorflow as tf + + model = models.Sequential([layers.Dense(2)]) + optimizer = optimizers.Adam() + x, y = np.random.random((1, 2)), np.random.random((1, 2)) + model.compile(optimizer, "mse") + model.train_on_batch(x, y) + ref_pred = model.predict(x) + + # Both model and optimizer are Trackables + checkpoint = tf.train.Checkpoint(model, optimizer=optimizer) + temp_filepath = os.path.join(self.get_temp_dir(), "tf_ckpt") + save_path = checkpoint.save(temp_filepath) + + # Keep training the model (predictions now differ) + model.train_on_batch(x, y) + pred = model.predict(x) + self.assertNotAllClose(pred, ref_pred, atol=1e-3) + + # Restore the model and check prediction correctness + checkpoint.restore(save_path) + pred = model.predict(x) + self.assertAllClose(pred, ref_pred, atol=1e-5) + + def test_callable_learning_rate(self): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = optimizers.SGD(learning_rate=lambda: 0.1) + self.assertAllClose(optimizer.iterations, 0) + optimizer.apply_gradients([(grads, v)]) + self.assertAllClose(v, [[0.9, 1.9], [2.9, 3.9]]) + self.assertAllClose(optimizer.iterations, 1) + + def test_overwrite_with_gradient(self): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + v.overwrite_with_gradient = True + v2 = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + grads2 = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + + optimizer = optimizers.SGD(learning_rate=1.0) + optimizer.apply_gradients([(grads, v), (grads2, v2)]) + + # `v` is overwritten by its gradient but `v2` is updated normally + self.assertAllClose(v, [[1.0, 1.0], [1.0, 1.0]]) + self.assertAllClose(v2, [[0.0, 1.0], [2.0, 3.0]]) + + def test_overwrite_with_gradient_with_gradient_accumulation(self): + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + v.overwrite_with_gradient = True + v2 = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grad_ones = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + grad_twos = backend.convert_to_tensor([[2.0, 2.0], [2.0, 2.0]]) + optimizer = optimizers.SGD( + learning_rate=1.0, gradient_accumulation_steps=2 + ) + + # Iteration 1 + optimizer.apply_gradients([(grad_ones, v), (grad_ones, v2)]) + self.assertAllClose(optimizer._iterations, 1) + self.assertAllClose(optimizer.iterations, 0) + self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose(v2, [[1.0, 2.0], [3.0, 4.0]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + ) + self.assertAllClose( + optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]] + ) + + # Iteration 2 + optimizer.apply_gradients([(grad_twos, v), (grad_twos, v2)]) + self.assertAllClose(optimizer._iterations, 2) + self.assertAllClose(optimizer.iterations, 1) + self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]]) + self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]] + ) + self.assertAllClose( + optimizer._accumulated_gradients[1], [[0.0, 0.0], [0.0, 0.0]] + ) + + # Iteration 3 + optimizer.apply_gradients([(grad_ones, v), (grad_ones, v2)]) + self.assertAllClose(optimizer._iterations, 3) + self.assertAllClose(optimizer.iterations, 1) + self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]]) + self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]]) + self.assertAllClose( + optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + ) + self.assertAllClose( + optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]] + ) + + @parameterized.parameters( + [ + ("adam",), + ("sgd",), + ("adamw",), + ("adagrad",), + ("rmsprop",), + ("adadelta",), + ("adamax",), + ("lion",), + ("nadam",), + ("ftrl",), + ("adafactor",), + ] + ) + def test_gradient_accumulation_with_weigth_decay(self, optimizer): + optimizer1 = optimizers.get( + {"class_name": optimizer, "config": {"weight_decay": 0.05}} + ) + optimizer3 = optimizers.get( + { + "class_name": optimizer, + "config": { + "weight_decay": 0.05, + "gradient_accumulation_steps": 3, + }, + } + ) + variable1 = backend.Variable([[0.9], [0.5]]) + variable3 = backend.Variable([[0.9], [0.5]]) + + for epoch in range(8): + grads3 = np.random.random([3, 2, 1]).astype("float32") + + grads1 = backend.convert_to_tensor(grads3.mean(axis=0)) + optimizer1.apply_gradients([(grads1, variable1)]) + + for batch in range(3): + grads3_ = backend.convert_to_tensor(grads3[batch]) + optimizer3.apply_gradients([(grads3_, variable3)]) + + self.assertAllClose(variable1, variable3) + + def test_setting_lr_to_callable_untracks_lr_var(self): + adam = optimizers.Adam(learning_rate=0.001) + self.assertLen(adam.variables, 2) + adam.learning_rate = optimizers.schedules.PolynomialDecay( + adam.learning_rate, 4 + ) + self.assertLen(adam.variables, 1) + + @parameterized.parameters( + [ + ("adam",), + ("sgd",), + ("adamw",), + ("adagrad",), + ("rmsprop",), + ("adadelta",), + ("adamax",), + ("lion",), + ("nadam",), + ("ftrl",), + ("adafactor",), + ] + ) + def test_pickleable_optimizers(self, optimizer): + optimizer = optimizers.get(optimizer) + reloaded = pickle.loads(pickle.dumps(optimizer)) + + self.assertEqual(optimizer.get_config(), reloaded.get_config()) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="The tf.Variable test can only run with TensorFlow backend.", + ) + def test_mixed_with_tf_variables(self): + import tensorflow as tf + + v = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + tf_v = tf.Variable([[1.0, 2.0], [3.0, 4.0]]) + tf_grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]]) + optimizer = optimizers.Adam(learning_rate=1.0) + optimizer.apply_gradients([(grads, v), (tf_grads, tf_v)]) + self.assertAllClose(optimizer.iterations, 1) + + # Test with no grads + with self.assertWarnsRegex( + UserWarning, "Gradients do not exist for variables" + ): + optimizer.apply_gradients([(grads, v), (None, tf_v)]) + self.assertAllClose(optimizer.iterations, 2) diff --git a/keras/src/optimizers/rmsprop.py b/keras/src/optimizers/rmsprop.py new file mode 100644 index 000000000000..b32b5b61d6b9 --- /dev/null +++ b/keras/src/optimizers/rmsprop.py @@ -0,0 +1,172 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export(["keras.optimizers.RMSprop"]) +class RMSprop(optimizer.Optimizer): + """Optimizer that implements the RMSprop algorithm. + + The gist of RMSprop is to: + + - Maintain a moving (discounted) average of the square of gradients + - Divide the gradient by the root of this average + + This implementation of RMSprop uses plain momentum, not Nesterov momentum. + + The centered version additionally maintains a moving average of the + gradients, and uses that average to estimate the variance. + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.001`. + rho: float, defaults to 0.9. Discounting factor for the old gradients. + momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the + momentum value, with a decay rate equals to `1 - momentum`. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults + to 1e-7. + centered: Boolean. If `True`, gradients are normalized by the estimated + variance of the gradient; if False, by the uncentered second moment. + Setting this to `True` may help with training, but is slightly more + expensive in terms of computation and memory. Defaults to `False`. + {{base_optimizer_keyword_args}} + + Example: + + >>> opt = keras.optimizers.RMSprop(learning_rate=0.1) + >>> var1 = keras.backend.Variable(10.0) + >>> loss = lambda: (var1 ** 2) / 2.0 # d(loss) / d(var1) = var1 + >>> opt.minimize(loss, [var1]) + >>> var1 + 9.683772 + + Reference: + + - [Hinton, 2012]( + http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + """ + + def __init__( + self, + learning_rate=0.001, + rho=0.9, + momentum=0.0, + epsilon=1e-7, + centered=False, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="rmsprop", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + name=name, + **kwargs, + ) + self.rho = rho + self.momentum = momentum + self.epsilon = epsilon + self.centered = centered + + def build(self, var_list): + if self.built: + return + + super().build(var_list) + + self._velocities = self.add_optimizer_variables(var_list, "velocity") + + self._momentums = [] + if self.momentum > 0: + self._momentums = self.add_optimizer_variables(var_list, "momentum") + + self._average_gradients = [] + if self.centered: + self._average_gradients = self.add_optimizer_variables( + var_list, "average_gradient" + ) + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + lr = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + + velocity = self._velocities[self._get_variable_index(variable)] + momentum = None + if self.momentum > 0: + momentum = self._momentums[self._get_variable_index(variable)] + average_grad = None + if self.centered: + average_grad = self._average_gradients[ + self._get_variable_index(variable) + ] + + rho = self.rho + + self.assign( + velocity, + ops.add( + ops.multiply(rho, velocity), + ops.multiply(1 - rho, ops.square(gradient)), + ), + ) + if self.centered: + self.assign( + average_grad, + ops.add( + ops.multiply(rho, average_grad), + ops.multiply(1 - rho, gradient), + ), + ) + denominator = velocity - ops.square(average_grad) + self.epsilon + else: + denominator = ops.add(velocity, self.epsilon) + increment = ops.divide( + ops.multiply(lr, gradient), ops.sqrt(denominator) + ) + if self.momentum > 0: + self.assign( + momentum, + ops.add(ops.multiply(self.momentum, momentum), increment), + ) + self.assign_sub(variable, momentum) + else: + self.assign_sub(variable, increment) + + def get_config(self): + config = super().get_config() + + config.update( + { + "rho": self.rho, + "momentum": self.momentum, + "epsilon": self.epsilon, + "centered": self.centered, + } + ) + return config + + +RMSprop.__doc__ = RMSprop.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/rmsprop_test.py b/keras/src/optimizers/rmsprop_test.py new file mode 100644 index 000000000000..f22dc82801bc --- /dev/null +++ b/keras/src/optimizers/rmsprop_test.py @@ -0,0 +1,77 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.rmsprop import RMSprop + + +class RMSpropTest(testing.TestCase): + def test_config(self): + optimizer = RMSprop( + learning_rate=0.5, + rho=0.8, + momentum=0.05, + epsilon=1e-6, + centered=True, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = RMSprop(learning_rate=0.5) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose( + vars, [-0.5811, 0.4189, 1.4189, 2.4189], rtol=1e-4, atol=1e-4 + ) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = RMSprop(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = RMSprop(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = RMSprop(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = RMSprop(centered=True) + + x = backend.Variable(np.ones([10])) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + golden = np.tile( + [[0.9967], [0.9933], [0.9908], [0.9885], [0.9864]], (1, 10) + ) + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = RMSprop(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = RMSprop(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/optimizers/schedules/__init__.py b/keras/src/optimizers/schedules/__init__.py new file mode 100644 index 000000000000..a6812ebb0827 --- /dev/null +++ b/keras/src/optimizers/schedules/__init__.py @@ -0,0 +1,16 @@ +from keras.src.optimizers.schedules.learning_rate_schedule import CosineDecay +from keras.src.optimizers.schedules.learning_rate_schedule import ( + CosineDecayRestarts, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + ExponentialDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + InverseTimeDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + PiecewiseConstantDecay, +) +from keras.src.optimizers.schedules.learning_rate_schedule import ( + PolynomialDecay, +) diff --git a/keras/src/optimizers/schedules/learning_rate_schedule.py b/keras/src/optimizers/schedules/learning_rate_schedule.py new file mode 100644 index 000000000000..9f2df3398dfe --- /dev/null +++ b/keras/src/optimizers/schedules/learning_rate_schedule.py @@ -0,0 +1,976 @@ +"""Various learning rate schedule functions.""" + +import math + +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.saving import serialization_lib + + +@keras_export("keras.optimizers.schedules.LearningRateSchedule") +class LearningRateSchedule: + """The learning rate schedule base class. + + You can use a learning rate schedule to modulate how the learning rate + of your optimizer changes over time. + + Several built-in learning rate schedules are available, such as + `keras.optimizers.schedules.ExponentialDecay` or + `keras.optimizers.schedules.PiecewiseConstantDecay`: + + ```python + lr_schedule = keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=1e-2, + decay_steps=10000, + decay_rate=0.9) + optimizer = keras.optimizers.SGD(learning_rate=lr_schedule) + ``` + + A `LearningRateSchedule` instance can be passed in as the `learning_rate` + argument of any optimizer. + + To implement your own schedule object, you should implement the `__call__` + method, which takes a `step` argument (scalar integer tensor, the + current training step count). + Like for any other Keras object, you can also optionally + make your object serializable by implementing the `get_config` + and `from_config` methods. + + Example: + + ```python + class MyLRSchedule(keras.optimizers.schedules.LearningRateSchedule): + + def __init__(self, initial_learning_rate): + self.initial_learning_rate = initial_learning_rate + + def __call__(self, step): + return self.initial_learning_rate / (step + 1) + + optimizer = keras.optimizers.SGD(learning_rate=MyLRSchedule(0.1)) + ``` + """ + + def __call__(self, step): + raise NotImplementedError( + f"Learning rate schedule '{self.__class__.__name__}' " + "must override `__call__(self, step)`." + ) + + def get_config(self): + raise NotImplementedError( + f"Learning rate schedule '{self.__class__.__name__}' " + "must override `get_config()` in order to be serializable." + ) + + @classmethod + def from_config(cls, config): + """Instantiates a `LearningRateSchedule` from its config. + + Args: + config: Output of `get_config()`. + + Returns: + A `LearningRateSchedule` instance. + """ + return cls(**config) + + +@keras_export("keras.optimizers.schedules.ExponentialDecay") +class ExponentialDecay(LearningRateSchedule): + """A `LearningRateSchedule` that uses an exponential decay schedule. + + When training a model, it is often useful to lower the learning rate as + the training progresses. This schedule applies an exponential decay function + to an optimizer step, given a provided initial learning rate. + + The schedule is a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + return initial_learning_rate * decay_rate ^ (step / decay_steps) + ``` + + If the argument `staircase` is `True`, then `step / decay_steps` is + an integer division and the decayed learning rate follows a + staircase function. + + You can pass this schedule directly into a `keras.optimizers.Optimizer` + as the learning rate. + Example: When fitting a Keras model, decay every 100000 steps with a base + of 0.96: + + ```python + initial_learning_rate = 0.1 + lr_schedule = keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate, + decay_steps=100000, + decay_rate=0.96, + staircase=True) + + model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr_schedule), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + The learning rate schedule is also serializable and deserializable using + `keras.optimizers.schedules.serialize` and + `keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A Python float. The initial learning rate. + decay_steps: A Python integer. Must be positive. See the decay + computation above. + decay_rate: A Python float. The decay rate. + staircase: Boolean. If `True` decay the learning rate at discrete + intervals. + name: String. Optional name of the operation. Defaults to + `"ExponentialDecay`". + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar tensor of the + same type as `initial_learning_rate`. + """ + + def __init__( + self, + initial_learning_rate, + decay_steps, + decay_rate, + staircase=False, + name="ExponentialDecay", + ): + super().__init__() + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.decay_rate = decay_rate + self.staircase = staircase + self.name = name + + if self.decay_steps <= 0: + raise ValueError( + "Argument `decay_steps` must be > 0. " + f"Received: decay_steps={self.decay_steps}" + ) + + def __call__(self, step): + with ops.name_scope(self.name): + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate + ) + dtype = initial_learning_rate.dtype + decay_steps = ops.cast(self.decay_steps, dtype) + decay_rate = ops.cast(self.decay_rate, dtype) + + global_step_recomp = ops.cast(step, dtype) + p = global_step_recomp / decay_steps + if self.staircase: + p = ops.floor(p) + return ops.multiply(initial_learning_rate, ops.power(decay_rate, p)) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "decay_rate": self.decay_rate, + "staircase": self.staircase, + "name": self.name, + } + + +@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay") +class PiecewiseConstantDecay(LearningRateSchedule): + """A `LearningRateSchedule` that uses a piecewise constant decay schedule. + + The function returns a 1-arg callable to compute the piecewise constant + when passed the current optimizer step. This can be useful for changing the + learning rate value across different invocations of optimizer functions. + + Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 + for the next 10000 steps, and 0.1 for any additional steps. + + ```python + step = ops.array(0) + boundaries = [100000, 110000] + values = [1.0, 0.5, 0.1] + learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay( + boundaries, values) + + # Later, whenever we perform an optimization step, we pass in the step. + learning_rate = learning_rate_fn(step) + ``` + + You can pass this schedule directly into a `keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `keras.optimizers.schedules.serialize` and + `keras.optimizers.schedules.deserialize`. + + Args: + boundaries: A list of Python numbers with strictly increasing + entries, and with all elements having the same type as the + optimizer step. + values: A list of Python numbers that specifies the values for the + intervals defined by `boundaries`. It should have one more + element than `boundaries`, and all elements should have the same + type. + name: A string. Optional name of the operation. Defaults to + `"PiecewiseConstant"`. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar tensor of the + same type as the boundary tensors. + + The output of the 1-arg function that takes the `step` + is `values[0]` when `step <= boundaries[0]`, + `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, + ..., and `values[-1]` when `step > boundaries[-1]`. + + + Raises: + ValueError: if the number of elements in the `boundaries` and `values` + lists do not match. + """ + + def __init__(self, boundaries, values, name="PiecewiseConstant"): + super().__init__() + + if len(boundaries) != len(values) - 1: + raise ValueError( + "The length of boundaries should be 1 less than the length of " + f"values. Received: boundaries={boundaries} of length " + f"{len(boundaries)}, and values={values} " + f"of length {len(values)}." + ) + + self.boundaries = boundaries + self.values = values + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name): + boundaries = [ops.convert_to_tensor(x) for x in self.boundaries] + values = [ops.convert_to_tensor(x) for x in self.values] + step = ops.convert_to_tensor(step) + + for i, b in enumerate(boundaries): + if b.dtype != step.dtype: + # We cast the boundaries to have the same type as the step + b = ops.cast(b, step.dtype) + boundaries[i] = b + + result_dtype = values[0].dtype + result_value = ops.array(0, dtype=result_dtype) + + # For each range between boundaries, we check whether the step is + # within that range, cast the resulting boolean to a number, + # and multiply the result by the corresponding value for the range. + # Taking the sum of these yields a piecewise constant function. + step_less_than_first_boundary = ops.cast( + step <= boundaries[0], result_dtype + ) + result_value += step_less_than_first_boundary * values[0] + + step_greater_than_last_boundary = ops.cast( + step > boundaries[-1], result_dtype + ) + result_value += step_greater_than_last_boundary * values[-1] + + for low, high, value in zip( + boundaries[:-1], boundaries[1:], values[1:-1] + ): + # Need to bind v here; can do this with lambda v=v: ... + step_in_range = ops.cast( + (step > low) & (step <= high), result_dtype + ) + result_value += step_in_range * value + + return result_value + + def get_config(self): + return { + "boundaries": self.boundaries, + "values": self.values, + "name": self.name, + } + + +@keras_export("keras.optimizers.schedules.PolynomialDecay") +class PolynomialDecay(LearningRateSchedule): + """A `LearningRateSchedule` that uses a polynomial decay schedule. + + It is commonly observed that a monotonically decreasing learning rate, whose + degree of change is carefully chosen, results in a better performing model. + This schedule applies a polynomial decay function to an optimizer step, + given a provided `initial_learning_rate`, to reach an `end_learning_rate` + in the given `decay_steps`. + + It requires a `step` value to compute the decayed learning rate. You + can just pass a backend variable that you increment at each training + step. + + The schedule is a 1-arg callable that produces a decayed learning rate + when passed the current optimizer step. This can be useful for changing the + learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + step = min(step, decay_steps) + return ((initial_learning_rate - end_learning_rate) * + (1 - step / decay_steps) ^ (power) + ) + end_learning_rate + ``` + + If `cycle` is True then a multiple of `decay_steps` is used, the first one + that is bigger than `step`. + + ```python + def decayed_learning_rate(step): + decay_steps = decay_steps * ceil(step / decay_steps) + return ((initial_learning_rate - end_learning_rate) * + (1 - step / decay_steps) ^ (power) + ) + end_learning_rate + ``` + + You can pass this schedule directly into a `keras.optimizers.Optimizer` + as the learning rate. + Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using + sqrt (i.e. power=0.5): + + ```python + ... + starter_learning_rate = 0.1 + end_learning_rate = 0.01 + decay_steps = 10000 + learning_rate_fn = keras.optimizers.schedules.PolynomialDecay( + starter_learning_rate, + decay_steps, + end_learning_rate, + power=0.5) + + model.compile(optimizer=keras.optimizers.SGD( + learning_rate=learning_rate_fn), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + The learning rate schedule is also serializable and deserializable using + `keras.optimizers.schedules.serialize` and + `keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A Python float. The initial learning rate. + decay_steps: A Python integer. Must be positive. See the decay + computation above. + end_learning_rate: A Python float. The minimal end learning rate. + power: A Python float. The power of the polynomial. Defaults to + `1.0`. + cycle: A boolean, whether it should cycle beyond decay_steps. + name: String. Optional name of the operation. Defaults to + `"PolynomialDecay"`. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar tensor of the + same type as `initial_learning_rate`. + """ + + def __init__( + self, + initial_learning_rate, + decay_steps, + end_learning_rate=0.0001, + power=1.0, + cycle=False, + name="PolynomialDecay", + ): + super().__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.end_learning_rate = end_learning_rate + self.power = power + self.cycle = cycle + self.name = name + + if self.decay_steps <= 0: + raise ValueError( + "Argument `decay_steps` must be > 0. " + f"Received: decay_steps={self.decay_steps}" + ) + + def __call__(self, step): + with ops.name_scope(self.name): + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate + ) + dtype = initial_learning_rate.dtype + end_learning_rate = ops.cast(self.end_learning_rate, dtype) + power = ops.cast(self.power, dtype) + + global_step_recomp = ops.cast(step, dtype) + decay_steps_recomp = ops.cast(self.decay_steps, dtype) + if self.cycle: + # Find the first multiple of decay_steps that is bigger than + # global_step. If global_step is zero set the multiplier to 1 + multiplier = ops.where( + ops.equal(global_step_recomp, 0), + 1.0, + ops.ceil(global_step_recomp / self.decay_steps), + ) + decay_steps_recomp = ops.multiply( + decay_steps_recomp, multiplier + ) + else: + # Make sure that the global_step used is not bigger than + # decay_steps. + global_step_recomp = ops.minimum( + global_step_recomp, decay_steps_recomp + ) + + p = ops.divide(global_step_recomp, decay_steps_recomp) + return ops.add( + ops.multiply( + initial_learning_rate - end_learning_rate, + ops.power(1 - p, power), + ), + end_learning_rate, + ) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "end_learning_rate": self.end_learning_rate, + "power": self.power, + "cycle": self.cycle, + "name": self.name, + } + + +@keras_export("keras.optimizers.schedules.InverseTimeDecay") +class InverseTimeDecay(LearningRateSchedule): + """A `LearningRateSchedule` that uses an inverse time decay schedule. + + When training a model, it is often useful to lower the learning rate as + the training progresses. This schedule applies the inverse decay function + to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a backend variable that you increment at each training step. + + The schedule is a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + return initial_learning_rate / (1 + decay_rate * step / decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + def decayed_learning_rate(step): + return initial_learning_rate / + (1 + decay_rate * floor(step / decay_step)) + ``` + + You can pass this schedule directly into a `keras.optimizers.Optimizer` + as the learning rate. + Example: Fit a Keras model when decaying 1/t with a rate of 0.5: + + ```python + ... + initial_learning_rate = 0.1 + decay_steps = 1.0 + decay_rate = 0.5 + learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay( + initial_learning_rate, decay_steps, decay_rate) + + model.compile(optimizer=keras.optimizers.SGD( + learning_rate=learning_rate_fn), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + Args: + initial_learning_rate: A Python float. The initial learning rate. + decay_steps: How often to apply decay. + decay_rate: A Python number. The decay rate. + staircase: Whether to apply decay in a discrete staircase, as o + pposed to continuous, fashion. + name: String. Optional name of the operation. Defaults to + `"InverseTimeDecay"`. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar tensor of the + same type as `initial_learning_rate`. + """ + + def __init__( + self, + initial_learning_rate, + decay_steps, + decay_rate, + staircase=False, + name="InverseTimeDecay", + ): + super().__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.decay_rate = decay_rate + self.staircase = staircase + self.name = name + + if self.decay_steps <= 0: + raise ValueError( + "Argument `decay_steps` must be > 0. " + f"Received: decay_steps={self.decay_steps}" + ) + + def __call__(self, step): + with ops.name_scope(self.name): + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate + ) + dtype = initial_learning_rate.dtype + decay_steps = ops.cast(self.decay_steps, dtype) + decay_rate = ops.cast(self.decay_rate, dtype) + + global_step_recomp = ops.cast(step, dtype) + p = global_step_recomp / decay_steps + if self.staircase: + p = ops.floor(p) + const = ops.cast(ops.array(1), dtype) + denom = ops.add(const, ops.multiply(decay_rate, p)) + return ops.divide(initial_learning_rate, denom) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "decay_rate": self.decay_rate, + "staircase": self.staircase, + "name": self.name, + } + + +@keras_export("keras.optimizers.schedules.CosineDecay") +class CosineDecay(LearningRateSchedule): + """A `LearningRateSchedule` that uses a cosine decay with optional warmup. + + See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983), + SGDR: Stochastic Gradient Descent with Warm Restarts. + + For the idea of a linear warmup of our learning rate, + see [Goyal et al.](https://arxiv.org/pdf/1706.02677.pdf). + + When we begin training a model, we often want an initial increase in our + learning rate followed by a decay. If `warmup_target` is an int, this + schedule applies a linear increase per optimizer step to our learning rate + from `initial_learning_rate` to `warmup_target` for a duration of + `warmup_steps`. Afterwards, it applies a cosine decay function taking our + learning rate from `warmup_target` to `alpha` for a duration of + `decay_steps`. If `warmup_target` is None we skip warmup and our decay + will take our learning rate from `initial_learning_rate` to `alpha`. + It requires a `step` value to compute the learning rate. You can + just pass a backend variable that you increment at each training step. + + The schedule is a 1-arg callable that produces a warmup followed by a + decayed learning rate when passed the current optimizer step. This can be + useful for changing the learning rate value across different invocations of + optimizer functions. + + Our warmup is computed as: + + ```python + def warmup_learning_rate(step): + completed_fraction = step / warmup_steps + total_delta = target_warmup - initial_learning_rate + return completed_fraction * total_delta + ``` + + And our decay is computed as: + + ```python + if warmup_target is None: + initial_decay_lr = initial_learning_rate + else: + initial_decay_lr = warmup_target + + def decayed_learning_rate(step): + step = min(step, decay_steps) + cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + return initial_decay_lr * decayed + ``` + + Example usage without warmup: + + ```python + decay_steps = 1000 + initial_learning_rate = 0.1 + lr_decayed_fn = keras.optimizers.schedules.CosineDecay( + initial_learning_rate, decay_steps) + ``` + + Example usage with warmup: + + ```python + decay_steps = 1000 + initial_learning_rate = 0 + warmup_steps = 1000 + target_learning_rate = 0.1 + lr_warmup_decayed_fn = keras.optimizers.schedules.CosineDecay( + initial_learning_rate, decay_steps, warmup_target=target_learning_rate, + warmup_steps=warmup_steps + ) + ``` + + You can pass this schedule directly into a `keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `keras.optimizers.schedules.serialize` and + `keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A Python float. The initial learning rate. + decay_steps: A Python int. Number of steps to decay over. + alpha: A Python float. Minimum learning rate value for decay as a + fraction of `initial_learning_rate`. + name: String. Optional name of the operation. Defaults to + `"CosineDecay"`. + warmup_target: A Python float. The target learning rate for our + warmup phase. Will cast to the `initial_learning_rate` datatype. + Setting to `None` will skip warmup and begins decay phase from + `initial_learning_rate`. Otherwise scheduler will warmup from + `initial_learning_rate` to `warmup_target`. + warmup_steps: A Python int. Number of steps to warmup over. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar tensor of the + same type as `initial_learning_rate`. + """ + + def __init__( + self, + initial_learning_rate, + decay_steps, + alpha=0.0, + name="CosineDecay", + warmup_target=None, + warmup_steps=0, + ): + super().__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.alpha = alpha + self.name = name + self.warmup_steps = warmup_steps + self.warmup_target = warmup_target + + if self.decay_steps <= 0: + raise ValueError( + "Argument `decay_steps` must be > 0. " + f"Received: decay_steps={self.decay_steps}" + ) + + def _decay_function(self, step, decay_steps, decay_from_lr, dtype): + with ops.name_scope(self.name): + completed_fraction = ops.divide(step, decay_steps) + pi = ops.array(math.pi, dtype=dtype) + cosine_decayed = 0.5 * ( + 1.0 + ops.cos(ops.multiply(pi, completed_fraction)) + ) + decayed = (1 - self.alpha) * cosine_decayed + self.alpha + return ops.multiply(decay_from_lr, decayed) + + def _warmup_function( + self, step, warmup_steps, warmup_target, initial_learning_rate + ): + with ops.name_scope(self.name): + completed_fraction = step / warmup_steps + total_step_delta = warmup_target - initial_learning_rate + return total_step_delta * completed_fraction + initial_learning_rate + + def __call__(self, step): + with ops.name_scope(self.name): + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate + ) + dtype = initial_learning_rate.dtype + decay_steps = ops.cast(self.decay_steps, dtype) + global_step_recomp = ops.cast(step, dtype) + + if self.warmup_target is None: + global_step_recomp = ops.minimum( + global_step_recomp, decay_steps + ) + return self._decay_function( + global_step_recomp, + decay_steps, + initial_learning_rate, + dtype, + ) + + warmup_target = ops.cast(self.warmup_target, dtype) + warmup_steps = ops.cast(self.warmup_steps, dtype) + + global_step_recomp = ops.minimum( + global_step_recomp, decay_steps + warmup_steps + ) + + return ops.cond( + global_step_recomp < warmup_steps, + lambda: self._warmup_function( + global_step_recomp, + warmup_steps, + warmup_target, + initial_learning_rate, + ), + lambda: self._decay_function( + global_step_recomp - warmup_steps, + decay_steps, + warmup_target, + dtype, + ), + ) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "alpha": self.alpha, + "name": self.name, + "warmup_target": self.warmup_target, + "warmup_steps": self.warmup_steps, + } + + +@keras_export("keras.optimizers.schedules.CosineDecayRestarts") +class CosineDecayRestarts(LearningRateSchedule): + """A `LearningRateSchedule` that uses a cosine decay schedule with restarts. + + See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983), + SGDR: Stochastic Gradient Descent with Warm Restarts. + + When training a model, it is often useful to lower the learning rate as + the training progresses. This schedule applies a cosine decay function with + restarts to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a backend variable that you increment at each training step. + + The schedule is a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + + The learning rate multiplier first decays + from 1 to `alpha` for `first_decay_steps` steps. Then, a warm + restart is performed. Each new warm restart runs for `t_mul` times more + steps and with `m_mul` times initial learning rate as the new learning rate. + + Example: + ```python + first_decay_steps = 1000 + lr_decayed_fn = ( + keras.optimizers.schedules.CosineDecayRestarts( + initial_learning_rate, + first_decay_steps)) + ``` + + You can pass this schedule directly into a `keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `keras.optimizers.schedules.serialize` and + `keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A Python float. The initial learning rate. + first_decay_steps: A Python integer. Number of steps to decay over. + t_mul: A Python float. Used to derive the number of iterations in + the i-th period. + m_mul: A Python float. Used to derive the initial learning rate of + the i-th period. + alpha: A Python float. Minimum learning rate value as a fraction of + the `initial_learning_rate`. + name: String. Optional name of the operation. Defaults to + `"SGDRDecay"`. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar tensor of the + same type as `initial_learning_rate`. + """ + + def __init__( + self, + initial_learning_rate, + first_decay_steps, + t_mul=2.0, + m_mul=1.0, + alpha=0.0, + name="SGDRDecay", + ): + super().__init__() + + self.initial_learning_rate = initial_learning_rate + self.first_decay_steps = first_decay_steps + self._t_mul = t_mul + self._m_mul = m_mul + self.alpha = alpha + self.name = name + + if self.first_decay_steps <= 0: + raise ValueError( + "Argument `first_decay_steps` must be > 0. " + f"Received: first_decay_steps={self.first_decay_steps}" + ) + + def __call__(self, step): + with ops.name_scope(self.name): + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate + ) + dtype = initial_learning_rate.dtype + first_decay_steps = ops.cast(self.first_decay_steps, dtype) + alpha = ops.cast(self.alpha, dtype) + t_mul = ops.cast(self._t_mul, dtype) + m_mul = ops.cast(self._m_mul, dtype) + + global_step_recomp = ops.cast(step, dtype) + completed_fraction = global_step_recomp / first_decay_steps + + def compute_step(completed_fraction, geometric=False): + """Helper for `cond` operation.""" + if geometric: + # ops.log is sensitive to the precision of dtype, so we need + # the additional casting + i_restart = ops.floor( + ops.log( + ops.cast( + 1.0 - completed_fraction * (1.0 - t_mul), dtype + ) + ) + / ops.log(t_mul) + ) + + sum_r = ops.divide( + 1.0 - ops.power(t_mul, i_restart), (1.0 - t_mul) + ) + completed_fraction = ops.divide( + ops.subtract(completed_fraction, sum_r), + ops.power(t_mul, i_restart), + ) + + else: + i_restart = ops.floor(completed_fraction) + completed_fraction -= i_restart + + return i_restart, completed_fraction + + i_restart, completed_fraction = ops.cond( + ops.equal(t_mul, 1.0), + lambda: compute_step(completed_fraction, geometric=False), + lambda: compute_step(completed_fraction, geometric=True), + ) + + m_fac = ops.power(m_mul, i_restart) + cosine_decayed = ( + 0.5 + * m_fac + * ( + 1.0 + + ops.cos( + ops.multiply( + ops.array(math.pi, dtype=dtype), completed_fraction + ) + ) + ) + ) + decayed = ops.add(ops.multiply((1 - alpha), cosine_decayed), alpha) + + return ops.multiply(initial_learning_rate, decayed) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "first_decay_steps": self.first_decay_steps, + "t_mul": self._t_mul, + "m_mul": self._m_mul, + "alpha": self.alpha, + "name": self.name, + } + + +@keras_export("keras.optimizers.schedules.serialize") +def serialize(learning_rate_schedule): + """Serializes a `LearningRateSchedule` into a JSON-compatible dict. + + Args: + learning_rate_schedule: The `LearningRateSchedule` object to serialize. + + Returns: + A JSON-serializable dict representing the object's config. + + Example: + + >>> lr_schedule = keras.optimizers.schedules.ExponentialDecay( + ... 0.1, decay_steps=100000, decay_rate=0.96, staircase=True) + >>> keras.optimizers.schedules.serialize(lr_schedule) + {'module': 'keras.optimizers.schedules', + 'class_name': 'ExponentialDecay', 'config': {...}, + 'registered_name': None} + """ + return serialization_lib.serialize_keras_object(learning_rate_schedule) + + +@keras_export("keras.optimizers.schedules.deserialize") +def deserialize(config, custom_objects=None): + """Instantiates a `LearningRateSchedule` object from a serialized form. + + Args: + config: The serialized form of the `LearningRateSchedule`. Dictionary of + the form {'class_name': str, 'config': dict}. + custom_objects: A dictionary mapping class names (or function names) of + custom (non-Keras) objects to class/functions. + + Returns: + A `LearningRateSchedule` object. + + Example: + + ```python + # Configuration for PolynomialDecay + config = { + 'class_name': 'PolynomialDecay', + 'config': {'cycle': False, + 'decay_steps': 10000, + 'end_learning_rate': 0.01, + 'initial_learning_rate': 0.1, + 'name': None, + 'power': 0.5 + } + } + lr_schedule = keras.optimizers.schedules.deserialize(config) + ``` + """ + return serialization_lib.deserialize_keras_object( + config, + module_objects=globals(), + custom_objects=custom_objects, + printable_module_name="decay", + ) diff --git a/keras/src/optimizers/schedules/learning_rate_schedule_test.py b/keras/src/optimizers/schedules/learning_rate_schedule_test.py new file mode 100644 index 000000000000..052db9e93945 --- /dev/null +++ b/keras/src/optimizers/schedules/learning_rate_schedule_test.py @@ -0,0 +1,462 @@ +"""Tests for learning rate schedule API.""" + +import math + +import numpy as np +import pytest + +from keras.src import backend +from keras.src import layers +from keras.src import optimizers +from keras.src import testing +from keras.src.models import Sequential +from keras.src.optimizers import schedules + + +class TestFitLRSchedulesFlow(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_fit_lr_correctness(self): + model = Sequential( + [ + layers.Dense( + 2, kernel_initializer="ones", bias_initializer="ones" + ) + ] + ) + optimizer = optimizers.Adam( + learning_rate=schedules.ExponentialDecay( + initial_learning_rate=0.05, decay_steps=1, decay_rate=0.9 + ) + ) + self.assertEqual(len(optimizer.variables), 1) + self.assertEqual(optimizer.variables[0], 0) + + model.compile(optimizer=optimizer, loss="mse") + x = np.arange(32).reshape((16, 2)) + y = np.arange(32).reshape((16, 2)) + history = model.fit(x, y, epochs=3, batch_size=4, shuffle=False) + self.assertEqual(optimizer.variables[0], 4 * 3) + self.assertAllClose( + history.history["loss"], + [230.79457092285156, 128.30319213867188, 79.33648681640625], + rtol=5e-5, + ) + + +class ExponentialDecayTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + schedules.ExponentialDecay( + initial_learning_rate=0.05, + decay_steps=10, + decay_rate=0.96, + staircase=True, + name="my_ed", + ) + ) + + def test_continuous(self): + step = 5 + decayed_lr = schedules.ExponentialDecay(0.05, 10, 0.96) + expected = 0.05 * 0.96 ** (5.0 / 10.0) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_staircase(self): + step = backend.Variable(1.0) + decayed_lr = schedules.ExponentialDecay(0.1, 3, 0.96, staircase=True) + + # No change to learning rate due to staircase + expected = 0.1 + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + expected = 0.1 + step.assign(2) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + # Decayed learning rate + expected = 0.1 * 0.96 ** (100 // 3) + step.assign(100) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_variables(self): + step = backend.Variable(1.0) + decayed_lr = schedules.ExponentialDecay(0.1, 3, 0.96, staircase=True) + + # No change to learning rate + step.assign(1) + self.assertAllClose(decayed_lr(step), 0.1, 1e-6) + step.assign(2) + self.assertAllClose(decayed_lr(step), 0.1, 1e-6) + # Decayed learning rate + step.assign(100) + expected = 0.1 * 0.96 ** (100 // 3) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + +class PiecewiseConstantDecayTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + schedules.PiecewiseConstantDecay( + boundaries=[10, 20], values=[1, 2, 3], name="my_pcd" + ) + ) + + def test_piecewise_values(self): + x = backend.Variable(-999.0) + decayed_lr = schedules.PiecewiseConstantDecay( + [100, 110, 120], [1.0, 0.1, 0.01, 0.001] + ) + + self.assertAllClose(decayed_lr(x), 1.0, 1e-6) + x.assign(100) + self.assertAllClose(decayed_lr(x), 1.0, 1e-6) + x.assign(105) + self.assertAllClose(decayed_lr(x), 0.1, 1e-6) + x.assign(110) + self.assertAllClose(decayed_lr(x), 0.1, 1e-6) + x.assign(120) + self.assertAllClose(decayed_lr(x), 0.01, 1e-6) + x.assign(999) + self.assertAllClose(decayed_lr(x), 0.001, 1e-6) + + def test_boundary_values(self): + # Test casting boundaries from int32 to int64. + x_int64 = backend.Variable(0, dtype="int64", trainable=False) + boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] + decayed_lr = schedules.PiecewiseConstantDecay(boundaries, values) + + self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6) + x_int64.assign(1) + self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6) + x_int64.assign(2) + self.assertAllClose(decayed_lr(x_int64), 0.5, 1e-6) + x_int64.assign(3) + self.assertAllClose(decayed_lr(x_int64), 0.6, 1e-6) + x_int64.assign(4) + self.assertAllClose(decayed_lr(x_int64), 0.7, 1e-6) + + +class LinearDecayTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + schedules.PolynomialDecay( + initial_learning_rate=0.1, + decay_steps=100, + end_learning_rate=0.005, + power=1.0, + cycle=False, + name="my_ld", + ) + ) + + def test_halfway(self): + step = 5 + lr = 0.05 + end_lr = 0.0 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr) + expected = lr * 0.5 + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_end(self): + step = 10 + lr = 0.05 + end_lr = 0.001 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr) + expected = end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_halfway_with_end(self): + step = 5 + lr = 0.05 + end_lr = 0.001 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr) + expected = (lr + end_lr) * 0.5 + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_beyond_end(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr) + expected = end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_beyond_end_with_cycle(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, cycle=True) + expected = (lr - end_lr) * 0.25 + end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + +class SqrtDecayTest(testing.TestCase): + def test_halfway(self): + step = 5 + lr = 0.05 + end_lr = 0.0 + power = 0.5 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power) + expected = lr * 0.5**power + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_end(self): + step = 10 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_halfway_with_end(self): + step = 5 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power) + expected = (lr - end_lr) * 0.5**power + end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_beyond_end(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_beyond_end_with_cycle(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = schedules.PolynomialDecay( + lr, 10, end_lr, power=power, cycle=True + ) + expected = (lr - end_lr) * 0.25**power + end_lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_begin_with_cycle(self): + lr = 0.001 + decay_steps = 10 + step = 0 + decayed_lr = schedules.PolynomialDecay(lr, decay_steps, cycle=True) + expected = lr + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + +class InverseTimeDecayTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + schedules.InverseTimeDecay( + initial_learning_rate=0.05, + decay_steps=10, + decay_rate=0.96, + staircase=True, + name="my_itd", + ) + ) + + def test_decay(self): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = backend.Variable(0.0) + decayed_lr = schedules.InverseTimeDecay(initial_lr, k, decay_rate) + + for i in range(k + 1): + expected = initial_lr / (1 + i / k * decay_rate) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + step.assign(step + 1) + + def test_staircase(self): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = backend.Variable(0.0) + decayed_lr = schedules.InverseTimeDecay( + initial_lr, k, decay_rate, staircase=True + ) + + for i in range(k + 1): + expected = initial_lr / (1 + decay_rate * (i // k)) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + step.assign(step + 1) + + +class CosineDecayTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + schedules.CosineDecay( + initial_learning_rate=0.05, + decay_steps=10, + alpha=0.1, + warmup_target=0.2, + warmup_steps=2, + name="my_cd", + ) + ) + + def np_cosine_decay(self, step, decay_steps, alpha=0.0): + step = min(step, decay_steps) + completed_fraction = step / decay_steps + decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) + return (1.0 - alpha) * decay + alpha + + def test_decay(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecay(initial_lr, num_training_steps) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def linear_warmup(self, step, warmup_steps, initial_lr, target_lr): + completed_fraction = step / warmup_steps + total_delta = target_lr - initial_lr + return completed_fraction * total_delta + + def test_warmup(self): + warmup_steps = 1500 + initial_lr = 0.0 + target_lr = 10.0 + for step in range(0, 1500, 250): + lr = schedules.CosineDecay( + initial_lr, + 10, + warmup_target=target_lr, + warmup_steps=warmup_steps, + ) + expected = self.linear_warmup( + step, warmup_steps, initial_lr, target_lr + ) + self.assertAllClose(lr(step), expected) + + def test_alpha(self): + num_training_steps = 1000 + initial_lr = 1.0 + alpha = 0.1 + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecay( + initial_lr, num_training_steps, alpha + ) + expected = self.np_cosine_decay(step, num_training_steps, alpha) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_float64(self): + num_training_steps = 1000 + initial_lr = np.float64(1.0) + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecay(initial_lr, num_training_steps) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_warmup_decay(self): + warmup_steps = 2000 + decay_steps = 1000 + initial_lr = 0.0 + target_lr = 10.0 + for step in range(0, 3000, 250): + lr = schedules.CosineDecay( + initial_lr, + decay_steps, + warmup_target=target_lr, + warmup_steps=warmup_steps, + ) + if step < warmup_steps + 1: + expected = self.linear_warmup( + step, warmup_steps, initial_lr, target_lr + ) + else: + expected = target_lr * self.np_cosine_decay( + step - warmup_steps, decay_steps + ) + self.assertAllClose(lr(step), expected) + + +class CosineDecayRestartsTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test( + schedules.CosineDecayRestarts( + initial_learning_rate=0.05, + first_decay_steps=10, + alpha=0.1, + t_mul=3.0, + m_mul=4.0, + name="my_cdr", + ) + ) + + def np_cosine_decay_restarts( + self, step, decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0 + ): + fac = 1.0 + while step >= decay_steps: + step -= decay_steps + decay_steps *= t_mul + fac *= m_mul + + completed_fraction = step / decay_steps + decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) + return (1.0 - alpha) * decay + alpha + + def test_decay(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecayRestarts( + initial_lr, num_training_steps + ) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_float64(self): + num_training_steps = 1000 + initial_lr = np.float64(1.0) + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecayRestarts( + initial_lr, num_training_steps + ) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_alpha(self): + num_training_steps = 1000 + initial_lr = 1.0 + alpha = 0.1 + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecayRestarts( + initial_lr, num_training_steps, alpha=alpha + ) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, alpha=alpha + ) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_mmul(self): + num_training_steps = 1000 + initial_lr = 1.0 + m_mul = 0.9 + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecayRestarts( + initial_lr, num_training_steps, m_mul=m_mul + ) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, m_mul=m_mul + ) + self.assertAllClose(decayed_lr(step), expected, 1e-6) + + def test_tmul(self): + num_training_steps = 1000 + initial_lr = 1.0 + t_mul = 1.0 + for step in range(0, 1500, 250): + decayed_lr = schedules.CosineDecayRestarts( + initial_lr, num_training_steps, t_mul=t_mul + ) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, t_mul=t_mul + ) + self.assertAllClose(decayed_lr(step), expected, 1e-6) diff --git a/keras/src/optimizers/sgd.py b/keras/src/optimizers/sgd.py new file mode 100644 index 000000000000..15c951ed8d06 --- /dev/null +++ b/keras/src/optimizers/sgd.py @@ -0,0 +1,138 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.optimizers import optimizer + + +@keras_export("keras.optimizers.SGD") +class SGD(optimizer.Optimizer): + """Gradient descent (with momentum) optimizer. + + Update rule for parameter `w` with gradient `g` when `momentum` is 0: + + ```python + w = w - learning_rate * g + ``` + + Update rule when `momentum` is larger than 0: + + ```python + velocity = momentum * velocity - learning_rate * g + w = w + velocity + ``` + + When `nesterov=True`, this rule becomes: + + ```python + velocity = momentum * velocity - learning_rate * g + w = w + momentum * velocity - learning_rate * g + ``` + + Args: + learning_rate: A float, a + `keras.optimizers.schedules.LearningRateSchedule` instance, or + a callable that takes no arguments and returns the actual value to + use. The learning rate. Defaults to `0.01`. + momentum: float hyperparameter >= 0 that accelerates gradient descent in + the relevant direction and dampens oscillations. 0 is vanilla + gradient descent. Defaults to `0.0`. + nesterov: boolean. Whether to apply Nesterov momentum. + Defaults to `False`. + {{base_optimizer_keyword_args}} + """ + + def __init__( + self, + learning_rate=0.01, + momentum=0.0, + nesterov=False, + weight_decay=None, + clipnorm=None, + clipvalue=None, + global_clipnorm=None, + use_ema=False, + ema_momentum=0.99, + ema_overwrite_frequency=None, + loss_scale_factor=None, + gradient_accumulation_steps=None, + name="SGD", + **kwargs, + ): + super().__init__( + learning_rate=learning_rate, + name=name, + weight_decay=weight_decay, + clipnorm=clipnorm, + clipvalue=clipvalue, + global_clipnorm=global_clipnorm, + use_ema=use_ema, + ema_momentum=ema_momentum, + ema_overwrite_frequency=ema_overwrite_frequency, + loss_scale_factor=loss_scale_factor, + gradient_accumulation_steps=gradient_accumulation_steps, + **kwargs, + ) + if not isinstance(momentum, float) or momentum < 0 or momentum > 1: + raise ValueError("`momentum` must be a float between [0, 1].") + self.momentum = momentum + self.nesterov = nesterov + + def build(self, variables): + """Initialize optimizer variables. + + SGD optimizer has one variable `momentums`, only set if `self.momentum` + is not 0. + + Args: + var_list: list of model variables to build SGD variables on. + """ + if self.built: + return + super().build(variables) + self.momentums = [] + if self.momentum != 0: + self.momentums = self.add_optimizer_variables(variables, "momentum") + + def update_step(self, gradient, variable, learning_rate): + """Update step given gradient and the associated model variable.""" + learning_rate = ops.cast(learning_rate, variable.dtype) + gradient = ops.cast(gradient, variable.dtype) + m = None + if self.momentum != 0: + m = self.momentums[self._get_variable_index(variable)] + + if m is not None: + momentum = ops.cast(self.momentum, variable.dtype) + self.assign( + m, + ops.subtract( + ops.multiply(m, momentum), + ops.multiply(gradient, learning_rate), + ), + ) + if self.nesterov: + self.assign_add( + variable, + ops.subtract( + ops.multiply(m, momentum), + ops.multiply(gradient, learning_rate), + ), + ) + else: + self.assign_add(variable, m) + else: + self.assign_sub(variable, ops.multiply(gradient, learning_rate)) + + def get_config(self): + config = super().get_config() + config.update( + { + "momentum": self.momentum, + "nesterov": self.nesterov, + } + ) + return config + + +SGD.__doc__ = SGD.__doc__.replace( + "{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args +) diff --git a/keras/src/optimizers/sgd_test.py b/keras/src/optimizers/sgd_test.py new file mode 100644 index 000000000000..31961e3bf1ff --- /dev/null +++ b/keras/src/optimizers/sgd_test.py @@ -0,0 +1,100 @@ +# flake8: noqa + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.optimizers.sgd import SGD + + +class SGDTest(testing.TestCase): + def test_config(self): + optimizer = SGD( + learning_rate=0.5, + momentum=0.06, + nesterov=True, + weight_decay=0.004, + ) + self.run_class_serialization_test(optimizer) + + def test_single_step(self): + optimizer = SGD(learning_rate=0.5) + self.assertEqual(len(optimizer.variables), 2) + grads = ops.array([1.0, 6.0, 7.0, 2.0]) + vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) + optimizer.build([vars]) + optimizer.apply_gradients(zip([grads], [vars])) + self.assertAllClose(vars, [0.5, -1.0, -0.5, 3.0], rtol=1e-4, atol=1e-4) + self.assertEqual(len(optimizer.variables), 2) + self.assertEqual(optimizer.variables[0], 1) + self.assertEqual(optimizer.variables[1], 0.5) + + def test_invalid_momentum(self): + with self.assertRaisesRegex( + ValueError, "`momentum` must be a float between \\[0, 1\\]." + ): + SGD(momentum=-1.0) + + with self.assertRaisesRegex( + ValueError, "`momentum` must be a float between \\[0, 1\\]." + ): + SGD(momentum=2.0) + + def test_weight_decay(self): + grads, var1, var2, var3 = ( + ops.zeros(()), + backend.Variable(2.0), + backend.Variable(2.0, name="exclude"), + backend.Variable(2.0), + ) + optimizer_1 = SGD(learning_rate=1.0, weight_decay=0.004) + optimizer_1.apply_gradients(zip([grads], [var1])) + + optimizer_2 = SGD(learning_rate=1.0, weight_decay=0.004) + optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) + optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) + + optimizer_3 = SGD(learning_rate=1.0, weight_decay=0.004) + optimizer_3.exclude_from_weight_decay(var_list=[var3]) + optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) + + self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) + self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) + self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) + + def test_correctness_with_golden(self): + optimizer = SGD(nesterov=True) + + x = backend.Variable(np.ones([10])) + grads = ops.arange(0.1, 1.1, 0.1) + first_grads = ops.full((10,), 0.01) + + # fmt: off + golden = np.array( + [[0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, + 0.9999, 0.9999], [0.9989, 0.9979, 0.9969, 0.9959, 0.9949, 0.9939, + 0.9929, 0.9919, 0.9909, 0.9899], [0.9979, 0.9959, 0.9939, 0.9919, + 0.9899, 0.9879, 0.9859, 0.9839, 0.9819, 0.9799], [0.9969, 0.9939, + 0.9909, 0.9879, 0.9849, 0.9819, 0.9789, 0.9759, 0.9729, 0.9699], + [0.9959, 0.9919, 0.9879, 0.9839, 0.9799, 0.9759, 0.9719, 0.9679, + 0.9639, 0.9599]] + ) + # fmt: on + + optimizer.apply_gradients(zip([first_grads], [x])) + for i in range(5): + self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) + optimizer.apply_gradients(zip([grads], [x])) + + def test_clip_norm(self): + optimizer = SGD(clipnorm=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) + + def test_clip_value(self): + optimizer = SGD(clipvalue=1) + grad = [np.array([100.0, 100.0])] + clipped_grad = optimizer._clip_gradients(grad) + self.assertAllClose(clipped_grad[0], [1.0, 1.0]) diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py new file mode 100644 index 000000000000..586530204588 --- /dev/null +++ b/keras/src/quantizers/__init__.py @@ -0,0 +1,57 @@ +import inspect + +from keras.src.api_export import keras_export +from keras.src.quantizers.quantizers import AbsMaxQuantizer +from keras.src.quantizers.quantizers import Quantizer +from keras.src.quantizers.quantizers import abs_max_quantize +from keras.src.quantizers.quantizers import compute_float8_amax_history +from keras.src.quantizers.quantizers import compute_float8_scale +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars +from keras.src.quantizers.quantizers import pack_int4 +from keras.src.quantizers.quantizers import quantize_and_dequantize +from keras.src.quantizers.quantizers import unpack_int4 +from keras.src.saving import serialization_lib +from keras.src.utils.naming import to_snake_case + +ALL_OBJECTS = {Quantizer, AbsMaxQuantizer} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) + + +@keras_export("keras.quantizers.serialize") +def serialize(initializer): + return serialization_lib.serialize_keras_object(initializer) + + +@keras_export("keras.quantizers.deserialize") +def deserialize(config, custom_objects=None): + """Return a Keras quantizer object via its config.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.quantizers.get") +def get(identifier, **kwargs): + """Retrieve a Keras quantizer object via an identifier.""" + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj(kwargs) + return obj + else: + raise ValueError( + f"Could not interpret quantizer identifier: {identifier}" + ) diff --git a/keras/src/quantizers/gptq.py b/keras/src/quantizers/gptq.py new file mode 100644 index 000000000000..d1a04039afa5 --- /dev/null +++ b/keras/src/quantizers/gptq.py @@ -0,0 +1,495 @@ +import types +from functools import partial + +from keras.src import ops +from keras.src import quantizers +from keras.src.layers import Dense +from keras.src.layers import EinsumDense +from keras.src.ops import linalg +from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantizers import GPTQQuantizer +from keras.src.quantizers.quantizers import compute_quantization_parameters +from keras.src.quantizers.quantizers import dequantize_with_zero_point +from keras.src.quantizers.quantizers import quantize_with_zero_point + + +def _stable_permutation(metric): + """Return a stable permutation that sorts `metric` in descending order. + Uses an index-based jitter to break ties deterministically.""" + n = ops.shape(metric)[0] + idx = ops.arange(0, n, dtype="int32") + # tiny jitter = (idx / n) * 1e-12 so it never flips a real strict ordering + jitter = ops.divide(ops.cast(idx, "float32"), ops.cast(n, "float32")) + metric_jittered = ops.add(metric, ops.multiply(jitter, 1e-12)) + # argsort by negative to get descending + return ops.argsort(ops.negative(metric_jittered)) + + +def gptq_quantize_matrix( + weights_transpose, + inv_hessian, + *, + blocksize=128, + group_size=-1, + activation_order=False, + order_metric=None, + compute_scale_zero=compute_quantization_parameters, +): + """ + Implements the GPTQ error correction updates. + + For a single column update (column j): + e = invH[j, j] * (w_j - q_j) + W[:, j+1:] -= e * invH[j, j+1:] + where: + - w_j is the original column, + - q_j is the quantized column, + - invH is the inverse Hessian, + - e is the propagated error term. + + Across entire blocks: + W[:, future] -= E_block * invH[block, future] + where: + - E_block is the quantization error accumulated for the current block, + - invH[block, future] denotes the cross-block slice of the inverse Hessian, + - W[:, future] are the columns yet to be quantized. + + Args: + weights_transpose: Transposed weight matrix [out_features, in_features] + to quantize. + inv_hessian: Inverse Hessian matrix [in_features, in_features] for + error propagation. + blocksize: Size of the blocks to process (default: 128). + group_size: Size of the groups for parameter reuse + (default: -1, no grouping). + activation_order: Whether to apply activation-order permutation + (default: False). + order_metric: Metric for ordering features + (default: None, uses 1 / diag(invH)). + compute_scale_zero: Function to compute scale and zero for + quantization. + + Returns: + quantized_weights: Quantized weight matrix [out_features, in_features]. + scale: float32. Scale parameters for quantization + [out_features, num_groups]. + zero: Zero-point parameters for quantization [out_features, num_groups]. + g_idx: int32. Group indices for each feature [in_features]. + """ + in_features = ops.shape(weights_transpose)[1] + + if activation_order: + # Use 1 / diag(inverse_hessian) as importance proxy by default. + if order_metric is None: + order_metric = ops.reciprocal( + ops.add(ops.diagonal(inv_hessian), 1e-12) + ) + else: + # sanitize provided metric + order_metric = ops.cast(order_metric, "float32") + order_metric = ops.where( + ops.isfinite(order_metric), + order_metric, + ops.zeros_like(order_metric), + ) + # Sort in descending order by importance + perm = _stable_permutation(order_metric) + inv_perm = ops.argsort(perm) + + weights_transpose = ops.take(weights_transpose, perm, axis=1) + inv_hessian = ops.take( + ops.take(inv_hessian, perm, axis=0), perm, axis=1 + ) + else: + perm = inv_perm = None + + # weights_buffer: [out_features, in_features] + weights_buffer = weights_transpose + # Buffer for the final quantized matrix: [out_features, in_features] + quantized_weights_buffer = ops.zeros_like(weights_transpose, dtype="int32") + + scale_chunks = [] + zero_chunks = [] + + # Compute effective group size + effective_group = in_features if group_size == -1 else group_size + + # Process features in blocks + for block_start in range(0, in_features, blocksize): + block_end = min(block_start + blocksize, in_features) + block_size = block_end - block_start + + # Block views + # block_weights: [out_features, block_size] + block_weights = weights_buffer[:, block_start:block_end] + # block_error: [out_features, block_size] + block_error = ops.zeros_like(block_weights) + # block_inv_hessian: [block_size, block_size] + block_inv_hessian = inv_hessian[ + block_start:block_end, block_start:block_end + ] + + # Per-group cached params for reuse within the group + cached_scale = None + cached_zero = None + cached_maxq = None + cached_group_start = -1 + + for block_idx in range(block_size): + # Current global column index, represents the original column + # in the weight matrix + global_idx = block_start + block_idx + # weight_column: [out_features,] + weight_column = block_weights[:, block_idx] + # Group-wise parameter reuse (compute once per group) + if not effective_group == in_features: # group_size != -1 + # Determine the group start index for the current column + group_start = (global_idx // effective_group) * effective_group + if group_start != cached_group_start: + # New group encountered, compute & cache params + # for this group + group_end = min(group_start + effective_group, in_features) + group_slice = weights_buffer[:, group_start:group_end] + cached_scale, cached_zero, cached_maxq = compute_scale_zero( + group_slice + ) + # Store params once per group (in the order encountered). + scale_chunks.append(cached_scale) + zero_chunks.append(cached_zero) + cached_group_start = group_start + scale, zero, maxq = cached_scale, cached_zero, cached_maxq + else: + # Single global group covering all columns. + if cached_scale is None: + cached_scale, cached_zero, cached_maxq = compute_scale_zero( + weights_buffer + ) + scale_chunks.append(cached_scale) + zero_chunks.append(cached_zero) + cached_group_start = 0 + scale, zero, maxq = cached_scale, cached_zero, cached_maxq + + # Quantize column and store it. + # quantized_column: [out_features, 1] + quantized_column = quantize_with_zero_point( + ops.expand_dims(weight_column, 1), scale, zero, maxq + ) + + # Store quantized column in the buffer. + quantized_weights_buffer = ops.slice_update( + quantized_weights_buffer, + (0, global_idx), + ops.cast(quantized_column, "int32"), + ) + # Dequantize column to compute error. + # dequantized_col: [out_features,] + dequantized_col = dequantize_with_zero_point( + quantized_column, scale, zero + )[:, 0] + # Error feedback for remaining columns within the block + # block_inv_hessian_diag: scalar + current_block_influence = block_inv_hessian[block_idx, block_idx] + # We divide by current_block_influence to get the + # correct scaling of the error term. + err = ops.divide( + ops.subtract(weight_column, dequantized_col), + current_block_influence, + ) + # Record error for propagation to future blocks + block_error = ops.slice_update( + block_error, (0, block_idx), ops.expand_dims(err, 1) + ) + + # Update remaining columns in the current block + # (those before the current column have already been quantized) + # Propagate error to remaining columns in the block. + if block_idx < block_size - 1: + # update: [out_features, block_size - block_idx - 1] + update = ops.matmul( + ops.expand_dims(err, 1), + ops.expand_dims( + block_inv_hessian[block_idx, block_idx + 1 :], 0 + ), + ) + # tail is a view of the remaining columns in the block + # to be updated + # tail: [out_features, block_size - block_idx - 1] + tail = block_weights[:, block_idx + 1 :] + block_weights = ops.slice_update( + block_weights, + (0, block_idx + 1), + ops.subtract(tail, update), + ) + + # Propagate block errors to future features (beyond the block) + if block_end < in_features: + # Total update for all future columns, based on the + # accumulated error in this block. This is calculated + # as the matrix product of the block_error and the + # relevant slice of the inverse Hessian. + # total_update: [out_features, in_features - block_end] + total_update = ops.matmul( + block_error, inv_hessian[block_start:block_end, block_end:] + ) + # Update the remaining weights in the buffer. This is done + # by subtracting the total_update from the remaining columns. + weights_buffer = ops.concatenate( + [ + weights_buffer[:, :block_end], + ops.subtract(weights_buffer[:, block_end:], total_update), + ], + axis=1, + ) + + # Build group indices for each (possibly permuted) column + # base_group = effective_group (int) + base_group = effective_group + + # g_idx in permuted domain + g_idx = ops.arange(0, in_features, dtype="int32") + g_idx = ops.divide(g_idx, base_group) + g_idx = ops.cast(g_idx, "float32") + + # Map group indices and quantized weights back to original column order + if activation_order: + g_idx = ops.take(g_idx, inv_perm, axis=0) + quantized_weights_buffer = ops.take( + quantized_weights_buffer, inv_perm, axis=1 + ) + + # Concatenate recorded group params + if len(scale_chunks) == 0: + # Edge case: no groups recorded (empty input); fall back to whole matrix + s, z, _ = compute_scale_zero(weights_transpose) + scale = s + zero = z + else: + scale = ops.concatenate(scale_chunks, axis=1) + zero = ops.concatenate(zero_chunks, axis=1) + + return quantized_weights_buffer, scale, zero, g_idx + + +class GPTQ: + def __init__(self, layer, config=GPTQConfig(tokenizer=None, dataset=None)): + self.original_layer = layer + self.num_samples = 0 + self.config = config + self.quantizer = GPTQQuantizer( + config, compute_dtype=layer.variable_dtype + ) + + # Explicitly handle each supported layer type + if isinstance(layer, Dense) or ( + isinstance(layer, EinsumDense) and layer.kernel.ndim == 2 + ): + # For a standard Dense layer, the dimensions are straightforward. + self.kernel_shape = layer.kernel.shape + # rows: [input_features] + self.rows = self.kernel_shape[0] + # columns: [output_features] + self.columns = self.kernel_shape[1] + self.layer = layer + + # Handle 3D EinsumDense layers (typically from attention blocks). + elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3: + # For EinsumDense, we determine the effective 2D dimensions. + self.kernel_shape = layer.kernel.shape + shape = list(self.kernel_shape) + try: + d_model_dim_index = shape.index(max(shape)) + except ValueError: + raise TypeError( + f"Could not determine hidden dimension from shape {shape}" + ) + + if d_model_dim_index == 0: # QKV projection case + in_features, heads, head_dim = shape + self.rows, self.columns = ( + in_features, + ops.multiply(heads, head_dim), + ) + elif d_model_dim_index in [1, 2]: # Attention Output case + heads, head_dim, out_features = shape + self.rows, self.columns = ( + ops.multiply(heads, head_dim), + out_features, + ) + + # Create a temporary object that holds a reshaped + # 2D version of the kernel. + self.layer = types.SimpleNamespace( + kernel=ops.reshape(layer.kernel, (self.rows, self.columns)), + ) + else: + # Raise an error if the layer is not supported. + raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}") + self.hessian = ops.zeros((self.rows, self.rows), dtype="float32") + + def update_hessian_with_batch(self, input_batch): + """ + Updates the running average of the Hessian matrix with a new batch. + + This method computes the Hessian matrix for a given batch of input + activations and updates the accumulated Hessian (`self.hessian`) using a + numerically stable running average. This allows the Hessian to be + computed over a large dataset without loading all samples into memory + at once. + + The input tensor is first reshaped into a 2D matrix [num_samples, + num_features] before the Hessian is calculated. + + Args: + input_batch: A 2D or higher-dimensional tensor of input activations + from a calibration batch. + + Raises: + ValueError: If the feature dimension of the input tensor + `input_batch` does not match the dimensions of the + pre-initialized Hessian matrix `self.hessian`. + """ + if input_batch is None: + raise ValueError("Input tensor cannot be None.") + + if len(input_batch.shape) < 2: + raise ValueError( + "Input tensor must have rank >= 2 " + f"(got rank {len(input_batch.shape)})." + ) + if ops.size(input_batch) == 0: + raise ValueError("Input tensor cannot be empty.") + if len(input_batch.shape) > 2: + # [batch, features] + input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1])) + x = ops.cast(input_batch, "float32") + + num_new_samples = ops.shape(x)[0] + num_prev_samples = self.num_samples + total_samples = ops.add(num_prev_samples, num_new_samples) + + if ops.shape(self.hessian)[0] != ops.shape(x)[-1]: + raise ValueError( + f"Hessian dimensions ({ops.shape(self.hessian)[0]}) do not " + f"match input features ({ops.shape(x)[-1]})." + ) + + # gram_matrix: [features, features] + gram_matrix = ops.matmul(ops.transpose(x), x) + # Ensures numerical stability and symmetry in case of large floating + # point activations. + gram_matrix = ops.divide( + ops.add(gram_matrix, ops.transpose(gram_matrix)), 2.0 + ) + + # Decay previous mean and add current per-sample contribution + # (factor 2/N) + if self.num_samples > 0: + self.hessian = ops.multiply( + self.hessian, ops.divide(num_prev_samples, total_samples) + ) + + self.hessian = ops.add( + self.hessian, + ops.multiply(ops.divide(2.0, total_samples), gram_matrix), + ) + + self.num_samples = self.num_samples + ops.shape(x)[0] or 0 + + def quantize_and_correct_layer( + self, + blocksize=128, + ): + """ + Performs GPTQ quantization and correction on the layer's weights. + + This method implements the core logic of the "Optimal Brain Quant" + (OBQ) method, as applied by GPTQ, to quantize the weights of a single + layer. It iteratively quantizes blocks of weights and corrects for the + quantization error by updating the remaining weights. + + The algorithm follows these main steps: + 1. Initialization: It optionally reorders the weight columns based + on activation magnitudes (`activation_order=True`) to protect more + salient + weights. + 2. Hessian Modification: The Hessian matrix, pre-computed from + calibration data, is dampened to ensure its invertibility and + stability. + 3. Iterative Quantization: The function iterates through the + weight columns in blocks (`blocksize`). In each iteration, it: + a. Quantizes one column. + b. Calculates the quantization error. + c. Updates the remaining weights in the *current* block by + distributing the error, using the inverse Hessian. + 4. Block-wise Correction: After a block is quantized, the total + error from that block is propagated to the *next* block of weights + to be processed. + 5. Finalization: The quantized weights are reordered back if + `activation_order` was used, and the layer's weights are updated. + This implementation is based on the official GPTQ paper and repository. + For more details, see: + - Paper: https://arxiv.org/abs/2210.17323 + - Original Code: https://github.com/IST-DASLab/gptq + + + Args: + blocksize: (int, optional) The size of the weight block to process + at a time. Defaults to 128. + """ + weights_matrix = ops.transpose(self.layer.kernel) + + # Dampen the Hessian for Stability + hessian_diagonal = ops.diagonal(self.hessian) + dead_diagonal = ops.equal(hessian_diagonal, 0.0) + hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal) + hessian_matrix = ops.add( + self.hessian, + ops.diag( + ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal)) + ), + ) + + # Add dampening factor to the Hessian diagonal + damping_factor = ops.multiply( + self.config.hessian_damping, ops.mean(hessian_diagonal) + ) + hessian_diagonal = ops.add(hessian_diagonal, damping_factor) + hessian_matrix = ops.add( + ops.subtract( + hessian_matrix, ops.diag(ops.diagonal(hessian_matrix)) + ), + ops.diag(hessian_diagonal), + ) + + # Compute the inverse Hessian, which is used for error correction + inverse_hessian = linalg.inv(hessian_matrix) + + quantized, scale, zero, g_idx = gptq_quantize_matrix( + weights_matrix, + inv_hessian=inverse_hessian, + blocksize=blocksize, + group_size=self.config.group_size, + activation_order=self.config.activation_order, + order_metric=ops.diagonal(hessian_matrix), + compute_scale_zero=partial(self.quantizer.find_params, weight=True), + ) + quantized = ops.cast( + quantized, self.original_layer.quantized_kernel.dtype + ) + + if self.config.weight_bits == 4: + # For 4-bit weights, we need to pack them into bytes + quantized, _, _ = quantizers.pack_int4( + quantized, axis=0, dtype="uint8" + ) + + del self.original_layer._kernel + self.original_layer.quantized_kernel.assign(quantized) + self.original_layer.kernel_scale.assign(scale) + self.original_layer.kernel_zero.assign(zero) + self.original_layer.g_idx.assign(g_idx) + self.original_layer.is_gptq_calibrated = True + + def free(self): + del self.hessian + del self.layer diff --git a/keras/src/quantizers/gptq_config.py b/keras/src/quantizers/gptq_config.py new file mode 100644 index 000000000000..eaf9434ee192 --- /dev/null +++ b/keras/src/quantizers/gptq_config.py @@ -0,0 +1,184 @@ +from keras.src.api_export import keras_export + + +@keras_export("keras.quantizers.GPTQConfig") +class GPTQConfig: + """Configuration class for the GPTQ (Gradient-based Post-Training + Quantization) algorithm. + + GPTQ is a post-training quantization method that quantizes neural network + weights to lower precision (e.g., 4-bit) while minimizing the impact on + model accuracy. It works by analyzing the Hessian matrix of the loss + function with respect to the weights and applying optimal quantization + that preserves the most important weight values. + + **When to use GPTQ:** + - You want to reduce model size and memory usage + - You need faster inference on hardware that supports low-precision + operations + - You want to maintain model accuracy as much as possible + - You have a pre-trained model that you want to quantize without + retraining + + **How it works:** + 1. Uses calibration data to compute the Hessian matrix for each layer + 2. Applies iterative quantization with error correction + 3. Reorders weights based on activation importance (optional) + 4. Quantizes weights while minimizing quantization error + + **Example usage:** + ```python + from keras.quantizers import GPTQConfig + from keras import Model + + # Create configuration for 4-bit quantization + config = GPTQConfig( + dataset=calibration_data, # Your calibration dataset + tokenizer=your_tokenizer, # Tokenizer for text data + weight_bits=4, # Quantize to 4 bits + num_samples=128, # Number of calibration samples + sequence_length=512, # Sequence length for each sample + hessian_damping=0.01, # Hessian stabilization factor + group_size=128, # Weight grouping for quantization + symmetric=False, # Use asymmetric quantization + activation_order=True # Reorder weights by importance + ) + + # Apply quantization to your model + model = Model(...) # Your pre-trained model + model.quantize("gptq", config=config) + + # The model now has quantized weights and can be used for inference + ``` + + **Benefits:** + - **Memory reduction**: 4-bit quantization reduces memory by ~8x compared + to float32 + - **Faster inference**: Lower precision operations are faster on supported + hardware + - **Accuracy preservation**: Minimizes accuracy loss through optimal + quantization + - **No retraining required**: Works with pre-trained models + + **Advanced usage examples:** + + **Per-channel quantization (recommended for most cases):** + ```python + config = GPTQConfig( + dataset=calibration_data, + tokenizer=tokenizer, + weight_bits=4, + group_size=-1, # -1 enables per-channel quantization + symmetric=False + ) + ``` + + **Grouped quantization (for specific hardware requirements):** + ```python + config = GPTQConfig( + dataset=calibration_data, + tokenizer=tokenizer, + weight_bits=4, + group_size=64, # 64 weights share the same scale factor + symmetric=True # Use symmetric quantization + ) + ``` + + **High-accuracy quantization with activation ordering:** + ```python + config = GPTQConfig( + dataset=calibration_data, + tokenizer=tokenizer, + weight_bits=4, + activation_order=True, # Reorder weights by importance + hessian_damping=0.005, # Lower damping for more precise + # quantization + num_samples=256 # More samples for better accuracy + ) + ``` + + **References:** + - Original GPTQ paper: "GPTQ: Accurate Post-Training Quantization + for Generative Pre-trained Transformers" + - Implementation based on: https://github.com/IST-DASLab/gptq + - Suitable for: Transformer models, large language models, and other + deep neural networks + + **Note:** The quality of quantization depends heavily on the calibration + dataset. Use representative data that covers the expected input + distribution for best results. + + Args: + dataset: The calibration dataset. It can be an iterable that yields + strings or pre-tokenized numerical tensors (e.g., a list of + strings, a generator, or a NumPy array). This data is used to + analyze the model's activations. + tokenizer: A `keras_nlp.Tokenizer` instance (or a similar callable) + that is used to process the `dataset` if it contains strings. + weight_bits: (int, optional) The number of bits to quantize weights to. + Defaults to 4. + num_samples: (int, optional) The number of calibration data samples to + use from the dataset. Defaults to 128. + sequence_length: (int, optional) The sequence length to use for each + calibration sample. Defaults to 512. + hessian_damping: (float, optional) The % of Hessian damping to use for + stabilization during inverse calculation. Defaults to 0.01. + group_size: (int, optional) The size of weight groups to quantize + together. A `group_size` of -1 indicates per-channel quantization. + Defaults to 128. + symmetric: (bool, optional) If `True`, uses symmetric quantization. + If `False`, uses asymmetric quantization. Defaults to `False`. + activation_order: (bool, optional) If `True`, reorders weight columns + based on activation magnitude, which can improve quantization + accuracy. Defaults to `False`. + """ + + def __init__( + self, + dataset, + tokenizer, + *, + weight_bits: int = 4, + num_samples: int = 128, + per_channel: bool = True, + sequence_length: int = 512, + hessian_damping: float = 0.01, + group_size: int = 128, + symmetric: bool = False, + activation_order: bool = False, + ): + if weight_bits not in [2, 3, 4, 8]: + raise ValueError( + f"Unsupported weight_bits {weight_bits}. " + "Supported values are 2, 3, 4, and 8." + ) + if num_samples <= 0: + raise ValueError("num_samples must be a positive integer.") + if sequence_length <= 0: + raise ValueError("sequence_length must be a positive integer.") + if hessian_damping < 0 or hessian_damping > 1: + raise ValueError("hessian_damping must be between 0 and 1.") + if group_size < -1 or group_size == 0: + raise ValueError( + "Invalid group_size. Supported values are -1 (whole-tensor) " + "or a positive integer, " + f"but got {group_size}." + ) + self.dataset = dataset + self.tokenizer = tokenizer + self.num_samples = num_samples + self.per_channel = per_channel + self.sequence_length = sequence_length + self.hessian_damping = hessian_damping + self.weight_bits = weight_bits + self.group_size = group_size + self.symmetric = symmetric + self.activation_order = activation_order + + def dtype_policy_string(self): + """Returns the dtype policy string for this configuration. + + Returns: + A string representing the dtype policy, e.g. "gptq_4bit". + """ + return f"gptq/{self.weight_bits}/{self.group_size}" diff --git a/keras/src/quantizers/gptq_config_test.py b/keras/src/quantizers/gptq_config_test.py new file mode 100644 index 000000000000..0bdd4607cd0f --- /dev/null +++ b/keras/src/quantizers/gptq_config_test.py @@ -0,0 +1,52 @@ +from keras.src import testing +from keras.src.quantizers.gptq_config import GPTQConfig + + +class TestGPTQConfig(testing.TestCase): + def test_invalid_weight_bits(self): + with self.assertRaisesRegex(ValueError, "Unsupported weight_bits"): + GPTQConfig(dataset=None, tokenizer=None, weight_bits=1) + with self.assertRaisesRegex(ValueError, "Unsupported weight_bits"): + GPTQConfig(dataset=None, tokenizer=None, weight_bits=5) + + def test_invalid_num_samples(self): + with self.assertRaisesRegex( + ValueError, "num_samples must be a positive" + ): + GPTQConfig(dataset=None, tokenizer=None, num_samples=0) + with self.assertRaisesRegex( + ValueError, "num_samples must be a positive" + ): + GPTQConfig(dataset=None, tokenizer=None, num_samples=-1) + + def test_invalid_sequence_length(self): + with self.assertRaisesRegex( + ValueError, "sequence_length must be a positive" + ): + GPTQConfig(dataset=None, tokenizer=None, sequence_length=0) + with self.assertRaisesRegex( + ValueError, "sequence_length must be a positive" + ): + GPTQConfig(dataset=None, tokenizer=None, sequence_length=-10) + + def test_invalid_hessian_damping(self): + with self.assertRaisesRegex( + ValueError, "hessian_damping must be between" + ): + GPTQConfig(dataset=None, tokenizer=None, hessian_damping=-0.1) + with self.assertRaisesRegex( + ValueError, "hessian_damping must be between" + ): + GPTQConfig(dataset=None, tokenizer=None, hessian_damping=1.1) + + def test_invalid_group_size(self): + with self.assertRaisesRegex(ValueError, "Invalid group_size"): + GPTQConfig(dataset=None, tokenizer=None, group_size=0) + with self.assertRaisesRegex(ValueError, "Invalid group_size"): + GPTQConfig(dataset=None, tokenizer=None, group_size=-2) + + def test_dtype_policy_string(self): + config = GPTQConfig( + dataset=None, tokenizer=None, weight_bits=4, group_size=64 + ) + assert config.dtype_policy_string() == "gptq/4/64" diff --git a/keras/src/quantizers/gptq_core.py b/keras/src/quantizers/gptq_core.py new file mode 100644 index 000000000000..b97e929e37d2 --- /dev/null +++ b/keras/src/quantizers/gptq_core.py @@ -0,0 +1,462 @@ +import math +from contextlib import contextmanager + +import numpy as np +from absl import logging + +from keras.src import ops +from keras.src import utils as keras_utils +from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy +from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap +from keras.src.layers import Dense +from keras.src.layers import EinsumDense +from keras.src.layers import Embedding +from keras.src.quantizers.gptq import GPTQ +from keras.src.quantizers.gptq_config import GPTQConfig + + +@contextmanager +def stream_hessians(layers_map, gptq_objects): + """ + Temporarily monkey-patch each target layer's `call` method so + that input activations are streamed into the GPTQ instance + running Hessian estimate at capture time. + + On `__enter__`: For every (name, layer) in `layers_map`, replaces + `layer.call` with a wrapper that: + 1) extracts the layer input from `*args`/`**kwargs`, + 2) reshapes it to 2D `[-1, rows]` where + `rows = gptq_objects[name].rows`, + 3) calls `gptq_objects[name].update_hessian_with_batch(x2d)` + 4) delegates to the original `layer.call` and returns its + output. + + On `__exit__`: All original `layer.call` methods are restored even if an + exception occurs. + + * Space complexity: O(d**2) per layer (for the Hessian). + * No weights are modified; only GPTQ statistics are updated. + + Args: + layers_map: Dict[str, Layer]. Mapping from logical layer names to + the Keras layers that should be patched during calibration. Keys must + match `gptq_objects`. + gptq_objects: Dict[str, GPTQ]. Mapping from names to GPTQ instances. + + Yields: + None: The patched state is active only within the `with` block. After + exit, all layers are unpatched and safe to use normally. + + Example: + ```python + >>> with stream_hessians(layers_map, gptq_objects): + ... for sample in calibration_inputs: + ... if len(sample.shape) == 2: + ... sample = ops.expand_dims(sample, 0) + ... _ = block(sample) # hooks update Hessians on-the-fly + >>> # <- original layer.call methods restored here + ``` + """ + original_calls = {} + + def create_hook(name, original_call_func): + def hook(*args, **kwargs): + inp = args[0] if args else kwargs["inputs"] + # Explicitly reshape the input tensor to be 2D, with the + # second dimension matching the number of input features + # expected by the layer's kernel. + # This correctly handles inputs of any dimensionality + # (e.g., 3D or 4D). + num_features = gptq_objects[name].rows + input_2d = ops.reshape(inp, (-1, num_features)) + gptq_objects[name].update_hessian_with_batch(input_2d) + return original_call_func(*args, **kwargs) + + return hook + + try: + for name, layer in layers_map.items(): + original_calls[name] = layer.call + layer.call = create_hook(name, layer.call) + yield + finally: + for name, layer in layers_map.items(): + layer.call = original_calls[name] + + +def get_dataloader( + tokenizer, + sequence_length, + dataset, + num_samples=128, + *, + strategy="strided", + seed=42, + stride=None, + eos_id=None, +): + """ + Prepares and chunks the calibration dataloader, repeating short datasets. + All processing happens on the CPU. + + Args: + tokenizer: The tokenizer to use for text splitting. + sequence_length: The length of each input sequence. + dataset: The dataset to sample from. + num_samples: The number of samples to generate. + strategy: The sampling strategy to use. Possible values are + 1. "strided": Samples are taken at regular intervals. + 2. "linspace": Samples are taken at evenly spaced intervals. + 3. "random": Samples are taken at random positions. + seed: The random seed for reproducibility. Used only if + strategy="random" + stride: The stride length for "strided" sampling. + eos_id: The end-of-sequence token ID. + + Returns: + np.ndarray of shape (num_samples, 1, sequence_length), dtype int32. + """ + if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)): + raise TypeError( + "The `dataset` argument must be an iterable (e.g., a list of " + "strings, a generator, or a NumPy array). Got type: " + f"{type(dataset).__name__}. Please pass the loaded dataset " + "directly." + ) + + dataset_list = list(dataset) + if not dataset_list: + raise ValueError("Provided dataset is empty.") + + pieces = [] + if isinstance(dataset_list[0], str): + for i, s in enumerate(dataset_list): + toks = np.asarray(tokenizer.tokenize(s)).reshape(-1) + pieces.append(toks) + # avoid windows that span document boundaries + if eos_id is not None and i < len(dataset_list) - 1: + pieces.append(np.array([eos_id], dtype=np.int32)) + else: + for s in dataset_list: + toks = ops.convert_to_numpy(s).reshape(-1) + pieces.append(toks.astype(np.int32, copy=False)) + + all_tokens = ( + pieces[0].astype(np.int32, copy=False) + if len(pieces) == 1 + else np.concatenate(pieces, axis=0).astype(np.int32, copy=False) + ) + + required_tokens = num_samples * sequence_length + if all_tokens.size < required_tokens: + repeats = math.ceil(required_tokens / max(1, all_tokens.size)) + all_tokens = np.tile(all_tokens, repeats) + + max_start = all_tokens.size - sequence_length + if max_start < 0: + raise ValueError( + f"Not enough tokens to form one sample of length {sequence_length} " + f"(have {all_tokens.size})." + ) + + # Choose deterministic, well-spread starts by default + if strategy == "random": + rng = np.random.default_rng(seed) + starts = rng.integers( + 0, max_start + 1, size=num_samples, dtype=np.int64 + ) + elif strategy == "linspace": + # even coverage with no RNG + starts = np.linspace(0, max_start, num_samples, dtype=np.int64) + elif strategy == "strided": + # stride chosen to cover the space roughly uniformly + if stride is None: + stride = max(1, (max_start + 1) // num_samples) + # offset derived deterministically from seed + offset = ( + (abs(hash(("gptq-calib", seed))) % (max_start + 1)) + if max_start > 0 + else 0 + ) + starts = (offset + np.arange(num_samples, dtype=np.int64) * stride) % ( + max_start + 1 + ) + else: + raise ValueError(f"Unknown strategy: {strategy}") + + # Gather contiguous windows + # sliding_window_view avoids building a big index matrix + windows = np.lib.stride_tricks.sliding_window_view( + all_tokens, sequence_length + ) + samples = windows[starts] # (num_samples, sequence_length) + return samples.astype(np.int32)[:, None, :] + + +def _get_backbone_layers(model): + """Extract embedding and transformer layers from a KerasHub model.""" + backbone = model.backbone + if not hasattr(backbone, "transformer_layers"): + raise ValueError( + "The model's backbone does not have a 'transformer_layers' " + "attribute. Please ensure you are using a standard KerasHub " + "transformer model." + ) + transformer_blocks = backbone.transformer_layers + + embedding_layer = None + if hasattr(backbone, "token_embedding"): + embedding_layer = backbone.token_embedding + elif hasattr(backbone, "embedding"): + embedding_layer = backbone.embedding + return embedding_layer, transformer_blocks + + +def _get_custom_layers(model): + """Heuristic for extracting embedding + transformer blocks from a custom + model.""" + embedding_layer = None + transformer_blocks = [] + for layer in model.layers: + if isinstance(layer, Embedding) and embedding_layer is None: + embedding_layer = layer + elif getattr(layer, "_layers", None): # container-like block + transformer_blocks.append(layer) + return embedding_layer, transformer_blocks + + +def find_layers_in_block(block): + """ + Finds all Dense and EinsumDense layers in a transformer block. + + Args: + block: A Keras layer representing a transformer block. + Returns: + A dict mapping layer paths to the corresponding Dense or EinsumDense + """ + found_layers = {} + for sub_layer in block._flatten_layers(): + if len(list(sub_layer._flatten_layers())) == 1: + if isinstance(sub_layer, (Dense, EinsumDense)): + found_layers[sub_layer.path] = sub_layer + return found_layers + + +def apply_gptq_layerwise(model, dataloader, config): + """Applies GPTQ quantization layer-by-layer to a Keras model. + + This function is designed to work with common transformer architectures, + like those provided by KerasHub. It automatically discovers the model's + structure by first looking for the standard format: a `model.backbone` + attribute that contains a `transformer_layers` list. + + If a standard backbone is not found, it falls back to a heuristic for + custom models, where it assumes the first `keras.layers.Embedding` layer + is the input embedding and any subsequent container layers are the + transformer blocks to be quantized. + + The core logic operates as follows: + 1. It automatically detects the model's structure, identifying the main + embedding layer and a sequence of transformer blocks. + 2. It processes the model sequentially, one block at a time. For each + block, it uses temporary hooks to capture the input activations of + each target layer during a forward pass with the calibration data. + 3. These captured activations are used to compute the Hessian matrix for + each layer's weights. + 4. The GPTQ algorithm is then applied to each layer to find the optimal + quantized weights that minimize the error introduced. + 5. The output activations from the current block are then used as the + input for the next block, ensuring that quantization errors are + accounted for throughout the model. + + Args: + model: The Keras model instance to be quantized. The function will + attempt to automatically discover its structure. + dataloader: An iterable providing calibration data. Each item should + be a batch of token IDs suitable for the model's embedding layer. + config: A GPTQConfiguration object. + + Raises: + ValueError: If the function cannot automatically find an embedding + layer or any transformer-like blocks to quantize within the model. + """ + + num_samples = config.num_samples + + logging.info("Starting model quantization...") + embedding_layer = None + transformer_blocks = [] + if hasattr(model, "backbone"): + logging.info("Detected KerasHub model structure.") + embedding_layer, transformer_blocks = _get_backbone_layers(model) + else: + logging.info("Detected custom model structure.") + embedding_layer, transformer_blocks = _get_custom_layers(model) + + if embedding_layer is None: + raise ValueError( + "Could not automatically find an embedding layer in the model." + ) + if not transformer_blocks: + raise ValueError( + "Could not automatically find any transformer-like blocks to " + "quantize." + ) + + # Initial inputs are the outputs of the token embedding layer + inputs = [ + embedding_layer(ops.convert_to_tensor(batch, dtype="int32")) + for batch in dataloader + ] + num_samples = min(num_samples, len(inputs)) + + progbar = keras_utils.Progbar(target=len(transformer_blocks)) + + for block_idx, block in enumerate(transformer_blocks): + logging.info(f"Quantizing Block {block_idx}") + sub_layers_map = find_layers_in_block(block) + + if not sub_layers_map: + logging.info( + f" No Dense or EinsumDense layers found in block {block_idx}. " + "Skipping." + ) + else: + logging.info(f"Found layers: {list(sub_layers_map.keys())}") + gptq_objects = { + name: GPTQ(layer, config) + for name, layer in sub_layers_map.items() + } + + with stream_hessians(sub_layers_map, gptq_objects): + for sample_idx in range(num_samples): + current_input = inputs[sample_idx] + if len(current_input.shape) == 2: + current_input = ops.expand_dims(current_input, axis=0) + _ = block(current_input) + + for name, gptq_object in gptq_objects.items(): + logging.info(f"Quantizing {name}...") + gptq_object.quantize_and_correct_layer() + gptq_object.free() + + del gptq_objects + + if block_idx < len(transformer_blocks) - 1: + logging.info(f"Generating inputs for block {block_idx + 1}...") + next_block_inputs = [] + for sample_idx in range(num_samples): + current_input = inputs[sample_idx] + if len(current_input.shape) == 2: + current_input = ops.expand_dims(current_input, axis=0) + output = block(current_input)[0] + next_block_inputs.append(output) + inputs = next_block_inputs + progbar.update(current=block_idx + 1) + + logging.info("Quantization process complete.") + + +def gptq_quantize(model, config): + """ + Top-level function to quantize a Keras model using GPTQ. + """ + logging.info("Starting GPTQ quantization process...") + + # Load all data needed from the generator/source in a single call. + total_samples_to_request = config.num_samples + dataloader = get_dataloader( + config.tokenizer, + config.sequence_length, + config.dataset, + num_samples=total_samples_to_request, + ) + + # Split the materialized data. This works because dataloader + # is now a NumPy array, which can be sliced and reused. + calibration_dataloader = dataloader[: config.num_samples] + + apply_gptq_layerwise(model, calibration_dataloader, config) + + +def get_group_size_for_layer(layer, config): + """Determine the group size for GPTQ quantization. + + The group size can be specified either through the `config` argument + or through the `dtype_policy` if it is of type `GPTQDTypePolicy`. + + The config argument is usually available when quantizing the layer + via the `quantize` method. If the layer was deserialized from a + saved model, the group size should be specified in the `dtype_policy`. + + Args: + config: An optional configuration object that may contain the + `group_size` attribute. + Returns: + int. The determined group size for GPTQ quantization. + Raises: + ValueError: If the group size is not specified in either the + `config` or the `dtype_policy`. + """ + if config and isinstance(config, GPTQConfig): + return config.group_size + elif isinstance(layer.dtype_policy, GPTQDTypePolicy): + return layer.dtype_policy.group_size + elif isinstance(layer.dtype_policy, DTypePolicyMap): + policy = layer.dtype_policy[layer.path] + if not isinstance(policy, GPTQDTypePolicy): + # This should never happen based on how we set the + # quantization mode, but we check just in case. + raise ValueError( + "Expected a `dtype_policy` of type `GPTQDTypePolicy`." + f"Got: {type(policy)}" + ) + return policy.group_size + else: + raise ValueError( + "For GPTQ quantization, the group_size must be specified" + "either through a `dtype_policy` of type " + "`GPTQDTypePolicy` or the `config` argument." + ) + + +def get_weight_bits_for_layer(layer, config): + """Determine the number of weight bits for GPTQ quantization. + + The number of weight bits can be specified either through the `config` + argument or through the `dtype_policy` if it is of type + `GPTQDTypePolicy`. + + The config argument is usually available when quantizing the layer + via the `quantize` method. If the layer was deserialized from a + saved model, the weight bits should be specified in the `dtype_policy`. + + Args: + config: An optional configuration object that may contain the + `weight_bits` attribute. + Returns: + int. The determined number of weight bits for GPTQ quantization. + Raises: + ValueError: If the weight bits is not specified in either the + `config` or the `dtype_policy`. + """ + if config and isinstance(config, GPTQConfig): + return config.weight_bits + elif isinstance(layer.dtype_policy, GPTQDTypePolicy): + return layer.dtype_policy.weight_bits + elif isinstance(layer.dtype_policy, DTypePolicyMap): + policy = layer.dtype_policy[layer.path] + if not isinstance(policy, GPTQDTypePolicy): + # This should never happen based on how we set the + # quantization mode, but we check just in case. + raise ValueError( + "Expected a `dtype_policy` of type `GPTQDTypePolicy`." + f"Got: {type(policy)}" + ) + return policy.weight_bits + else: + raise ValueError( + "For GPTQ quantization, the weight_bits must be specified" + "either through a `dtype_policy` of type " + "`GPTQDTypePolicy` or the `config` argument." + ) diff --git a/keras/src/quantizers/gptq_core_test.py b/keras/src/quantizers/gptq_core_test.py new file mode 100644 index 000000000000..5ac0ecba3787 --- /dev/null +++ b/keras/src/quantizers/gptq_core_test.py @@ -0,0 +1,311 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.gptq_core import get_dataloader + +VOCAB_SIZE = 100 + + +class MockTokenizer: + """A mock tokenizer that mimics the real API for testing.""" + + def tokenize(self, text): + return [ord(c) % VOCAB_SIZE for c in "".join(text)] + + def __call__(self, text): + return self.tokenize(text) + + +class EmptyBlock(layers.Layer): + """A block that contains no quantizable layers.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.ln = layers.LayerNormalization() + + def call(self, inputs): + return self.ln(inputs) + + +class TransformerBlock(layers.Layer): + """A toy transformer block with a quantizable Dense layer.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense = layers.Dense(128) + + def call(self, inputs): + return self.dense(inputs) + + +def _get_model_with_backbone( + has_transformer_layers=True, embedding_name="embedding" +): + """Creates a KerasHub-style model with a backbone.""" + + class Backbone(layers.Layer): + def __init__(self, vocab_size, embedding_dim=128, **kwargs): + super().__init__(**kwargs) + # Use direct assignment + setattr( + self, + embedding_name, + layers.Embedding(vocab_size, embedding_dim), + ) + + # Keep track of layers in a list for the call method + self.transformer_layers = [] + if has_transformer_layers: + self.transformer_layers.append(TransformerBlock()) + + def call(self, inputs): + x = getattr(self, embedding_name)(inputs) + for layer in self.transformer_layers: + x = layer(x) + return x + + class Model(models.Model): + def __init__(self, vocab_size, **kwargs): + super().__init__(**kwargs) + # Pass configuration directly + self.backbone = Backbone(vocab_size=vocab_size) + self.classifier = layers.Dense(1, activation="sigmoid") + + def call(self, inputs): + x = self.backbone(inputs) + x = layers.GlobalAveragePooling1D()(x) + return self.classifier(x) + + model = Model(vocab_size=VOCAB_SIZE) + rng = np.random.default_rng(seed=42) + dummy_input = rng.normal(loc=0, scale=1, size=(2, 64)).astype(np.float32) + + _ = model(dummy_input) + return model + + +def build_all_tokens_strings(dataset, tokenizer, eos_id=None): + pieces = [] + for i, s in enumerate(dataset): + toks = np.asarray(tokenizer.tokenize(s), dtype=np.int32).reshape(-1) + pieces.append(toks) + if eos_id is not None and i < len(dataset) - 1: + pieces.append(np.array([eos_id], dtype=np.int32)) + return np.concatenate(pieces, axis=0).astype(np.int32, copy=False) + + +def sliding_windows(x, L): + return np.lib.stride_tricks.sliding_window_view(x, L) + + +@pytest.mark.requires_trainable_backend +class TestGPTQCore(testing.TestCase): + @parameterized.named_parameters( + [("strided", "strided"), ("linspace", "linspace"), ("random", "random")] + ) + def test_shape_and_dtype_strings(self, strategy): + """Test the shape and dtype of the output for string inputs.""" + tok = MockTokenizer() + dataset = ["a b c d e f g", "h i j k"] + seq_len, n = 5, 7 + + out = get_dataloader( + tok, seq_len, dataset, num_samples=n, strategy=strategy, seed=123 + ) + self.assertEqual(out.shape, (n, 1, seq_len)) + self.assertEqual(out.dtype, np.int32) + + @parameterized.named_parameters( + [("strided", "strided"), ("linspace", "linspace"), ("random", "random")] + ) + def test_shape_and_dtype_pretokenized(self, strategy): + """Test the shape and dtype of the output for pre-tokenized inputs.""" + tok = MockTokenizer() + # Pre-tokenized inputs; mixed shapes (1, L) and (L,) + seqs = [ + np.array([[1, 2, 3, 4]], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + ] + tok = MockTokenizer() + seq_len, n = 3, 4 + + out = get_dataloader( + tok, seq_len, seqs, num_samples=n, strategy=strategy, seed=7 + ) + self.assertEqual(out.shape, (n, 1, seq_len)) + self.assertEqual(out.dtype, np.int32) + + def test_strided_is_deterministic_for_same_args(self): + tok = MockTokenizer() + dataset = ["a b c d e", "f g h i j k"] + out1 = get_dataloader( + tok, 4, dataset, num_samples=6, strategy="strided", seed=99 + ) + out2 = get_dataloader( + tok, 4, dataset, num_samples=6, strategy="strided", seed=99 + ) + self.assertTrue(ops.all(ops.equal(out1, out2))) + + def test_random_reproducibility_by_seed(self): + tok = MockTokenizer() + dataset = ["a b c d e", "f g h i j k"] + a = get_dataloader( + tok, 4, dataset, num_samples=6, strategy="random", seed=123 + ) + b = get_dataloader( + tok, 4, dataset, num_samples=6, strategy="random", seed=123 + ) + c = get_dataloader( + tok, 4, dataset, num_samples=6, strategy="random", seed=124 + ) + self.assertTrue(ops.all(ops.equal(a, b))) + self.assertFalse(ops.all(ops.equal(a, c))) + + def test_linspace_windows_match_expected(self): + tok = MockTokenizer() + dataset = ["aa bb cc dd", "ee ff gg"] + seq_len, n = 3, 5 + eos_id = None + + all_tokens = build_all_tokens_strings(dataset, tok, eos_id=eos_id) + max_start = all_tokens.size - seq_len + expected_starts = np.linspace(0, max_start, n, dtype=np.int64) + + expected = sliding_windows(all_tokens, seq_len)[expected_starts] + got = get_dataloader( + tok, seq_len, dataset, num_samples=n, strategy="linspace" + ) + self.assertTrue( + ops.all(ops.equal(got[:, 0, :], expected.astype(np.int32))) + ) + + def test_strided_override_respected(self): + """Tests that strided windows are disjoint and cover the input.""" + tok = MockTokenizer() + # 20 tokens total + # with seq_len=4 and stride=4, we expect disjoint chunks + # in order (modulo offset) + dataset = [" ".join([f"t{i}" for i in range(20)])] + seq_len, n, stride = 4, 5, 4 + + out = get_dataloader( + tok, + seq_len, + dataset, + num_samples=n, + strategy="strided", + stride=stride, + seed=0, + ) + + # Validate that each sample is a contiguous run + # of length seq_len from the flattened stream + flat = build_all_tokens_strings(dataset, tok) + for s in out[:, 0, :]: + # Each window should appear as a slice in the flat stream + # (This is a soft check; exact start positions depend on offset.) + joined = " ".join(map(str, s.tolist())) + self.assertIn(joined, " ".join(map(str, flat.tolist()))) + + def test_eos_insertion_is_present_in_some_window_with_linspace(self): + tok = MockTokenizer() + dataset = ["aa aa", "bb bb"] # len = 5 + 1(EOS) + 5 = 11 + eos = 9999 + seq_len = 3 + n = 3 + + out = get_dataloader( + tok, + seq_len, + dataset, + num_samples=n, + strategy="linspace", + eos_id=eos, + ) + + # linspace starts -> [0, 4, 8]; the middle window [4:7] + # includes EOS at 5 + windows = out[:, 0, :] + self.assertTrue( + np.any(np.any(windows == eos, axis=1)), + "Expected EOS to appear in at least one sampled window with " + "linspace.", + ) + + def test_get_dataloader_error_scenarios(self): + """Tests error cases for get_dataloader.""" + with pytest.raises(ValueError, match="Provided dataset is empty"): + get_dataloader( + tokenizer=MockTokenizer(), + sequence_length=10, + dataset=[], + num_samples=10, + ) + with self.assertRaisesRegex( + TypeError, + "The `dataset` argument must be an iterable.*Got type: str.*" + "Please pass the loaded dataset directly.", + ): + get_dataloader( + tokenizer=MockTokenizer(), + sequence_length=10, + dataset="wikitext2", + num_samples=10, + ) + + def test_apply_gptq_on_multi_block_model(self): + """Tests quantization on a model with multiple blocks.""" + model = models.Sequential( + [ + layers.Embedding(VOCAB_SIZE, 128), + TransformerBlock(), + TransformerBlock(), + ] + ) + model.build(input_shape=(None, 10)) + config = GPTQConfig( + dataset=["test data"], tokenizer=MockTokenizer(), group_size=32 + ) + model.quantize("gptq", config=config) + + @parameterized.named_parameters( + ( + "no_embedding_layer", + models.Sequential([layers.Dense(10)]), + "Could not automatically find an embedding layer", + ), + ( + "no_transformer_blocks", + models.Sequential( + [layers.Embedding(VOCAB_SIZE, 10), layers.Dense(10)] + ), + "Could not automatically find any transformer-like blocks", + ), + ( + "backbone_no_layers", + _get_model_with_backbone(has_transformer_layers=False), + "Could not automatically find any transformer-like blocks", + ), + ( + "backbone_no_embedding", + _get_model_with_backbone(embedding_name="wrong_name"), + "Could not automatically find an embedding layer in the model", + ), + ) + def test_apply_gptq_with_unsupported_architectures( + self, model, error_message + ): + """Tests that quantize fails correctly for various unsupported + model architectures.""" + if not model.built: + model.build(input_shape=(None, 10)) + + config = GPTQConfig(dataset=["test"], tokenizer=MockTokenizer()) + with self.assertRaisesRegex(ValueError, error_message): + model.quantize("gptq", config=config) diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py new file mode 100644 index 000000000000..e0f4dd8c9744 --- /dev/null +++ b/keras/src/quantizers/gptq_test.py @@ -0,0 +1,634 @@ +from collections.abc import Callable + +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.quantizers.gptq import GPTQ +from keras.src.quantizers.gptq import _stable_permutation +from keras.src.quantizers.gptq import gptq_quantize_matrix +from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantizers import dequantize_with_sz_map +from keras.src.quantizers.quantizers import dequantize_with_zero_point +from keras.src.quantizers.quantizers import quantize_with_zero_point +from keras.src.testing.test_utils import named_product + +VOCAB_SIZE = 1000 +SEQ_LEN = 128 +NUM_SAMPLES = 16 +W_BITS = 4 +NUM_CLASSES = 32 + +CALIBRATION_TEXT = """ +GPTQ (Generative Pre-trained Transformer Quantization) is an advanced +post-training quantization (PTQ) algorithm designed to compress large +language models with minimal accuracy degradation. It addresses the +challenge of reducing model size from high-precision formats like +FP16 to low-bit integers (e.g., INT4, INT3) without the need for +expensive retraining. The algorithm operates on a layer-by-layer basis, +treating the quantization of each weight matrix $W$ as a +reconstruction problem. Its objective is to find a quantized weight +matrix $\hat{W}$ that minimizes the mean squared error of the layer's +output, formulated as $\arg\min_{\hat{W}} \|WX - \hat{W}X\|_F^2$, +where $X$ is a set of calibration inputs. GPTQ's primary innovation +is its greedy, error-compensating quantization process, based on the +Optimal Brain Quantizer (OBQ) framework. It quantizes weights one by +one (or in small groups). After quantizing a single weight $w_q$ to +its discrete value $\hat{w}_q$, it introduces a quantization error of +$\delta = w_q - \hat{w}_q$. This error is then immediately compensated +for by updating all remaining, unquantized weights in the layer. +The update step is guided by second-order information, specifically +the inverse of the Hessian matrix ($\mathbf{H}^{-1}$) of the layer's +reconstruction loss. This inverse Hessian provides a measure of weight +saliency and inter-dependencies. The update applied to the remaining +weights is calculated based on $\delta$ and the corresponding entries +in $\mathbf{H}^{-1}$, effectively propagating the error to less +sensitive weights. This sequential compensation minimizes the +cumulative error across the entire layer, allowing GPTQ to maintain +high model fidelity, as measured by perplexity, even at aggressive +bit-rates. +""" + + +def _get_test_layer(layer_type, kernel_shape): + if layer_type == "Dense": + layer = layers.Dense(units=kernel_shape[1]) + layer.build(input_shape=(None, kernel_shape[0])) + elif layer_type == "EinsumDense": + output_shape = (kernel_shape[1], kernel_shape[2]) + layer = layers.EinsumDense( + equation="...h,hio->...io", output_shape=output_shape + ) + layer.build(input_shape=(None, kernel_shape[0])) + else: + layer = layers.Layer() + return layer + + +@pytest.mark.requires_trainable_backend +class GPTQTest(testing.TestCase): + def test_initialization_with_dense_layer(self): + mock_layer = _get_test_layer("Dense", kernel_shape=(64, 128)) + + gptq_instance = GPTQ(mock_layer) + self.assertEqual(gptq_instance.rows, 64) + self.assertEqual(gptq_instance.columns, 128) + self.assertEqual(gptq_instance.hessian.shape, (64, 64)) + + def test_initialization_with_einsumdense_3d(self): + mock_layer = _get_test_layer("EinsumDense", kernel_shape=(64, 4, 32)) + gptq_instance = GPTQ(mock_layer) + self.assertEqual(gptq_instance.rows, 64) + self.assertEqual(gptq_instance.columns, 4 * 32) + self.assertEqual(gptq_instance.hessian.shape, (64, 64)) + + def test_update_hessian(self): + dense = _get_test_layer("Dense", kernel_shape=(16, 32)) + dense_gptq = GPTQ(dense) + + rng = np.random.default_rng(seed=42) + batch1 = rng.standard_normal(size=(8, 16)).astype("float32") + + dense_gptq.update_hessian_with_batch(batch1) + self.assertEqual(dense_gptq.num_samples, 8) + H1 = dense_gptq.hessian + + batch2 = rng.standard_normal(size=(4, 16)).astype("float32") + + dense_gptq.update_hessian_with_batch(batch2) + self.assertEqual(dense_gptq.num_samples, 12) + + H2 = dense_gptq.hessian + + self.assertNotAllClose(H1, H2) + + def test_gptq_on_single_layer(self): + rng = np.random.default_rng(seed=42) + dense = _get_test_layer("Dense", kernel_shape=(16, 32)) + + config = GPTQConfig( + dataset=None, + tokenizer=None, + weight_bits=4, + symmetric=False, + group_size=-1, + ) + + dense.quantize("gptq", config=config) + dense_gptq = GPTQ( + dense, + config, + ) + + calibration_data = rng.standard_normal(size=(128, 16)).astype("float32") + + dense_gptq.update_hessian_with_batch(calibration_data) + dense_gptq.quantize_and_correct_layer() + + self.assertEqual(backend.standardize_dtype(dense.kernel.dtype), "uint8") + + dense_gptq.free() + self.assertIsNone(getattr(dense_gptq, "hessian", None)) + self.assertIsNone(getattr(dense_gptq, "layer", None)) + + def test_unsupported_layer_error(self): + unsupported_layer = _get_test_layer("Unsupported", kernel_shape=None) + with self.assertRaisesRegex(TypeError, "Unsupported layer type"): + GPTQ(unsupported_layer) + + def test_update_hessian_invalid_input(self): + rng = np.random.default_rng(seed=42) + dense = _get_test_layer("Dense", kernel_shape=(16, 32)) + gptq_instance = GPTQ(dense) + with self.assertRaisesRegex(ValueError, "cannot be None"): + gptq_instance.update_hessian_with_batch(None) + with self.assertRaisesRegex(ValueError, "cannot be empty"): + gptq_instance.update_hessian_with_batch(np.empty((0, 16))) + with self.assertRaisesRegex(ValueError, "match input features"): + bad_input = rng.standard_normal(size=(8, 99)) + gptq_instance.update_hessian_with_batch(bad_input) + + def test_streaming_equals_big_batch(self): + """Tests that streaming updates match big batch updates.""" + # dummy inputs + x = ops.array(np.random.randn(100, 7), "float32") + + # One-shot hessian update + layer_1 = layers.Dense(5, use_bias=False) + layer_1.build(input_shape=(None, 7)) + + g1 = GPTQ(layer_1) + g1.update_hessian_with_batch(x) + + # Streamed hessian update + layer_2 = layers.Dense(5, use_bias=False) + layer_2.build(input_shape=(None, 7)) + g2 = GPTQ(layer_2) + g2.update_hessian_with_batch(x[:50]) + g2.update_hessian_with_batch(x[50:]) + + # Both the one-shot and streamed hessian updates should match + self.assertAllClose(g1.hessian, g2.hessian, rtol=1e-6, atol=1e-6) + + def test_hessian_matches_closed_form(self): + """Tests that the Hessian matches the closed-form solution.""" + x = ops.array(np.random.randn(128, 7), "float32") + layer = layers.Dense(5, use_bias=False) + layer.build((None, 7)) + g = GPTQ(layer) + g.update_hessian_with_batch(x) + + expected = ops.multiply( + ops.divide(2.0, x.shape[0]), ops.matmul(ops.transpose(x), x) + ) + self.assertAllClose(g.hessian, expected, rtol=1e-6, atol=1e-6) + + def test_higher_rank_inputs_are_reshaped(self): + """Tests that higher-rank inputs are reshaped correctly.""" + # x: [batch, time, feat] + x = ops.array(np.random.randn(10, 4, 7), "float32") + x_flat = ops.reshape(x, (-1, ops.shape(x)[-1])) + + layer1 = layers.Dense(5, use_bias=False) + layer1.build((None, 7)) + g1 = GPTQ(layer1) + g1.update_hessian_with_batch(x) + + layer2 = layers.Dense(5, use_bias=False) + layer2.build((None, 7)) + g2 = GPTQ(layer2) + g2.update_hessian_with_batch(x_flat) + + self.assertAllClose(g1.hessian, g2.hessian, rtol=1e-6, atol=1e-6) + + def test_raises_on_feature_mismatch(self): + x = ops.array(np.random.randn(8, 7), "float32") + layer = layers.Dense(5, use_bias=False) + layer.build((None, 6)) # wrong in_features + g = GPTQ(layer) + + with self.assertRaisesRegex(ValueError, "do not match input features"): + g.update_hessian_with_batch(x) + + with self.assertRaisesRegex(ValueError, "cannot be None"): + g.update_hessian_with_batch(None) + with self.assertRaisesRegex(ValueError, "cannot be empty"): + g.update_hessian_with_batch( + ops.array(np.empty((0, 7), dtype="float32")) + ) + + def test_num_samples_accumulates_correctly(self): + """Tests that the number of samples is accumulated correctly when + streaming updates are used.""" + x = ops.array(np.random.randn(64, 7), "float32") + layer = layers.Dense(5, use_bias=False) + layer.build((None, 7)) + g = GPTQ(layer) + + g.update_hessian_with_batch(x[:5]) + g.update_hessian_with_batch(x[5:30]) + g.update_hessian_with_batch(x[30:]) + + self.assertEqual(g.num_samples, 64) + + def test_numeric_stability_large_values(self): + """Tests numeric stability of hessian update with large input values.""" + x = ops.multiply(ops.array(np.random.randn(32, 7), "float32"), 1e6) + layer = layers.Dense(5, use_bias=False) + layer.build((None, 7)) + + g = GPTQ(layer) + g.update_hessian_with_batch(x) + + # Should be finite and symmetric + self.assertTrue(ops.all(ops.isfinite(g.hessian))) + self.assertTrue(ops.all(ops.equal(g.hessian, ops.transpose(g.hessian)))) + + def test_einsumdense_2d_kernel_hessian_shape(self): + x = layers.Input((7,)) + y = layers.EinsumDense("ab,bc->ac", output_shape=(5,))(x) + model = keras.Model(x, y) + einsum_dense_layer = next( + l for l in model.layers if isinstance(l, layers.EinsumDense) + ) + + g = GPTQ(einsum_dense_layer) + + # should infer rows==7 + self.assertEqual(ops.shape(g.hessian), (7, 7)) + + def test_einsumdense_3d_kernel_streaming_equals_big_batch(self): + """Tests that streaming updates to the Hessian are equivalent to a big + batch update.""" + # Construct a tiny attention-like einsum with 3D kernel + x = layers.Input((7,)) + qkv = layers.EinsumDense("bf,fhk->bhk", output_shape=(2, 3))( + x + ) # heads=2, head_dim=3 + model = keras.Model(x, qkv) + einsum_dense_layer = next( + l for l in model.layers if isinstance(l, layers.EinsumDense) + ) + + x = ops.array(np.random.randn(50, 7), "float32") + + g1 = GPTQ(einsum_dense_layer) + g1.update_hessian_with_batch(x) + + g2 = GPTQ(einsum_dense_layer) + g2.update_hessian_with_batch(x[:20]) + g2.update_hessian_with_batch(x[20:]) + + self.assertAllClose(g1.hessian, g2.hessian, rtol=1e-6, atol=1e-6) + + def test_identity_inv_hessian_matches_direct_quantization(self): + """Tests that the matrix quantization without error correction + matches the direct implementation.""" + in_features, out_features = 16, 8 + weights = ops.reshape( + ops.linspace( + -0.9, 1.1, in_features * out_features, dtype="float32" + ), + (in_features, out_features), + ) + weights_transpose = ops.transpose(weights) + + # inverse_hessian = identity; no cross-feature correction + # (since all off-diagonal elements are zero), which means + # there is no interaction between different features + inverse_hessian = ops.eye(in_features, dtype="float32") + + quantized_weights, scale_map, zero_map, g_idx = gptq_quantize_matrix( + weights_transpose, + inverse_hessian, + blocksize=128, + group_size=1, # per-column quantization + activation_order=False, + compute_scale_zero=_compute_scale_zero, + ) + + dequantized_weights = dequantize_with_sz_map( + quantized_weights, scale_map, zero_map, g_idx + ) + + # Compare function output with columnwise direct application + # of quantization. + out = ops.zeros_like(weights_transpose) + for j in range(ops.shape(weights_transpose)[1]): + column = weights_transpose[:, j : j + 1] + scale, zero, maxq = _compute_scale_zero(column) + quantized_col = quantize_with_zero_point(column, scale, zero, maxq) + dequantized = dequantize_with_zero_point(quantized_col, scale, zero) + out = ops.slice_update( + out, (0, j), ops.expand_dims(dequantized[:, 0], 1) + ) + + self.assertAllClose(dequantized_weights, out, atol=1e-6) + + def test_activation_order_produces_equivalent_weights(self): + """ + Tests that quantizing with `activation_order=True` yields the same + final weights as `activation_order=False`, because the internal + permutation should be undone. + """ + # Set up shared inputs and a non-trivial permutation. + in_features, out_features = 8, 6 + initial_weights = ops.array( + np.random.randn(in_features, out_features), "float32" + ) + + # Generate a Hessian that creates a non-trivial permutation. + hessian_diag = ops.random.shuffle( + ops.linspace(10.0, 1.0, in_features, dtype="float32") + ) + hessian_matrix = ops.diag(hessian_diag) + + # Sanity check: ensure the permutation is not the identity. + perm = _stable_permutation(hessian_diag) + self.assertFalse(ops.all(ops.equal(perm, ops.arange(in_features)))) + + def create_and_quantize(use_activation_order): + layer = layers.Dense(out_features, use_bias=False) + layer.build((None, in_features)) + layer.set_weights([ops.copy(initial_weights)]) + + config = GPTQConfig( + dataset=None, + tokenizer=None, + group_size=-1, + activation_order=use_activation_order, + ) + layer.quantize("gptq", config=config) + + quantizer = GPTQ(layer, config) + quantizer.hessian = hessian_matrix + quantizer.quantize_and_correct_layer() + return layer + + # Quantize two layers, one with and one without activation ordering. + ordered_layer = create_and_quantize(use_activation_order=True) + unordered_layer = create_and_quantize(use_activation_order=False) + + self.assertAllClose( + ordered_layer.get_weights()[0], + unordered_layer.get_weights()[0], + msg="Weights should be identical as the permutation is undone.", + ) + + +def _compute_scale_zero(x, **_): + # Per-column asymmetric int4 example + # scale = (max-min)/maxq, zero = round(-min/scale) + maxq = 15.0 + xmin = ops.min(x, axis=0, keepdims=True) + xmax = ops.max(x, axis=0, keepdims=True) + scale = ops.divide(ops.subtract(xmax, xmin), ops.add(maxq, 1e-8)) + zero = ops.round(ops.divide(ops.negative(xmin), ops.add(scale, 1e-8))) + return scale, zero, maxq + + +def _get_sequence_classifier(): + """Transformer-based sequence classifier + + tokens -> Embedding -> Transformer -> GAP -> Dense(num_classes). + """ + embed_dim = 32 + num_heads = 4 + ff_dim = 32 + + class SimpleTransformerBlock(layers.Layer): + def __init__(self, embed_dim, num_heads, ff_dim, **kwargs): + super().__init__(**kwargs) + + self.att = layers.MultiHeadAttention( + num_heads=num_heads, key_dim=embed_dim // num_heads + ) + self.ffn = models.Sequential( + [ + layers.Dense(ff_dim, activation="relu"), + layers.Dense(embed_dim), + ] + ) + self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) + self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) + + def call(self, inputs): + attention_output = self.att(inputs, inputs) + out1 = self.layernorm1(inputs + attention_output) + ffn_output = self.ffn(out1) + return self.layernorm2(out1 + ffn_output) + + inputs = layers.Input(shape=(SEQ_LEN,), dtype="int32") + x = layers.Embedding(VOCAB_SIZE, embed_dim)(inputs) + x = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)(x) + x = layers.GlobalAveragePooling1D()(x) + outputs = layers.Dense(NUM_CLASSES)(x) + return models.Model(inputs, outputs) + + +def _get_simple_model(): + return models.Sequential([layers.Dense(10, input_shape=(5,))]) + + +def _mean_kl(p, q): + # Add small epsilon for numerical stability + eps = 1e-8 + p = ops.clip(p, eps, 1.0) + q = ops.clip(q, eps, 1.0) + # Compute KL divergence + # D_KL(P || Q) = sum(P * log(P / Q)) + return ops.mean( + ops.sum(ops.multiply(p, ops.subtract(ops.log(p), ops.log(q))), axis=-1) + ) + + +def _top1_match_rate(a_logits, b_logits): + """Calculates the top-1 match rate between two sets of logits. + + Formula: T = 1/N * sum(1{argmax(a_i) == argmax(b_i)}) + """ + return ops.mean( + ops.equal(ops.argmax(a_logits, axis=-1), ops.argmax(b_logits, axis=-1)) + ) + + +DATASETS = { + "string_dataset": lambda: _string_dataset( + CALIBRATION_TEXT, NUM_SAMPLES, SEQ_LEN + ), + "token_dataset": lambda: _token_dataset(NUM_SAMPLES, SEQ_LEN), +} + +CONFIGS = { + "default": {}, + "per_channel": {"group_size": -1, "per_channel": True}, + "act_order": {"activation_order": True}, + "symmetric": {"symmetric": True}, + "group_wise": {"group_size": 8}, + "group_wise_act_order": {"group_size": 8, "activation_order": True}, + "symmetric_act_order": {"symmetric": True, "activation_order": True}, + "symmetric_per_channel": {"symmetric": True, "per_channel": True}, + "group_wise_symmetric_8bit": { + "group_size": 8, + "symmetric": True, + "weight_bits": 8, + }, +} + + +def _pad_or_trim_1d(ids, length): + """Pads or trims a 1D array to a specified length.""" + ids = ops.ravel(ops.array(ids, "int64")) + if len(ids) < length: + ids = ops.concatenate( + [ids, ops.zeros(length - len(ids), dtype=ids.dtype)] + ) + else: + ids = ids[:length] + return ids + + +def _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN): + """Tokenizes strings to char-IDs or passes through int arrays; + outputs shape (1, seq_len).""" + + def _tok(x): + if isinstance(x, str): + ids = ops.convert_to_tensor( + np.fromiter((ord(c) % vocab_size for c in x), dtype=np.int64) + ) + else: + ids = np.asarray(x, dtype=np.int64) + ids = _pad_or_trim_1d(ids, seq_len) + return ids[None, :] + + _tok.tokenize = _tok + return _tok + + +def _string_dataset( + long_text, num_samples=NUM_SAMPLES, sequence_length=SEQ_LEN +): + """Yields string slices""" + rng = np.random.default_rng(seed=0) + L = max(1, len(long_text) - sequence_length) + for _ in range(num_samples): + start = rng.integers(0, L) if L > 1 else 0 + yield long_text[start : start + sequence_length] + + +def _token_dataset( + num_samples=NUM_SAMPLES, sequence_length=SEQ_LEN, vocab_size=VOCAB_SIZE +): + """Yields tokenized samples.""" + rng = np.random.default_rng(seed=0) + for _ in range(num_samples): + yield rng.integers( + low=0, high=vocab_size, size=(1, sequence_length), dtype=np.int64 + ) + + +@pytest.mark.requires_trainable_backend +@pytest.mark.skipif( + backend.backend() == "torch", + reason="torch gives low accuracy on CI, but works well locally", +) +class TestModelQuantization(testing.TestCase): + @parameterized.named_parameters( + named_product( + [ + {"testcase_name": dataset_id, "dataset": dataset} + for dataset_id, dataset in DATASETS.items() + ], + [ + {"testcase_name": config_id, "config": config} + for config_id, config in CONFIGS.items() + ], + ) + ) + def test_quantize_gptq_combinations(self, dataset, config): + """Tests GPTQ quantization on a tiny transformer classifier. + + Validates classification performance of the quantized model + with respect to the full-precision baseline. + """ + rng = np.random.default_rng(seed=321) + keras.utils.set_random_seed(123) + + # Build the calibration set. + calibration_set = list( + dataset() if isinstance(dataset, Callable) else dataset + ) + self.assertNotEmpty(calibration_set) + + # Build classifier and tokenizer + model = _get_sequence_classifier() + tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN) + + # Build an eval batch drawn from the SAME distribution as calibration + batch_size = min(8, len(calibration_set)) + eval_samples = [ + calibration_set[rng.integers(0, len(calibration_set))] + for _ in range(batch_size) + ] + x_eval = ops.concatenate([tokenizer(s) for s in eval_samples], axis=0) + + # Baseline logits + y_ref = model.predict(x_eval) + + base_cfg = dict( + dataset=calibration_set, + tokenizer=tokenizer, + weight_bits=W_BITS, + num_samples=NUM_SAMPLES, + sequence_length=SEQ_LEN, + group_size=32, + symmetric=False, + activation_order=False, + ) + gptq_cfg = GPTQConfig(**{**base_cfg, **config}) + + # Quantize + model.quantize("gptq", config=gptq_cfg) + + # Post-quant logits + y_q = model.predict(x_eval) + + top1_match = _top1_match_rate(y_ref, y_q) + + p_ref, p_q = ops.softmax(y_ref), ops.softmax(y_q) + kl = _mean_kl(p_ref, p_q) + + self.assertGreaterEqual( + top1_match, 0.5, f"Top-1 agreement too low: {top1_match:.3f}" + ) + self.assertLessEqual(kl, 0.30, f"KL divergence too high: {kl:.3f}") + + @parameterized.named_parameters( + { + "testcase_name": "gptq_with_invalid_config", + "mode": "gptq", + "config": {"weight_bits": 4}, + "expected_exception": ValueError, + "error_msg": "Mode 'gptq' requires a valid `config`", + }, + { + "testcase_name": "non_gptq_with_unsupported_config", + "mode": "int8", + "config": GPTQConfig(dataset=["a"], tokenizer=lambda x: x), + "expected_exception": ValueError, + "error_msg": "only supported for 'gptq'", + }, + ) + def test_quantize_scenarios( + self, mode, config, expected_exception, error_msg + ): + model = _get_simple_model() + with self.assertRaisesRegex(expected_exception, error_msg): + model.quantize(mode, config=config) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py new file mode 100644 index 000000000000..d9ef671b6fc9 --- /dev/null +++ b/keras/src/quantizers/quantizers.py @@ -0,0 +1,939 @@ +import ml_dtypes +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.backend import KerasTensor +from keras.src.backend import any_symbolic_tensors +from keras.src.backend.common.backend_utils import canonicalize_axis +from keras.src.backend.common.backend_utils import standardize_axis_for_numpy +from keras.src.ops.operation import Operation +from keras.src.quantizers.gptq_config import GPTQConfig + +"""Int8-related classes and methods""" + + +@keras_export(["keras.Quantizer", "keras.quantizers.Quantizer"]) +class Quantizer: + def __init__(self, output_dtype="int8"): + self.output_dtype = output_dtype + + def __call__(self, x): + """Compute a quantized output from an input tensor.""" + return x + + @classmethod + def from_config(cls, config): + """Creates a quantizer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same quantizer from the config + dictionary. + + This method is used by Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Args: + config: A Python dictionary, typically the output of get_config. + + Returns: + A quantizer instance. + """ + return cls(**config) + + def get_config(self): + """Returns the config of the quantizer. + + A quantizer config is a Python dictionary (serializable) + containing all configuration parameters of the quantizer. + The same quantizer can be reinstantiated later + (without any saved state) from this configuration. + + This method is optional if you are just training and executing models, + exporting to and from SavedModels, or using weight checkpoints. + + This method is required for Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Returns: + Python dictionary. + """ + raise NotImplementedError(f"{self} does not implement get_config()") + + +@keras_export("keras.quantizers.abs_max_quantize") +def abs_max_quantize( + inputs, + axis, + value_range=(-127, 127), + dtype="int8", + epsilon=backend.epsilon(), + to_numpy=False, +): + if to_numpy: + # Save memory on the device using numpy + original_dtype = backend.standardize_dtype(inputs.dtype) + inputs = ops.convert_to_numpy(inputs) + axis = standardize_axis_for_numpy(axis) + scale = np.divide( + value_range[1], + np.add(np.max(np.abs(inputs), axis=axis, keepdims=True), epsilon), + ) + outputs = np.multiply(inputs, scale) + outputs = np.clip(np.round(outputs), value_range[0], value_range[1]) + outputs = outputs.astype(dtype) + return ops.convert_to_tensor(outputs), ops.convert_to_tensor( + scale, dtype=original_dtype + ) + + inputs = ops.convert_to_tensor(inputs) + scale = ops.divide( + value_range[1], + ops.add(ops.max(ops.abs(inputs), axis=axis, keepdims=True), epsilon), + ) + scale = ops.cast(scale, backend.standardize_dtype(inputs.dtype)) + outputs = ops.multiply(inputs, scale) + outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1]) + outputs = ops.cast(outputs, dtype) + return outputs, scale + + +@keras_export("keras.quantizers.AbsMaxQuantizer") +class AbsMaxQuantizer(Quantizer): + def __init__( + self, + axis, + value_range=(-127, 127), + epsilon=backend.epsilon(), + output_dtype="int8", + ): + Quantizer.__init__(self, output_dtype=output_dtype) + if isinstance(axis, int): + axis = (axis,) + self.axis = tuple(axis) + self.value_range = value_range + self.epsilon = epsilon + + def __call__(self, x): + quantized_x, scale = abs_max_quantize( + x, self.axis, self.value_range, self.output_dtype, self.epsilon + ) + return quantized_x, scale + + def get_config(self): + return { + "axis": self.axis, + "value_range": self.value_range, + "epsilon": self.epsilon, + "output_dtype": self.output_dtype, + } + + +def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): + """Adjusts and nudges the quantization range for better accuracy.""" + # Use higher precision for the computation. + compute_dtype = backend.result_type(min_range.dtype, "float32") + min_range = ops.cast(min_range, compute_dtype) + max_range = ops.cast(max_range, compute_dtype) + + quant_max = (1 << num_bits) - 1 + quant_min = 0 if not narrow_range else 1 + diff_range = ops.subtract(max_range, min_range) + + # Calculate the scale and ensure it's positive + scale = ops.divide(diff_range, quant_max - quant_min) + + # Re-calculate the inverse to avoid loss of precision + inv_scale = ops.divide(quant_max - quant_min, diff_range) + + # Calculate the zero point from the min range + zero_point_from_min = quant_min - ops.divide(min_range, scale) + + # Ensure zero point is within valid range [0, quant_max] + zero_point = ops.clip(zero_point_from_min, quant_min, quant_max) + + # Nudge zero point if it's very close to an integer + nudged_zero_point = ops.round(zero_point) + + # Calculate nudged limits + nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale) + nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale) + + return nudged_min, nudged_max, scale, inv_scale + + +class FakeQuantWithMinMaxVars(Operation): + def __init__(self, num_bits=8, narrow_range=False, axis=None): + super().__init__() + self.num_bits = num_bits + self.narrow_range = narrow_range + self.axis = axis + + def call(self, inputs, min_vals, max_vals): + return fake_quant_with_min_max_vars( + inputs, + min_vals, + max_vals, + num_bits=self.num_bits, + narrow_range=self.narrow_range, + axis=self.axis, + ) + + def compute_output_spec(self, inputs, min_vals, max_vals): + return KerasTensor(inputs.shape, dtype=inputs.dtype) + + +@keras_export("keras.quantizers.fake_quant_with_min_max_vars") +def fake_quant_with_min_max_vars( + inputs, + min_vals, + max_vals, + num_bits=8, + narrow_range=False, + axis=None, +): + """Perform per-tensor or per-channel fake quantization. + + `[min_vals, max_vals]` define the clamping range for the `inputs`. + + The `inputs` are quantized into the quantization range: + - `[0, 2^num_bits - 1]` when `narrow_range=False` + - `[1, 2^num_bits - 1]` when `narrow_range=True` + + After quantization, the values are dequantized and output as floats within + the `[min_vals, max_vals]` interval. + + This operation supports gradient computation, allowing `min_vals` and + `max_vals` to be trained. + + Args: + inputs: Input Keras tensor of float dtype. + min_vals: A global minimum scalar or a per-channel minimum tensor. + max_vals: A global maximum scalar or a per-channel maximum tensor. + num_bits: Quantization bit width (e.g., `8` for int8). Defaults to `8`. + narrow_range: Whether to use narrow quantization range. Defaults to + `False`. + axis: Axis along which to perform per-channel quantization. If `None`, + per-tensor quantization is performed. Defaults to `None`. + + + Returns: + Tensor: A Keras tensor with fake quantization applied. + """ + if any_symbolic_tensors((inputs,)): + return FakeQuantWithMinMaxVars().symbolic_call( + inputs, min_vals, max_vals + ) + + inputs = ops.convert_to_tensor(inputs) + min_vals = ops.convert_to_tensor(min_vals) + max_vals = ops.convert_to_tensor(max_vals) + num_bits = int(num_bits) + + if axis is not None: + axis = canonicalize_axis(axis, inputs.ndim) + + # Shortcut for TensorFlow backend by using `tf.quantization.fake_quant_*` + # apis. This is necessary to be recognizable for the TFLite converter. + if backend.backend() == "tensorflow": + import tensorflow as tf + + # `tf.quantization.fake_quant_*` only supports float32. + dtype = backend.standardize_dtype(inputs.dtype) + if axis is None: + outputs = tf.quantization.fake_quant_with_min_max_vars( + ops.cast(inputs, "float32"), + ops.cast(ops.reshape(min_vals, ()), "float32"), + ops.cast(ops.reshape(max_vals, ()), "float32"), + num_bits=num_bits, + narrow_range=narrow_range, + ) + return ops.cast(outputs, dtype=dtype) + else: + # `tf.quantization.fake_quant_with_min_max_vars_per_channel` only + # supports the last channel for the per-channel quantization. We + # use `ops.swapaxes` for the pre- and post-processing. + last_axis = inputs.ndim - 1 + inputs = ops.swapaxes(inputs, axis, last_axis) + outputs = tf.quantization.fake_quant_with_min_max_vars_per_channel( + ops.cast(inputs, "float32"), + ops.cast(min_vals, "float32"), + ops.cast(max_vals, "float32"), + num_bits=num_bits, + narrow_range=narrow_range, + ) + outputs = ops.cast(outputs, dtype=dtype) + return ops.swapaxes(outputs, last_axis, axis) + + @ops.custom_gradient + def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): + dtype = backend.standardize_dtype(x.dtype) + + # Calculate quantization parameters for all channels at once + nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge( + min_val, max_val, num_bits, narrow_range + ) + + quant_zero = ops.floor( + ops.add(ops.multiply(-nudged_min, inv_scale), 0.5) + ) + x_clamped = ops.clip( + x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype) + ) + x_clamped_shifted = ops.subtract(x_clamped, nudged_min) + result = ops.multiply( + ops.floor( + ops.add( + ops.subtract( + ops.multiply(x_clamped_shifted, inv_scale), quant_zero + ), + 0.5, + ) + ), + scale, + ) + result = ops.cast(result, dtype=dtype) + + # Create gradient mask for all channels + masks = ops.logical_and( + ops.greater_equal(x, nudged_min), ops.less_equal(x, nudged_max) + ) + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + + # Gradient for x + dx = ops.where(masks, upstream, 0.0) + axes = [i for i in range(len(dx.shape)) if i != axis] + + # Gradient for min_val + # When x is clipped to min, the gradient flows to min_val + min_mask = ops.less_equal(x, nudged_min) + grad_min = ops.where(min_mask, upstream, 0.0) + if axis is not None: + grad_min = ops.sum(grad_min, axis=axes) + else: + grad_min = ops.sum(grad_min) + + # Gradient for max_val + # When x is clipped to max, the gradient flows to max_val + max_mask = ops.greater_equal(x, nudged_max) + grad_max = ops.where(max_mask, upstream, 0.0) + if axis is not None: + grad_max = ops.sum(grad_max, axis=axes) + else: + grad_max = ops.sum(grad_max) + + return dx, grad_min, grad_max + + return result, grad + + return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals) + + +"""Float8-related methods""" + + +@keras_export("keras.quantizers.compute_float8_scale") +def compute_float8_scale(amax, scale, dtype_max, margin=0): + # The algorithm for computing the new scale is sourced from + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas + # wherein the `original_scale` corresponds to the reciprocal of the + # `scale` passed in this function. + scale = ops.reciprocal(scale) + sf = ops.divide(ops.divide(dtype_max, amax), 2**margin) + sf = ops.where(amax > 0.0, sf, scale) + sf = ops.where(ops.isfinite(amax), sf, scale) + return ops.reciprocal(sf) + + +@keras_export("keras.quantizers.compute_float8_amax_history") +def compute_float8_amax_history(x, amax_history): + amax_update = ops.cast(ops.max(ops.abs(x)), amax_history.dtype) + new_amax_history = ops.scatter_update( + ops.roll(amax_history, shift=-1), + [[0]], + ops.reshape(amax_update, [1]), + ) + return new_amax_history + + +@keras_export("keras.quantizers.quantize_and_dequantize") +def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): + # Quantize + quantized_dtype_max = ops.cast( + float(ml_dtypes.finfo(quantized_dtype).max), compute_dtype + ) + x = ops.divide(inputs, ops.cast(scale, compute_dtype)) + x = ops.clip(x, -quantized_dtype_max, quantized_dtype_max) + x = ops.cast(x, quantized_dtype) + + # Dequantize + x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype)) + return x + + +@keras_export("keras.quantizers.pack_int4") +def pack_int4(arr, axis=0, dtype="int8"): + """Pack an int4 tensor into an int8 tensor with packed nibbles. + + The input values must already be int8 in the signed range `[-8, 7]` and + represent the desired int4 values. Packing is performed along the specified + axis (default is 0). + + For every two consecutive rows, the **low nibble** of the output byte + stores the value from the first row, and the **high nibble** stores + the value from the second row. + + Args: + arr: An `int8` or `uint8` tensor containing int4 values in the range + `[-8, 7]`. + axis: The axis along which to pack the tensor. Defaults to 0. + dtype: The data type of the input and packed tensor. Can be + `"int8"` or `"uint8"`. Defaults to `"int8"`. + + Returns: + tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is + the packed int8 tensor with int4 values stored in nibbles, + `packed_shape` is the shape of the packed tensor, and `orig_rows` + is the original (unpacked) row count prior to any padding that may + have been inserted when an odd number of rows is supplied. + + Example: + + ```python + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) + True + ``` + """ + if dtype not in ("int8", "uint8"): + raise ValueError( + f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'." + ) + if backend.standardize_dtype(arr.dtype) != dtype: + raise TypeError( + f"Expected {dtype} tensor for packing, got " + f"{backend.standardize_dtype(arr.dtype)}." + ) + + rank = getattr(arr.shape, "rank", None) or len(arr.shape) + + if axis < 0: + axis += rank + + # 1. Bring `axis` to the front. + perm = [axis] + [i for i in range(rank) if i != axis] + inv_perm = [perm.index(i) for i in range(rank)] + transposed = ops.transpose(arr, perm) + + # 2. Pad to even length. + rows = ops.shape(transposed)[0] + needs_pad = ops.equal(ops.mod(rows, 2), 1) + + # Always append one zero row so the tensor shape is static for JAX. If no + # padding is actually needed, we'll slice it away later. + zero_row = transposed[:1, ...] * 0 # same dtype/shape (1, ...) + padded_full = ops.concatenate([transposed, zero_row], axis=0) + + # Number of valid rows after (possible) padding: + # rows + (1 if needs_pad else 0) + rows_packed = rows + ops.cast(needs_pad, "int32") + + # Slice to keep only the valid rows. This keeps the shape rank static while + # allowing the row count to be dynamic. + padded = padded_full[:rows_packed, ...] + + # 3-4. Group in pairs and pack. + low = padded[::2, ...] + high = padded[1::2, ...] + + mask = ops.array(0x0F, dtype=dtype) + low_u = ops.bitwise_and(low, mask) + high_u = ops.bitwise_and(high, mask) + + packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4)) + packed = ops.cast(packed, dtype) + + # 5-6. Restore shape. + packed = ops.transpose(packed, inv_perm) # back to original order + orig_len = rows # number of slices before padding + return packed, ops.shape(packed), orig_len + + +@keras_export("keras.quantizers.unpack_int4") +def unpack_int4(packed, orig_len, axis=0, dtype="int8"): + """Unpack a packed int4 back to an int8 tensor in the range [-8, 7]. + + This function reverses the packing performed by `pack_int4`, restoring + the original int8 tensor (values in the range [-8, 7]) from a packed int8 + tensor where each element contains two int4 values (one in the lower nibble, + one in the upper nibble). + + The function restores the original axis order and removes any + padding that was added during packing. + + Args: + packed: An int8 tensor containing packed int4 values along the + specified axis. Each int8 value encodes two int4 values. + orig_len: The original (unpadded) length of the axis that was + packed. This is used to remove any padding that may have + been added during packing to ensure an even number of rows. + axis: The axis along which the tensor was packed. Defaults to 0. + dtype: The data type of the input and unpacked tensor. Can be + `"int8"` or `"uint8"`. Defaults to `"int8"`. + + Returns: + unpacked: An int8 tensor with the same shape as the original + (unpacked) tensor, with values in the range [-8, 7]. + + Example: + + ```python + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) + True + ``` + """ + if dtype not in ("int8", "uint8"): + raise ValueError( + f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'." + ) + + if backend.standardize_dtype(packed.dtype) not in ("int8", "uint8"): + raise TypeError( + f"Expected int8 or uint8 tensor for unpacking, got {packed.dtype}" + ) + + def to_signed(x): + """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].""" + dtype_x = backend.standardize_dtype(x.dtype) + eight = ops.cast(8, dtype_x) + sixteen = ops.cast(16, dtype_x) + return ops.where(x < eight, x, x - sixteen) + + rank = getattr(packed.shape, "rank", None) or len(packed.shape) + if axis < 0: + axis += rank + + # Fast path for the most common case in Dense layers + if axis == 0 and rank == 2: + # The result of the bitwise op is a wider dtype (e.g., int32). + mask = ops.array(0x0F, dtype=packed.dtype) + low_unpacked = ops.bitwise_and(packed, mask) + high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask) + + if dtype == "int8": + low_unpacked = to_signed(low_unpacked) + high_unpacked = to_signed(high_unpacked) + + low_final = ops.cast(low_unpacked, dtype) + high_final = ops.cast(high_unpacked, dtype) + + # Interleave and reshape + stacked = ops.stack([low_final, high_final], axis=1) + unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:])) + + # Remove padding and return + return unpacked[:orig_len, ...] + + # General case + perm = [axis] + [i for i in range(rank) if i != axis] + inv_perm = [perm.index(i) for i in range(rank)] + transposed = ops.transpose(packed, perm) + + # 1. Split nibbles. + mask = ops.array(0x0F, dtype=packed.dtype) + low = ops.bitwise_and(transposed, mask) + high = ops.bitwise_and(ops.right_shift(transposed, 4), mask) + + # 2. Conditionally convert to signed. + if dtype == "int8": + low = to_signed(low) + high = to_signed(high) + + low = ops.cast(low, dtype) + high = ops.cast(high, dtype) + + # 3. Interleave and reshape. + stacked = ops.stack([low, high], axis=1) + unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:])) + + # 4. Remove padding and restore original layout. + unpacked = unpacked[:orig_len, ...] + unpacked = ops.transpose(unpacked, inv_perm) + + return unpacked + + +class GPTQQuantizer(Quantizer): + """A class that handles the quantization of weights using GPTQ method. + + This class provides methods to find quantization parameters (scale and zero) + for a given tensor and can be used to quantize weights in a GPTQ context. + + Args: + weight_bits: (int) The number of bits to quantize to (e.g., 4). + per_channel: (bool) A flag indicating whether quantization is + applied per-channel (`True`) or per-tensor (`False`). + Defaults to `False`. + symmetric: (bool) A flag indicating whether symmetric (`True`) or + asymmetric (`False`) quantization is used. Defaults to `False`. + group_size: (int) The size of weight groups for quantization. A + value of -1 indicates that grouping is not used. + Defaults to -1. + """ + + def __init__( + self, + config=GPTQConfig(tokenizer=None, dataset=None), + compute_dtype="float32", + ): + Quantizer.__init__(self) + self.weight_bits = config.weight_bits + self.per_channel = config.per_channel + self.symmetric = config.symmetric + self.group_size = config.group_size + self.compute_dtype = compute_dtype + + # These are now determined later by `find_params` + self.scale = None + self.zero = None + self.maxq = None + + def find_params(self, input_tensor, weight=True): + """Finds quantization parameters (scale and zero) for a given tensor.""" + self.scale, self.zero, self.maxq = compute_quantization_parameters( + input_tensor, + bits=self.weight_bits, + symmetric=self.symmetric, + per_channel=self.per_channel, + group_size=self.group_size, + weight=weight, + compute_dtype=self.compute_dtype, + ) + return self.scale, self.zero, self.maxq + + def get_config(self): + config = super().get_config() + config.update( + { + "weight_bits": self.weight_bits, + "per_channel": self.per_channel, + "symmetric": self.symmetric, + "group_size": self.group_size, + } + ) + return config + + @classmethod + def from_config(cls, config): + gptq = GPTQConfig( + tokenizer=None, + dataset=None, + weight_bits=config["weight_bits"], + per_channel=config["per_channel"], + symmetric=config["symmetric"], + group_size=config["group_size"], + ) + return cls(gptq) + + +def compute_quantization_parameters( + x, + *, + bits, + symmetric=False, + per_channel=False, + group_size=-1, + weight=False, + compute_dtype="float32", +): + """ + Computes the scale and zero-point for quantization. + + This function calculates the scale and zero-point required for quantizing + a given tensor `x` based on the specified parameters. It supports grouped, + per-channel, per-tensor, symmetric, and asymmetric quantization - along + with any combinations of these. + + Args: + x: KerasTensor. The input tensor to quantize. + bits: int. The number of bits to quantize to (e.g., 4). + symmetric: bool. Whether to use symmetric quantization. + per_channel: bool. Whether to quantize per channel. + group_size: int. The group size for quantization. + weight: bool. Whether the input tensor is a weight tensor. + + Returns: + scale: KerasTensor. The scale tensor for quantization. + zero: KerasTensor. The zero tensor for quantization. + maxq: scalar. The maximum quantization value. + """ + if x is None: + raise ValueError(f"Input tensor {x} cannot be None.") + + # For weights, we typically expect at least a 2D tensor. + if weight and len(x.shape) < 2: + raise ValueError( + f"Input weight tensor {x} must have a rank of at " + f"least 2, but got rank {len(x.shape)}." + ) + + if ops.size(x) == 0: + raise ValueError("Input tensor 'x' cannot be empty.") + + original_shape = x.shape + + if per_channel: + if weight: + if group_size != -1: + input_reshaped = ops.reshape(x, [-1, group_size]) + else: + input_reshaped = ops.reshape(x, [original_shape[0], -1]) + else: # per-tensor + input_reshaped = ops.reshape(x, [1, -1]) + + # Find min/max values + min_values = ops.min(input_reshaped, axis=1) + max_values = ops.max(input_reshaped, axis=1) + + # Apply symmetric quantization logic if enabled + if symmetric: + max_values = ops.maximum(ops.abs(min_values), max_values) + min_values = ops.where( + ops.less(min_values, 0), ops.negative(max_values), min_values + ) + + # Ensure range is not zero to avoid division errors + zero_range = ops.equal(min_values, max_values) + min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values) + max_values = ops.where(zero_range, ops.add(max_values, 1), max_values) + + maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype) + + # Calculate scale and zero-point + scale = ops.divide(ops.subtract(max_values, min_values), maxq) + if symmetric: + zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2)) + else: + zero = ops.round(ops.divide(ops.negative(min_values), scale)) + + # Ensure scale is non-zero + scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale) + + if weight: + # Per-channel, non-grouped case: simple reshape is correct. + if per_channel and group_size == -1: + scale = ops.reshape(scale, [-1, 1]) + zero = ops.reshape(zero, [-1, 1]) + elif not per_channel: + num_rows = original_shape[0] + scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1)) + zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1)) + if per_channel: + scale = ops.reshape(scale, [-1, 1]) + zero = ops.reshape(zero, [-1, 1]) + + zero = ops.cast(zero, "uint8") + + return scale, zero, maxq + + +def quantize_with_zero_point(input_tensor, scale, zero, maxq): + """Quantize a float tensor into discrete levels [0, maxq] using + per-tensor/per-channel/grouped scaling. + + Returns `q` (same dtype as inputs/scales; float is fine) where values are in + [0, maxq]. + + Args: + input_tensor: KerasTensor. The input tensor to quantize. + scale: KerasTensor. The scale tensor for quantization. + zero: KerasTensor. The zero tensor for quantization. + maxq: KerasTensor. The maximum quantization value. + + Returns: + KerasTensor. The quantized tensor. + """ + # Guard against divide-by-zero + epsilon = ops.cast(1e-8, dtype=scale.dtype) + safe_scale = ops.where(ops.equal(scale, 0), epsilon, scale) + + quantized_tensor = ops.round( + ops.add( + ops.divide(input_tensor, safe_scale), ops.cast(zero, scale.dtype) + ) + ) + quantized_tensor = ops.clip(quantized_tensor, 0, maxq) + return quantized_tensor + + +def dequantize_with_zero_point(input_tensor, scale, zero): + """ + Dequantizes a quantized tensor using the provided scale and zero tensors. + + Args: + input_tensor: KerasTensor. The quantized tensor to dequantize. + scale: KerasTensor. The scale tensor for dequantization. + zero: KerasTensor. The zero tensor for dequantization. + + Returns: + KerasTensor. The dequantized tensor. + """ + return ops.multiply( + scale, ops.subtract(input_tensor, ops.cast(zero, scale.dtype)) + ) + + +def quantize_with_sz_map(weights_matrix, scale, zero, g_idx, maxq): + """Quantize the weight matrix from group params. + + This function uses the provided scale and zero tensors to quantize the + input weights_matrix according to the group indices. It maps each column + of the weights_matrix to its corresponding group parameters and performs + the quantization operation. + + Args: + weights_matrix: 2D tensor of shape [out_features, in_features]. + scale: Per-group scale tensor of shape [out_features, n_groups]. + zero: Per-group zero-point tensor of shape [out_features, n_groups]. + g_idx: Integer tensor of shape [in_features,] mapping each column to + its group index. + maxq: Scalar (float) representing the maximum integer quantization + level (e.g., 2^bits - 1). + + Returns: + A tensor with the same shape as `weights_matrix` containing the + quantized weights produced using the provided group parameters. + """ + groups = ops.cast(g_idx, "int32") + scale_cols = ops.take(scale, groups, axis=1) # [out_features, in_features] + zero_cols = ops.take(zero, groups, axis=1) # [out_features, in_features] + + # Quantize elementwise, then cast to int + return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq) + + +def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx): + """Rebuild a dequantized weight matrix from group params. + + This function uses the provided scale and zero tensors to dequantize the + input weights_matrix according to the group indices. It maps each column + of the weights_matrix to its corresponding group parameters and performs + the dequantization operation. + + Args: + weights_matrix: 2D tensor of shape [out_features, in_features]. + scale: Per-group scale tensor of shape [out_features, n_groups]. + zero: Per-group zero-point tensor of shape [out_features, n_groups]. + g_idx: Integer tensor of shape [in_features,] mapping each column to + its group index. + maxq: Scalar (float) representing the maximum integer quantization + level (e.g., 2^bits - 1). + + Returns: + A tensor with the same shape as `weights_matrix` containing the + dequantized weights produced using the provided group parameters. + """ + # Map group indices to scales and zeros + groups = ops.cast(g_idx, "int32") + scales_mapped = ops.take(scale, groups, axis=1) + zeros_mapped = ops.take(zero, groups, axis=1) + zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype) + + quantized = ops.multiply( + ops.subtract(weights_matrix, zeros_mapped), scales_mapped + ) + + return quantized diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py new file mode 100644 index 000000000000..1f0e82177789 --- /dev/null +++ b/keras/src/quantizers/quantizers_test.py @@ -0,0 +1,932 @@ +import sys + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import quantizers +from keras.src import random +from keras.src import testing +from keras.src.quantizers.quantizers import compute_quantization_parameters +from keras.src.quantizers.quantizers import dequantize_with_sz_map +from keras.src.quantizers.quantizers import dequantize_with_zero_point +from keras.src.quantizers.quantizers import quantize_with_sz_map +from keras.src.quantizers.quantizers import quantize_with_zero_point +from keras.src.testing.test_utils import named_product + + +class QuantizersTest(testing.TestCase): + def test_get_method(self): + quantizer = quantizers.get("abs_max_quantizer", axis=-1) + self.assertTrue(quantizer, quantizers.AbsMaxQuantizer) + + quantizer = quantizers.get(None) + self.assertEqual(quantizer, None) + + with self.assertRaises(ValueError): + quantizers.get("typo") + + def test_abs_max_quantizer(self): + values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype="float32") + quantizer = quantizers.AbsMaxQuantizer(axis=-1) + + # Test quantizing + quantized_values, scale = quantizer(values) + self.assertDType(quantized_values, "int8") + self.assertDType(scale, "float32") + self.assertEqual(tuple(quantized_values.shape), (3, 4, 5)) + self.assertEqual(tuple(scale.shape), (3, 4, 1)) + self.assertLessEqual(ops.max(quantized_values), 127) + self.assertGreaterEqual(ops.min(quantized_values), -127) + + # Test dequantizing + dequantized_values = ops.divide(quantized_values, scale) + rmse = ops.sqrt( + ops.mean(ops.square(ops.subtract(values, dequantized_values))) + ) + self.assertLess(rmse, 1e-1) # loose assertion + + # Test serialization + self.run_class_serialization_test(quantizer) + + # Test bfloat16 & float16 dtype + values = random.uniform( + [3, 4, 5], minval=-1, maxval=1, dtype="bfloat16" + ) + quantized_values, scale = quantizer(values) + self.assertDType(quantized_values, "int8") + self.assertDType(scale, "bfloat16") + values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype="float16") + quantized_values, scale = quantizer(values) + self.assertDType(quantized_values, "int8") + self.assertDType(scale, "float16") + + def test_abs_max_quantizer_to_numpy(self): + values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype="float32") + quantized_values, scale = quantizers.abs_max_quantize( + values, axis=-1, to_numpy=True + ) + ref_quantized_values, ref_scale = quantizers.abs_max_quantize( + values, axis=-1 + ) + self.assertAllClose(quantized_values, ref_quantized_values) + self.assertAllClose(scale, ref_scale) + + def test_compute_float8_scale(self): + amax = 3.0 + scale = 4.0 + dtype_max = 448.0 # float8_e4m3fn + # The algorithm for computing the new scale is sourced from + # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas + expected_scale = 1.0 / (dtype_max / amax) / (2**0) + + computed_scale = quantizers.compute_float8_scale(amax, scale, dtype_max) + self.assertAllClose(computed_scale, expected_scale) + + def test_compute_float8_amax_history(self): + values = random.uniform([3, 4, 5], minval=-1, maxval=1) + amax_history = random.uniform([123]) + amax_from_values = ops.max(ops.abs(values)) + + computed_amax_history = quantizers.compute_float8_amax_history( + values, amax_history + ) + self.assertAllClose(computed_amax_history[0], amax_from_values) + # Shift to left with 1 step + self.assertAllClose( + computed_amax_history[1:], ops.roll(amax_history, -1)[1:] + ) + + def test_quantize_and_dequantize(self): + scale = 1.0 / 100.0 + values = random.uniform([3, 4, 5], minval=-1, maxval=1) + qdq_values = quantizers.quantize_and_dequantize( + values, scale, "float8_e4m3fn", "float32" + ) + # A loose assertion due to an expected quantization error + self.assertAllClose(qdq_values, values, atol=1e-1) + + qdq_values = quantizers.quantize_and_dequantize( + values, scale, "float8_e5m2", "float32" + ) + # A loose assertion due to an expected quantization error + self.assertAllClose(qdq_values, values, atol=5e-1) + + SHAPE_AXIS_SCENARIOS = [ + # 1. 2D Tensors + # Covers the unpack fast path (rank=2, axis=0) for both parities + {"testcase_name": "2d_axis0_odd", "shape": (5, 8), "axis": 0}, + {"testcase_name": "2d_axis0_even", "shape": (4, 8), "axis": 0}, + # Covers the general path and a negative axis for 2D tensors + {"testcase_name": "2d_axis1_odd", "shape": (8, 7), "axis": 1}, + {"testcase_name": "2d_axis_neg1_even", "shape": (8, 6), "axis": -1}, + # 2. Higher-Rank Tensors + # Covers a middle axis for a complex shape with both parities + {"testcase_name": "4d_axis1_odd", "shape": (2, 5, 4, 6), "axis": 1}, + {"testcase_name": "4d_axis2_even", "shape": (2, 4, 8, 6), "axis": 2}, + # Covers the last axis of a complex shape with a negative index + { + "testcase_name": "4d_axis_neg1_odd", + "shape": (2, 4, 6, 7), + "axis": -1, + }, + ] + + DTYPE_PARAMS = [ + {"testcase_name": "int8", "dtype": "int8", "minval": -8, "maxval": 8}, + {"testcase_name": "uint8", "dtype": "uint8", "minval": 0, "maxval": 16}, + ] + + @parameterized.named_parameters( + named_product(SHAPE_AXIS_SCENARIOS, DTYPE_PARAMS) + ) + def test_pack_unpack_int4(self, shape, axis, dtype, minval, maxval): + # Create a random tensor with int4 values in the specified range and + # dtype + arr = ops.cast( + ops.floor(random.uniform(shape, minval=minval, maxval=maxval)), + dtype, + ) + + # Pack the tensor using the specified dtype + packed, packed_shape, orig_len = quantizers.pack_int4( + arr, axis=axis, dtype=dtype + ) + + # Unpack the tensor using the specified dtype + unpacked = quantizers.unpack_int4( + packed, orig_len, axis=axis, dtype=dtype + ) + + # Verify that the packed tensor has the correct dtype + self.assertDType(packed, dtype) + + # Verify that the unpacked tensor has the correct dtype + self.assertDType(unpacked, dtype) + + # The unpacked tensor should be the same as the original tensor + self.assertAllClose(unpacked, arr) + + # Test the packed shape + expected_packed_shape = list(shape) + expected_packed_shape[axis] = (expected_packed_shape[axis] + 1) // 2 + self.assertEqual( + list(ops.convert_to_numpy(packed_shape)), expected_packed_shape + ) + + @parameterized.named_parameters( + ("per_tensor", None), + ("per_channel", -1), + ) + def test_fake_quant_with_min_max_vars_symbolic(self, axis): + x = backend.KerasTensor((2, 3, 4)) + y = quantizers.fake_quant_with_min_max_vars(x, -3.0, 3.0, axis=axis) + + self.assertIsInstance(y, backend.KerasTensor) + self.assertEqual(y.shape, (2, 3, 4)) + + @parameterized.named_parameters( + [ + { + "testcase_name": "wide_8bits_input_mins_0.0_input_maxs_255.0", + "narrow_range": False, + "input_mins": [0.0], + "input_maxs": [255.0], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [255.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "wide_8bits_input_mins_0.5_input_maxs_128.0", + "narrow_range": False, + "input_mins": [0.5], + "input_maxs": [128.0], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_8bits_input_mins_-128.0_input_maxs_-0.5", + "narrow_range": False, + "input_mins": [-128.0], + "input_maxs": [-0.5], + "num_bits": 8, + "expected_nudged_input_mins": [-127.5], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_8bits_input_mins_-0.1_input_maxs_127.4", + "narrow_range": False, + "input_mins": [-0.1], + "input_maxs": [127.4], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "narrow_8bits_input_mins_0.0_input_maxs_254.0", + "narrow_range": True, + "input_mins": [0.0], + "input_maxs": [254.0], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [254.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "narrow_8bits_input_mins_0.1_input_maxs_127.1", + "narrow_range": True, + "input_mins": [0.1], + "input_maxs": [127.1], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": ( + "narrow_8bits_input_mins_-127.1_input_maxs_-0.1" + ), + "narrow_range": True, + "input_mins": [-127.1], + "input_maxs": [-0.1], + "num_bits": 8, + "expected_nudged_input_mins": [-127.0], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": ( + "narrow_8bits_input_mins_-0.1_input_maxs_126.9" + ), + "narrow_range": True, + "input_mins": [-0.1], + "input_maxs": [126.9], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_0.0_input_maxs_127.0", + "narrow_range": False, + "input_mins": [0.0], + "input_maxs": [127.0], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_0.5_input_maxs_64.0", + "narrow_range": False, + "input_mins": [0.5], + "input_maxs": [64.0], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_-64.0_input_maxs_-0.5", + "narrow_range": False, + "input_mins": [-64.0], + "input_maxs": [-0.5], + "num_bits": 7, + "expected_nudged_input_mins": [-63.5], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_-0.1_input_maxs_63.4", + "narrow_range": False, + "input_mins": [-0.1], + "input_maxs": [63.4], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "narrow_7bits_input_mins_0.0_input_maxs_126.0", + "narrow_range": True, + "input_mins": [0.0], + "input_maxs": [126.0], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [126.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "narrow_7bits_input_mins_0.1_input_maxs_63.1", + "narrow_range": True, + "input_mins": [0.1], + "input_maxs": [63.1], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": ( + "narrow_7bits_input_mins_-63.1_input_maxs_-0.1" + ), + "narrow_range": True, + "input_mins": [-63.1], + "input_maxs": [-0.1], + "num_bits": 7, + "expected_nudged_input_mins": [-63.0], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "narrow_7bits_input_mins_-0.1_input_maxs_62.9", + "narrow_range": True, + "input_mins": [-0.1], + "input_maxs": [62.9], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_8bits_multi_channel", + "narrow_range": False, + "input_mins": [0.0, 0.5, -128.0, -0.1], + "input_maxs": [255.0, 128.0, -0.5, 127.4], + "num_bits": 8, + "expected_nudged_input_mins": [0.0, 0.0, -127.5, 0.0], + "expected_nudged_input_maxs": [255.0, 127.5, 0.0, 127.5], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + { + "testcase_name": "narrow_8bits_multi_channel", + "narrow_range": True, + "input_mins": [0.0, 0.1, -127.1, -0.1], + "input_maxs": [254.0, 127.1, -0.1, 126.9], + "num_bits": 8, + "expected_nudged_input_mins": [0.0, 0.0, -127.0, 0.0], + "expected_nudged_input_maxs": [254.0, 127.0, 0.0, 127.0], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + { + "testcase_name": "wide_7bits_multi_channel", + "narrow_range": False, + "input_mins": [0.0, 0.5, -64.0, -0.1], + "input_maxs": [127.0, 64.0, -0.5, 63.4], + "num_bits": 7, + "expected_nudged_input_mins": [0.0, 0.0, -63.5, 0.0], + "expected_nudged_input_maxs": [127.0, 63.5, 0.0, 63.5], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + { + "testcase_name": "narrow_7bits_multi_channel", + "narrow_range": True, + "input_mins": [0.0, 0.1, -63.1, -0.1], + "input_maxs": [126.0, 63.1, -0.1, 62.9], + "num_bits": 7, + "expected_nudged_input_mins": [0.0, 0.0, -63.0, 0.0], + "expected_nudged_input_maxs": [126.0, 63.0, 0.0, 63.0], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + ] + ) + @pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=f"{backend.backend()} doesn't support `custom_gradient`.", + ) + def test_fake_quant_with_min_max_vars( + self, + input_mins, + input_maxs, + num_bits, + narrow_range, + axis, + expected_nudged_input_mins, + expected_nudged_input_maxs, + expected_steps, + ): + num_channels = len(input_mins) + inputs_list = [] + expected_list = [] + initial_gradients_list = [] + expected_backprops_wrt_input_list = [] + for i in range(num_channels): + expected_nudged_input_min = expected_nudged_input_mins[i] + expected_nudged_input_max = expected_nudged_input_maxs[i] + expected_step = expected_steps[i] + + inputs_list.append( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, + expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, + expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step, + ] + ) + expected_list.append( + [ + expected_nudged_input_min, + expected_nudged_input_min, + expected_nudged_input_min, + expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, + expected_nudged_input_max, + expected_nudged_input_max, + expected_nudged_input_max, + ] + ) + initial_gradients_list.append( + list(range(1, len(inputs_list[-1]) + 1)) + ) + expected_backprops_wrt_input_list.append( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0] + ) + inputs = ops.transpose(ops.array(inputs_list, dtype="float32")) + expected = ops.transpose(ops.array(expected_list, dtype="float32")) + expected_backprops_wrt_input = ops.transpose( + ops.array(expected_backprops_wrt_input_list, dtype="float32") + ) + input_min = ops.array(input_mins, dtype="float32") + input_max = ops.array(input_maxs, dtype="float32") + initial_gradients = ops.transpose( + ops.array(initial_gradients_list, dtype="float32") + ) + + # Test gradients. + if backend.backend() == "tensorflow": + import tensorflow as tf + + @tf.function(jit_compile=True) + def test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ): + with tf.GradientTape() as tape: + tape.watch(inputs) + result = quantizers.fake_quant_with_min_max_vars( + inputs, + input_mins, + input_maxs, + num_bits, + narrow_range, + axis, + ) + return initial_gradients * tape.gradient(result, inputs) + + if backend.backend() == "torch": + import torch + + def test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ): + # Create tensor and enable gradient tracking + inputs = torch.tensor( + inputs, dtype=torch.float32, requires_grad=True + ) + + # Apply the quantization operation + result = quantizers.fake_quant_with_min_max_vars( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ) + + # Compute gradients + result.backward(torch.ones_like(result)) + + return initial_gradients * inputs.grad + + if backend.backend() == "jax": + import jax + + def test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ): + # Define the function to compute gradients for + def quantize_fn(x): + return quantizers.fake_quant_with_min_max_vars( + x, input_mins, input_maxs, num_bits, narrow_range, axis + ) + + _, f_vjp = jax.vjp(quantize_fn, inputs) + + # NOTE: When python version >= 3.10, the gradients are at + # `f_vjp.args[0].args[0][0]`. Otherwise, they are at + # `f_vjp.args[0].args[0][1]`. + if sys.version_info >= (3, 10): + input_gradients = f_vjp.args[0].args[0][0] + else: + input_gradients = f_vjp.args[0].args[0][1] + + return ops.multiply(initial_gradients, input_gradients) + + gradients = test_op( + inputs, input_min, input_max, num_bits, narrow_range, axis + ) + if backend.backend() != "jax" or not testing.jax_uses_gpu(): + # JAX GPU produces less precise numbers, causing the CI to fail. + # For example, 127.5 / 255.0 results in 0.49999997 instead of 0.5. + self.assertAllClose(gradients, expected_backprops_wrt_input) + + # Test outputs. + outputs = quantizers.fake_quant_with_min_max_vars( + inputs, + input_min, + input_max, + num_bits=num_bits, + narrow_range=narrow_range, + axis=axis, + ) + self.assertAllClose(outputs, expected) + + # Test bfloat16 & float16 dtype + outputs = quantizers.fake_quant_with_min_max_vars( + ops.cast(inputs, "bfloat16"), + input_min, + input_max, + num_bits=num_bits, + narrow_range=narrow_range, + axis=axis, + ) + self.assertDType(outputs, "bfloat16") + self.assertAllClose(outputs, expected) + + outputs = quantizers.fake_quant_with_min_max_vars( + ops.cast(inputs, "float16"), + input_min, + input_max, + num_bits=num_bits, + narrow_range=narrow_range, + axis=axis, + ) + self.assertDType(outputs, "float16") + self.assertAllClose(outputs, expected) + + +class GPTQQuantizerTest(testing.TestCase): + @parameterized.named_parameters( + ("bits_2_sym_False", 2, False), + ("bits_4_sym_False", 4, False), + ("bits_8_sym_False", 8, False), + ("bits_2_sym_True", 2, True), + ("bits_4_sym_True", 4, True), + ("bits_8_sym_True", 8, True), + ) + def test_quantize_dequantize_roundtrip_error_bound_per_tensor( + self, bits, symmetric + ): + """ + For finite inputs and positive scales, the reconstruction error + |x_hat - clip(x)| is bounded by 0.5 * scale elementwise. + """ + rng = np.random.default_rng(0) + x = ops.array(rng.standard_normal((64, 32)), "float32") + scale = ops.array(0.05) # per-tensor scale + maxq = ops.array(ops.subtract(ops.power(2, bits), 1), "float32") + zero = ops.array(maxq / 2.0 if symmetric else 3.0, "float32") + + quantized = quantize_with_zero_point(x, scale, zero, maxq) + dequantized = dequantize_with_zero_point(quantized, scale, zero) + + # Representable dequantization range: + # [scale*(0 - zero), scale*(maxq - zero)] + lo = ops.multiply(scale, ops.subtract(ops.array(0.0), zero)) + hi = ops.multiply(scale, ops.subtract(maxq, zero)) + x_clipped = ops.clip(x, lo, hi) + + err = ops.abs(dequantized - x_clipped) + self.assertTrue( + ops.all(err <= (ops.add(ops.multiply(0.5, scale), 1e-7))) + ) + + def test_quantize_clipping_behavior_extremes(self): + """ + Very negative q == 0 ; very positive q == maxq. + """ + maxq = ops.array(15.0) + scale = ops.array(0.1) + zero = ops.array(7.0) + + x = ops.array([[-1e6, 1e6]], "float32") + quantized = quantize_with_zero_point(x, scale, zero, maxq) + + self.assertEqual(quantized.shape, (1, 2)) + self.assertEqual(quantized[0, 0], 0.0) + self.assertEqual(quantized[0, 1], maxq) + + def test_zero_scale_guard_no_nans_for_finite_inputs(self): + """ + If scale == 0, quantize should not produce NaNs (uses epsilon + replacement). + """ + x = ops.array([[0.0, 1.0, -2.0]]) + scale = ops.array(0.0) # triggers epsilon path + zero = ops.array(5.0) + maxq = ops.array(15.0) + + q = quantize_with_zero_point(x, scale, zero, maxq) + self.assertFalse(ops.any(ops.isnan(q))) + + # Dequantize should also be finite + x_hat = dequantize_with_zero_point(q, scale, zero) + self.assertTrue(ops.all(ops.isfinite(x_hat))) + + @parameterized.parameters(4, 8) + def test_idempotent_quantize_when_input_is_already_levels(self, bits): + """ + If input is already exactly on representable dequantized grid, + quantize→dequantize should return the same values (within float eps). + """ + scale = ops.array(0.125) + maxq = ops.array(ops.subtract(ops.power(2, bits), 1), "float32") + zero = ops.array(ops.divide(maxq, 2.0)) + + # Build dequantized grid points: x = scale * (k - zero), k in [0..maxq] + ks = ops.arange(0, ops.add(maxq, 1)) + x_vals = ops.multiply(scale, ops.subtract(ks, zero)) + x = ops.reshape(x_vals, (1, -1)) + + q = quantize_with_zero_point(x, scale, zero, maxq) + x_hat = dequantize_with_zero_point(q, scale, zero) + self.assertAllClose(x_hat, x, rtol=0, atol=1e-6) + + +class ComputeScaleZeroTest(testing.TestCase): + def test_error_when_x_is_none(self): + with self.assertRaisesRegex(ValueError, "cannot be None"): + compute_quantization_parameters(None, bits=4) + + def test_error_when_x_is_empty(self): + x = ops.array([], "float32") + with self.assertRaisesRegex(ValueError, "cannot be empty"): + compute_quantization_parameters(x, bits=4) + + def test_error_when_weight_rank_too_low(self): + x = ops.array([1.0, 2.0], "float32") # rank-1 + with self.assertRaisesRegex(ValueError, "rank of at least 2"): + compute_quantization_parameters(x, bits=4, weight=True) + + @parameterized.named_parameters( + ("bits2_asym", 2, False), + ("bits4_asym", 4, False), + ("bits8_asym", 8, False), + ("bits2_sym", 2, True), + ("bits4_sym", 4, True), + ("bits8_sym", 8, True), + ) + def test_per_tensor_shapes_and_basic_invariants(self, bits, symmetric): + """Test per-tensor shapes and basic invariants.""" + x = ops.array( + np.random.default_rng(0).standard_normal((7, 5), dtype="float32") + ) + scale, zero, maxq = compute_quantization_parameters( + x, bits=bits, symmetric=symmetric, per_channel=False, weight=False + ) + + # Shapes (per-tensor): (1,) for scale/zero + self.assertEqual(scale.shape, (1,)) + self.assertEqual(zero.shape, (1,)) + + # Scale must be strictly positive + self.assertTrue(ops.all(scale > 0.0)) + + if symmetric: + # zero should be (maxq + 1)/2 for symmetric + expected_zero = ops.divide(ops.add(maxq, 1.0), 2.0) + self.assertAllClose(zero, expected_zero) + else: + # Asymmetric: zero ~ round(-min/scale) on the flattened input + flat = ops.reshape(x, (1, -1)) + min_val = ops.min(flat, axis=1) + expected_zero = ops.round(ops.divide(ops.negative(min_val), scale)) + self.assertAllClose(zero, expected_zero) + + def test_per_tensor_symmetric_on_constant_input_uses_safe_range(self): + """Ensures safe range adjustment if entries are equal""" + x = ops.array(np.full((3, 4), 0.0, dtype=np.float32)) + scale, zero, maxq = compute_quantization_parameters( + x, bits=4, symmetric=True, per_channel=False, weight=False + ) + # With symmetric=True and constant input, zero = (maxq+1)/2 + self.assertAllClose(zero, ops.array((float(maxq) + 1.0) / 2.0)) + self.assertTrue(ops.all(ops.greater(scale, 0.0))) + + def test_weight_per_tensor_tiles_rows(self): + """Tests that scales/zeros tensors are properly tiled when + per-channel quantization is not used.""" + x = ops.array( + np.random.default_rng(1).standard_normal((8, 16)), "float32" + ) + scale, zero, _ = compute_quantization_parameters( + x, bits=4, symmetric=False, per_channel=False, weight=True + ) + # When weight=True and per_channel=False, shapes are (rows, 1) + self.assertEqual(scale.shape, (8, 1)) + self.assertEqual(zero.shape, (8, 1)) + + # All elements in the scale and zero tensors must be equal due to + # tiling. + self.assertTrue(ops.all(scale == scale[0, 0])) + self.assertTrue(ops.all(zero == zero[0, 0])) + + def test_weight_per_channel_ungrouped_shapes(self): + """Tests that scales/zeros tensors have the correct shape when + per-channel quantization is used without grouping.""" + x = ops.array( + np.random.default_rng(2).standard_normal((6, 10)), "float32" + ) + scale, zero, _ = compute_quantization_parameters( + x, + bits=4, + symmetric=False, + per_channel=True, + group_size=-1, + weight=True, + ) + # Per-channel (ungrouped): one scale per output row -> (rows, 1) + self.assertEqual(scale.shape, (6, 1)) + self.assertEqual(zero.shape, (6, 1)) + self.assertTrue(ops.all(ops.greater(scale, 0.0))) + + # Each channel should have roughly unique scales and zeros + self.assertFalse(ops.all(scale == scale[0, 0])) + self.assertFalse(ops.all(zero == zero[0, 0])) + + def test_weight_per_channel_grouped_shapes_and_count(self): + """Tests that scales/zeros have the correct shape and count when + per-channel quantization is used with grouping.""" + rows, cols, groups = 8, 16, 4 + x = ops.array( + np.random.default_rng(3).standard_normal((rows, cols)), "float32" + ) + scale, zero, _ = compute_quantization_parameters( + x, + bits=4, + symmetric=False, + per_channel=True, + group_size=groups, + weight=True, + ) + # Grouped path reshapes to [-1, group_size] + # number of groups = rows*cols / groups + num_groups = (rows * cols) // groups + self.assertEqual(scale.shape, (num_groups, 1)) + self.assertEqual(zero.shape, (num_groups, 1)) + self.assertTrue(ops.all(ops.greater(scale, 0.0))) + + @parameterized.named_parameters( + ("sym_true", True), + ("sym_false", False), + ) + def test_dtype_and_finiteness(self, symmetric): + x = ops.array( + np.random.default_rng(4).standard_normal((5, 7)).astype("float32") + ) + scale, zero, maxq = compute_quantization_parameters( + x, + bits=8, + symmetric=symmetric, + per_channel=True, + group_size=-1, + weight=True, + ) + # All outputs should be all finite + self.assertTrue(ops.all(ops.isfinite(scale))) + self.assertTrue(ops.all(ops.isfinite(zero))) + self.assertTrue(ops.all(ops.isfinite(maxq))) + + def test_dequantize_with_sz_map_logic(self): + """Validates the vectorized dequantization logic against a + manual implementation.""" + out_features, in_features, group_size = 4, 16, 4 + n_groups = in_features // group_size + + # Create dummy quantized weights + q_weights = ops.cast( + ops.array( + np.random.randint(0, 15, size=(out_features, in_features)) + ), + "uint8", + ) + + # Create dummy scales and zeros + scale = ops.abs( + ops.array( + np.random.random((out_features, n_groups)).astype("float32") + ) + ) + zero = ops.cast( + ops.array(np.random.randint(0, 15, size=(out_features, n_groups))), + "uint8", + ) + + # Create group index mapping + g_idx = ops.array(np.arange(in_features) // group_size, dtype="int32") + + # Get the result from the function under test + dequantized_result = dequantize_with_sz_map( + q_weights, scale, zero, g_idx + ) + + # Manually compute the expected result + expected_dequantized = np.zeros( + (out_features, in_features), dtype="float32" + ) + + for i in range(out_features): + for j in range(in_features): + group = g_idx[j] + s = scale[i, group] + z = zero[i, group] + # Dequantization formula: (q_val - z) * s + expected_dequantized[i, j] = ops.multiply( + ops.subtract(q_weights[i, j], ops.cast(z, "float32")), s + ) + + self.assertAllClose(dequantized_result, expected_dequantized) + + def test_quantize_with_sz_map_logic(self): + """Validates the vectorized quantization logic against a + manual implementation.""" + out_features, in_features, group_size = 4, 16, 4 + n_groups = in_features // group_size + + # Create dummy float weights + weights = ops.array( + np.random.default_rng(5).standard_normal( + (out_features, in_features) + ), + "float32", + ) + + # Create dummy scales and zeros + scale = ops.abs( + ops.array( + np.random.random((out_features, n_groups)).astype("float32") + ) + ) + zero = ops.cast( + ops.array(np.random.randint(0, 15, size=(out_features, n_groups))), + "uint8", + ) + + maxq = ops.array(15.0) + + # Create group index mapping + g_idx = ops.array(np.arange(in_features) // group_size, dtype="int32") + + # Get the result from the function under test + quantized_result = quantize_with_sz_map( + weights, scale, zero, g_idx, maxq + ) + + # Manually compute the expected result + expected_quantized = np.zeros( + (out_features, in_features), dtype="uint8" + ) + + for i in range(out_features): + for j in range(in_features): + group = g_idx[j] + s = scale[i, group] + z = zero[i, group] + # Quantization formula: clip(round(x/s + z), 0, maxq) + q_val = ops.round(ops.add(ops.divide(weights[i, j], s), z)) + q_val_clipped = ops.clip(q_val, 0.0, maxq) + expected_quantized[i, j] = ops.cast(q_val_clipped, "uint8") + + self.assertAllClose(quantized_result, expected_quantized) diff --git a/keras/src/random/__init__.py b/keras/src/random/__init__.py new file mode 100644 index 000000000000..4ba54c78837c --- /dev/null +++ b/keras/src/random/__init__.py @@ -0,0 +1,9 @@ +from keras.src.random.random import categorical +from keras.src.random.random import dropout +from keras.src.random.random import gamma +from keras.src.random.random import normal +from keras.src.random.random import randint +from keras.src.random.random import shuffle +from keras.src.random.random import truncated_normal +from keras.src.random.random import uniform +from keras.src.random.seed_generator import SeedGenerator diff --git a/keras/src/random/random.py b/keras/src/random/random.py new file mode 100644 index 000000000000..6b65c12ac4b4 --- /dev/null +++ b/keras/src/random/random.py @@ -0,0 +1,345 @@ +from keras.src import backend +from keras.src.api_export import keras_export + + +@keras_export("keras.random.normal") +def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + """Draw random samples from a normal (Gaussian) distribution. + + Args: + shape: The shape of the random values to generate. + mean: Float, defaults to 0. Mean of the random values to generate. + stddev: Float, defaults to 1. Standard deviation of the random values + to generate. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`). + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value `seed=None` + will produce an error, and a `seed` argument must be provided. + """ + return backend.random.normal( + shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) + + +@keras_export("keras.random.categorical") +def categorical(logits, num_samples, dtype="int32", seed=None): + """Draws samples from a categorical distribution. + + This function takes as input `logits`, a 2-D input tensor with shape + (batch_size, num_classes). Each row of the input represents a categorical + distribution, with each column index containing the log-probability for a + given class. + + The function will output a 2-D tensor with shape (batch_size, num_samples), + where each row contains samples from the corresponding row in `logits`. + Each column index contains an independent samples drawn from the input + distribution. + + Args: + logits: 2-D Tensor with shape (batch_size, num_classes). Each row + should define a categorical distribution with the unnormalized + log-probabilities for all classes. + num_samples: Int, the number of independent samples to draw for each + row of the input. This will be the second dimension of the output + tensor's shape. + dtype: Optional dtype of the output tensor. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + + Returns: + A 2-D tensor with (batch_size, num_samples). + """ + logits_shape = list(backend.convert_to_tensor(logits).shape) + if len(logits_shape) != 2: + raise ValueError( + "`logits` should be a 2-D tensor with shape " + f"[batch_size, num_classes]. Received: logits={logits}" + ) + return backend.random.categorical( + logits, num_samples, dtype=dtype, seed=seed + ) + + +@keras_export("keras.random.uniform") +def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): + """Draw samples from a uniform distribution. + + The generated values follow a uniform distribution in the range + `[minval, maxval)`. The lower bound `minval` is included in the range, + while the upper bound `maxval` is excluded. + + `dtype` must be a floating point type, the default range is `[0, 1)`. + + Args: + shape: The shape of the random values to generate. + minval: Float, defaults to 0. Lower bound of the range of + random values to generate (inclusive). + maxval: Float, defaults to 1. Upper bound of the range of + random values to generate (exclusive). + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`) + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + if dtype and not backend.is_float_dtype(dtype): + raise ValueError( + "`keras.random.uniform` requires a floating point `dtype`. " + f"Received: dtype={dtype} " + ) + return backend.random.uniform( + shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed + ) + + +@keras_export("keras.random.randint") +def randint(shape, minval, maxval, dtype="int32", seed=None): + """Draw random integers from a uniform distribution. + + The generated values follow a uniform distribution in the range + `[minval, maxval)`. The lower bound `minval` is included in the range, + while the upper bound `maxval` is excluded. + + `dtype` must be an integer type. + + Args: + shape: The shape of the random values to generate. + minval: Float, defaults to 0. Lower bound of the range of + random values to generate (inclusive). + maxval: Float, defaults to 1. Upper bound of the range of + random values to generate (exclusive). + dtype: Optional dtype of the tensor. Only integer types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`) + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + if dtype and not backend.is_int_dtype(dtype): + raise ValueError( + "`keras.random.randint` requires an integer `dtype`. " + f"Received: dtype={dtype} " + ) + return backend.random.randint( + shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed + ) + + +@keras_export("keras.random.truncated_normal") +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): + """Draw samples from a truncated normal distribution. + + The values are drawn from a normal distribution with specified mean and + standard deviation, discarding and re-drawing any samples that are more + than two standard deviations from the mean. + + Args: + shape: The shape of the random values to generate. + mean: Float, defaults to 0. Mean of the random values to generate. + stddev: Float, defaults to 1. Standard deviation of the random values + to generate. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`) + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + return backend.random.truncated_normal( + shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed + ) + + +@keras_export("keras.random.dropout") +def dropout(inputs, rate, noise_shape=None, seed=None): + return backend.random.dropout( + inputs, rate, noise_shape=noise_shape, seed=seed + ) + + +@keras_export("keras.random.shuffle") +def shuffle(x, axis=0, seed=None): + """Shuffle the elements of a tensor uniformly at random along an axis. + + Args: + x: The tensor to be shuffled. + axis: An integer specifying the axis along which to shuffle. Defaults to + `0`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + return backend.random.shuffle(x, axis=axis, seed=seed) + + +@keras_export("keras.random.gamma") +def gamma(shape, alpha, dtype=None, seed=None): + """Draw random samples from the Gamma distribution. + + Args: + shape: The shape of the random values to generate. + alpha: Float, the parameter of the distribution. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`). + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + return backend.random.gamma(shape, alpha=alpha, dtype=dtype, seed=seed) + + +@keras_export("keras.random.binomial") +def binomial(shape, counts, probabilities, dtype=None, seed=None): + """Draw samples from a Binomial distribution. + + The values are drawn from a Binomial distribution with + specified trial count and probability of success. + + Args: + shape: The shape of the random values to generate. + counts: A number or array of numbers representing the + number of trials. It must be broadcastable with `probabilities`. + probabilities: A float or array of floats representing the + probability of success of an individual event. + It must be broadcastable with `counts`. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`). + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + return backend.random.binomial( + shape, + counts=counts, + probabilities=probabilities, + dtype=dtype, + seed=seed, + ) + + +@keras_export("keras.random.beta") +def beta(shape, alpha, beta, dtype=None, seed=None): + """Draw samples from a Beta distribution. + + The values are drawn from a Beta distribution parametrized + by alpha and beta. + + Args: + shape: The shape of the random values to generate. + alpha: Float or an array of floats representing the first + parameter alpha. Must be broadcastable with `beta` and `shape`. + beta: Float or an array of floats representing the second + parameter beta. Must be broadcastable with `alpha` and `shape`. + dtype: Optional dtype of the tensor. Only floating point types are + supported. If not specified, `keras.config.floatx()` is used, + which defaults to `float32` unless you configured it otherwise (via + `keras.config.set_floatx(float_dtype)`). + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. + """ + return backend.random.beta( + shape=shape, alpha=alpha, beta=beta, dtype=dtype, seed=seed + ) diff --git a/keras/src/random/random_test.py b/keras/src/random/random_test.py new file mode 100644 index 000000000000..327227db3a54 --- /dev/null +++ b/keras/src/random/random_test.py @@ -0,0 +1,495 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend.common import dtypes +from keras.src.backend.common import standardize_dtype +from keras.src.random import random +from keras.src.random import seed_generator +from keras.src.testing.test_utils import named_product +from keras.src.utils.rng_utils import set_random_seed + + +class RandomCorrectnessTest(testing.TestCase): + @parameterized.parameters( + {"seed": 10, "shape": (5,), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3, 4), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 3}, + ) + def test_normal(self, seed, shape, mean, stddev): + np.random.seed(seed) + np_res = np.random.normal(loc=mean, scale=stddev, size=shape) + res = random.normal(shape, mean=mean, stddev=stddev, seed=seed) + self.assertEqual(res.shape, shape) + self.assertEqual(res.shape, np_res.shape) + + @parameterized.parameters( + {"seed": 10, "shape": (5,), "minval": 0, "maxval": 1}, + {"seed": 10, "shape": (2, 3), "minval": 0, "maxval": 1}, + {"seed": 10, "shape": (2, 3, 4), "minval": 0, "maxval": 2}, + {"seed": 10, "shape": (2, 3), "minval": -1, "maxval": 1}, + {"seed": 10, "shape": (2, 3), "minval": 1, "maxval": 3}, + ) + def test_uniform(self, seed, shape, minval, maxval): + np.random.seed(seed) + np_res = np.random.uniform(low=minval, high=maxval, size=shape) + res = random.uniform(shape, minval=minval, maxval=maxval, seed=seed) + self.assertEqual(res.shape, shape) + self.assertEqual(res.shape, np_res.shape) + self.assertLessEqual(ops.max(res), maxval) + self.assertGreaterEqual(ops.max(res), minval) + + @parameterized.parameters( + {"seed": 10, "num_samples": 1, "batch_size": 1}, + {"seed": 10, "num_samples": 5, "batch_size": 2}, + {"seed": 10, "num_samples": 10, "batch_size": 4}, + {"seed": 10, "num_samples": 15, "batch_size": 8}, + ) + def test_categorical(self, seed, num_samples, batch_size): + np.random.seed(seed) + # Create logits that definitely favors the batch index after a softmax + # is applied. Without a softmax, this would be close to random. + logits = np.eye(batch_size) * 1e5 + 1e6 + res = random.categorical(logits, num_samples, seed=seed) + # Outputs should have shape `(batch_size, num_samples)`, where each + # output index matches the batch index. + self.assertEqual(res.shape, (batch_size, num_samples)) + expected = np.tile(np.arange(batch_size)[:, None], (1, num_samples)) + self.assertAllClose(res, expected) + + @parameterized.parameters( + {"seed": 10, "shape": (5,), "min": 0, "max": 10, "dtype": "uint16"}, + {"seed": 10, "shape": (2, 3), "min": 0, "max": 10, "dtype": "uint32"}, + {"seed": 10, "shape": (2, 3, 4), "min": 0, "max": 2, "dtype": "int8"}, + {"seed": 10, "shape": (2, 3), "min": -1, "max": 1, "dtype": "int16"}, + {"seed": 10, "shape": (2, 3), "min": 1, "max": 3, "dtype": "int32"}, + ) + def test_randint(self, seed, shape, min, max, dtype): + np.random.seed(seed) + np_res = np.random.randint(low=min, high=max, size=shape) + res = random.randint( + shape, minval=min, maxval=max, seed=seed, dtype=dtype + ) + self.assertEqual(res.shape, shape) + self.assertEqual(res.shape, np_res.shape) + self.assertLessEqual(ops.max(res), max) + self.assertGreaterEqual(ops.max(res), min) + # Torch has incomplete dtype support for uints; will remap some dtypes. + if keras.backend.backend() != "torch": + self.assertEqual(backend.standardize_dtype(res.dtype), dtype) + + @parameterized.parameters( + {"seed": 10, "shape": (5,), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3, 4), "mean": 0, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 1}, + {"seed": 10, "shape": (2, 3), "mean": 10, "stddev": 3}, + # Test list shapes. + {"seed": 10, "shape": [2, 3], "mean": 10, "stddev": 3}, + ) + def test_truncated_normal(self, seed, shape, mean, stddev): + np.random.seed(seed) + np_res = np.random.normal(loc=mean, scale=stddev, size=shape) + res = random.truncated_normal( + shape, mean=mean, stddev=stddev, seed=seed + ) + self.assertEqual(res.shape, tuple(shape)) + self.assertEqual(res.shape, np_res.shape) + self.assertLessEqual(ops.max(res), mean + 2 * stddev) + self.assertGreaterEqual(ops.max(res), mean - 2 * stddev) + + def test_dropout(self): + x = ops.ones((3, 5)) + self.assertAllClose(random.dropout(x, rate=0, seed=0), x) + x_res = random.dropout(x, rate=0.8, seed=0) + self.assertGreater(ops.max(x_res), ops.max(x)) + self.assertGreater(ops.sum(x_res == 0), 2) + + def test_dropout_noise_shape(self): + inputs = ops.ones((2, 3, 5, 7)) + x = random.dropout( + inputs, rate=0.3, noise_shape=[None, 3, 5, None], seed=0 + ) + self.assertEqual(x.shape, (2, 3, 5, 7)) + + def test_global_seed_generator(self): + # Check that unseeded RNG calls use and update global_rng_state() + + def random_numbers(seed): + rng_state = seed_generator.global_seed_generator().state + rng_state.assign(seed) + x = random.normal((), seed=None) + y = random.normal((), seed=None) + return x, y, rng_state.value + + if backend.backend() == "tensorflow": + import tensorflow as tf + + random_numbers = tf.function(jit_compile=True)(random_numbers) + + seed = ops.zeros((2,)) + seed0 = ops.convert_to_numpy(seed) + x1, y1, seed = random_numbers(seed) + x1 = ops.convert_to_numpy(x1) + y1 = ops.convert_to_numpy(y1) + seed1 = ops.convert_to_numpy(seed) + x2, y2, seed = random_numbers(seed) + x2 = ops.convert_to_numpy(x2) + y2 = ops.convert_to_numpy(y2) + seed2 = ops.convert_to_numpy(seed) + x3, y3, seed = random_numbers(seed) + x3 = ops.convert_to_numpy(x3) + y3 = ops.convert_to_numpy(y3) + seed3 = ops.convert_to_numpy(seed) + + self.assertNotEqual(seed0[1], seed1[1]) + self.assertNotEqual(seed1[1], seed2[1]) + self.assertNotEqual(seed2[1], seed3[1]) + + self.assertGreater(np.abs(x1 - y1), 1e-4) + self.assertGreater(np.abs(x1 - y1), 1e-4) + self.assertGreater(np.abs(x2 - y2), 1e-4) + self.assertGreater(np.abs(x3 - y3), 1e-4) + self.assertGreater(np.abs(x1 - x2), 1e-4) + self.assertGreater(np.abs(x1 - x3), 1e-4) + self.assertGreater(np.abs(x2 - x3), 1e-4) + self.assertGreater(np.abs(y1 - y2), 1e-4) + self.assertGreater(np.abs(y1 - y3), 1e-4) + self.assertGreater(np.abs(y2 - y3), 1e-4) + + seed_generator.global_seed_generator().state.assign(seed) + + def test_shuffle(self): + x = np.arange(100).reshape(10, 10) + + # Test axis=0 + y = random.shuffle(x, seed=0) + + self.assertFalse(np.all(x == ops.convert_to_numpy(y))) + self.assertAllClose(np.sum(x, axis=0), ops.sum(y, axis=0)) + self.assertNotAllClose(np.sum(x, axis=1), ops.sum(y, axis=1)) + + # Test axis=1 + y = random.shuffle(x, axis=1, seed=0) + + self.assertFalse(np.all(x == ops.convert_to_numpy(y))) + self.assertAllClose(np.sum(x, axis=1), ops.sum(y, axis=1)) + self.assertNotAllClose(np.sum(x, axis=0), ops.sum(y, axis=0)) + + @parameterized.parameters( + {"seed": 10, "shape": (5, 2), "alpha": 2.0, "dtype": "float16"}, + {"seed": 10, "shape": (2,), "alpha": 1.5, "dtype": "float32"}, + {"seed": 10, "shape": (2, 3), "alpha": 0.5, "dtype": "float32"}, + ) + def test_gamma(self, seed, shape, alpha, dtype): + values = random.gamma(shape, alpha=alpha, seed=seed, dtype=dtype) + self.assertEqual(ops.shape(values), shape) + self.assertEqual(backend.standardize_dtype(values.dtype), dtype) + self.assertGreater(np.min(ops.convert_to_numpy(values)), 0.0) + + @parameterized.parameters( + { + "seed": 10, + "shape": (5, 2), + "counts": 5e4, + "probabilities": 0.5, + "dtype": "float16", + }, + { + "seed": 10, + "shape": (2,), + "counts": 1e5, + "probabilities": 0.5, + "dtype": "float32", + }, + { + "seed": 10, + "shape": (2, 3), + "counts": [[1e5, 2e5, 3e5], [4e5, 5e5, 6e5]], + "probabilities": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + "dtype": "float32", + }, + ) + def test_binomial(self, seed, shape, counts, probabilities, dtype): + set_random_seed(1337) + values = random.binomial( + shape=shape, + counts=counts, + probabilities=probabilities, + seed=seed, + dtype=dtype, + ) + self.assertEqual(ops.shape(values), shape) + self.assertEqual(backend.standardize_dtype(values.dtype), dtype) + + # The following test that ensures that the number of time + # each event occurs doesn't exceed the total input count specified + # by the user for that event. + # Hence, we do an element wise comparison between `counts` array + # and the (generated) `values` array. + values_np = ops.convert_to_numpy(values) + assert np.greater_equal(np.array(counts), values_np).all() + + # Following test computes the probabilities of each event + # by dividing number of times an event occurs (which is the generated + # value) by the corresponding value in the (total) counts array. + # and then makes sure that the computed probabilities approximate + # the input probabilities + generated_probabilities = values_np / np.array(counts) + probabilities = np.ones(shape) * np.array(probabilities) + self.assertAllClose( + probabilities, generated_probabilities, rtol=0.005, atol=0.005 + ) + + @parameterized.parameters( + { + "seed": 10, + "shape": (10000,), + "alpha": 3.0, + "beta": 2.0, + "dtype": "float16", + }, + { + "seed": 10, + "shape": (10000, 3), + "alpha": [[7.0, 0.5, 1.5]], + "beta": [[15.0, 0.9, 4.5]], + "dtype": "float32", + }, + { + "seed": 10, + "shape": (10000, 30), + "alpha": 1.0, + "beta": 1.0, + "dtype": "float32", + }, + ) + def test_beta(self, seed, shape, alpha, beta, dtype): + set_random_seed(1337) + values = random.beta( + shape=shape, alpha=alpha, beta=beta, seed=seed, dtype=dtype + ) + self.assertEqual(ops.shape(values), shape) + self.assertEqual(backend.standardize_dtype(values.dtype), dtype) + values_np = ops.convert_to_numpy(values) + self.assertGreaterEqual(np.min(values_np), b=0.0) + self.assertLessEqual(np.max(values_np), b=1.0) + + _alpha_is_an_array = False + if isinstance(alpha, list): + alpha = np.array(alpha) + beta = np.array(beta) + _alpha_is_an_array = True + + # Mean check: + # For a beta distributed random variable, + # mean = alpha / (alpha + beta) + expected_mean = alpha / (alpha + beta) + + if _alpha_is_an_array: + actual_mean = np.mean(values_np, axis=0) + self.assertAllClose( + expected_mean.flatten(), actual_mean, atol=0.005, rtol=0.005 + ) + else: + actual_mean = np.mean(values_np.flatten()) + self.assertAlmostEqual(expected_mean, actual_mean, decimal=2) + + # Variance check: + # For a beta distributed random variable, + # variance = (alpha * beta) / ((alpha + beta)^2)(alpha + beta + 1) + expected_variance = (alpha * beta) / ( + np.square(alpha + beta) * (alpha + beta + 1) + ) + if _alpha_is_an_array: + actual_variance = np.var(values_np, axis=0) + self.assertAllClose( + expected_variance.flatten(), + actual_variance, + atol=0.005, + rtol=0.005, + ) + else: + actual_variance = np.var(values_np.flatten()) + self.assertAlmostEqual( + expected_variance, actual_variance, decimal=2 + ) + + +class RandomBehaviorTest(testing.TestCase): + def test_beta_tf_data_compatibility(self): + import tensorflow as tf + + from keras.src.layers.preprocessing.data_layer import DataLayer + from keras.src.random.seed_generator import SeedGenerator + + class BetaLayer(DataLayer): + def __init__(self, seed=None, **kwargs): + super().__init__(**kwargs) + self.seed = seed + self.generator = SeedGenerator(seed) + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, inputs): + seed_generator = self._get_seed_generator(self.backend._backend) + noise = self.backend.random.beta( + self.backend.shape(inputs), + alpha=0.5, + beta=0.5, + seed=seed_generator, + ) + inputs = inputs + noise + return inputs + + layer = BetaLayer() + input_data = np.random.random([2, 4, 4, 3]) + ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output = ops.convert_to_numpy(output) + self.assertEqual(output.shape, (2, 4, 4, 3)) + + def test_categorical_errors(self): + with self.assertRaises(ValueError): + random.categorical(np.ones((5,)), 5) + with self.assertRaises(ValueError): + random.categorical(np.ones((5, 5, 5)), 5) + + def test_randint_dtype_validation(self): + with self.assertRaisesRegex( + ValueError, "`keras.random.randint` requires an integer `dtype`." + ): + random.randint((3, 4), minval=0, maxval=10, dtype="float64") + + def test_uniform_dtype_validation(self): + with self.assertRaisesRegex( + ValueError, + "`keras.random.uniform` requires a floating point `dtype`.", + ): + random.uniform((3, 4), minval=0, maxval=10, dtype="int64") + + @pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test requires `jax` as the backend.", + ) + def test_dropout_jax_jit_stateless(self): + import jax + import jax.numpy as jnp + + x = ops.ones(3) + + @jax.jit + def train_step(x): + with keras.src.backend.StatelessScope(): + x = keras.layers.Dropout(rate=0.1)(x, training=True) + return x + + x = train_step(x) + self.assertIsInstance(x, jnp.ndarray) + + @pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test requires `jax` as the backend.", + ) + def test_jax_rngkey_seed(self): + import jax + import jax.numpy as jnp + + seed = 1234 + rng = jax.random.PRNGKey(seed) + self.assertEqual(rng.shape, (2,)) + self.assertEqual(rng.dtype, jnp.uint32) + x = random.randint((3, 5), 0, 10, seed=rng) + self.assertIsInstance(x, jnp.ndarray) + + @pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test requires `jax` as the backend.", + ) + def test_jax_unseed_disallowed_during_tracing(self): + import jax + + @jax.jit + def jit_fn(): + return random.randint((2, 2), 0, 10, seed=None) + + with self.assertRaisesRegex( + ValueError, "you should only use seeded random ops" + ): + jit_fn() + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="This test requires `tensorflow` as the backend.", + ) + def test_tf_cast_seed(self): + import tensorflow as tf + + inputs = tf.ones([2, 3], dtype="float32") + seed = tf.int32.max + 1000 # Test floormod operation + outputs_mod = random.categorical(inputs, 2, seed=seed) + outputs_nomod = random.categorical(inputs, 2, seed=1001) + self.assertAllClose(outputs_mod, outputs_nomod) + + +class RandomDTypeTest(testing.TestCase): + """Test the dtype to verify that the behavior matches JAX.""" + + INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")] + FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)] + if backend.backend() == "torch": + INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")] + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_normal(self, dtype): + res = random.normal((2, 3), dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=INT_DTYPES)) + def test_categorical(self, dtype): + logits = np.eye(4) * 1e5 + 1e6 + res = random.categorical(logits, 10, dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_uniform(self, dtype): + res = random.uniform((2, 3), dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=INT_DTYPES)) + def test_randint(self, dtype): + res = random.randint((2, 3), 0, 10, dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_truncated_normal(self, dtype): + res = random.truncated_normal((2, 3), dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_dropout(self, dtype): + x = ops.ones((3, 5), dtype=dtype) + res = random.dropout(x, rate=0.8, seed=0) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_gamma(self, dtype): + res = random.gamma((2, 3), 2.0, dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_binomial(self, dtype): + res = random.binomial((2,), 1e5, 0.5, dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) + + @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES)) + def test_beta(self, dtype): + res = random.beta((2, 3), 2.0, 3.0, dtype=dtype) + self.assertEqual(standardize_dtype(res.dtype), dtype) diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py new file mode 100644 index 000000000000..dd2adbc13bbe --- /dev/null +++ b/keras/src/random/seed_generator.py @@ -0,0 +1,161 @@ +import random as python_random + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state +from keras.src.utils import jax_utils +from keras.src.utils.naming import auto_name + + +@keras_export("keras.random.SeedGenerator") +class SeedGenerator: + """Generates variable seeds upon each call to a function generating + random numbers. + + In Keras, all random number generators (such as + `keras.random.normal()`) are stateless, meaning that if you pass an + integer seed to them (such as `seed=42`), they will return the same + values for repeated calls. To get different values for each + call, a `SeedGenerator` providing the state of the random generator + has to be used. + + Note that all the random number generators have a default seed of None, + which implies that an internal global SeedGenerator is used. + If you need to decouple the RNG from the global state you can provide + a local `StateGenerator` with either a deterministic or random initial + state. + + Remark concerning the JAX backen: Note that the use of a local + `StateGenerator` as seed argument is required for JIT compilation of + RNG with the JAX backend, because the use of global state is not + supported. + + Example: + + ```python + seed_gen = keras.random.SeedGenerator(seed=42) + values = keras.random.normal(shape=(2, 3), seed=seed_gen) + new_values = keras.random.normal(shape=(2, 3), seed=seed_gen) + ``` + + Usage in a layer: + + ```python + class Dropout(keras.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.seed_generator = keras.random.SeedGenerator(1337) + + def call(self, x, training=False): + if training: + return keras.random.dropout( + x, rate=0.5, seed=self.seed_generator + ) + return x + ``` + """ + + def __init__(self, seed=None, name=None, **kwargs): + if name is None: + name = auto_name(self.__class__.__name__) + self.name = name + + custom_backend = kwargs.pop("backend", None) + if kwargs: + raise ValueError(f"Unrecognized keyword arguments: {kwargs}") + if custom_backend is not None: + self.backend = custom_backend + else: + self.backend = backend + + self._initial_seed = seed + if seed is None: + seed = make_default_seed() + + if not isinstance(seed, int): + raise ValueError( + f"Argument `seed` must be an integer. Received: seed={seed}" + ) + + def seed_initializer(*args, **kwargs): + dtype = kwargs.get("dtype", None) + return self.backend.convert_to_tensor([seed, 0], dtype=dtype) + + with self.backend.name_scope(self.name, caller=self): + self.state = self.backend.Variable( + seed_initializer, + shape=(2,), + dtype=self.backend.random_seed_dtype(), + trainable=False, + aggregation="none", + name="seed_generator_state", + ) + + def next(self, ordered=True): + seed_state = self.state + # Use * 1 to create a copy + new_seed_value = seed_state.value * 1 + if ordered: + increment = self.backend.convert_to_tensor( + np.array([0, 1]), dtype=seed_state.dtype + ) + self.state.assign(self.backend.numpy.add(seed_state, increment)) + else: + # This produces a sequence of near-unique numbers + # between 0 and 1M + self.state.assign((seed_state + 1) * 5387 % 933199) + return new_seed_value + + def get_config(self): + return {"seed": self._initial_seed} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def global_seed_generator(): + if jax_utils.is_in_jax_tracing_scope(): + raise ValueError( + "[JAX RNG] When tracing a JAX function, " + "you should only use seeded random ops, e.g. " + "you should create a `SeedGenerator` instance, attach it " + "to your layer/model, and pass the instance as the `seed` " + "argument when calling random ops. Unseeded random ops " + "would get incorrectly traced by JAX and would become constant " + "after tracing. Example:\n\n" + "```\n" + "# Make sure to set the seed generator as a layer attribute\n" + "self.seed_generator = keras.random.SeedGenerator(seed=1337)\n" + "...\n" + "out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n" + "```" + ) + gen = global_state.get_global_attribute("global_seed_generator") + if gen is None: + gen = SeedGenerator() + global_state.set_global_attribute("global_seed_generator", gen) + return gen + + +def make_default_seed(): + return python_random.randint(1, int(1e9)) + + +def draw_seed(seed): + from keras.src.backend import convert_to_tensor + from keras.src.backend import random_seed_dtype + + if isinstance(seed, SeedGenerator): + return seed.next() + elif isinstance(seed, int): + return convert_to_tensor([seed, 0], dtype=random_seed_dtype()) + elif seed is None: + return global_seed_generator().next(ordered=False) + raise ValueError( + "Argument `seed` must be either an integer " + "or an instance of `SeedGenerator`. " + f"Received: seed={seed} (of type {type(seed)})" + ) diff --git a/keras/src/random/seed_generator_test.py b/keras/src/random/seed_generator_test.py new file mode 100644 index 000000000000..d1101e0a871a --- /dev/null +++ b/keras/src/random/seed_generator_test.py @@ -0,0 +1,95 @@ +import numpy as np +import pytest + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.random import seed_generator + + +class SeedGeneratorTest(testing.TestCase): + def test_seed_generator_initialization(self): + gen = seed_generator.SeedGenerator() + self.assertIsNotNone(gen.state) + + seed = 12345 + gen = seed_generator.SeedGenerator(seed=seed) + self.assertEqual(ops.convert_to_numpy(gen.state)[0], seed) + + with self.assertRaisesRegex( + ValueError, "Argument `seed` must be an integer" + ): + seed_generator.SeedGenerator(seed="invalid_seed") + + def test_seed_generator_next(self): + gen = seed_generator.SeedGenerator(seed=42) + seed1 = ops.convert_to_numpy(gen.next()) + seed2 = ops.convert_to_numpy(gen.next()) + self.assertFalse(np.array_equal(seed1, seed2)) + + def test_global_seed_generator(self): + gen1 = seed_generator.global_seed_generator() + gen2 = seed_generator.global_seed_generator() + self.assertEqual(gen1, gen2) + + def test_make_default_seed(self): + seed1 = seed_generator.make_default_seed() + seed2 = seed_generator.make_default_seed() + self.assertNotEqual(seed1, seed2) + + def test_seed_generator_dtype(self): + gen = seed_generator.SeedGenerator(seed=42) + self.assertEqual(gen.state.dtype, backend.random_seed_dtype()) + seed = gen.next() + self.assertEqual(gen.state.dtype, backend.random_seed_dtype()) + self.assertEqual( + backend.standardize_dtype(seed.dtype), backend.random_seed_dtype() + ) + + def test_draw_seed_from_seed_generator(self): + gen = seed_generator.SeedGenerator(seed=42) + seed1 = seed_generator.draw_seed(gen) + self.assertTrue(backend.is_tensor(seed1)) + + def test_draw_seed_from_integer(self): + seed2 = seed_generator.draw_seed(12345) + self.assertTrue(backend.is_tensor(seed2)) + self.assertEqual( + backend.standardize_dtype(seed2.dtype), backend.random_seed_dtype() + ) + + def test_draw_seed_from_none(self): + seed3 = seed_generator.draw_seed(None) + self.assertTrue(backend.is_tensor(seed3)) + + def test_draw_seed_invalid(self): + with self.assertRaisesRegex( + ValueError, "Argument `seed` must be either an integer" + ): + seed_generator.draw_seed("invalid_seed") + + def test_seed_generator_unexpected_kwargs(self): + with self.assertRaisesRegex( + ValueError, "Unrecognized keyword arguments" + ): + seed_generator.SeedGenerator(invalid_arg="unexpected_value") + + @pytest.mark.skipif( + backend.backend() != "jax", reason="This test requires the JAX backend" + ) + def test_jax_tracing_with_global_seed_generator(self): + import jax + + @jax.jit + def traced_function(): + return seed_generator.global_seed_generator().next() + + with self.assertRaisesRegex( + ValueError, + "When tracing a JAX function, you should only use seeded random", + ): + traced_function() + + def test_seed_generator_serialization(self): + random_generator = seed_generator.SeedGenerator(seed=42) + self.run_class_serialization_test(random_generator) diff --git a/keras/src/regularizers/__init__.py b/keras/src/regularizers/__init__.py new file mode 100644 index 000000000000..c40bd6ab4549 --- /dev/null +++ b/keras/src/regularizers/__init__.py @@ -0,0 +1,60 @@ +import inspect + +from keras.src.api_export import keras_export +from keras.src.regularizers.regularizers import L1 +from keras.src.regularizers.regularizers import L1L2 +from keras.src.regularizers.regularizers import L2 +from keras.src.regularizers.regularizers import OrthogonalRegularizer +from keras.src.regularizers.regularizers import Regularizer +from keras.src.saving import serialization_lib +from keras.src.utils.naming import to_snake_case + +ALL_OBJECTS = { + Regularizer, + L1, + L2, + L1L2, + OrthogonalRegularizer, +} + +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) + + +@keras_export("keras.regularizers.serialize") +def serialize(regularizer): + return serialization_lib.serialize_keras_object(regularizer) + + +@keras_export("keras.regularizers.deserialize") +def deserialize(config, custom_objects=None): + """Return a Keras regularizer object via its config.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.regularizers.get") +def get(identifier): + """Retrieve a Keras regularizer object via an identifier.""" + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj() + return obj + else: + raise ValueError( + f"Could not interpret regularizer identifier: {identifier}" + ) diff --git a/keras/src/regularizers/regularizers.py b/keras/src/regularizers/regularizers.py new file mode 100644 index 000000000000..99459fe32fb7 --- /dev/null +++ b/keras/src/regularizers/regularizers.py @@ -0,0 +1,352 @@ +import math + +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.utils.numerical_utils import normalize + + +@keras_export(["keras.Regularizer", "keras.regularizers.Regularizer"]) +class Regularizer: + """Regularizer base class. + + Regularizers allow you to apply penalties on layer parameters or layer + activity during optimization. These penalties are summed into the loss + function that the network optimizes. + + Regularization penalties are applied on a per-layer basis. The exact API + will depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` + and `Conv3D`) have a unified API. + + These layers expose 3 keyword arguments: + + - `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel + - `bias_regularizer`: Regularizer to apply a penalty on the layer's bias + - `activity_regularizer`: Regularizer to apply a penalty on the layer's + output + + All layers (including custom layers) expose `activity_regularizer` as a + settable property, whether or not it is in the constructor arguments. + + The value returned by the `activity_regularizer` is divided by the input + batch size so that the relative weighting between the weight regularizers + and the activity regularizers does not change with the batch size. + + You can access a layer's regularization penalties by calling `layer.losses` + after calling the layer on inputs. + + ## Example + + >>> layer = Dense( + ... 5, input_dim=5, + ... kernel_initializer='ones', + ... kernel_regularizer=L1(0.01), + ... activity_regularizer=L2(0.01)) + >>> tensor = ops.ones(shape=(5, 5)) * 2.0 + >>> out = layer(tensor) + + >>> # The kernel regularization term is 0.25 + >>> # The activity regularization term (after dividing by the batch size) + >>> # is 5 + >>> ops.sum(layer.losses) + 5.25 + + ## Available penalties + + ```python + L1(0.3) # L1 Regularization Penalty + L2(0.1) # L2 Regularization Penalty + L1L2(l1=0.01, l2=0.01) # L1 + L2 penalties + ``` + + ## Directly calling a regularizer + + Compute a regularization loss on a tensor by directly calling a regularizer + as if it is a one-argument function. + + E.g. + + >>> regularizer = L2(2.) + >>> tensor = ops.ones(shape=(5, 5)) + >>> regularizer(tensor) + 50.0 + + ## Developing new regularizers + + Any function that takes in a weight matrix and returns a scalar + tensor can be used as a regularizer, e.g.: + + >>> def l1_reg(weight_matrix): + ... return 0.01 * ops.sum(ops.absolute(weight_matrix)) + ... + >>> layer = Dense(5, input_dim=5, + ... kernel_initializer='ones', kernel_regularizer=l1_reg) + >>> tensor = ops.ones(shape=(5, 5)) + >>> out = layer(tensor) + >>> layer.losses + 0.25 + + Alternatively, you can write your custom regularizers in an + object-oriented way by extending this regularizer base class, e.g.: + + >>> class L2Regularizer(Regularizer): + ... def __init__(self, l2=0.): + ... self.l2 = l2 + ... + ... def __call__(self, x): + ... return self.l2 * ops.sum(ops.square(x)) + ... + ... def get_config(self): + ... return {'l2': float(self.l2)} + ... + >>> layer = Dense( + ... 5, input_dim=5, kernel_initializer='ones', + ... kernel_regularizer=L2Regularizer(l2=0.5)) + + >>> tensor = ops.ones(shape=(5, 5)) + >>> out = layer(tensor) + >>> layer.losses + 12.5 + + ### A note on serialization and deserialization: + + Registering the regularizers as serializable is optional if you are just + training and executing models, exporting to and from SavedModels, or saving + and loading weight checkpoints. + + Registration is required for saving and + loading models to HDF5 format, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. If using this + functionality, you must make sure any python process running your model has + also defined and registered your custom regularizer. + """ + + def __call__(self, x): + """Compute a regularization penalty from an input tensor.""" + return 0.0 + + @classmethod + def from_config(cls, config): + """Creates a regularizer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same regularizer from the config + dictionary. + + This method is used by Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Args: + config: A Python dictionary, typically the output of get_config. + + Returns: + A regularizer instance. + """ + return cls(**config) + + def get_config(self): + """Returns the config of the regularizer. + + An regularizer config is a Python dictionary (serializable) + containing all configuration parameters of the regularizer. + The same regularizer can be reinstantiated later + (without any saved state) from this configuration. + + This method is optional if you are just training and executing models, + exporting to and from SavedModels, or using weight checkpoints. + + This method is required for Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Returns: + Python dictionary. + """ + raise NotImplementedError(f"{self} does not implement get_config()") + + +@keras_export(["keras.regularizers.L1L2", "keras.regularizers.l1_l2"]) +class L1L2(Regularizer): + """A regularizer that applies both L1 and L2 regularization penalties. + + The L1 regularization penalty is computed as: + `loss = l1 * reduce_sum(abs(x))` + + The L2 regularization penalty is computed as + `loss = l2 * reduce_sum(square(x))` + + L1L2 may be passed to a layer as a string identifier: + + >>> dense = Dense(3, kernel_regularizer='l1_l2') + + In this case, the default values used are `l1=0.01` and `l2=0.01`. + + Arguments: + l1: float, L1 regularization factor. + l2: float, L2 regularization factor. + """ + + def __init__(self, l1=0.0, l2=0.0): + # The default value for l1 and l2 are different from the value in l1_l2 + # for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 + # and no l1 penalty. + l1 = 0.0 if l1 is None else l1 + l2 = 0.0 if l2 is None else l2 + validate_float_arg(l1, name="l1") + validate_float_arg(l2, name="l2") + + self.l1 = l1 + self.l2 = l2 + + def __call__(self, x): + regularization = ops.convert_to_tensor(0.0, dtype=x.dtype) + if self.l1: + regularization += self.l1 * ops.sum(ops.absolute(x)) + if self.l2: + regularization += self.l2 * ops.sum(ops.square(x)) + return regularization + + def get_config(self): + return {"l1": float(self.l1), "l2": float(self.l2)} + + +@keras_export(["keras.regularizers.L1", "keras.regularizers.l1"]) +class L1(Regularizer): + """A regularizer that applies a L1 regularization penalty. + + The L1 regularization penalty is computed as: + `loss = l1 * reduce_sum(abs(x))` + + L1 may be passed to a layer as a string identifier: + + >>> dense = Dense(3, kernel_regularizer='l1') + + In this case, the default value used is `l1=0.01`. + + Arguments: + l1: float, L1 regularization factor. + """ + + def __init__(self, l1=0.01): + l1 = 0.01 if l1 is None else l1 + validate_float_arg(l1, name="l1") + self.l1 = ops.convert_to_tensor(l1) + + def __call__(self, x): + return self.l1 * ops.sum(ops.absolute(x)) + + def get_config(self): + return {"l1": float(self.l1)} + + +@keras_export(["keras.regularizers.L2", "keras.regularizers.l2"]) +class L2(Regularizer): + """A regularizer that applies a L2 regularization penalty. + + The L2 regularization penalty is computed as: + `loss = l2 * reduce_sum(square(x))` + + L2 may be passed to a layer as a string identifier: + + >>> dense = Dense(3, kernel_regularizer='l2') + + In this case, the default value used is `l2=0.01`. + + Arguments: + l2: float, L2 regularization factor. + """ + + def __init__(self, l2=0.01): + l2 = 0.01 if l2 is None else l2 + validate_float_arg(l2, name="l2") + self.l2 = l2 + + def __call__(self, x): + return self.l2 * ops.sum(ops.square(x)) + + def get_config(self): + return {"l2": float(self.l2)} + + +@keras_export( + [ + "keras.regularizers.OrthogonalRegularizer", + "keras.regularizers.orthogonal_regularizer", + ] +) +class OrthogonalRegularizer(Regularizer): + """Regularizer that encourages input vectors to be orthogonal to each other. + + It can be applied to either the rows of a matrix (`mode="rows"`) or its + columns (`mode="columns"`). When applied to a `Dense` kernel of shape + `(input_dim, units)`, rows mode will seek to make the feature vectors + (i.e. the basis of the output space) orthogonal to each other. + + Arguments: + factor: Float. The regularization factor. The regularization penalty + will be proportional to `factor` times the mean of the dot products + between the L2-normalized rows (if `mode="rows"`, or columns if + `mode="columns"`) of the inputs, excluding the product of each + row/column with itself. Defaults to `0.01`. + mode: String, one of `{"rows", "columns"}`. Defaults to `"rows"`. In + rows mode, the regularization effect seeks to make the rows of the + input orthogonal to each other. In columns mode, it seeks to make + the columns of the input orthogonal to each other. + + Example: + + >>> regularizer = OrthogonalRegularizer(factor=0.01) + >>> layer = Dense(units=4, kernel_regularizer=regularizer) + """ + + def __init__(self, factor=0.01, mode="rows"): + validate_float_arg(factor, name="factor") + self.factor = ops.convert_to_tensor(factor) + if mode not in {"rows", "columns"}: + raise ValueError( + "Invalid value for argument `mode`. Expected one of " + f'{{"rows", "columns"}}. Received: mode={mode}' + ) + self.mode = mode + + def __call__(self, inputs): + if len(inputs.shape) != 2: + raise ValueError( + "Inputs to OrthogonalRegularizer must have rank 2. Received: " + f"inputs.shape={inputs.shape}" + ) + if self.mode == "rows": + inputs = normalize(inputs, axis=1) + product = ops.matmul(inputs, ops.transpose(inputs)) + size = inputs.shape[0] + else: + inputs = normalize(inputs, axis=0) + product = ops.matmul(ops.transpose(inputs), inputs) + size = inputs.shape[1] + product_no_diagonal = product * ( + 1.0 - ops.eye(size, dtype=inputs.dtype) + ) + num_pairs = size * (size - 1.0) / 2.0 + return ( + self.factor + * 0.5 + * ops.sum(ops.absolute(product_no_diagonal)) + / num_pairs + ) + + def get_config(self): + return {"factor": float(self.factor), "mode": self.mode} + + +def validate_float_arg(value, name): + """check penalty number availability, raise ValueError if failed.""" + if ( + not isinstance(value, (float, int)) + or (math.isinf(value) or math.isnan(value)) + or value < 0 + ): + raise ValueError( + f"Invalid value for argument {name}: expected a non-negative float." + f"Received: {name}={value}" + ) + return float(value) diff --git a/keras/src/regularizers/regularizers_test.py b/keras/src/regularizers/regularizers_test.py new file mode 100644 index 000000000000..36141f54f772 --- /dev/null +++ b/keras/src/regularizers/regularizers_test.py @@ -0,0 +1,165 @@ +import numpy as np + +from keras.src import backend +from keras.src import regularizers +from keras.src import testing +from keras.src.regularizers.regularizers import validate_float_arg + + +class RegularizersTest(testing.TestCase): + def test_config(self): + reg = regularizers.L1(0.1) + self.run_class_serialization_test(reg) + + reg = regularizers.L2(0.1) + self.run_class_serialization_test(reg) + + reg = regularizers.L1L2(l1=0.1, l2=0.2) + self.run_class_serialization_test(reg) + + reg = regularizers.OrthogonalRegularizer(factor=0.1, mode="rows") + self.run_class_serialization_test(reg) + + def test_l1(self): + value = np.random.random((4, 4)).astype(np.float32) + x = backend.Variable(value) + y = regularizers.L1(0.1)(x) + self.assertAllClose(y, 0.1 * np.sum(np.abs(value))) + + def test_l2(self): + value = np.random.random((4, 4)).astype(np.float32) + x = backend.Variable(value) + y = regularizers.L2(0.1)(x) + self.assertAllClose(y, 0.1 * np.sum(np.square(value))) + + def test_l1_l2(self): + value = np.random.random((4, 4)).astype(np.float32) + x = backend.Variable(value) + y = regularizers.L1L2(l1=0.1, l2=0.2)(x) + self.assertAllClose( + y, 0.1 * np.sum(np.abs(value)) + 0.2 * np.sum(np.square(value)) + ) + + def test_orthogonal_regularizer(self): + value = np.random.random((4, 4)).astype(np.float32) + x = backend.Variable(value) + y = regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x) + + l2_norm = np.linalg.norm(value, axis=1, keepdims=True) + inputs = value / l2_norm + self.assertAllClose( + y, + 0.1 + * 0.5 + * np.sum( + np.abs(np.dot(inputs, np.transpose(inputs)) * (1.0 - np.eye(4))) + ) + / (4.0 * (4.0 - 1.0) / 2.0), + ) + + def test_get_method(self): + obj = regularizers.get("l1l2") + self.assertIsInstance(obj, regularizers.L1L2) + + obj = regularizers.get("l1") + self.assertIsInstance(obj, regularizers.L1) + + obj = regularizers.get("l2") + self.assertIsInstance(obj, regularizers.L2) + + obj = regularizers.get("orthogonal_regularizer") + self.assertIsInstance(obj, regularizers.OrthogonalRegularizer) + + obj = regularizers.get(None) + self.assertEqual(obj, None) + + with self.assertRaises(ValueError): + regularizers.get("typo") + + def test_l1l2_get_config(self): + l1 = 0.01 + l2 = 0.02 + reg = regularizers.L1L2(l1=l1, l2=l2) + config = reg.get_config() + + self.assertEqual(config, {"l1": l1, "l2": l2}) + + reg_from_config = regularizers.L1L2.from_config(config) + config_from_config = reg_from_config.get_config() + + self.assertDictEqual(config, config_from_config) + self.assertEqual(reg_from_config.l1, l1) + self.assertEqual(reg_from_config.l2, l2) + + def test_orthogonal_regularizer_mode_validation(self): + with self.assertRaises(ValueError) as context: + regularizers.OrthogonalRegularizer(factor=0.01, mode="invalid_mode") + + expected_message = ( + 'Invalid value for argument `mode`. Expected one of {"rows", ' + '"columns"}. Received: mode=invalid_mode' + ) + self.assertEqual(str(context.exception), expected_message) + + def test_orthogonal_regularizer_input_rank_validation(self): + with self.assertRaises(ValueError) as context: + value = np.random.random((4, 4, 4)).astype(np.float32) + x = backend.Variable(value) + regularizers.OrthogonalRegularizer(factor=0.1)(x) + + expected_message = ( + "Inputs to OrthogonalRegularizer must have rank 2. " + f"Received: inputs.shape={(4, 4, 4)}" + ) + self.assertEqual(str(context.exception), expected_message) + + def test_orthogonal_regularizer_get_config(self): + factor = 0.01 + mode = "columns" + regularizer = regularizers.OrthogonalRegularizer( + factor=factor, mode=mode + ) + config = regularizer.get_config() + + self.assertAlmostEqual(config["factor"], factor, 7) + self.assertEqual(config["mode"], mode) + + reg_from_config = regularizers.OrthogonalRegularizer.from_config(config) + config_from_config = reg_from_config.get_config() + + self.assertAlmostEqual(config_from_config["factor"], factor, 7) + self.assertEqual(config_from_config["mode"], mode) + + +class ValidateFloatArgTest(testing.TestCase): + def test_validate_float_with_valid_args(self): + self.assertEqual(validate_float_arg(1, "test"), 1.0) + self.assertEqual(validate_float_arg(1.0, "test"), 1.0) + + def test_validate_float_with_invalid_types(self): + with self.assertRaisesRegex( + ValueError, "expected a non-negative float" + ): + validate_float_arg("not_a_number", "test") + + def test_validate_float_with_nan(self): + with self.assertRaisesRegex( + ValueError, "expected a non-negative float" + ): + validate_float_arg(float("nan"), "test") + + def test_validate_float_with_inf(self): + with self.assertRaisesRegex( + ValueError, "expected a non-negative float" + ): + validate_float_arg(float("inf"), "test") + with self.assertRaisesRegex( + ValueError, "expected a non-negative float" + ): + validate_float_arg(-float("inf"), "test") + + def test_validate_float_with_negative_number(self): + with self.assertRaisesRegex( + ValueError, "expected a non-negative float" + ): + validate_float_arg(-1, "test") diff --git a/keras/src/saving/__init__.py b/keras/src/saving/__init__.py new file mode 100644 index 000000000000..3af25ce633af --- /dev/null +++ b/keras/src/saving/__init__.py @@ -0,0 +1,9 @@ +from keras.src.saving.object_registration import CustomObjectScope +from keras.src.saving.object_registration import custom_object_scope +from keras.src.saving.object_registration import get_custom_objects +from keras.src.saving.object_registration import get_registered_name +from keras.src.saving.object_registration import get_registered_object +from keras.src.saving.object_registration import register_keras_serializable +from keras.src.saving.saving_api import load_model +from keras.src.saving.serialization_lib import deserialize_keras_object +from keras.src.saving.serialization_lib import serialize_keras_object diff --git a/keras/src/saving/file_editor.py b/keras/src/saving/file_editor.py new file mode 100644 index 000000000000..b486590f2132 --- /dev/null +++ b/keras/src/saving/file_editor.py @@ -0,0 +1,819 @@ +import collections +import json +import os.path +import pprint +import zipfile + +import h5py +import numpy as np +import rich.console + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.saving import saving_lib +from keras.src.saving.saving_lib import H5IOStore +from keras.src.utils import naming +from keras.src.utils import summary_utils + +try: + import IPython as ipython +except ImportError: + ipython = None + + +def is_ipython_notebook(): + """Checks if the code is being executed in a notebook.""" + try: + from IPython import get_ipython + + # Check if an active IPython shell exists. + if get_ipython() is not None: + return True + return False + except ImportError: + return False + + +@keras_export("keras.saving.KerasFileEditor") +class KerasFileEditor: + """Utility to inspect, edit, and resave Keras weights files. + + You will find this class useful when adapting + an old saved weights file after having made + architecture changes to a model. + + Args: + filepath: The path to a local file to inspect and edit. + + Examples: + + ```python + editor = KerasFileEditor("my_model.weights.h5") + + # Displays current contents + editor.summary() + + # Remove the weights of an existing layer + editor.delete_object("layers/dense_2") + + # Add the weights of a new layer + editor.add_object("layers/einsum_dense", weights={"0": ..., "1": ...}) + + # Save the weights of the edited model + editor.resave_weights("edited_model.weights.h5") + ``` + """ + + def __init__( + self, + filepath, + ): + self.filepath = filepath + self.metadata = None + self.config = None + self.model = None + self.console = rich.console.Console(highlight=False) + + if filepath.endswith(".keras"): + zf = zipfile.ZipFile(filepath, "r") + weights_store = H5IOStore( + f"{saving_lib._VARS_FNAME}.h5", + archive=zf, + mode="r", + ) + with zf.open(saving_lib._CONFIG_FILENAME, "r") as f: + config_json = f.read() + with zf.open(saving_lib._METADATA_FILENAME, "r") as f: + metadata_json = f.read() + self.config = json.loads(config_json) + self.metadata = json.loads(metadata_json) + + elif filepath.endswith(".weights.h5"): + weights_store = H5IOStore(filepath, mode="r") + else: + raise ValueError( + "Invalid filename: " + "expected a `.keras` `.weights.h5` extension. " + f"Received: filepath={filepath}" + ) + + weights_dict, object_metadata = self._extract_weights_from_store( + weights_store.h5_file + ) + weights_store.close() + self.weights_dict = weights_dict + self.object_metadata = object_metadata # {path: object_name} + self.console.print(self._generate_filepath_info(rich_style=True)) + + if self.metadata is not None: + self.console.print(self._generate_metadata_info(rich_style=True)) + + def summary(self): + """Prints the weight structure of the opened file.""" + self._weights_summary_cli() + + def compare(self, reference_model): + """Compares the opened file to a reference model. + + This method will list all mismatches between the + currently opened file and the provided reference model. + + Args: + reference_model: Model instance to compare to. + + Returns: + Dict with the following keys: + `'status'`, `'error_count'`, `'match_count'`. + Status can be `'success'` or `'error'`. + `'error_count'` is the number of mismatches found. + `'match_count'` is the number of matching weights found. + """ + self.console.print("Running comparison") + ref_spec = {} + get_weight_spec_of_saveable(reference_model, ref_spec) + + def _compare( + target, + ref_spec, + inner_path, + target_name, + ref_name, + error_count, + match_count, + checked_paths, + ): + base_inner_path = inner_path + for ref_key, ref_val in ref_spec.items(): + inner_path = f"{base_inner_path}/{ref_key}" + if inner_path in checked_paths: + continue + + if ref_key not in target: + error_count += 1 + checked_paths.add(inner_path) + if isinstance(ref_val, dict): + self.console.print( + f"[color(160)]...Object [bold]{inner_path}[/] " + f"present in {ref_name}, " + f"missing from {target_name}[/]" + ) + self.console.print( + f" In {ref_name}, {inner_path} contains " + f"the following keys: {list(ref_val.keys())}" + ) + else: + self.console.print( + f"[color(160)]...Weight [bold]{inner_path}[/] " + f"present in {ref_name}, " + f"missing from {target_name}[/]" + ) + elif isinstance(ref_val, dict): + _error_count, _match_count = _compare( + target[ref_key], + ref_spec[ref_key], + inner_path, + target_name, + ref_name, + error_count=error_count, + match_count=match_count, + checked_paths=checked_paths, + ) + error_count += _error_count + match_count += _match_count + else: + if target[ref_key].shape != ref_val.shape: + error_count += 1 + checked_paths.add(inner_path) + self.console.print( + f"[color(160)]...Weight shape mismatch " + f"for [bold]{inner_path}[/][/]\n" + f" In {ref_name}: " + f"shape={ref_val.shape}\n" + f" In {target_name}: " + f"shape={target[ref_key].shape}" + ) + else: + match_count += 1 + return error_count, match_count + + checked_paths = set() + error_count, match_count = _compare( + self.weights_dict, + ref_spec, + inner_path="", + target_name="saved file", + ref_name="reference model", + error_count=0, + match_count=0, + checked_paths=checked_paths, + ) + _error_count, _ = _compare( + ref_spec, + self.weights_dict, + inner_path="", + target_name="reference model", + ref_name="saved file", + error_count=0, + match_count=0, + checked_paths=checked_paths, + ) + error_count += _error_count + self.console.print("─────────────────────") + if error_count == 0: + status = "success" + self.console.print( + "[color(28)][bold]Comparison successful:[/] " + "saved file is compatible with the reference model[/]" + ) + if match_count == 1: + plural = "" + else: + plural = "s" + self.console.print( + f" Found {match_count} matching weight{plural}" + ) + else: + status = "error" + if error_count == 1: + plural = "" + else: + plural = "s" + self.console.print( + f"[color(160)][bold]Found {error_count} error{plural}:[/] " + "saved file is not compatible with the reference model[/]" + ) + return { + "status": status, + "error_count": error_count, + "match_count": match_count, + } + + def _edit_object(self, edit_fn, source_name, target_name=None): + if target_name is not None and "/" in target_name: + raise ValueError( + "Argument `target_name` should be a leaf name, " + "not a full path name. " + f"Received: target_name='{target_name}'" + ) + if "/" in source_name: + # It's a path + elements = source_name.split("/") + weights_dict = self.weights_dict + for e in elements[:-1]: + if e not in weights_dict: + raise ValueError( + f"Path '{source_name}' not found in model." + ) + weights_dict = weights_dict[e] + if elements[-1] not in weights_dict: + raise ValueError(f"Path '{source_name}' not found in model.") + edit_fn( + weights_dict, source_name=elements[-1], target_name=target_name + ) + else: + # Ensure unicity + def count_occurences(d, name, count=0): + for k in d: + if isinstance(d[k], dict): + count += count_occurences(d[k], name, count) + if name in d: + count += 1 + return count + + occurrences = count_occurences(self.weights_dict, source_name) + if occurrences > 1: + raise ValueError( + f"Name '{source_name}' occurs more than once in the model; " + "try passing a complete path" + ) + if occurrences == 0: + raise ValueError( + f"Source name '{source_name}' does not appear in the " + "model. Use `editor.weights_summary()` " + "to list all objects." + ) + + def _edit(d): + for k in d: + if isinstance(d[k], dict): + _edit(d[k]) + if source_name in d: + edit_fn(d, source_name=source_name, target_name=target_name) + + _edit(self.weights_dict) + + def rename_object(self, object_name, new_name): + """Rename an object in the file (e.g. a layer). + + Args: + object_name: String, name or path of the + object to rename (e.g. `"dense_2"` or + `"layers/dense_2"`). + new_name: String, new name of the object. + """ + + def rename_fn(weights_dict, source_name, target_name): + weights_dict[target_name] = weights_dict[source_name] + weights_dict.pop(source_name) + + self._edit_object(rename_fn, object_name, new_name) + + def delete_object(self, object_name): + """Removes an object from the file (e.g. a layer). + + Args: + object_name: String, name or path of the + object to delete (e.g. `"dense_2"` or + `"layers/dense_2"`). + """ + + def delete_fn(weights_dict, source_name, target_name=None): + weights_dict.pop(source_name) + + self._edit_object(delete_fn, object_name) + + def add_object(self, object_path, weights): + """Add a new object to the file (e.g. a layer). + + Args: + object_path: String, full path of the + object to add (e.g. `"layers/dense_2"`). + weights: Dict mapping weight names to weight + values (arrays), + e.g. `{"0": kernel_value, "1": bias_value}`. + """ + if not isinstance(weights, dict): + raise ValueError( + "Argument `weights` should be a dict " + "where keys are weight names (usually '0', '1', etc.) " + "and values are NumPy arrays. " + f"Received: type(weights)={type(weights)}" + ) + + if "/" in object_path: + # It's a path + elements = object_path.split("/") + partial_path = "/".join(elements[:-1]) + weights_dict = self.weights_dict + for e in elements[:-1]: + if e not in weights_dict: + raise ValueError( + f"Path '{partial_path}' not found in model." + ) + weights_dict = weights_dict[e] + weights_dict[elements[-1]] = weights + else: + self.weights_dict[object_path] = weights + + def delete_weight(self, object_name, weight_name): + """Removes a weight from an existing object. + + Args: + object_name: String, name or path of the + object from which to remove the weight + (e.g. `"dense_2"` or `"layers/dense_2"`). + weight_name: String, name of the weight to + delete (e.g. `"0"`). + """ + + def delete_weight_fn(weights_dict, source_name, target_name=None): + if weight_name not in weights_dict[source_name]: + raise ValueError( + f"Weight {weight_name} not found " + f"in object {object_name}. " + "Weights found: " + f"{list(weights_dict[source_name].keys())}" + ) + weights_dict[source_name].pop(weight_name) + + self._edit_object(delete_weight_fn, object_name) + + def add_weights(self, object_name, weights): + """Add one or more new weights to an existing object. + + Args: + object_name: String, name or path of the + object to add the weights to + (e.g. `"dense_2"` or `"layers/dense_2"`). + weights: Dict mapping weight names to weight + values (arrays), + e.g. `{"0": kernel_value, "1": bias_value}`. + """ + if not isinstance(weights, dict): + raise ValueError( + "Argument `weights` should be a dict " + "where keys are weight names (usually '0', '1', etc.) " + "and values are NumPy arrays. " + f"Received: type(weights)={type(weights)}" + ) + + def add_weight_fn(weights_dict, source_name, target_name=None): + weights_dict[source_name].update(weights) + + self._edit_object(add_weight_fn, object_name) + + def save(self, filepath): + """Save the edited weights file. + + Args: + filepath: Path to save the file to. + Must be a `.weights.h5` file. + """ + filepath = str(filepath) + if not filepath.endswith(".weights.h5"): + raise ValueError( + "Invalid `filepath` argument: " + "expected a `.weights.h5` extension. " + f"Received: filepath={filepath}" + ) + weights_store = H5IOStore(filepath, mode="w") + + def _save(weights_dict, weights_store, inner_path): + vars_to_create = {} + for name, value in weights_dict.items(): + if isinstance(value, dict): + if value: + _save( + weights_dict[name], + weights_store, + inner_path=os.path.join(inner_path, name), + ) + else: + # e.g. name="0", value=HDF5Dataset + vars_to_create[name] = value + if vars_to_create: + var_store = weights_store.make(inner_path) + for name, value in vars_to_create.items(): + var_store[name] = value + + _save(self.weights_dict, weights_store, inner_path="") + weights_store.close() + + def resave_weights(self, filepath): + self.save(filepath) + + def _extract_weights_from_store(self, data, metadata=None, inner_path=""): + metadata = metadata or {} + + object_metadata = {} + for k, v in data.attrs.items(): + object_metadata[k] = v + if object_metadata: + metadata[inner_path] = object_metadata + + result = collections.OrderedDict() + for key in data.keys(): + inner_path = f"{inner_path}/{key}" + value = data[key] + if isinstance(value, h5py.Group): + if len(value) == 0: + continue + if "vars" in value.keys() and len(value["vars"]) == 0: + continue + + if hasattr(value, "keys"): + if "vars" in value.keys(): + result[key], metadata = self._extract_weights_from_store( + value["vars"], metadata=metadata, inner_path=inner_path + ) + else: + result[key], metadata = self._extract_weights_from_store( + value, metadata=metadata, inner_path=inner_path + ) + else: + result[key] = value[()] + return result, metadata + + def _generate_filepath_info(self, rich_style=False): + if rich_style: + filepath = f"'{self.filepath}'" + filepath = f"{summary_utils.highlight_symbol(filepath)}" + else: + filepath = f"'{self.filepath}'" + return f"Keras model file {filepath}" + + def _generate_config_info(self, rich_style=False): + return pprint.pformat(self.config) + + def _generate_metadata_info(self, rich_style=False): + version = self.metadata["keras_version"] + date = self.metadata["date_saved"] + if rich_style: + version = f"{summary_utils.highlight_symbol(version)}" + date = f"{summary_utils.highlight_symbol(date)}" + return f"Saved with Keras {version} - date: {date}" + + def _print_weights_structure( + self, weights_dict, indent=0, is_first=True, prefix="", inner_path="" + ): + for idx, (key, value) in enumerate(weights_dict.items()): + inner_path = os.path.join(inner_path, key) + is_last = idx == len(weights_dict) - 1 + if is_first: + is_first = False + connector = "> " + elif is_last: + connector = "└─ " + else: + connector = "├─ " + + if isinstance(value, dict): + bold_key = summary_utils.bold_text(key) + object_label = f"{prefix}{connector}{bold_key}" + if inner_path in self.object_metadata: + metadata = self.object_metadata[inner_path] + if "name" in metadata: + name = metadata["name"] + object_label += f" ('{name}')" + self.console.print(object_label) + if is_last: + appended = " " + else: + appended = "│ " + new_prefix = prefix + appended + self._print_weights_structure( + value, + indent + 1, + is_first=is_first, + prefix=new_prefix, + inner_path=inner_path, + ) + else: + if hasattr(value, "shape"): + bold_key = summary_utils.bold_text(key) + self.console.print( + f"{prefix}{connector}{bold_key}:" + + f" shape={value.shape}, dtype={value.dtype}" + ) + else: + self.console.print(f"{prefix}{connector}{key}: {value}") + + def _weights_summary_cli(self): + self.console.print("Weights structure") + self._print_weights_structure(self.weights_dict, prefix=" " * 2) + + def _weights_summary_interactive(self): + def _generate_html_weights(dictionary, margin_left=0, font_size=1): + html = "" + for key, value in dictionary.items(): + if isinstance(value, dict) and value: + weights_html = _generate_html_weights( + value, margin_left + 20, font_size - 1 + ) + html += ( + f'
' + '{key}' + f"{weights_html}" + "
" + ) + else: + html += ( + f'
' + f'' + f"{key} : shape={value.shape}" + f", dtype={value.dtype}" + f"
' + f"{display_weight(value)}" + "
" + "
" + ) + return html + + output = "Weights structure" + + initialize_id_counter() + output += _generate_html_weights(self.weights_dict) + ipython.display.display(ipython.display.HTML(output)) + + +def get_weight_spec_of_saveable(saveable, spec, visited_saveables=None): + from keras.src.saving.keras_saveable import KerasSaveable + + visited_saveables = visited_saveables or set() + + # If the saveable has already been saved, skip it. + if id(saveable) in visited_saveables: + return + + if hasattr(saveable, "save_own_variables"): + store = {} + saveable.save_own_variables(store) + if store: + keys = sorted(store.keys()) + for k in keys: + val = store[k] + spec[k] = backend.KerasTensor(shape=val.shape, dtype=val.dtype) + + visited_saveables.add(id(saveable)) + + for child_attr, child_obj in saving_lib._walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + sub_spec = {} + get_weight_spec_of_saveable( + child_obj, + sub_spec, + visited_saveables=visited_saveables, + ) + if sub_spec: + spec[child_attr] = sub_spec + elif isinstance(child_obj, (list, dict, tuple, set)): + sub_spec = {} + get_weight_spec_of_container( + child_obj, + sub_spec, + visited_saveables=visited_saveables, + ) + if sub_spec: + spec[child_attr] = sub_spec + + +def get_weight_spec_of_container(container, spec, visited_saveables): + from keras.src.saving.keras_saveable import KerasSaveable + + used_names = {} + if isinstance(container, dict): + container = list(container.values()) + + for saveable in container: + if isinstance(saveable, KerasSaveable): + name = naming.to_snake_case(saveable.__class__.__name__) + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + sub_spec = {} + get_weight_spec_of_saveable( + saveable, + sub_spec, + visited_saveables=visited_saveables, + ) + if sub_spec: + spec[name] = sub_spec + + +def initialize_id_counter(): + global div_id_counter + div_id_counter = 0 + + +def increment_id_counter(): + global div_id_counter + div_id_counter += 1 + + +def get_id_counter(): + return div_id_counter + + +def display_weight(weight, axis=-1, threshold=16): + def _find_factors_closest_to_sqrt(num): + sqrt_num = int(np.sqrt(num)) + + for i in range(sqrt_num, 0, -1): + if num % i == 0: + M = i + N = num // i + + if M > N: + return N, M + return M, N + + def _color_from_rbg(value): + return f"rgba({value[0]}, {value[1]}, {value[2]}, 1)" + + def _reduce_3d_array_by_mean(arr, n, axis): + if axis == 2: + trimmed_arr = arr[:, :, : arr.shape[2] - (arr.shape[2] % n)] + reshaped = np.reshape( + trimmed_arr, (arr.shape[0], arr.shape[1], -1, n) + ) + mean_values = np.mean(reshaped, axis=3) + + elif axis == 1: + trimmed_arr = arr[:, : arr.shape[1] - (arr.shape[1] % n), :] + reshaped = np.reshape( + trimmed_arr, (arr.shape[0], -1, n, arr.shape[2]) + ) + mean_values = np.mean(reshaped, axis=2) + + elif axis == 0: + trimmed_arr = arr[: arr.shape[0] - (arr.shape[0] % n), :, :] + reshaped = np.reshape( + trimmed_arr, (-1, n, arr.shape[1], arr.shape[2]) + ) + mean_values = np.mean(reshaped, axis=1) + + else: + raise ValueError("Axis must be 0, 1, or 2.") + + return mean_values + + def _create_matrix_html(matrix, subplot_size=840): + rows, cols, num_slices = matrix.shape + + M, N = _find_factors_closest_to_sqrt(num_slices) + + try: + from matplotlib import cm + except ImportError: + cm = None + if cm: + rgb_matrix = cm.jet(matrix) + else: + rgb_matrix = (matrix - np.min(matrix)) / ( + np.max(matrix) - np.min(matrix) + ) + rgb_matrix = np.stack([rgb_matrix, rgb_matrix, rgb_matrix], axis=-1) + rgb_matrix = (rgb_matrix[..., :3] * 255).astype("uint8") + + subplot_html = "" + for i in range(num_slices): + cell_html = "" + for row in rgb_matrix[..., i, :]: + for rgb in row: + color = _color_from_rbg(rgb) + cell_html += ( + f'
' + f"
" + ) + subplot_html += f""" +
+ {cell_html} +
+ """ + + cell_size = subplot_size // (N * cols) + + increment_id_counter() + div_id = get_id_counter() + + html_code = f""" +
+ +
+ {subplot_html} +
+
+ """ + + return html_code + + if weight.ndim == 1: + weight = weight[..., np.newaxis] + + weight = np.swapaxes(weight, axis, -1) + weight = weight.reshape(-1, weight.shape[-1]) + + M, N = _find_factors_closest_to_sqrt(weight.shape[0]) + weight = weight.reshape(M, N, weight.shape[-1]) + + for reduce_axis in [0, 1, 2]: + if weight.shape[reduce_axis] > threshold: + weight = _reduce_3d_array_by_mean( + weight, + weight.shape[reduce_axis] // threshold, + axis=reduce_axis, + ) + + weight = (weight - weight.min()) / (weight.max() - weight.min() + 1e-5) + + html_code = _create_matrix_html(weight) + return html_code diff --git a/keras/src/saving/file_editor_test.py b/keras/src/saving/file_editor_test.py new file mode 100644 index 000000000000..f02ca11516b1 --- /dev/null +++ b/keras/src/saving/file_editor_test.py @@ -0,0 +1,112 @@ +import os + +import numpy as np +import pytest + +import keras +from keras.src import testing +from keras.src.saving.file_editor import KerasFileEditor + + +def get_source_model(): + inputs = keras.Input((2,)) + x = keras.layers.Dense(3, name="mydense")(inputs) + outputs = keras.layers.Dense(3, name="output_layer")(x) + model = keras.Model(inputs, outputs) + return model + + +def get_target_model(): + inputs = keras.Input((2,)) + x = keras.layers.Dense(3, name="mydense")(inputs) + x = keras.layers.Dense(3, name="myotherdense")(x) + outputs = keras.layers.Dense(3, name="output_layer")(x) + model = keras.Model(inputs, outputs) + return model + + +class SavingTest(testing.TestCase): + def test_basics(self): + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + + model = get_source_model() + model.save(temp_filepath) + + editor = KerasFileEditor(temp_filepath) + editor.summary() + + target_model = get_target_model() + + out = editor.compare(model) # Succeeds + self.assertEqual(out["status"], "success") + out = editor.compare(target_model) # Fails + + editor.add_object( + "layers/dense_3", weights={"kernel": np.random.random((3, 3))} + ) + out = editor.compare(target_model) # Fails + self.assertEqual(out["status"], "error") + self.assertEqual(out["error_count"], 2) + + editor.rename_object("dense_3", "dense_4") + editor.rename_object("layers/dense_4", "dense_2") + editor.add_weights("dense_2", weights={"bias": np.random.random((3,))}) + out = editor.compare(target_model) # Succeeds + self.assertEqual(out["status"], "success") + + editor.add_object( + "layers/dense_3", weights={"0": np.random.random((3, 3))} + ) + out = editor.compare(target_model) # Fails + self.assertEqual(out["status"], "error") + self.assertEqual(out["error_count"], 1) + + editor.delete_object("layers/dense_3") + out = editor.compare(target_model) # Succeeds + self.assertEqual(out["status"], "success") + editor.summary() + + temp_filepath = os.path.join(self.get_temp_dir(), "resaved.weights.h5") + editor.save(temp_filepath) + target_model.load_weights(temp_filepath) + + editor = KerasFileEditor(temp_filepath) + editor.summary() + out = editor.compare(target_model) # Succeeds + self.assertEqual(out["status"], "success") + + editor.delete_weight("dense_2", "bias") + out = editor.compare(target_model) # Fails + self.assertEqual(out["status"], "error") + self.assertEqual(out["error_count"], 1) + + editor.add_weights("dense_2", {"bias": np.zeros((7,))}) + out = editor.compare(target_model) # Fails + self.assertEqual(out["status"], "error") + self.assertEqual(out["error_count"], 1) + + editor.delete_weight("dense_2", "bias") + editor.add_weights("dense_2", {"bias": np.zeros((3,))}) + out = editor.compare(target_model) # Succeeds + self.assertEqual(out["status"], "success") + + @pytest.mark.requires_trainable_backend + def test_scalar_weight(self): + model = keras.Sequential(name="my_sequential") + model.add(keras.Input(shape=(1,), name="my_input")) + model.add(keras.layers.Dense(1, activation="sigmoid", name="my_dense")) + model.compile(optimizer="adam", loss="mse", metrics=["mae"]) + model.fit(np.array([[1]]), np.array([[1]]), verbose=0) + model_fpath = os.path.join(self.get_temp_dir(), "model.keras") + weights_fpath = os.path.join(self.get_temp_dir(), "model.weights.h5") + model.save(model_fpath) + model.save_weights(weights_fpath) + + model_editor = KerasFileEditor(model_fpath) + self.assertEqual( + len(keras.src.tree.flatten(model_editor.weights_dict)), 8 + ) + model_weights_editor = KerasFileEditor(weights_fpath) + self.assertEqual( + len(keras.src.tree.flatten(model_weights_editor.weights_dict)), 8 + ) diff --git a/keras/src/saving/keras_saveable.py b/keras/src/saving/keras_saveable.py new file mode 100644 index 000000000000..7fc536b470cb --- /dev/null +++ b/keras/src/saving/keras_saveable.py @@ -0,0 +1,38 @@ +import io + + +class KerasSaveable: + # Note: renaming this function will cause old pickles to be broken. + # This is probably not a huge deal, as pickle should not be a recommended + # saving format -- it should only be supported for use with distributed + # computing frameworks. + + def _obj_type(self): + raise NotImplementedError( + "KerasSaveable subclases must provide an " + "implementation for `obj_type()`" + ) + + @classmethod + def _unpickle_model(cls, bytesio): + import keras.src.saving.saving_lib as saving_lib + + # pickle is not safe regardless of what you do. + return saving_lib._load_model_from_fileobj( + bytesio, custom_objects=None, compile=True, safe_mode=False + ) + + def __reduce__(self): + """__reduce__ is used to customize the behavior of `pickle.pickle()`. + + The method returns a tuple of two elements: a function, and a list of + arguments to pass to that function. In this case we just leverage the + keras saving library.""" + import keras.src.saving.saving_lib as saving_lib + + buf = io.BytesIO() + saving_lib._save_model_to_fileobj(self, buf, "h5") + return ( + self._unpickle_model, + (buf,), + ) diff --git a/keras/src/saving/object_registration.py b/keras/src/saving/object_registration.py new file mode 100644 index 000000000000..2b1ac1df803d --- /dev/null +++ b/keras/src/saving/object_registration.py @@ -0,0 +1,230 @@ +import inspect + +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + +GLOBAL_CUSTOM_OBJECTS = {} +GLOBAL_CUSTOM_NAMES = {} + + +@keras_export( + [ + "keras.saving.CustomObjectScope", + "keras.saving.custom_object_scope", + "keras.utils.CustomObjectScope", + "keras.utils.custom_object_scope", + ] +) +class CustomObjectScope: + """Exposes custom classes/functions to Keras deserialization internals. + + Under a scope `with custom_object_scope(objects_dict)`, Keras methods such + as `keras.models.load_model()` or + `keras.models.model_from_config()` will be able to deserialize any + custom object referenced by a saved config (e.g. a custom layer or metric). + + Example: + + Consider a custom regularizer `my_regularizer`: + + ```python + layer = Dense(3, kernel_regularizer=my_regularizer) + # Config contains a reference to `my_regularizer` + config = layer.get_config() + ... + # Later: + with custom_object_scope({'my_regularizer': my_regularizer}): + layer = Dense.from_config(config) + ``` + + Args: + custom_objects: Dictionary of `{str: object}` pairs, + where the `str` key is the object name. + """ + + def __init__(self, custom_objects): + self.custom_objects = custom_objects or {} + self.backup = None + + def __enter__(self): + self.backup = global_state.get_global_attribute( + "custom_objects_scope_dict", {} + ).copy() + global_state.set_global_attribute( + "custom_objects_scope_dict", self.custom_objects.copy() + ) + return self + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute( + "custom_objects_scope_dict", self.backup.copy() + ) + + +# Alias. +custom_object_scope = CustomObjectScope + + +@keras_export( + [ + "keras.saving.get_custom_objects", + "keras.utils.get_custom_objects", + ] +) +def get_custom_objects(): + """Retrieves a live reference to the global dictionary of custom objects. + + Custom objects set using `custom_object_scope()` are not added to the + global dictionary of custom objects, and will not appear in the returned + dictionary. + + Example: + + ```python + get_custom_objects().clear() + get_custom_objects()['MyObject'] = MyObject + ``` + + Returns: + Global dictionary mapping registered class names to classes. + """ + return GLOBAL_CUSTOM_OBJECTS + + +@keras_export( + [ + "keras.saving.register_keras_serializable", + "keras.utils.register_keras_serializable", + ] +) +def register_keras_serializable(package="Custom", name=None): + """Registers an object with the Keras serialization framework. + + This decorator injects the decorated class or function into the Keras custom + object dictionary, so that it can be serialized and deserialized without + needing an entry in the user-provided custom object dict. It also injects a + function that Keras will call to get the object's serializable string key. + + Note that to be serialized and deserialized, classes must implement the + `get_config()` method. Functions do not have this requirement. + + The object will be registered under the key `'package>name'` where `name`, + defaults to the object name if not passed. + + Example: + + ```python + # Note that `'my_package'` is used as the `package` argument here, and since + # the `name` argument is not provided, `'MyDense'` is used as the `name`. + @register_keras_serializable('my_package') + class MyDense(keras.layers.Dense): + pass + + assert get_registered_object('my_package>MyDense') == MyDense + assert get_registered_name(MyDense) == 'my_package>MyDense' + ``` + + Args: + package: The package that this class belongs to. This is used for the + `key` (which is `"package>name"`) to identify the class. Note that + this is the first argument passed into the decorator. + name: The name to serialize this class under in this package. If not + provided or `None`, the class' name will be used (note that this is + the case when the decorator is used with only one argument, which + becomes the `package`). + + Returns: + A decorator that registers the decorated class with the passed names. + """ + + def decorator(arg): + """Registers a class with the Keras serialization framework.""" + class_name = name if name is not None else arg.__name__ + registered_name = f"{package}>{class_name}" + + if inspect.isclass(arg) and not hasattr(arg, "get_config"): + raise ValueError( + "Cannot register a class that does not have a " + "get_config() method." + ) + + GLOBAL_CUSTOM_OBJECTS[registered_name] = arg + GLOBAL_CUSTOM_NAMES[arg] = registered_name + + return arg + + return decorator + + +@keras_export( + [ + "keras.saving.get_registered_name", + "keras.utils.get_registered_name", + ] +) +def get_registered_name(obj): + """Returns the name registered to an object within the Keras framework. + + This function is part of the Keras serialization and deserialization + framework. It maps objects to the string names associated with those objects + for serialization/deserialization. + + Args: + obj: The object to look up. + + Returns: + The name associated with the object, or the default Python name if the + object is not registered. + """ + if obj in GLOBAL_CUSTOM_NAMES: + return GLOBAL_CUSTOM_NAMES[obj] + else: + return obj.__name__ + + +@keras_export( + [ + "keras.saving.get_registered_object", + "keras.utils.get_registered_object", + ] +) +def get_registered_object(name, custom_objects=None, module_objects=None): + """Returns the class associated with `name` if it is registered with Keras. + + This function is part of the Keras serialization and deserialization + framework. It maps strings to the objects associated with them for + serialization/deserialization. + + Example: + + ```python + def from_config(cls, config, custom_objects=None): + if 'my_custom_object_name' in config: + config['hidden_cls'] = tf.keras.saving.get_registered_object( + config['my_custom_object_name'], custom_objects=custom_objects) + ``` + + Args: + name: The name to look up. + custom_objects: A dictionary of custom objects to look the name up in. + Generally, custom_objects is provided by the user. + module_objects: A dictionary of custom objects to look the name up in. + Generally, module_objects is provided by midlevel library + implementers. + + Returns: + An instantiable class associated with `name`, or `None` if no such class + exists. + """ + custom_objects_scope_dict = global_state.get_global_attribute( + "custom_objects_scope_dict", {} + ) + if name in custom_objects_scope_dict: + return custom_objects_scope_dict[name] + elif name in GLOBAL_CUSTOM_OBJECTS: + return GLOBAL_CUSTOM_OBJECTS[name] + elif custom_objects and name in custom_objects: + return custom_objects[name] + elif module_objects and name in module_objects: + return module_objects[name] + return None diff --git a/keras/src/saving/object_registration_test.py b/keras/src/saving/object_registration_test.py new file mode 100644 index 000000000000..ece59e7e208a --- /dev/null +++ b/keras/src/saving/object_registration_test.py @@ -0,0 +1,121 @@ +import keras +from keras.src import testing +from keras.src.saving import object_registration +from keras.src.saving import serialization_lib + + +class TestObjectRegistration(testing.TestCase): + def test_custom_object_scope(self): + def custom_fn(): + pass + + class CustomClass: + pass + + def check_get_in_thread(): + with object_registration.custom_object_scope( + {"CustomClass": CustomClass, "custom_fn": custom_fn} + ): + actual_custom_fn = keras.activations.get("custom_fn") + self.assertEqual(actual_custom_fn, custom_fn) + actual_custom_class = keras.regularizers.get("CustomClass") + self.assertEqual(actual_custom_class.__class__, CustomClass) + + with object_registration.custom_object_scope( + {"CustomClass": CustomClass, "custom_fn": custom_fn} + ): + actual_custom_fn = keras.activations.get("custom_fn") + self.assertEqual(actual_custom_fn, custom_fn) + actual_custom_class = keras.regularizers.get("CustomClass") + self.assertEqual(actual_custom_class.__class__, CustomClass) + checked_thread = self.checkedThread(check_get_in_thread) + checked_thread.start() + checked_thread.join() + + def test_serialize_custom_class_with_default_name(self): + @object_registration.register_keras_serializable() + class TestClass: + def __init__(self, value): + self._value = value + + def get_config(self): + return {"value": self._value} + + @classmethod + def from_config(cls, config): + return cls(**config) + + serialized_name = "Custom>TestClass" + inst = TestClass(value=10) + class_name = object_registration.GLOBAL_CUSTOM_NAMES[TestClass] + self.assertEqual(serialized_name, class_name) + + config = serialization_lib.serialize_keras_object(inst) + self.assertEqual("TestClass", config["class_name"]) + new_inst = serialization_lib.deserialize_keras_object(config) + self.assertIsNot(inst, new_inst) + self.assertIsInstance(new_inst, TestClass) + self.assertEqual(10, new_inst._value) + + def test_serialize_custom_class_with_custom_name(self): + @object_registration.register_keras_serializable( + "TestPackage", "CustomName" + ) + class OtherTestClass: + def __init__(self, val): + self._val = val + + def get_config(self): + return {"val": self._val} + + @classmethod + def from_config(cls, config): + return cls(**config) + + serialized_name = "TestPackage>CustomName" + inst = OtherTestClass(val=5) + class_name = object_registration.GLOBAL_CUSTOM_NAMES[OtherTestClass] + self.assertEqual(serialized_name, class_name) + fn_class_name = object_registration.get_registered_name(OtherTestClass) + self.assertEqual(fn_class_name, class_name) + + cls = object_registration.get_registered_object(fn_class_name) + self.assertEqual(OtherTestClass, cls) + + config = keras.saving.serialize_keras_object(inst) + self.assertEqual("OtherTestClass", config["class_name"]) + new_inst = keras.saving.deserialize_keras_object(config) + self.assertIsNot(inst, new_inst) + self.assertIsInstance(new_inst, OtherTestClass) + self.assertEqual(5, new_inst._val) + + def test_serialize_custom_function(self): + @object_registration.register_keras_serializable() + def my_fn(): + return 42 + + serialized_name = "Custom>my_fn" + class_name = object_registration.GLOBAL_CUSTOM_NAMES[my_fn] + self.assertEqual(serialized_name, class_name) + fn_class_name = object_registration.get_registered_name(my_fn) + self.assertEqual(fn_class_name, class_name) + + config = keras.saving.serialize_keras_object(my_fn) + fn = keras.saving.deserialize_keras_object(config) + self.assertEqual(42, fn()) + + fn_2 = object_registration.get_registered_object(fn_class_name) + self.assertEqual(42, fn_2()) + + def test_serialize_custom_class_without_get_config_fails(self): + with self.assertRaisesRegex( + ValueError, + "Cannot register a class that does not have a get_config.*", + ): + + @object_registration.register_keras_serializable( + "TestPackage", "TestClass" + ) + class TestClass: + def __init__(self, value): + self._value = value diff --git a/keras/src/saving/saving_api.py b/keras/src/saving/saving_api.py new file mode 100644 index 000000000000..3a45f35f5a4b --- /dev/null +++ b/keras/src/saving/saving_api.py @@ -0,0 +1,316 @@ +import os +import zipfile + +from absl import logging + +from keras.src.api_export import keras_export +from keras.src.legacy.saving import legacy_h5_format +from keras.src.saving import saving_lib +from keras.src.utils import file_utils +from keras.src.utils import io_utils + +try: + import h5py +except ImportError: + h5py = None + + +@keras_export(["keras.saving.save_model", "keras.models.save_model"]) +def save_model(model, filepath, overwrite=True, zipped=None, **kwargs): + """Saves a model as a `.keras` file. + + Args: + model: Keras model instance to be saved. + filepath: `str` or `pathlib.Path` object. Path where to save the model. + overwrite: Whether we should overwrite any existing model at the target + location, or instead ask the user via an interactive prompt. + zipped: Whether to save the model as a zipped `.keras` + archive (default when saving locally), or as an unzipped directory + (default when saving on the Hugging Face Hub). + + Example: + + ```python + model = keras.Sequential( + [ + keras.layers.Dense(5, input_shape=(3,)), + keras.layers.Softmax(), + ], + ) + model.save("model.keras") + loaded_model = keras.saving.load_model("model.keras") + x = keras.random.uniform((10, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + ``` + + Note that `model.save()` is an alias for `keras.saving.save_model()`. + + The saved `.keras` file is a `zip` archive that contains: + + - The model's configuration (architecture) + - The model's weights + - The model's optimizer's state (if any) + + Thus models can be reinstantiated in the exact same state. + """ + include_optimizer = kwargs.pop("include_optimizer", True) + save_format = kwargs.pop("save_format", False) + if save_format: + if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith( + ".keras" + ): + logging.warning( + "The `save_format` argument is deprecated in Keras 3. " + "We recommend removing this argument as it can be inferred " + "from the file path. " + f"Received: save_format={save_format}" + ) + else: + raise ValueError( + "The `save_format` argument is deprecated in Keras 3. " + "Please remove this argument and pass a file path with " + "either `.keras` or `.h5` extension." + f"Received: save_format={save_format}" + ) + if kwargs: + raise ValueError( + "The following argument(s) are not supported: " + f"{list(kwargs.keys())}" + ) + + # Deprecation warnings + if str(filepath).endswith((".h5", ".hdf5")): + logging.warning( + "You are saving your model as an HDF5 file via " + "`model.save()` or `keras.saving.save_model(model)`. " + "This file format is considered legacy. " + "We recommend using instead the native Keras format, " + "e.g. `model.save('my_model.keras')` or " + "`keras.saving.save_model(model, 'my_model.keras')`. " + ) + + is_hf = str(filepath).startswith("hf://") + if zipped is None: + zipped = not is_hf # default behavior depends on destination + + # If file exists and should not be overwritten. + try: + exists = (not is_hf) and os.path.exists(filepath) + except TypeError: + exists = False + if exists and not overwrite: + proceed = io_utils.ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + if zipped and str(filepath).endswith(".keras"): + return saving_lib.save_model(model, filepath) + if not zipped: + return saving_lib.save_model(model, filepath, zipped=False) + if str(filepath).endswith((".h5", ".hdf5")): + return legacy_h5_format.save_model_to_hdf5( + model, filepath, overwrite, include_optimizer + ) + raise ValueError( + "Invalid filepath extension for saving. " + "Please add either a `.keras` extension for the native Keras " + f"format (recommended) or a `.h5` extension. " + "Use `model.export(filepath)` if you want to export a SavedModel " + "for use with TFLite/TFServing/etc. " + f"Received: filepath={filepath}." + ) + + +@keras_export(["keras.saving.load_model", "keras.models.load_model"]) +def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): + """Loads a model saved via `model.save()`. + + Args: + filepath: `str` or `pathlib.Path` object, path to the saved model file. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + compile: Boolean, whether to compile the model after loading. + safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. + When `safe_mode=False`, loading an object has the potential to + trigger arbitrary code execution. This argument is only + applicable to the Keras v3 model format. Defaults to `True`. + + Returns: + A Keras model instance. If the original model was compiled, + and the argument `compile=True` is set, then the returned model + will be compiled. Otherwise, the model will be left uncompiled. + + Example: + + ```python + model = keras.Sequential([ + keras.layers.Dense(5, input_shape=(3,)), + keras.layers.Softmax()]) + model.save("model.keras") + loaded_model = keras.saving.load_model("model.keras") + x = np.random.random((10, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + ``` + + Note that the model variables may have different name values + (`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded. + It is recommended that you use layer attributes to + access specific variables, e.g. `model.get_layer("dense_1").kernel`. + """ + is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile( + filepath + ) + is_keras_dir = file_utils.isdir(filepath) and file_utils.exists( + file_utils.join(filepath, "config.json") + ) + is_hf = str(filepath).startswith("hf://") + + # Support for remote zip files + if ( + file_utils.is_remote_path(filepath) + and not file_utils.isdir(filepath) + and not is_keras_zip + and not is_hf + ): + local_path = file_utils.join( + saving_lib.get_temp_dir(), os.path.basename(filepath) + ) + + # Copy from remote to temporary local directory + file_utils.copy(filepath, local_path) + + # Switch filepath to local zipfile for loading model + if zipfile.is_zipfile(local_path): + filepath = local_path + is_keras_zip = True + + if is_keras_zip or is_keras_dir or is_hf: + return saving_lib.load_model( + filepath, + custom_objects=custom_objects, + compile=compile, + safe_mode=safe_mode, + ) + if str(filepath).endswith((".h5", ".hdf5")): + return legacy_h5_format.load_model_from_hdf5( + filepath, + custom_objects=custom_objects, + compile=compile, + safe_mode=safe_mode, + ) + elif str(filepath).endswith(".keras"): + raise ValueError( + f"File not found: filepath={filepath}. " + "Please ensure the file is an accessible `.keras` " + "zip file." + ) + else: + raise ValueError( + f"File format not supported: filepath={filepath}. " + "Keras 3 only supports V3 `.keras` files and " + "legacy H5 format files (`.h5` extension). " + "Note that the legacy SavedModel format is not " + "supported by `load_model()` in Keras 3. In " + "order to reload a TensorFlow SavedModel as an " + "inference-only layer in Keras 3, use " + "`keras.layers.TFSMLayer(" + f"{filepath}, call_endpoint='serving_default')` " + "(note that your `call_endpoint` " + "might have a different name)." + ) + + +@keras_export("keras.saving.save_weights") +def save_weights( + model, filepath, overwrite=True, max_shard_size=None, **kwargs +): + filepath_str = str(filepath) + if max_shard_size is None and not filepath_str.endswith(".weights.h5"): + raise ValueError( + "The filename must end in `.weights.h5`. " + f"Received: filepath={filepath_str}" + ) + elif max_shard_size is not None and not filepath_str.endswith( + ("weights.h5", "weights.json") + ): + raise ValueError( + "The filename must end in `.weights.json` when `max_shard_size` is " + f"specified. Received: filepath={filepath_str}" + ) + try: + exists = os.path.exists(filepath) + except TypeError: + exists = False + if exists and not overwrite: + proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str) + if not proceed: + return + saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs) + + +@keras_export("keras.saving.load_weights") +def load_weights(model, filepath, skip_mismatch=False, **kwargs): + filepath_str = str(filepath) + + # Get the legacy kwargs. + objects_to_skip = kwargs.pop("objects_to_skip", None) + by_name = kwargs.pop("by_name", None) + if kwargs: + raise ValueError(f"Invalid keyword arguments: {kwargs}") + + if filepath_str.endswith(".keras"): + if objects_to_skip is not None: + raise ValueError( + "`objects_to_skip` only supports loading '.weights.h5' files." + f"Received: {filepath}" + ) + if by_name is not None: + raise ValueError( + "`by_name` only supports loading legacy '.h5' or '.hdf5' " + f"files. Received: {filepath}" + ) + saving_lib.load_weights_only( + model, filepath, skip_mismatch=skip_mismatch + ) + elif filepath_str.endswith(".weights.h5") or filepath_str.endswith( + ".weights.json" + ): + if by_name is not None: + raise ValueError( + "`by_name` only supports loading legacy '.h5' or '.hdf5' " + f"files. Received: {filepath}" + ) + saving_lib.load_weights_only( + model, + filepath, + skip_mismatch=skip_mismatch, + objects_to_skip=objects_to_skip, + ) + elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"): + if not h5py: + raise ImportError( + "Loading a H5 file requires `h5py` to be installed." + ) + if objects_to_skip is not None: + raise ValueError( + "`objects_to_skip` only supports loading '.weights.h5' files." + f"Received: {filepath}" + ) + with h5py.File(filepath, "r") as f: + if "layer_names" not in f.attrs and "model_weights" in f: + f = f["model_weights"] + if by_name: + legacy_h5_format.load_weights_from_hdf5_group_by_name( + f, model, skip_mismatch + ) + else: + legacy_h5_format.load_weights_from_hdf5_group( + f, model, skip_mismatch + ) + else: + raise ValueError( + f"File format not supported: filepath={filepath}. " + "Keras 3 only supports V3 `.keras` and `.weights.h5` " + "files, or legacy V1/V2 `.h5` files." + ) diff --git a/keras/src/saving/saving_api_test.py b/keras/src/saving/saving_api_test.py new file mode 100644 index 000000000000..638528eaac7b --- /dev/null +++ b/keras/src/saving/saving_api_test.py @@ -0,0 +1,311 @@ +import os +import pathlib +import unittest.mock as mock + +import numpy as np +from absl import logging +from absl.testing import parameterized + +from keras.src import layers +from keras.src.legacy.saving.legacy_h5_format import save_model_to_hdf5 +from keras.src.models import Sequential +from keras.src.saving import saving_api +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product + + +class SaveModelTests(test_case.TestCase): + def get_model(self): + return Sequential( + [ + layers.Dense(5, input_shape=(3,)), + layers.Softmax(), + ] + ) + + def test_basic_saving(self): + """Test basic model saving and loading.""" + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_model.keras") + saving_api.save_model(model, filepath) + + loaded_model = saving_api.load_model(filepath) + x = np.random.uniform(size=(10, 3)) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) + + def test_invalid_save_format(self): + """Test deprecated save_format argument.""" + model = self.get_model() + with self.assertRaisesRegex( + ValueError, "The `save_format` argument is deprecated" + ): + saving_api.save_model(model, "model.txt", save_format=True) + + def test_unsupported_arguments(self): + """Test unsupported argument during model save.""" + model = self.get_model() + filepath = os.path.join(self.get_temp_dir(), "test_model.keras") + with self.assertRaisesRegex( + ValueError, r"The following argument\(s\) are not supported" + ): + saving_api.save_model(model, filepath, random_arg=True) + + def test_save_h5_format(self): + """Test saving model in h5 format.""" + model = self.get_model() + filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") + + # Verify the warning. + with mock.patch.object(logging, "warning") as mock_warn: + saving_api.save_model(model, filepath_h5) + mock_warn.assert_called_once_with( + "You are saving your model as an HDF5 file via " + "`model.save()` or `keras.saving.save_model(model)`. " + "This file format is considered legacy. " + "We recommend using instead the native Keras format, " + "e.g. `model.save('my_model.keras')` or " + "`keras.saving.save_model(model, 'my_model.keras')`. " + ) + self.assertTrue(os.path.exists(filepath_h5)) + os.remove(filepath_h5) + + def test_save_unsupported_extension(self): + """Test saving model with unsupported extension.""" + model = self.get_model() + with self.assertRaisesRegex( + ValueError, "Invalid filepath extension for saving" + ): + saving_api.save_model(model, "model.png") + + def test_objects_to_skip(self): + model = Sequential( + [ + layers.Input((3,)), + layers.Dense(5), + layers.Dense(5), + ] + ) + skip = model.layers[0] + filepath = os.path.join(self.get_temp_dir(), "test_model.weights.h5") + saving_api.save_weights(model, filepath, objects_to_skip=[skip]) + new_model = Sequential( + [ + layers.Input((3,)), + layers.Dense(5), + layers.Dense(5), + ] + ) + new_model.load_weights(filepath, objects_to_skip=[new_model.layers[0]]) + self.assertNotAllClose( + new_model.layers[0].get_weights()[0], + model.layers[0].get_weights()[0], + ) + self.assertAllClose( + new_model.layers[0].get_weights()[1], + model.layers[0].get_weights()[1], + ) + saving_api.save_weights(model, filepath) + new_model.load_weights(filepath, objects_to_skip=[new_model.layers[0]]) + self.assertNotAllClose( + new_model.layers[0].get_weights()[0], + model.layers[0].get_weights()[0], + ) + self.assertAllClose( + new_model.layers[0].get_weights()[1], + model.layers[0].get_weights()[1], + ) + + +class LoadModelTests(test_case.TestCase): + def get_model(self, dtype=None): + return Sequential( + [ + layers.Dense(5, input_shape=(3,), dtype=dtype), + layers.Softmax(), + ] + ) + + @parameterized.named_parameters( + [ + {"testcase_name": "bfloat16", "dtype": "bfloat16"}, + {"testcase_name": "float16", "dtype": "float16"}, + {"testcase_name": "float32", "dtype": "float32"}, + {"testcase_name": "float64", "dtype": "float64"}, + ] + ) + def test_basic_load(self, dtype): + """Test basic model loading.""" + model = self.get_model(dtype) + filepath = os.path.join(self.get_temp_dir(), "test_model.keras") + saving_api.save_model(model, filepath) + + loaded_model = saving_api.load_model(filepath) + x = np.random.uniform(size=(10, 3)) + self.assertEqual(loaded_model.weights[0].dtype, dtype) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) + + def test_load_unsupported_format(self): + """Test loading model with unsupported format.""" + with self.assertRaisesRegex(ValueError, "File format not supported"): + saving_api.load_model("model.pkl") + + def test_load_keras_not_zip(self): + """Test loading keras file that's not a zip.""" + with self.assertRaisesRegex(ValueError, "File not found"): + saving_api.load_model("not_a_zip.keras") + + def test_load_h5_format(self): + """Test loading model in h5 format.""" + model = self.get_model() + filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") + saving_api.save_model(model, filepath_h5) + loaded_model = saving_api.load_model(filepath_h5) + x = np.random.uniform(size=(10, 3)) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) + os.remove(filepath_h5) + + def test_load_model_with_custom_objects(self): + """Test loading model with custom objects.""" + + class CustomLayer(layers.Layer): + def call(self, inputs): + return inputs + + model = Sequential([CustomLayer(input_shape=(3,))]) + filepath = os.path.join(self.get_temp_dir(), "custom_model.keras") + model.save(filepath) + loaded_model = saving_api.load_model( + filepath, custom_objects={"CustomLayer": CustomLayer} + ) + self.assertIsInstance(loaded_model.layers[0], CustomLayer) + os.remove(filepath) + + def test_save_unzipped(self): + """Test saving/loading an unzipped model dir.""" + model = self.get_model() + + # Test error with keras extension + bad_filepath = os.path.join(self.get_temp_dir(), "test_model.keras") + with self.assertRaisesRegex(ValueError, "should not end in"): + saving_api.save_model(model, bad_filepath, zipped=False) + + filepath = os.path.join(self.get_temp_dir(), "test_model_dir") + saving_api.save_model(model, filepath, zipped=False) + + self.assertTrue(os.path.exists(filepath)) + self.assertTrue(os.path.isdir(filepath)) + config_filepath = os.path.join(filepath, "config.json") + weights_filepath = os.path.join(filepath, "model.weights.h5") + self.assertTrue(os.path.exists(config_filepath)) + self.assertTrue(os.path.exists(weights_filepath)) + + loaded_model = saving_api.load_model(filepath) + x = np.random.uniform(size=(10, 3)) + self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) + + +class LoadWeightsTests(test_case.TestCase): + def get_model(self, dtype=None): + return Sequential( + [ + layers.Dense(5, input_shape=(3,), dtype=dtype), + layers.Softmax(), + ] + ) + + @parameterized.named_parameters( + named_product( + save_format=["keras", "weights.h5", "h5"], + source_dtype=["float64", "float32", "float16", "bfloat16"], + dest_dtype=["float64", "float32", "float16", "bfloat16"], + ) + ) + def test_load_weights(self, save_format, source_dtype, dest_dtype): + """Test loading keras weights.""" + src_model = self.get_model(dtype=source_dtype) + if save_format == "keras": + filepath = os.path.join(self.get_temp_dir(), "test_weights.keras") + src_model.save(filepath) + elif save_format == "weights.h5": + filepath = os.path.join( + self.get_temp_dir(), "test_weights.weights.h5" + ) + src_model.save_weights(filepath) + elif save_format == "h5": + if "bfloat16" in (source_dtype, dest_dtype): + raise self.skipTest( + "bfloat16 dtype is not supported in legacy h5 format." + ) + filepath = os.path.join(self.get_temp_dir(), "test_weights.h5") + save_model_to_hdf5(src_model, filepath) + else: + raise ValueError(f"Unsupported save format: {save_format}") + + dest_model = self.get_model(dtype=dest_dtype) + dest_model.load_weights(filepath) + + src_weights = src_model.get_weights() + dest_weights = dest_model.get_weights() + for orig, loaded in zip(src_weights, dest_weights): + self.assertAllClose( + orig.astype("float32"), + loaded.astype("float32"), + atol=0.001, + rtol=0.01, + ) + + def test_load_weights_invalid_kwargs(self): + src_model = self.get_model() + keras_filepath = os.path.join(self.get_temp_dir(), "test_weights.keras") + weight_h5_filepath = os.path.join( + self.get_temp_dir(), "test_weights.weights.h5" + ) + legacy_h5_filepath = os.path.join( + self.get_temp_dir(), "test_weights.h5" + ) + src_model.save(keras_filepath) + src_model.save_weights(weight_h5_filepath) + save_model_to_hdf5(src_model, legacy_h5_filepath) + + dest_model = self.get_model() + # Test keras file. + with self.assertRaisesRegex( + ValueError, r"only supports loading '.weights.h5' files." + ): + dest_model.load_weights(keras_filepath, objects_to_skip=[]) + with self.assertRaisesRegex( + ValueError, r"only supports loading legacy '.h5' or '.hdf5' files." + ): + dest_model.load_weights(keras_filepath, by_name=True) + with self.assertRaisesRegex(ValueError, r"Invalid keyword arguments"): + dest_model.load_weights(keras_filepath, bad_kwarg=None) + # Test weights.h5 file. + with self.assertRaisesRegex( + ValueError, r"only supports loading legacy '.h5' or '.hdf5' files." + ): + dest_model.load_weights(weight_h5_filepath, by_name=True) + # Test h5 file. + with self.assertRaisesRegex( + ValueError, r"only supports loading '.weights.h5' files." + ): + dest_model.load_weights(legacy_h5_filepath, objects_to_skip=[]) + + def test_load_weights_invalid_extension(self): + """Test loading weights with unsupported extension.""" + model = self.get_model() + with self.assertRaisesRegex(ValueError, "File format not supported"): + model.load_weights("invalid_extension.pkl") + + def test_load_sharded_weights(self): + src_model = self.get_model() + temp_filepath = pathlib.Path( + os.path.join(self.get_temp_dir(), "test_weights.weights.json") + ) + src_model.save_weights(temp_filepath, max_shard_size=1) + self.assertLen(os.listdir(temp_filepath.parent), 2) + src_weights = src_model.get_weights() + dest_model = self.get_model() + dest_model.load_weights(temp_filepath) + dest_weights = dest_model.get_weights() + for orig, loaded in zip(src_weights, dest_weights): + self.assertAllClose(orig, loaded) diff --git a/keras/src/saving/saving_lib.py b/keras/src/saving/saving_lib.py new file mode 100644 index 000000000000..c2d7b69d9f2e --- /dev/null +++ b/keras/src/saving/saving_lib.py @@ -0,0 +1,1658 @@ +"""Python-based idempotent model-saving functionality.""" + +import datetime +import io +import json +import math +import os +import pathlib +import shutil +import tempfile +import warnings +import zipfile + +import ml_dtypes +import numpy as np + +from keras.src import backend +from keras.src.backend.common import global_state +from keras.src.saving.serialization_lib import ObjectSharingScope +from keras.src.saving.serialization_lib import deserialize_keras_object +from keras.src.saving.serialization_lib import serialize_keras_object +from keras.src.utils import dtype_utils +from keras.src.utils import file_utils +from keras.src.utils import io_utils +from keras.src.utils import naming +from keras.src.utils import plot_model +from keras.src.utils.model_visualization import check_pydot +from keras.src.utils.summary_utils import readable_memory_size +from keras.src.utils.summary_utils import weight_memory_size +from keras.src.version import __version__ as keras_version + +try: + import h5py +except ImportError: + h5py = None +try: + import psutil +except ImportError: + psutil = None +try: + import huggingface_hub +except ImportError: + huggingface_hub = None + + +_CONFIG_FILENAME = "config.json" +_METADATA_FILENAME = "metadata.json" +_VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5" +_VARS_FNAME_H5 = f"{_VARS_FNAME}.h5" +_VARS_FNAME_NPZ = f"{_VARS_FNAME}.npz" +_ASSETS_DIRNAME = "assets" +_MEMORY_UPPER_BOUND = 0.5 # 50% + + +_MODEL_CARD_TEMPLATE = """ +--- +library_name: keras +--- + +This model has been uploaded using the Keras library and can be used with JAX, +TensorFlow, and PyTorch backends. + +This model card has been generated automatically and should be completed by the +model author. +See [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for +more information. + +For more details about the model architecture, check out +[config.json](./config.json).""" + + +def save_model(model, filepath, weights_format="h5", zipped=True): + """Save a zip-archive representing a Keras model to the given file or path. + + The zip-based archive contains the following structure: + + - JSON-based configuration file (config.json): Records of model, layer, and + other saveables' configuration. + - H5-based saveable state files, found in respective directories, such as + model/states.npz, model/dense_layer/states.npz, etc. + - Metadata file. + + The states of Keras saveables (layers, optimizers, loss, and metrics) are + automatically saved as long as they can be discovered through the attributes + returned by `dir(Model)`. Typically, the state includes the variables + associated with the saveable, but some specially purposed layers may + contain more such as the vocabularies stored in the hashmaps. The saveables + define how their states are saved by exposing `save_state()` and + `load_state()` APIs. + + For the case of layer states, the variables will be visited as long as + they are either 1) referenced via layer attributes, or 2) referenced via a + container (list, tuple, or dict), and the container is referenced via a + layer attribute. + """ + if weights_format == "h5" and h5py is None: + raise ImportError("h5py must be installed in order to save a model.") + + if not model.built: + warnings.warn( + "You are saving a model that has not yet been built. " + "It might not contain any weights yet. " + "Consider building the model first by calling it " + "on some data.", + stacklevel=2, + ) + + if isinstance(filepath, io.IOBase): + _save_model_to_fileobj(model, filepath, weights_format) + return + + filepath = str(filepath) + is_hf = filepath.startswith("hf://") + if zipped and not filepath.endswith(".keras"): + raise ValueError( + "Invalid `filepath` argument: expected a `.keras` extension. " + f"Received: filepath={filepath}" + ) + if not zipped and filepath.endswith(".keras"): + raise ValueError( + "When using `zipped=False`, the `filepath` argument should not " + f"end in `.keras`. Received: filepath={filepath}" + ) + if zipped and is_hf: + raise ValueError( + "When saving to the Hugging Face Hub, you should not save the " + f"model as zipped. Received: filepath={filepath}, zipped={zipped}" + ) + if is_hf: + _upload_model_to_hf(model, filepath, weights_format) + elif not zipped: + _save_model_to_dir(model, filepath, weights_format) + else: + if file_utils.is_remote_path(filepath): + # Remote path. Zip to local memory byte io and copy to remote + zip_filepath = io.BytesIO() + _save_model_to_fileobj(model, zip_filepath, weights_format) + with file_utils.File(filepath, "wb") as f: + f.write(zip_filepath.getvalue()) + else: + with open(filepath, "wb") as f: + _save_model_to_fileobj(model, f, weights_format) + + +def _serialize_model_as_json(model): + with ObjectSharingScope(): + serialized_model_dict = serialize_keras_object(model) + config_json = json.dumps(serialized_model_dict) + metadata_json = json.dumps( + { + "keras_version": keras_version, + "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), + } + ) + return config_json, metadata_json + + +def _save_model_to_dir(model, dirpath, weights_format): + if not file_utils.exists(dirpath): + file_utils.makedirs(dirpath) + config_json, metadata_json = _serialize_model_as_json(model) + with open(file_utils.join(dirpath, _METADATA_FILENAME), "w") as f: + f.write(metadata_json) + with open(file_utils.join(dirpath, _CONFIG_FILENAME), "w") as f: + f.write(config_json) + weights_filepath = file_utils.join(dirpath, _VARS_FNAME_H5) + assert_dirpath = file_utils.join(dirpath, _ASSETS_DIRNAME) + try: + if weights_format == "h5": + weights_store = H5IOStore(weights_filepath, mode="w") + elif weights_format == "npz": + weights_store = NpzIOStore(weights_filepath, mode="w") + else: + raise ValueError( + "Unknown `weights_format` argument. " + "Expected 'h5' or 'npz'. " + f"Received: weights_format={weights_format}" + ) + asset_store = DiskIOStore(assert_dirpath, mode="w") + _save_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + ) + finally: + weights_store.close() + asset_store.close() + + +def _save_model_to_fileobj(model, fileobj, weights_format): + config_json, metadata_json = _serialize_model_as_json(model) + + with zipfile.ZipFile(fileobj, "w") as zf: + with zf.open(_METADATA_FILENAME, "w") as f: + f.write(metadata_json.encode()) + with zf.open(_CONFIG_FILENAME, "w") as f: + f.write(config_json.encode()) + + weights_file_path = None + weights_store = None + asset_store = None + write_zf = False + try: + if weights_format == "h5": + try: + if is_memory_sufficient(model): + # Load the model weights into memory before writing + # .keras if the system memory is sufficient. + weights_store = H5IOStore( + _VARS_FNAME_H5, archive=zf, mode="w" + ) + else: + # Try opening the .h5 file, then writing it to `zf` at + # the end of the function call. This is more memory + # efficient than writing the weights into memory first. + working_dir = pathlib.Path(fileobj.name).parent + weights_file_path = tempfile.NamedTemporaryFile( + dir=working_dir + ) + weights_store = H5IOStore( + weights_file_path.name, mode="w" + ) + write_zf = True + except: + # If we can't use the local disk for any reason, write the + # weights into memory first, which consumes more memory. + weights_store = H5IOStore( + _VARS_FNAME_H5, archive=zf, mode="w" + ) + elif weights_format == "npz": + weights_store = NpzIOStore( + _VARS_FNAME_NPZ, archive=zf, mode="w" + ) + else: + raise ValueError( + "Unknown `weights_format` argument. " + "Expected 'h5' or 'npz'. " + f"Received: weights_format={weights_format}" + ) + + asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w") + + _save_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + ) + except: + # Skip the final `zf.write` if any exception is raised + write_zf = False + if weights_store: + weights_store.archive = None + raise + finally: + if weights_store: + weights_store.close() + if asset_store: + asset_store.close() + if write_zf and weights_file_path: + zf.write(weights_file_path.name, _VARS_FNAME_H5) + if weights_file_path: + weights_file_path.close() + + +def _upload_model_to_hf(model, hf_path, weights_format): + if huggingface_hub is None: + raise ImportError( + "To save models to the Hugging Face Hub, " + "you must install the `huggingface_hub` package." + ) + + original_hf_path = hf_path + if hf_path.startswith("hf://"): + hf_path = hf_path[5:] + if hf_path.count("/") > 1: + raise ValueError( + "Invalid `hf_path` argument: expected `namespace/model_name`" + f" format. Received: hf_path={original_hf_path}" + ) + + api = huggingface_hub.HfApi( + library_name="keras", library_version=keras_version + ) + repo_url = api.create_repo(hf_path, exist_ok=True) + repo_id = repo_url.repo_id + + with tempfile.TemporaryDirectory() as tmp_dir: + _save_model_to_dir(model, tmp_dir, weights_format) + + model_card = _MODEL_CARD_TEMPLATE + + if check_pydot(): + plot_path = file_utils.join(tmp_dir, "assets", "summary_plot.png") + plot_model( + model, + to_file=plot_path, + show_layer_names=True, + show_shapes=True, + show_dtype=True, + ) + if len(model.layers) <= 10: + model_card += "\n\n![](./assets/summary_plot.png)" + else: + model_card += ( + "A plot of the model can be found " + "[here](./assets/summary_plot.png)." + ) + + with open(file_utils.join(tmp_dir, "README.md"), "w") as f: + f.write(model_card) + + api.upload_folder( + repo_id=repo_id, + folder_path=tmp_dir, + commit_message="Save model using Keras.", + ) + io_utils.print_msg( + f"Model saved to the Hugging Face Hub: {repo_url}\n" + "To load back the model, use " + f"`keras.saving.load_model('hf://{repo_id}')`" + ) + + +def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): + """Load a zip archive representing a Keras model.""" + if isinstance(filepath, io.IOBase): + return _load_model_from_fileobj( + filepath, custom_objects, compile, safe_mode + ) + elif str(filepath).startswith("hf://"): + if huggingface_hub is None: + raise ImportError( + "To load models from the Hugging Face Hub, " + "you must install the `huggingface_hub` package." + ) + + repo_id = filepath[5:] + folder_path = huggingface_hub.snapshot_download( + repo_id=repo_id, + library_name="keras", + library_version=keras_version, + ) + return _load_model_from_dir( + folder_path, custom_objects, compile, safe_mode + ) + else: + filepath = str(filepath) + if not filepath.endswith(".keras"): + is_keras_dir = file_utils.isdir(filepath) and file_utils.exists( + file_utils.join(filepath, "config.json") + ) + if is_keras_dir: + return _load_model_from_dir( + filepath, custom_objects, compile, safe_mode + ) + raise ValueError( + "Invalid filename: expected a `.keras` extension. " + f"Received: filepath={filepath}" + ) + with open(filepath, "rb") as f: + return _load_model_from_fileobj( + f, custom_objects, compile, safe_mode + ) + + +def _load_model_from_dir(dirpath, custom_objects, compile, safe_mode): + if not file_utils.exists(dirpath): + raise ValueError(f"Directory doesn't exist: {dirpath}") + if not file_utils.isdir(dirpath): + raise ValueError(f"Path isn't a directory: {dirpath}") + + with open(file_utils.join(dirpath, _CONFIG_FILENAME), "r") as f: + config_json = f.read() + model = _model_from_config(config_json, custom_objects, compile, safe_mode) + + all_filenames = file_utils.listdir(dirpath) + try: + if _VARS_FNAME_H5 in all_filenames: + weights_file_path = file_utils.join(dirpath, _VARS_FNAME_H5) + weights_store = H5IOStore(weights_file_path, mode="r") + elif _VARS_FNAME_NPZ in all_filenames: + weights_file_path = file_utils.join(dirpath, _VARS_FNAME_NPZ) + weights_store = NpzIOStore(weights_file_path, mode="r") + else: + raise ValueError( + f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file." + ) + if len(all_filenames) > 3: + asset_store = DiskIOStore( + file_utils.join(dirpath, _ASSETS_DIRNAME), mode="r" + ) + + else: + asset_store = None + + failed_saveables = set() + error_msgs = {} + _load_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + + finally: + weights_store.close() + if asset_store: + asset_store.close() + + if failed_saveables: + _raise_loading_failure(error_msgs) + return model + + +def _model_from_config(config_json, custom_objects, compile, safe_mode): + # Note: we should NOT use a custom JSON decoder. Anything that + # needs custom decoding must be handled in deserialize_keras_object. + config_dict = json.loads(config_json) + if not compile: + # Disable compilation + config_dict["compile_config"] = None + # Construct the model from the configuration file in the archive. + with ObjectSharingScope(): + model = deserialize_keras_object( + config_dict, custom_objects, safe_mode=safe_mode + ) + return model + + +def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode): + with zipfile.ZipFile(fileobj, "r") as zf: + with zf.open(_CONFIG_FILENAME, "r") as f: + config_json = f.read() + + model = _model_from_config( + config_json, custom_objects, compile, safe_mode + ) + + all_filenames = zf.namelist() + extract_dir = None + weights_store = None + asset_store = None + try: + if _VARS_FNAME_H5 in all_filenames: + try: + if is_memory_sufficient(model): + # Load the entire file into memory if the system memory + # is sufficient. + io_file = io.BytesIO( + zf.open(_VARS_FNAME_H5, "r").read() + ) + weights_store = H5IOStore(io_file, mode="r") + else: + # Try extracting the model.weights.h5 file, and then + # loading it using using h5py. This is significantly + # faster than reading from the zip archive on the fly. + extract_dir = tempfile.TemporaryDirectory( + dir=pathlib.Path(fileobj.name).parent + ) + zf.extract(_VARS_FNAME_H5, extract_dir.name) + weights_store = H5IOStore( + pathlib.Path(extract_dir.name, _VARS_FNAME_H5), + mode="r", + ) + except: + # If we can't use the local disk for any reason, read the + # weights from the zip archive on the fly, which is less + # efficient. + weights_store = H5IOStore(_VARS_FNAME_H5, zf, mode="r") + elif _VARS_FNAME_NPZ in all_filenames: + weights_store = NpzIOStore(_VARS_FNAME_NPZ, zf, mode="r") + else: + raise ValueError( + f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file." + ) + + if len(all_filenames) > 3: + asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r") + + failed_saveables = set() + error_msgs = {} + _load_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + finally: + if weights_store: + weights_store.close() + if asset_store: + asset_store.close() + if extract_dir: + extract_dir.cleanup() + + if failed_saveables: + _raise_loading_failure(error_msgs) + return model + + +def save_weights_only( + model, filepath, max_shard_size=None, objects_to_skip=None +): + """Save only the weights of a model to a target filepath. + + Supports both `.weights.h5` and `.keras`. + """ + if not model.built: + raise ValueError( + "You are saving a model that has not yet been built. " + "Try building the model first by calling it on some data or " + "by using `build()`." + ) + + filepath_str = str(filepath) + tmp_dir = None + remote_filepath = None + if max_shard_size is None and not filepath_str.endswith(".weights.h5"): + raise ValueError( + "The filename must end in `.weights.h5`. " + f"Received: filepath={filepath_str}" + ) + elif max_shard_size is not None and not filepath_str.endswith( + ("weights.h5", "weights.json") + ): + raise ValueError( + "The filename must end in `.weights.json` when `max_shard_size` is " + f"specified. Received: filepath={filepath_str}" + ) + try: + if file_utils.is_remote_path(filepath): + tmp_dir = get_temp_dir() + local_filepath = os.path.join(tmp_dir, os.path.basename(filepath)) + remote_filepath = filepath + filepath = local_filepath + + if max_shard_size is not None: + weights_store = ShardedH5IOStore(filepath, max_shard_size, mode="w") + else: + weights_store = H5IOStore(filepath, mode="w") + if objects_to_skip is not None: + visited_saveables = set(id(o) for o in objects_to_skip) + else: + visited_saveables = set() + _save_state( + model, + weights_store=weights_store, + assets_store=None, + inner_path="", + visited_saveables=visited_saveables, + ) + weights_store.close() + finally: + if tmp_dir is not None: + file_utils.copy(filepath, remote_filepath) + shutil.rmtree(tmp_dir) + + +def load_weights_only( + model, filepath, skip_mismatch=False, objects_to_skip=None +): + """Load the weights of a model from a filepath (.keras or .weights.h5). + + Note: only supports h5 for now. + """ + if not model.built: + raise ValueError( + "You are loading weights into a model that has not yet been built. " + "Try building the model first by calling it on some data or " + "by using `build()`." + ) + + archive = None + tmp_dir = None + filepath_str = str(filepath) + + try: + if file_utils.is_remote_path(filepath_str): + tmp_dir = get_temp_dir() + local_filepath = os.path.join( + tmp_dir, os.path.basename(filepath_str) + ) + file_utils.copy(filepath_str, local_filepath) + filepath_str = filepath = local_filepath + + if filepath_str.endswith("weights.h5"): + weights_store = H5IOStore(filepath, mode="r") + elif filepath_str.endswith("weights.json"): + weights_store = ShardedH5IOStore(filepath, mode="r") + elif filepath_str.endswith(".keras"): + archive = zipfile.ZipFile(filepath, "r") + weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode="r") + + failed_saveables = set() + if objects_to_skip is not None: + visited_saveables = set(id(o) for o in objects_to_skip) + else: + visited_saveables = set() + error_msgs = {} + _load_state( + model, + weights_store=weights_store, + assets_store=None, + inner_path="", + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + weights_store.close() + if archive: + archive.close() + + if failed_saveables: + _raise_loading_failure(error_msgs, warn_only=skip_mismatch) + finally: + if tmp_dir is not None: + shutil.rmtree(tmp_dir) + + +def _raise_loading_failure(error_msgs, warn_only=False): + first_key = list(error_msgs.keys())[0] + ex_saveable, ex_error = error_msgs[first_key] + msg = ( + f"A total of {len(error_msgs)} objects could not " + "be loaded. Example error message for " + f"object {ex_saveable}:\n\n" + f"{ex_error}\n\n" + "List of objects that could not be loaded:\n" + f"{[x[0] for x in error_msgs.values()]}" + ) + if warn_only: + warnings.warn(msg) + else: + raise ValueError(msg) + + +def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path): + if not file_utils.isdir(system_path): + zipfile_to_save.write(system_path, zip_path) + else: + for file_name in file_utils.listdir(system_path): + system_file_path = file_utils.join(system_path, file_name).replace( + "\\", "/" + ) + zip_file_path = file_utils.join(zip_path, file_name).replace( + "\\", "/" + ) + _write_to_zip_recursively( + zipfile_to_save, system_file_path, zip_file_path + ) + + +def _name_key(name): + """Make sure that private attributes are visited last.""" + if name.startswith("_"): + return f"~{name}" + return name + + +def _walk_saveable(saveable): + from keras.src.saving.keras_saveable import KerasSaveable + + if not isinstance(saveable, KerasSaveable): + raise ValueError( + "Expected object to be an " + "instance of `KerasSaveable`, but " + f"got {saveable} of type {type(saveable)}" + ) + + obj_type = saveable._obj_type() + attr_skipset = get_attr_skipset(obj_type) + + # Save all layers directly tracked by Sequential and Functional first. + # This helps avoid ordering concerns for subclassed Sequential or Functional + # models with extra attributes--the internal Keras state take precedence. + if obj_type in ("Sequential", "Functional"): + yield "layers", saveable.layers + + for child_attr in sorted(dir(saveable), key=lambda x: _name_key(x)): + if child_attr.startswith("__") or child_attr in attr_skipset: + continue + try: + child_obj = getattr(saveable, child_attr) + except Exception: + # Avoid raising the exception when visiting the attributes. + continue + yield child_attr, child_obj + + +def _save_state( + saveable, + weights_store, + assets_store, + inner_path, + visited_saveables, +): + from keras.src.saving.keras_saveable import KerasSaveable + + if not isinstance(weights_store, (H5IOStore, ShardedH5IOStore, NpzIOStore)): + raise ValueError( + "Expected `weights_store` to be an instance of " + "`H5IOStore`, `ShardedH5IOStore` or `NpzIOStore`. " + f"Received: {weights_store} of type {type(weights_store)}" + ) + if not isinstance(assets_store, (DiskIOStore, type(None))): + raise ValueError( + "Expected `assets_store` to be an instance of " + "`DiskIOStore` or `None`. " + f"Received: {assets_store} of type {type(assets_store)}" + ) + + # If the saveable has already been saved, skip it. + if id(saveable) in visited_saveables: + return + + if hasattr(saveable, "save_own_variables") and weights_store: + if hasattr(saveable, "name") and isinstance(saveable.name, str): + metadata = {"name": saveable.name} + else: + metadata = None + saveable.save_own_variables( + weights_store.make(inner_path, metadata=metadata) + ) + if hasattr(saveable, "save_assets") and assets_store: + saveable.save_assets(assets_store.make(inner_path)) + + visited_saveables.add(id(saveable)) + + # Recursively save state of children saveables (layers, optimizers, etc.) + for child_attr, child_obj in _walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + _save_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + visited_saveables=visited_saveables, + ) + elif isinstance(child_obj, (list, dict, tuple, set)): + _save_container_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + visited_saveables=visited_saveables, + ) + + +def _load_state( + saveable, + weights_store, + assets_store, + inner_path, + skip_mismatch=False, + visited_saveables=None, + failed_saveables=None, + error_msgs=None, +): + from keras.src.saving.keras_saveable import KerasSaveable + + if not isinstance(weights_store, (H5IOStore, ShardedH5IOStore, NpzIOStore)): + raise ValueError( + "Expected `weights_store` to be an instance of " + "`H5IOStore`, `ShardedH5IOStore` or `NpzIOStore`. " + f"Received: {weights_store} of type {type(weights_store)}" + ) + if not isinstance(assets_store, (DiskIOStore, type(None))): + raise ValueError( + "Expected `assets_store` to be an instance of " + "`DiskIOStore` or `None`. " + f"Received: {assets_store} of type {type(assets_store)}" + ) + + if visited_saveables and id(saveable) in visited_saveables: + return + + failure = False + + if hasattr(saveable, "load_own_variables") and weights_store: + if skip_mismatch or failed_saveables is not None: + try: + saveable.load_own_variables(weights_store.get(inner_path)) + except Exception as e: + failed_saveables.add(id(saveable)) + error_msgs[id(saveable)] = saveable, e + failure = True + else: + saveable.load_own_variables(weights_store.get(inner_path)) + + if hasattr(saveable, "load_assets") and assets_store: + if skip_mismatch or failed_saveables is not None: + try: + saveable.load_assets(assets_store.get(inner_path)) + except Exception as e: + failed_saveables.add(id(saveable)) + error_msgs[id(saveable)] = saveable, e + failure = True + else: + saveable.load_assets(assets_store.get(inner_path)) + + if failed_saveables is not None: + currently_failed = len(failed_saveables) + else: + currently_failed = 0 + + # Recursively load states for Keras saveables such as layers/optimizers. + for child_attr, child_obj in _walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + _load_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + elif isinstance(child_obj, (list, dict, tuple, set)): + _load_container_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + + if failed_saveables is not None: + newly_failed = len(failed_saveables) - currently_failed + else: + newly_failed = 0 + + if not failure: + if visited_saveables is not None and newly_failed <= 0: + visited_saveables.add(id(saveable)) + if id(saveable) in failed_saveables: + failed_saveables.remove(id(saveable)) + error_msgs.pop(id(saveable)) + + +def _save_container_state( + container, weights_store, assets_store, inner_path, visited_saveables +): + from keras.src.saving.keras_saveable import KerasSaveable + + used_names = {} + if isinstance(container, dict): + container = list(container.values()) + + for saveable in container: + if isinstance(saveable, KerasSaveable): + # Do NOT address the saveable via `saveable.name`, since + # names are usually autogenerated and thus not reproducible + # (i.e. they may vary across two instances of the same model). + name = naming.to_snake_case(saveable.__class__.__name__) + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + _save_state( + saveable, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + visited_saveables=visited_saveables, + ) + + +def _load_container_state( + container, + weights_store, + assets_store, + inner_path, + skip_mismatch, + visited_saveables, + failed_saveables, + error_msgs, +): + from keras.src.saving.keras_saveable import KerasSaveable + + used_names = {} + if isinstance(container, dict): + container = list(container.values()) + + for saveable in container: + if isinstance(saveable, KerasSaveable): + name = naming.to_snake_case(saveable.__class__.__name__) + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + _load_state( + saveable, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + + +class DiskIOStore: + """Asset store backed by disk storage. + + If `archive` is specified, then `root_path` refers to the filename + inside the archive. + + If `archive` is not specified, then `root_path` refers to the full path of + the target directory. + """ + + def __init__(self, root_path, archive=None, mode=None): + self.mode = mode + self.root_path = root_path + self.archive = archive + self.tmp_dir = None + if self.archive: + self.tmp_dir = get_temp_dir() + if self.mode == "r": + self.archive.extractall(path=self.tmp_dir) + self.working_dir = file_utils.join( + self.tmp_dir, self.root_path + ).replace("\\", "/") + if self.mode == "w": + file_utils.makedirs(self.working_dir) + else: + if mode == "r": + self.working_dir = root_path + else: + self.tmp_dir = get_temp_dir() + self.working_dir = file_utils.join( + self.tmp_dir, self.root_path + ).replace("\\", "/") + file_utils.makedirs(self.working_dir) + + def make(self, path): + if not path: + return self.working_dir + path = file_utils.join(self.working_dir, path).replace("\\", "/") + if not file_utils.exists(path): + file_utils.makedirs(path) + return path + + def get(self, path): + if not path: + return self.working_dir + path = file_utils.join(self.working_dir, path).replace("\\", "/") + if file_utils.exists(path): + return path + return None + + def close(self): + if self.mode == "w" and self.archive: + _write_to_zip_recursively( + self.archive, self.working_dir, self.root_path + ) + if self.tmp_dir and file_utils.exists(self.tmp_dir): + file_utils.rmtree(self.tmp_dir) + + +class H5IOStore: + """Numerical variable store backed by HDF5. + + Args: + path_or_io: `str`, `pathlib.Path` or `io.BytesIO` object. The path where + to save the model. + archive: Optional `zipfile.ZipFile` object. If specified, the h5 file + will be saved inside the archive and `path_or_io` will be used as + the filename. + mode: `str`. One of {`"r"`, `"w"`}. The mode to open the h5 file. + Defaults to `"r"`. + """ + + def __init__(self, path_or_io, archive=None, mode="r"): + if mode not in ("w", "r"): + raise ValueError( + f"`mode` should be either 'w' or 'r'. Received: {mode}" + ) + if isinstance(path_or_io, (str, pathlib.Path)): + self.path_or_io = pathlib.Path(path_or_io) + elif isinstance(path_or_io, io.BytesIO): + if archive is not None: + raise ValueError( + "When `path_or_io` is an `io.BytesIO` object, `archive` " + "should be `None`." + ) + self.path_or_io = path_or_io + else: + raise TypeError( + "`path_or_io` should be a `str`, `pathlib.Path` or " + f"`io.BytesIO` object. Received: path_or_io={path_or_io} of " + f"type {type(path_or_io)}." + ) + self.mode = mode + self.archive = archive + self.io_file = None + + # Init H5 file. + self.h5_file = self._get_h5_file(self.path_or_io) + + # Init H5 entry group. + self._h5_entry_path = None + self._h5_entry_group = {} + self._h5_entry_metadata = None + self._h5_entry_initialized = False + + def __bool__(self): + # Delegate `__bool__` to the underlying `h5_file`. Otherwise, Python + # will mistakenly using `__len__` to determine the value. + return self.h5_file.__bool__() + + def _get_h5_file(self, path_or_io, mode=None): + mode = mode or self.mode + if mode not in ("r", "w", "a"): + raise ValueError( + f"`mode` should be either 'r', 'w' or 'a'. Received: {mode}" + ) + if self.archive: + if mode == "w": + self.io_file = io.BytesIO() + else: + self.io_file = self.archive.open(str(path_or_io), "r") + return h5py.File(self.io_file, mode=mode) + else: + return h5py.File(path_or_io, mode=mode) + + def make(self, path, metadata=None): + """Make a new H5 entry group. + + This method is only available in write mode. It defers the creation of + the H5 entry group until `__setitem__` is called, preventing the + creation of empty groups. + + Args: + path: `str`. The variable path. + metadata: Optional `dict`. The metadata to save with the H5 entry + group. Defaults to `None`. + """ + if self.mode != "w": + raise ValueError("`make` is only allowed in write mode.") + if not isinstance(metadata, (dict, type(None))): + raise ValueError( + f"`metadata` should be a dict or `None`. Received: {metadata}" + ) + + self._h5_entry_path = path + if metadata: + self._create_h5_group(path, metadata=metadata) + else: + # Defer to `__setitem__` for H5 group creation to prevent the + # creation of empty groups when the store is unused. + self._h5_entry_group = {} + self._h5_entry_initialized = False + return self + + def get(self, path): + """Get the H5 entry group. + + This method is only available in read mode. + + Args: + path: `str`. The variable path. + """ + if self.mode != "r": + raise ValueError("`get` is only allowed in read mode.") + + self._h5_entry_path = path + self._h5_entry_group = {} # Defaults to an empty dict if not found. + if not path: + if "vars" in self.h5_file: + self._h5_entry_group = self.h5_file["vars"] + elif path in self.h5_file and "vars" in self.h5_file[path]: + self._h5_entry_group = self.h5_file[path]["vars"] + else: + # No hit. Fix for 2.13 compatibility. + if "_layer_checkpoint_dependencies" in self.h5_file: + path = path.replace("layers", "_layer_checkpoint_dependencies") + if path in self.h5_file and "vars" in self.h5_file[path]: + self._h5_entry_group = self.h5_file[path]["vars"] + self._h5_entry_initialized = True + return self + + def close(self): + self.h5_file.close() + if self.mode == "w" and self.archive: + self.archive.writestr(str(self.path_or_io), self.io_file.getvalue()) + if self.io_file: + self.io_file.close() + + # H5 entry level methods. + + def _create_h5_group(self, path, metadata=None): + if not path: + self._h5_entry_group = self.h5_file.create_group("vars") + else: + self._h5_entry_group = self.h5_file.create_group(path).create_group( + "vars" + ) + if metadata: + for k, v in metadata.items(): + self._h5_entry_group.attrs[k] = v + + self._h5_entry_initialized = True + + def __len__(self): + return self._h5_entry_group.__len__() + + def keys(self): + return self._h5_entry_group.keys() + + def items(self): + return self._h5_entry_group.items() + + def values(self): + return self._h5_entry_group.values() + + def __getitem__(self, key): + value = self._h5_entry_group[key] + if ( + hasattr(value, "attrs") + and "dtype" in value.attrs + and value.attrs["dtype"] == "bfloat16" + ): + value = np.array(value, dtype=ml_dtypes.bfloat16) + elif ( + hasattr(value, "shape") + and hasattr(value, "dtype") + and not isinstance(value, np.ndarray) + ): + value = np.array(value) + return value + + def __setitem__(self, key, value): + if self.mode not in ("w", "a"): + raise ValueError("Setting a value is only allowed in write mode.") + if not self._h5_entry_initialized: + self._create_h5_group(self._h5_entry_path) + + value = backend.convert_to_numpy(value) + if backend.standardize_dtype(value.dtype) == "bfloat16": + ds = self._h5_entry_group.create_dataset(key, data=value) + ds.attrs["dtype"] = "bfloat16" + else: + self._h5_entry_group[key] = value + + def __delitem__(self, key): + if self.mode not in ("w", "a"): + raise ValueError("Deleting a value is only allowed in write mode.") + del self._h5_entry_group[key] + + def __contains__(self, item): + return item in self._h5_entry_group + + +class ShardedH5IOStore(H5IOStore): + """Sharded numerical variable store backed by HDF5. + + Args: + path_or_io: `str` or `pathlib.Path` object. The path where to save the + model. + max_shard_size: `int` or `float`. Maximum size in GB for each sharded + file. If `None`, no sharding will be done. Defaults to `None`. + archive: Optional `zipfile.ZipFile` object. If specified, the h5 file + will be saved inside the archive and `path_or_io` will be used as + the filename. + mode: `str`. One of {'r', 'w'}. The mode to open the h5 file. Defaults + to `"r"`. + """ + + def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"): + if mode not in ("w", "r"): + raise ValueError( + f"`mode` should be either 'w' or 'r'. Received: {mode}" + ) + if not isinstance(path_or_io, (str, pathlib.Path)): + raise TypeError( + "`path_or_io` should be a `str`, `pathlib.Path` object. " + f"Received: path_or_io={path_or_io} of type {type(path_or_io)}." + ) + self.path = pathlib.Path(path_or_io) + self.mode = mode + self.archive = archive + self.io_file = None + + self.max_shard_size = float(max_shard_size) * 1024**3 # To bytes. + self.base_name = self.path.stem.replace(".weights", "") + + if self.path.suffix != ".json": + method = "Saving" if self.mode == "w" else "Loading" + new_path = self.path.with_suffix(".json") + warnings.warn( + f"{method} sharded weights requires `*.json` as the " + f"extension. The original path: {str(self.path)} will be " + f"renamed to {str(new_path)}." + ) + self.path = new_path + + # Init H5 entry group. + self._h5_entry_path = None + self._h5_entry_group = {} + self._h5_entry_metadata = None + self._h5_entry_initialized = False + + # Init shard parameters. + self.current_shard_index = 0 + self.current_shard_size = 0 + self.total_shard_size = 0 # In bytes. + self.current_shard_path = None + self.current_shard_filenames = [] + if self.mode == "w": + self.sharding_config = { + "metadata": { + "total_size": 0, + }, + "weight_map": {}, + } + else: + if self.archive: + self.sharding_config = json.loads( + self.archive.open(str(self.path), "r").read() + ) + else: + with open(self.path, "r") as map_file: + self.sharding_config = json.load(map_file) + self.h5_file = self._create_new_shard_file() + + def make(self, path, metadata=None): + """Make a new H5 entry group. + + This method is only available in write mode. It defers the creation of + the H5 entry group until `__setitem__` is called, preventing the + creation of empty groups. + + The information about the current shard is reset. + + Args: + path: `str`. The variable path. + metadata: Optional `dict`. The metadata to save with the H5 entry + group. Defaults to `None`. + """ + self.current_shard_filenames = [] + if self.h5_file is not None: + self.current_shard_filenames.append( + pathlib.Path(self.h5_file.filename).name + ) + return super().make(path, metadata) + + def get(self, path): + """Get the H5 entry group. + + This method is only available in read mode. If the path is not found in + the current shard, it will switch to the correct shard. + + Args: + path: `str`. The variable path. + """ + if not path: + parsed_path = "/vars" + else: + parsed_path = path + + # If not found, check shard map and switch files. + weight_map = self.sharding_config["weight_map"] + filenames = weight_map.get(parsed_path) or weight_map.get( + f"/{parsed_path}/vars" + ) + if filenames is not None: + if not isinstance(filenames, list): + filenames = [filenames] + self.current_shard_filenames = filenames + filename = filenames[0] + else: + self.current_shard_filenames = [] + filename = None + + if filename is not None and filename != self.current_shard_path.name: + self.close() + self.h5_file = self._get_h5_file(self.path.with_name(filename)) + return super().get(path) + + def close(self): + if self.h5_file is not None: + self.h5_file.close() + self.h5_file = None + if self.mode == "w": + self.sharding_config["metadata"]["total_size"] = ( + self.total_shard_size + ) + json_str = json.dumps(self.sharding_config, indent=4) + if self.archive: + self.archive.writestr(str(self.path), json_str) + self.archive.writestr( + str(self.current_shard_path), self.io_file.getvalue() + ) + else: + with open(self.path, "w") as f: + f.write(json_str) + if self.io_file: + self.io_file.close() + + # Shard-specific methods. + + def _create_new_shard_file(self): + """Create a new shard file and return the H5 file object.""" + new_shard_path = ( + f"{self.base_name}_{self.current_shard_index:05}.weights.h5" + ) + self.current_shard_index += 1 + self.current_shard_path = self.path.with_name(new_shard_path) + h5_file = self._get_h5_file(self.current_shard_path) + self.current_shard_filenames.append(pathlib.Path(h5_file.filename).name) + self._h5_entry_initialized = False + return h5_file + + def _switch_h5_file(self, filename, mode): + """Switch to a different H5 file with the specified mode. + + This is useful for retrieving information from all shards, such as the + total length, keys, and items. + """ + if mode not in ("r", "a"): + raise ValueError( + f"`mode` should be either 'r' or 'a'. Received: {mode}" + ) + self.close() + self.h5_file = self._get_h5_file( + self.path.with_name(filename), mode=mode + ) + self._get_h5_group(self._h5_entry_path) + + def _restore_h5_file(self): + """Ensure the current shard is the last one created. + + We use mode="a" to avoid truncating the file during the switching. + """ + if ( + pathlib.Path(self.h5_file.filename).name + != self.current_shard_path.name + ): + self._switch_h5_file(self.current_shard_path.name, mode="a") + + # H5 entry level methods. + + def _get_h5_group(self, path): + """Get the H5 entry group. If it doesn't exist, return an empty dict.""" + try: + if not path: + self._h5_entry_group = self.h5_file["vars"] + else: + self._h5_entry_group = self.h5_file[path]["vars"] + self._h5_entry_initialized = True + except KeyError: + self._h5_entry_group = {} + self._h5_entry_initialized = False + + # Dict methods. + + def __len__(self): + total_len = self._h5_entry_group.__len__() + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + total_len += self._h5_entry_group.__len__() + self._restore_h5_file() + return total_len + + def keys(self): + keys = set(self._h5_entry_group.keys()) + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + keys.update(self._h5_entry_group.keys()) + self._restore_h5_file() + return keys + + def items(self): + yield from self._h5_entry_group.items() + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + yield from self._h5_entry_group.items() + self._restore_h5_file() + + def values(self): + yield from self._h5_entry_group.values() + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + yield from self._h5_entry_group.values() + self._restore_h5_file() + + def __getitem__(self, key): + if key in self._h5_entry_group: + return super().__getitem__(key) + + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + if key in self._h5_entry_group: + item = super().__getitem__(key) + self._restore_h5_file() + return item + raise KeyError( + f"Key '{key}' not found in any of the shards: " + f"{self.current_shard_filenames}" + ) + + def __setitem__(self, key, value): + self._restore_h5_file() + + # Accumulate `current_shard_size`. + value = backend.convert_to_numpy(value) + dtype = backend.standardize_dtype(value.dtype) + weight_counts = math.prod(value.shape) + per_param_size = dtype_utils.dtype_size(dtype) + value_size = weight_counts * per_param_size / 8 # In bytes. + self.total_shard_size += value_size + if value_size > self.max_shard_size: + value_size_str = readable_memory_size(value_size) + max_shard_size_str = readable_memory_size(self.max_shard_size) + raise ValueError( + f"The size of {key} is {value_size_str} which " + f"exceeds the maximum shard size {max_shard_size_str}. You " + "can increase the `max_shard_size` parameter to accommodate " + "the size." + ) + + # Create a new shard if the current shard is full. + self.current_shard_size += value_size + if self.current_shard_size > self.max_shard_size: + self.close() + self.h5_file = self._create_new_shard_file() + self.current_shard_size = value_size + + super().__setitem__(key, value) + + # Update the weight map. + variable_path = self._h5_entry_group.name + shard_filename = self.current_shard_path.name + weight_map = self.sharding_config["weight_map"] + if variable_path not in weight_map: + weight_map[variable_path] = shard_filename + else: + if not isinstance(weight_map[variable_path], list): + weight_map[variable_path] = [weight_map[variable_path]] + if shard_filename not in weight_map[variable_path]: + weight_map[variable_path].append(shard_filename) + + def __delitem__(self, key): + if key in self._h5_entry_group: + super().__delitem__(key) + return + + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="a") + if key in self._h5_entry_group: + super().__delitem__(key) + self._restore_h5_file() + return + raise KeyError( + f"Key '{key}' not found in any of the shards: " + f"{self.current_shard_filenames}" + ) + + def __contains__(self, item): + if item in self._h5_entry_group: + return True + + for filename in self.current_shard_filenames: + if filename == self.current_shard_path.name: + continue + self._switch_h5_file(filename, mode="r") + if item in self._h5_entry_group: + self._restore_h5_file() + return True + self._restore_h5_file() + return False + + +class NpzIOStore: + def __init__(self, root_path, archive=None, mode="r"): + """Numerical variable store backed by NumPy.savez/load. + + If `archive` is specified, then `root_path` refers to the filename + inside the archive. + + If `archive` is not specified, then `root_path` refers to the path of + the npz file on disk. + """ + self.root_path = root_path + self.mode = mode + self.archive = archive + if mode == "w": + self.contents = {} + else: + if self.archive: + self.f = archive.open(root_path, mode="r") + else: + self.f = open(root_path, mode="rb") + self.contents = np.load(self.f) + + def make(self, path, metadata=None): + if not path: + self.contents["__root__"] = {} + return self.contents["__root__"] + self.contents[path] = {} + return self.contents[path] + + def get(self, path): + if not path: + if "__root__" in self.contents: + return dict(self.contents["__root__"]) + return {} + if path in self.contents: + return self.contents[path].tolist() + return {} + + def close(self): + if self.mode == "w": + if self.archive: + self.f = self.archive.open( + self.root_path, mode="w", force_zip64=True + ) + else: + self.f = open(self.root_path, mode="wb") + np.savez(self.f, **self.contents) + self.f.close() + + +def get_temp_dir(): + temp_dir = tempfile.mkdtemp() + testfile = tempfile.TemporaryFile(dir=temp_dir) + testfile.close() + return temp_dir + + +def get_attr_skipset(obj_type): + skipset = global_state.get_global_attribute( + f"saving_attr_skiplist_{obj_type}", None + ) + if skipset is not None: + return skipset + + skipset = set( + [ + "_self_unconditional_dependency_names", + ] + ) + if obj_type == "Operation": + from keras.src.ops.operation import Operation + + ref_obj = Operation() + skipset.update(dir(ref_obj)) + elif obj_type == "Layer": + from keras.src.layers.layer import Layer + + ref_obj = Layer() + skipset.update(dir(ref_obj)) + elif obj_type == "Functional": + from keras.src.layers.layer import Layer + + ref_obj = Layer() + skipset.update(dir(ref_obj) + ["operations", "_operations"]) + elif obj_type == "Sequential": + from keras.src.layers.layer import Layer + + ref_obj = Layer() + skipset.update(dir(ref_obj) + ["_functional"]) + elif obj_type == "Metric": + from keras.src.metrics.metric import Metric + from keras.src.trainers.compile_utils import CompileMetrics + + ref_obj_a = Metric() + ref_obj_b = CompileMetrics([], []) + skipset.update(dir(ref_obj_a) + dir(ref_obj_b)) + elif obj_type == "Optimizer": + from keras.src.optimizers.optimizer import Optimizer + + ref_obj = Optimizer(1.0) + skipset.update(dir(ref_obj)) + skipset.remove("variables") + elif obj_type == "Loss": + from keras.src.losses.loss import Loss + + ref_obj = Loss() + skipset.update(dir(ref_obj)) + elif obj_type == "Cross": + from keras.src.layers.preprocessing.feature_space import Cross + + ref_obj = Cross((), 1) + skipset.update(dir(ref_obj)) + elif obj_type == "Feature": + from keras.src.layers.preprocessing.feature_space import Feature + + ref_obj = Feature("int32", lambda x: x, "int") + skipset.update(dir(ref_obj)) + else: + raise ValueError( + f"get_attr_skipset got invalid {obj_type=}. " + "Accepted values for `obj_type` are " + "['Operation', 'Layer', 'Functional', 'Sequential', 'Metric', " + "'Optimizer', 'Loss', 'Cross', 'Feature']" + ) + + global_state.set_global_attribute( + f"saving_attr_skipset_{obj_type}", skipset + ) + return skipset + + +def is_memory_sufficient(model): + """Check if there is sufficient memory to load the model into memory. + + If psutil is installed, we can use it to determine whether the memory is + sufficient. Otherwise, we use a predefined value of 1 GB for available + memory. + """ + if psutil is None: + available_memory = 1024 * 1024 * 1024 # 1 GB in bytes + else: + available_memory = psutil.virtual_memory().available # In bytes + return ( + weight_memory_size(model.variables) + < available_memory * _MEMORY_UPPER_BOUND + ) diff --git a/keras/src/saving/saving_lib_test.py b/keras/src/saving/saving_lib_test.py new file mode 100644 index 000000000000..2aef81f66ea3 --- /dev/null +++ b/keras/src/saving/saving_lib_test.py @@ -0,0 +1,1381 @@ +"""Tests for Keras python-based idempotent saving functions.""" + +import json +import os +import warnings +import zipfile +from io import BytesIO +from pathlib import Path +from unittest import mock + +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.saving import saving_lib + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class MyDense(keras.layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + self.nested_layer = keras.layers.Dense(self.units, name="dense") + + def build(self, input_shape): + self.additional_weights = [ + self.add_weight( + shape=(), + name="my_additional_weight", + initializer="ones", + trainable=True, + ), + self.add_weight( + shape=(), + name="my_additional_weight_2", + initializer="ones", + trainable=True, + ), + ] + self.weights_in_dict = { + "my_weight": self.add_weight( + shape=(), + name="my_dict_weight", + initializer="ones", + trainable=True, + ), + } + self.nested_layer.build(input_shape) + + def call(self, inputs): + return self.nested_layer(inputs) + + def two(self): + return 2 + + +ASSETS_DATA = "These are my assets" +VARIABLES_DATA = np.random.random((10,)) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class LayerWithCustomSaving(MyDense): + def build(self, input_shape): + self.assets = ASSETS_DATA + self.stored_variables = VARIABLES_DATA + return super().build(input_shape) + + def save_assets(self, inner_path): + with open(os.path.join(inner_path, "assets.txt"), "w") as f: + f.write(self.assets) + + def save_own_variables(self, store): + store["variables"] = self.stored_variables + + def load_assets(self, inner_path): + with open(os.path.join(inner_path, "assets.txt"), "r") as f: + text = f.read() + self.assets = text + + def load_own_variables(self, store): + self.stored_variables = np.array(store["variables"]) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class CustomModelX(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dense1 = MyDense(1, name="my_dense_1") + self.dense2 = MyDense(1, name="my_dense_2") + + def call(self, inputs): + out = self.dense1(inputs) + return self.dense2(out) + + def one(self): + return 1 + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class ModelWithCustomSaving(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.custom_dense = LayerWithCustomSaving(1) + + def call(self, inputs): + return self.custom_dense(inputs) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class CompileOverridingModel(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dense1 = MyDense(1) + + def compile(self, *args, **kwargs): + super().compile(*args, **kwargs) + + def call(self, inputs): + return self.dense1(inputs) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class CompileOverridingSequential(keras.Sequential): + def compile(self, *args, **kwargs): + super().compile(*args, **kwargs) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class SubclassFunctional(keras.Model): + """Subclassed functional identical to `_get_basic_functional_model`.""" + + def __init__(self, **kwargs): + inputs = keras.Input(shape=(4,), batch_size=2) + dense = keras.layers.Dense(1, name="first_dense") + x = dense(inputs) + outputs = keras.layers.Dense(1, name="second_dense")(x) + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + # Attrs for layers in the functional graph should not affect saving + self.layer_attr = dense + + @property + def layer_property(self): + # Properties for layers in the functional graph should not affect saving + return self.layer_attr + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +def my_mean_squared_error(y_true, y_pred): + """Identical to built-in `mean_squared_error`, but as a custom fn.""" + return ops.mean(ops.square(y_pred - y_true), axis=-1) + + +def _get_subclassed_model(compile=True): + subclassed_model = CustomModelX(name="custom_model_x") + if compile: + subclassed_model.compile( + optimizer="adam", + loss=my_mean_squared_error, + metrics=[keras.metrics.Hinge(), "mse"], + ) + return subclassed_model + + +def _get_custom_sequential_model(compile=True): + sequential_model = keras.Sequential( + [MyDense(1), MyDense(1)], name="sequential" + ) + if compile: + sequential_model.compile( + optimizer="adam", + loss=my_mean_squared_error, + metrics=[keras.metrics.Hinge(), "mse"], + ) + return sequential_model + + +def _get_basic_sequential_model(compile=True): + sequential_model = keras.Sequential( + [ + keras.layers.Dense(1, name="dense_1"), + keras.layers.Dense(1, name="dense_2"), + ], + name="sequential", + ) + if compile: + sequential_model.compile( + optimizer="adam", + loss=my_mean_squared_error, + metrics=[keras.metrics.Hinge(), "mse"], + ) + return sequential_model + + +def _get_custom_functional_model(compile=True): + inputs = keras.Input(shape=(4,), batch_size=2) + x = MyDense(1, name="first_dense")(inputs) + outputs = MyDense(1, name="second_dense")(x) + functional_model = keras.Model(inputs, outputs) + if compile: + functional_model.compile( + optimizer="adam", + loss=my_mean_squared_error, + metrics=[keras.metrics.Hinge(), "mse"], + ) + return functional_model + + +def _get_basic_functional_model(compile=True): + inputs = keras.Input(shape=(4,), batch_size=2) + x = keras.layers.Dense(1, name="first_dense")(inputs) + outputs = keras.layers.Dense(1, name="second_dense")(x) + functional_model = keras.Model(inputs, outputs) + if compile: + functional_model.compile( + optimizer="adam", + loss=my_mean_squared_error, + metrics=[keras.metrics.Hinge(), "mse"], + ) + return functional_model + + +def _get_subclassed_functional_model(compile=True): + functional_model = SubclassFunctional() + if compile: + functional_model.compile( + optimizer="adam", + loss=my_mean_squared_error, + metrics=[keras.metrics.Hinge(), "mse"], + ) + return functional_model + + +# We need a global function for `Pool.apply_async` +def _load_model_fn(filepath): + saving_lib.load_model(filepath) + + +class SavingTest(testing.TestCase): + def setUp(self): + # Set `_MEMORY_UPPER_BOUND` to zero for testing purpose. + self.original_value = saving_lib._MEMORY_UPPER_BOUND + saving_lib._MEMORY_UPPER_BOUND = 0 + return super().setUp() + + def tearDown(self): + saving_lib._MEMORY_UPPER_BOUND = self.original_value + return super().tearDown() + + def _test_inference_after_instantiation(self, model): + x_ref = np.random.random((2, 4)) + y_ref = model(x_ref) + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + model.save(temp_filepath) + + loaded_model = saving_lib.load_model(temp_filepath) + self.assertFalse(model.compiled) + for w_ref, w in zip(model.variables, loaded_model.variables): + self.assertAllClose(w_ref, w) + self.assertAllClose(y_ref, loaded_model(x_ref)) + + @parameterized.named_parameters( + ("subclassed", _get_subclassed_model), + ("basic_sequential", _get_basic_sequential_model), + ("basic_functional", _get_basic_functional_model), + ("custom_sequential", _get_custom_sequential_model), + ("custom_functional", _get_custom_functional_model), + ("subclassed_functional", _get_subclassed_functional_model), + ) + def test_inference_after_instantiation(self, model_fn): + model = model_fn(compile=False) + self._test_inference_after_instantiation(model) + + # Test small model path + saving_lib._MEMORY_UPPER_BOUND = 1.0 + self._test_inference_after_instantiation(model) + + def _test_compile_preserved(self, model): + x_ref = np.random.random((2, 4)) + y_ref = np.random.random((2, 1)) + + model.fit(x_ref, y_ref) + out_ref = model(x_ref) + ref_metrics = model.evaluate(x_ref, y_ref) + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + model.save(temp_filepath) + + loaded_model = saving_lib.load_model(temp_filepath) + self.assertTrue(model.compiled) + self.assertTrue(loaded_model.built) + for w_ref, w in zip(model.variables, loaded_model.variables): + self.assertAllClose(w_ref, w) + self.assertAllClose(out_ref, loaded_model(x_ref)) + + self.assertEqual( + model.optimizer.__class__, loaded_model.optimizer.__class__ + ) + self.assertEqual( + model.optimizer.get_config(), loaded_model.optimizer.get_config() + ) + for w_ref, w in zip( + model.optimizer.variables, loaded_model.optimizer.variables + ): + self.assertAllClose(w_ref, w) + + new_metrics = loaded_model.evaluate(x_ref, y_ref) + for ref_m, m in zip(ref_metrics, new_metrics): + self.assertAllClose(ref_m, m) + + @parameterized.named_parameters( + ("subclassed", _get_subclassed_model), + ("basic_sequential", _get_basic_sequential_model), + ("basic_functional", _get_basic_functional_model), + ("custom_sequential", _get_custom_sequential_model), + ("custom_functional", _get_custom_functional_model), + ("subclassed_functional", _get_subclassed_functional_model), + ) + @pytest.mark.requires_trainable_backend + def test_compile_preserved(self, model_fn): + model = model_fn(compile=True) + self._test_compile_preserved(model) + + # Test small model path + saving_lib._MEMORY_UPPER_BOUND = 1.0 + self._test_compile_preserved(model) + + def test_saving_preserve_unbuilt_state(self): + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + subclassed_model = CustomModelX() + subclassed_model.save(temp_filepath) + loaded_model = saving_lib.load_model(temp_filepath) + self.assertEqual(subclassed_model.compiled, loaded_model.compiled) + self.assertFalse(subclassed_model.built) + self.assertFalse(loaded_model.built) + + @pytest.mark.requires_trainable_backend + def test_saved_module_paths_and_class_names(self): + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + subclassed_model = _get_subclassed_model() + x = np.random.random((100, 32)) + y = np.random.random((100, 1)) + subclassed_model.fit(x, y, epochs=1) + subclassed_model.save(temp_filepath) + + with zipfile.ZipFile(temp_filepath, "r") as z: + with z.open(saving_lib._CONFIG_FILENAME, "r") as c: + config_json = c.read() + config_dict = json.loads(config_json) + self.assertEqual( + config_dict["registered_name"], "my_custom_package>CustomModelX" + ) + self.assertEqual( + config_dict["compile_config"]["optimizer"], + keras.src.saving.serialize_keras_object( + keras.src.optimizers.get("adam") + ), + ) + self.assertEqual( + config_dict["compile_config"]["loss"]["config"], + "my_custom_package>my_mean_squared_error", + ) + + @pytest.mark.requires_trainable_backend + def test_saving_custom_assets_and_variables(self): + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + model = ModelWithCustomSaving() + model.compile( + optimizer="adam", + loss="mse", + ) + x = np.random.random((100, 32)) + y = np.random.random((100, 1)) + model.fit(x, y, epochs=1) + + # Assert that the archive has not been saved. + self.assertFalse(os.path.exists(temp_filepath)) + + model.save(temp_filepath) + + loaded_model = saving_lib.load_model(temp_filepath) + self.assertEqual(loaded_model.custom_dense.assets, ASSETS_DATA) + self.assertEqual( + loaded_model.custom_dense.stored_variables.tolist(), + VARIABLES_DATA.tolist(), + ) + + def _test_compile_overridden_warnings(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") + model = ( + CompileOverridingModel() + if model_type == "subclassed" + else CompileOverridingSequential( + [keras.layers.Embedding(4, 1), MyDense(1), MyDense(1)] + ) + ) + model.compile("sgd", "mse") + model.save(temp_filepath) + + with mock.patch.object(warnings, "warn") as mock_warn: + saving_lib.load_model(temp_filepath) + if not mock_warn.call_args_list: + raise AssertionError("Did not warn.") + self.assertIn( + "`compile()` was not called as part of model loading " + "because the model's `compile()` method is custom. ", + mock_warn.call_args_list[0][0][0], + ) + + def test_compile_overridden_warnings_sequential(self): + self._test_compile_overridden_warnings("sequential") + + def test_compile_overridden_warnings_subclassed(self): + self._test_compile_overridden_warnings("subclassed") + + def test_metadata(self): + temp_filepath = Path( + os.path.join(self.get_temp_dir(), "my_model.keras") + ) + model = CompileOverridingModel() + model.save(temp_filepath) + with zipfile.ZipFile(temp_filepath, "r") as z: + with z.open(saving_lib._METADATA_FILENAME, "r") as c: + metadata_json = c.read() + metadata = json.loads(metadata_json) + self.assertIn("keras_version", metadata) + self.assertIn("date_saved", metadata) + + # def test_gfile_copy_local_called(self): + # temp_filepath = Path( + # os.path.join(self.get_temp_dir(), "my_model.keras") + # ) + # model = CompileOverridingModel() + # with mock.patch( + # "re.match", autospec=True + # ) as mock_re_match, mock.patch( + # "tensorflow.compat.v2.io.file_utils.copy", autospec=True + # ) as mock_copy: + # # Mock Remote Path check to true to test gfile copy logic + # mock_re_match.return_value = True + # model.save(temp_filepath) + # mock_re_match.assert_called() + # mock_copy.assert_called() + # self.assertIn(str(temp_filepath), mock_re_match.call_args.args) + # self.assertIn(str(temp_filepath), mock_copy.call_args.args) + + def test_save_load_weights_only(self): + temp_filepath = Path( + os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + saving_lib.save_weights_only(model, temp_filepath) + model = _get_basic_functional_model() + saving_lib.load_weights_only(model, temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + # Test with Model method + model = _get_basic_functional_model() + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + def test_save_weights_only_with_unbuilt_model(self): + temp_filepath = Path( + os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ) + model = _get_subclassed_model() + with self.assertRaisesRegex( + ValueError, "You are saving a model that has not yet been built." + ): + saving_lib.save_weights_only(model, temp_filepath) + + def test_load_weights_only_with_unbuilt_model(self): + temp_filepath = Path( + os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ) + model = _get_subclassed_model() + x = np.random.random((100, 32)) + _ = model.predict(x) # Build the model by calling it on some data + saving_lib.save_weights_only(model, temp_filepath) + saving_lib.load_weights_only(model, temp_filepath) + + new_model = _get_subclassed_model() + with self.assertRaisesRegex( + ValueError, + "You are loading weights into a model that has not yet been built.", + ): + saving_lib.load_weights_only(new_model, temp_filepath) + + def test_load_weights_only_with_keras_file(self): + # Test loading weights from whole saved model + temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras")) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + saving_lib.save_model(model, temp_filepath) + model = _get_basic_functional_model() + saving_lib.load_weights_only(model, temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + # Test with Model method + model = _get_basic_functional_model() + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + def test_save_weights_subclassed_functional(self): + # The subclassed and basic functional model should have the same + # weights structure. + temp_filepath = Path( + os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + # Test saving basic, loading subclassed. + saving_lib.save_weights_only(model, temp_filepath) + model = _get_subclassed_functional_model() + saving_lib.load_weights_only(model, temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + # Test saving subclassed, loading basic. + saving_lib.save_weights_only(model, temp_filepath) + model = _get_basic_functional_model() + saving_lib.load_weights_only(model, temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + @pytest.mark.requires_trainable_backend + def test_compile_arg(self): + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + model = _get_basic_functional_model() + model.compile("sgd", "mse") + model.fit(np.random.random((2, 4)), np.random.random((2, 1))) + saving_lib.save_model(model, temp_filepath) + + model = saving_lib.load_model(temp_filepath) + self.assertEqual(model.compiled, True) + model = saving_lib.load_model(temp_filepath, compile=False) + self.assertEqual(model.compiled, False) + + # def test_overwrite(self): + # temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + # model = _get_basic_functional_model() + # model.save(temp_filepath) + # model.save(temp_filepath, overwrite=True) + # with self.assertRaises(EOFError): + # model.save(temp_filepath, overwrite=False) + + # temp_filepath = os.path.join( + # self.get_temp_dir(), "mymodel.weights.h5" + # ) + # model = _get_basic_functional_model() + # model.save_weights(temp_filepath) + # model.save_weights(temp_filepath, overwrite=True) + # with self.assertRaises(EOFError): + # model.save_weights(temp_filepath, overwrite=False) + + def test_partial_load(self): + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + original_model = keras.Sequential( + [ + keras.Input(shape=(3,), batch_size=2), + keras.layers.Dense(4), + keras.layers.Dense(5), + ] + ) + original_model.save(temp_filepath) + + # Test with a model that has a differently shaped layer + new_model = keras.Sequential( + [ + keras.Input(shape=(3,), batch_size=2), + keras.layers.Dense(4), + keras.layers.Dense(6), + ] + ) + new_layer_kernel_value = np.array(new_model.layers[1].kernel) + with self.assertRaisesRegex(ValueError, "must match"): + # Doesn't work by default + new_model.load_weights(temp_filepath) + # Now it works + new_model.load_weights(temp_filepath, skip_mismatch=True) + ref_weights = original_model.layers[0].get_weights() + new_weights = new_model.layers[0].get_weights() + self.assertEqual(len(ref_weights), len(new_weights)) + for ref_w, w in zip(ref_weights, new_weights): + self.assertAllClose(ref_w, w) + self.assertAllClose( + np.array(new_model.layers[1].kernel), new_layer_kernel_value + ) + + # Test with a model that has a new layer at the end + new_model = keras.Sequential( + [ + keras.Input(shape=(3,), batch_size=2), + keras.layers.Dense(4), + keras.layers.Dense(5), + keras.layers.Dense(5), + ] + ) + new_layer_kernel_value = np.array(new_model.layers[2].kernel) + with self.assertRaisesRegex(ValueError, "received 0 variables"): + # Doesn't work by default + new_model.load_weights(temp_filepath) + # Now it works + new_model.load_weights(temp_filepath, skip_mismatch=True) + for layer_index in [0, 1]: + ref_weights = original_model.layers[layer_index].get_weights() + new_weights = new_model.layers[layer_index].get_weights() + self.assertEqual(len(ref_weights), len(new_weights)) + for ref_w, w in zip(ref_weights, new_weights): + self.assertAllClose(ref_w, w) + self.assertAllClose( + np.array(new_model.layers[2].kernel), new_layer_kernel_value + ) + + @pytest.mark.requires_trainable_backend + def test_save_to_fileobj(self): + model = keras.Sequential( + [keras.layers.Dense(1, input_shape=(1,)), keras.layers.Dense(1)] + ) + model.compile(optimizer="adam", loss="mse") + + out = BytesIO() + saving_lib.save_model(model, out) + out.seek(0) + model = saving_lib.load_model(out) + + model.fit(np.array([1, 2]), np.array([1, 2])) + pred1 = model.predict(np.array([1, 2])) + + out = BytesIO() + saving_lib.save_model(model, out) + out.seek(0) + new_model = saving_lib.load_model(out) + + pred2 = new_model.predict(np.array([1, 2])) + + self.assertAllClose(pred1, pred2, atol=1e-5) + + @parameterized.named_parameters( + ("high_memory_config", True), + ("low_memory_config", False), + ) + def test_save_model_exception_raised(self, is_memory_sufficient): + if is_memory_sufficient: + saving_lib._MEMORY_UPPER_BOUND = 0.5 # 50% + + # Assume we have an error in `save_own_variables`. + class RaiseErrorLayer(keras.layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(units) + + def call(self, inputs): + return self.dense(inputs) + + def save_own_variables(self, store): + raise ValueError + + model = keras.Sequential([keras.Input([1]), RaiseErrorLayer(1)]) + filepath = f"{self.get_temp_dir()}/model.keras" + with self.assertRaises(ValueError): + saving_lib.save_model(model, filepath) + + # Ensure we don't have a bad "model.weights.h5" inside the zip file. + self.assertTrue(Path(filepath).exists()) + with zipfile.ZipFile(filepath) as zf: + all_filenames = zf.namelist() + self.assertNotIn("model.weights.h5", all_filenames) + + # Ensure we don't have any temporary files left. + self.assertLen(os.listdir(Path(filepath).parent), 1) + self.assertIn("model.keras", os.listdir(Path(filepath).parent)) + + @parameterized.named_parameters( + ("high_memory_config", True), + ("low_memory_config", False), + ) + def test_load_model_exception_raised(self, is_memory_sufficient): + if is_memory_sufficient: + saving_lib._MEMORY_UPPER_BOUND = 0.5 # 50% + + # Assume we have an error in `load_own_variables`. + class RaiseErrorLayer(keras.layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.dense = keras.layers.Dense(units) + + def call(self, inputs): + return self.dense(inputs) + + def load_own_variables(self, store): + raise ValueError + + model = keras.Sequential([keras.Input([1]), RaiseErrorLayer(1)]) + filepath = f"{self.get_temp_dir()}/model.keras" + saving_lib.save_model(model, filepath) + with self.assertRaises(ValueError): + saving_lib.load_model( + filepath, custom_objects={"RaiseErrorLayer": RaiseErrorLayer} + ) + + # Ensure we don't have any temporary files left. + self.assertLen(os.listdir(Path(filepath).parent), 1) + self.assertIn("model.keras", os.listdir(Path(filepath).parent)) + + def test_load_model_read_only_system(self): + model = keras.Sequential([keras.Input([1]), keras.layers.Dense(32)]) + filepath = f"{self.get_temp_dir()}/model.keras" + saving_lib.save_model(model, filepath) + + # Load the model correctly, regardless of whether an OSError occurs. + original_mode = os.stat(Path(filepath).parent).st_mode + os.chmod(Path(filepath).parent, mode=0o555) + model = saving_lib.load_model(filepath) + os.chmod(Path(filepath).parent, mode=original_mode) + + # Ensure we don't have any temporary files left. + self.assertLen(os.listdir(Path(filepath).parent), 1) + self.assertIn("model.keras", os.listdir(Path(filepath).parent)) + + @pytest.mark.skipif( + backend.backend() == "jax", + reason="JAX backend doesn't support Python's multiprocessing", + ) + @pytest.mark.skipif( + testing.tensorflow_uses_gpu() or testing.torch_uses_gpu(), + reason="This test doesn't support GPU", + ) + def test_load_model_concurrently(self): + import multiprocessing as mp + + model = keras.Sequential([keras.Input([1]), keras.layers.Dense(2)]) + filepath = f"{self.get_temp_dir()}/model.keras" + saving_lib.save_model(model, filepath) + + # Load the model concurrently. + results = [] + with mp.Pool(4) as pool: + for i in range(4): + results.append(pool.apply_async(_load_model_fn, (filepath,))) + pool.close() + pool.join() + [r.get() for r in results] # No error occurs here + + def test_load_model_containing_reused_layer(self): + # https://github.com/keras-team/keras/issues/20307 + inputs = keras.Input((4,)) + reused_layer = keras.layers.Dense(4) + x = reused_layer(inputs) + x = keras.layers.Dense(4)(x) + outputs = reused_layer(x) + model = keras.Model(inputs, outputs) + + self.assertLen(model.layers, 3) # Input + 2 Dense layers + self._test_inference_after_instantiation(model) + + @parameterized.named_parameters( + ("efficientnet_b0_512", "efficientnet_b0", 1), # Only 1 sharded file. + ("efficientnet_b0_10", "efficientnet_b0", 0.01), + ) + def test_weights_sharding(self, model_name, max_shard_size): + from keras.src.applications import efficientnet + + if backend.image_data_format() == "channels_last": + shape = (224, 224, 3) + else: + shape = (3, 224, 224) + + if model_name == "efficientnet_b0": + model_fn = efficientnet.EfficientNetB0 + + temp_filepath = Path( + os.path.join(self.get_temp_dir(), "mymodel.weights.json") + ) + model = model_fn(weights=None, input_shape=shape) + ref_input = np.random.random((1, *shape)).astype("float32") + ref_output = model.predict(ref_input) + + # Save the sharded files. + saving_lib.save_weights_only( + model, temp_filepath, max_shard_size=max_shard_size + ) + self.assertIn("mymodel.weights.json", os.listdir(temp_filepath.parent)) + if max_shard_size == 1: + # 1 sharded file + 1 config file = 2. + self.assertLen(os.listdir(temp_filepath.parent), 2) + elif max_shard_size == 0.01: + # 3 sharded file + 1 config file = 4. + self.assertLen(os.listdir(temp_filepath.parent), 4) + + with open(temp_filepath, "r") as f: + sharding_config = json.load(f) + self.assertIn("metadata", sharding_config) + self.assertIn("weight_map", sharding_config) + + # Instantiate new model and load the sharded files. + model = model_fn(weights=None, input_shape=shape) + saving_lib.load_weights_only(model, temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + +class SavingAPITest(testing.TestCase): + def test_saving_api_errors(self): + from keras.src.saving import saving_api + + model = _get_basic_functional_model() + + # Saving API errors + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel") + with self.assertRaisesRegex(ValueError, "argument is deprecated"): + saving_api.save_model(model, temp_filepath, save_format="keras") + + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.notkeras") + with self.assertRaisesRegex(ValueError, "Invalid filepath extension"): + saving_api.save_model(model, temp_filepath) + + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + with self.assertRaisesRegex(ValueError, "are not supported"): + saving_api.save_model(model, temp_filepath, invalid_arg="hello") + + # Loading API errors + temp_filepath = os.path.join(self.get_temp_dir(), "non_existent.keras") + with self.assertRaisesRegex( + ValueError, "Please ensure the file is an accessible" + ): + _ = saving_api.load_model(temp_filepath) + + temp_filepath = os.path.join(self.get_temp_dir(), "my_saved_model") + with self.assertRaisesRegex(ValueError, "File format not supported"): + _ = saving_api.load_model(temp_filepath) + + def test_model_api_endpoint(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras")) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + model.save(temp_filepath) + model = keras.saving.load_model(temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + def test_model_api_endpoint_h5(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.h5")) + model = _get_basic_functional_model() + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + model.save(temp_filepath) + model = keras.saving.load_model(temp_filepath) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + def test_model_api_errors(self): + model = _get_basic_functional_model() + + # Saving API errors + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel") + with self.assertRaisesRegex(ValueError, "argument is deprecated"): + model.save(temp_filepath, save_format="keras") + + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.notkeras") + with self.assertRaisesRegex(ValueError, "Invalid filepath extension"): + model.save(temp_filepath) + + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + with self.assertRaisesRegex(ValueError, "are not supported"): + model.save(temp_filepath, invalid_arg="hello") + + def test_safe_mode(self): + temp_filepath = os.path.join(self.get_temp_dir(), "unsafe_model.keras") + model = keras.Sequential( + [ + keras.Input(shape=(3,)), + keras.layers.Lambda(lambda x: x * 2), + ] + ) + model.save(temp_filepath) + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + model = saving_lib.load_model(temp_filepath) + model = saving_lib.load_model(temp_filepath, safe_mode=False) + + def test_normalization_kpl(self): + # With adapt + temp_filepath = os.path.join(self.get_temp_dir(), "norm_model.keras") + model = keras.Sequential( + [ + keras.Input(shape=(3,)), + keras.layers.Normalization(), + ] + ) + data = np.random.random((3, 3)) + model.layers[0].adapt(data) + ref_out = model(data) + model.save(temp_filepath) + model = saving_lib.load_model(temp_filepath) + out = model(data) + self.assertAllClose(ref_out, out, atol=1e-6) + + # Without adapt + model = keras.Sequential( + [ + keras.Input(shape=(3,)), + keras.layers.Normalization( + mean=np.random.random((3,)), + variance=np.random.random((3,)), + ), + ] + ) + ref_out = model(data) + model.save(temp_filepath) + model = saving_lib.load_model(temp_filepath) + out = model(data) + self.assertAllClose(ref_out, out, atol=1e-6) + + +# This class is properly registered with a `get_config()` method. +# However, since it does not subclass keras.layers.Layer, it lacks +# `from_config()` for deserialization. +@keras.saving.register_keras_serializable() +class GrowthFactor: + def __init__(self, factor): + self.factor = factor + + def __call__(self, inputs): + return inputs * self.factor + + def get_config(self): + return {"factor": self.factor} + + +@keras.saving.register_keras_serializable(package="Complex") +class FactorLayer(keras.layers.Layer): + def __init__(self, factor, **kwargs): + super().__init__(**kwargs) + self.factor = factor + + def call(self, x): + return x * self.factor + + def get_config(self): + return {"factor": self.factor} + + +# This custom model does not explicitly deserialize the layers it includes +# in its `get_config`. Explicit deserialization in a `from_config` override +# or `__init__` is needed here, or an error will be thrown at loading time. +@keras.saving.register_keras_serializable(package="Complex") +class ComplexModel(keras.layers.Layer): + def __init__(self, first_layer, second_layer=None, **kwargs): + super().__init__(**kwargs) + self.first_layer = first_layer + if second_layer is not None: + self.second_layer = second_layer + else: + self.second_layer = keras.layers.Dense(8) + + def get_config(self): + config = super().get_config() + config.update( + { + "first_layer": self.first_layer, + "second_layer": self.second_layer, + } + ) + return config + + def call(self, inputs): + return self.first_layer(self.second_layer(inputs)) + + +class SavingBattleTest(testing.TestCase): + def test_custom_object_without_from_config(self): + temp_filepath = os.path.join( + self.get_temp_dir(), "custom_fn_model.keras" + ) + + inputs = keras.Input(shape=(4, 4)) + outputs = keras.layers.Dense(1, activation=GrowthFactor(0.5))(inputs) + model = keras.Model(inputs, outputs) + + model.save(temp_filepath) + + with self.assertRaisesRegex( + TypeError, "Unable to reconstruct an instance" + ): + _ = saving_lib.load_model(temp_filepath) + + def test_complex_model_without_explicit_deserialization(self): + temp_filepath = os.path.join(self.get_temp_dir(), "complex_model.keras") + + inputs = keras.Input((32,)) + outputs = ComplexModel(first_layer=FactorLayer(0.5))(inputs) + model = keras.Model(inputs, outputs) + + model.save(temp_filepath) + + with self.assertRaisesRegex(TypeError, "are explicitly deserialized"): + _ = saving_lib.load_model(temp_filepath) + + def test_redefinition_of_trackable(self): + """Test that a trackable can be aliased under a new name.""" + + class NormalModel(keras.Model): + def __init__(self): + super().__init__() + self.dense = keras.layers.Dense(3) + + def call(self, x): + return self.dense(x) + + class WeirdModel(keras.Model): + def __init__(self): + super().__init__() + # This property will be traversed first, + # but "_dense" isn't in the saved file + # generated by NormalModel. + self.a_dense = keras.layers.Dense(3) + + @property + def dense(self): + return self.a_dense + + def call(self, x): + return self.dense(x) + + temp_filepath = os.path.join( + self.get_temp_dir(), "normal_model.weights.h5" + ) + model_a = NormalModel() + model_a(np.random.random((2, 2))) + model_a.save_weights(temp_filepath) + model_b = WeirdModel() + model_b(np.random.random((2, 2))) + model_b.load_weights(temp_filepath) + self.assertAllClose( + model_a.dense.kernel.numpy(), model_b.dense.kernel.numpy() + ) + + def test_normalization_legacy_h5_format(self): + temp_filepath = os.path.join(self.get_temp_dir(), "custom_model.h5") + + inputs = keras.Input((32,)) + normalization = keras.layers.Normalization() + outputs = normalization(inputs) + + model = keras.Model(inputs, outputs) + + x = np.random.random((1, 32)) + normalization.adapt(x) + ref_out = model(x) + + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + out = new_model(x) + self.assertAllClose(ref_out, out, atol=1e-6) + + def test_legacy_h5_format(self): + temp_filepath = os.path.join(self.get_temp_dir(), "custom_model.h5") + + inputs = keras.Input((32,)) + x = MyDense(2)(inputs) + outputs = CustomModelX()(x) + model = keras.Model(inputs, outputs) + + x = np.random.random((1, 32)) + ref_out = model(x) + + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + out = new_model(x) + self.assertAllClose(ref_out, out, atol=1e-6) + + def test_nested_functional_model_saving(self): + def func(in_size=4, out_size=2, name=None): + inputs = keras.layers.Input(shape=(in_size,)) + outputs = keras.layers.Dense(out_size)((inputs)) + return keras.Model(inputs, outputs=outputs, name=name) + + input_a, input_b = keras.Input((4,)), keras.Input((4,)) + out_a = func(out_size=2, name="func_a")(input_a) + out_b = func(out_size=3, name="func_b")(input_b) + model = keras.Model([input_a, input_b], outputs=[out_a, out_b]) + + temp_filepath = os.path.join(self.get_temp_dir(), "nested_func.keras") + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + x = [np.random.random((2, 4))], np.random.random((2, 4)) + ref_out = model(x) + out = new_model(x) + self.assertAllClose(ref_out[0], out[0]) + self.assertAllClose(ref_out[1], out[1]) + + def test_nested_shared_functional_model_saving(self): + def func(in_size=4, out_size=2, name=None): + inputs = keras.layers.Input(shape=(in_size,)) + outputs = keras.layers.Dense(out_size)((inputs)) + return keras.Model(inputs, outputs=outputs, name=name) + + inputs = [keras.Input((4,)), keras.Input((4,))] + func_shared = func(out_size=4, name="func_shared") + shared_a = func_shared(inputs[0]) + shared_b = func_shared(inputs[1]) + out_a = keras.layers.Dense(2)(shared_a) + out_b = keras.layers.Dense(2)(shared_b) + model = keras.Model(inputs, outputs=[out_a, out_b]) + + temp_filepath = os.path.join( + self.get_temp_dir(), "nested_shared_func.keras" + ) + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + x = [np.random.random((2, 4))], np.random.random((2, 4)) + ref_out = model(x) + out = new_model(x) + self.assertAllClose(ref_out[0], out[0]) + self.assertAllClose(ref_out[1], out[1]) + + def test_bidirectional_lstm_saving(self): + inputs = keras.Input((3, 2)) + outputs = keras.layers.Bidirectional(keras.layers.LSTM(64))(inputs) + model = keras.Model(inputs, outputs) + temp_filepath = os.path.join(self.get_temp_dir(), "bidir_lstm.keras") + model.save(temp_filepath) + new_model = keras.saving.load_model(temp_filepath) + x = np.random.random((1, 3, 2)) + ref_out = model(x) + out = new_model(x) + self.assertAllClose(ref_out, out) + + def test_remove_weights_only_saving_and_loading(self): + def is_remote_path(path): + return True + + temp_filepath = os.path.join(self.get_temp_dir(), "model.weights.h5") + + with mock.patch( + "keras.src.utils.file_utils.is_remote_path", is_remote_path + ): + model = _get_basic_functional_model() + model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + + +class SavingH5IOStoreTest(testing.TestCase): + def test_h5_io_store_basics(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5")) + + # Pre-defined data. + a = np.random.random((2, 4)).astype("float32") + b = np.random.random((4, 8)).astype("int32") + + # Set. + store = saving_lib.H5IOStore(temp_filepath, mode="w") + vars_store = store.make("vars") + vars_store["a"] = a + vars_store["b"] = b + vars_store["c"] = 42 + self.assertAllClose(vars_store["a"], a) + self.assertAllClose(vars_store["b"], b) + self.assertEqual(int(vars_store["c"][()]), 42) + + # Delete. + del vars_store["c"] + + # Contain. + self.assertNotIn("c", vars_store) + + store.close() + self.assertTrue(os.path.exists(temp_filepath)) + + # Get. + store = saving_lib.H5IOStore(temp_filepath, mode="r") + vars_store = store.get("vars") + self.assertAllClose(vars_store["a"], a) + self.assertAllClose(vars_store["b"], b) + self.assertNotIn("c", vars_store) + + def test_h5_io_store_lora(self): + # For `keras_hub.models.backbone.save_lora_weights` and + # `keras_hub.models.backbone.load_lora_weights` + temp_filepath = Path(os.path.join(self.get_temp_dir(), "layer.lora.h5")) + layer = keras.layers.Dense(units=16) + layer.build((None, 8)) + layer.enable_lora(4) + + ref_input = np.random.random((1, 8)).astype("float32") + ref_output = layer(ref_input) + + # Save the LoRA weights. + store = saving_lib.H5IOStore(temp_filepath, mode="w") + lora_store = store.make("lora") + lora_store["rank"] = layer.lora_rank + inner_store = store.make("lora/0") + inner_store["lora_kernel_a"] = layer.lora_kernel_a + inner_store["lora_kernel_b"] = layer.lora_kernel_b + store.close() + + # Load the LoRA weights. + revived_layer = keras.layers.Dense(units=16) + revived_layer.build((None, 8)) + store = saving_lib.H5IOStore(temp_filepath, mode="r") + lora_store = store.get("lora") + revived_layer.enable_lora(int(lora_store["rank"][()])) + lora_kernel_a = store.get("lora/0")["lora_kernel_a"] + lora_kernel_b = store.get("lora/0")["lora_kernel_b"] + revived_layer._kernel.assign(layer._kernel) + revived_layer.bias.assign(layer.bias) + revived_layer.lora_kernel_a.assign(lora_kernel_a) + revived_layer.lora_kernel_b.assign(lora_kernel_b) + self.assertAllClose(revived_layer(ref_input), ref_output, atol=1e-6) + + def test_h5_io_store_exception_raised(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5")) + + # Bad `path_or_io`. + with self.assertRaisesRegex( + TypeError, + ( + r"`path_or_io` should be a `str`, `pathlib.Path` or " + r"`io.BytesIO` object." + ), + ): + saving_lib.H5IOStore(None, mode="w") + + # Bad `mode`. + with self.assertRaisesRegex( + ValueError, r"`mode` should be either 'w' or 'r'." + ): + saving_lib.H5IOStore(temp_filepath, mode="x") + + # No archive when using `io.BytesIO` as `path_or_io`. + with self.assertRaisesRegex( + ValueError, + ( + r"When `path_or_io` is an `io.BytesIO` object, `archive` " + r"should be `None`." + ), + ): + saving_lib.H5IOStore(BytesIO(), archive="archive", mode="w") + + store = saving_lib.H5IOStore(temp_filepath, mode="w") + + # Bad `metadata`. + with self.assertRaisesRegex( + ValueError, r"`metadata` should be a dict or `None`." + ): + store.make("vars", metadata="metadata") + + store.close() + + store = saving_lib.H5IOStore(temp_filepath, mode="r") + vars_store = store.get("vars") + + # Set in read mode. + with self.assertRaisesRegex( + ValueError, r"Setting a value is only allowed in write mode." + ): + vars_store["weights"] = np.random.random((2, 4)).astype("float32") + + # Delete in read mode. + with self.assertRaisesRegex( + ValueError, r"Deleting a value is only allowed in write mode." + ): + del vars_store["weights"] + + def test_sharded_h5_io_store_basics(self): + name = "sharded_store" + temp_filepath = Path(os.path.join(self.get_temp_dir(), f"{name}.json")) + + # Pre-defined data. Each has about 0.0037GB. + a = np.random.random((1000, 1000)).astype("float32") + b = np.random.random((1000, 1000)).astype("int32") + + # Set. + store = saving_lib.ShardedH5IOStore( + temp_filepath, max_shard_size=0.005, mode="w" + ) + vars_store = store.make("vars") + vars_store["a"] = a + vars_store["b"] = b + vars_store["c"] = 42 + self.assertLen(store.sharding_config["weight_map"]["/vars/vars"], 2) + self.assertLen(vars_store, 3) + self.assertAllClose(vars_store["a"], a) + self.assertAllClose(vars_store["b"], b) + self.assertEqual(int(vars_store["c"][()]), 42) + + # Delete. + del vars_store["c"] + self.assertLen(vars_store, 2) + del vars_store["a"] # Delete from an older shard. + self.assertLen(vars_store, 1) + vars_store["a"] = a + + # Contain. + self.assertIn("a", vars_store) + self.assertNotIn("c", vars_store) + + store.close() + self.assertTrue(os.path.exists(temp_filepath)) + self.assertTrue( + os.path.exists(temp_filepath.with_name(f"{name}_00000.weights.h5")) + ) + + # Get. + store = saving_lib.ShardedH5IOStore(temp_filepath, mode="r") + vars_store = store.get("vars") + self.assertLen(vars_store, 2) + self.assertAllClose(vars_store["a"], a) + self.assertAllClose(vars_store["b"], b) + self.assertNotIn("c", vars_store) + + # Keys. + for key in ["a", "b"]: + self.assertIn(key, vars_store.keys()) + + # Items. + for key, value in vars_store.items(): + if key == "a": + self.assertAllClose(value, a) + elif key == "b": + self.assertAllClose(value, b) + else: + raise ValueError(f"Unexpected key: {key}") + + # Values. + for value in vars_store.values(): + if backend.standardize_dtype(value.dtype) == "float32": + self.assertAllClose(value, a) + elif backend.standardize_dtype(value.dtype) == "int32": + self.assertAllClose(value, b) + else: + raise ValueError(f"Unexpected value: {value}") + + def test_sharded_h5_io_store_exception_raised(self): + temp_filepath = Path(os.path.join(self.get_temp_dir(), "store.h5")) + + # Bad `path_or_io`. + with self.assertRaisesRegex( + TypeError, + r"`path_or_io` should be a `str`, `pathlib.Path` object. ", + ): + saving_lib.ShardedH5IOStore(None, mode="w") + + # Bad `mode`. + with self.assertRaisesRegex( + ValueError, r"`mode` should be either 'w' or 'r'." + ): + saving_lib.ShardedH5IOStore(temp_filepath, mode="x") + + store = saving_lib.ShardedH5IOStore( + temp_filepath, max_shard_size=0.00001, mode="w" + ) + vars_store = store.make("vars") + + # Too large data. + with self.assertRaisesRegex( + ValueError, r"exceeds the maximum shard size" + ): + vars_store["weights"] = np.random.random((100, 100)).astype( + "float32" + ) + + # Bad `get`. + with self.assertRaisesRegex( + KeyError, r"Key 'abc' not found in any of the shards:" + ): + vars_store["abc"] + + # Bad `del`. + with self.assertRaisesRegex( + KeyError, r"Key 'abc' not found in any of the shards:" + ): + del vars_store["abc"] + + store.close() diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py new file mode 100644 index 000000000000..da943a6c6096 --- /dev/null +++ b/keras/src/saving/serialization_lib.py @@ -0,0 +1,840 @@ +"""Object config serialization and deserialization logic.""" + +import importlib +import inspect +import types +import warnings + +import numpy as np + +from keras.src import api_export +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state +from keras.src.saving import object_registration +from keras.src.saving.keras_saveable import KerasSaveable +from keras.src.utils import python_utils +from keras.src.utils.module_utils import tensorflow as tf + +PLAIN_TYPES = (str, int, float, bool) + +# List of Keras modules with built-in string representations for Keras defaults +BUILTIN_MODULES = frozenset( + { + "activations", + "constraints", + "initializers", + "losses", + "metrics", + "optimizers", + "regularizers", + } +) + +LOADING_APIS = frozenset( + { + "keras.config.enable_unsafe_deserialization", + "keras.models.load_model", + "keras.preprocessing.image.load_img", + "keras.saving.load_model", + "keras.saving.load_weights", + "keras.utils.get_file", + "keras.utils.load_img", + } +) + + +class SerializableDict: + def __init__(self, **config): + self.config = config + + def serialize(self): + return serialize_keras_object(self.config) + + +class SafeModeScope: + """Scope to propagate safe mode flag to nested deserialization calls.""" + + def __init__(self, safe_mode=True): + self.safe_mode = safe_mode + + def __enter__(self): + self.original_value = in_safe_mode() + global_state.set_global_attribute("safe_mode_saving", self.safe_mode) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute( + "safe_mode_saving", self.original_value + ) + + +@keras_export("keras.config.enable_unsafe_deserialization") +def enable_unsafe_deserialization(): + """Disables safe mode globally, allowing deserialization of lambdas.""" + global_state.set_global_attribute("safe_mode_saving", False) + + +def in_safe_mode(): + return global_state.get_global_attribute("safe_mode_saving") + + +class ObjectSharingScope: + """Scope to enable detection and reuse of previously seen objects.""" + + def __enter__(self): + global_state.set_global_attribute("shared_objects/id_to_obj_map", {}) + global_state.set_global_attribute("shared_objects/id_to_config_map", {}) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute("shared_objects/id_to_obj_map", None) + global_state.set_global_attribute( + "shared_objects/id_to_config_map", None + ) + + +def get_shared_object(obj_id): + """Retrieve an object previously seen during deserialization.""" + id_to_obj_map = global_state.get_global_attribute( + "shared_objects/id_to_obj_map" + ) + if id_to_obj_map is not None: + return id_to_obj_map.get(obj_id, None) + + +def record_object_after_serialization(obj, config): + """Call after serializing an object, to keep track of its config.""" + if config["module"] == "__main__": + config["module"] = None # Ensures module is None when no module found + id_to_config_map = global_state.get_global_attribute( + "shared_objects/id_to_config_map" + ) + if id_to_config_map is None: + return # Not in a sharing scope + obj_id = int(id(obj)) + if obj_id not in id_to_config_map: + id_to_config_map[obj_id] = config + else: + config["shared_object_id"] = obj_id + prev_config = id_to_config_map[obj_id] + prev_config["shared_object_id"] = obj_id + + +def record_object_after_deserialization(obj, obj_id): + """Call after deserializing an object, to keep track of it in the future.""" + id_to_obj_map = global_state.get_global_attribute( + "shared_objects/id_to_obj_map" + ) + if id_to_obj_map is None: + return # Not in a sharing scope + id_to_obj_map[obj_id] = obj + + +@keras_export( + [ + "keras.saving.serialize_keras_object", + "keras.utils.serialize_keras_object", + ] +) +def serialize_keras_object(obj): + """Retrieve the config dict by serializing the Keras object. + + `serialize_keras_object()` serializes a Keras object to a python dictionary + that represents the object, and is a reciprocal function of + `deserialize_keras_object()`. See `deserialize_keras_object()` for more + information about the config format. + + Args: + obj: the Keras object to serialize. + + Returns: + A python dict that represents the object. The python dict can be + deserialized via `deserialize_keras_object()`. + """ + if obj is None: + return obj + + if isinstance(obj, PLAIN_TYPES): + return obj + + if isinstance(obj, (list, tuple)): + config_arr = [serialize_keras_object(x) for x in obj] + return tuple(config_arr) if isinstance(obj, tuple) else config_arr + if isinstance(obj, dict): + return serialize_dict(obj) + + # Special cases: + if isinstance(obj, bytes): + return { + "class_name": "__bytes__", + "config": {"value": obj.decode("utf-8")}, + } + if isinstance(obj, slice): + return { + "class_name": "__slice__", + "config": { + "start": serialize_keras_object(obj.start), + "stop": serialize_keras_object(obj.stop), + "step": serialize_keras_object(obj.step), + }, + } + # Ellipsis is an instance, and ellipsis class is not in global scope. + # checking equality also fails elsewhere in the library, so we have + # to dynamically get the type. + if isinstance(obj, type(Ellipsis)): + return {"class_name": "__ellipsis__", "config": {}} + if isinstance(obj, backend.KerasTensor): + history = getattr(obj, "_keras_history", None) + if history: + history = list(history) + history[0] = history[0].name + return { + "class_name": "__keras_tensor__", + "config": { + "shape": obj.shape, + "dtype": obj.dtype, + "keras_history": history, + }, + } + if tf.available and isinstance(obj, tf.TensorShape): + return obj.as_list() if obj._dims is not None else None + if backend.is_tensor(obj): + return { + "class_name": "__tensor__", + "config": { + "value": backend.convert_to_numpy(obj).tolist(), + "dtype": backend.standardize_dtype(obj.dtype), + }, + } + if type(obj).__module__ == np.__name__: + if isinstance(obj, np.ndarray) and obj.ndim > 0: + return { + "class_name": "__numpy__", + "config": { + "value": obj.tolist(), + "dtype": backend.standardize_dtype(obj.dtype), + }, + } + else: + # Treat numpy floats / etc as plain types. + return obj.item() + if tf.available and isinstance(obj, tf.DType): + return obj.name + if isinstance(obj, types.FunctionType) and obj.__name__ == "": + warnings.warn( + "The object being serialized includes a `lambda`. This is unsafe. " + "In order to reload the object, you will have to pass " + "`safe_mode=False` to the loading function. " + "Please avoid using `lambda` in the " + "future, and use named Python functions instead. " + f"This is the `lambda` being serialized: {inspect.getsource(obj)}", + stacklevel=2, + ) + return { + "class_name": "__lambda__", + "config": { + "value": python_utils.func_dump(obj), + }, + } + if tf.available and isinstance(obj, tf.TypeSpec): + ts_config = obj._serialize() + # TensorShape and tf.DType conversion + ts_config = list( + map( + lambda x: ( + x.as_list() + if isinstance(x, tf.TensorShape) + else (x.name if isinstance(x, tf.DType) else x) + ), + ts_config, + ) + ) + return { + "class_name": "__typespec__", + "spec_name": obj.__class__.__name__, + "module": obj.__class__.__module__, + "config": ts_config, + "registered_name": None, + } + + inner_config = _get_class_or_fn_config(obj) + config_with_public_class = serialize_with_public_class( + obj.__class__, inner_config + ) + + if config_with_public_class is not None: + get_build_and_compile_config(obj, config_with_public_class) + record_object_after_serialization(obj, config_with_public_class) + return config_with_public_class + + # Any custom object or otherwise non-exported object + if isinstance(obj, types.FunctionType): + module = obj.__module__ + else: + module = obj.__class__.__module__ + class_name = obj.__class__.__name__ + + if module == "builtins": + registered_name = None + else: + if isinstance(obj, types.FunctionType): + registered_name = object_registration.get_registered_name(obj) + else: + registered_name = object_registration.get_registered_name( + obj.__class__ + ) + + config = { + "module": module, + "class_name": class_name, + "config": inner_config, + "registered_name": registered_name, + } + get_build_and_compile_config(obj, config) + record_object_after_serialization(obj, config) + return config + + +def get_build_and_compile_config(obj, config): + if hasattr(obj, "get_build_config"): + build_config = obj.get_build_config() + if build_config is not None: + config["build_config"] = serialize_dict(build_config) + if hasattr(obj, "get_compile_config"): + compile_config = obj.get_compile_config() + if compile_config is not None: + config["compile_config"] = serialize_dict(compile_config) + return + + +def serialize_with_public_class(cls, inner_config=None): + """Serializes classes from public Keras API or object registration. + + Called to check and retrieve the config of any class that has a public + Keras API or has been registered as serializable via + `keras.saving.register_keras_serializable()`. + """ + # This gets the `keras.*` exported name, such as + # "keras.optimizers.Adam". + keras_api_name = api_export.get_name_from_symbol(cls) + + # Case of custom or unknown class object + if keras_api_name is None: + registered_name = object_registration.get_registered_name(cls) + if registered_name is None: + return None + + # Return custom object config with corresponding registration name + return { + "module": cls.__module__, + "class_name": cls.__name__, + "config": inner_config, + "registered_name": registered_name, + } + + # Split the canonical Keras API name into a Keras module and class name. + parts = keras_api_name.split(".") + return { + "module": ".".join(parts[:-1]), + "class_name": parts[-1], + "config": inner_config, + "registered_name": None, + } + + +def serialize_with_public_fn(fn, config, fn_module_name=None): + """Serializes functions from public Keras API or object registration. + + Called to check and retrieve the config of any function that has a public + Keras API or has been registered as serializable via + `keras.saving.register_keras_serializable()`. If function's module name + is already known, returns corresponding config. + """ + if fn_module_name: + return { + "module": fn_module_name, + "class_name": "function", + "config": config, + "registered_name": config, + } + keras_api_name = api_export.get_name_from_symbol(fn) + if keras_api_name: + parts = keras_api_name.split(".") + return { + "module": ".".join(parts[:-1]), + "class_name": "function", + "config": config, + "registered_name": config, + } + else: + registered_name = object_registration.get_registered_name(fn) + if not registered_name and not fn.__module__ == "builtins": + return None + return { + "module": fn.__module__, + "class_name": "function", + "config": config, + "registered_name": registered_name, + } + + +def _get_class_or_fn_config(obj): + """Return the object's config depending on its type.""" + # Functions / lambdas: + if isinstance(obj, types.FunctionType): + return object_registration.get_registered_name(obj) + # All classes: + if hasattr(obj, "get_config"): + config = obj.get_config() + if not isinstance(config, dict): + raise TypeError( + f"The `get_config()` method of {obj} should return " + f"a dict. It returned: {config}" + ) + return serialize_dict(config) + elif hasattr(obj, "__name__"): + return object_registration.get_registered_name(obj) + else: + raise TypeError( + f"Cannot serialize object {obj} of type {type(obj)}. " + "To be serializable, " + "a class must implement the `get_config()` method." + ) + + +def serialize_dict(obj): + return {key: serialize_keras_object(value) for key, value in obj.items()} + + +@keras_export( + [ + "keras.saving.deserialize_keras_object", + "keras.utils.deserialize_keras_object", + ] +) +def deserialize_keras_object( + config, custom_objects=None, safe_mode=True, **kwargs +): + """Retrieve the object by deserializing the config dict. + + The config dict is a Python dictionary that consists of a set of key-value + pairs, and represents a Keras object, such as an `Optimizer`, `Layer`, + `Metrics`, etc. The saving and loading library uses the following keys to + record information of a Keras object: + + - `class_name`: String. This is the name of the class, + as exactly defined in the source + code, such as "LossesContainer". + - `config`: Dict. Library-defined or user-defined key-value pairs that store + the configuration of the object, as obtained by `object.get_config()`. + - `module`: String. The path of the python module. Built-in Keras classes + expect to have prefix `keras`. + - `registered_name`: String. The key the class is registered under via + `keras.saving.register_keras_serializable(package, name)` API. The + key has the format of '{package}>{name}', where `package` and `name` are + the arguments passed to `register_keras_serializable()`. If `name` is not + provided, it uses the class name. If `registered_name` successfully + resolves to a class (that was registered), the `class_name` and `config` + values in the dict will not be used. `registered_name` is only used for + non-built-in classes. + + For example, the following dictionary represents the built-in Adam optimizer + with the relevant config: + + ```python + dict_structure = { + "class_name": "Adam", + "config": { + "amsgrad": false, + "beta_1": 0.8999999761581421, + "beta_2": 0.9990000128746033, + "decay": 0.0, + "epsilon": 1e-07, + "learning_rate": 0.0010000000474974513, + "name": "Adam" + }, + "module": "keras.optimizers", + "registered_name": None + } + # Returns an `Adam` instance identical to the original one. + deserialize_keras_object(dict_structure) + ``` + + If the class does not have an exported Keras namespace, the library tracks + it by its `module` and `class_name`. For example: + + ```python + dict_structure = { + "class_name": "MetricsList", + "config": { + ... + }, + "module": "keras.trainers.compile_utils", + "registered_name": "MetricsList" + } + + # Returns a `MetricsList` instance identical to the original one. + deserialize_keras_object(dict_structure) + ``` + + And the following dictionary represents a user-customized `MeanSquaredError` + loss: + + ```python + @keras.saving.register_keras_serializable(package='my_package') + class ModifiedMeanSquaredError(keras.losses.MeanSquaredError): + ... + + dict_structure = { + "class_name": "ModifiedMeanSquaredError", + "config": { + "fn": "mean_squared_error", + "name": "mean_squared_error", + "reduction": "auto" + }, + "registered_name": "my_package>ModifiedMeanSquaredError" + } + # Returns the `ModifiedMeanSquaredError` object + deserialize_keras_object(dict_structure) + ``` + + Args: + config: Python dict describing the object. + custom_objects: Python dict containing a mapping between custom + object names the corresponding classes or functions. + safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. + When `safe_mode=False`, loading an object has the potential to + trigger arbitrary code execution. This argument is only + applicable to the Keras v3 model format. Defaults to `True`. + + Returns: + The object described by the `config` dictionary. + """ + safe_scope_arg = in_safe_mode() # Enforces SafeModeScope + safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode + + module_objects = kwargs.pop("module_objects", None) + custom_objects = custom_objects or {} + tlco = global_state.get_global_attribute("custom_objects_scope_dict", {}) + gco = object_registration.GLOBAL_CUSTOM_OBJECTS + custom_objects = {**custom_objects, **tlco, **gco} + + if config is None: + return None + + if ( + isinstance(config, str) + and custom_objects + and custom_objects.get(config) is not None + ): + # This is to deserialize plain functions which are serialized as + # string names by legacy saving formats. + return custom_objects[config] + + if isinstance(config, (list, tuple)): + return [ + deserialize_keras_object( + x, custom_objects=custom_objects, safe_mode=safe_mode + ) + for x in config + ] + + if module_objects is not None: + inner_config, fn_module_name, has_custom_object = None, None, False + + if isinstance(config, dict): + if "config" in config: + inner_config = config["config"] + if "class_name" not in config: + raise ValueError( + f"Unknown `config` as a `dict`, config={config}" + ) + + # Check case where config is function or class and in custom objects + if custom_objects and ( + config["class_name"] in custom_objects + or config.get("registered_name") in custom_objects + or ( + isinstance(inner_config, str) + and inner_config in custom_objects + ) + ): + has_custom_object = True + + # Case where config is function but not in custom objects + elif config["class_name"] == "function": + fn_module_name = config["module"] + if fn_module_name == "builtins": + config = config["config"] + else: + config = config["registered_name"] + + # Case where config is class but not in custom objects + else: + if config.get("module", "_") is None: + raise TypeError( + "Cannot deserialize object of type " + f"`{config['class_name']}`. If " + f"`{config['class_name']}` is a custom class, please " + "register it using the " + "`@keras.saving.register_keras_serializable()` " + "decorator." + ) + config = config["class_name"] + + if not has_custom_object: + # Return if not found in either module objects or custom objects + if config not in module_objects: + # Object has already been deserialized + return config + if isinstance(module_objects[config], types.FunctionType): + return deserialize_keras_object( + serialize_with_public_fn( + module_objects[config], config, fn_module_name + ), + custom_objects=custom_objects, + ) + return deserialize_keras_object( + serialize_with_public_class( + module_objects[config], inner_config=inner_config + ), + custom_objects=custom_objects, + ) + + if isinstance(config, PLAIN_TYPES): + return config + if not isinstance(config, dict): + raise TypeError(f"Could not parse config: {config}") + + if "class_name" not in config or "config" not in config: + return { + key: deserialize_keras_object( + value, custom_objects=custom_objects, safe_mode=safe_mode + ) + for key, value in config.items() + } + + class_name = config["class_name"] + inner_config = config["config"] or {} + custom_objects = custom_objects or {} + + # Special cases: + if class_name == "__keras_tensor__": + obj = backend.KerasTensor( + inner_config["shape"], dtype=inner_config["dtype"] + ) + obj._pre_serialization_keras_history = inner_config["keras_history"] + return obj + + if class_name == "__tensor__": + return backend.convert_to_tensor( + inner_config["value"], dtype=inner_config["dtype"] + ) + if class_name == "__numpy__": + return np.array(inner_config["value"], dtype=inner_config["dtype"]) + if config["class_name"] == "__bytes__": + return inner_config["value"].encode("utf-8") + if config["class_name"] == "__ellipsis__": + return Ellipsis + if config["class_name"] == "__slice__": + return slice( + deserialize_keras_object( + inner_config["start"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + deserialize_keras_object( + inner_config["stop"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + deserialize_keras_object( + inner_config["step"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + ) + if config["class_name"] == "__lambda__": + if safe_mode: + raise ValueError( + "Requested the deserialization of a Python lambda. This " + "carries a potential risk of arbitrary code execution and thus " + "it is disallowed by default. If you trust the source of the " + "artifact, you can override this error by passing " + "`safe_mode=False` to the loading function, or calling " + "`keras.config.enable_unsafe_deserialization()." + ) + return python_utils.func_load(inner_config["value"]) + if tf is not None and config["class_name"] == "__typespec__": + obj = _retrieve_class_or_fn( + config["spec_name"], + config["registered_name"], + config["module"], + obj_type="class", + full_config=config, + custom_objects=custom_objects, + ) + # Conversion to TensorShape and DType + inner_config = map( + lambda x: ( + tf.TensorShape(x) + if isinstance(x, list) + else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x) + ), + inner_config, + ) + return obj._deserialize(tuple(inner_config)) + + # Below: classes and functions. + module = config.get("module", None) + registered_name = config.get("registered_name", class_name) + + if class_name == "function": + fn_name = inner_config + return _retrieve_class_or_fn( + fn_name, + registered_name, + module, + obj_type="function", + full_config=config, + custom_objects=custom_objects, + ) + + # Below, handling of all classes. + # First, is it a shared object? + if "shared_object_id" in config: + obj = get_shared_object(config["shared_object_id"]) + if obj is not None: + return obj + + cls = _retrieve_class_or_fn( + class_name, + registered_name, + module, + obj_type="class", + full_config=config, + custom_objects=custom_objects, + ) + + if isinstance(cls, types.FunctionType): + return cls + if not hasattr(cls, "from_config"): + raise TypeError( + f"Unable to reconstruct an instance of '{class_name}' because " + f"the class is missing a `from_config()` method. " + f"Full object config: {config}" + ) + + # Instantiate the class from its config inside a custom object scope + # so that we can catch any custom objects that the config refers to. + custom_obj_scope = object_registration.CustomObjectScope(custom_objects) + safe_mode_scope = SafeModeScope(safe_mode) + with custom_obj_scope, safe_mode_scope: + try: + instance = cls.from_config(inner_config) + except TypeError as e: + raise TypeError( + f"{cls} could not be deserialized properly. Please" + " ensure that components that are Python object" + " instances (layers, models, etc.) returned by" + " `get_config()` are explicitly deserialized in the" + " model's `from_config()` method." + f"\n\nconfig={config}.\n\nException encountered: {e}" + ) + build_config = config.get("build_config", None) + if build_config and not instance.built: + instance.build_from_config(build_config) + instance.built = True + compile_config = config.get("compile_config", None) + if compile_config: + instance.compile_from_config(compile_config) + instance.compiled = True + + if "shared_object_id" in config: + record_object_after_deserialization( + instance, config["shared_object_id"] + ) + return instance + + +def _retrieve_class_or_fn( + name, registered_name, module, obj_type, full_config, custom_objects=None +): + # If there is a custom object registered via + # `register_keras_serializable()`, that takes precedence. + if obj_type == "function": + custom_obj = object_registration.get_registered_object( + name, custom_objects=custom_objects + ) + else: + custom_obj = object_registration.get_registered_object( + registered_name, custom_objects=custom_objects + ) + if custom_obj is not None: + return custom_obj + + if module: + # If it's a Keras built-in object, + # we cannot always use direct import, because the exported + # module name might not match the package structure + # (e.g. experimental symbols). + if module == "keras" or module.startswith("keras."): + api_name = f"{module}.{name}" + + if api_name in LOADING_APIS: + raise ValueError( + f"Cannot deserialize `{api_name}`, loading functions are " + "not allowed during deserialization" + ) + + obj = api_export.get_symbol_from_name(api_name) + if obj is not None: + return obj + + # Configs of Keras built-in functions do not contain identifying + # information other than their name (e.g. 'acc' or 'tanh'). This special + # case searches the Keras modules that contain built-ins to retrieve + # the corresponding function from the identifying string. + if obj_type == "function" and module == "builtins": + for mod in BUILTIN_MODULES: + obj = api_export.get_symbol_from_name(f"keras.{mod}.{name}") + if obj is not None: + return obj + + # Workaround for serialization bug in Keras <= 3.6 whereby custom + # functions would only be saved by name instead of registered name, + # i.e. "name" instead of "package>name". This allows recent versions + # of Keras to reload models saved with 3.6 and lower. + if ">" not in name: + separated_name = f">{name}" + for custom_name, custom_object in custom_objects.items(): + if custom_name.endswith(separated_name): + return custom_object + + # Otherwise, attempt to retrieve the class object given the `module` + # and `class_name`. Import the module, find the class. + package = module.split(".", maxsplit=1)[0] + if package in {"keras", "keras_hub", "keras_cv", "keras_nlp"}: + try: + mod = importlib.import_module(module) + obj = vars(mod).get(name, None) + if isinstance(obj, type) and issubclass(obj, KerasSaveable): + return obj + else: + raise ValueError( + f"Could not deserialize '{module}.{name}' because " + "it is not a KerasSaveable subclass" + ) + except ModuleNotFoundError: + raise TypeError( + f"Could not deserialize {obj_type} '{name}' because " + f"its parent module {module} cannot be imported. " + f"Full object config: {full_config}" + ) + + raise TypeError( + f"Could not locate {obj_type} '{name}'. Make sure custom classes and " + "functions are decorated with " + "`@keras.saving.register_keras_serializable()`. If they are already " + "decorated, make sure they are all imported so that the decorator is " + f"run before trying to load them. Full object config: {full_config}" + ) diff --git a/keras/src/saving/serialization_lib_test.py b/keras/src/saving/serialization_lib_test.py new file mode 100644 index 000000000000..8ff0d8cf6fe1 --- /dev/null +++ b/keras/src/saving/serialization_lib_test.py @@ -0,0 +1,478 @@ +"""Tests for serialization_lib.""" + +import json + +import numpy as np +import pytest + +import keras +from keras.src import ops +from keras.src import testing +from keras.src.saving import object_registration +from keras.src.saving import serialization_lib + + +def custom_fn(x): + return x**2 + + +class CustomLayer(keras.layers.Layer): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def call(self, x): + return x * self.factor + + def get_config(self): + return {"factor": self.factor} + + +class NestedCustomLayer(keras.layers.Layer): + def __init__(self, factor, dense=None, activation=None): + super().__init__() + self.factor = factor + + if dense is None: + self.dense = keras.layers.Dense(1, activation=custom_fn) + else: + self.dense = serialization_lib.deserialize_keras_object(dense) + self.activation = serialization_lib.deserialize_keras_object(activation) + + def call(self, x): + return self.dense(x * self.factor) + + def get_config(self): + return { + "factor": self.factor, + "dense": self.dense, + "activation": self.activation, + } + + +class WrapperLayer(keras.layers.Wrapper): + def call(self, x): + return self.layer(x) + + +class SerializationLibTest(testing.TestCase): + def roundtrip(self, obj, custom_objects=None, safe_mode=True): + serialized = serialization_lib.serialize_keras_object(obj) + json_data = json.dumps(serialized) + json_data = json.loads(json_data) + deserialized = serialization_lib.deserialize_keras_object( + json_data, custom_objects=custom_objects, safe_mode=safe_mode + ) + reserialized = serialization_lib.serialize_keras_object(deserialized) + return serialized, deserialized, reserialized + + def test_simple_objects(self): + for obj in [ + "hello", + b"hello", + np.array([0, 1]), + np.array([0.0, 1.0]), + np.float32(1.0), + ["hello", 0, "world", 1.0, True], + {"1": "hello", "2": 0, "3": True}, + {"1": "hello", "2": [True, False]}, + slice(None, 20, 1), + slice(None, np.array([0, 1]), 1), + ]: + serialized, _, reserialized = self.roundtrip(obj) + self.assertEqual(serialized, reserialized) + + def test_builtin_layers(self): + layer = keras.layers.Dense( + 3, + name="foo", + trainable=False, + dtype="float16", + ) + serialized, restored, reserialized = self.roundtrip(layer) + self.assertEqual(serialized, reserialized) + self.assertEqual(layer.name, restored.name) + self.assertEqual(layer.trainable, restored.trainable) + self.assertEqual(layer.compute_dtype, restored.compute_dtype) + + def test_numpy_get_item_layer(self): + def tuples_to_lists_str(x): + return str(x).replace("(", "[").replace(")", "]") + + input = keras.layers.Input(shape=(2,)) + layer = input[:, 1] + model = keras.Model(input, layer) + serialized, _, reserialized = self.roundtrip(model) + # Anticipate JSON roundtrip mapping tuples to lists: + serialized_str = tuples_to_lists_str(serialized) + reserialized_str = tuples_to_lists_str(reserialized) + self.assertEqual(serialized_str, reserialized_str) + + def test_serialize_ellipsis(self): + _, deserialized, _ = self.roundtrip(Ellipsis) + self.assertEqual(..., deserialized) + + def test_tensors_and_shapes(self): + x = ops.random.normal((2, 2), dtype="float64") + obj = {"x": x} + _, new_obj, _ = self.roundtrip(obj) + self.assertAllClose(x, new_obj["x"], atol=1e-5) + + obj = {"x.shape": x.shape} + _, new_obj, _ = self.roundtrip(obj) + self.assertEqual(tuple(x.shape), tuple(new_obj["x.shape"])) + + def test_custom_fn(self): + obj = {"activation": custom_fn} + serialized, _, reserialized = self.roundtrip( + obj, custom_objects={"custom_fn": custom_fn} + ) + self.assertEqual(serialized, reserialized) + + # Test inside layer + dense = keras.layers.Dense(1, activation=custom_fn) + dense.build((None, 2)) + _, new_dense, _ = self.roundtrip( + dense, custom_objects={"custom_fn": custom_fn} + ) + x = ops.random.normal((2, 2)) + y1 = dense(x) + _ = new_dense(x) + new_dense.set_weights(dense.get_weights()) + y2 = new_dense(x) + self.assertAllClose(y1, y2, atol=1e-5) + + def test_custom_layer(self): + layer = CustomLayer(factor=2) + x = ops.random.normal((2, 2)) + y1 = layer(x) + _, new_layer, _ = self.roundtrip( + layer, custom_objects={"CustomLayer": CustomLayer} + ) + y2 = new_layer(x) + self.assertAllClose(y1, y2, atol=1e-5) + + layer = NestedCustomLayer(factor=2) + x = ops.random.normal((2, 2)) + y1 = layer(x) + _, new_layer, _ = self.roundtrip( + layer, + custom_objects={ + "NestedCustomLayer": NestedCustomLayer, + "custom_fn": custom_fn, + }, + ) + _ = new_layer(x) + new_layer.set_weights(layer.get_weights()) + y2 = new_layer(x) + self.assertAllClose(y1, y2, atol=1e-5) + + def test_lambda_fn(self): + obj = {"activation": lambda x: x**2} + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + self.roundtrip(obj, safe_mode=True) + + _, new_obj, _ = self.roundtrip(obj, safe_mode=False) + self.assertEqual(obj["activation"](3), new_obj["activation"](3)) + + def test_lambda_layer(self): + lmbda = keras.layers.Lambda(lambda x: x**2) + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + self.roundtrip(lmbda, safe_mode=True) + + _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False) + x = ops.random.normal((2, 2)) + y1 = lmbda(x) + y2 = new_lmbda(x) + self.assertAllClose(y1, y2, atol=1e-5) + + def test_safe_mode_scope(self): + lmbda = keras.layers.Lambda(lambda x: x**2) + with serialization_lib.SafeModeScope(safe_mode=True): + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + self.roundtrip(lmbda) + with serialization_lib.SafeModeScope(safe_mode=False): + _, new_lmbda, _ = self.roundtrip(lmbda) + x = ops.random.normal((2, 2)) + y1 = lmbda(x) + y2 = new_lmbda(x) + self.assertAllClose(y1, y2, atol=1e-5) + + @pytest.mark.requires_trainable_backend + def test_dict_inputs_outputs(self): + input_foo = keras.Input((2,), name="foo") + input_bar = keras.Input((2,), name="bar") + dense = keras.layers.Dense(1) + output_foo = dense(input_foo) + output_bar = dense(input_bar) + model = keras.Model( + {"foo": input_foo, "bar": input_bar}, + {"foo": output_foo, "bar": output_bar}, + ) + _, new_model, _ = self.roundtrip(model) + original_output = model( + {"foo": np.zeros((2, 2)), "bar": np.zeros((2, 2))} + ) + restored_output = model( + {"foo": np.zeros((2, 2)), "bar": np.zeros((2, 2))} + ) + self.assertAllClose(original_output["foo"], restored_output["foo"]) + self.assertAllClose(original_output["bar"], restored_output["bar"]) + + @pytest.mark.requires_trainable_backend + def test_shared_inner_layer(self): + with serialization_lib.ObjectSharingScope(): + input_1 = keras.Input((2,)) + input_2 = keras.Input((2,)) + shared_layer = keras.layers.Dense(1) + output_1 = shared_layer(input_1) + wrapper_layer = WrapperLayer(shared_layer) + output_2 = wrapper_layer(input_2) + model = keras.Model([input_1, input_2], [output_1, output_2]) + _, new_model, _ = self.roundtrip( + model, custom_objects={"WrapperLayer": WrapperLayer} + ) + + self.assertIs(model.layers[2], model.layers[3].layer) + self.assertIs(new_model.layers[2], new_model.layers[3].layer) + + @pytest.mark.requires_trainable_backend + def test_functional_subclass(self): + class PlainFunctionalSubclass(keras.Model): + pass + + inputs = keras.Input((2,), batch_size=3) + outputs = keras.layers.Dense(1)(inputs) + model = PlainFunctionalSubclass(inputs, outputs) + x = ops.random.normal((2, 2)) + y1 = model(x) + _, new_model, _ = self.roundtrip( + model, + custom_objects={"PlainFunctionalSubclass": PlainFunctionalSubclass}, + ) + new_model.set_weights(model.get_weights()) + y2 = new_model(x) + self.assertAllClose(y1, y2, atol=1e-5) + self.assertIsInstance(new_model, PlainFunctionalSubclass) + + class FunctionalSubclassWCustomInit(keras.Model): + def __init__(self, num_units=2): + inputs = keras.Input((2,), batch_size=3) + outputs = keras.layers.Dense(num_units)(inputs) + super().__init__(inputs, outputs) + self.num_units = num_units + + def get_config(self): + return {"num_units": self.num_units} + + model = FunctionalSubclassWCustomInit(num_units=3) + x = ops.random.normal((2, 2)) + y1 = model(x) + _, new_model, _ = self.roundtrip( + model, + custom_objects={ + "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit + }, + ) + new_model.set_weights(model.get_weights()) + y2 = new_model(x) + self.assertAllClose(y1, y2, atol=1e-5) + self.assertIsInstance(new_model, FunctionalSubclassWCustomInit) + + def test_shared_object(self): + class MyLayer(keras.layers.Layer): + def __init__(self, activation, **kwargs): + super().__init__(**kwargs) + if isinstance(activation, dict): + self.activation = ( + serialization_lib.deserialize_keras_object(activation) + ) + else: + self.activation = activation + + def call(self, x): + return self.activation(x) + + def get_config(self): + config = super().get_config() + config["activation"] = self.activation + return config + + class SharedActivation: + def __call__(self, x): + return x**2 + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls() + + shared_act = SharedActivation() + layer_1 = MyLayer(activation=shared_act) + layer_2 = MyLayer(activation=shared_act) + layers = [layer_1, layer_2] + + with serialization_lib.ObjectSharingScope(): + serialized, new_layers, reserialized = self.roundtrip( + layers, + custom_objects={ + "MyLayer": MyLayer, + "SharedActivation": SharedActivation, + }, + ) + self.assertIn("shared_object_id", serialized[0]["config"]["activation"]) + obj_id = serialized[0]["config"]["activation"] + self.assertIn("shared_object_id", serialized[1]["config"]["activation"]) + self.assertEqual(obj_id, serialized[1]["config"]["activation"]) + self.assertIs(layers[0].activation, layers[1].activation) + self.assertIs(new_layers[0].activation, new_layers[1].activation) + + def test_layer_sharing(self): + seq = keras.Sequential( + [ + keras.Input(shape=(3,)), + keras.layers.Dense(5), + keras.layers.Softmax(), + ], + ) + func = keras.Model(inputs=seq.inputs, outputs=seq.outputs) + serialized, deserialized, reserialized = self.roundtrip(func) + self.assertLen(deserialized.layers, 3) + + def test_keras36_custom_function_reloading(self): + @object_registration.register_keras_serializable(package="serial_test") + def custom_registered_fn(x): + return x**2 + + config36 = { + "module": "builtins", + "class_name": "function", + "config": "custom_registered_fn", + "registered_name": "function", + } + obj = serialization_lib.deserialize_keras_object(config36) + self.assertIs(obj, custom_registered_fn) + + config = { + "module": "builtins", + "class_name": "function", + "config": "serial_test>custom_registered_fn", + "registered_name": "function", + } + obj = serialization_lib.deserialize_keras_object(config) + self.assertIs(obj, custom_registered_fn) + + def test_layer_instance_as_activation(self): + """Tests serialization when activation is a Layer instance.""" + + # Dense layer with ReLU layer as activation + layer_dense_relu = keras.layers.Dense( + units=4, activation=keras.layers.ReLU(name="my_relu") + ) + # Build the layer to ensure weights/state are initialized if needed + layer_dense_relu.build(input_shape=(None, 8)) + _, restored_dense_relu, _ = self.roundtrip(layer_dense_relu) + + # Verify the activation is correctly deserialized as a ReLU layer + self.assertIsInstance(restored_dense_relu.activation, keras.layers.ReLU) + # Verify properties are preserved + self.assertEqual(restored_dense_relu.activation.name, "my_relu") + + def test_layer_instance_with_config_as_activation(self): + """ + Tests serialization when activation is a Layer instance with config. + """ + + # Conv1D layer with LeakyReLU layer (with config) as activation + leaky_activation = keras.layers.LeakyReLU( + negative_slope=0.15, name="my_leaky" + ) + layer_conv_leaky = keras.layers.Conv1D( + filters=2, kernel_size=3, activation=leaky_activation + ) + # Build the layer + layer_conv_leaky.build(input_shape=(None, 10, 4)) + _, restored_conv_leaky, _ = self.roundtrip(layer_conv_leaky) + + # Verify the activation is correctly deserialized as LeakyReLU + self.assertIsInstance( + restored_conv_leaky.activation, keras.layers.LeakyReLU + ) + # Verify configuration of the activation layer is preserved + self.assertEqual(restored_conv_leaky.activation.negative_slope, 0.15) + self.assertEqual(restored_conv_leaky.activation.name, "my_leaky") + + def test_layer_string_as_activation(self): + """Tests serialization when activation is a string.""" + + layer_dense_relu_string = keras.layers.Dense(units=4, activation="relu") + layer_dense_relu_string.build(input_shape=(None, 8)) + _, restored_dense_relu_string, _ = self.roundtrip( + layer_dense_relu_string + ) + + # Verify the activation is correctly deserialized to the relu function + self.assertTrue(callable(restored_dense_relu_string.activation)) + # Check if it resolves to the canonical keras activation function + self.assertEqual( + restored_dense_relu_string.activation, keras.activations.relu + ) + + +@keras.saving.register_keras_serializable() +class MyDense(keras.layers.Layer): + def __init__( + self, + units, + *, + kernel_regularizer=None, + kernel_initializer=None, + **kwargs, + ): + super().__init__(**kwargs) + self._units = units + self._kernel_regularizer = kernel_regularizer + self._kernel_initializer = kernel_initializer + + def get_config(self): + return dict( + units=self._units, + kernel_initializer=self._kernel_initializer, + kernel_regularizer=self._kernel_regularizer, + **super().get_config(), + ) + + def build(self, input_shape): + _, input_units = input_shape + self._kernel = self.add_weight( + name="kernel", + shape=[input_units, self._units], + dtype="float32", + regularizer=self._kernel_regularizer, + initializer=self._kernel_initializer, + ) + + def call(self, inputs): + return ops.matmul(inputs, self._kernel) + + +@keras.saving.register_keras_serializable() +class MyWrapper(keras.layers.Layer): + def __init__(self, wrapped, **kwargs): + super().__init__(**kwargs) + self._wrapped = wrapped + + def get_config(self): + return dict(wrapped=self._wrapped, **super().get_config()) + + @classmethod + def from_config(cls, config): + config["wrapped"] = keras.saving.deserialize_keras_object( + config["wrapped"] + ) + return cls(**config) + + def call(self, inputs): + return self._wrapped(inputs) diff --git a/keras/src/testing/__init__.py b/keras/src/testing/__init__.py new file mode 100644 index 000000000000..ae554ff85857 --- /dev/null +++ b/keras/src/testing/__init__.py @@ -0,0 +1,5 @@ +from keras.src.testing.test_case import TestCase +from keras.src.testing.test_case import jax_uses_gpu +from keras.src.testing.test_case import tensorflow_uses_gpu +from keras.src.testing.test_case import torch_uses_gpu +from keras.src.testing.test_case import uses_gpu diff --git a/keras/src/testing/test_case.py b/keras/src/testing/test_case.py new file mode 100644 index 000000000000..1b7ceddfdb78 --- /dev/null +++ b/keras/src/testing/test_case.py @@ -0,0 +1,792 @@ +import json +import shutil +import tempfile +import unittest +from pathlib import Path + +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import distribution +from keras.src import ops +from keras.src import tree +from keras.src import utils +from keras.src.backend.common import is_float_dtype +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.global_state import clear_session +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.losses.loss import Loss +from keras.src.models import Model +from keras.src.utils import traceback_utils + + +class TestCase(parameterized.TestCase, unittest.TestCase): + maxDiff = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setUp(self): + # clear global state so that test cases are independent + # required for the jit enabled torch tests since dynamo has + # a global cache for guards, compiled fn, etc + clear_session(free_memory=False) + if traceback_utils.is_traceback_filtering_enabled(): + traceback_utils.disable_traceback_filtering() + + def get_temp_dir(self): + temp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(temp_dir)) + return temp_dir + + def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): + if not isinstance(x1, np.ndarray): + x1 = backend.convert_to_numpy(x1) + if not isinstance(x2, np.ndarray): + x2 = backend.convert_to_numpy(x2) + np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg) + + def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): + try: + self.assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg) + except AssertionError: + return + msg = msg or "" + raise AssertionError( + f"The two values are close at all elements. \n{msg}.\nValues: {x1}" + ) + + def assertAlmostEqual(self, x1, x2, decimal=3, msg=None): + msg = msg or "" + if not isinstance(x1, np.ndarray): + x1 = backend.convert_to_numpy(x1) + if not isinstance(x2, np.ndarray): + x2 = backend.convert_to_numpy(x2) + np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg) + + def assertAllEqual(self, x1, x2, msg=None): + self.assertEqual(len(x1), len(x2), msg=msg) + for e1, e2 in zip(x1, x2): + if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)): + self.assertAllEqual(e1, e2, msg=msg) + else: + e1 = backend.convert_to_numpy(e1) + e2 = backend.convert_to_numpy(e2) + self.assertEqual(e1, e2, msg=msg) + + def assertLen(self, iterable, expected_len, msg=None): + self.assertEqual(len(iterable), expected_len, msg=msg) + + def assertSparse(self, x, sparse=True): + if isinstance(x, KerasTensor): + self.assertEqual(x.sparse, sparse) + elif backend.backend() == "tensorflow": + import tensorflow as tf + + if sparse: + self.assertIsInstance(x, tf.SparseTensor) + else: + self.assertNotIsInstance(x, tf.SparseTensor) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + if sparse: + self.assertIsInstance(x, jax_sparse.JAXSparse) + else: + self.assertNotIsInstance(x, jax_sparse.JAXSparse) + else: + self.assertFalse( + sparse, + f"Backend {backend.backend()} does not support sparse tensors", + ) + + def assertRagged(self, x, ragged=True): + if isinstance(x, KerasTensor): + self.assertEqual(x.ragged, ragged) + elif backend.backend() == "tensorflow": + import tensorflow as tf + + if ragged: + self.assertIsInstance(x, tf.RaggedTensor) + else: + self.assertNotIsInstance(x, tf.RaggedTensor) + else: + self.assertFalse( + ragged, + f"Backend {backend.backend()} does not support ragged tensors", + ) + + def assertDType(self, x, dtype, msg=None): + if hasattr(x, "dtype"): + x_dtype = backend.standardize_dtype(x.dtype) + else: + # If x is a python number + x_dtype = backend.standardize_dtype(type(x)) + standardized_dtype = backend.standardize_dtype(dtype) + default_msg = ( + "The dtype of x does not match the expected one. " + f"Received: x.dtype={x_dtype} and dtype={dtype}" + ) + msg = msg or default_msg + self.assertEqual(x_dtype, standardized_dtype, msg=msg) + + def assertFileExists(self, path): + if not Path(path).is_file(): + raise AssertionError(f"File {path} does not exist") + + def run_class_serialization_test(self, instance, custom_objects=None): + from keras.src.saving import custom_object_scope + from keras.src.saving import deserialize_keras_object + from keras.src.saving import serialize_keras_object + + # get_config roundtrip + cls = instance.__class__ + config = instance.get_config() + config_json = to_json_with_tuples(config) + ref_dir = dir(instance)[:] + with custom_object_scope(custom_objects): + revived_instance = cls.from_config(config) + revived_config = revived_instance.get_config() + revived_config_json = to_json_with_tuples(revived_config) + self.assertEqual(config_json, revived_config_json) + self.assertEqual(set(ref_dir), set(dir(revived_instance))) + + # serialization roundtrip + serialized = serialize_keras_object(instance) + serialized_json = to_json_with_tuples(serialized) + with custom_object_scope(custom_objects): + revived_instance = deserialize_keras_object( + from_json_with_tuples(serialized_json) + ) + revived_config = revived_instance.get_config() + revived_config_json = to_json_with_tuples(revived_config) + self.assertEqual(config_json, revived_config_json) + new_dir = dir(revived_instance)[:] + for lst in [ref_dir, new_dir]: + if "__annotations__" in lst: + lst.remove("__annotations__") + self.assertEqual(set(ref_dir), set(new_dir)) + return revived_instance + + def run_layer_test( + self, + layer_cls, + init_kwargs, + input_shape=None, + input_dtype=None, + input_sparse=False, + input_ragged=False, + input_data=None, + call_kwargs=None, + expected_output_shape=None, + expected_output_dtype=None, + expected_output_sparse=False, + expected_output_ragged=False, + expected_output=None, + expected_num_trainable_weights=None, + expected_num_non_trainable_weights=None, + expected_num_non_trainable_variables=None, + expected_num_seed_generators=None, + expected_num_losses=None, + supports_masking=None, + expected_mask_shape=None, + custom_objects=None, + run_training_check=True, + run_mixed_precision_check=True, + assert_built_after_instantiation=False, + ): + """Run basic checks on a layer. + + Args: + layer_cls: The class of the layer to test. + init_kwargs: Dict of arguments to be used to + instantiate the layer. + input_shape: Shape tuple (or list/dict of shape tuples) + to call the layer on. + input_dtype: Corresponding input dtype. + input_sparse: Whether the input is a sparse tensor (this requires + the backend to support sparse tensors). + input_ragged: Whether the input is a ragged tensor (this requires + the backend to support ragged tensors). + input_data: Tensor (or list/dict of tensors) + to call the layer on. + call_kwargs: Dict of arguments to use when calling the + layer (does not include the first input tensor argument) + expected_output_shape: Shape tuple + (or list/dict of shape tuples) + expected as output. + expected_output_dtype: dtype expected as output. + expected_output_sparse: Whether the output is expected to be sparse + (this requires the backend to support sparse tensors). + expected_output_ragged: Whether the output is expected to be ragged + (this requires the backend to support ragged tensors). + expected_output: Expected output tensor -- only + to be specified if input_data is provided. + expected_num_trainable_weights: Expected number + of trainable weights of the layer once built. + expected_num_non_trainable_weights: Expected number + of non-trainable weights of the layer once built. + expected_num_seed_generators: Expected number of + SeedGenerators objects of the layer once built. + expected_num_losses: Expected number of loss tensors + produced when calling the layer. + supports_masking: If True, will check that the layer + supports masking. + expected_mask_shape: Expected mask shape tuple + returned by compute_mask() (only supports 1 shape). + custom_objects: Dict of any custom objects to be + considered during deserialization. + run_training_check: Whether to attempt to train the layer + (if an input shape or input data was provided). + run_mixed_precision_check: Whether to test the layer with a mixed + precision dtype policy. + assert_built_after_instantiation: Whether to assert `built=True` + after the layer's instantiation. + """ + if input_shape is not None and input_data is not None: + raise ValueError( + "input_shape and input_data cannot be passed at the same time." + ) + if expected_output_shape is not None and expected_output is not None: + raise ValueError( + "expected_output_shape and expected_output cannot be passed " + "at the same time." + ) + if expected_output is not None and input_data is None: + raise ValueError( + "In order to use expected_output, input_data must be provided." + ) + if expected_mask_shape is not None and supports_masking is not True: + raise ValueError( + "In order to use expected_mask_shape, supports_masking " + "must be True." + ) + + init_kwargs = init_kwargs or {} + call_kwargs = call_kwargs or {} + + if input_shape is not None and input_dtype is not None: + if isinstance(input_shape, tuple) and is_shape_tuple( + input_shape[0] + ): + self.assertIsInstance(input_dtype, tuple) + self.assertEqual( + len(input_shape), + len(input_dtype), + msg="The number of input shapes and dtypes does not match", + ) + elif isinstance(input_shape, dict): + self.assertIsInstance(input_dtype, dict) + self.assertEqual( + set(input_shape.keys()), + set(input_dtype.keys()), + msg="The number of input shapes and dtypes does not match", + ) + elif isinstance(input_shape, list): + self.assertIsInstance(input_dtype, list) + self.assertEqual( + len(input_shape), + len(input_dtype), + msg="The number of input shapes and dtypes does not match", + ) + elif not isinstance(input_shape, tuple): + raise ValueError("The type of input_shape is not supported") + if input_shape is not None and input_dtype is None: + input_dtype = tree.map_shape_structure( + lambda _: "float32", input_shape + ) + + # Estimate actual number of weights, variables, seed generators if + # expected ones not set. When using layers uses composition it should + # build each sublayer manually. + if input_data is not None or input_shape is not None: + if input_data is None: + input_data = create_eager_tensors( + input_shape, input_dtype, input_sparse, input_ragged + ) + layer = layer_cls(**init_kwargs) + if isinstance(input_data, dict): + layer(**input_data, **call_kwargs) + else: + layer(input_data, **call_kwargs) + + if expected_num_trainable_weights is None: + expected_num_trainable_weights = len(layer.trainable_weights) + if expected_num_non_trainable_weights is None: + expected_num_non_trainable_weights = len( + layer.non_trainable_weights + ) + if expected_num_non_trainable_variables is None: + expected_num_non_trainable_variables = len( + layer.non_trainable_variables + ) + if expected_num_seed_generators is None: + expected_num_seed_generators = len(get_seed_generators(layer)) + + # Serialization test. + layer = layer_cls(**init_kwargs) + self.run_class_serialization_test(layer, custom_objects) + + # Basic masking test. + if supports_masking is not None: + self.assertEqual( + layer.supports_masking, + supports_masking, + msg="Unexpected supports_masking value", + ) + + def run_build_asserts(layer): + self.assertTrue(layer.built) + if expected_num_trainable_weights is not None: + self.assertLen( + layer.trainable_weights, + expected_num_trainable_weights, + msg="Unexpected number of trainable_weights", + ) + if expected_num_non_trainable_weights is not None: + self.assertLen( + layer.non_trainable_weights, + expected_num_non_trainable_weights, + msg="Unexpected number of non_trainable_weights", + ) + if expected_num_non_trainable_variables is not None: + self.assertLen( + layer.non_trainable_variables, + expected_num_non_trainable_variables, + msg="Unexpected number of non_trainable_variables", + ) + if expected_num_seed_generators is not None: + self.assertLen( + get_seed_generators(layer), + expected_num_seed_generators, + msg="Unexpected number of seed_generators", + ) + if ( + backend.backend() == "torch" + and expected_num_trainable_weights is not None + and expected_num_non_trainable_weights is not None + and expected_num_seed_generators is not None + ): + self.assertLen( + layer.torch_params, + expected_num_trainable_weights + + expected_num_non_trainable_weights + + expected_num_seed_generators, + msg="Unexpected number of torch_params", + ) + + def run_output_asserts(layer, output, eager=False): + if expected_output_shape is not None: + + def verify_shape(expected_shape, x): + shape = x.shape + if len(shape) != len(expected_shape): + return False + for expected_dim, dim in zip(expected_shape, shape): + if expected_dim is not None and expected_dim != dim: + return False + return True + + shapes_match = tree.map_structure_up_to( + output, verify_shape, expected_output_shape, output + ) + self.assertTrue( + all(tree.flatten(shapes_match)), + msg=f"Expected output shapes {expected_output_shape} but " + f"received {tree.map_structure(lambda x: x.shape, output)}", + ) + if expected_output_dtype is not None: + + def verify_dtype(expected_dtype, x): + return expected_dtype == backend.standardize_dtype(x.dtype) + + dtypes_match = tree.map_structure( + verify_dtype, expected_output_dtype, output + ) + self.assertTrue( + all(tree.flatten(dtypes_match)), + msg=f"Expected output dtypes {expected_output_dtype} but " + f"received {tree.map_structure(lambda x: x.dtype, output)}", + ) + if expected_output_sparse: + for x in tree.flatten(output): + self.assertSparse(x) + if expected_output_ragged: + for x in tree.flatten(output): + self.assertRagged(x) + if eager: + if expected_output is not None: + self.assertEqual(type(expected_output), type(output)) + for ref_v, v in zip( + tree.flatten(expected_output), tree.flatten(output) + ): + self.assertAllClose( + ref_v, v, msg="Unexpected output value" + ) + if expected_num_losses is not None: + self.assertLen(layer.losses, expected_num_losses) + + def run_training_step(layer, input_data, output_data): + class TestModel(Model): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def call(self, x, training=False): + return self.layer(x, training=training) + + model = TestModel(layer) + + data = (input_data, output_data) + if backend.backend() == "torch": + data = tree.map_structure(backend.convert_to_numpy, data) + + def data_generator(): + while True: + yield data + + # Single op loss to avoid compilation issues with ragged / sparse. + class TestLoss(Loss): + def __call__(self, y_true, y_pred, sample_weight=None): + return ops.sum(y_pred) + + # test the "default" path for each backend by setting + # jit_compile="auto". + # for tensorflow and jax backends auto is jitted + # Note that tensorflow cannot be jitted with sparse tensors + # for torch backend auto is eager + # + # NB: for torch, jit_compile=True turns on torchdynamo + # which may not always succeed in tracing depending + # on the model. Run your program with these env vars + # to get debug traces of dynamo: + # TORCH_LOGS="+dynamo" + # TORCHDYNAMO_VERBOSE=1 + # TORCHDYNAMO_REPORT_GUARD_FAILURES=1 + jit_compile = "auto" + if backend.backend() == "tensorflow" and input_sparse: + jit_compile = False + model.compile( + optimizer="sgd", loss=TestLoss(), jit_compile=jit_compile + ) + model.fit(data_generator(), steps_per_epoch=1, verbose=0) + + # Build test. + if input_data is not None or input_shape is not None: + if input_shape is None: + build_shape = tree.map_structure( + lambda x: ops.shape(x), input_data + ) + else: + build_shape = input_shape + layer = layer_cls(**init_kwargs) + if isinstance(build_shape, dict): + layer.build(**build_shape) + else: + layer.build(build_shape) + run_build_asserts(layer) + + # Symbolic call test. + if input_shape is None: + keras_tensor_inputs = tree.map_structure( + lambda x: create_keras_tensors( + ops.shape(x), x.dtype, input_sparse, input_ragged + ), + input_data, + ) + else: + keras_tensor_inputs = create_keras_tensors( + input_shape, input_dtype, input_sparse, input_ragged + ) + layer = layer_cls(**init_kwargs) + if isinstance(keras_tensor_inputs, dict): + keras_tensor_outputs = layer( + **keras_tensor_inputs, **call_kwargs + ) + else: + keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs) + run_build_asserts(layer) + run_output_asserts(layer, keras_tensor_outputs, eager=False) + + if expected_mask_shape is not None: + output_mask = layer.compute_mask(keras_tensor_inputs) + self.assertEqual(expected_mask_shape, output_mask.shape) + + # The stateless layers should be built after instantiation. + if assert_built_after_instantiation: + layer = layer_cls(**init_kwargs) + self.assertTrue( + layer.built, + msg=( + f"{type(layer)} is stateless, so it should be built " + "after instantiation." + ), + ) + + # Ensure that the subclass layer doesn't mark itself as built + # when `build` is overridden. + + class ModifiedBuildLayer(layer_cls): + def build(self, *args, **kwargs): + pass + + layer = ModifiedBuildLayer(**init_kwargs) + self.assertFalse( + layer.built, + msg=( + f"The `build` of {type(layer)} is overriden, so it " + "should not be built after instantiation." + ), + ) + + # Eager call test and compiled training test. + if input_data is not None or input_shape is not None: + if input_data is None: + input_data = create_eager_tensors( + input_shape, input_dtype, input_sparse + ) + layer = layer_cls(**init_kwargs) + if isinstance(input_data, dict): + output_data = layer(**input_data, **call_kwargs) + else: + output_data = layer(input_data, **call_kwargs) + run_output_asserts(layer, output_data, eager=True) + + if run_training_check: + run_training_step(layer, input_data, output_data) + + # Never test mixed precision on torch CPU. Torch lacks support. + if run_mixed_precision_check and backend.backend() == "torch": + import torch + + run_mixed_precision_check = torch.cuda.is_available() + + if run_mixed_precision_check: + layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"}) + input_spec = tree.map_structure( + lambda spec: KerasTensor( + spec.shape, + dtype=( + layer.compute_dtype + if layer.autocast + and backend.is_float_dtype(spec.dtype) + else spec.dtype + ), + ), + keras_tensor_inputs, + ) + if isinstance(input_data, dict): + output_data = layer(**input_data, **call_kwargs) + output_spec = layer.compute_output_spec(**input_spec) + else: + output_data = layer(input_data, **call_kwargs) + output_spec = layer.compute_output_spec(input_spec) + for tensor, spec in zip( + tree.flatten(output_data), tree.flatten(output_spec) + ): + dtype = standardize_dtype(tensor.dtype) + self.assertEqual( + dtype, + spec.dtype, + f"expected output dtype {spec.dtype}, got {dtype}", + ) + for weight in layer.weights: + dtype = standardize_dtype(weight.dtype) + if is_float_dtype(dtype): + self.assertEqual(dtype, "float32") + + +def tensorflow_uses_gpu(): + return backend.backend() == "tensorflow" and uses_gpu() + + +def jax_uses_gpu(): + return backend.backend() == "jax" and uses_gpu() + + +def torch_uses_gpu(): + if backend.backend() != "torch": + return False + from keras.src.backend.torch.core import get_device + + return get_device() == "cuda" + + +def uses_gpu(): + # Condition used to skip tests when using the GPU + devices = distribution.list_devices() + if any(d.startswith("gpu") for d in devices): + return True + return False + + +def uses_cpu(): + devices = distribution.list_devices() + if any(d.startswith("cpu") for d in devices): + return True + return False + + +def create_keras_tensors(input_shape, dtype, sparse, ragged): + if isinstance(input_shape, dict): + return { + utils.removesuffix(k, "_shape"): KerasTensor( + v, dtype=dtype[k], sparse=sparse, ragged=ragged + ) + for k, v in input_shape.items() + } + return map_shape_dtype_structure( + lambda shape, dt: KerasTensor( + shape, dtype=dt, sparse=sparse, ragged=ragged + ), + input_shape, + dtype, + ) + + +def create_eager_tensors(input_shape, dtype, sparse, ragged): + from keras.src.backend import random + + if set(tree.flatten(dtype)).difference( + [ + "float16", + "float32", + "float64", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + ] + ): + raise ValueError( + "dtype must be a standard float or int dtype. " + f"Received: dtype={dtype}" + ) + + if sparse: + if backend.backend() == "tensorflow": + import tensorflow as tf + + def create_fn(shape, dt): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal(shape)).astype(dt) + x = np.multiply(x, rng.random(shape) < 0.7) + return tf.sparse.from_dense(x) + + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + def create_fn(shape, dt): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal(shape)).astype(dt) + x = np.multiply(x, rng.random(shape) < 0.7) + return jax_sparse.BCOO.fromdense(x, n_batch=1) + + else: + raise ValueError( + f"Sparse is unsupported with backend {backend.backend()}" + ) + + elif ragged: + if backend.backend() == "tensorflow": + import tensorflow as tf + + def create_fn(shape, dt): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal(shape)).astype(dt) + x = np.multiply(x, rng.random(shape) < 0.7) + return tf.RaggedTensor.from_tensor(x, padding=0) + + else: + raise ValueError( + f"Ragged is unsupported with backend {backend.backend()}" + ) + + else: + + def create_fn(shape, dt): + return ops.cast( + random.uniform(shape, dtype="float32") * 3, dtype=dt + ) + + if isinstance(input_shape, dict): + return { + utils.removesuffix(k, "_shape"): create_fn(v, dtype[k]) + for k, v in input_shape.items() + } + return map_shape_dtype_structure(create_fn, input_shape, dtype) + + +def is_shape_tuple(x): + return isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ) + + +def map_shape_dtype_structure(fn, shape, dtype): + """Variant of tree.map_structure that operates on shape tuples.""" + if is_shape_tuple(shape): + return fn(tuple(shape), dtype) + if isinstance(shape, list): + return [ + map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype) + ] + if isinstance(shape, tuple): + return tuple( + map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype) + ) + if isinstance(shape, dict): + return { + k: map_shape_dtype_structure(fn, v, dtype[k]) + for k, v in shape.items() + } + else: + raise ValueError( + f"Cannot map function to unknown objects {shape} and {dtype}" + ) + + +def get_seed_generators(layer): + """Get a List of all seed generators in the layer recursively.""" + seed_generators = [] + seen_ids = set() + for sublayer in layer._flatten_layers(True, True): + for sg in sublayer._seed_generators: + if id(sg) not in seen_ids: + seed_generators.append(sg) + seen_ids.add(id(sg)) + return seed_generators + + +def to_json_with_tuples(value): + def _tuple_encode(obj): + if isinstance(obj, tuple): + return {"__class__": "tuple", "__value__": list(obj)} + if isinstance(obj, list): + return [_tuple_encode(e) for e in obj] + if isinstance(obj, dict): + return {key: _tuple_encode(value) for key, value in obj.items()} + return obj + + class _PreserveTupleJsonEncoder(json.JSONEncoder): + def encode(self, obj): + obj = _tuple_encode(obj) + return super().encode(obj) + + return _PreserveTupleJsonEncoder(sort_keys=True, indent=4).encode(value) + + +def from_json_with_tuples(value): + def _tuple_decode(obj): + if not isinstance(obj, dict): + return obj + if "__class__" not in obj or "__value__" not in obj: + return obj + return tuple(obj["__value__"]) + + return json.loads(value, object_hook=_tuple_decode) diff --git a/keras/src/testing/test_utils.py b/keras/src/testing/test_utils.py new file mode 100644 index 000000000000..0df3645ff6df --- /dev/null +++ b/keras/src/testing/test_utils.py @@ -0,0 +1,163 @@ +import numpy as np + + +def get_test_data( + train_samples, test_samples, input_shape, num_classes, random_seed=None +): + """Generates balanced, stratified synthetic test data to train a model on. + + Args: + train_samples: Integer, how many training samples to generate. + test_samples: Integer, how many test samples to generate. + input_shape: Tuple of integers, shape of the inputs. + num_classes: Integer, number of classes for the data and targets. + random_seed: Integer, random seed used by Numpy to generate data. + + Returns: + A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + """ + np.random.seed(random_seed) + + # Total samples + total_samples = train_samples + test_samples + + # Ensure that we generate a balanced dataset + samples_per_class = total_samples // num_classes + y = np.array( + [i for i in range(num_classes) for _ in range(samples_per_class)], + dtype=np.int32, + ) + + # Generate extra samples in a deterministic manner + extra_samples = total_samples - len(y) + y_extra = np.array( + [i % num_classes for i in range(extra_samples)], dtype=np.int64 + ) + y = np.concatenate([y, y_extra]) + + # Generate data + templates = 2 * num_classes * np.random.random((num_classes,) + input_shape) + x = np.zeros((total_samples,) + input_shape, dtype=np.float32) + for i in range(total_samples): + x[i] = templates[y[i]] + np.random.normal( + loc=0, scale=1.0, size=input_shape + ) + + # Shuffle the entire dataset to ensure randomness based on seed + indices = np.arange(total_samples) + np.random.shuffle(indices) + x, y = x[indices], y[indices] + + # Stratified Shuffle Split + x_train, y_train, x_test, y_test = [], [], [], [] + for cls in range(num_classes): + cls_indices = np.where(y == cls)[0] + np.random.shuffle(cls_indices) + train_count = int(train_samples / num_classes) + + x_train.extend(x[cls_indices[:train_count]]) + y_train.extend(y[cls_indices[:train_count]]) + + x_test.extend(x[cls_indices[train_count:]]) + y_test.extend(y[cls_indices[train_count:]]) + + # Convert to numpy arrays + x_train, y_train = np.array(x_train), np.array(y_train) + x_test, y_test = np.array(x_test), np.array(y_test) + + # Shuffle training and test sets after stratified split + train_indices = np.arange(len(x_train)) + test_indices = np.arange(len(x_test)) + np.random.shuffle(train_indices) + np.random.shuffle(test_indices) + + x_train, y_train = x_train[train_indices], y_train[train_indices] + x_test, y_test = x_test[test_indices], y_test[test_indices] + + return (x_train, y_train), (x_test, y_test) + + +def named_product(*args, **kwargs): + """Utility to generate the cartesian product of parameters values and + generate a test case names for each combination. + + The result of this function is to be used with the + `@parameterized.named_parameters` decorator. It is a replacement for + `@parameterized.product` which adds explicit test case names. + + For example, this code: + ``` + class NamedExample(parameterized.TestCase): + @parameterized.named_parameters( + named_product( + [ + {'testcase_name': 'negative', 'x': -1}, + {'testcase_name': 'positive', 'x': 1}, + {'testcase_name': 'zero', 'x': 0}, + ], + numeral_type=[float, int], + ) + ) + def test_conversion(self, x, numeral_type): + self.assertEqual(numeral_type(x), x) + ``` + produces six tests (note that absl will reorder them by name): + - `NamedExample::test_conversion_negative_float` + - `NamedExample::test_conversion_positive_float` + - `NamedExample::test_conversion_zero_float` + - `NamedExample::test_conversion_negative_int` + - `NamedExample::test_conversion_positive_int` + - `NamedExample::test_conversion_zero_int` + + This function is also useful in the case where there is no product to + generate test case names for one argument: + ``` + @parameterized.named_parameters(named_product(numeral_type=[float, int])) + ``` + + Args: + *args: Each positional parameter is a sequence of keyword arg dicts. + Every test case generated will include exactly one dict from each + positional parameter. These will then be merged to form an overall + list of arguments for the test case. Each dict must contain a + `"testcase_name"` key whose value is combined with others to + generate the test case name. + **kwargs: A mapping of parameter names and their possible values. + Possible values should given as either a list or a tuple. A string + representation of each value is used to generate the test case name. + + Returns: + A list of maps for the test parameters combinations to pass to + `@parameterized.named_parameters`. + """ + + def value_to_str(value): + if hasattr(value, "__name__"): + return value.__name__.lower() + return str(value).lower() + + # Convert the keyword arguments in the same dict format as the args + all_test_dicts = args + tuple( + tuple({"testcase_name": value_to_str(v), key: v} for v in values) + for key, values in kwargs.items() + ) + + # The current list of tests, start with one empty test + tests = [{}] + for test_dicts in all_test_dicts: + new_tests = [] + for test_dict in test_dicts: + for test in tests: + # Augment the testcase name by appending + testcase_name = test.get("testcase_name", "") + testcase_name += "_" if testcase_name else "" + testcase_name += test_dict["testcase_name"] + new_test = test.copy() + # Augment the test by adding all the parameters + new_test.update(test_dict) + new_test["testcase_name"] = testcase_name + new_tests.append(new_test) + # Overwrite the list of tests with the product obtained so far + tests = new_tests + + return tests diff --git a/keras/src/testing/test_utils_test.py b/keras/src/testing/test_utils_test.py new file mode 100644 index 000000000000..f0b6591c79de --- /dev/null +++ b/keras/src/testing/test_utils_test.py @@ -0,0 +1,290 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src.testing import test_case +from keras.src.testing import test_utils + + +class GetTestDataTest(test_case.TestCase): + def setUp(self): + self.train_samples = 100 + self.test_samples = 50 + self.input_shape = (28, 28) + self.num_classes = 10 + + def test_labels_within_range(self): + """Check if labels are within valid range.""" + (_, y_train), (_, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertTrue(np.all(y_train < self.num_classes)) + self.assertTrue(np.all(y_train >= 0)) + self.assertTrue(np.all(y_test < self.num_classes)) + self.assertTrue(np.all(y_test >= 0)) + + def test_edge_cases_for_zero_samples(self): + """Test when train or test samples are zero.""" + (x_train, _), (x_test, _) = test_utils.get_test_data( + 0, self.test_samples, self.input_shape, self.num_classes + ) + self.assertEqual(len(x_train), 0) + + (x_train, _), (x_test, _) = test_utils.get_test_data( + self.train_samples, 0, self.input_shape, self.num_classes + ) + self.assertEqual(len(x_test), 0) + + def test_get_test_data_returns_correct_number_of_samples(self): + """Check if returned samples count is correct.""" + (x_train, y_train), (x_test, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertEqual(len(x_train), self.train_samples) + self.assertEqual(len(y_train), self.train_samples) + self.assertEqual(len(x_test), self.test_samples) + self.assertEqual(len(y_test), self.test_samples) + + def test_get_test_data_returns_correct_shape_of_data(self): + """Check if returned data shape is correct.""" + (x_train, y_train), (x_test, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertEqual( + x_train.shape, (self.train_samples,) + self.input_shape + ) + self.assertEqual(y_train.shape, (self.train_samples,)) + self.assertEqual(x_test.shape, (self.test_samples,) + self.input_shape) + self.assertEqual(y_test.shape, (self.test_samples,)) + + def test_get_test_data_returns_different_data_for_different_seeds(self): + """Test variability with different seeds.""" + (x_train_1, y_train_1), (x_test_1, y_test_1) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + random_seed=1, + ) + (x_train_2, y_train_2), (x_test_2, y_test_2) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + random_seed=2, + ) + self.assertFalse(np.array_equal(x_train_1, x_train_2)) + self.assertFalse(np.array_equal(y_train_1, y_train_2)) + self.assertFalse(np.array_equal(x_test_1, x_test_2)) + self.assertFalse(np.array_equal(y_test_1, y_test_2)) + + def test_get_test_data_returns_consistent_data_for_same_seed(self): + """Test consistency with the same seed.""" + (x_train_1, y_train_1), (x_test_1, y_test_1) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + random_seed=1, + ) + (x_train_2, y_train_2), (x_test_2, y_test_2) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + random_seed=1, + ) + self.assertTrue(np.array_equal(x_train_1, x_train_2)) + self.assertTrue(np.array_equal(y_train_1, y_train_2)) + self.assertTrue(np.array_equal(x_test_1, x_test_2)) + self.assertTrue(np.array_equal(y_test_1, y_test_2)) + + def test_input_shape_variations(self): + """Check function for different input shapes.""" + input_shape_3d = (28, 28, 3) + (x_train_3d, _), (_, _) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + input_shape_3d, + self.num_classes, + ) + self.assertEqual( + x_train_3d.shape, (self.train_samples,) + input_shape_3d + ) + + def test_all_classes_represented(self): + """Ensure all classes are represented in the data.""" + (_, y_train), (_, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertEqual(len(np.unique(y_train)), self.num_classes) + self.assertEqual(len(np.unique(y_test)), self.num_classes) + + def test_data_type(self): + """Validate the type of the generated data.""" + (x_train, _), (x_test, _) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertEqual(x_train.dtype, np.float32) + self.assertEqual(x_test.dtype, np.float32) + + def test_label_type(self): + """Validate label type of the generated labels.""" + (_, y_train), (_, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertEqual(y_train.dtype, np.int64) + self.assertEqual(y_test.dtype, np.int64) + + +class ClassDistributionTests(test_case.TestCase): + def setUp(self): + self.train_samples = 100 + self.test_samples = 50 + self.input_shape = (28, 28) + self.num_classes = 10 + + def test_equal_class_distribution(self): + """Verify equal class distribution in train and test sets.""" + (_, y_train), (_, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + _, counts_train = np.unique(y_train, return_counts=True) + _, counts_test = np.unique(y_test, return_counts=True) + + self.assertTrue( + np.all(counts_train == self.train_samples // self.num_classes) + ) + self.assertTrue( + np.all(counts_test == self.test_samples // self.num_classes) + ) + + def test_uneven_samples_class_distribution(self): + """Check class distribution with uneven samples.""" + train_samples = 103 + test_samples = 52 + (_, y_train), (_, y_test) = test_utils.get_test_data( + train_samples, + test_samples, + self.input_shape, + self.num_classes, + ) + _, counts_train = np.unique(y_train, return_counts=True) + _, counts_test = np.unique(y_test, return_counts=True) + + self.assertTrue(np.max(counts_train) - np.min(counts_train) <= 1) + self.assertTrue(np.max(counts_test) - np.min(counts_test) <= 1) + + def test_randomness_in_class_distribution(self): + """Ensure class distribution isn't too deterministic.""" + (_, y_train_1), (_, y_test_1) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + (_, y_train_2), (_, y_test_2) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + self.num_classes, + ) + self.assertFalse(np.array_equal(y_train_1, y_train_2)) + self.assertFalse(np.array_equal(y_test_1, y_test_2)) + + def test_large_number_of_classes(self): + """Validate function with a large number of classes.""" + num_classes = 150 + train_samples = ( + num_classes * 10 + ) # 10 samples for each class in training + test_samples = num_classes * 5 # 5 samples for each class in testing + (_, y_train), (_, y_test) = test_utils.get_test_data( + train_samples, + test_samples, + self.input_shape, + num_classes, + ) + self.assertEqual(len(np.unique(y_train)), num_classes) + self.assertEqual(len(np.unique(y_test)), num_classes) + + def test_single_class(self): + """Test with a single class.""" + num_classes = 1 + (_, y_train), (_, y_test) = test_utils.get_test_data( + self.train_samples, + self.test_samples, + self.input_shape, + num_classes, + ) + self.assertTrue(np.all(y_train == 0)) + self.assertTrue(np.all(y_test == 0)) + + +class NamedProductTest(parameterized.TestCase): + def test_test_cases(self): + all_tests = test_utils.named_product( + [ + {"testcase_name": "negative", "x": -1}, + {"testcase_name": "positive", "x": 1}, + {"testcase_name": "zero", "x": 0}, + ], + numeral_type=[float, int], + ) + names = [test["testcase_name"] for test in all_tests] + self.assertListEqual( + names, + [ + "negative_float", + "positive_float", + "zero_float", + "negative_int", + "positive_int", + "zero_int", + ], + ) + + def test_test_cases_no_product(self): + all_tests = test_utils.named_product(numeral_type=[float, int]) + names = [test["testcase_name"] for test in all_tests] + self.assertListEqual(names, ["float", "int"]) + + @parameterized.named_parameters( + test_utils.named_product( + [ + {"testcase_name": "negative", "x": -1}, + {"testcase_name": "positive", "x": 1}, + {"testcase_name": "zero", "x": 0}, + ], + numeral_type=[float, int], + ) + ) + def test_via_decorator(self, x, numeral_type): + self.assertIn(x, (-1, 1, 0)) + self.assertIn(numeral_type, (float, int)) + + @parameterized.named_parameters( + test_utils.named_product(numeral_type=[float, int]) + ) + def test_via_decorator_no_product(self, numeral_type): + self.assertIn(numeral_type, (float, int)) diff --git a/keras/src/trainers/__init__.py b/keras/src/trainers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py new file mode 100644 index 000000000000..d911aa805ca0 --- /dev/null +++ b/keras/src/trainers/compile_utils.py @@ -0,0 +1,840 @@ +from collections import namedtuple + +from keras.src import losses as losses_module +from keras.src import metrics as metrics_module +from keras.src import ops +from keras.src import tree +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.losses import loss as loss_module +from keras.src.utils.naming import get_object_name +from keras.src.utils.tracking import Tracker + + +class MetricsList(metrics_module.Metric): + def __init__(self, metrics, name="metrics_list", output_name=None): + super().__init__(name=name) + self.metrics = metrics + self.output_name = output_name + + def update_state(self, y_true, y_pred, sample_weight=None): + for m in self.metrics: + m.update_state(y_true, y_pred, sample_weight=sample_weight) + + def reset_state(self): + for m in self.metrics: + m.reset_state() + + def get_result(self): + return {m.name: m.result() for m in self.metrics} + + def get_config(self): + raise NotImplementedError + + @classmethod + def from_config(cls, config): + raise NotImplementedError + + +def is_function_like(value): + if value is None: + return True + if isinstance(value, str): + return True + if callable(value): + return True + return False + + +def is_binary_or_sparse_categorical(y_true, y_pred): + y_t_rank = len(y_true.shape) + y_p_rank = len(y_pred.shape) + y_t_last_dim = y_true.shape[-1] + y_p_last_dim = y_pred.shape[-1] + + is_binary = y_p_last_dim == 1 + is_sparse_categorical = ( + y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1 + ) + return is_binary, is_sparse_categorical + + +def get_metric(identifier, y_true, y_pred): + if identifier is None: + return None # Ok to have no metric for an output. + + # Convenience feature for selecting b/t binary, categorical, + # and sparse categorical. + if str(identifier).lower() not in ["accuracy", "acc"]: + metric_obj = metrics_module.get(identifier) + else: + is_binary, is_sparse_categorical = is_binary_or_sparse_categorical( + y_true, y_pred + ) + if is_binary: + metric_obj = metrics_module.BinaryAccuracy(name=str(identifier)) + elif is_sparse_categorical: + metric_obj = metrics_module.SparseCategoricalAccuracy( + name=str(identifier) + ) + else: + metric_obj = metrics_module.CategoricalAccuracy( + name=str(identifier) + ) + + if isinstance(identifier, str): + metric_name = identifier + else: + metric_name = get_object_name(metric_obj) + + if not isinstance(metric_obj, metrics_module.Metric): + metric_obj = metrics_module.MeanMetricWrapper(metric_obj) + + metric_obj.name = metric_name + return metric_obj + + +def get_loss(identifier, y_true, y_pred): + if identifier is None: + return None # Ok to have no loss for an output. + + # Convenience feature for selecting b/t binary, categorical, + # and sparse categorical. + if str(identifier).lower() not in ["crossentropy", "ce"]: + loss_obj = losses_module.get(identifier) + else: + is_binary, is_sparse_categorical = is_binary_or_sparse_categorical( + y_true, y_pred + ) + if is_binary: + loss_obj = losses_module.binary_crossentropy + elif is_sparse_categorical: + loss_obj = losses_module.sparse_categorical_crossentropy + else: + loss_obj = losses_module.categorical_crossentropy + + if not isinstance(loss_obj, losses_module.Loss): + if isinstance(identifier, str): + loss_name = identifier + else: + loss_name = get_object_name(loss_obj) + loss_obj = losses_module.LossFunctionWrapper(loss_obj, name=loss_name) + return loss_obj + + +class CompileMetrics(metrics_module.Metric): + def __init__( + self, + metrics, + weighted_metrics, + name="compile_metric", + output_names=None, + ): + super().__init__(name=name) + if metrics and not isinstance(metrics, (list, tuple, dict)): + raise ValueError( + "Expected `metrics` argument to be a list, tuple, or dict. " + f"Received instead: metrics={metrics} of type {type(metrics)}" + ) + if weighted_metrics and not isinstance( + weighted_metrics, (list, tuple, dict) + ): + raise ValueError( + "Expected `weighted_metrics` argument to be a list, tuple, or " + f"dict. Received instead: weighted_metrics={weighted_metrics} " + f"of type {type(weighted_metrics)}" + ) + self._user_metrics = metrics + self._user_weighted_metrics = weighted_metrics + self.built = False + self.name = "compile_metrics" + self.output_names = output_names + self._resolved_output_names = None + + @property + def metrics(self): + if not self.built: + return [] + metrics = [] + for m in self._flat_metrics + self._flat_weighted_metrics: + if isinstance(m, MetricsList): + metrics.extend(m.metrics) + elif m is not None: + metrics.append(m) + return metrics + + @property + def variables(self): + # Avoiding relying on implicit tracking since + # CompileMetrics may be instantiated or built in a no tracking scope. + if not self.built: + return [] + vars = [] + for m in self.metrics: + if m is not None: + vars.extend(m.variables) + return vars + + def build(self, y_true, y_pred): + num_outputs = 1 # default + # Resolve output names. If y_pred is a dict, prefer its keys. + if isinstance(y_pred, dict): + keys = sorted(list(y_pred.keys())) + if self.output_names and set(self.output_names) == set(keys): + # If there is a perfect match, use the user-provided order. + output_names = self.output_names + else: + output_names = keys + elif self.output_names: + output_names = self.output_names + elif isinstance(y_pred, (list, tuple)): + num_outputs = len(y_pred) + if all(hasattr(x, "_keras_history") for x in y_pred): + output_names = [x._keras_history.operation.name for x in y_pred] + else: + output_names = None + else: + output_names = None + self._resolved_output_names = output_names + if output_names: + num_outputs = len(output_names) + + y_pred = self._flatten_y(y_pred) + y_true = self._flatten_y(y_true) + + metrics = self._user_metrics + weighted_metrics = self._user_weighted_metrics + self._flat_metrics = self._build_metrics_set( + metrics, + num_outputs, + output_names, + y_true, + y_pred, + argument_name="metrics", + ) + self._flat_weighted_metrics = self._build_metrics_set( + weighted_metrics, + num_outputs, + output_names, + y_true, + y_pred, + argument_name="weighted_metrics", + ) + self.built = True + + def _build_metrics_set( + self, metrics, num_outputs, output_names, y_true, y_pred, argument_name + ): + flat_metrics = [] + if isinstance(metrics, dict): + for name in metrics.keys(): + if name not in output_names: + raise ValueError( + f"In the dict argument `{argument_name}`, key " + f"'{name}' does not correspond to any model " + f"output. Received:\n{argument_name}={metrics}" + ) + if num_outputs == 1: + if not metrics: + flat_metrics.append(None) + else: + if isinstance(metrics, dict): + metrics = tree.flatten(metrics) + if not isinstance(metrics, list): + metrics = [metrics] + if not all(is_function_like(m) for m in metrics): + raise ValueError( + f"Expected all entries in the `{argument_name}` list " + f"to be metric objects. Received instead:\n" + f"{argument_name}={metrics}" + ) + flat_metrics.append( + MetricsList( + [ + get_metric(m, y_true[0], y_pred[0]) + for m in metrics + if m is not None + ] + ) + ) + else: + if isinstance(metrics, (list, tuple)): + if len(metrics) != len(y_pred): + raise ValueError( + "For a model with multiple outputs, " + f"when providing the `{argument_name}` argument as a " + "list, it should have as many entries as the model has " + f"outputs. Received:\n{argument_name}={metrics}\nof " + f"length {len(metrics)} whereas the model has " + f"{len(y_pred)} outputs." + ) + for idx, (mls, yt, yp) in enumerate( + zip(metrics, y_true, y_pred) + ): + if not isinstance(mls, list): + mls = [mls] + name = output_names[idx] if output_names else None + if not all(is_function_like(e) for e in mls): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` list should be metric objects. " + f"Found the following sublist with unknown " + f"types: {mls}" + ) + flat_metrics.append( + MetricsList( + [ + get_metric(m, yt, yp) + for m in mls + if m is not None + ], + output_name=name, + ) + ) + elif isinstance(metrics, dict): + if output_names is None: + raise ValueError( + f"Argument `{argument_name}` can only be provided as a " + "dict when the model also returns a dict of outputs. " + f"Received {argument_name}={metrics}" + ) + for name in metrics.keys(): + if not isinstance(metrics[name], list): + metrics[name] = [metrics[name]] + if not all(is_function_like(e) for e in metrics[name]): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` dict should be metric objects. " + f"At key '{name}', found the following sublist " + f"with unknown types: {metrics[name]}" + ) + for name, yt, yp in zip(output_names, y_true, y_pred): + if name in metrics: + flat_metrics.append( + MetricsList( + [ + get_metric(m, yt, yp) + for m in metrics[name] + if m is not None + ], + output_name=name, + ) + ) + else: + flat_metrics.append(None) + return flat_metrics + + def _flatten_y(self, y): + names = self._resolved_output_names + if isinstance(y, dict) and names: + result = [] + for name in names: + if name in y: + result.append(y[name]) + return result + return tree.flatten(y) + + def update_state(self, y_true, y_pred, sample_weight=None): + if not self.built: + self.build(y_true, y_pred) + y_true = self._flatten_y(y_true) + y_pred = self._flatten_y(y_pred) + for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred): + if m: + m.update_state(y_t, y_p) + if sample_weight is not None: + sample_weight = self._flatten_y(sample_weight) + # For multi-outputs, repeat sample weights for n outputs. + if len(sample_weight) < len(y_true): + sample_weight = [sample_weight[0] for _ in range(len(y_true))] + else: + sample_weight = [None for _ in range(len(y_true))] + for m, y_t, y_p, s_w in zip( + self._flat_weighted_metrics, y_true, y_pred, sample_weight + ): + if m: + m.update_state(y_t, y_p, s_w) + + def reset_state(self): + if not self.built: + return + for m in self._flat_metrics: + if m: + m.reset_state() + for m in self._flat_weighted_metrics: + if m: + m.reset_state() + + def result(self): + if not self.built: + raise ValueError( + "Cannot get result() since the metric has not yet been built." + ) + results = {} + unique_name_counters = {} + for mls in self._flat_metrics: + if not mls: + continue + for m in mls.metrics: + name = m.name + if mls.output_name: + name = f"{mls.output_name}_{name}" + if name not in unique_name_counters: + results[name] = m.result() + unique_name_counters[name] = 1 + else: + index = unique_name_counters[name] + unique_name_counters[name] += 1 + name = f"{name}_{index}" + results[name] = m.result() + + for mls in self._flat_weighted_metrics: + if not mls: + continue + for m in mls.metrics: + name = m.name + if mls.output_name: + name = f"{mls.output_name}_{name}" + if name not in unique_name_counters: + results[name] = m.result() + unique_name_counters[name] = 1 + else: + name = f"weighted_{m.name}" + if mls.output_name: + name = f"{mls.output_name}_{name}" + if name not in unique_name_counters: + unique_name_counters[name] = 1 + else: + index = unique_name_counters[name] + unique_name_counters[name] += 1 + name = f"{name}_{index}" + results[name] = m.result() + return results + + def get_config(self): + raise NotImplementedError + + @classmethod + def from_config(cls, config): + raise NotImplementedError + + +class CompileLoss(losses_module.Loss): + Loss = namedtuple("Loss", ["path", "loss", "loss_weights", "name"]) + + def __init__( + self, + loss, + loss_weights=None, + reduction="sum_over_batch_size", + output_names=None, + ): + if loss_weights and not isinstance( + loss_weights, (list, tuple, dict, float) + ): + raise ValueError( + "Expected `loss_weights` argument to be a float " + "(single output case) or a list, tuple, or " + "dict (multiple output case). " + f"Received instead: loss_weights={loss_weights} " + f"of type {type(loss_weights)}" + ) + self._user_loss = loss + self._user_loss_weights = loss_weights + self.built = False + self.output_names = output_names + super().__init__(name="compile_loss", reduction=reduction) + + # Use `Tracker` to track metrics for individual losses. + self._metrics = [] + self._tracker = Tracker( + { + "metrics": ( + lambda x: isinstance(x, metrics_module.Metric), + self._metrics, + ) + } + ) + self._flat_losses = None + self._y_pred_build_structure = None + self._y_true_build_structure = None + + @property + def metrics(self): + return self._metrics + + @property + def variables(self): + vars = [] + for m in self.metrics: + vars.extend(m.variables) + return vars + + def _build_nested(self, y_true, y_pred, loss, output_names, current_path): + flat_y_pred = tree.flatten(y_pred) + if not tree.is_nested(loss): + _loss = loss.loss + if _loss is None: + return + loss_weight = loss.weight + resolved_loss = get_loss(_loss, y_true, y_pred) + name_path = current_path + if not tree.is_nested(output_names): + if output_names is not None: + output_name = output_names + else: + output_name = resolved_loss.name + if len(name_path) == 0: + name_path = (output_name,) + elif isinstance(name_path[-1], int): + name_path = name_path[:-1] + (output_name,) + name = "/".join([str(path) for path in name_path]) + if name == "": + if isinstance(output_names, dict): + flat_output_names = list(output_names.keys()) + else: + flat_output_names = tree.flatten(output_names) + name = "_".join(flat_output_names) + self._flat_losses.append( + CompileLoss.Loss(current_path, resolved_loss, loss_weight, name) + ) + return + elif ( + issubclass(type(loss), (list, tuple)) + and all([not tree.is_nested(_loss) for _loss in loss]) + and len(loss) == len(flat_y_pred) + ): + loss = tree.pack_sequence_as(y_pred, loss) + elif issubclass(type(loss), (list, tuple)) and not isinstance( + y_pred, type(loss) + ): + for _loss in loss: + self._build_nested( + y_true, + y_pred, + _loss, + output_names, + current_path, + ) + return + + if not tree.is_nested(loss): + return self._build_nested( + y_true, y_pred, loss, output_names, current_path + ) + + if not isinstance(loss, type(y_pred)): + raise KeyError( + f"The path: {current_path} in " + "the `loss` argument, can't be found in " + "the model's output (`y_pred`)." + ) + + # shallow traverse the loss config + if isinstance(loss, dict): + iterator = loss.items() + + def key_check_fn(key, objs): + return all( + [isinstance(obj, dict) and key in obj for obj in objs] + ) + + elif issubclass(type(loss), (list, tuple)): + iterator = enumerate(loss) + + def key_check_fn(key, objs): + return all( + [ + issubclass(type(obj), (list, tuple)) and key < len(obj) + for obj in objs + ] + ) + + else: + raise TypeError( + f"Unsupported type {type(loss)} in the `loss` configuration." + ) + + for key, _loss in iterator: + if _loss is None: + continue + if not key_check_fn(key, (y_true, y_pred)): + raise KeyError( + f"The path: {current_path + (key,)} in " + "the `loss` argument, can't be found in " + "either the model's output (`y_pred`) or in the " + "labels (`y_true`)." + ) + + self._build_nested( + y_true[key], + y_pred[key], + _loss, + output_names[key], + current_path + (key,), + ) + + def build(self, y_true, y_pred): + loss = self._user_loss + loss_weights = self._user_loss_weights + flat_output_names = self.output_names + if ( + self.output_names + and isinstance(self._user_loss, dict) + and not isinstance(y_pred, dict) + ): + if set(self.output_names) == set(self._user_loss.keys()): + loss = [self._user_loss[name] for name in self.output_names] + if isinstance(self._user_loss_weights, dict): + loss_weights = [ + self._user_loss_weights[name] + for name in self.output_names + ] + else: + raise ValueError( + f"Expected keys {self.output_names} in loss dict, but " + f"found loss.keys()={list(self._user_loss.keys())}" + ) + + # Pytree leaf container + class WeightedLoss: + def __new__(cls, loss, weight): + if loss is None: + return None + return object.__new__(cls) + + def __init__(self, loss, weight): + self.loss = loss + self.weight = weight + + # pack the losses and the weights together + if loss_weights is not None: + try: + tree.assert_same_structure(loss, loss_weights) + except ValueError: + flat_loss_weights = tree.flatten(loss_weights) + if len(tree.flatten(loss)) != len(flat_loss_weights): + raise ValueError( + f"`loss_weights` must match the number of losses, " + f"got {len(tree.flatten(loss))} losses " + f"and {len(loss_weights)} weights." + ) + loss_weights = tree.pack_sequence_as(loss, flat_loss_weights) + loss = tree.map_structure( + lambda _loss, _weight: WeightedLoss(_loss, _weight), + loss, + loss_weights, + ) + else: + loss = tree.map_structure( + lambda _loss: WeightedLoss(_loss, None), loss + ) + + self._flat_losses = [] + + if ( + isinstance(loss, dict) + and issubclass(type(y_pred), (list, tuple)) + and set(loss.keys()) == set(flat_output_names) + and len(y_pred) == len(flat_output_names) + ): + y_pred = {name: y_p for name, y_p in zip(flat_output_names, y_pred)} + y_true = {name: y_t for name, y_t in zip(flat_output_names, y_true)} + elif ( + isinstance(loss, dict) + and not tree.is_nested(y_pred) + and set(loss.keys()) == set(flat_output_names) + and len(flat_output_names) == 1 + ): + y_pred = { + name: y_p for name, y_p in zip(flat_output_names, [y_pred]) + } + y_true = { + name: y_t for name, y_t in zip(flat_output_names, [y_true]) + } + + try: + output_names = tree.pack_sequence_as(y_pred, flat_output_names) + except: + inferred_flat_output_names = self._get_y_pred_output_names(y_pred) + output_names = tree.pack_sequence_as( + y_pred, inferred_flat_output_names + ) + + if not tree.is_nested(loss): + loss = tree.map_structure(lambda x: loss, y_pred) + + self._build_nested(y_true, y_pred, loss, output_names, ()) + + # Add `Mean` metric to the tracker for each loss. + if len(self._flat_losses) > 1: + for _loss in self._flat_losses: + name = f"{_loss.name}_loss" + self._tracker.add_to_store( + "metrics", metrics_module.Mean(name=name) + ) + + self._y_pred_build_structure = tree.map_structure( + lambda x: None, y_pred + ) + self._y_true_build_structure = tree.map_structure( + lambda x: None, y_true + ) + self.built = True + + def _get_y_pred_output_names(self, y_pred): + flat_y_pred = tree.flatten(y_pred) + if all((isinstance(x, KerasTensor) for x in flat_y_pred)): + output_names = [] + for tensor in flat_y_pred: + if hasattr(tensor, "_keras_history"): + output_names.append(tensor._keras_history.operation.name) + else: + output_names.append(tensor.name) + else: + output_names = [None] * len(flat_y_pred) + return output_names + + def __call__(self, y_true, y_pred, sample_weight=None): + with ops.name_scope(self.name): + return self.call(y_true, y_pred, sample_weight) + + def call(self, y_true, y_pred, sample_weight=None): + def resolve_path(path, object): + for _path in path: + object = object[_path] + return object + + if not tree.is_nested(y_true) and not tree.is_nested(y_pred): + # Fast path: single output case / no loss-tracking metric. + if not self.built: + self.build(y_true, y_pred) + # Although we are in the fast path, we still need to iterate + # through the losses to prevent the torch compiler from failing. + loss_values = [] + for path, loss_fn, loss_weight, _ in self._flat_losses: + y_t, y_p = ( + resolve_path(path, y_true), + resolve_path(path, y_pred), + ) + if sample_weight is not None and tree.is_nested(sample_weight): + _sample_weight = resolve_path(path, sample_weight) + else: + _sample_weight = sample_weight + value = ops.cast( + loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype + ) + if loss_weight is not None: + value = ops.multiply(value, loss_weight) + loss_values.append(value) + return loss_values[0] + + try: + tree.assert_same_structure(y_pred, y_true) + except ValueError: + # Check case where y_true is either flat or leaf + if ( + not tree.is_nested(y_true) + and hasattr(y_pred, "__len__") + and len(y_pred) == 1 + ): + y_true = [y_true] + + # Check case where y_pred is list/tuple and y_true is dict + elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict): + if set(self.output_names) == set(y_true.keys()): + y_true = [y_true[name] for name in self.output_names] + + try: + y_true = tree.pack_sequence_as(y_pred, y_true) + except: + # Check case where y_true has the same structure but uses + # different (but reconcilable) container types, + # e.g `list` vs `tuple`. + try: + tree.assert_same_paths(y_true, y_pred) + y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true)) + except: + try: + # Check case where loss is partially defined over y_pred + flat_y_true = tree.flatten(y_true) + flat_loss = tree.flatten(self._user_loss) + flat_loss_non_nones = [ + (i, loss) + for i, loss in enumerate(flat_loss) + if loss is not None + ] + assert len(flat_y_true) == len(flat_loss_non_nones) + y_true = [None] * len(flat_loss) + for y_t, (i, loss) in zip( + flat_y_true, flat_loss_non_nones + ): + y_true[i] = y_t + y_true = tree.pack_sequence_as(self._user_loss, y_true) + except: + y_true_struct = tree.map_structure( + lambda _: "*", y_true + ) + y_pred_struct = tree.map_structure( + lambda _: "*", y_pred + ) + raise ValueError( + "y_true and y_pred have different structures.\n" + f"y_true: {y_true_struct}\n" + f"y_pred: {y_pred_struct}\n" + ) + + if not self.built: + self.build(y_true, y_pred) + + try: + tree.assert_same_structure(self._y_pred_build_structure, y_pred) + except ValueError: + y_pred = tree.pack_sequence_as( + self._y_pred_build_structure, tree.flatten(y_pred) + ) + try: + tree.assert_same_structure(self._y_true_build_structure, y_true) + except ValueError: + y_true = tree.pack_sequence_as( + self._y_true_build_structure, tree.flatten(y_true) + ) + + # We need to add a dummy `None` if the model has only a single output. + metrics = [None] if len(self.metrics) == 0 else self.metrics + + # Iterate all losses in flat form. + loss_values = [] + + for (path, loss_fn, loss_weight, _), metric in zip( + self._flat_losses, metrics + ): + y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred) + if sample_weight is not None and tree.is_nested(sample_weight): + _sample_weight = resolve_path(path, sample_weight) + else: + _sample_weight = sample_weight + + value = ops.cast( + loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype + ) + # Record *unweighted* individual losses. + if metric: + metric.update_state( + loss_module.unscale_loss_for_distribution(value), + sample_weight=tree.flatten(y_p)[0].shape[0], + ) + if loss_weight is not None: + value = ops.multiply(value, loss_weight) + loss_values.append(value) + + if loss_values: + total_loss = sum(loss_values) + return total_loss + return None + + def get_config(self): + raise NotImplementedError + + @classmethod + def from_config(cls, config): + raise NotImplementedError diff --git a/keras/src/trainers/compile_utils_test.py b/keras/src/trainers/compile_utils_test.py new file mode 100644 index 000000000000..d27c5292b63d --- /dev/null +++ b/keras/src/trainers/compile_utils_test.py @@ -0,0 +1,622 @@ +from collections import namedtuple + +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import metrics as losses_module +from keras.src import metrics as metrics_module +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.trainers.compile_utils import CompileLoss +from keras.src.trainers.compile_utils import CompileMetrics + + +class TestCompileMetrics(testing.TestCase): + def test_single_output_case(self): + compile_metrics = CompileMetrics( + metrics=[metrics_module.MeanSquaredError()], + weighted_metrics=[metrics_module.MeanSquaredError()], + ) + # Test symbolic build + y_true = backend.KerasTensor((3, 4)) + y_pred = backend.KerasTensor((3, 4)) + compile_metrics.build(y_true, y_pred) + # Test eager build + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + sample_weight = np.array([1, 0.0, 1]) + compile_metrics.build(y_true, y_pred) + + # Test update / result / reset flow + compile_metrics.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + y_pred = np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]) + compile_metrics.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 2) + self.assertAllClose(result["mean_squared_error"], 0.055833336) + self.assertAllClose(result["weighted_mean_squared_error"], 0.0725) + + compile_metrics.reset_state() + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 2) + self.assertAllClose(result["mean_squared_error"], 0.0) + self.assertAllClose(result["weighted_mean_squared_error"], 0.0) + + def test_list_output_case(self): + compile_metrics = CompileMetrics( + metrics=[ + [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(), + ], + [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(), + ], + ], + weighted_metrics=[ + [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(), + ], + [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(), + ], + ], + ) + # Test symbolic build + y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + compile_metrics.build(y_true, y_pred) + self.assertEqual(len(compile_metrics.metrics), 8) + + # Test eager build + y_true = [ + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + ] + y_pred = [ + np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]), + np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]), + ] + sample_weight = np.array([1, 0.0, 1]) + compile_metrics.build(y_true, y_pred) + self.assertEqual(len(compile_metrics.metrics), 8) + + # Test update / result / reset flow + compile_metrics.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + y_pred = [ + np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]), + np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]), + ] + compile_metrics.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 8) + self.assertAllClose(result["mean_squared_error"], 0.055833336) + self.assertAllClose(result["weighted_mean_squared_error"], 0.0725) + + compile_metrics.reset_state() + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 8) + self.assertAllClose(result["mean_squared_error"], 0.0) + self.assertAllClose(result["weighted_mean_squared_error"], 0.0) + + def test_dict_output_case(self): + compile_metrics = CompileMetrics( + metrics={ + "output_1": [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), + ], + "output_2": [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), + ], + }, + weighted_metrics={ + "output_1": [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), + ], + "output_2": [ + metrics_module.MeanSquaredError(), + metrics_module.MeanSquaredError(name="mse"), + ], + }, + ) + # Test symbolic build + y_true = { + "output_1": backend.KerasTensor((3, 4)), + "output_2": backend.KerasTensor((3, 4)), + } + y_pred = { + "output_1": backend.KerasTensor((3, 4)), + "output_2": backend.KerasTensor((3, 4)), + } + compile_metrics.build(y_true, y_pred) + # Test eager build + y_true = { + "output_1": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "output_2": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + } + y_pred = { + "output_1": np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]), + "output_2": np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]), + } + sample_weight = np.array([1, 0.0, 1]) + compile_metrics.build(y_true, y_pred) + + # Test update / result / reset flow + compile_metrics.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + y_pred = { + "output_1": np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]), + "output_2": np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]), + } + compile_metrics.update_state( + y_true, y_pred, sample_weight=sample_weight + ) + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 8) + # Result values obtained from `tf.keras` + # m = tf.keras.metrics.MeanSquaredError() + # m.update_state(y_true, y_pred1, sample_weight=weight) + # m.update_state(y_true, y_pred2, sample_weight=weight) + # m.result().numpy() + self.assertAllClose(result["output_1_mean_squared_error"], 0.055833336) + self.assertAllClose(result["output_2_mean_squared_error"], 0.055833336) + self.assertAllClose(result["output_1_mse"], 0.055833336) + self.assertAllClose(result["output_2_mse"], 0.055833336) + self.assertAllClose( + result["output_1_weighted_mean_squared_error"], 0.0725 + ) + self.assertAllClose( + result["output_2_weighted_mean_squared_error"], 0.0725 + ) + self.assertAllClose(result["output_1_weighted_mse"], 0.0725) + self.assertAllClose(result["output_2_weighted_mse"], 0.0725) + + compile_metrics.reset_state() + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 8) + self.assertAllClose(result["output_1_mean_squared_error"], 0.0) + self.assertAllClose(result["output_2_mean_squared_error"], 0.0) + self.assertAllClose(result["output_1_weighted_mean_squared_error"], 0.0) + self.assertAllClose(result["output_2_weighted_mean_squared_error"], 0.0) + + def test_name_conversions(self): + compile_metrics = CompileMetrics( + metrics=["acc", "accuracy", "mse"], + weighted_metrics=[], + ) + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + compile_metrics.build(y_true, y_pred) + compile_metrics.update_state(y_true, y_pred, sample_weight=None) + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 3) + self.assertAllClose(result["acc"], 0.333333) + self.assertAllClose(result["accuracy"], 0.333333) + self.assertTrue("mse" in result) + + def test_custom_metric_function(self): + def my_custom_metric(y_true, y_pred): + return ops.mean(ops.square(y_true - y_pred), axis=-1) + + compile_metrics = CompileMetrics( + metrics=[my_custom_metric], + weighted_metrics=[], + ) + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + compile_metrics.build(y_true, y_pred) + compile_metrics.update_state(y_true, y_pred, sample_weight=None) + result = compile_metrics.result() + self.assertIsInstance(result, dict) + self.assertEqual(len(result), 1) + self.assertTrue("my_custom_metric" in result) + + def test_dict_outputs_ignore_mismatched_output_names(self): + """Tests that when output_names does not match dict keys, the correct + keys are used.""" + + # output_names represent internal op names that do not match dict keys. + compile_metrics = CompileMetrics( + metrics={ + "a": metrics_module.MeanSquaredError(), + "b": metrics_module.MeanSquaredError(), + }, + weighted_metrics=None, + output_names=["dense", "dense_1"], + ) + + # Symbolic build with dict outputs keyed by user-facing names. + y_true = { + "a": backend.KerasTensor((3, 2)), + "b": backend.KerasTensor((3, 2)), + } + y_pred = { + "a": backend.KerasTensor((3, 2)), + "b": backend.KerasTensor((3, 2)), + } + + # The build method should correctly map metrics for outputs 'a' and 'b', + # even when the op names do not match. + compile_metrics.build(y_true, y_pred) + + # Make the two outputs produce different MSEs to verify mapping. + y_true = { + "a": np.zeros((3, 2), dtype="float32"), + "b": np.zeros((3, 2), dtype="float32"), + } + y_pred = { + # MSE(a) = 0.0 + "a": np.zeros((3, 2), dtype="float32"), + # MSE(b) = 1.0 + "b": np.ones((3, 2), dtype="float32"), + } + compile_metrics.update_state(y_true, y_pred) + + result = compile_metrics.result() + self.assertIsInstance(result, dict) + + # Should expose metrics under the dict keys ('a', 'b'), + # and not the internal names. + self.assertIn("a_mean_squared_error", result) + self.assertIn("b_mean_squared_error", result) + self.assertAllClose(result["a_mean_squared_error"], 0.0) + self.assertAllClose(result["b_mean_squared_error"], 1.0, atol=1e-6) + + +class TestCompileLoss(testing.TestCase): + def test_single_output_case(self): + compile_loss = CompileLoss( + loss=losses_module.MeanSquaredError(), + ) + # Test symbolic build + y_true = backend.KerasTensor((3, 4)) + y_pred = backend.KerasTensor((3, 4)) + compile_loss.build(y_true, y_pred) + # Test eager build + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + compile_loss.build(y_true, y_pred) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.068333, atol=1e-5) + + def test_single_output_case_with_crossentropy_loss(self): + compile_loss = CompileLoss(loss="crossentropy") + + # Test symbolic build + y_true = backend.KerasTensor((3, 4)) + y_pred = backend.KerasTensor((3, 4)) + compile_loss.build(y_true, y_pred) + # Test eager build + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + compile_loss.build(y_true, y_pred) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.706595, atol=1e-5) + + @parameterized.parameters(True, False) + def test_list_output_case(self, broadcast): + if broadcast: + # Test broadcasting single loss to all outputs + compile_loss = CompileLoss( + loss="mse", + ) + else: + compile_loss = CompileLoss( + loss=["mse", "mse"], + ) + # Test symbolic build + y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + compile_loss.build(y_true, y_pred) + # Test eager build + y_true = [ + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + ] + y_pred = [ + np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + ] + compile_loss.build(y_true, y_pred) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.953333, atol=1e-5) + + @parameterized.parameters(True, False) + def test_dict_output_case(self, broadcast): + if broadcast: + # Test broadcasting single loss to all outputs + compile_loss = CompileLoss( + loss="mse", + ) + else: + compile_loss = CompileLoss( + loss={"a": "mse", "b": "mse"}, + ) + # Test symbolic build + y_true = { + "a": backend.KerasTensor((3, 4)), + "b": backend.KerasTensor((3, 4)), + } + y_pred = { + "a": backend.KerasTensor((3, 4)), + "b": backend.KerasTensor((3, 4)), + } + compile_loss.build(y_true, y_pred) + # Test eager build + y_true = { + "a": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + sample_weight = { + "a": np.array([1.0, 2.0, 3.0]), + "b": np.array([3.0, 2.0, 1.0]), + } + compile_loss.build(y_true, y_pred) + value = compile_loss(y_true, y_pred, sample_weight) + self.assertAllClose(value, 1.266666, atol=1e-5) + + def test_list_loss_dict_data(self): + compile_loss = CompileLoss(loss=["mse", "mae"], output_names=["b", "a"]) + y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))] + compile_loss.build(y_true, y_pred) + y_true = { + "a": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 1.07666, atol=1e-5) + + def test_struct_loss(self): + y_true = { + "a": { + "c": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "d": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + }, + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": { + "c": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "d": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + }, + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + loss = {"a": {"c": "mse", "d": "mae"}} + compile_loss = CompileLoss(loss=loss, output_names=["c", "d", "b"]) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_pred + ) + compile_loss.build(y_true_symb, y_pred_symb) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 1.07666, atol=1e-5) + + def test_struct_loss_valid_weights(self): + y_true = { + "a": np.array([1, 2]), + "b": np.array([1, 2]), + } + y_pred = { + "a": np.array([3, 4]), + "b": np.array([3, 4]), + } + loss = {"a": "mse", "b": "mse"} + compile_loss = CompileLoss( + loss=loss, + output_names=["a", "b"], + loss_weights={ + "a": np.ones((2,)), + "b": np.zeros((2,)), + }, + ) + compile_loss.build(y_true, y_pred) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 4) + + # Metrics still report unweighted loss. + a_loss_mean, b_loss_mean = compile_loss.metrics + self.assertEqual(a_loss_mean.result(), 4) + self.assertEqual(b_loss_mean.result(), 4) + + def test_struct_loss_invalid_weights(self): + y_true = { + "a": { + "c": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "d": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + }, + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": { + "c": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "d": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + }, + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + loss = {"a": {"c": "mse", "d": "mae"}} + compile_loss = CompileLoss( + loss=loss, output_names=["c", "d", "b"], loss_weights=[1] + ) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_pred + ) + with self.assertRaisesRegex( + ValueError, "must match the number of losses" + ): + compile_loss.build(y_true_symb, y_pred_symb) + + def test_struct_loss_indice_path(self): + y_true = { + "a": ( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + ), + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": ( + np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + ), + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + loss = {"a": ["mse", "mae"]} + compile_loss = CompileLoss(loss=loss, output_names=["c", "d", "b"]) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_pred + ) + compile_loss.build(y_true_symb, y_pred_symb) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 1.07666, atol=1e-5) + + def test_struct_loss_namedtuple(self): + Point = namedtuple("Point", ["x", "y"]) + y_true = { + "a": Point( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + ), + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": Point( + np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + ), + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + loss = {"a": Point("mse", "mae")} + compile_loss = CompileLoss(loss=loss, output_names=["c", "d", "b"]) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_pred + ) + compile_loss.build(y_true_symb, y_pred_symb) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 1.07666, atol=1e-5) + + def test_struct_loss_invalid_path(self): + y_true = { + "a": { + "c": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "d": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + }, + "b": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]), + } + y_pred = { + "a": { + "c": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), + "d": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + }, + "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), + } + loss = {"a": {"c": "mse"}, "b": {"d": "mae"}} + compile_loss = CompileLoss(loss=loss, output_names=["c", "d", "b"]) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((3, 4)), y_pred + ) + with self.assertRaisesRegex( + KeyError, "can't be found in the model's output" + ): + compile_loss.build(y_true_symb, y_pred_symb) + + def test_different_container_types(self): + y1, y2, y3 = np.array([[1]]), np.array([[2]]), np.array([[3]]) + y_true = ([{"a": y1}, {"b": ([y2], y3)}],) + y_pred = [({"a": y1}, {"b": [(y2,), y3]})] + loss = "mse" + compile_loss = CompileLoss(loss=loss, output_names=["a", "b", "c"]) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((1, 1)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((1, 1)), y_pred + ) + compile_loss.build(y_true_symb, y_pred_symb) + compile_loss(y_true, y_pred) + + def test_structure_mismatch(self): + y_true = [np.array([[1]]), np.array([[1]])] + y_pred = [np.array([[1]]), np.array([[1]])] + loss = ["mse", "mse"] + compile_loss = CompileLoss(loss=loss, output_names=["a", "b"]) + y_true_symb = tree.map_structure( + lambda _: backend.KerasTensor((1, 1)), y_true + ) + y_pred_symb = tree.map_structure( + lambda _: backend.KerasTensor((1, 1)), y_pred + ) + compile_loss.build(y_true_symb, y_pred_symb) + with self.assertRaisesRegex( + ValueError, "y_true and y_pred have different structures." + ): + wrong_struc_y_true = [np.array([[1]])] + compile_loss(wrong_struc_y_true, y_pred) + + @parameterized.parameters( + ["mse", None, None], + [None, "mse", None], + [None, None, "mse"], + [None, "mse", "mse"], + ["mse", None, "mse"], + ) + def test_y_true_partial_y_pred_span(self, *loss_conf): + loss_conf = list(loss_conf) + ones = np.ones((320, 3)) + zeros = np.zeros((320, 3)) + twos = np.ones((320, 3)) * 2 + y_pred = [zeros, ones, twos] + y_true = [y for y, loss in zip(y_pred, loss_conf) if loss is not None] + y_true = y_true[0] if len(y_true) == 1 else y_true + compile_loss = CompileLoss(loss=loss_conf, output_names=["a", "b", "c"]) + # build call + compile_loss(y_true, y_pred) + # built call + loss = compile_loss(y_true, y_pred) + self.assertEqual(loss, 0.0) diff --git a/keras/src/trainers/data_adapters/__init__.py b/keras/src/trainers/data_adapters/__init__.py new file mode 100644 index 000000000000..f0932d36730e --- /dev/null +++ b/keras/src/trainers/data_adapters/__init__.py @@ -0,0 +1,205 @@ +import types + +from keras.src.distribution import distribution_lib +from keras.src.trainers.data_adapters import array_data_adapter +from keras.src.trainers.data_adapters import data_adapter +from keras.src.trainers.data_adapters import py_dataset_adapter +from keras.src.trainers.data_adapters.array_data_adapter import ArrayDataAdapter +from keras.src.trainers.data_adapters.generator_data_adapter import ( + GeneratorDataAdapter, +) +from keras.src.trainers.data_adapters.grain_dataset_adapter import ( + GrainDatasetAdapter, +) +from keras.src.trainers.data_adapters.py_dataset_adapter import PyDatasetAdapter +from keras.src.trainers.data_adapters.tf_dataset_adapter import TFDatasetAdapter +from keras.src.trainers.data_adapters.torch_data_loader_adapter import ( + TorchDataLoaderAdapter, +) + + +def get_data_adapter( + x, + y=None, + sample_weight=None, + batch_size=None, + steps_per_epoch=None, + shuffle=False, + class_weight=None, +): + # Allow passing a custom data adapter. + if isinstance(x, data_adapter.DataAdapter): + return x + + # Check for multi-process/worker distribution. + distribution = distribution_lib.distribution() + if ( + distribution is not None + and getattr(distribution, "_is_multi_process", False) + and getattr(distribution, "auto_shard_dataset", False) + and not is_tf_dataset(x) + ): + raise ValueError( + "When using a multi-worker distribution with auto-sharding enabled, " + "the data must be provided as a `tf.data.Dataset` instance. " + f"Received: type(x)={type(x)}. " + "If the dataset is already sharded across workers, then set " + "`distribution.auto_shard_dataset = False`." + ) + + if array_data_adapter.can_convert_arrays((x, y, sample_weight)): + return ArrayDataAdapter( + x, + y, + sample_weight=sample_weight, + class_weight=class_weight, + shuffle=shuffle, + batch_size=batch_size, + steps=steps_per_epoch, + ) + elif is_tf_dataset(x): + # Unsupported args: y, sample_weight, shuffle + if y is not None: + raise_unsupported_arg("y", "the targets", "tf.data.Dataset") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "tf.data.Dataset" + ) + return TFDatasetAdapter( + x, class_weight=class_weight, distribution=distribution + ) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a tf.data.Dataset. The Dataset is " + # "expected to already be shuffled " + # "(via `.shuffle(tf.data.AUTOTUNE)`)" + # ) + elif isinstance(x, py_dataset_adapter.PyDataset): + if y is not None: + raise_unsupported_arg("y", "the targets", "PyDataset") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "PyDataset" + ) + return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle) + # TODO: should we warn or not? + # if x.num_batches is None and shuffle: + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a infinite PyDataset. The " + # "PyDataset is expected to already be shuffled." + # ) + elif is_torch_dataloader(x): + if y is not None: + raise_unsupported_arg("y", "the targets", "torch DataLoader") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "torch DataLoader" + ) + if class_weight is not None: + raise ValueError( + "Argument `class_weight` is not supported for torch " + f"DataLoader inputs. You can modify your `__getitem__ ` method" + " to return input tensor, label and class_weight. " + "Alternatively, use a custom training loop. See the User Guide " + "https://keras.io/guides/custom_train_step_in_torch/" + "#supporting-sampleweight-amp-classweight for more details. " + f"Received: class_weight={class_weight}" + ) + return TorchDataLoaderAdapter(x) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a torch DataLoader. The DataLoader " + # "is expected to already be shuffled." + # ) + elif is_grain_dataset(x): + if y is not None: + raise_unsupported_arg( + "y", "the targets", "grain.Dataset and grain.DataLoader" + ) + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", + "the sample weights", + "grain.Dataset and grain.DataLoader", + ) + if class_weight is not None: + raise ValueError( + "Argument `class_weight` is not supported for grain.Dataset " + f"and grain.DataLoader inputs. You can modify your " + "`__getitem__ ` method to return input tensor, label and " + "class_weight. " + f"Received: class_weight={class_weight}" + ) + return GrainDatasetAdapter(x) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a grain dataset. The grain dataset " + # "is expected to already be shuffled." + # ) + elif isinstance(x, types.GeneratorType): + if y is not None: + raise_unsupported_arg("y", "the targets", "PyDataset") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "PyDataset" + ) + if class_weight is not None: + raise ValueError( + "Argument `class_weight` is not supported for Python " + f"generator inputs. Received: class_weight={class_weight}" + ) + return GeneratorDataAdapter(x) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a generator. The generator " + # "is expected to yield already-shuffled data." + # ) + else: + raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})") + + +def raise_unsupported_arg(arg_name, arg_description, input_type): + raise ValueError( + f"When providing `x` as a {input_type}, `{arg_name}` " + f"should not be passed. Instead, {arg_description} should " + f"be included as part of the {input_type}." + ) + + +def is_tf_dataset(x): + if hasattr(x, "__class__"): + for parent in x.__class__.__mro__: + if parent.__name__ in ( + "DatasetV2", + "DistributedDataset", + "DistributedDatasetsFromFunction", + ) and "tensorflow.python." in str(parent.__module__): + return True + return False + + +def is_torch_dataloader(x): + if hasattr(x, "__class__"): + for parent in x.__class__.__mro__: + if parent.__name__ == "DataLoader" and "torch.utils.data" in str( + parent.__module__ + ): + return True + return False + + +def is_grain_dataset(x): + if hasattr(x, "__class__"): + for parent in x.__class__.__mro__: + if parent.__name__ in ( + "MapDataset", + "IterDataset", + "DataLoader", + ) and "grain" in str(parent.__module__): + return True + return False diff --git a/keras/src/trainers/data_adapters/array_data_adapter.py b/keras/src/trainers/data_adapters/array_data_adapter.py new file mode 100644 index 000000000000..87db9aac7032 --- /dev/null +++ b/keras/src/trainers/data_adapters/array_data_adapter.py @@ -0,0 +1,378 @@ +import functools +import math + +import numpy as np + +from keras.src import tree +from keras.src.trainers.data_adapters import array_slicing +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class ArrayDataAdapter(DataAdapter): + """Adapter for array-like objects, e.g. TF/JAX Tensors, NumPy arrays.""" + + def __init__( + self, + x, + y=None, + sample_weight=None, + batch_size=None, + steps=None, + shuffle=False, + class_weight=None, + ): + if not can_convert_arrays((x, y, sample_weight)): + raise ValueError( + "Expected all elements of `x` to be array-like. " + f"Received invalid types: x={x}" + ) + + if sample_weight is not None: + if class_weight is not None: + raise ValueError( + "You cannot `class_weight` and `sample_weight` " + "at the same time." + ) + if tree.is_nested(y): + if isinstance(sample_weight, (list, tuple, dict)): + try: + tree.assert_same_structure(y, sample_weight) + except ValueError: + raise ValueError( + "You should provide one `sample_weight` array per " + "output in `y`. The two structures did not match:\n" + f"- y: {y}\n" + f"- sample_weight: {sample_weight}\n" + ) + else: + is_samplewise = len(sample_weight.shape) == 1 or ( + len(sample_weight.shape) == 2 + and sample_weight.shape[1] == 1 + ) + if not is_samplewise: + raise ValueError( + "For a model with multiple outputs, when providing " + "a single `sample_weight` array, it should only " + "have one scalar score per sample " + "(i.e. shape `(num_samples,)`). If you want to use " + "non-scalar sample weights, pass a `sample_weight` " + "argument with one array per model output." + ) + # Replicate the same sample_weight array on all outputs. + sample_weight = tree.map_structure( + lambda _: sample_weight, y + ) + if class_weight is not None: + if tree.is_nested(y): + raise ValueError( + "`class_weight` is only supported for Models with a single " + "output." + ) + sample_weight = data_adapter_utils.class_weight_to_sample_weights( + y, class_weight + ) + + inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight) + + data_adapter_utils.check_data_cardinality(inputs) + num_samples = set( + i.shape[0] for i in tree.flatten(inputs) if i is not None + ).pop() + self._num_samples = num_samples + self._inputs = inputs + + # If batch_size is not passed but steps is, calculate from the input + # data. Defaults to `32` for backwards compatibility. + if not batch_size: + batch_size = int(math.ceil(num_samples / steps)) if steps else 32 + + self._size = int(math.ceil(num_samples / batch_size)) + self._batch_size = batch_size + self._partial_batch_size = num_samples % batch_size + self._shuffle = shuffle + + def get_numpy_iterator(self): + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="numpy" + ) + + def slice_and_convert_to_numpy(sliceable, indices=None): + x = sliceable[indices] + x = sliceable.convert_to_numpy(x) + return x + + return self._get_iterator(slice_and_convert_to_numpy, inputs) + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + shuffle = self._shuffle + batch_size = self._batch_size + num_samples = self._num_samples + num_full_batches = int(self._num_samples // batch_size) + + # Vectorized version of shuffle. + # This is a performance improvement over using `from_tensor_slices`. + # The indices of the data are shuffled and batched, and these indices + # are then zipped with the data and used to extract a batch of the data + # at each step. The performance improvements here come from: + # 1. vectorized batch using gather + # 2. parallelized map + # 3. pipelined permutation generation + # 4. optimized permutation batching + # 5. disabled static optimizations + + indices_dataset = tf.data.Dataset.range(1) + + def permutation(_): + # It turns out to be more performant to make a new set of indices + # rather than reusing the same range Tensor. (presumably because of + # buffer forwarding.) + indices = tf.range(num_samples, dtype=tf.int64) + if shuffle and shuffle != "batch": + indices = tf.random.shuffle(indices) + return indices + + # We prefetch a single element. Computing large permutations can take + # quite a while so we don't want to wait for prefetching over an epoch + # boundary to trigger the next permutation. On the other hand, too many + # simultaneous shuffles can contend on a hardware level and degrade all + # performance. + indices_dataset = indices_dataset.map(permutation).prefetch(1) + + def slice_batch_indices(indices): + """Convert a Tensor of indices into a dataset of batched indices. + + This step can be accomplished in several ways. The most natural is + to slice the Tensor in a Dataset map. (With a condition on the upper + index to handle the partial batch.) However it turns out that + coercing the Tensor into a shape which is divisible by the batch + size (and handling the last partial batch separately) allows for a + much more favorable memory access pattern and improved performance. + + Args: + indices: Tensor which determines the data order for an entire + epoch. + + Returns: + A Dataset of batched indices. + """ + num_in_full_batch = num_full_batches * batch_size + first_k_indices = tf.slice(indices, [0], [num_in_full_batch]) + first_k_indices = tf.reshape( + first_k_indices, [num_full_batches, batch_size] + ) + + flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices) + if self._partial_batch_size: + index_remainder = tf.data.Dataset.from_tensors( + tf.slice( + indices, [num_in_full_batch], [self._partial_batch_size] + ) + ) + flat_dataset = flat_dataset.concatenate(index_remainder) + + return flat_dataset + + def slice_inputs(indices_dataset, inputs): + """Slice inputs into a Dataset of batches. + + Given a Dataset of batch indices and the unsliced inputs, + this step slices the inputs in a parallelized fashion + and produces a dataset of input batches. + + Args: + indices_dataset: A Dataset of batched indices. + inputs: A python data structure that contains the inputs, + targets, and possibly sample weights. + + Returns: + A Dataset of input batches matching the batch indices. + """ + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="tensorflow" + ) + inputs = tree.lists_to_tuples(inputs) + + dataset = tf.data.Dataset.zip( + (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat()) + ) + + def grab_batch(i, data): + def grab_one(x): + if isinstance(x, array_slicing.TensorflowSparseWrapper): + return array_slicing.slice_tensorflow_sparse_wrapper( + x, i + ) + if isinstance(x, (list, tuple, dict)): + return None + if tf.is_tensor(x): + return tf.gather(x, i, axis=0) + return x + + return tree.traverse(grab_one, data) + + dataset = dataset.map( + grab_batch, num_parallel_calls=tf.data.AUTOTUNE + ) + + # Default optimizations are disabled to avoid the overhead of + # (unnecessary) input pipeline graph serialization & deserialization + options = tf.data.Options() + options.experimental_optimization.apply_default_optimizations = ( + False + ) + if self._shuffle: + options.experimental_external_state_policy = ( + tf.data.experimental.ExternalStatePolicy.IGNORE + ) + dataset = dataset.with_options(options) + return dataset + + indices_dataset = indices_dataset.flat_map(slice_batch_indices) + if shuffle == "batch": + indices_dataset = indices_dataset.map(tf.random.shuffle) + + dataset = slice_inputs(indices_dataset, self._inputs) + + options = tf.data.Options() + options.experimental_distribute.auto_shard_policy = ( + tf.data.experimental.AutoShardPolicy.DATA + ) + dataset = dataset.with_options(options) + return dataset.prefetch(tf.data.AUTOTUNE) + + def get_jax_iterator(self): + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="jax" + ) + + def slice_and_convert_to_jax(sliceable, indices=None): + x = sliceable[indices] + x = sliceable.convert_to_jax_compatible(x) + return x + + return self._get_iterator(slice_and_convert_to_jax, inputs) + + def get_torch_dataloader(self): + import torch + + from keras.src.backend.torch.core import convert_to_tensor + + class ArrayDataset(torch.utils.data.Dataset): + def __init__(self, array): + self.array = array + + def __getitems__(self, indices): + def slice_and_convert(sliceable): + x = sliceable[indices] + x = sliceable.convert_to_torch_compatible(x) + x = convert_to_tensor(x) + return x + + return tree.map_structure( + slice_and_convert, self.array, none_is_leaf=False + ) + + def __len__(self): + return len(self.array[0]) + + class RandomBatchSampler(torch.utils.data.Sampler): + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + for batch in self.sampler: + yield [batch[i] for i in torch.randperm(len(batch))] + + def __len__(self): + return len(self.sampler) + + if self._shuffle == "batch": + batch_sampler = RandomBatchSampler( + torch.utils.data.BatchSampler( + range(self._num_samples), + batch_size=self._batch_size, + drop_last=False, + ) + ) + elif self._shuffle: + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.RandomSampler(range(self._num_samples)), + batch_size=self._batch_size, + drop_last=False, + ) + else: + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.SequentialSampler(range(self._num_samples)), + batch_size=self._batch_size, + drop_last=False, + ) + + # Because ArrayDataset.__getitems__ returns full batches organized in + # the expected structure, there is nothing to collate. + def no_op_collate(batch): + return batch + + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="torch" + ) + dataset = ArrayDataset(inputs) + return torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, collate_fn=no_op_collate + ) + + def _get_iterator(self, slice_and_convert_fn, inputs): + global_permutation = None + if self._shuffle and self._shuffle != "batch": + global_permutation = np.random.permutation(self._num_samples) + + for i in range(self._size): + start = i * self._batch_size + stop = min((i + 1) * self._batch_size, self._num_samples) + if self._shuffle == "batch": + indices = np.random.permutation(stop - start) + start + elif self._shuffle: + indices = global_permutation[start:stop] + else: + indices = slice(start, stop) + + slice_indices_and_convert_fn = functools.partial( + slice_and_convert_fn, indices=indices + ) + yield tree.map_structure( + slice_indices_and_convert_fn, inputs, none_is_leaf=False + ) + + @property + def num_batches(self): + return self._size + + @property + def batch_size(self): + return self._batch_size + + @property + def has_partial_batch(self): + return self._partial_batch_size > 0 + + @property + def partial_batch_size(self): + return self._partial_batch_size or None + + +def can_convert_arrays(arrays): + """Check if array like-inputs can be handled by `ArrayDataAdapter` + + Args: + inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like. + + Returns: + `True` if `arrays` can be handled by `ArrayDataAdapter`, `False` + otherwise. + """ + + return all( + tree.flatten(tree.map_structure(array_slicing.can_slice_array, arrays)) + ) diff --git a/keras/src/trainers/data_adapters/array_data_adapter_test.py b/keras/src/trainers/data_adapters/array_data_adapter_test.py new file mode 100644 index 000000000000..dc26c2fc277b --- /dev/null +++ b/keras/src/trainers/data_adapters/array_data_adapter_test.py @@ -0,0 +1,301 @@ +import jax +import jax.experimental.sparse as jax_sparse +import numpy as np +import pandas +import scipy +import tensorflow as tf +import torch +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters import array_data_adapter + + +class TestArrayDataAdapter(testing.TestCase): + def make_array(self, array_type, shape, dtype): + x = np.array([[i] * shape[1] for i in range(shape[0])], dtype=dtype) + if array_type == "np": + return x + elif array_type == "tf": + return tf.constant(x) + elif array_type == "tf_ragged": + return tf.RaggedTensor.from_tensor(x) + elif array_type == "tf_sparse": + return tf.sparse.from_dense(x) + elif array_type == "jax": + return jax.numpy.array(x) + elif array_type == "jax_sparse": + return jax_sparse.BCOO.fromdense(x) + elif array_type == "torch": + return torch.as_tensor(x) + elif array_type == "pandas_data_frame": + return pandas.DataFrame(x) + elif array_type == "pandas_series": + return pandas.Series(x[:, 0]) + elif array_type == "scipy_sparse": + return scipy.sparse.coo_matrix(x) + + @parameterized.named_parameters( + named_product( + array_type=[ + "np", + "tf", + "tf_ragged", + "tf_sparse", + "jax", + "jax_sparse", + "torch", + "pandas_data_frame", + "pandas_series", + "scipy_sparse", + ], + array_dtype=["float32", "float64"], + shuffle=[False, "batch", True], + ) + ) + def test_basic_flow(self, array_type, array_dtype, shuffle): + x = self.make_array(array_type, (34, 4), array_dtype) + y = self.make_array(array_type, (34, 2), "int32") + xdim1 = 1 if array_type == "pandas_series" else 4 + ydim1 = 1 if array_type == "pandas_series" else 2 + + adapter = array_data_adapter.ArrayDataAdapter( + x, + y=y, + sample_weight=None, + batch_size=16, + steps=None, + shuffle=shuffle, + ) + self.assertEqual(adapter.num_batches, 3) + self.assertEqual(adapter.batch_size, 16) + self.assertEqual(adapter.has_partial_batch, True) + self.assertEqual(adapter.partial_batch_size, 2) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + if array_type == "tf_ragged": + expected_class = tf.RaggedTensor + xdim1 = None + ydim1 = None + elif array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"): + expected_class = tf.SparseTensor + else: + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + if array_type in ("tf_sparse", "jax_sparse", "scipy_sparse"): + expected_class = jax_sparse.JAXSparse + else: + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + x_order = [] + y_order = [] + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual( + backend.standardize_dtype(bx.dtype), backend.floatx() + ) + self.assertEqual(backend.standardize_dtype(by.dtype), "int32") + if i < 2: + self.assertEqual(bx.shape, (16, xdim1)) + self.assertEqual(by.shape, (16, ydim1)) + else: + self.assertEqual(bx.shape, (2, xdim1)) + self.assertEqual(by.shape, (2, ydim1)) + + if isinstance(bx, tf.SparseTensor): + bx = tf.sparse.to_dense(bx) + by = tf.sparse.to_dense(by) + if isinstance(bx, jax_sparse.JAXSparse): + bx = bx.todense() + by = by.todense() + x_batch_order = [float(bx[j, 0]) for j in range(bx.shape[0])] + y_batch_order = [float(by[j, 0]) for j in range(by.shape[0])] + x_order.extend(x_batch_order) + y_order.extend(y_batch_order) + + if shuffle == "batch": + self.assertAllClose( + sorted(x_batch_order), range(i * 16, i * 16 + bx.shape[0]) + ) + + self.assertAllClose(x_order, y_order) + if shuffle: + self.assertNotAllClose(x_order, list(range(34))) + else: + self.assertAllClose(x_order, list(range(34))) + + def test_multi_inputs_and_outputs(self): + x1 = np.random.random((34, 1)) + x2 = np.random.random((34, 2)) + y1 = np.random.random((34, 3)) + y2 = np.random.random((34, 4)) + sw = np.random.random((34,)) + adapter = array_data_adapter.ArrayDataAdapter( + x={"x1": x1, "x2": x2}, + y=[y1, y2], + sample_weight=sw, + batch_size=16, + steps=None, + shuffle=False, + ) + gen = adapter.get_numpy_iterator() + for i, batch in enumerate(gen): + self.assertEqual(len(batch), 3) + bx, by, bw = batch + self.assertIsInstance(bx, dict) + self.assertIsInstance(by, list) + self.assertIsInstance(bw, list) + + self.assertIsInstance(bx["x1"], np.ndarray) + self.assertIsInstance(bx["x2"], np.ndarray) + self.assertIsInstance(by[0], np.ndarray) + self.assertIsInstance(by[1], np.ndarray) + self.assertIsInstance(bw[0], np.ndarray) + self.assertIsInstance(bw[1], np.ndarray) + + self.assertEqual(bx["x1"].dtype, by[0].dtype) + self.assertEqual(bx["x1"].dtype, backend.floatx()) + if i < 2: + self.assertEqual(bx["x1"].shape, (16, 1)) + self.assertEqual(bx["x2"].shape, (16, 2)) + self.assertEqual(by[0].shape, (16, 3)) + self.assertEqual(by[1].shape, (16, 4)) + self.assertEqual(bw[0].shape, (16,)) + self.assertEqual(bw[1].shape, (16,)) + else: + self.assertEqual(bx["x1"].shape, (2, 1)) + self.assertEqual(by[0].shape, (2, 3)) + self.assertEqual(bw[0].shape, (2,)) + self.assertEqual(bw[1].shape, (2,)) + ds = adapter.get_tf_dataset() + for i, batch in enumerate(ds): + self.assertEqual(len(batch), 3) + bx, by, bw = batch + self.assertIsInstance(bx, dict) + # NOTE: the y list was converted to a tuple for tf.data + # compatibility. + self.assertIsInstance(by, tuple) + self.assertIsInstance(bw, tuple) + + self.assertIsInstance(bx["x1"], tf.Tensor) + self.assertIsInstance(bx["x2"], tf.Tensor) + self.assertIsInstance(by[0], tf.Tensor) + self.assertIsInstance(by[1], tf.Tensor) + self.assertIsInstance(bw[0], tf.Tensor) + self.assertIsInstance(bw[1], tf.Tensor) + + self.assertEqual(bx["x1"].dtype, by[0].dtype) + self.assertEqual(bx["x1"].dtype, backend.floatx()) + if i < 2: + self.assertEqual(tuple(bx["x1"].shape), (16, 1)) + self.assertEqual(tuple(bx["x2"].shape), (16, 2)) + self.assertEqual(tuple(by[0].shape), (16, 3)) + self.assertEqual(tuple(by[1].shape), (16, 4)) + self.assertEqual(tuple(bw[0].shape), (16,)) + self.assertEqual(tuple(bw[1].shape), (16,)) + else: + self.assertEqual(tuple(bx["x1"].shape), (2, 1)) + self.assertEqual(tuple(by[0].shape), (2, 3)) + self.assertEqual(tuple(bw[0].shape), (2,)) + self.assertEqual(tuple(bw[1].shape), (2,)) + + @parameterized.named_parameters( + named_product(target_encoding=["int", "categorical"]) + ) + def test_class_weights(self, target_encoding): + x = np.random.random((4, 2)) + if target_encoding == "int": + y = np.array([[0], [1], [2], [3]], dtype="int32") + else: + y = np.array( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype="float32", + ) + + class_weight = { + 0: 0.1, + 1: 0.2, + 2: 0.3, + 3: 0.4, + } + adapter = array_data_adapter.ArrayDataAdapter( + x, + y=y, + class_weight=class_weight, + batch_size=16, + ) + gen = adapter.get_numpy_iterator() + for batch in gen: + self.assertEqual(len(batch), 3) + _, _, bw = batch + self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4]) + + def test_errors(self): + x = np.random.random((34, 1)) + y = np.random.random((34, 3)) + sw = np.random.random((34,)) + cw = { + 0: 0.1, + 1: 0.2, + 2: 0.3, + 3: 0.4, + } + + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter(x="Invalid") + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter(x=x, y="Invalid") + with self.assertRaisesRegex( + ValueError, "Expected all elements of `x` to be array-like" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=y, sample_weight="Invalid" + ) + + with self.assertRaisesRegex( + ValueError, "You cannot `class_weight` and `sample_weight`" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=y, sample_weight=sw, class_weight=cw + ) + + nested_y = ({"x": x, "y": y},) + with self.assertRaisesRegex( + ValueError, "You should provide one `sample_weight` array per" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, sample_weight=[] + ) + + tensor_sw = self.make_array("tf", (34, 2), "int32") + with self.assertRaisesRegex( + ValueError, "For a model with multiple outputs, when providing" + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, sample_weight=tensor_sw + ) + + with self.assertRaisesRegex( + ValueError, + "`class_weight` is only supported for Models with a single", + ): + array_data_adapter.ArrayDataAdapter( + x=x, y=nested_y, class_weight=cw + ) diff --git a/keras/src/trainers/data_adapters/array_slicing.py b/keras/src/trainers/data_adapters/array_slicing.py new file mode 100644 index 000000000000..74622ebb4aee --- /dev/null +++ b/keras/src/trainers/data_adapters/array_slicing.py @@ -0,0 +1,520 @@ +import collections +import math + +import numpy as np + +from keras.src import backend +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.utils.module_utils import tensorflow as tf + +try: + import pandas +except ImportError: + pandas = None + + +# Leave jax, tf, and torch arrays off this list. Instead we will use +# `__array__` to detect these types. Doing so allows us to avoid importing a +# backend framework we are not currently using just to do type-checking. +ARRAY_TYPES = (np.ndarray,) +if pandas: + ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame) + + +class Sliceable: + """`Sliceable` wrapping a tensor. + + A `Sliceable` implements the subscript operator to slice or index against + the first dimension of the array. It also has conversion methods for each + one of the backends. + + Args: + array: the native array or tensor to wrap. + + Attributes: + shape: the shape of the full dense native array. + """ + + def __init__(self, array): + self.array = array + + def __getitem__(self, indices): + """Select elements in the 0th dimension. + + Args: + indices: the indices to select. Only needs to support one dimension, + the 0th dimension. Should support a `slice` or a list, tuple, + `np.array` or 1D tensor. + Returns: A slice of `self.array`. + """ + return self.array[indices] + + @classmethod + def cast(cls, x, dtype): + """Cast a tensor to a different dtype. + + Only called on a full array as provided by the user. + + Args: + x: the tensor to cast. + Returns: the cast tensor. + """ + return x.astype(dtype) + + @classmethod + def convert_to_numpy(cls, x): + """Convert a tensor to a NumPy array. + + Only called after slicing using `__getitem__`. + + Args: + x: the tensor to convert. + Returns: the converted tensor. + """ + return x + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + """Convert a tensor to something compatible with `tf.data.Dataset`. + + This can be a NumPy array, `tf.Tensor` or any other type of tensor that + `tf.data.Dataset.from_tensors` can consume. + Only called on a full array as provided by the user. + + Args: + x: the tensor to convert. + Returns: converted version tensor. + """ + return x + + @classmethod + def convert_to_jax_compatible(cls, x): + """Convert a tensor to something that the JAX backend can consume. + + This can be a `JAX` array, `JAXSparse` or a NumPy array. + Only called after slicing using `__getitem__`. + Used to convert sparse tensors and densify ragged tensors. + + Args: + x: the tensor to convert. + Returns: the converted tensor. + """ + return x + + @classmethod + def convert_to_torch_compatible(cls, x): + """Convert a tensor to something that the Torch backend can consume. + + This can be a Torch tensor, NumPy array or any other type of tensor that + `keras.backend.torch.core.convert_to_tensor()` can consume. + Only called after slicing using `__getitem__`. + Used to densify sparse tensors and ragged tensors. + + Args: + x: the tensor to convert. + Returns: the converted tensor. + """ + return x + + +class NumpySliceable(Sliceable): + pass + + +class TensorflowSliceable(Sliceable): + def __getitem__(self, indices): + from keras.src.utils.module_utils import tensorflow as tf + + if isinstance(indices, slice): + return self.array[indices] + else: + return tf.gather(self.array, indices, axis=0) + + @classmethod + def cast(cls, x, dtype): + from keras.src.backend.tensorflow.core import cast + + return cast(x, dtype) + + @classmethod + def convert_to_numpy(cls, x): + from keras.src.backend.tensorflow.core import convert_to_numpy + + return convert_to_numpy(x) + + +class TensorflowRaggedSliceable(TensorflowSliceable): + @classmethod + def convert_to_jax_compatible(cls, x): + return cls.convert_to_numpy(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + return x.to_tensor() + + +class TensorflowSparseSliceable(TensorflowSliceable): + def __init__(self, array): + super().__init__(to_tensorflow_sparse_wrapper(array)) + + @property + def shape(self): + return self.array.sparse.shape + + def __getitem__(self, indices): + return slice_tensorflow_sparse_wrapper(self.array, indices) + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + return to_tensorflow_sparse_wrapper(x) + + @classmethod + def convert_to_jax_compatible(cls, x): + return data_adapter_utils.tf_sparse_to_jax_sparse(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + from keras.src.backend.tensorflow import sparse as tf_sparse + + return tf_sparse.sparse_to_dense(x) + + +class JaxSparseSliceable(Sliceable): + def __getitem__(self, indices): + return self.array[indices, ...] + + @classmethod + def convert_to_numpy(cls, x): + from keras.src.backend.jax.core import convert_to_numpy + + return convert_to_numpy(x) + + @classmethod + def convert_to_tf_dataset_compatible(cls, array): + return to_tensorflow_sparse_wrapper( + data_adapter_utils.jax_sparse_to_tf_sparse(array) + ) + + @classmethod + def convert_to_torch_compatible(cls, x): + return x.todense() + + +class TorchSliceable(Sliceable): + @classmethod + def cast(cls, x, dtype): + from keras.src.backend.torch.core import cast + + return cast(x, dtype) + + @classmethod + def convert_to_numpy(cls, x): + from keras.src.backend.torch.core import convert_to_numpy + + return convert_to_numpy(x) + + +class PandasSliceable(Sliceable): + def __getitem__(self, indices): + return self.array.iloc[indices] + + @classmethod + def convert_to_numpy(cls, x): + return x.to_numpy() + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + return cls.convert_to_numpy(x) + + @classmethod + def convert_to_jax_compatible(cls, x): + return cls.convert_to_numpy(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + return cls.convert_to_numpy(x) + + +class PandasDataFrameSliceable(PandasSliceable): + pass + + +class PandasSeriesSliceable(PandasSliceable): + @classmethod + def convert_to_numpy(cls, x): + return np.expand_dims(x.to_numpy(), axis=-1) + + +class ScipySparseSliceable(Sliceable): + def __init__(self, array): + # The COO representation is not indexable / sliceable and does not lend + # itself to it. Use the CSR representation instead, which is sliceable. + super().__init__(array.tocsr()) + + @classmethod + def convert_to_numpy(cls, x): + return x.todense() + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + return to_tensorflow_sparse_wrapper( + data_adapter_utils.scipy_sparse_to_tf_sparse(x) + ) + + @classmethod + def convert_to_jax_compatible(cls, x): + return data_adapter_utils.scipy_sparse_to_jax_sparse(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + return x.todense() + + +# `tf.SparseTensor` does not support indexing or `tf.gather`. The COO +# representation it uses does not lend itself to indexing. We add some +# intermediary tensors to ease the indexing and slicing. We put both indices and +# values in `RaggedTensor`s where each row corresponds to a row in the sparse +# tensor. This is because the number of values per row is not fixed. +# `RaggedTensor`s do support indexing and `tf.gather`, although on CPU only. +# We then reconstruct a `SparseTensor` from extracted rows. In theory, there is +# no duplication of data for the indices and values, only the addition of row +# splits for the ragged representation. +# `TensorflowSparseWrapper` is a named tuple which combines the original +# `SparseTensor` (used for the shape) and the ragged representations of indices +# and values for indexing / slicing. We use a named tuple and not a `Sliceable` +# to be able to ingest it in `tf.data.Dataset.from_tensors()` and map it. + +TensorflowSparseWrapper = collections.namedtuple( + "TensorflowSparseWrapper", ["sparse", "ragged_indices", "ragged_values"] +) + + +def to_tensorflow_sparse_wrapper(sparse): + from keras.src.utils.module_utils import tensorflow as tf + + row_ids = sparse.indices[:, 0] + row_splits = tf.experimental.RowPartition.from_value_rowids( + row_ids + ).row_splits() + + ragged_indices = tf.cast( + tf.RaggedTensor.from_row_splits(sparse.indices, row_splits), tf.int64 + ) + ragged_values = tf.RaggedTensor.from_row_splits(sparse.values, row_splits) + return TensorflowSparseWrapper(sparse, ragged_indices, ragged_values) + + +def slice_tensorflow_sparse_wrapper(sparse_wrapper, indices): + from keras.src.utils.module_utils import tensorflow as tf + + if isinstance(indices, slice): + sparse_indices = sparse_wrapper.ragged_indices[indices] + sparse_values = sparse_wrapper.ragged_values[indices] + batch_dim = indices.stop - indices.start + else: + sparse_indices = tf.gather(sparse_wrapper.ragged_indices, indices) + sparse_values = tf.gather(sparse_wrapper.ragged_values, indices) + if isinstance(indices, list): + batch_dim = len(indices) + else: + batch_dim = indices.shape[0] + if batch_dim is None: + batch_dim = tf.shape(indices)[0] + + row_ids = sparse_indices.value_rowids() + sparse_indices = sparse_indices.flat_values[:, 1:] # remove first value + sparse_indices = tf.concat( + [tf.expand_dims(row_ids, -1), sparse_indices], axis=1 + ) + + sparse_values = sparse_values.flat_values + sparse_shape = (batch_dim,) + tuple( + sparse_wrapper.sparse.shape.as_list()[1:] + ) + return tf.SparseTensor(sparse_indices, sparse_values, sparse_shape) + + +def can_slice_array(x): + return ( + x is None + or isinstance(x, ARRAY_TYPES) + or data_adapter_utils.is_tensorflow_tensor(x) + or data_adapter_utils.is_jax_array(x) + or data_adapter_utils.is_torch_tensor(x) + or data_adapter_utils.is_scipy_sparse(x) + or hasattr(x, "__array__") + ) + + +def convert_to_sliceable(arrays, target_backend=None): + """Convert a structure of arrays into `Sliceable` instances + + Args: + arrays: the arrays to convert. + target_backend: the target backend for the output: + - `None` indicates that `arrays` will be wrapped into `Sliceable`s + as-is without using a different representation. This is used by + `train_validation_split()`. + - `tensorflow` indicates that + `Sliceable.convert_to_tf_dataset_compatible` will be called. The + returned structure therefore contains arrays, not `Sliceable`s. + - `numpy`, `jax` or `torch` indices that the arrays will eventually + be converted to this backend type after slicing. In this case, + the intermediary `Sliceable`s may use a different representation + from the input `arrays` for better performance. + Returns: the same structure with `Sliceable` instances or arrays. + """ + + def convert_single_array(x): + if x is None: + return x + + # Special case: handle np "object" arrays containing strings + if ( + isinstance(x, np.ndarray) + and str(x.dtype) == "object" + and backend.backend() == "tensorflow" + and all(isinstance(e, str) for e in x) + ): + x = tf.convert_to_tensor(x, dtype="string") + + # Step 1. Determine which Sliceable class to use. + if isinstance(x, np.ndarray): + sliceable_class = NumpySliceable + elif data_adapter_utils.is_tensorflow_tensor(x): + if data_adapter_utils.is_tensorflow_ragged(x): + sliceable_class = TensorflowRaggedSliceable + elif data_adapter_utils.is_tensorflow_sparse(x): + sliceable_class = TensorflowSparseSliceable + else: + sliceable_class = TensorflowSliceable + elif data_adapter_utils.is_jax_array(x): + if data_adapter_utils.is_jax_sparse(x): + sliceable_class = JaxSparseSliceable + else: + x = np.asarray(x) + sliceable_class = NumpySliceable + elif data_adapter_utils.is_torch_tensor(x): + sliceable_class = TorchSliceable + elif pandas is not None and isinstance(x, pandas.DataFrame): + sliceable_class = PandasDataFrameSliceable + elif pandas is not None and isinstance(x, pandas.Series): + sliceable_class = PandasSeriesSliceable + elif data_adapter_utils.is_scipy_sparse(x): + sliceable_class = ScipySparseSliceable + elif hasattr(x, "__array__"): + x = np.asarray(x) + sliceable_class = NumpySliceable + else: + raise ValueError( + "Expected a NumPy array, tf.Tensor, tf.RaggedTensor, " + "tf.SparseTensor, jax.np.ndarray, " + "jax.experimental.sparse.JAXSparse, torch.Tensor, " + "Pandas Dataframe, or Pandas Series. Received invalid input: " + f"{x} (of type {type(x)})" + ) + + # Step 2. Normalize floats to floatx. + def is_non_floatx_float(dtype): + return ( + dtype is not object + and backend.is_float_dtype(dtype) + and not backend.standardize_dtype(dtype) == backend.floatx() + ) + + cast_dtype = None + if pandas is not None and isinstance(x, pandas.DataFrame): + if any(is_non_floatx_float(d) for d in x.dtypes.values): + cast_dtype = backend.floatx() + else: + if is_non_floatx_float(x.dtype): + cast_dtype = backend.floatx() + + if cast_dtype is not None: + x = sliceable_class.cast(x, cast_dtype) + + # Step 3. Apply target backend specific logic and optimizations. + if target_backend is None: + return sliceable_class(x) + + if target_backend == "tensorflow": + return sliceable_class.convert_to_tf_dataset_compatible(x) + + # With dense arrays and JAX as output, it is faster to use NumPy as an + # intermediary representation, so wrap input array in a NumPy array, + # which should not use extra memory. + # See https://github.com/google/jax/issues/1276 for an explanation of + # why slicing a NumPy array is faster than slicing a JAX array. + if target_backend == "jax" and sliceable_class in ( + TensorflowSliceable, + TorchSliceable, + ): + x = np.asarray(x) + sliceable_class = NumpySliceable + + return sliceable_class(x) + + return tree.map_structure(convert_single_array, arrays) + + +def train_validation_split(arrays, validation_split): + """Split arrays into train and validation subsets in deterministic order. + + The last part of data will become validation data. + + Args: + arrays: Tensors to split. Allowed inputs are arbitrarily nested + structures of Tensors and NumPy arrays. + validation_split: Float between 0 and 1. The proportion of the dataset + to include in the validation split. The rest of the dataset will be + included in the training split. + + Returns: + `(train_arrays, validation_arrays)` + """ + + flat_arrays = tree.flatten(arrays) + unsplitable = [type(t) for t in flat_arrays if not can_slice_array(t)] + if unsplitable: + raise ValueError( + "Argument `validation_split` is only supported " + "for tensors or NumPy arrays." + f"Found incompatible type in the input: {unsplitable}" + ) + + if all(t is None for t in flat_arrays): + return arrays, arrays + + first_non_none = None + for t in flat_arrays: + if t is not None: + first_non_none = t + break + + # Assumes all arrays have the same batch shape or are `None`. + batch_dim = int(first_non_none.shape[0]) + split_at = int(math.floor(batch_dim * (1.0 - validation_split))) + + if split_at == 0 or split_at == batch_dim: + raise ValueError( + f"Training data contains {batch_dim} samples, which is not " + "sufficient to split it into a validation and training set as " + f"specified by `validation_split={validation_split}`. Either " + "provide more data, or a different value for the " + "`validation_split` argument." + ) + + def _split(t, start, end): + if t is None: + return t + return t[start:end] + + sliceables = convert_to_sliceable(arrays) + train_arrays = tree.map_structure( + lambda x: _split(x, start=0, end=split_at), sliceables + ) + val_arrays = tree.map_structure( + lambda x: _split(x, start=split_at, end=batch_dim), sliceables + ) + return train_arrays, val_arrays diff --git a/keras/src/trainers/data_adapters/data_adapter.py b/keras/src/trainers/data_adapters/data_adapter.py new file mode 100644 index 000000000000..17e2c1784b8d --- /dev/null +++ b/keras/src/trainers/data_adapters/data_adapter.py @@ -0,0 +1,112 @@ +class DataAdapter: + """Base class for input data adapters. + + The purpose of a DataAdapter is to provide a unified interface to + iterate over input data provided in a variety of formats -- such as + NumPy arrays, tf.Tensors, tf.data.Datasets, Keras PyDatasets, etc. + """ + + def get_numpy_iterator(self): + """Get a Python iterable for the `DataAdapter`, that yields NumPy + arrays. + + Returns: + A Python iterator. + """ + raise NotImplementedError + + def get_tf_dataset(self): + """Get a `tf.data.Dataset` instance for the DataAdapter. + + Note that the dataset returned does not repeat for epoch, so caller + might need to create new iterator for the same dataset at the beginning + of the epoch. This behavior might change in the future. + + Returns: + A `tf.data.Dataset`. Caller might use the dataset in different + context, e.g. iter(dataset) in eager to get the value directly, or + in graph mode, provide the iterator tensor to Keras model function. + """ + raise NotImplementedError + + def get_jax_iterator(self): + """Get a Python iterable for the `DataAdapter`, that yields arrays that + that can be fed to JAX. NumPy arrays are preferred for performance. + + Returns: + A Python iterator. + """ + raise NotImplementedError + + def get_torch_dataloader(self): + """Get a Torch `DataLoader` for the `DataAdapter`. + + Returns: + A Torch `DataLoader`. + """ + raise NotImplementedError + + @property + def builtin_prefetch(self): + """Whether the DataAdapter has built-in prefetching capabilities. + + Prefetching is an optimization technique where data is loaded and + prepared in advance while the model is processing the current batch, + reducing training time by overlapping data loading with computation. + + Returns: + bool: True if the DataAdapter implements its own prefetching + mechanism and handles data loading asynchronously. False if the + caller should implement prefetching externally. + """ + return False + + @property + def num_batches(self): + """Return the size (number of batches) for the dataset created. + + For certain type of the data input, the number of batches is known, eg + for Numpy data, the size is same as (number_of_element / batch_size). + Whereas for dataset or python generator, the size is unknown since it + may or may not have an end state. + + Returns: + int, the number of batches for the dataset, or None if it is + unknown. The caller could use this to control the loop of training, + show progress bar, or handle unexpected StopIteration error. + """ + raise NotImplementedError + + @property + def batch_size(self): + """Return the batch size of the dataset created. + + For certain type of the data input, the batch size is known, and even + required, like numpy array. Whereas for dataset, the batch is unknown + unless we take a peek. + + Returns: + int, the batch size of the dataset, or None if it is unknown. + """ + raise NotImplementedError + + @property + def has_partial_batch(self): + """Whether the dataset has partial batch at the end.""" + raise NotImplementedError + + @property + def partial_batch_size(self): + """The size of the final partial batch for dataset. + + Will return None if has_partial_batch is False or batch_size is None. + """ + raise NotImplementedError + + def on_epoch_begin(self): + """A hook called before each epoch.""" + pass + + def on_epoch_end(self): + """A hook called after each epoch.""" + pass diff --git a/keras/src/trainers/data_adapters/data_adapter_utils.py b/keras/src/trainers/data_adapters/data_adapter_utils.py new file mode 100644 index 000000000000..6cad232ada98 --- /dev/null +++ b/keras/src/trainers/data_adapters/data_adapter_utils.py @@ -0,0 +1,395 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export + +NUM_BATCHES_FOR_TENSOR_SPEC = 2 + + +@keras_export("keras.utils.unpack_x_y_sample_weight") +def unpack_x_y_sample_weight(data): + """Unpacks user-provided data tuple. + + This is a convenience utility to be used when overriding + `Model.train_step`, `Model.test_step`, or `Model.predict_step`. + This utility makes it easy to support data of the form `(x,)`, + `(x, y)`, or `(x, y, sample_weight)`. + + Example: + + >>> features_batch = ops.ones((10, 5)) + >>> labels_batch = ops.zeros((10, 5)) + >>> data = (features_batch, labels_batch) + >>> # `y` and `sample_weight` will default to `None` if not provided. + >>> x, y, sample_weight = unpack_x_y_sample_weight(data) + >>> sample_weight is None + True + + Args: + data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. + + Returns: + The unpacked tuple, with `None`s for `y` and `sample_weight` if they are + not provided. + """ + if isinstance(data, list): + data = tuple(data) + if not isinstance(data, tuple): + return (data, None, None) + elif len(data) == 1: + return (data[0], None, None) + elif len(data) == 2: + return (data[0], data[1], None) + elif len(data) == 3: + return (data[0], data[1], data[2]) + error_msg = ( + "Data is expected to be in format `x`, `(x,)`, `(x, y)`, " + f"or `(x, y, sample_weight)`, found: {data}" + ) + raise ValueError(error_msg) + + +@keras_export("keras.utils.pack_x_y_sample_weight") +def pack_x_y_sample_weight(x, y=None, sample_weight=None): + """Packs user-provided data into a tuple. + + This is a convenience utility for packing data into the tuple formats + that `Model.fit()` uses. + + Example: + + >>> x = ops.ones((10, 1)) + >>> data = pack_x_y_sample_weight(x) + >>> isinstance(data, ops.Tensor) + True + >>> y = ops.ones((10, 1)) + >>> data = pack_x_y_sample_weight(x, y) + >>> isinstance(data, tuple) + True + >>> x, y = data + + Args: + x: Features to pass to `Model`. + y: Ground-truth targets to pass to `Model`. + sample_weight: Sample weight for each element. + + Returns: + Tuple in the format used in `Model.fit()`. + """ + if y is None: + # For single x-input, we do no tuple wrapping since in this case + # there is no ambiguity. This also makes NumPy and Dataset + # consistent in that the user does not have to wrap their Dataset + # data in an unnecessary tuple. + if not isinstance(x, (tuple, list)): + return x + else: + return (x,) + elif sample_weight is None: + return (x, y) + else: + return (x, y, sample_weight) + + +def list_to_tuple(maybe_list): + """Datasets will stack any list of tensors, so we convert them to tuples.""" + if isinstance(maybe_list, list): + return tuple(maybe_list) + return maybe_list + + +def check_data_cardinality(data): + num_samples = set( + int(i.shape[0]) for i in tree.flatten(data) if i is not None + ) + if len(num_samples) > 1: + msg = ( + "Data cardinality is ambiguous. " + "Make sure all arrays contain the same number of samples." + ) + for label, single_data in zip(["x", "y", "sample_weight"], data): + sizes = ", ".join( + str(i.shape[0]) for i in tree.flatten(single_data) + ) + msg += f"'{label}' sizes: {sizes}\n" + raise ValueError(msg) + + +def class_weight_to_sample_weights(y, class_weight): + # Convert to numpy to ensure consistent handling of operations + # (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch + + y_numpy = ops.convert_to_numpy(y) + sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx()) + if len(y_numpy.shape) > 1: + if y_numpy.shape[-1] != 1: + y_numpy = np.argmax(y_numpy, axis=-1) + else: + y_numpy = np.squeeze(y_numpy, axis=-1) + y_numpy = np.round(y_numpy).astype("int32") + + for i in range(y_numpy.shape[0]): + sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0) + return sample_weight + + +def get_keras_tensor_spec(batches): + """Return the KerasTensor spec for a list of batches. + + The spec is represented using `KerasTensor` which could handle dense, sparse + or ragged tensors. + + Args: + batches: list of structures of tensors. The structures must be + identical, but the shape at each leaf may be different. + + Returns: + A nested structure of `KerasTensor`. + """ + + def get_single_tensor_spec(*tensors): + x = tensors[0] + if not hasattr(x, "shape"): + # Try to convert to a numpy array. + x = np.array(x) + rank = len(x.shape) + if rank < 1: + raise ValueError( + "When passing a dataset to a Keras model, the arrays must " + f"be at least rank 1. Received: {x} of rank {len(x.shape)}." + ) + for t in tensors: + if len(t.shape) != rank: + raise ValueError( + "When passing a dataset to a Keras model, the " + "corresponding arrays in each batch must have the same " + f"rank. Received: {x} and {t}" + ) + shape = [] + # Merge shapes: go through each dimension one by one and keep the + # common values + for dims in zip(*[list(x.shape) for x in tensors]): + dims_set = set(dims) + shape.append(dims_set.pop() if len(dims_set) == 1 else None) + + dtype = backend.standardize_dtype(x.dtype) + if is_tensorflow_ragged(x): + return backend.KerasTensor( + shape=shape, + dtype=dtype, + ragged=True, + ragged_rank=x.ragged_rank, + row_splits_dtype=x.row_splits.dtype, + ) + if is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x): + return backend.KerasTensor(shape=shape, dtype=dtype, sparse=True) + else: + return backend.KerasTensor(shape=shape, dtype=dtype) + + return tree.map_structure( + get_single_tensor_spec, *batches, none_is_leaf=False + ) + + +def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True): + """Convert a KerasTensor to a TensorSpec. + + Args: + keras_tensor: A KerasTensor instance. + batch_axis_to_none: If `True`, the batch axis of the returned + tensor spec will be set to None. Defaults to `True`. + """ + from keras.src.utils.module_utils import tensorflow as tf + + if keras_tensor is None: + return tf.OptionalSpec(None) + if not isinstance(keras_tensor, backend.KerasTensor): + raise TypeError( + f"Expected a KerasTensor, but got {keras_tensor} of type " + f"{type(keras_tensor)}." + ) + shape = list(keras_tensor.shape) + if batch_axis_to_none: + shape[0] = None + if keras_tensor.ragged: + return tf.RaggedTensorSpec( + shape=shape, + dtype=keras_tensor.dtype, + ragged_rank=keras_tensor.ragged_rank, + row_splits_dtype=keras_tensor.row_splits_dtype, + ) + elif keras_tensor.sparse: + return tf.SparseTensorSpec(shape=shape, dtype=keras_tensor.dtype) + else: + return tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype) + + +def get_tensor_spec(batches): + """Return the common tensor spec for a list of batches. + + The spec is represented using `tf.TensorSpec`, `tf.SparseTensorSpec` and + `tf.RaggedTensorSpec`. + + Args: + batches: list of structures of tensors. The structures must be + identical, but the shape at each leaf may be different. + + Returns: + A common tensor spec. + """ + tensor_specs = get_keras_tensor_spec(batches) + return tree.map_structure(convert_to_tf_tensor_spec, tensor_specs) + + +def get_jax_iterator(iterable): + import jax + import jax.experimental.sparse as jax_sparse + + def convert_to_jax_compatible(x): + if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)): + return x + elif is_scipy_sparse(x): + return scipy_sparse_to_jax_sparse(x) + elif is_tensorflow_sparse(x): + return tf_sparse_to_jax_sparse(x) + else: + return np.asarray(x) + + for batch in iterable: + yield tree.map_structure( + convert_to_jax_compatible, batch, none_is_leaf=False + ) + + +def get_numpy_iterator(iterable): + def convert_to_numpy(x): + if not isinstance(x, np.ndarray): + # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`, + # `torch.Tensor`, as well as any other tensor-like object that + # has added numpy support. + if hasattr(x, "__array__"): + if is_torch_tensor(x): + x = x.cpu() + x = np.asarray(x) + return x + + for batch in iterable: + yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False) + + +def get_torch_dataloader(iterable): + import torch.utils.data as torch_data + + from keras.src.backend.torch.core import convert_to_tensor + + class ConverterIterableDataset(torch_data.IterableDataset): + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + for batch in self.iterable: + yield tree.map_structure( + convert_to_tensor, batch, none_is_leaf=False + ) + + dataset = ConverterIterableDataset(iterable) + # `batch_size=None` indicates that we should not re-batch + return torch_data.DataLoader(dataset, batch_size=None) + + +def is_tensorflow_tensor(value): + if hasattr(value, "__class__"): + if value.__class__.__name__ in ("RaggedTensor", "SparseTensor"): + return "tensorflow.python." in str(value.__class__.__module__) + for parent in value.__class__.__mro__: + if parent.__name__ in ("Tensor") and "tensorflow.python." in str( + parent.__module__ + ): + return True + return False + + +def is_tensorflow_ragged(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__name__ == "RaggedTensor" + and "tensorflow.python." in str(value.__class__.__module__) + ) + return False + + +def is_tensorflow_sparse(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__name__ == "SparseTensor" + and "tensorflow.python." in str(value.__class__.__module__) + ) + return False + + +def is_jax_array(value): + if hasattr(value, "__class__"): + for parent in value.__class__.__mro__: + if parent.__name__ == "Array" and str(parent.__module__) == "jax": + return True + return is_jax_sparse(value) # JAX sparse arrays do not extend jax.Array + + +def is_jax_sparse(value): + if hasattr(value, "__class__"): + return str(value.__class__.__module__).startswith( + "jax.experimental.sparse" + ) + return False + + +def is_torch_tensor(value): + if hasattr(value, "__class__"): + for parent in value.__class__.__mro__: + if parent.__name__ == "Tensor" and str(parent.__module__).endswith( + "torch" + ): + return True + return False + + +def is_scipy_sparse(x): + return str(x.__class__.__module__).startswith("scipy.sparse") and hasattr( + x, "tocoo" + ) + + +def scipy_sparse_to_tf_sparse(x): + from keras.src.utils.module_utils import tensorflow as tf + + coo = x.tocoo() + indices = np.concatenate( + (np.expand_dims(coo.row, 1), np.expand_dims(coo.col, 1)), axis=1 + ) + return tf.SparseTensor(indices, coo.data, coo.shape) + + +def scipy_sparse_to_jax_sparse(x): + import jax + import jax.experimental.sparse as jax_sparse + + with jax.default_device(jax.local_devices(backend="cpu")[0]): + return jax_sparse.BCOO.from_scipy_sparse(x) + + +def tf_sparse_to_jax_sparse(x): + import jax + import jax.experimental.sparse as jax_sparse + + values = np.asarray(x.values) + indices = np.asarray(x.indices) + with jax.default_device(jax.local_devices(backend="cpu")[0]): + return jax_sparse.BCOO((values, indices), shape=x.shape) + + +def jax_sparse_to_tf_sparse(x): + from keras.src.utils.module_utils import tensorflow as tf + + return tf.SparseTensor(x.indices, x.data, x.shape) diff --git a/keras/src/trainers/data_adapters/data_adapter_utils_test.py b/keras/src/trainers/data_adapters/data_adapter_utils_test.py new file mode 100644 index 000000000000..01d62eeaa581 --- /dev/null +++ b/keras/src/trainers/data_adapters/data_adapter_utils_test.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.trainers.data_adapters.data_adapter_utils import ( + class_weight_to_sample_weights, +) + + +class TestClassWeightToSampleWeights(testing.TestCase): + @parameterized.named_parameters( + [ + # Simple case, where y is flat + ( + "simple_class_labels", + np.array([0, 1, 0, 2]), + {0: 1.0, 1: 2.0, 2: 3.0}, + np.array([1.0, 2.0, 1.0, 3.0]), + ), + # Testing with one-hot encoded labels, + # so basically the argmax statement + ( + "one_hot_encoded_labels", + np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]), + {0: 1.0, 1: 2.0, 2: 3.0}, + np.array([1.0, 2.0, 1.0, 3.0]), + ), + # 3 is not mapped, so it's assigned the default weight (1) + ( + "unmapped_class", + np.array([0, 3, 0, 2]), + {0: 1.0, 1: 2.0, 2: 3.0}, + np.array([1.0, 1.0, 1.0, 3.0]), + ), + ( + "multi_dimensional_input", + np.array([[0], [1], [0], [2]]), + {0: 1.0, 1: 2.0, 2: 3.0}, + np.array([1.0, 2.0, 1.0, 3.0]), + ), + ( + "all_unmapped", + np.array([0, 1, 0, 2]), + {}, + np.array([1.0, 1.0, 1.0, 1.0]), + ), + ] + ) + def test_class_weight_to_sample_weights(self, y, class_weight, expected): + self.assertAllClose( + class_weight_to_sample_weights(y, class_weight), expected + ) + + @pytest.mark.skipif(backend.backend() != "torch", reason="torch only") + def test_class_weight_to_sample_weights_torch_specific(self): + import torch + + y = torch.from_numpy(np.array([0, 1, 0, 2])) + self.assertAllClose( + class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}), + np.array([1.0, 2.0, 1.0, 3.0]), + ) + y_one_hot = torch.from_numpy( + np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]) + ) + self.assertAllClose( + class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}), + np.array([1.0, 2.0, 1.0, 3.0]), + ) + + @pytest.mark.skipif(backend.backend() != "jax", reason="jax only") + def test_class_weight_to_sample_weights_jax_specific(self): + import jax + + y = jax.numpy.asarray(np.array([0, 1, 0, 2])) + self.assertAllClose( + class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}), + np.array([1.0, 2.0, 1.0, 3.0]), + ) + y_one_hot = jax.numpy.asarray( + np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]) + ) + self.assertAllClose( + class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}), + np.array([1.0, 2.0, 1.0, 3.0]), + ) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="tensorflow only" + ) + def test_class_weight_to_sample_weights_tf_specific(self): + import tensorflow as tf + + y = tf.convert_to_tensor(np.array([0, 1, 0, 2])) + self.assertAllClose( + class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}), + np.array([1.0, 2.0, 1.0, 3.0]), + ) + y_one_hot = tf.convert_to_tensor( + np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]) + ) + self.assertAllClose( + class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}), + np.array([1.0, 2.0, 1.0, 3.0]), + ) diff --git a/keras/src/trainers/data_adapters/generator_data_adapter.py b/keras/src/trainers/data_adapters/generator_data_adapter.py new file mode 100644 index 000000000000..186e45da93de --- /dev/null +++ b/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -0,0 +1,89 @@ +import itertools + +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class GeneratorDataAdapter(DataAdapter): + """Adapter for Python generators.""" + + def __init__(self, generator): + first_batches, generator = peek_and_restore(generator) + self.generator = generator + self._first_batches = first_batches + self._output_signature = None + if not isinstance(first_batches[0], tuple): + raise ValueError( + "When passing a Python generator to a Keras model, " + "the generator must return a tuple, either " + "(input,) or (inputs, targets) or " + "(inputs, targets, sample_weights). " + f"Received: {first_batches[0]}" + ) + + def get_numpy_iterator(self): + return data_adapter_utils.get_numpy_iterator(self.generator()) + + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self.generator()) + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + def convert_to_tf(x, spec): + if x is None: + return tf.experimental.Optional.empty(None) + if data_adapter_utils.is_scipy_sparse(x): + x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) + elif data_adapter_utils.is_jax_sparse(x): + x = data_adapter_utils.jax_sparse_to_tf_sparse(x) + if not spec.shape.is_compatible_with(x.shape): + raise TypeError( + f"Generator yielded an element of shape {x.shape} where " + f"an element of shape {spec.shape} was expected. Your " + "generator provides tensors with variable input " + "dimensions other than the batch size. Make sure that the " + "generator's first two batches do not have the same " + "dimension value wherever there is a variable input " + "dimension." + ) + return x + + def get_tf_iterator(): + for batch in self.generator(): + batch = tree.map_structure( + convert_to_tf, batch, self._output_signature + ) + yield batch + + if self._output_signature is None: + self._output_signature = data_adapter_utils.get_tensor_spec( + self._first_batches + ) + ds = tf.data.Dataset.from_generator( + get_tf_iterator, + output_signature=self._output_signature, + ) + ds = ds.prefetch(tf.data.AUTOTUNE) + return ds + + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self.generator()) + + @property + def num_batches(self): + return None + + @property + def batch_size(self): + return None + + +def peek_and_restore(generator): + batches = list( + itertools.islice( + generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + ) + ) + return batches, lambda: itertools.chain(batches, generator) diff --git a/keras/src/trainers/data_adapters/generator_data_adapter_test.py b/keras/src/trainers/data_adapters/generator_data_adapter_test.py new file mode 100644 index 000000000000..35a129be1e85 --- /dev/null +++ b/keras/src/trainers/data_adapters/generator_data_adapter_test.py @@ -0,0 +1,229 @@ +import math + +import jax +import jax.experimental.sparse as jax_sparse +import numpy as np +import pytest +import scipy +import tensorflow as tf +import torch +from absl.testing import parameterized +from jax import numpy as jnp + +from keras.src import backend +from keras.src import testing +from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters import generator_data_adapter + + +def example_generator(x, y, sample_weight=None, batch_size=32): + def make(): + for i in range(math.ceil(len(x) / batch_size)): + low = i * batch_size + high = min(low + batch_size, len(x)) + batch_x = x[low:high] + batch_y = y[low:high] + if sample_weight is not None: + yield batch_x, batch_y, sample_weight[low:high] + else: + yield batch_x, batch_y + + return make + + +class GeneratorDataAdapterTest(testing.TestCase): + @parameterized.named_parameters( + named_product( + [ + {"testcase_name": "use_weight", "use_sample_weight": True}, + {"testcase_name": "no_weight", "use_sample_weight": False}, + ], + generator_type=["np", "tf", "jax", "torch"], + ) + ) + def test_basic_flow(self, use_sample_weight, generator_type): + x = np.random.random((34, 4)).astype("float32") + y = np.array([[i, i] for i in range(34)], dtype="float32") + sw = np.random.random((34,)).astype("float32") + if generator_type == "tf": + x, y, sw = tf.constant(x), tf.constant(y), tf.constant(sw) + elif generator_type == "jax": + x, y, sw = jnp.array(x), jnp.array(y), jnp.array(sw) + elif generator_type == "torch": + x, y, sw = ( + torch.as_tensor(x), + torch.as_tensor(y), + torch.as_tensor(sw), + ) + if not use_sample_weight: + sw = None + make_generator = example_generator( + x, + y, + sample_weight=sw, + batch_size=16, + ) + + adapter = generator_data_adapter.GeneratorDataAdapter(make_generator()) + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = ( + jax.Array if generator_type == "jax" else np.ndarray + ) + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + sample_order = [] + for i, batch in enumerate(it): + if use_sample_weight: + self.assertEqual(len(batch), 3) + bx, by, bsw = batch + else: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i < 2: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + if use_sample_weight: + self.assertIsInstance(bsw, expected_class) + for j in range(by.shape[0]): + sample_order.append(by[j, 0]) + self.assertAllClose(sample_order, list(range(34))) + + def test_with_different_shapes(self): + def generator(): + yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32") + yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32") + yield np.ones([2, 6], "float32"), np.ones([2, 2], "float32") + + adapter = generator_data_adapter.GeneratorDataAdapter(generator()) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i == 0: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + elif i == 1: + self.assertEqual(bx.shape, (16, 5)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 6)) + self.assertEqual(by.shape, (2, 2)) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="tf.data.Dataset specific behavior", + ) + def test_with_unexpected_shapes(self): + def generator(): + yield np.ones([16, 4], "float32"), np.ones([16, 2], "float32") + yield np.ones([16, 5], "float32"), np.ones([16, 2], "float32") + yield np.ones([16, 6], "float32"), np.ones([16, 3], "float32") + + adapter = generator_data_adapter.GeneratorDataAdapter(generator()) + + it = iter(adapter.get_tf_dataset()) + next(it) + next(it) + # note that Tensorflow wraps the TypeError in an InvalidArgumentError. + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "TypeError:.* shape \\(16, 3\\).* shape \\(None, 2\\) was expected" + ".*first two batches", + ): + next(it) + + @parameterized.named_parameters( + named_product(generator_type=["tf", "jax", "scipy"]) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors", + ) + def test_sparse_tensors(self, generator_type): + if generator_type == "tf": + x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 4)) + y = tf.SparseTensor([[0, 0], [1, 1]], [3.0, 4.0], (2, 2)) + elif generator_type == "jax": + x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 4)) + y = jax_sparse.BCOO(([3.0, 4.0], [[0, 0], [1, 1]]), shape=(2, 2)) + elif generator_type == "scipy": + x = scipy.sparse.coo_matrix(([1.0, 2.0], ([0, 1], [0, 2])), (2, 4)) + y = scipy.sparse.coo_matrix(([3.0, 4.0], ([0, 1], [0, 1])), (2, 2)) + + def generate(): + for _ in range(4): + yield x, y + + adapter = generator_data_adapter.GeneratorDataAdapter(generate()) + + if backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.SparseTensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = jax_sparse.BCOO + + for batch in it: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + + @pytest.mark.skipif( + not backend.SUPPORTS_RAGGED_TENSORS, + reason="Backend does not support ragged tensors", + ) + def test_ragged_tensors(self): + x = tf.ragged.constant( + [[[0.0, 1.0]], [[2.0, 3.0], [4.0, 5.0]]], ragged_rank=1 + ) + y = tf.ragged.constant( + [[[0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], ragged_rank=1 + ) + + def generate(): + for _ in range(4): + yield x, y + + adapter = generator_data_adapter.GeneratorDataAdapter(generate()) + + if backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.RaggedTensor + + for batch in it: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.shape, (2, None, 2)) + self.assertEqual(by.shape, (2, None, 2)) diff --git a/keras/src/trainers/data_adapters/grain_dataset_adapter.py b/keras/src/trainers/data_adapters/grain_dataset_adapter.py new file mode 100644 index 000000000000..de62f962caf4 --- /dev/null +++ b/keras/src/trainers/data_adapters/grain_dataset_adapter.py @@ -0,0 +1,214 @@ +import itertools + +import numpy as np + +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter +from keras.src.utils.module_utils import grain +from keras.src.utils.module_utils import tensorflow as tf + + +class GrainDatasetAdapter(DataAdapter): + """Adapter that handles `grain.DataLoader`, `grain.MapDataset` and + `grain.IterDataset`. + """ + + def __init__(self, dataset): + """Initialize the GrainDatasetAdapter. + + Args: + dataset: A Grain dataset instance. Must be one of + `grain.DataLoader`, `grain.MapDataset`, or `grain.IterDataset`. + """ + + if not isinstance( + dataset, (grain.MapDataset, grain.IterDataset, grain.DataLoader) + ): + raise ValueError( + "Expected `dataset` to be a grain.MapDataset, " + "grain.IterDataset or grain.DataLoader. " + f"Received: {dataset} of type {type(dataset)}" + ) + + self._dataset = dataset + + batch_size, output_signature = self._get_dataset_info(dataset) + self._batch_size = batch_size + self._output_signature = output_signature + self._output_tf_signature = None + + def _get_dataset_info(self, dataset): + """Get the `batch_size` and `output_signature` from the dataset. + + We use a small list of batches to infer the `batch_size` and + `output_signature`. + """ + batches = list( + itertools.islice( + dataset, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + ) + ) + output_signature = data_adapter_utils.get_keras_tensor_spec(batches) + flat_output_signature = tree.flatten(output_signature) + batch_size = flat_output_signature[0].shape[0] + if batch_size is not None: + batch_size = int(batch_size) + return batch_size, output_signature + + def get_numpy_iterator(self): + from grain._src.python.shared_memory_array import ( + SharedMemoryArrayMetadata, + ) + + def convert_to_numpy(x): + if isinstance(x, (np.ndarray, SharedMemoryArrayMetadata)): + return x + else: + # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`, + # `torch.Tensor`, as well as any other tensor-like object that + # has added numpy support. + if hasattr(x, "__array__"): + if data_adapter_utils.is_torch_tensor(x): + x = x.cpu() + x = np.asarray(x) + return x + + class ConvertToNumpy(grain.transforms.Map): + def map(self, x): + return tree.map_structure( + convert_to_numpy, x, none_is_leaf=False + ) + + if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): + dataset = self._dataset.map(ConvertToNumpy()) + else: + # Instantiate a new `DataLoader`. + dataset = grain.DataLoader( + data_source=self._dataset._data_source, + sampler=self._dataset._sampler, + # Append `ConvertToNumpy`. + operations=list(self._dataset._operations) + [ConvertToNumpy()], + worker_count=self._dataset._multiprocessing_options.num_workers, + worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size, + shard_options=self._dataset._shard_options, + read_options=self._dataset._read_options, + enable_profiling=self._dataset._multiprocessing_options.enable_profiling, + ) + return dataset + + def get_jax_iterator(self): + def convert_to_jax_compatible(x): + if data_adapter_utils.is_scipy_sparse(x): + x = data_adapter_utils.scipy_sparse_to_jax_sparse(x) + elif data_adapter_utils.is_tensorflow_sparse(x): + x = data_adapter_utils.tf_sparse_to_jax_sparse(x) + return x + + class ConvertToJaxCompatible(grain.transforms.Map): + def map(self, x): + return tree.map_structure( + convert_to_jax_compatible, x, none_is_leaf=False + ) + + if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): + dataset = self._dataset.map(ConvertToJaxCompatible()) + else: + # Instantiate a new `DataLoader`. + dataset = grain.DataLoader( + data_source=self._dataset._data_source, + sampler=self._dataset._sampler, + # Append `ConvertToJaxCompatible`. + operations=list(self._dataset._operations) + + [ConvertToJaxCompatible()], + worker_count=self._dataset._multiprocessing_options.num_workers, + worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size, + shard_options=self._dataset._shard_options, + read_options=self._dataset._read_options, + enable_profiling=self._dataset._multiprocessing_options.enable_profiling, + ) + return dataset + + def get_tf_dataset(self): + def convert_to_tf(x): + if x is None: + return tf.experimental.Optional.empty(None) + if data_adapter_utils.is_scipy_sparse(x): + x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) + elif data_adapter_utils.is_jax_sparse(x): + x = data_adapter_utils.jax_sparse_to_tf_sparse(x) + return x + + class ConvertToTF(grain.transforms.Map): + def map(self, x): + return tree.map_structure(convert_to_tf, x) + + # `tf.data.Dataset.from_generator` does not support lists as output. + # We convert lists to tuples. + class ListToTuple(grain.transforms.Map): + def map(self, x): + return tree.lists_to_tuples(x) + + if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)): + dataset = self._dataset.map(ConvertToTF()) + dataset = dataset.map(ListToTuple()) + else: + # Instantiate a new `DataLoader`. + dataset = grain.DataLoader( + data_source=self._dataset._data_source, + sampler=self._dataset._sampler, + # Append `ConvertToTF` and `ListToTuple`. + operations=list(self._dataset._operations) + + [ConvertToTF(), ListToTuple()], + worker_count=self._dataset._multiprocessing_options.num_workers, + worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size, + shard_options=self._dataset._shard_options, + read_options=self._dataset._read_options, + enable_profiling=self._dataset._multiprocessing_options.enable_profiling, + ) + + if self._output_tf_signature is None: + self._output_tf_signature = tree.map_structure( + data_adapter_utils.convert_to_tf_tensor_spec, + self._output_signature, + ) + + return tf.data.Dataset.from_generator( + lambda: dataset, output_signature=self._output_tf_signature + ) + + def get_torch_dataloader(self): + import torch.utils.data as torch_data + + class ConverterIterableDataset(torch_data.IterableDataset): + def __init__(self, iterable): + super().__init__() + self.iterable = iterable + + def __iter__(self): + return iter(self.iterable) + + # `batch_size=None` indicates that we should not re-batch + return torch_data.DataLoader( + ConverterIterableDataset(self._dataset), batch_size=None + ) + + @property + def builtin_prefetch(self): + return True + + @property + def num_batches(self): + return None + + @property + def batch_size(self): + return self._batch_size + + @property + def has_partial_batch(self): + return None + + @property + def partial_batch_size(self): + return None diff --git a/keras/src/trainers/data_adapters/grain_dataset_adapter_test.py b/keras/src/trainers/data_adapters/grain_dataset_adapter_test.py new file mode 100644 index 000000000000..cb9dc870b807 --- /dev/null +++ b/keras/src/trainers/data_adapters/grain_dataset_adapter_test.py @@ -0,0 +1,219 @@ +import grain +import numpy as np +import tensorflow as tf +import torch +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters import grain_dataset_adapter + + +class Range2DSource(grain.sources.RandomAccessDataSource): + def __init__(self, start, stop): + self.start = start + self.stop = stop + + def __getitem__(self, idx): + return np.expand_dims(np.array([self.start + idx]), axis=0) + + def __len__(self): + return self.stop - self.start + + +class GrainDatasetAdapterTest(testing.TestCase): + def _get_dataset(self, dataset_type, worker_count=0, num_threads=0): + x = np.random.normal(size=(34, 4)).astype("float32") + y = np.random.normal(size=(34, 2)).astype("float32") + + class MySource(grain.sources.RandomAccessDataSource): + def __init__(self, x, y): + self.x = x + self.y = y + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + def __len__(self): + return len(self.x) + + if dataset_type == "map_dataset": + dataset = grain.MapDataset.source(MySource(x, y)).batch( + batch_size=16 + ) + elif dataset_type == "iter_dataset": + dataset = ( + grain.MapDataset.source(MySource(x, y)) + .to_iter_dataset() + .batch(batch_size=16) + ) + else: + source = MySource(x, y) + dataset = grain.DataLoader( + data_source=source, + operations=[grain.transforms.Batch(batch_size=16)], + shard_options=grain.sharding.NoSharding(), + sampler=grain.samplers.IndexSampler( + num_records=len(source), num_epochs=1 + ), + worker_count=worker_count, + read_options=grain.ReadOptions(num_threads=num_threads), + ) + return dataset + + @parameterized.named_parameters( + named_product( + dataset_type=["map_dataset", "iter_dataset", "data_loader"] + ) + ) + def test_basic_flow(self, dataset_type): + dataset = self._get_dataset(dataset_type) + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + + self.assertEqual(adapter.num_batches, None) + self.assertEqual(adapter.batch_size, 16) + self.assertEqual(adapter.has_partial_batch, None) + self.assertEqual(adapter.partial_batch_size, None) + + if backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + else: + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i < 2: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + + @parameterized.named_parameters( + named_product(data_type=["list", "dict", "nested_list", "nested_dict"]) + ) + def test_nested_data(self, data_type): + if data_type not in ("list", "dict", "nested_list", "nested_dict"): + raise ValueError( + "data_type must be one of 'list', 'dict', 'nested_list' or " + f"'nested_dict'. Received: {data_type}" + ) + + class NestedSource(grain.sources.RandomAccessDataSource): + def __init__(self, data_type): + self.x = np.random.random((40, 4)).astype("float32") + self.y = np.random.random((40, 2)).astype("float32") + self.data_type = data_type + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + x = self.x[idx] + y = self.y[idx] + if self.data_type == "list": + return x, y + elif self.data_type == "dict": + return {"x": x, "y": y} + elif self.data_type == "nested_list": + return x, (x, y) + elif self.data_type == "nested_dict": + return {"data": {"x": x, "y": y}} + + dataset = grain.MapDataset.source(NestedSource(data_type)).batch( + batch_size=4 + ) + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + + if backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + else: + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + + for batch in it: + if data_type == "list": + self.assertEqual(len(batch), 2) + bx, by = batch + elif data_type == "dict": + self.assertEqual(len(batch), 2) + bx, by = batch["x"], batch["y"] + elif data_type == "nested_list": + self.assertEqual(len(batch), 2) + bx, (_, by) = batch + elif data_type == "nested_dict": + self.assertEqual(len(batch["data"]), 2) + bx, by = batch["data"]["x"], batch["data"]["y"] + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(bx.shape, (4, 4)) + self.assertEqual(by.shape, (4, 2)) + + def test_multiple_calling_on_iterators(self): + dataset = self._get_dataset("iter_dataset") + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + + numpy_it = adapter.get_numpy_iterator() + jax_it = adapter.get_jax_iterator() + tf_it = adapter.get_tf_dataset() + torch_it = adapter.get_torch_dataloader() + for it in (numpy_it, jax_it, tf_it, torch_it): + for batch in it: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertEqual(bx.dtype, by.dtype) + + def test_builtin_prefetch(self): + dataset = grain.MapDataset.source(Range2DSource(0, 42)) + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + self.assertTrue(adapter.builtin_prefetch) + + def test_num_batches(self): + dataset = grain.MapDataset.source(Range2DSource(0, 42)) + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + self.assertEqual(adapter.num_batches, None) + + # Test for Infinite Cardinality + dataset = grain.MapDataset.source(Range2DSource(0, 42)) + dataset = dataset.repeat() + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + self.assertIsNone(adapter.num_batches) + + # Test for Unknown Cardinality + dataset = dataset.filter(lambda x: True) + adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset) + self.assertIsNone(adapter.num_batches) + + def test_invalid_dataset_type(self): + with self.assertRaisesRegex( + ValueError, + ( + r"Expected `dataset` to be a grain.MapDataset, " + r"grain.IterDataset or grain.DataLoader. " + ), + ): + grain_dataset_adapter.GrainDatasetAdapter( + "This is not a grain.Dataset" + ) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py new file mode 100644 index 000000000000..18865af026cf --- /dev/null +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -0,0 +1,708 @@ +import itertools +import multiprocessing.dummy +import queue +import random +import threading +import warnings +import weakref +from contextlib import closing + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +@keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"]) +class PyDataset: + """Base class for defining a parallel dataset using Python code. + + Every `PyDataset` must implement the `__getitem__()` and the `__len__()` + methods. If you want to modify your dataset between epochs, + you may additionally implement `on_epoch_end()`, + or `on_epoch_begin` to be called at the start of each epoch. + The `__getitem__()` method should return a complete batch + (not a single sample), and the `__len__` method should return + the number of batches in the dataset (rather than the number of samples). + + Args: + workers: Number of workers to use in multithreading or + multiprocessing. + use_multiprocessing: Whether to use Python multiprocessing for + parallelism. Setting this to `True` means that your + dataset will be replicated in multiple forked processes. + This is necessary to gain compute-level (rather than I/O level) + benefits from parallelism. However it can only be set to + `True` if your dataset can be safely pickled. + max_queue_size: Maximum number of batches to keep in the queue + when iterating over the dataset in a multithreaded or + multiprocessed setting. + Reduce this value to reduce the CPU memory consumption of + your dataset. Defaults to 10. + + Notes: + + - `PyDataset` is a safer way to do multiprocessing. + This structure guarantees that the model will only train + once on each sample per epoch, which is not the case + with Python generators. + - The arguments `workers`, `use_multiprocessing`, and `max_queue_size` + exist to configure how `fit()` uses parallelism to iterate + over the dataset. They are not being used by the `PyDataset` class + directly. When you are manually iterating over a `PyDataset`, + no parallelism is applied. + + Example: + + ```python + from skimage.io import imread + from skimage.transform import resize + import numpy as np + import math + + # Here, `x_set` is list of path to the images + # and `y_set` are the associated classes. + + class CIFAR10PyDataset(keras.utils.PyDataset): + + def __init__(self, x_set, y_set, batch_size, **kwargs): + super().__init__(**kwargs) + self.x, self.y = x_set, y_set + self.batch_size = batch_size + + def __len__(self): + # Return number of batches. + return math.ceil(len(self.x) / self.batch_size) + + def __getitem__(self, idx): + # Return x, y for batch idx. + low = idx * self.batch_size + # Cap upper bound at array length; the last batch may be smaller + # if the total number of items is not a multiple of batch size. + high = min(low + self.batch_size, len(self.x)) + batch_x = self.x[low:high] + batch_y = self.y[low:high] + + return np.array([ + resize(imread(file_name), (200, 200)) + for file_name in batch_x]), np.array(batch_y) + ``` + """ + + def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): + self._workers = workers + self._use_multiprocessing = use_multiprocessing + self._max_queue_size = max_queue_size + + def _warn_if_super_not_called(self): + warn = False + if not hasattr(self, "_workers"): + self._workers = 1 + warn = True + if not hasattr(self, "_use_multiprocessing"): + self._use_multiprocessing = False + warn = True + if not hasattr(self, "_max_queue_size"): + self._max_queue_size = 10 + warn = True + if warn: + warnings.warn( + "Your `PyDataset` class should call " + "`super().__init__(**kwargs)` in its constructor. " + "`**kwargs` can include `workers`, " + "`use_multiprocessing`, `max_queue_size`. Do not pass " + "these arguments to `fit()`, as they will be ignored.", + stacklevel=2, + ) + + @property + def workers(self): + self._warn_if_super_not_called() + return self._workers + + @workers.setter + def workers(self, value): + self._workers = value + + @property + def use_multiprocessing(self): + self._warn_if_super_not_called() + return self._use_multiprocessing + + @use_multiprocessing.setter + def use_multiprocessing(self, value): + self._use_multiprocessing = value + + @property + def max_queue_size(self): + self._warn_if_super_not_called() + return self._max_queue_size + + @max_queue_size.setter + def max_queue_size(self, value): + self._max_queue_size = value + + def __getitem__(self, index): + """Gets batch at position `index`. + + Args: + index: position of the batch in the PyDataset. + + Returns: + A batch + """ + del index + raise NotImplementedError + + def __iter__(self): + index_range = None + try: + num_batches = self.num_batches + if num_batches is not None: + index_range = range(num_batches) + except NotImplementedError: + pass + + if index_range is None: + index_range = itertools.count() + + for index in index_range: + yield self[index] + + @property + def num_batches(self): + """Number of batches in the PyDataset. + + Returns: + The number of batches in the PyDataset or `None` to indicate that + the dataset is infinite. + """ + # For backwards compatibility, support `__len__`. + if hasattr(self, "__len__"): + return len(self) + raise NotImplementedError( + "You need to implement the `num_batches` property:\n\n" + "@property\ndef num_batches(self):\n return ..." + ) + + def on_epoch_begin(self): + """Method called at the beginning of every epoch.""" + pass + + def on_epoch_end(self): + """Method called at the end of every epoch.""" + pass + + +class PyDatasetAdapter(DataAdapter): + """Adapter for `keras.utils.PyDataset` instances.""" + + def __init__( + self, + x, + class_weight=None, + shuffle=False, + ): + self.py_dataset = x + self.class_weight = class_weight + self.enqueuer = None + self.shuffle = shuffle + self._output_signature = None + self._within_epoch = False + + workers = self.py_dataset.workers + use_multiprocessing = self.py_dataset.use_multiprocessing + if workers > 1 or (workers > 0 and use_multiprocessing): + self.enqueuer = OrderedEnqueuer( + self.py_dataset, + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=self.py_dataset.max_queue_size, + shuffle=self.shuffle, + ) + + def _standardize_batch(self, batch): + if isinstance(batch, dict): + return batch + if isinstance(batch, np.ndarray): + batch = (batch,) + if isinstance(batch, list): + batch = tuple(batch) + if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}: + raise ValueError( + "PyDataset.__getitem__() must return a tuple or a dict. " + "If a tuple, it must be ordered either " + "(input,) or (inputs, targets) or " + "(inputs, targets, sample_weights). " + f"Received: {str(batch)[:100]}... of type {type(batch)}" + ) + if self.class_weight is not None: + if len(batch) == 3: + raise ValueError( + "You cannot specify `class_weight` " + "and `sample_weight` at the same time." + ) + if len(batch) == 2: + sw = data_adapter_utils.class_weight_to_sample_weights( + batch[1], self.class_weight + ) + batch = batch + (sw,) + return batch + + def _infinite_generator(self): + for i in itertools.count(): + yield self._standardize_batch(self.py_dataset[i]) + + def _finite_generator(self): + indices = range(self.py_dataset.num_batches) + if self.shuffle: + indices = list(indices) + random.shuffle(indices) + + for i in indices: + yield self._standardize_batch(self.py_dataset[i]) + + def _infinite_enqueuer_generator(self): + self.enqueuer.start() + for batch in self.enqueuer.get(): + yield self._standardize_batch(batch) + + def _finite_enqueuer_generator(self): + self.enqueuer.start() + num_batches = self.py_dataset.num_batches + for i, batch in enumerate(self.enqueuer.get()): + yield self._standardize_batch(batch) + if i >= num_batches - 1: + self.enqueuer.stop() + return + + def _get_iterator(self): + if self.enqueuer is None: + if self.py_dataset.num_batches is None: + return self._infinite_generator() + else: + return self._finite_generator() + else: + if self.py_dataset.num_batches is None: + return self._infinite_enqueuer_generator() + else: + return self._finite_enqueuer_generator() + + def get_numpy_iterator(self): + return data_adapter_utils.get_numpy_iterator(self._get_iterator()) + + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self._get_iterator()) + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + num_batches = self.py_dataset.num_batches + if self._output_signature is None: + num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + if num_batches is not None: + num_samples = min(num_samples, num_batches) + batches = [ + self._standardize_batch(self.py_dataset[i]) + for i in range(num_samples) + ] + if len(batches) == 0: + raise ValueError("The PyDataset has length 0") + self._output_signature = data_adapter_utils.get_tensor_spec(batches) + + ds = tf.data.Dataset.from_generator( + self._get_iterator, + output_signature=self._output_signature, + ) + if self.enqueuer is not None: + # The enqueuer does its own multithreading / multiprocesssing to + # prefetch items. Disable the tf.data.Dataset prefetching and + # threading as it interferes. + options = tf.data.Options() + options.autotune.enabled = False + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + else: + ds = ds.prefetch(tf.data.AUTOTUNE) + return ds + + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self._get_iterator()) + + def on_epoch_begin(self): + if self._within_epoch: + raise ValueError( + "`on_epoch_begin` was called twice without `on_epoch_end` " + "having been called." + ) + self._within_epoch = True + if self.enqueuer: + self.enqueuer.start() + self.py_dataset.on_epoch_begin() + + def on_epoch_end(self): + if self.enqueuer: + self.enqueuer.stop() + self.py_dataset.on_epoch_end() + self._within_epoch = False + + @property + def num_batches(self): + return self.py_dataset.num_batches + + @property + def batch_size(self): + return None + + +# Global variables to be shared across processes +_SHARED_SEQUENCES = {} +# We use a Value to provide unique id to different processes. +_SEQUENCE_COUNTER = None + + +# Because multiprocessing pools are inherently unsafe, starting from a clean +# state can be essential to avoiding deadlocks. In order to accomplish this, we +# need to be able to check on the status of Pools that we create. +_DATA_POOLS = weakref.WeakSet() +_WORKER_ID_QUEUE = None # Only created if needed. +_FORCE_THREADPOOL = False + + +def get_pool_class(use_multiprocessing): + global _FORCE_THREADPOOL + if not use_multiprocessing or _FORCE_THREADPOOL: + return multiprocessing.dummy.Pool # ThreadPool + return multiprocessing.Pool + + +def get_worker_id_queue(): + """Lazily create the queue to track worker ids.""" + global _WORKER_ID_QUEUE + if _WORKER_ID_QUEUE is None: + _WORKER_ID_QUEUE = multiprocessing.Queue() + return _WORKER_ID_QUEUE + + +def get_index(uid, i): + """Get the value from the PyDataset `uid` at index `i`. + + To allow multiple PyDatasets to be used at the same time, we use `uid` to + get a specific one. A single PyDataset would cause the validation to + overwrite the training PyDataset. + + This methods is called from worker threads. + + Args: + uid: int, PyDataset identifier + i: index + + Returns: + The value at index `i`. + """ + return _SHARED_SEQUENCES[uid][i] + + +class PyDatasetEnqueuer: + """Base class to enqueue inputs. + + The task of an Enqueuer is to use parallelism to speed up preprocessing. + This is done with processes or threads. + + Example: + + ```python + enqueuer = PyDatasetEnqueuer(...) + enqueuer.start() + datas = enqueuer.get() + for data in datas: + # Use the inputs; training, evaluating, predicting. + # ... stop sometime. + enqueuer.stop() + ``` + + The `enqueuer.get()` should be an infinite stream of data. + """ + + def __init__( + self, + py_dataset, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + ): + self.py_dataset = py_dataset + + global _SEQUENCE_COUNTER + if _SEQUENCE_COUNTER is None: + try: + _SEQUENCE_COUNTER = multiprocessing.Value("i", 0) + except OSError: + # In this case the OS does not allow us to use + # multiprocessing. We resort to an int + # for enqueuer indexing. + _SEQUENCE_COUNTER = 0 + + if isinstance(_SEQUENCE_COUNTER, int): + self.uid = _SEQUENCE_COUNTER + _SEQUENCE_COUNTER += 1 + else: + # Doing Multiprocessing.Value += x is not process-safe. + with _SEQUENCE_COUNTER.get_lock(): + self.uid = _SEQUENCE_COUNTER.value + _SEQUENCE_COUNTER.value += 1 + + self.ready_queue = queue.Queue() + self.future_queue = queue.Queue(max_queue_size) + self.running = False + self.start_stop_lock = threading.Lock() + self.run_thread = None + if use_multiprocessing: + self.executor_fn = self._get_executor_init(workers) + else: + # We do not need the init since it's threads. + self.executor_fn = lambda _: get_pool_class(False)(workers) + + def is_running(self): + """Whether the enqueuer is running. + + This method is thread safe and called from many threads. + + Returns: boolean indicating whether this enqueuer is running. + """ + return self.running + + def start(self): + """Starts the handler's workers. + + This method is thread safe but is called from the main thread. + It is safe to call this method multiple times, extra calls are ignored. + """ + with self.start_stop_lock: + if self.running: + return + self.running = True + self.run_thread = threading.Thread(target=self._run) + self.run_thread.name = f"Worker_{self.uid}" + self.run_thread.daemon = True + self.run_thread.start() + + def stop(self, drain_queue_and_join=True): + """Stops running threads and wait for them to exit, if necessary. + + This method is thread safe and is called from various threads. Note that + the `drain_queue_and_join` argument must be set correctly. + It is safe to call this method multiple times, extra calls are ignored. + + Args: + drain_queue_and_join: set to True to drain the queue of pending + items and wait for the worker thread to complete. Set to False + if invoked from a worker thread to avoid deadlocks. Note that + setting this to False means this enqueuer won't be reused. + """ + with self.start_stop_lock: + if not self.running: + return + self.running = False + + if drain_queue_and_join: + # Drain the `future_queue` and put items in `ready_queue` for + # the next run. + while True: + try: + value = self.future_queue.get(block=True, timeout=0.1) + if isinstance(value, Exception): + raise value # Propagate exception from other thread + inputs = value.get() + self.future_queue.task_done() + if inputs is not None: + self.ready_queue.put(inputs) + except queue.Empty: + break + self.run_thread.join() + + self.run_thread = None + _SHARED_SEQUENCES[self.uid] = None + + def _send_py_dataset(self): + """Sends current Iterable to all workers.""" + # For new processes that may spawn + _SHARED_SEQUENCES[self.uid] = self.py_dataset + + def __del__(self): + self.stop(drain_queue_and_join=False) + + def _run(self): + """Submits request to the executor and queue the `Future` objects.""" + raise NotImplementedError + + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. + + Args: + workers: Number of workers. + + Returns: + Function, a Function to initialize the pool + """ + raise NotImplementedError + + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + This method is called from the main thread. + + Yields: + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + """ + raise NotImplementedError + + +class OrderedEnqueuer(PyDatasetEnqueuer): + """Builds a Enqueuer from a PyDataset. + + Args: + py_dataset: A `keras.utils.PyDataset` object. + use_multiprocessing: use multiprocessing if True, otherwise threading + shuffle: whether to shuffle the data at the beginning of each epoch + """ + + def __init__( + self, + py_dataset, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + shuffle=False, + ): + super().__init__( + py_dataset, workers, use_multiprocessing, max_queue_size + ) + self.shuffle = shuffle + if self.py_dataset.num_batches is None: + # For infinite datasets, `self.indices` is created here once for all + # so that subsequent runs resume from where they stopped. + self.indices = itertools.count() + + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. + + Args: + workers: Number of workers. + + Returns: + Function, a Function to initialize the pool + """ + + def pool_fn(seqs): + pool = get_pool_class(True)( + workers, + initializer=init_pool_generator, + initargs=(seqs, None, get_worker_id_queue()), + ) + _DATA_POOLS.add(pool) + return pool + + return pool_fn + + def _run(self): + """Submits request to the executor and queue the `Future` objects. + + This method is the run method of worker threads. + """ + try: + if self.py_dataset.num_batches is not None: + # For finite datasets, `self.indices` is created here so that + # shuffling creates different a order each time. + indices = range(self.py_dataset.num_batches) + if self.shuffle: + indices = list(indices) + random.shuffle(indices) + self.indices = iter(indices) + self._send_py_dataset() # Share the initial py_dataset + + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: + while self.is_running(): + try: + i = next(self.indices) + self.future_queue.put( + executor.apply_async(get_index, (self.uid, i)), + block=True, + ) + except StopIteration: + break + except Exception as e: + self.future_queue.put(e) # Report exception + + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + This method is called from the main thread. + + Yields: + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + """ + while self.is_running(): + try: + inputs = self.ready_queue.get(block=False) + yield inputs + continue # Retry the ready_queue + except queue.Empty: + pass + + try: + value = self.future_queue.get(block=True, timeout=5) + self.future_queue.task_done() + if isinstance(value, Exception): + raise value # Propagate exception from other thread + inputs = value.get() + if inputs is not None: + yield inputs + except queue.Empty: + pass + except Exception as e: + self.stop(drain_queue_and_join=True) + raise e + + # Note that it is ok to poll the iterator after the initial `start`, + # which may happen before the first `on_epoch_begin`. But it's not ok to + # poll after `on_epoch_end`. + raise ValueError( + "Iterator called after `on_epoch_end` or before `on_epoch_begin`." + ) + + +def init_pool_generator(gens, random_seed=None, id_queue=None): + """Initializer function for pool workers. + + Args: + gens: State which should be made available to worker processes. + random_seed: An optional value with which to seed child processes. + id_queue: A multiprocessing Queue of worker ids. + This is used to indicate that a worker process + was created by Keras. + """ + global _SHARED_SEQUENCES + _SHARED_SEQUENCES = gens + + worker_proc = multiprocessing.current_process() + + # name isn't used for anything, but setting a more descriptive name is + # helpful when diagnosing orphaned processes. + worker_proc.name = f"Keras_worker_{worker_proc.name}" + + if random_seed is not None: + np.random.seed(random_seed + worker_proc.ident) + + if id_queue is not None: + # If a worker dies during init, the pool will just create a replacement. + id_queue.put(worker_proc.ident, block=True, timeout=0.1) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py new file mode 100644 index 000000000000..8cdd5befb3a8 --- /dev/null +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -0,0 +1,453 @@ +import math +import time + +import jax +import numpy as np +import pytest +import tensorflow as tf +import torch +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters import py_dataset_adapter +from keras.src.utils.rng_utils import set_random_seed + + +class ExamplePyDataset(py_dataset_adapter.PyDataset): + def __init__( + self, + x_set, + y_set, + sample_weight=None, + batch_size=32, + delay=0, + infinite=False, + **kwargs, + ): + super().__init__(**kwargs) + self.x, self.y = x_set, y_set + self.batch_size = batch_size + self.sample_weight = sample_weight + self.delay = delay + self.infinite = infinite + + @property + def num_batches(self): + if self.infinite: + return None + return math.ceil(len(self.x) / self.batch_size) + + def __getitem__(self, idx): + # Create artificial delay to test multiprocessing + time.sleep(self.delay) + + if self.infinite: + idx = idx % math.ceil(len(self.x) / self.batch_size) + # Return x, y for batch idx. + low = idx * self.batch_size + # Cap upper bound at array length; the last batch may be smaller + # if the total number of items is not a multiple of batch size. + high = min(low + self.batch_size, len(self.x)) + batch_x = self.x[low:high] + batch_y = self.y[low:high] + if self.sample_weight is not None: + return batch_x, batch_y, self.sample_weight[low:high] + return batch_x, batch_y + + +class DictPyDataset(py_dataset_adapter.PyDataset): + def __init__(self, inputs, batch_size=32, **kwargs): + super().__init__(**kwargs) + self.inputs = inputs + self.batch_size = batch_size + + @property + def num_batches(self): + return math.ceil(len(self.inputs["x"]) / self.batch_size) + + def __getitem__(self, idx): + # Return x, y for batch idx. + low = idx * self.batch_size + # Cap upper bound at array length; the last batch may be smaller + # if the total number of items is not a multiple of batch size. + high = min(low + self.batch_size, len(self.inputs["x"])) + batch_x = self.inputs["x"][low:high] + batch_y = self.inputs["y"][low:high] + batch = {"x": batch_x, "y": batch_y} + return batch + + +class ExceptionPyDataset(py_dataset_adapter.PyDataset): + @property + def num_batches(self): + return 4 + + def __getitem__(self, index): + if index < 2: + return ( + np.random.random((8, 4)).astype("float32"), + np.random.random((8, 2)).astype("float32"), + ) + raise ValueError("Expected exception") + + +@pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Flaky on GPU") +class PyDatasetAdapterTest(testing.TestCase): + @parameterized.named_parameters( + named_product( + [ + { + "testcase_name": "multiprocessing", + "workers": 2, + "use_multiprocessing": True, + "max_queue_size": 10, + "dataset_type": "np", + }, + { + "testcase_name": "multithreading", + "workers": 2, + "use_multiprocessing": False, + "max_queue_size": 10, + "dataset_type": "np", + }, + { + "testcase_name": "single_np", + "dataset_type": "np", + }, + { + "testcase_name": "single_tf", + "dataset_type": "tf", + }, + { + "testcase_name": "single_jax", + "dataset_type": "jax", + }, + { + "testcase_name": "single_torch", + "dataset_type": "torch", + }, + ], + infinite=[True, False], + shuffle=[True, False], + ) + ) + def test_basic_flow( + self, + shuffle, + dataset_type, + infinite, + workers=0, + use_multiprocessing=False, + max_queue_size=0, + ): + if use_multiprocessing and shuffle: + pytest.skip("Starting processes is slow, test fewer variants") + + set_random_seed(1337) + x = np.random.random((64, 4)).astype("float32") + y = np.array([[i, i] for i in range(64)], dtype="float32") + CPU_DEVICES = { + "tensorflow": "CPU:0", + "jax": "cpu:0", + "torch": "cpu", + "numpy": "cpu", + } + with backend.device(CPU_DEVICES[backend.backend()]): + if dataset_type == "tf": + x, y = tf.constant(x), tf.constant(y) + elif dataset_type == "jax": + x, y = jax.numpy.array(x), jax.numpy.array(y) + elif dataset_type == "torch": + x, y = torch.as_tensor(x), torch.as_tensor(y) + py_dataset = ExamplePyDataset( + x, + y, + batch_size=16, + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size, + infinite=infinite, + ) + adapter = py_dataset_adapter.PyDatasetAdapter( + py_dataset, shuffle=shuffle + ) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.Array if dataset_type == "jax" else np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + sample_order = [] + adapter.on_epoch_begin() + for batch in it: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + for i in range(by.shape[0]): + sample_order.append(by[i, 0]) + if infinite: + if len(sample_order) == 64: + adapter.on_epoch_end() + adapter.on_epoch_begin() + elif len(sample_order) >= 128: + break + adapter.on_epoch_end() + + expected_order = list(range(64)) + if infinite: + self.assertAllClose(sample_order, expected_order + expected_order) + elif shuffle: + self.assertNotAllClose(sample_order, expected_order) + self.assertAllClose(sorted(sample_order), expected_order) + else: + self.assertAllClose(sample_order, expected_order) + + # TODO: test sample weights + # TODO: test inference mode (single output) + + def test_class_weight(self): + x = np.random.randint(1, 100, (4, 5)) + y = np.array([0, 1, 2, 1]) + class_w = {0: 2, 1: 1, 2: 3} + py_dataset = ExamplePyDataset(x, y, batch_size=2) + adapter = py_dataset_adapter.PyDatasetAdapter( + py_dataset, shuffle=False, class_weight=class_w + ) + if backend.backend() == "numpy": + gen = adapter.get_numpy_iterator() + elif backend.backend() == "tensorflow": + gen = adapter.get_tf_dataset() + elif backend.backend() == "jax": + gen = adapter.get_jax_iterator() + elif backend.backend() == "torch": + gen = adapter.get_torch_dataloader() + + for index, batch in enumerate(gen): + # Batch is a tuple of (x, y, class_weight) + self.assertLen(batch, 3) + batch = [backend.convert_to_numpy(x) for x in batch] + # Let's verify the data and class weights match for each element + # of the batch (2 elements in each batch) + for sub_elem in range(2): + self.assertAllEqual(batch[0][sub_elem], x[index * 2 + sub_elem]) + self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem]) + class_key = np.int32(batch[1][sub_elem]) + self.assertEqual(batch[2][sub_elem], class_w[class_key]) + + self.assertEqual(index, 1) # 2 batches + + def test_speedup(self): + x = np.random.random((40, 4)) + y = np.random.random((40, 2)) + + no_speedup_py_dataset = ExamplePyDataset( + x, + y, + batch_size=4, + delay=0.2, + ) + adapter = py_dataset_adapter.PyDatasetAdapter( + no_speedup_py_dataset, shuffle=False + ) + gen = adapter.get_numpy_iterator() + t0 = time.time() + for batch in gen: + pass + no_speedup_time = time.time() - t0 + + speedup_py_dataset = ExamplePyDataset( + x, + y, + batch_size=4, + workers=4, + # TODO: the github actions runner may have performance issue with + # multiprocessing + # use_multiprocessing=True, + max_queue_size=8, + delay=0.2, + ) + adapter = py_dataset_adapter.PyDatasetAdapter( + speedup_py_dataset, shuffle=False + ) + gen = adapter.get_numpy_iterator() + t0 = time.time() + for batch in gen: + pass + speedup_time = time.time() - t0 + + self.assertLess(speedup_time, no_speedup_time) + + def test_dict_inputs(self): + inputs = { + "x": np.random.random((40, 4)), + "y": np.random.random((40, 2)), + } + py_dataset = DictPyDataset(inputs, batch_size=4) + adapter = py_dataset_adapter.PyDatasetAdapter(py_dataset, shuffle=False) + gen = adapter.get_numpy_iterator() + for batch in gen: + self.assertEqual(len(batch), 2) + bx, by = batch["x"], batch["y"] + self.assertIsInstance(bx, np.ndarray) + self.assertIsInstance(by, np.ndarray) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(bx.shape, (4, 4)) + self.assertEqual(by.shape, (4, 2)) + + ds = adapter.get_tf_dataset() + for batch in ds: + self.assertEqual(len(batch), 2) + bx, by = batch["x"], batch["y"] + self.assertIsInstance(bx, tf.Tensor) + self.assertIsInstance(by, tf.Tensor) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(tuple(bx.shape), (4, 4)) + self.assertEqual(tuple(by.shape), (4, 2)) + + def test_with_different_shapes(self): + class TestPyDataset(py_dataset_adapter.PyDataset): + @property + def num_batches(self): + return 3 + + def __getitem__(self, idx): + if idx == 0: + return np.ones([16, 4], "float32"), np.ones( + [16, 2], "float32" + ) + if idx == 1: + return np.ones([16, 5], "float32"), np.ones( + [16, 2], "float32" + ) + else: + return np.ones([2, 6], "float32"), np.ones( + [2, 2], "float32" + ) + + adapter = py_dataset_adapter.PyDatasetAdapter( + TestPyDataset(), shuffle=False + ) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i == 0: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + elif i == 1: + self.assertEqual(bx.shape, (16, 5)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 6)) + self.assertEqual(by.shape, (2, 2)) + + @parameterized.named_parameters( + [ + { + "testcase_name": "multiprocessing", + "workers": 2, + "use_multiprocessing": True, + "max_queue_size": 10, + }, + { + "testcase_name": "multithreading", + "workers": 2, + "max_queue_size": 10, + }, + { + "testcase_name": "single", + }, + ] + ) + def test_exception_reported( + self, + workers=0, + use_multiprocessing=False, + max_queue_size=0, + ): + if backend.backend() == "jax" and use_multiprocessing is True: + self.skipTest( + "The CI failed for an unknown reason with " + "`use_multiprocessing=True` in the jax backend" + ) + dataset = ExceptionPyDataset( + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=max_queue_size, + ) + adapter = py_dataset_adapter.PyDatasetAdapter(dataset, shuffle=False) + + expected_exception_class = ValueError + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + # tf.data wraps the exception + expected_exception_class = tf.errors.InvalidArgumentError + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + + it = iter(it) + next(it) + next(it) + with self.assertRaisesRegex( + expected_exception_class, "Expected exception" + ): + next(it) + + def test_iterate_finite(self): + py_dataset = ExamplePyDataset( + np.ones((6, 11), dtype="int32"), + np.zeros((6, 11), dtype="int32"), + batch_size=2, + ) + batches = [batch for batch in py_dataset] + self.assertLen(batches, 3) + + def test_iterate_infinite_with_none_num_batches(self): + py_dataset = ExamplePyDataset( + np.ones((6, 11), dtype="int32"), + np.zeros((6, 11), dtype="int32"), + batch_size=2, + infinite=True, + ) + for index, _ in enumerate(py_dataset): + if index >= 10: + break + + def test_iterate_infinite_with_no_len(self): + class NoLenDataset(py_dataset_adapter.PyDataset): + def __getitem__(self, idx): + yield np.ones((2, 11), dtype="int32") + + for index, _ in enumerate(NoLenDataset()): + if index >= 10: + break diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/keras/src/trainers/data_adapters/tf_dataset_adapter.py new file mode 100644 index 000000000000..492deb764c3e --- /dev/null +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -0,0 +1,147 @@ +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class TFDatasetAdapter(DataAdapter): + """Adapter that handles `tf.data.Dataset`.""" + + def __init__(self, dataset, class_weight=None, distribution=None): + """Initialize the TFDatasetAdapter. + + Args: + dataset: The input `tf.data.Dataset` instance. + class_weight: A map where the keys are integer class ids and values + are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`. + distribution: A `keras.distribution.Distribution` instance. Used to + shard the input dataset into per worker/process dataset + instance. + """ + from keras.src.utils.module_utils import tensorflow as tf + + if not isinstance( + dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) + ): + raise ValueError( + "Expected argument `dataset` to be a tf.data.Dataset. " + f"Received: {dataset}" + ) + if class_weight is not None: + dataset = dataset.map( + make_class_weight_map_fn(class_weight) + ).prefetch(tf.data.AUTOTUNE) + if distribution is not None: + dataset = distribution.distribute_dataset(dataset) + self._dataset = dataset + + def get_numpy_iterator(self): + from keras.src.backend.tensorflow.core import convert_to_numpy + + for batch in self._dataset: + yield tree.map_structure( + convert_to_numpy, batch, none_is_leaf=False + ) + + def get_jax_iterator(self): + from keras.src.backend.tensorflow.core import convert_to_numpy + from keras.src.utils.module_utils import tensorflow as tf + + def convert_to_jax(x): + if isinstance(x, tf.SparseTensor): + return data_adapter_utils.tf_sparse_to_jax_sparse(x) + else: + # We use numpy as an intermediary because it is faster. + return convert_to_numpy(x) + + for batch in self._dataset: + yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False) + + def get_tf_dataset(self): + return self._dataset + + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self._dataset) + + @property + def builtin_prefetch(self): + return True + + @property + def num_batches(self): + cardinality = self._dataset.cardinality + if callable(cardinality): + # `dataset.cardinality` is normally expected to be a callable. + cardinality = int(self._dataset.cardinality()) + else: + # However, in the case of `DistributedDataset`, it's a np.int64. + cardinality = int(cardinality) + # Return None for Unknown and Infinite cardinality datasets + if cardinality < 0: + return None + return cardinality + + @property + def batch_size(self): + first_element_spec = tree.flatten(self._dataset.element_spec)[0] + return first_element_spec.shape[0] + + @property + def has_partial_batch(self): + return None + + @property + def partial_batch_size(self): + return None + + +def make_class_weight_map_fn(class_weight): + """Applies class weighting to a `Dataset`. + + The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where + `y` must be a single `Tensor`. + + Args: + class_weight: A map where the keys are integer class ids and values are + the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` + + Returns: + A function that can be used with `tf.data.Dataset.map` to apply class + weighting. + """ + from keras.src.utils.module_utils import tensorflow as tf + + class_weight_tensor = tf.convert_to_tensor( + [ + class_weight.get(int(c), 1.0) + for c in range(max(class_weight.keys()) + 1) + ] + ) + + def class_weights_map_fn(*data): + """Convert `class_weight` to `sample_weight`.""" + x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data) + if sw is not None: + raise ValueError( + "You cannot `class_weight` and `sample_weight` " + "at the same time." + ) + if tree.is_nested(y): + raise ValueError( + "`class_weight` is only supported for Models with a single " + "output." + ) + + if y.shape.rank >= 2: + y_classes = tf.__internal__.smart_cond.smart_cond( + tf.shape(y)[-1] > 1, + lambda: tf.argmax(y, axis=-1, output_type=tf.int32), + lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32), + ) + else: + # Special casing for rank 1, where we can guarantee sparse encoding. + y_classes = tf.cast(tf.round(y), tf.int32) + + cw = tf.gather(class_weight_tensor, y_classes) + return x, y, cw + + return class_weights_map_fn diff --git a/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py new file mode 100644 index 000000000000..c4889f4677f0 --- /dev/null +++ b/keras/src/trainers/data_adapters/tf_dataset_adapter_test.py @@ -0,0 +1,357 @@ +from unittest import mock + +import jax +import numpy as np +import pytest +import tensorflow as tf +import torch + +from keras.src import Sequential +from keras.src import backend +from keras.src import layers +from keras.src import testing +from keras.src.trainers.data_adapters import tf_dataset_adapter + + +class TestTFDatasetAdapter(testing.TestCase): + def test_basic_flow(self): + x = tf.random.normal((34, 4)) + y = tf.random.normal((34, 2)) + base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) + adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds) + + self.assertEqual(adapter.num_batches, 3) + self.assertEqual(adapter.batch_size, None) + self.assertEqual(adapter.has_partial_batch, None) + self.assertEqual(adapter.partial_batch_size, None) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i < 2: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + + def _test_class_weights(self, target_encoding="int"): + x = np.random.random((4, 2)) + if target_encoding == "int": + y = np.array([[0], [1], [2], [3]], dtype="int64") + else: + y = np.array( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype="float32", + ) + + class_weight = { + 0: 0.1, + 1: 0.2, + 2: 0.3, + 3: 0.4, + } + base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) + adapter = tf_dataset_adapter.TFDatasetAdapter( + base_ds, class_weight=class_weight + ) + gen = adapter.get_numpy_iterator() + for batch in gen: + self.assertEqual(len(batch), 3) + _, _, bw = batch + self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4]) + + def test_class_weights_int_targets(self): + self._test_class_weights(target_encoding="int") + + def test_class_weights_categorical_targets(self): + self._test_class_weights(target_encoding="categorical") + + def test_builtin_prefetch(self): + dataset = tf.data.Dataset.range(42) + adapter = tf_dataset_adapter.TFDatasetAdapter(dataset) + self.assertTrue(adapter.builtin_prefetch) + + def test_num_batches(self): + dataset = tf.data.Dataset.range(42) + cardinality = int(dataset.cardinality()) + self.assertEqual(cardinality, 42) + adapter = tf_dataset_adapter.TFDatasetAdapter(dataset) + self.assertEqual(adapter.num_batches, 42) + + # Test for Infinite Cardinality + dataset = tf.data.Dataset.range(42) + dataset = dataset.repeat() + cardinality = int(dataset.cardinality()) + self.assertEqual(cardinality, tf.data.INFINITE_CARDINALITY) + adapter = tf_dataset_adapter.TFDatasetAdapter(dataset) + self.assertIsNone(adapter.num_batches) + + # Test for Unknown Cardinality + dataset = dataset.filter(lambda x: True) + cardinality = int(dataset.cardinality()) + self.assertEqual(cardinality, tf.data.UNKNOWN_CARDINALITY) + adapter = tf_dataset_adapter.TFDatasetAdapter(dataset) + self.assertIsNone(adapter.num_batches) + + def test_invalid_dataset_type(self): + with self.assertRaisesRegex( + ValueError, "Expected argument `dataset` to be a tf.data.Dataset" + ): + invalid_data = "This is not a tf.data.Dataset" + tf_dataset_adapter.TFDatasetAdapter(invalid_data) + + def test_class_weight_and_sample_weight_together(self): + x = np.random.random((4, 2)) + y = np.array([[0], [1], [2], [3]], dtype="int64") + sw = np.array([0.5, 0.5, 0.5, 0.5]) + base_ds = tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(16) + class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4} + + with self.assertRaisesRegex( + ValueError, + "You cannot `class_weight` and `sample_weight` at the same time.", + ): + tf_dataset_adapter.TFDatasetAdapter( + base_ds, class_weight=class_weight + ) + + def test_different_y_shapes_with_class_weight(self): + x = np.random.random((4, 2)) + y = np.array( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype="float32", + ) + base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) + class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4} + adapter = tf_dataset_adapter.TFDatasetAdapter( + base_ds, class_weight=class_weight + ) + gen = adapter.get_numpy_iterator() + for batch in gen: + _, _, bw = batch + self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4]) + + y_sparse = np.array([0, 1, 2, 3], dtype="int64") + base_ds = tf.data.Dataset.from_tensor_slices((x, y_sparse)).batch(16) + adapter = tf_dataset_adapter.TFDatasetAdapter( + base_ds, class_weight=class_weight + ) + gen = adapter.get_numpy_iterator() + for batch in gen: + _, _, bw = batch + self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4]) + + def test_nested_y_with_class_weight(self): + x = np.random.random((4, 2)) + + # Define two target outputs, y1 and y2, for the dataset + y1 = np.array([0, 1, 2, 3], dtype="int64") + y2 = np.array([0, 1, 2, 3], dtype="int64") + + # Create a tf.data Dataset from the input data and two target outputs + base_ds = tf.data.Dataset.from_tensor_slices((x, (y1, y2))).batch(16) + + # Define class weights for potential classes in the output + class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4} + + with self.assertRaisesRegex( + ValueError, + "`class_weight` is only supported for Models with a single output.", + ): + tf_dataset_adapter.TFDatasetAdapter( + base_ds, class_weight=class_weight + ) + + def test_class_weights_map_fn_with_sample_weight(self): + class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4} + class_weights_map_fn = tf_dataset_adapter.make_class_weight_map_fn( + class_weight + ) + + x = np.array([[0.5, 0.5], [0.5, 0.5]]) + y = np.array([[1, 0], [0, 1]]) + sw = np.array([1.0, 1.0]) + + with self.assertRaisesRegex( + ValueError, + "You cannot `class_weight` and `sample_weight` at the same time.", + ): + class_weights_map_fn(x, y, sw) + + def test_class_weights_map_fn_nested_y(self): + class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4} + class_weights_map_fn = tf_dataset_adapter.make_class_weight_map_fn( + class_weight + ) + + x = np.array([[0.5, 0.5]]) + y1 = np.array([1]) + y2 = np.array([0]) + + with self.assertRaisesRegex( + ValueError, + "`class_weight` is only supported for Models with a single output.", + ): + class_weights_map_fn(x, (y1, y2)) + + def test_distribute_dataset(self): + x = tf.random.normal((34, 4)) + y = tf.random.normal((34, 2)) + base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) + + data_distribution = mock.Mock() + # Mimic that there are 2 worker, and each of the worker will get batch + # size of 8 + data_distribution.distribute_dataset = mock.MagicMock( + return_value=base_ds.rebatch(8).shard(2, index=0) + ) + + adapter = tf_dataset_adapter.TFDatasetAdapter( + base_ds, distribution=data_distribution + ) + + self.assertEqual(adapter.num_batches, None) + self.assertEqual(adapter.batch_size, None) + self.assertEqual(adapter.has_partial_batch, None) + self.assertEqual(adapter.partial_batch_size, None) + + gen = adapter.get_numpy_iterator() + for i, batch in enumerate(gen): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, np.ndarray) + self.assertIsInstance(by, np.ndarray) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(bx.dtype, "float32") + if i < 2: + self.assertEqual(bx.shape, (8, 4)) + self.assertEqual(by.shape, (8, 2)) + else: + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + ds = adapter.get_tf_dataset() + for i, batch in enumerate(ds): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, tf.Tensor) + self.assertIsInstance(by, tf.Tensor) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(bx.dtype, "float32") + if i < 2: + self.assertEqual(tuple(bx.shape), (8, 4)) + self.assertEqual(tuple(by.shape), (8, 2)) + else: + self.assertEqual(tuple(bx.shape), (2, 4)) + self.assertEqual(tuple(by.shape), (2, 2)) + + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS and backend.backend() != "numpy", + reason="Backend does not support sparse tensors", + ) + def test_tf_sparse_tensors(self): + x = tf.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 4) + ) + y = tf.SparseTensor( + indices=[[0, 0], [1, 1]], values=[3.0, 4.0], dense_shape=(2, 2) + ) + base_ds = tf.data.Dataset.from_tensors((x, y)) + adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.SparseTensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.experimental.sparse.BCOO + + for batch in it: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + + def test_distributed_datasets_from_function_adapter_properties(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0"]) + + def dataset_fn(input_context): + batch_size = input_context.get_per_replica_batch_size( + global_batch_size=2 + ) + x = tf.random.uniform((32, 4)) + y = tf.random.uniform((32, 2)) + return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size) + + dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) + adapter = tf_dataset_adapter.TFDatasetAdapter(dist_dataset) + self.assertEqual(adapter.num_batches, 16) + self.assertIsNone(adapter.batch_size) + self.assertIsNone(adapter.has_partial_batch) + self.assertIsNone(adapter.partial_batch_size) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + batch_count = 0 + for batch in it: + batch_count += 1 + self.assertEqual(len(batch), 2) + data, labels = batch + self.assertIsInstance(data, expected_class) + self.assertIsInstance(labels, expected_class) + self.assertEqual(data.shape, (2, 4)) + self.assertEqual(labels.shape, (2, 2)) + + self.assertEqual(batch_count, 16) + + @pytest.mark.requires_trainable_backend + def test_distributed_datasets_from_function_model_integration(self): + strategy = tf.distribute.MirroredStrategy(["CPU:0"]) + + def dataset_fn(input_context): + batch_size = input_context.get_per_replica_batch_size( + global_batch_size=2 + ) + x = tf.random.uniform((4, 1)) + y = tf.random.uniform((4, 2)) + return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size) + + dist_dataset = strategy.distribute_datasets_from_function(dataset_fn) + + model = Sequential([layers.Dense(2, input_shape=(1,))]) + model.compile(optimizer="adam", loss="mse") + history = model.fit(dist_dataset, epochs=1) + self.assertIn("loss", history.history) diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py new file mode 100644 index 000000000000..f0b2f524f4dd --- /dev/null +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -0,0 +1,93 @@ +import itertools + +import numpy as np + +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class TorchDataLoaderAdapter(DataAdapter): + """Adapter that handles `torch.utils.data.DataLoader`.""" + + def __init__(self, dataloader): + import torch + + if not isinstance(dataloader, torch.utils.data.DataLoader): + raise ValueError( + f"Expected argument `dataloader` to be an instance of" + f"`torch.utils.data.DataLoader`. Received: {dataloader}" + ) + + self._dataloader = dataloader + self._output_signature = None + self._batch_size = dataloader.batch_size + self._num_batches = None + self._partial_batch_size = None + if hasattr(dataloader.dataset, "__len__"): + self._num_batches = len(dataloader) + if self._batch_size is not None: + self._partial_batch_size = ( + len(dataloader.dataset) % self._batch_size + ) + + def get_numpy_iterator(self): + for batch in self._dataloader: + # shared memory using `np.asarray` + yield tuple( + tree.map_structure( + lambda x: np.asarray(x.cpu()), batch, none_is_leaf=False + ) + ) + + def get_jax_iterator(self): + # We use numpy as an intermediary because it is faster. + return self.get_numpy_iterator() + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + if self._output_signature is None: + batches = list( + itertools.islice( + self._dataloader, + data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC, + ) + ) + self._output_signature = tuple( + data_adapter_utils.get_tensor_spec(batches) + ) + return tf.data.Dataset.from_generator( + self.get_numpy_iterator, + output_signature=self._output_signature, + ) + + def get_torch_dataloader(self): + return self._dataloader + + @property + def builtin_prefetch(self): + prefetch_factor = self._dataloader.prefetch_factor + if prefetch_factor is not None and prefetch_factor > 0: + return True + else: + return False + + @property + def num_batches(self): + return self._num_batches + + @property + def batch_size(self): + return self._batch_size + + @property + def has_partial_batch(self): + if self._partial_batch_size: + return self._partial_batch_size > 0 + else: + return None + + @property + def partial_batch_size(self): + return self._partial_batch_size diff --git a/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py new file mode 100644 index 000000000000..32d6e8444841 --- /dev/null +++ b/keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py @@ -0,0 +1,187 @@ +import math + +import numpy as np +import tensorflow as tf +import torch +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters.torch_data_loader_adapter import ( + TorchDataLoaderAdapter, +) + + +class TestTorchDataLoaderAdapter(testing.TestCase): + def test_basic_dataloader(self): + x = torch.normal(2, 3, size=(34, 4)) + y = torch.normal(1, 3, size=(34, 2)) + ds = torch.utils.data.TensorDataset(x, y) + dataloader = torch.utils.data.DataLoader(ds, batch_size=16) + adapter = TorchDataLoaderAdapter(dataloader) + + self.assertEqual(adapter.num_batches, 3) + self.assertEqual(adapter.batch_size, 16) + self.assertEqual(adapter.has_partial_batch, True) + self.assertEqual(adapter.partial_batch_size, 2) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i < 2: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + + @parameterized.named_parameters( + named_product(batch_size=[None, 3], implements_len=[True, False]) + ) + def test_dataloader_iterable_dataset(self, batch_size, implements_len): + class TestIterableDataset(torch.utils.data.IterableDataset): + def __init__(self): + self.x = torch.normal(2, 3, size=(16, 4)) + self.y = torch.normal(1, 3, size=(16, 2)) + + def __iter__(self): + for _ in range(10): + yield (self.x, self.y) + + class TestIterableDatasetWithLen(TestIterableDataset): + def __len__(self): + return 10 + + ds = ( + TestIterableDatasetWithLen() + if implements_len + else TestIterableDataset() + ) + dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size) + adapter = TorchDataLoaderAdapter(dataloader) + + if implements_len and batch_size: + self.assertEqual(adapter.num_batches, math.ceil(10 / batch_size)) + self.assertEqual(adapter.batch_size, batch_size) + self.assertEqual(adapter.has_partial_batch, True) + self.assertEqual(adapter.partial_batch_size, 10 % batch_size) + elif implements_len: + self.assertEqual(adapter.num_batches, 10) + self.assertEqual(adapter.batch_size, None) + self.assertEqual(adapter.has_partial_batch, None) + self.assertEqual(adapter.partial_batch_size, None) + else: + self.assertIsNone(adapter.num_batches) + self.assertEqual(adapter.batch_size, batch_size) + self.assertIsNone(adapter.has_partial_batch) + self.assertIsNone(adapter.partial_batch_size) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + expected_class = np.ndarray + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + batch_count = 0 + for i, batch in enumerate(it): + batch_count += 1 + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if batch_size: + if i < 3: + self.assertEqual(bx.shape, (batch_size, 16, 4)) + self.assertEqual(by.shape, (batch_size, 16, 2)) + else: + self.assertEqual(bx.shape, (10 % batch_size, 16, 4)) + self.assertEqual(by.shape, (10 % batch_size, 16, 2)) + else: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + + if batch_size: + self.assertEqual(batch_count, math.ceil(10 / batch_size)) + else: + self.assertEqual(batch_count, 10) + + def test_with_different_shapes(self): + x = ( + [np.ones([4], "float32")] * 16 + + [np.ones([5], "float32")] * 16 + + [np.ones([6], "float32")] * 2 + ) + y = np.ones((34, 2), "float32") + ds = torch.utils.data.StackDataset(x, y) + dataloader = torch.utils.data.DataLoader(ds, batch_size=16) + adapter = TorchDataLoaderAdapter(dataloader) + + self.assertEqual(adapter.num_batches, 3) + self.assertEqual(adapter.batch_size, 16) + self.assertEqual(adapter.has_partial_batch, True) + self.assertEqual(adapter.partial_batch_size, 2) + + if backend.backend() == "numpy": + it = adapter.get_numpy_iterator() + elif backend.backend() == "tensorflow": + it = adapter.get_tf_dataset() + elif backend.backend() == "jax": + it = adapter.get_jax_iterator() + elif backend.backend() == "torch": + it = adapter.get_torch_dataloader() + + for i, batch in enumerate(it): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i == 0: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + elif i == 1: + self.assertEqual(bx.shape, (16, 5)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 6)) + self.assertEqual(by.shape, (2, 2)) + + @parameterized.named_parameters(named_product(num_workers=[0, 2])) + def test_builtin_prefetch(self, num_workers): + x = torch.normal(2, 3, size=(34, 4)) + y = torch.normal(1, 3, size=(34, 2)) + ds = torch.utils.data.TensorDataset(x, y) + dataloader = torch.utils.data.DataLoader( + ds, batch_size=16, num_workers=num_workers + ) + adapter = TorchDataLoaderAdapter(dataloader) + if num_workers > 0: + self.assertTrue(adapter.builtin_prefetch) + else: + self.assertFalse(adapter.builtin_prefetch) diff --git a/keras/src/trainers/epoch_iterator.py b/keras/src/trainers/epoch_iterator.py new file mode 100644 index 000000000000..67a603093d8e --- /dev/null +++ b/keras/src/trainers/epoch_iterator.py @@ -0,0 +1,174 @@ +""" +Separation of concerns: + +DataAdapter: + - x, y + - sample_weight + - class_weight + - shuffle + - batch_size + - steps, as it relates to batch_size for array data + +EpochIterator: + - whether to yield numpy or tf data + - steps + - most argument validation + +Trainer: + - steps_per_execution + - validation_split + - validation_data + - callbacks + - validation_freq + - epochs + - initial_epoch + - any backend-specific concern such as distribution + +PyDataset: + - num_workers + - use_multiprocessing + - max_queue_size + +EpochIterator steps: + +1. Look at data type and select correct DataHandler +2. Instantiate DataHandler with correct arguments +3. Raise or warn on unused arguments +4. in __iter__, iterate, either for a fixed number of steps +or until there is no data + +""" + +import contextlib +import warnings + +from keras.src.backend import config +from keras.src.trainers import data_adapters + + +class EpochIterator: + def __init__( + self, + x, + y=None, + sample_weight=None, + batch_size=None, + steps_per_epoch=None, + shuffle=False, + class_weight=None, + steps_per_execution=1, + ): + # Possibly cap steps_per_epoch for debugging runs. + max_steps_per_epoch = config.max_steps_per_epoch() + if max_steps_per_epoch: + if not steps_per_epoch or max_steps_per_epoch < steps_per_epoch: + warnings.warn( + "Limiting steps_per_epoch to %d" % max_steps_per_epoch + ) + steps_per_epoch = max_steps_per_epoch + self.steps_per_epoch = steps_per_epoch + self.steps_per_execution = steps_per_execution + self._current_iterator = None + self._epoch_iterator = None + self._steps_seen = 0 + self.data_adapter = data_adapters.get_data_adapter( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + shuffle=shuffle, + class_weight=class_weight, + ) + self._num_batches = self.data_adapter.num_batches + + def _get_iterator(self): + return self.data_adapter.get_numpy_iterator() + + def _interrupted_warning(self): + warnings.warn( + "Your input ran out of data; interrupting training. " + "Make sure that your dataset or generator can generate " + "at least `steps_per_epoch * epochs` batches. " + "You may need to use the `.repeat()` " + "function when building your dataset.", + stacklevel=2, + ) + + def reset(self): + self._current_iterator = None + self._num_batches = self.data_adapter.num_batches + self._steps_seen = 0 + self._epoch_iterator = None + self.data_adapter.on_epoch_end() + + def _enumerate_iterator(self): + self.data_adapter.on_epoch_begin() + steps_per_epoch = self.steps_per_epoch or self._num_batches or -1 + + if steps_per_epoch > 0: + if self._current_iterator is None or self.steps_per_epoch is None: + self._current_iterator = iter(self._get_iterator()) + self._steps_seen = 0 + for step in range(0, steps_per_epoch, self.steps_per_execution): + if self._num_batches and self._steps_seen >= self._num_batches: + if self.steps_per_epoch: + self._interrupted_warning() + break + self._steps_seen += self.steps_per_execution + yield ( + step, + step + self.steps_per_execution - 1, + self._current_iterator, + ) + if self._num_batches and self._steps_seen >= self._num_batches: + self._current_iterator = iter(self._get_iterator()) + self._steps_seen = 0 + else: + iterator = iter(self._get_iterator()) + step = -self.steps_per_execution + while True: + step += self.steps_per_execution + self._steps_seen = step + self.steps_per_execution + yield step, step + self.steps_per_execution - 1, iterator + self.data_adapter.on_epoch_end() + + def __iter__(self): + self._epoch_iterator = self._enumerate_iterator() + return self + + def __next__(self): + buffer = [] + begin_step, end_step, iterator = next(self._epoch_iterator) + with self.catch_stop_iteration(): + for _ in range(self.steps_per_execution): + data = next(iterator) + buffer.append(data) + return begin_step, end_step, buffer + if buffer: + return begin_step, end_step, buffer + raise StopIteration + + def enumerate_epoch(self): + for begin_step, end_step, data in self: + yield begin_step, end_step, data + + @contextlib.contextmanager + def catch_stop_iteration(self): + """Catches errors when an iterator runs out of data.""" + try: + yield + except StopIteration: + if self._num_batches is None: + self._num_batches = self._steps_seen + self._interrupted_warning() + self._current_iterator = None + self.data_adapter.on_epoch_end() + + @property + def num_batches(self): + if self.steps_per_epoch: + return self.steps_per_epoch + # Either copied from the data_adapter, or + # inferred at the end of an iteration. + return self._num_batches diff --git a/keras/src/trainers/epoch_iterator_test.py b/keras/src/trainers/epoch_iterator_test.py new file mode 100644 index 000000000000..e674c3220a9b --- /dev/null +++ b/keras/src/trainers/epoch_iterator_test.py @@ -0,0 +1,233 @@ +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.trainers import data_adapters +from keras.src.trainers import epoch_iterator + + +class TestEpochIterator(testing.TestCase): + @parameterized.named_parameters( + [("iterator", "iterator"), ("enumerate_epoch", "enumerate_epoch")] + ) + def test_basic_flow(self, call_type): + x = np.random.random((100, 16)) + y = np.random.random((100, 4)) + sample_weight = np.random.random((100,)) + batch_size = 16 + shuffle = True + iterator = epoch_iterator.EpochIterator( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + shuffle=shuffle, + ) + steps_seen = [] + if call_type == "iterator": + generator = iterator + else: + generator = iterator.enumerate_epoch() + for begin_step, end_step, batch in generator: + batch = batch[0] + steps_seen.append(begin_step) + self.assertEqual(begin_step, end_step) + self.assertEqual(len(batch), 3) + self.assertIsInstance(batch[0], np.ndarray) + self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6]) + + def test_insufficient_data(self): + batch_size = 8 + steps_per_epoch = 6 + dataset_size = batch_size * (steps_per_epoch - 2) + x = np.arange(dataset_size).reshape((dataset_size, 1)) + y = x * 2 + iterator = epoch_iterator.EpochIterator( + x=x, + y=y, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + ) + steps_seen = [] + with pytest.warns(match="Your input ran out of data"): + for step, _, _ in iterator: + steps_seen.append(step) + self.assertLen(steps_seen, steps_per_epoch - 2) + + self.assertIsInstance(iterator, epoch_iterator.EpochIterator) + + def test_unsupported_y_arg_tfdata(self): + with self.assertRaisesRegex(ValueError, "`y` should not be passed"): + x = tf.data.Dataset.from_tensor_slices(np.random.random((100, 16))) + y = np.random.random((100, 4)) + _ = epoch_iterator.EpochIterator(x=x, y=y) + + def test_unsupported_sample_weights_arg_tfdata(self): + with self.assertRaisesRegex( + ValueError, "`sample_weights` should not be passed" + ): + x = tf.data.Dataset.from_tensor_slices(np.random.random((100, 16))) + sample_weights = np.random.random((100,)) + _ = epoch_iterator.EpochIterator(x=x, sample_weight=sample_weights) + + @pytest.mark.skipif( + backend.backend() != "torch", reason="Need to import torch" + ) + def test_torch_dataloader(self): + import torch + + class ExampleTorchDataset(torch.utils.data.Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + torch_dataset = ExampleTorchDataset( + np.random.random((64, 2)), np.random.random((64, 1)) + ) + torch_dataloader = torch.utils.data.DataLoader( + torch_dataset, batch_size=8, shuffle=True + ) + iterator = epoch_iterator.EpochIterator(torch_dataloader) + for _, _, batch in iterator: + batch = batch[0] + self.assertEqual(batch[0].shape, (8, 2)) + self.assertEqual(batch[1].shape, (8, 1)) + + @pytest.mark.skipif( + backend.backend() != "torch", reason="Need to import torch" + ) + def test_unsupported_y_arg_torch_dataloader(self): + import torch + + class ExampleTorchDataset(torch.utils.data.Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + torch_dataset = ExampleTorchDataset( + np.random.random((100, 16)), np.random.random((100, 4)) + ) + x = torch.utils.data.DataLoader( + torch_dataset, batch_size=8, shuffle=True + ) + y = np.random.random((100, 4)) + + with self.assertRaisesRegex( + ValueError, + "When providing `x` as a torch DataLoader, `y` should not", + ): + _ = epoch_iterator.EpochIterator(x=x, y=y) + + @pytest.mark.skipif( + backend.backend() != "torch", reason="Need to import torch" + ) + def test_unsupported_sample_weights_arg_torch_dataloader(self): + import torch + + class ExampleTorchDataset(torch.utils.data.Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + torch_dataset = ExampleTorchDataset( + np.random.random((100, 16)), np.random.random((100, 4)) + ) + x = torch.utils.data.DataLoader( + torch_dataset, batch_size=8, shuffle=True + ) + sample_weights = np.random.random((100,)) + + with self.assertRaisesRegex( + ValueError, + "When providing `x` as a torch DataLoader, `sample_weights`", + ): + _ = epoch_iterator.EpochIterator(x=x, sample_weight=sample_weights) + + def test_python_generator_input(self): + def generator_example(): + for i in range(100): + yield (np.array([i]), np.array([i * 2])) + + x = generator_example() + epoch_iter = epoch_iterator.EpochIterator(x=x) + self.assertIsInstance( + epoch_iter.data_adapter, + data_adapters.GeneratorDataAdapter, + ) + + def test_unrecognized_data_type(self): + x = "unsupported_data" + with self.assertRaisesRegex(ValueError, "Unrecognized data type"): + _ = epoch_iterator.EpochIterator(x=x) + + @parameterized.named_parameters( + [ + {"testcase_name": "infinite", "infinite": True}, + {"testcase_name": "finite", "infinite": False}, + ] + ) + def test_epoch_callbacks(self, infinite): + class TestPyDataset(data_adapters.py_dataset_adapter.PyDataset): + def __init__( + self, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + infinite=False, + ): + super().__init__(workers, use_multiprocessing, max_queue_size) + self.data = np.random.rand(64, 2) + self.batch_size = 16 + self.infinite = infinite + + # check that callbacks are called in the correct order + self.tracker = [] + + @property + def num_batches(self): + if self.infinite: + return None + return len(self.data) // self.batch_size + + def on_epoch_begin(self): + self.tracker.append(1) + + def __getitem__(self, index): + idx = index % 2 + return self.data[ + idx * self.batch_size : (idx + 1) * self.batch_size + ] + + def on_epoch_end(self): + self.tracker.append(2) + + ds = TestPyDataset(infinite=infinite) + epoch_iter = epoch_iterator.EpochIterator(x=ds, steps_per_epoch=10) + + num_epochs = 5 + for epoch in range(num_epochs): + for _, _, _ in epoch_iter: + pass + + self.assertAllEqual(ds.tracker, [1, 2] * num_epochs) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py new file mode 100644 index 000000000000..bac422db249c --- /dev/null +++ b/keras/src/trainers/trainer.py @@ -0,0 +1,1156 @@ +import inspect +import platform +import warnings + +from keras.src import backend +from keras.src import metrics as metrics_module +from keras.src import ops +from keras.src import optimizers +from keras.src import tree +from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer +from keras.src.saving import serialization_lib +from keras.src.trainers.compile_utils import CompileLoss +from keras.src.trainers.compile_utils import CompileMetrics +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.utils import python_utils +from keras.src.utils import traceback_utils +from keras.src.utils import tracking + + +class Trainer: + def __init__(self): + self._lock = False + self._run_eagerly = False + self._jit_compile = None + self.compiled = False + self.loss = None + self.steps_per_execution = 1 + # Can be set by callbacks in on_train_begin + self._initial_epoch = None + self._compute_loss_has_training_arg = ( + "training" in inspect.signature(self.compute_loss).parameters + ) + + # Placeholders used in `compile` + self._compile_loss = None + self._compile_metrics = None + self._loss_tracker = None + + @traceback_utils.filter_traceback + @tracking.no_automatic_dependency_tracking + def compile( + self, + optimizer="rmsprop", + loss=None, + loss_weights=None, + metrics=None, + weighted_metrics=None, + run_eagerly=False, + steps_per_execution=1, + jit_compile="auto", + auto_scale_loss=True, + ): + """Configures the model for training. + + Example: + + ```python + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=1e-3), + loss=keras.losses.BinaryCrossentropy(), + metrics=[ + keras.metrics.BinaryAccuracy(), + keras.metrics.FalseNegatives(), + ], + ) + ``` + + Args: + optimizer: String (name of optimizer) or optimizer instance. See + `keras.optimizers`. + loss: Loss function. May be a string (name of loss function), or + a `keras.losses.Loss` instance. See `keras.losses`. A + loss function is any callable with the signature + `loss = fn(y_true, y_pred)`, where `y_true` are the ground truth + values, and `y_pred` are the model's predictions. + `y_true` should have shape `(batch_size, d0, .. dN)` + (except in the case of sparse loss functions such as + sparse categorical crossentropy which expects integer arrays of + shape `(batch_size, d0, .. dN-1)`). + `y_pred` should have shape `(batch_size, d0, .. dN)`. + The loss function should return a float tensor. + loss_weights: Optional list or dictionary specifying scalar + coefficients (Python floats) to weight the loss contributions of + different model outputs. The loss value that will be minimized + by the model will then be the *weighted sum* of all individual + losses, weighted by the `loss_weights` coefficients. If a list, + it is expected to have a 1:1 mapping to the model's outputs. If + a dict, it is expected to map output names (strings) to scalar + coefficients. + metrics: List of metrics to be evaluated by the model during + training and testing. Each of this can be a string (name of a + built-in function), function or a `keras.metrics.Metric` + instance. See `keras.metrics`. Typically you will use + `metrics=['accuracy']`. A function is any callable with the + signature `result = fn(y_true, _pred)`. To specify different + metrics for different outputs of a multi-output model, you could + also pass a dictionary, such as + `metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`. + You can also pass a list to specify a metric or a list of + metrics for each output, such as + `metrics=[['accuracy'], ['accuracy', 'mse']]` + or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass + the strings 'accuracy' or 'acc', we convert this to one of + `keras.metrics.BinaryAccuracy`, + `keras.metrics.CategoricalAccuracy`, + `keras.metrics.SparseCategoricalAccuracy` based on the + shapes of the targets and of the model output. A similar + conversion is done for the strings `"crossentropy"` + and `"ce"` as well. + The metrics passed here are evaluated without sample weighting; + if you would like sample weighting to apply, you can specify + your metrics via the `weighted_metrics` argument instead. + weighted_metrics: List of metrics to be evaluated and weighted by + `sample_weight` or `class_weight` during training and testing. + run_eagerly: Bool. If `True`, this model's forward pass + will never be compiled. It is recommended to leave this + as `False` when training (for best performance), + and to set it to `True` when debugging. + steps_per_execution: Int. The number of batches to run + during each a single compiled function call. Running multiple + batches inside a single compiled function call can + greatly improve performance on TPUs or small models with a large + Python overhead. At most, one full epoch will be run each + execution. If a number larger than the size of the epoch is + passed, the execution will be truncated to the size of the + epoch. Note that if `steps_per_execution` is set to `N`, + `Callback.on_batch_begin` and `Callback.on_batch_end` methods + will only be called every `N` batches (i.e. before/after + each compiled function execution). + Not supported with the PyTorch backend. + jit_compile: Bool or `"auto"`. Whether to use XLA compilation when + compiling a model. For `jax` and `tensorflow` backends, + `jit_compile="auto"` enables XLA compilation if the model + supports it, and disabled otherwise. + For `torch` backend, `"auto"` will default to eager + execution and `jit_compile=True` will run with `torch.compile` + with the `"inductor"` backend. + auto_scale_loss: Bool. If `True` and the model dtype policy is + `"mixed_float16"`, the passed optimizer will be automatically + wrapped in a `LossScaleOptimizer`, which will dynamically + scale the loss to prevent underflow. + """ + optimizer = optimizers.get(optimizer) + self.optimizer = optimizer + if ( + auto_scale_loss + and self.dtype_policy.name == "mixed_float16" + and self.optimizer + and not isinstance(self.optimizer, LossScaleOptimizer) + ): + self.optimizer = LossScaleOptimizer( + self.optimizer, name="loss_scale_optimizer" + ) + if hasattr(self, "output_names"): + output_names = self.output_names + else: + output_names = None + if loss is not None: + self._compile_loss = CompileLoss( + loss, loss_weights, output_names=output_names + ) + self.loss = loss + if metrics is not None or weighted_metrics is not None: + self._compile_metrics = CompileMetrics( + metrics, weighted_metrics, output_names=output_names + ) + if jit_compile == "auto": + if run_eagerly: + jit_compile = False + else: + jit_compile = self._resolve_auto_jit_compile() + if jit_compile and run_eagerly: + jit_compile = False + warnings.warn( + "If `run_eagerly` is True, then `jit_compile` " + "cannot also be True. Disabling `jit_compile`.", + stacklevel=2, + ) + + self.jit_compile = jit_compile + self.run_eagerly = run_eagerly + self.stop_training = False + self.compiled = True + self._loss_tracker = metrics_module.Mean(name="loss") + self.steps_per_execution = steps_per_execution + + self.train_function = None + self.test_function = None + self.predict_function = None + + self._compile_config = serialization_lib.SerializableDict( + optimizer=optimizer, + loss=loss, + loss_weights=loss_weights, + metrics=metrics, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + steps_per_execution=steps_per_execution, + jit_compile=jit_compile, + ) + + @property + def jit_compile(self): + if self._jit_compile is None: + # Value was never set. Resolve it now. + self._jit_compile = self._resolve_auto_jit_compile() + return self._jit_compile + + @jit_compile.setter + def jit_compile(self, value): + if value and not model_supports_jit(self): + warnings.warn( + "Model doesn't support `jit_compile=True`. " + "Proceeding with `jit_compile=False`." + ) + self._jit_compile = False + else: + self._jit_compile = value + + def _resolve_auto_jit_compile(self): + if backend.backend() == "torch": + # jit_compile = "auto" with the pytorch backend defaults to eager + return False + + if backend.backend() == "tensorflow": + import tensorflow as tf + + devices = tf.config.list_physical_devices() + if not list(filter(lambda x: x.device_type != "CPU", devices)): + # Disable XLA on CPU-only machines. + return False + + if self._distribute_strategy: + # Disable XLA with tf.distribute + return False + + if model_supports_jit(self): + return True + return False + + @property + def run_eagerly(self): + return self._run_eagerly + + @run_eagerly.setter + def run_eagerly(self, value): + self._run_eagerly = value + + @property + def metrics(self): + # Order: loss tracker, individual loss trackers, compiled metrics, + # custom metrics, sublayer metrics. + metrics = [] + if self.compiled: + if self._loss_tracker is not None: + metrics.append(self._loss_tracker) + if self._compile_metrics is not None: + metrics.append(self._compile_metrics) + if self._compile_loss is not None: + metrics.extend(self._compile_loss.metrics) + metrics.extend(self._metrics) + for layer in self._flatten_layers(include_self=False): + if isinstance(layer, Trainer): + # All Trainer-related metrics in sublayers should be ignored + # because a new Trainer has been instantiated. + continue + metrics.extend(layer.metrics) + return metrics + + @property + def metrics_names(self): + return [m.name for m in self.metrics] + + def reset_metrics(self): + for m in self.metrics: + m.reset_state() + + def _get_own_metrics(self): + metrics = [] + if self._loss_tracker is not None: + metrics.append(self._loss_tracker) + if self._compile_metrics is not None: + metrics.append(self._compile_metrics) + if self._compile_loss is not None: + metrics.extend(self._compile_loss.metrics) + metrics.extend(self._metrics) + return metrics + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + """Compute the total loss, validate it, and return it. + + Subclasses can optionally override this method to provide custom loss + computation logic. + + Example: + + ```python + class MyModel(Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = metrics.Mean(name='loss') + + def compute_loss(self, x, y, y_pred, sample_weight, training=True): + loss = ops.mean((y_pred - y) ** 2) + loss += ops.sum(self.losses) + self.loss_tracker.update_state(loss) + return loss + + def reset_metrics(self): + self.loss_tracker.reset_state() + + @property + def metrics(self): + return [self.loss_tracker] + + inputs = layers.Input(shape=(10,), name='my_input') + outputs = layers.Dense(10)(inputs) + model = MyModel(inputs, outputs) + model.add_loss(ops.sum(outputs)) + + optimizer = SGD() + model.compile(optimizer, loss='mse', steps_per_execution=10) + dataset = ... + model.fit(dataset, epochs=2, steps_per_epoch=10) + print(f"Custom loss: {model.loss_tracker.result()}") + ``` + + Args: + x: Input data. + y: Target data. + y_pred: Predictions returned by the model (output of `model(x)`) + sample_weight: Sample weights for weighting the loss function. + training: Whether we are training or evaluating the model. + + Returns: + The total loss as a scalar tensor, or `None` if no loss results + (which is the case when called by `Model.test_step`). + """ + # The default implementation does not use `x` or `training`. + del x + del training + losses = [] + if self._compile_loss is not None: + loss = self._compile_loss(y, y_pred, sample_weight) + if loss is not None: + losses.append(loss) + for loss in self.losses: + losses.append(self._aggregate_additional_loss(loss)) + if backend.backend() != "jax" and len(losses) == 0: + raise ValueError( + "No loss to compute. Provide a `loss` argument in `compile()`." + ) + if len(losses) == 1: + total_loss = losses[0] + elif len(losses) == 0: + total_loss = ops.zeros(()) + else: + total_loss = ops.sum(losses) + return total_loss + + def _compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + """Backwards compatibility wrapper for `compute_loss`. + + This should be used instead `compute_loss` within `train_step` and + `test_step` to support overrides of `compute_loss` that may not have + the `training` argument, as this argument was added in Keras 3.3. + """ + if self._compute_loss_has_training_arg: + return self.compute_loss( + x, y, y_pred, sample_weight, training=training + ) + else: + return self.compute_loss(x, y, y_pred, sample_weight) + + def _aggregate_additional_loss(self, loss): + """Aggregates losses from `add_loss`, regularizers and sublayers. + + Args: + loss: A tensor representing the additional loss to aggregate. + + Returns: + A tensor representing the summed loss, cast to the `floatx()` if + necessary. + """ + if not backend.is_float_dtype(loss.dtype): + loss = ops.cast(loss, dtype=backend.floatx()) + return ops.sum(loss) + + def stateless_compute_loss( + self, + trainable_variables, + non_trainable_variables, + metrics_variables, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + var_mapping = list(zip(self.trainable_variables, trainable_variables)) + var_mapping.extend( + zip(self.non_trainable_variables, non_trainable_variables) + ) + var_mapping.extend(zip(self.metrics_variables, metrics_variables)) + with backend.StatelessScope(state_mapping=var_mapping) as scope: + # Note that this is needed for the regularization loss, which need + # the latest value of train/non-trainable variables. + loss = self._compute_loss( + x, + y, + y_pred, + sample_weight=sample_weight, + training=training, + ) + + # Update non trainable vars (may have been updated in compute_loss) + non_trainable_variables = [] + for v in self.non_trainable_variables: + new_v = scope.get_current_value(v) + non_trainable_variables.append(new_v) + + # Update metrics vars (may have been updated in compute_loss) + metrics_variables = [] + for v in self.metrics_variables: + new_v = scope.get_current_value(v) + metrics_variables.append(new_v) + return loss, ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) + + def compute_metrics(self, x, y, y_pred, sample_weight=None): + """Update metric states and collect all metrics to be returned. + + Subclasses can optionally override this method to provide custom metric + updating and collection logic. Custom metrics are not passed in + `compile()`, they can be created in `__init__` or `build`. They are + automatically tracked and returned by `self.metrics`. + + Example: + + ```python + class MyModel(Sequential): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.custom_metric = MyMetric(name="custom_metric") + + def compute_metrics(self, x, y, y_pred, sample_weight): + # This super call updates metrics from `compile` and returns + # results for all metrics listed in `self.metrics`. + metric_results = super().compute_metrics( + x, y, y_pred, sample_weight) + + # `metric_results` contains the previous result for + # `custom_metric`, this is where we update it. + self.custom_metric.update_state(x, y, y_pred, sample_weight) + metric_results['custom_metric'] = self.custom_metric.result() + return metric_results + ``` + + Args: + x: Input data. + y: Target data. + y_pred: Predictions returned by the model output of `model.call(x)`. + sample_weight: Sample weights for weighting the loss function. + + Returns: + A `dict` containing values that will be passed to + `keras.callbacks.CallbackList.on_train_batch_end()`. Typically, + the values of the metrics listed in `self.metrics` are returned. + Example: `{'loss': 0.2, 'accuracy': 0.7}`. + """ + del x # The default implementation does not use `x`. + if self._compile_metrics is not None: + self._compile_metrics.update_state(y, y_pred, sample_weight) + return self.get_metrics_result() + + def get_metrics_result(self): + """Returns the model's metrics values as a dict. + + If any of the metric result is a dict (containing multiple metrics), + each of them gets added to the top level returned dict of this method. + + Returns: + A `dict` containing values of the metrics listed in `self.metrics`. + Example: `{'loss': 0.2, 'accuracy': 0.7}`. + """ + return_metrics = {} + for metric in self.metrics: + result = metric.result() + if isinstance(result, dict): + return_metrics.update(result) + else: + return_metrics[metric.name] = result + return python_utils.pythonify_logs(return_metrics) + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + """Trains the model for a fixed number of epochs (dataset iterations). + + Args: + x: Input data. It can be: + - A NumPy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A backend-native tensor, or a list of tensors + (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, + if the model has named inputs. + - A `keras.utils.PyDataset` returning `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `tf.data.Dataset` yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `torch.utils.data.DataLoader` yielding `(inputs, targets)` + or `(inputs, targets, sample_weights)`. + - A Python generator function yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + y: Target data. Like the input data `x`, it can be either NumPy + array(s) or backend-native tensor(s). If `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or a Python generator function, + `y` should not be specified since targets will be obtained from + `x`. + batch_size: Integer or `None`. + Number of samples per gradient update. + If unspecified, `batch_size` will default to 32. + Do not specify the `batch_size` if your input data `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided + (unless the `steps_per_epoch` flag is set to + something other than None). + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. + verbose: `"auto"`, 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + "auto" becomes 1 for most cases. + Note that the progress bar is not + particularly useful when logged to a file, + so `verbose=2` is recommended when not running interactively + (e.g., in a production environment). Defaults to `"auto"`. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during training. + See `keras.callbacks`. Note + `keras.callbacks.ProgbarLogger` and + `keras.callbacks.History` callbacks are created + automatically and need not be passed to `model.fit()`. + `keras.callbacks.ProgbarLogger` is created + or not based on the `verbose` argument in `model.fit()`. + validation_split: Float between 0 and 1. + Fraction of the training data to be used as validation data. + The model will set apart this fraction of the training data, + will not train on it, and will evaluate the loss and any model + metrics on this data at the end of each epoch. The validation + data is selected from the last samples in the `x` and `y` data + provided, before shuffling. + This argument is only supported when `x` and `y` are made of + NumPy arrays or tensors. + If both `validation_data` and `validation_split` are provided, + `validation_data` will override `validation_split`. + validation_data: Data on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. Thus, note the fact + that the validation loss of data provided using + `validation_split` or `validation_data` is not affected by + regularization layers like noise and dropout. + `validation_data` will override `validation_split`. + It can be: + - A tuple `(x_val, y_val)` of NumPy arrays or tensors. + - A tuple `(x_val, y_val, val_sample_weights)` of NumPy + arrays. + - A `keras.utils.PyDataset`, a `tf.data.Dataset`, a + `torch.utils.data.DataLoader` yielding `(inputs, targets)` or a + Python generator function yielding `(x_val, y_val)` or + `(inputs, targets, sample_weights)`. + shuffle: Boolean, whether to shuffle the training data before each + epoch. This argument is ignored when `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. When `class_weight` is specified + and targets have a rank of 2 or greater, either `y` must be + one-hot encoded, or an explicit final dimension of `1` must + be included for sparse class labels. + sample_weight: Optional NumPy array or tensor of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + NumPy array or tensor with the same length as the input samples + (1:1 mapping between weights and samples), or in the case of + temporal data, you can pass a 2D NumPy array or tensor with + shape `(samples, sequence_length)` to apply a different weight + to every timestep of every sample. + This argument is not supported when `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function. + Instead, provide `sample_weights` as the third element of `x`. + Note that sample weighting does not apply to metrics specified + via the `metrics` argument in `compile()`. To apply sample + weighting to your metrics, you can specify them via the + `weighted_metrics` in `compile()` instead. + initial_epoch: Integer. + Epoch at which to start training + (useful for resuming a previous training run). + steps_per_epoch: Integer or `None`. + Total number of steps (batches of samples) before declaring one + epoch finished and starting the next epoch. When training with + input tensors or NumPy arrays, the default `None` means that the + value used is the number of samples in your dataset divided by + the batch size, or 1 if that cannot be determined. + If `x` is a `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function, the + epoch will run until the input dataset is exhausted. When + passing an infinitely repeating dataset, you must specify the + `steps_per_epoch` argument, otherwise the training will run + indefinitely. + validation_steps: Integer or `None`. + Only relevant if `validation_data` is provided. + Total number of steps (batches of samples) to draw before + stopping when performing validation at the end of every epoch. + If `validation_steps` is `None`, validation will run until the + `validation_data` dataset is exhausted. In the case of an + infinitely repeating dataset, it will run indefinitely. If + `validation_steps` is specified and only part of the dataset + is consumed, the evaluation will start from the beginning of the + dataset at each epoch. This ensures that the same validation + samples are used every time. + validation_batch_size: Integer or `None`. + Number of samples per validation batch. + If unspecified, will default to `batch_size`. + Do not specify the `validation_batch_size` if your data is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + validation_freq: Only relevant if validation data is provided. + Specifies how many training epochs to run + before a new validation run is performed, + e.g. `validation_freq=2` runs validation every 2 epochs. + + Unpacking behavior for iterator-like inputs: + A common pattern is to pass an iterator like object such as a + `tf.data.Dataset` or a `keras.utils.PyDataset` to `fit()`, + which will in fact yield not only features (`x`) + but optionally targets (`y`) and sample weights (`sample_weight`). + Keras requires that the output of such iterator-likes be + unambiguous. The iterator should return a tuple + of length 1, 2, or 3, where the optional second and third elements + will be used for `y` and `sample_weight` respectively. + Any other type provided will be wrapped in + a length-one tuple, effectively treating everything as `x`. When + yielding dicts, they should still adhere to the top-level tuple + structure, + e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate + features, targets, and weights from the keys of a single dict. + A notable unsupported data type is the `namedtuple`. The reason is + that it behaves like both an ordered datatype (tuple) and a mapping + datatype (dict). So given a namedtuple of the form: + `namedtuple("example_tuple", ["y", "x"])` + it is ambiguous whether to reverse the order of the elements when + interpreting the value. Even worse is a tuple of the form: + `namedtuple("other_tuple", ["x", "y", "z"])` + where it is unclear if the tuple was intended to be unpacked + into `x`, `y`, and `sample_weight` or passed through + as a single element to `x`. + + Returns: + A `History` object. Its `History.history` attribute is + a record of training loss values and metrics values + at successive epochs, as well as validation loss values + and validation metrics values (if applicable). + """ + raise NotImplementedError + + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + """Returns the loss value & metrics values for the model in test mode. + + Computation is done in batches (see the `batch_size` arg.) + + Args: + x: Input data. It can be: + - A NumPy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A backend-native tensor, or a list of tensors + (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, + if the model has named inputs. + - A `keras.utils.PyDataset` returning `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `tf.data.Dataset` yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `torch.utils.data.DataLoader` yielding `(inputs, targets)` + or `(inputs, targets, sample_weights)`. + - A Python generator function yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + y: Target data. Like the input data `x`, it can be either NumPy + array(s) or backend-native tensor(s). If `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or a Python generator function, + `y` should not be specified since targets will be obtained from + `x`. + batch_size: Integer or `None`. + Number of samples per batch of computation. + If unspecified, `batch_size` will default to 32. + Do not specify the `batch_size` if your input data `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + verbose: `"auto"`, 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = single line. + `"auto"` becomes 1 for most cases. + Note that the progress bar is not + particularly useful when logged to a file, so `verbose=2` is + recommended when not running interactively + (e.g. in a production environment). Defaults to `"auto"`. + sample_weight: Optional NumPy array or tensor of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + NumPy array or tensor with the same length as the input samples + (1:1 mapping between weights and samples), or in the case of + temporal data, you can pass a 2D NumPy array or tensor with + shape `(samples, sequence_length)` to apply a different weight + to every timestep of every sample. + This argument is not supported when `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function. + Instead, provide `sample_weights` as the third element of `x`. + Note that sample weighting does not apply to metrics specified + via the `metrics` argument in `compile()`. To apply sample + weighting to your metrics, you can specify them via the + `weighted_metrics` in `compile()` instead. + steps: Integer or `None`. + Total number of steps (batches of samples) to draw before + declaring the evaluation round finished. If `steps` is `None`, + it will run until `x` is exhausted. In the case of an infinitely + repeating dataset, it will run indefinitely. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during evaluation. + return_dict: If `True`, loss and metric results are returned as a + dict, with each key being the name of the metric. + If `False`, they are returned as a list. + + Returns: + Scalar test loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). + + Note: When using compiled metrics, `evaluate()` may return multiple + submetric values, while `model.metrics_names` often lists only + top-level names (e.g., 'loss', 'compile_metrics'), leading to a + length mismatch. The order of the `evaluate()` output corresponds + to the order of metrics specified during `model.compile()`. You can + use this order to map the `evaluate()` results to the intended + metric. `model.metrics_names` itself will still return only the + top-level names. + """ + raise NotImplementedError + + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + """Generates output predictions for the input samples. + + Computation is done in batches. This method is designed for batch + processing of large numbers of inputs. It is not intended for use inside + of loops that iterate over your data and process small numbers of inputs + at a time. + + For small numbers of inputs that fit in one batch, + directly use `__call__()` for faster execution, e.g., + `model(x)`, or `model(x, training=False)` if you have layers such as + `BatchNormalization` that behave differently during + inference. + + Note: See [this FAQ entry]( + https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call) + for more details about the difference between `Model` methods + `predict()` and `__call__()`. + + Args: + x: Input data. It can be: + - A NumPy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A backend-native tensor, or a list of tensors + (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, + if the model has named inputs. + - A `keras.utils.PyDataset`. + - A `tf.data.Dataset`. + - A `torch.utils.data.DataLoader`. + - A Python generator function. + batch_size: Integer or `None`. + Number of samples per batch of computation. + If unspecified, `batch_size` will default to 32. + Do not specify the `batch_size` if your input data `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + verbose: `"auto"`, 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = single line. + `"auto"` becomes 1 for most cases. Note that the progress bar + is not particularly useful when logged to a file, + so `verbose=2` is recommended when not running interactively + (e.g. in a production environment). Defaults to `"auto"`. + steps: Total number of steps (batches of samples) to draw before + declaring the prediction round finished. If `steps` is `None`, + it will run until `x` is exhausted. In the case of an infinitely + repeating dataset, it will run indefinitely. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during prediction. + + Returns: + NumPy array(s) of predictions. + """ + raise NotImplementedError + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + """Runs a single gradient update on a single batch of data. + + Args: + x: Input data. Must be array-like. + y: Target data. Must be array-like. + sample_weight: Optional array of the same length as x, containing + weights to apply to the model's loss for each sample. + In the case of temporal data, you can pass a 2D array + with shape `(samples, sequence_length)`, to apply a different + weight to every timestep of every sample. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) to apply to the model's loss for the samples + from this class during training. This can be useful to tell the + model to "pay more attention" to samples from an + under-represented class. When `class_weight` is specified + and targets have a rank of 2 or greater, either `y` must + be one-hot encoded, or an explicit final dimension of 1 + must be included for sparse class labels. + return_dict: If `True`, loss and metric results are returned as a + dict, with each key being the name of the metric. If `False`, + they are returned as a list. + + Returns: + A scalar loss value (when no metrics and `return_dict=False`), + a list of loss and metric values + (if there are metrics and `return_dict=False`), or a dict of + metric and loss values (if `return_dict=True`). + """ + raise NotImplementedError + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + """Test the model on a single batch of samples. + + Args: + x: Input data. Must be array-like. + y: Target data. Must be array-like. + sample_weight: Optional array of the same length as x, containing + weights to apply to the model's loss for each sample. + In the case of temporal data, you can pass a 2D array + with shape `(samples, sequence_length)`, to apply a different + weight to every timestep of every sample. + return_dict: If `True`, loss and metric results are returned as a + dict, with each key being the name of the metric. If `False`, + they are returned as a list. + + Returns: + A scalar loss value (when no metrics and `return_dict=False`), + a list of loss and metric values + (if there are metrics and `return_dict=False`), or a dict of + metric and loss values (if `return_dict=True`). + """ + raise NotImplementedError + + def predict_on_batch(self, x): + """Returns predictions for a single batch of samples. + + Args: + x: Input data. It must be array-like. + + Returns: + NumPy array(s) of predictions. + """ + raise NotImplementedError + + def get_compile_config(self): + """Returns a serialized config with information for compiling the model. + + This method returns a config dictionary containing all the information + (optimizer, loss, metrics, etc.) with which the model was compiled. + + Returns: + A dict containing information for compiling the model. + """ + if self.compiled and hasattr(self, "_compile_config"): + return self._compile_config.serialize() + return {} + + def compile_from_config(self, config): + """Compiles the model with the information given in config. + + This method uses the information in the config (optimizer, loss, + metrics, etc.) to compile the model. + + Args: + config: Dict containing information for compiling the model. + """ + has_overridden_compile = self.__class__.compile != Trainer.compile + if has_overridden_compile: + warnings.warn( + "`compile()` was not called as part of model loading " + "because the model's `compile()` method is custom. " + "All subclassed Models that have `compile()` " + "overridden should also override " + "`get_compile_config()` and `compile_from_config(config)`. " + "Alternatively, you can " + "call `compile()` manually after loading.", + stacklevel=2, + ) + return + config = serialization_lib.deserialize_keras_object(config) + self.compile(**config) + if hasattr(self, "optimizer") and self.built: + # Create optimizer variables. + self.optimizer.build(self.trainable_variables) + + def _should_eval(self, epoch, validation_freq): + epoch = epoch + 1 # one-index the user-facing epoch. + if isinstance(validation_freq, int): + return epoch % validation_freq == 0 + elif isinstance(validation_freq, list): + return epoch in validation_freq + else: + raise ValueError( + "Expected `validation_freq` to be a list or int. " + f"Received: validation_freq={validation_freq} of the " + f"type {type(validation_freq)}." + ) + + def _get_metrics_result_or_logs(self, logs): + """Returns model metrics as a dict if the keys match with input logs. + + When the training / evaluation is performed with an asynchronous steps, + the last scheduled `train / test_step` may not give the latest metrics + because it is not guaranteed to be executed the last. This method gets + metrics from the model directly instead of relying on the return from + last step function. + + When the user has custom train / test step functions, the metrics + returned may be different from `Model.metrics`. In those instances, + this function will be no-op and return the logs passed in. + + Args: + logs: A `dict` of metrics returned by train / test step function. + + Returns: + A `dict` containing values of the metrics listed in `self.metrics` + when logs and model metrics keys match. Otherwise it returns input + `logs`. + """ + metric_logs = self.get_metrics_result() + # Verify that train / test step logs passed and metric logs have + # matching keys. It could be different when using custom step functions, + # in which case we return the logs from the last step. + if isinstance(logs, dict) and set(logs.keys()) == set( + metric_logs.keys() + ): + return metric_logs + return logs + + def _flatten_metrics_in_order(self, logs): + """Turns `logs` dict into a list as per key order of `metrics_names`.""" + metric_names = [] + for metric in self.metrics: + if isinstance(metric, CompileMetrics): + metric_names += [ + sub_metric.name for sub_metric in metric.metrics + ] + else: + metric_names.append(metric.name) + results = [] + for name in metric_names: + if name in logs: + results.append(logs[name]) + for key in sorted(logs.keys()): + if key not in metric_names: + results.append(logs[key]) + if len(results) == 1: + return results[0] + return results + + def _assert_compile_called(self, method_name=None): + if not self.compiled: + msg = "You must call `compile()` before " + if metrics_module: + msg += "using the model." + else: + msg += f"calling `{method_name}()`." + raise ValueError(msg) + + def _symbolic_build(self, iterator=None, data_batch=None): + model_unbuilt = not all(layer.built for layer in self._flatten_layers()) + compile_metrics_unbuilt = ( + self._compile_metrics is not None + and not self._compile_metrics.built + ) + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + optimizer_unbuilt = ( + self.optimizer is not None and not self.optimizer.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: + # Create symbolic tensors matching an input batch. + + def to_symbolic_input(v): + if v is None: + return None + return backend.KerasTensor( + v.shape, backend.standardize_dtype(v.dtype) + ) + + if data_batch is None: + for _, _, data_or_iterator in iterator: + if isinstance(data_or_iterator, (list, tuple)): + data_batch = data_or_iterator[0] + else: + data_batch = next(data_or_iterator) + break + data_batch = tree.map_structure(to_symbolic_input, data_batch) + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) + + # Build all model state with `backend.compute_output_spec`. + try: + y_pred = backend.compute_output_spec(self, x, training=False) + except Exception as e: + raise RuntimeError( + "Unable to automatically build the model. " + "Please build it yourself before calling " + "fit/evaluate/predict. " + "A model is 'built' when its variables have " + "been created and its `self.built` attribute " + "is True. Usually, calling the model on a batch " + "of data is the right way to build it.\n" + "Exception encountered:\n" + f"'{e}'" + ) + if compile_metrics_unbuilt: + # Build all metric state with `backend.compute_output_spec`. + backend.compute_output_spec( + self.compute_metrics, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + training=False, + ) + if optimizer_unbuilt: + # Build optimizer + self.optimizer.build(self.trainable_variables) + self._post_build() + + +def model_supports_jit(model): + # XLA not supported with TF on MacOS GPU + if platform.system() == "Darwin" and "arm" in platform.processor().lower(): + if backend.backend() == "tensorflow": + from keras.src.utils.module_utils import tensorflow as tf + + if tf.config.list_physical_devices("GPU"): + return False + # XLA not supported by some layers + if all(x.supports_jit for x in model._flatten_layers()): + if backend.backend() == "tensorflow": + from tensorflow.python.framework.config import ( + is_op_determinism_enabled, + ) + + if is_op_determinism_enabled(): + # disable XLA with determinism enabled since not all ops are + # supported by XLA with determinism enabled. + return False + return True + return False diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py new file mode 100644 index 000000000000..51833cb55fcc --- /dev/null +++ b/keras/src/trainers/trainer_test.py @@ -0,0 +1,2925 @@ +from unittest import mock + +import jax +import numpy as np +import pytest +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import losses +from keras.src import metrics +from keras.src import models +from keras.src import ops +from keras.src import optimizers +from keras.src import testing +from keras.src.backend import config +from keras.src.backend.common.symbolic_scope import in_symbolic_scope +from keras.src.callbacks.callback import Callback +from keras.src.distribution.distribution_lib import DataParallel +from keras.src.distribution.distribution_lib import DeviceMesh +from keras.src.optimizers.rmsprop import RMSprop +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product +from keras.src.trainers.data_adapters import py_dataset_adapter + +if backend.backend() == "jax": + from keras.src.backend.jax.trainer import JAXTrainer as Trainer + from keras.src.distribution import DataParallel + from keras.src.distribution import DeviceMesh +elif backend.backend() == "torch": + from keras.src.backend.torch.trainer import TorchTrainer as Trainer +elif backend.backend() == "tensorflow": + from keras.src.backend.tensorflow.trainer import ( + TensorFlowTrainer as Trainer, + ) +elif backend.backend() == "numpy": + from keras.src.backend.numpy.trainer import NumpyTrainer as Trainer +elif backend.backend() == "openvino": + from keras.src.backend.openvino.trainer import OpenVINOTrainer as Trainer +else: + raise ImportError(f"Invalid backend: {backend.backend()}") + + +# A model is just a layer mixed in with a Trainer. +class ExampleModel(Trainer, layers.Dense): + def __init__(self, units): + layers.Dense.__init__( + self, + units=units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + Trainer.__init__(self) + + +class CustomTrainTestStepModel(ExampleModel): + def train_step(self, data): + logs = super().train_step(data) + logs["my_custom_metric"] = 10.0 + return logs + + def test_step(self, data): + logs = super().test_step(data) + logs["my_custom_metric"] = 5.0 + return logs + + +class JaxCustomTrainTestStepModel(ExampleModel): + def train_step(self, state, data): + logs, state = super().train_step(state, data) + logs["my_custom_metric"] = 10.0 + return logs, state + + def test_step(self, state, data): + logs, state = super().test_step(state, data) + logs["my_custom_metric"] = 5.0 + return logs, state + + +class StructModel(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense_1 = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + self.dense_2 = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + + def call(self, x): + return { + "y_one": self.dense_1(x["x_one"]), + "y_two": self.dense_2(x["x_two"]), + } + + +class ListInputModel(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense_1 = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + self.dense_2 = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + + def call(self, x): + assert isinstance(x, (list, tuple)) + return self.dense_1(x[0]) + self.dense_2(x[1]) + + +class ListOutputModel(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense_1 = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + self.dense_2 = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + + def call(self, x): + return [self.dense_1(x), self.dense_2(x)] + + +class TrainingTestingLayer(Trainer, layers.Layer): + def __init__(self, **kwargs): + layers.Layer.__init__(self, **kwargs) + Trainer.__init__(self) + + def call(self, x, training=False): + if training: + return x + return x * 0 + + +class TestPyDataset(py_dataset_adapter.PyDataset): + def __init__(self, infinite=False, **kwargs): + super().__init__(**kwargs) + self.infinite = infinite + + @property + def num_batches(self): + return None if self.infinite else 20 + + def __getitem__(self, idx): + CPU_DEVICES = { + "tensorflow": "CPU:0", + "jax": "cpu:0", + "torch": "cpu", + } + with backend.device(CPU_DEVICES[backend.backend()]): + return ops.ones((5, 4)), ops.zeros((5, 3)) + + +def create_dataset(dataset_type, dataset_kwargs): + if dataset_type == "np_array": + return np.ones((100, 4)), np.zeros((100, 3)) + elif dataset_type == "native_array": + return ops.ones((100, 4)), ops.zeros((100, 3)) + elif dataset_type == "py_dataset": + return TestPyDataset(**dataset_kwargs), None + elif dataset_type == "tf_dataset": + import tensorflow as tf + + dataset = tf.data.Dataset.from_tensor_slices( + (tf.ones((100, 4)), tf.zeros((100, 3))) + ).batch(5) + if dataset_kwargs.get("infinite", False): + dataset = dataset.repeat() + return dataset, None + elif dataset_type == "torch_dataloader": + import torch + + class TestIterableDataset(torch.utils.data.IterableDataset): + def __iter__(self): + for _ in range(20): + yield torch.ones((5, 4)), torch.zeros((5, 3)) + + class TestIterableDatasetWithLen(TestIterableDataset): + def __len__(self): + return 20 + + if dataset_kwargs.get("iterable", False): + if dataset_kwargs.get("has_len", False): + dataset = TestIterableDatasetWithLen() + else: + dataset = TestIterableDataset() + return torch.utils.data.DataLoader(dataset), None + else: + dataset = torch.utils.data.TensorDataset( + torch.ones((100, 4)), torch.zeros((100, 3)) + ) + return torch.utils.data.DataLoader(dataset, batch_size=5), None + elif dataset_type == "generator": + + def generate_finite(): + for _ in range(20): + yield ops.ones((5, 4)), ops.zeros((5, 3)) + + def generate_infinite(): + while True: + yield ops.ones((5, 4)), ops.zeros((5, 3)) + + if dataset_kwargs.get("infinite", False): + return generate_infinite(), None + else: + return generate_finite(), None + elif dataset_type == "grain_datast": + import grain + + class TestIterableDataset(grain.sources.RandomAccessDataSource): + def __init__(self): + super().__init__() + self.x = np.ones((100, 4)).astype("float32") + self.y = np.zeros((100, 3)).astype("float32") + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + if dataset_kwargs.get("use_dataloader", False): + source = TestIterableDataset() + dataloader = grain.DataLoader( + data_source=source, + sampler=grain.samplers.IndexSampler(len(source), num_epochs=1), + operations=[grain.transforms.Batch(batch_size=5)], + ) + return dataloader, None + else: + dataset = grain.MapDataset.source(TestIterableDataset()) + if dataset_kwargs.get("has_len", False): + dataset = dataset.to_iter_dataset() + dataset = dataset.batch(5) + return dataset, None + else: + raise ValueError(f"Invalid dataset type {dataset_type}") + + +def sparse_generator(generator_type): + if generator_type == "scipy": + import scipy + + for _ in range(4): + x = scipy.sparse.random(2, 4, density=0.25, dtype="float32") + y = np.random.rand(2, 3).astype("float32") + yield x, y + elif generator_type == "tf": + import tensorflow as tf + + for _ in range(4): + x = tf.random.uniform((2, 4), dtype="float32") + x = tf.sparse.from_dense(tf.nn.dropout(x, 0.25)) + y = tf.random.uniform((2, 3), dtype="float32") + yield x, y + elif generator_type == "jax": + import jax + import jax.experimental.sparse as jax_sparse + + for _ in range(4): + seed = jax.random.PRNGKey(0) + x = jax_sparse.random_bcoo(seed, (2, 4), dtype="float32", nse=0.25) + y = jax.random.uniform(seed, (2, 3), dtype="float32") + yield x, y + else: + raise ValueError(f"Invalid generator type {generator_type}") + + +class EpochAgnosticMeanSquaredError(metrics.MeanSquaredError): + def __init__(self): + super().__init__(name="mse") + super().reset_state() + + def reset_state(self): + # prevent reset at each starting epoch + pass + + +class StepObserver(Callback): + def __init__(self): + super().__init__() + self.begin_count = 0 + self.end_count = 0 + self.epoch_begin_count = 0 + self.epoch_end_count = 0 + self.batch_loss_history = [] + + def on_epoch_begin(self, epoch, logs=None): + self.epoch_begin_count += 1 + + def on_epoch_end(self, epoch, logs=None): + self.epoch_end_count += 1 + + def on_batch_begin(self, batch, logs=None): + self.begin_count += 1 + + def on_batch_end(self, batch, logs=None): + self.end_count += 1 + self.batch_loss_history.append(logs["mse"]) + + +class StepCount(Callback): + def __init__(self, steps_per_execution=1): + super().__init__() + self.begin_count = 0 + self.end_count = 0 + self.epoch_begin_count = 0 + self.epoch_end_count = 0 + self.steps_per_execution = steps_per_execution + + def on_epoch_begin(self, epoch, logs=None): + self.begin_count = 0 + self.end_count = 0 + self.epoch_begin_count += 1 + + def on_epoch_end(self, epoch, logs=None): + self.epoch_end_count += 1 + + def on_batch_begin(self, batch, logs=None): + assert batch == self.begin_count * self.steps_per_execution + self.begin_count += 1 + + def on_batch_end(self, batch, logs=None): + self.end_count += 1 + assert batch == self.end_count * self.steps_per_execution - 1 + + +class TestTrainer(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_metric_tracking(self): + class ModelWithMetric(Trainer, layers.Dense): + def __init__(self, units): + layers.Dense.__init__( + self, + units=units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + Trainer.__init__(self) + self.my_metric = metrics.MeanSquaredError(name="my_metric") + + model = ModelWithMetric(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + x = np.ones((2, 4)) + y = np.zeros((2, 3)) + # Fit the model to make sure compile_metrics are built + model.fit(x, y, batch_size=2, epochs=1) + + # The model should have 3 metrics: loss_tracker, compile_metrics, + # my_metric. + self.assertEqual(len(model.metrics), 3) + self.assertEqual(model.metrics[0], model._loss_tracker) + self.assertEqual(model.metrics[1], model._compile_metrics) + self.assertEqual(model.metrics[2], model.my_metric) + + # All metrics should have their weights created + self.assertEqual(len(model._loss_tracker.variables), 2) + self.assertEqual(len(model._compile_metrics.variables), 2) + self.assertEqual(len(model.my_metric.variables), 2) + + # And those weights are tracked at the model level + self.assertEqual(len(model.metrics_variables), 6) + self.assertLen(model.non_trainable_variables, 0) + + # Models with only weighted_metrics should have the same 3 metrics + model_weighted = ModelWithMetric(units=3) + model_weighted.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + weighted_metrics=[metrics.MeanSquaredError()], + ) + model_weighted.fit( + x, + y, + batch_size=2, + epochs=1, + sample_weight=np.ones(2), + ) + self.assertEqual(len(model_weighted.metrics), 3) + + def test_nested_trainer_metrics(self): + # https://github.com/keras-team/keras/issues/20188 + model = ExampleModel(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + self.assertLen(model.metrics, 2) + self.assertEqual(model.metrics[0], model._loss_tracker) + self.assertEqual(model.metrics[1], model._compile_metrics) + + inputs = keras.Input((4,)) + outputs = model(inputs) + outputs = layers.Dense(8)(outputs) + new_model = models.Model(inputs, outputs) + new_model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + self.assertLen(new_model.metrics, 2) + self.assertEqual(new_model.metrics[0], new_model._loss_tracker) + self.assertEqual(new_model.metrics[1], new_model._compile_metrics) + + def test_nested_trainer_metrics_without_compile(self): + model = ExampleModel(units=3) + self.assertLen(model.metrics, 0) + + inputs = keras.Input((4,)) + outputs = model(inputs) + outputs = layers.Dense(8)(outputs) + new_model = models.Model(inputs, outputs) + new_model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + self.assertLen(new_model.metrics, 2) + self.assertEqual(new_model.metrics[0], new_model._loss_tracker) + self.assertEqual(new_model.metrics[1], new_model._compile_metrics) + + def test_multiple_compiles(self): + # https://github.com/keras-team/keras/issues/20474 + model1 = ExampleModel(units=3) + model2 = ExampleModel(units=3) + model1.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + # Combine these 2 models into `combined`. + inputs = keras.Input(shape=(4,)) + x = model1(inputs) + outputs = model2(x) + combined = models.Model(inputs, outputs) + combined.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + self.assertLen(model1.metrics, 2) + self.assertIsNotNone(model1._loss_tracker) + self.assertEqual(model1.metrics[0], model1._loss_tracker) + self.assertEqual(model1.metrics[1], model1._compile_metrics) + + # `combined.metrics` will not include `model1.metrics`. + self.assertLen(combined.metrics, 2) + self.assertIsNotNone(combined._loss_tracker) + self.assertEqual(combined.metrics[0], combined._loss_tracker) + self.assertEqual(combined.metrics[1], combined._compile_metrics) + + @pytest.mark.skipif( + backend.backend() != "torch", + reason="torch backend runs in eager mode for jit_compile='auto'", + ) + def test_compile_eager_vs_jit_torch(self): + model = ExampleModel(units=3) + model.compile(jit_compile="auto") + # torch trainer en/disables torch.compile only based on the value of + # model.jit_compile (not model.run_eagerly) + self.assertFalse(model.run_eagerly) + self.assertFalse(model.jit_compile) + + @parameterized.named_parameters( + [ + ("eager", True, False, False), + ("graph_fn", False, False, False), + ("jit", False, True, False), + ("steps_per_epoch_eager", True, False, True), + ("steps_per_epoch_graph_fn", False, False, True), + ("steps_per_epoch_jit", False, True, True), + ] + ) + @pytest.mark.requires_trainable_backend + def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch): + if not run_eagerly and not jit_compile and use_steps_per_epoch: + if False and backend.backend() == "tensorflow": + self.skipTest( + "TODO: Graph mode without XLA in TF backend leads to " + "unexpected logs, need further checks." + ) + if jit_compile and backend.backend() == "torch": + self.skipTest( + "TODO: compilation with torch backend leads to " + "unexpected logs, need further checks." + ) + + model = ExampleModel(units=3) + epochs = 3 + batch_size = 20 + steps_per_epoch = 7 + dataset_size = batch_size * (steps_per_epoch - 2) + x = np.ones((dataset_size, 4)) + y = np.zeros((dataset_size, 3)) + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + history = model.fit( + x, + y, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None, + epochs=epochs, + ) + history = history.history + self.assertIn("loss", history) + self.assertIn("mean_squared_error", history) + self.assertAllClose( + history["mean_squared_error"], + [14.5, 11.5, 8.5], + atol=1.0, # TODO: results vary across backends + ) + + @parameterized.named_parameters( + [ + { + "testcase_name": "np_array", + "dataset_type": "np_array", + "fit_kwargs": {"batch_size": 5}, + }, + { + "testcase_name": "native_array", + "dataset_type": "native_array", + "fit_kwargs": {"batch_size": 5}, + }, + { + "testcase_name": "py_dataset", + "dataset_type": "py_dataset", + }, + { + "testcase_name": "py_dataset_cw", + "dataset_type": "py_dataset", + "fit_kwargs": {"class_weight": {0: 1, 1: 2}}, + }, + { + "testcase_name": "py_dataset_infinite", + "dataset_type": "py_dataset", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "py_dataset_infinite_cw", + "dataset_type": "py_dataset", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": { + "steps_per_epoch": 20, + "class_weight": {0: 1, 1: 2}, + }, + }, + { + "testcase_name": "py_dataset_multithreading", + "dataset_type": "py_dataset", + "dataset_kwargs": {"workers": 2}, + }, + { + "testcase_name": "py_dataset_multithreading_cw", + "dataset_type": "py_dataset", + "dataset_kwargs": {"workers": 2}, + "fit_kwargs": {"class_weight": {0: 1, 1: 2}}, + }, + { + "testcase_name": "py_dataset_multithreading_infinite", + "dataset_type": "py_dataset", + "dataset_kwargs": {"infinite": True, "workers": 2}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "py_dataset_multiprocessing", + "dataset_type": "py_dataset", + "dataset_kwargs": {"workers": 2, "use_multiprocessing": True}, + }, + { + "testcase_name": "py_dataset_multiprocessing_cw", + "dataset_type": "py_dataset", + "dataset_kwargs": {"workers": 2, "use_multiprocessing": True}, + "fit_kwargs": {"class_weight": {0: 1, 1: 2}}, + }, + { + "testcase_name": "py_dataset_multiprocessing_infinite", + "dataset_type": "py_dataset", + "dataset_kwargs": { + "infinite": True, + "workers": 2, + "use_multiprocessing": True, + }, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "tf_dataset", + "dataset_type": "tf_dataset", + }, + { + "testcase_name": "tf_dataset_infinite", + "dataset_type": "tf_dataset", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "torch_dataloader_tensor", + "dataset_type": "torch_dataloader", + }, + { + "testcase_name": "torch_dataloader_iterable", + "dataset_type": "torch_dataloader", + "dataset_kwargs": {"iterable": True, "has_len": False}, + }, + { + "testcase_name": "torch_dataloader_iterable_with_len", + "dataset_type": "torch_dataloader", + "dataset_kwargs": {"iterable": True, "has_len": True}, + }, + { + "testcase_name": "generator", + "dataset_type": "generator", + }, + { + "testcase_name": "generator_infinite", + "dataset_type": "generator", + "dataset_kwargs": {"infinite": True}, + "fit_kwargs": {"steps_per_epoch": 20}, + }, + { + "testcase_name": "grain_datast", + "dataset_type": "grain_datast", + "dataset_kwargs": {"has_len": False}, + }, + { + "testcase_name": "grain_datast_with_len", + "dataset_type": "grain_datast", + "dataset_kwargs": {"has_len": True}, + }, + { + "testcase_name": "grain_dataloader", + "dataset_type": "grain_datast", + "dataset_kwargs": {"use_dataloader": True}, + }, + ] + ) + @pytest.mark.requires_trainable_backend + def test_fit_with_data_adapter( + self, dataset_type, dataset_kwargs={}, fit_kwargs={} + ): + jit_compile = True + if ( + dataset_kwargs.get("use_multiprocessing", False) + and backend.backend() == "jax" + ): + pytest.skip("Multiprocessing not supported with JAX backend") + if dataset_type == "grain_datast" and backend.backend() == "torch": + # Grain datasets are not supported with torch + jit_compile. + jit_compile = False + + model = ExampleModel(units=3) + optimizer = optimizers.Adagrad() + model.compile( + optimizer=optimizer, + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + jit_compile=jit_compile, + ) + x, y = create_dataset(dataset_type, dataset_kwargs) + model.fit(x, y, epochs=3, **fit_kwargs) + + @parameterized.named_parameters( + [ + ("eager", True, False, False), + ("graph_fn", False, False, False), + ("jit", False, True, False), + ("steps_per_epoch_eager", True, False, True), + ("steps_per_epoch_graph_fn", False, False, True), + ("steps_per_epoch_jit", False, True, True), + ] + ) + @pytest.mark.requires_trainable_backend + def test_fit_with_val_split( + self, run_eagerly, jit_compile, use_steps_per_epoch + ): + if not run_eagerly and not jit_compile and use_steps_per_epoch: + if backend.backend() == "tensorflow": + self.skipTest( + "TODO: Graph mode without XLA in TF backend leads to " + "unexpected logs, need further checks." + ) + + model = ExampleModel(units=3) + epochs = 3 + batch_size = 20 + steps_per_epoch = 7 + dataset_size = batch_size * (steps_per_epoch - 2) + x = np.ones((dataset_size, 4)) + y = np.zeros((dataset_size, 3)) + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + history = model.fit( + x, + y, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None, + epochs=epochs, + validation_split=0.2, + ) + history = history.history + self.assertIn("loss", history) + self.assertIn("val_loss", history) + + # Test with backend-native tensors. + x = ops.ones((dataset_size, 4)) + y = ops.zeros((dataset_size, 3)) + history = model.fit( + x, + y, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None, + epochs=epochs, + validation_split=0.2, + ) + history = history.history + self.assertIn("loss", history) + self.assertIn("val_loss", history) + + @pytest.mark.requires_trainable_backend + def test_fit_with_custom_train_step(self): + if backend.backend() == "jax": + model = JaxCustomTrainTestStepModel(units=3) + else: + model = CustomTrainTestStepModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + batch_size = 16 + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + history = model.fit(x, y, batch_size=batch_size) + history = history.history + self.assertIn("loss", history) + self.assertIn("mean_squared_error", history) + self.assertAllClose(history["my_custom_metric"], 10.0) + + @parameterized.named_parameters( + named_product( + generator_type=["tf", "jax", "scipy"], mode=["eager", "graph"] + ) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_fit_sparse(self, generator_type, mode): + model = ExampleModel(units=3) + optimizer = optimizers.Adagrad() + model.compile( + optimizer=optimizer, + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=(mode == "eager"), + jit_compile=False, + ) + dataset = sparse_generator(generator_type) + + sparse_variable_updates = False + + def mock_optimizer_assign(variable, value): + nonlocal sparse_variable_updates + if value.__class__.__name__ == "IndexedSlices": + sparse_variable_updates = True + + with mock.patch.object( + optimizer, "assign_sub", autospec=True + ) as optimizer_assign_sub: + optimizer_assign_sub.side_effect = mock_optimizer_assign + model.fit(dataset) + + # JAX does not produce sparse gradients the way we use it. + if backend.backend() != "jax": + # Verify tensors did not get densified along the way. + self.assertTrue(sparse_variable_updates) + + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + def test_evaluate_flow(self, run_eagerly, jit_compile): + model = ExampleModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + batch_size = 16 + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + output = model.evaluate(x, y, batch_size=batch_size) + self.assertAllClose(output, [16.0, 16.0]) + output = model.evaluate(x, y, batch_size=batch_size, return_dict=True) + self.assertIsInstance(output, dict) + self.assertIn("loss", output) + self.assertIn("mean_squared_error", output) + self.assertAllClose(output["mean_squared_error"], 16.0) + + @parameterized.named_parameters([("flat", False), ("dict", True)]) + @pytest.mark.requires_trainable_backend + def test_evaluate_with_custom_test_step(self, return_dict): + if backend.backend() == "jax": + model = JaxCustomTrainTestStepModel(units=3) + else: + model = CustomTrainTestStepModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + batch_size = 16 + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + output = model.evaluate( + x, y, batch_size=batch_size, return_dict=return_dict + ) + self.assertLen(output, 3) + if return_dict: + self.assertAllClose(output["my_custom_metric"], 5.0) + else: + self.assertAllClose(output[-1], 5.0) # Custom metrics go last. + + @parameterized.named_parameters( + named_product( + generator_type=["tf", "jax", "scipy"], mode=["eager", "graph"] + ) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_evaluate_sparse(self, generator_type, mode): + model = ExampleModel(units=3) + model.compile( + optimizer=optimizers.Adagrad(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=(mode == "eager"), + jit_compile=False, + ) + dataset = sparse_generator(generator_type) + model.evaluate(dataset) + + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + def test_predict_flow(self, run_eagerly, jit_compile): + # Test basic example + model = ExampleModel(units=3) + model.run_eagerly = run_eagerly + model.jit_compile = jit_compile + + x = np.ones((100, 4)) + batch_size = 16 + outputs = model.predict(x, batch_size=batch_size) + self.assertAllClose(outputs, 4 * np.ones((100, 3))) + + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + def test_predict_flow_struct(self, run_eagerly, jit_compile): + # Test with input/output structs + model = StructModel(units=3) + model.run_eagerly = run_eagerly + model.jit_compile = jit_compile + + x = { + "x_one": np.ones((100, 4)), + "x_two": np.ones((100, 4)), + } + batch_size = 16 + outputs = model.predict(x, batch_size=batch_size) + self.assertIsInstance(outputs, dict) + self.assertEqual(len(outputs), 2) + self.assertAllClose(outputs["y_one"], 4 * np.ones((100, 3))) + self.assertAllClose(outputs["y_two"], 4 * np.ones((100, 3))) + + @parameterized.named_parameters( + named_product( + generator_type=["tf", "jax", "scipy"], mode=["eager", "graph"] + ) + ) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_predict_sparse(self, generator_type, mode): + model = ExampleModel(units=3) + model.compile( + optimizer=optimizers.Adagrad(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=(mode == "eager"), + jit_compile=False, + ) + dataset = sparse_generator(generator_type) + dataset_size = sum( + [batch[1].shape[0] for batch in sparse_generator(generator_type)] + ) + y = model.predict(dataset) + self.assertEqual(len(y), dataset_size) + + @pytest.mark.skipif( + backend.backend() != "jax", + reason="Memory optimization is only implemented in JAX", + ) + def test_fit_eval_flow_for_jax_model_weights(self): + model = ExampleModel(units=3) + epochs = 3 + batch_size = 20 + steps_per_epoch = 7 + dataset_size = batch_size * (steps_per_epoch - 2) + x = np.ones((dataset_size, 4)) + y = np.zeros((dataset_size, 3)) + + class ModelWeightCheck(Callback): + def __init__(self): + super().__init__() + + # Note that we access model via self._model since self.model + # will trigger a sync of the jax training state back to the model. + def on_train_batch_end(self, batch, logs=None): + for v in self._model.trainable_variables: + assert v._value is None + for v in self._model.non_trainable_variables: + assert v._value is None + for v in self._model.optimizer.variables: + assert v._value is None + for v in self._model.metrics_variables: + assert v._value is None + + def on_test_batch_end(self, batch, logs=None): + for v in self._model.non_trainable_variables: + assert v._value is None + for v in self._model.metrics_variables: + assert v._value is None + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + model.fit( + x, + y, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + callbacks=[ModelWeightCheck()], + ) + + model.evaluate( + x, + y, + batch_size=batch_size, + callbacks=[ModelWeightCheck()], + ) + + @parameterized.named_parameters( + named_product( + steps_per_execution=[3, 101], mode=["eager", "non_jit", "jit"] + ) + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_steps_per_execution_steps_count(self, steps_per_execution, mode): + data_size = 100 + batch_size = 16 + epochs = 2 + + x = np.ones((data_size, 4)) + y = np.ones((data_size, 1)) + + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_count = StepCount(steps_per_execution) + + history = model.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + callbacks=[step_count], + verbose=0, + ) + + self.assertEqual( + step_count.begin_count, + 1 + (data_size - 1) // (steps_per_execution * batch_size), + ) + self.assertEqual(step_count.end_count, step_count.begin_count) + self.assertEqual(step_count.epoch_begin_count, epochs) + self.assertEqual( + step_count.epoch_end_count, step_count.epoch_begin_count + ) + + model_2 = ExampleModel(units=1) + model_2.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=1, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + history_2 = model_2.fit( + x=x, y=y, batch_size=batch_size, epochs=epochs, verbose=0 + ) + + self.assertAllClose(history.history["loss"], history_2.history["loss"]) + self.assertAllClose(model.get_weights(), model_2.get_weights()) + self.assertAllClose( + model.predict(x, batch_size=batch_size), + model_2.predict(x, batch_size=batch_size), + ) + self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y)) + + @parameterized.named_parameters( + named_product(steps_per_execution=[3, 8, 32]) + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="`unrolled_steps_per_execution` is only " + "available with the tensorflow backend.", + ) + def test_steps_per_execution_unrolled_steps_steps_count( + self, steps_per_execution + ): + data_size = 100 + batch_size = 16 + epochs = 2 + unrolled_steps_per_execution = 8 + + x = np.ones((data_size, 4)) + y = np.ones((data_size, 1)) + + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + jit_compile=True, + ) + step_count = StepCount(steps_per_execution) + model.unrolled_steps_per_execution = unrolled_steps_per_execution + history = model.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + callbacks=[step_count], + verbose=0, + ) + + self.assertEqual( + step_count.begin_count, + 1 + (data_size - 1) // (steps_per_execution * batch_size), + ) + self.assertEqual(step_count.end_count, step_count.begin_count) + self.assertEqual(step_count.epoch_begin_count, epochs) + self.assertEqual( + step_count.epoch_end_count, step_count.epoch_begin_count + ) + + model_2 = ExampleModel(units=1) + model_2.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + jit_compile=True, + ) + model_2.unrolled_steps_per_execution = 1 + history_2 = model_2.fit( + x=x, y=y, batch_size=batch_size, epochs=epochs, verbose=0 + ) + + self.assertAllClose(history.history["loss"], history_2.history["loss"]) + self.assertAllClose(model.get_weights(), model_2.get_weights()) + self.assertAllClose( + model.predict(x, batch_size=batch_size), + model_2.predict(x, batch_size=batch_size), + ) + self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y)) + + @parameterized.named_parameters( + named_product( + steps_per_execution=[1, 50], mode=["eager", "non_jit", "jit"] + ) + ) + def test_predict_preserve_order(self, steps_per_execution, mode): + if steps_per_execution > 1 and backend.backend() == "torch": + self.skipTest("`steps_per_execution` not implemented for torch yet") + + def generate_uneven_batches(): + batch_sizes = [2, 3, 4] + + def gen_i(): + for i in range(100): + yield i + + iterator = iter(gen_i()) + j = 0 + while True: + batch_size = batch_sizes[j % len(batch_sizes)] + try: + batch = np.array( + [next(iterator) for _ in range(batch_size)] + ) + except StopIteration: + break + j += 1 + yield batch + + from keras.src.utils.module_utils import tensorflow as tf + + dataset = tf.data.Dataset.from_generator( + generate_uneven_batches, + output_signature=tf.TensorSpec((None,), dtype=tf.int32), + ) + x = keras.layers.Input(shape=()) + y = keras.layers.Identity()(x) + model = keras.Model(x, y) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + + preds = model.predict(x=dataset, verbose=0) + + self.assertAllEqual(preds, np.arange(len(preds), dtype=np.float32)) + + @parameterized.named_parameters( + named_product( + steps_per_execution=[1, 50], mode=["eager", "non_jit", "jit"] + ) + ) + def test_predict_generator(self, steps_per_execution, mode): + if steps_per_execution > 1 and backend.backend() == "torch": + self.skipTest("`steps_per_execution` not implemented for torch yet") + + batch_size = 2 + + def generate_batches(): + def gen_i(): + for i in range(10): + yield i + + iterator = iter(gen_i()) + j = 0 + while True: + try: + batch = np.array( + [next(iterator) for _ in range(batch_size)] + ) + except StopIteration: + break + j += 1 + yield (batch,) + + model = keras.Sequential( + [keras.layers.InputLayer(shape=()), keras.layers.Identity()] + ) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + + preds = model.predict(x=generate_batches(), verbose=0) + self.assertAllEqual( + preds, np.concatenate(list(generate_batches()), axis=1)[0] + ) + + @parameterized.named_parameters( + named_product( + steps_per_execution=[3, 101], mode=["eager", "non_jit", "jit"] + ) + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_steps_per_execution_steps_count_unknown_dataset_size( + self, steps_per_execution, mode + ): + data_size = 100 + batch_size = 16 + epochs = 2 + + def data_generator(): + x = np.ones((data_size, 4), dtype=np.float32) + y = np.ones((data_size, 1), dtype=np.float32) + for _x, _y in zip(x, y): + yield _x, _y + + import tensorflow as tf + + dataset = tf.data.Dataset.from_generator( + data_generator, + output_signature=( + tf.TensorSpec(shape=(4,), dtype=tf.float32), + tf.TensorSpec(shape=(1,), dtype=tf.float32), + ), + ) + dataset = dataset.batch(batch_size) + + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_count = StepCount(steps_per_execution) + + history = model.fit( + dataset, + epochs=epochs, + callbacks=[step_count], + verbose=0, + ) + + batch_count = 1 + (data_size - 1) // (steps_per_execution * batch_size) + self.assertGreaterEqual(step_count.begin_count, batch_count) + self.assertEqual(step_count.end_count, batch_count) + self.assertEqual(step_count.epoch_begin_count, epochs) + self.assertEqual( + step_count.epoch_end_count, step_count.epoch_begin_count + ) + + model_2 = ExampleModel(units=1) + model_2.compile( + loss="mse", + optimizer="sgd", + steps_per_execution=1, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + history_2 = model_2.fit(dataset, epochs=epochs, verbose=0) + + self.assertAllClose(history.history["loss"], history_2.history["loss"]) + self.assertAllClose(model.get_weights(), model_2.get_weights()) + self.assertAllClose( + model.predict(dataset), + model_2.predict(dataset), + ) + self.assertAllClose(model.evaluate(dataset), model_2.evaluate(dataset)) + + @parameterized.named_parameters( + named_product( + steps_per_epoch_test=[ + "match_one_epoch", + "match_multi_epoch", + "not_match_too_low", + "not_match_but_high_enough", + ], + mode=["eager", "non_jit", "jit"], + ) + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_steps_per_execution_steps_per_epoch( + self, steps_per_epoch_test, mode + ): + batch_size = 8 + epochs = 2 + steps_per_execution = 2 + num_batches = 5 * steps_per_execution + data_size = num_batches * batch_size + + if steps_per_epoch_test == "match_one_epoch": + steps_per_epoch = num_batches + elif steps_per_epoch_test == "match_multi_epoch": + steps_per_epoch = num_batches // steps_per_execution + elif steps_per_epoch_test == "not_match_too_low": + steps_per_epoch = num_batches - steps_per_execution + elif steps_per_epoch_test == "not_match_but_high_enough": + steps_per_epoch = num_batches + steps_per_execution + + x = np.ones((data_size, 4)) + y = np.ones((data_size, 1)) + + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_observer = StepObserver() + + model.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + callbacks=[step_observer], + verbose=0, + ) + if steps_per_epoch_test != "not_match_too_low": + training_batch_count = ( + epochs + * min(steps_per_epoch, num_batches) + // steps_per_execution + ) + else: + complete_epochs = (num_batches // steps_per_execution) // ( + steps_per_epoch // steps_per_execution + ) + remaining_steps = (num_batches // steps_per_execution) % ( + steps_per_epoch // steps_per_execution + ) + steps_cycles = [ + complete_epochs * steps_per_epoch // steps_per_execution, + remaining_steps, + ] * epochs + steps_per_epochs = steps_cycles[:epochs] + training_batch_count = sum(steps_per_epochs) + + self.assertEqual(step_observer.begin_count, training_batch_count) + self.assertEqual(step_observer.end_count, step_observer.begin_count) + self.assertEqual(step_observer.epoch_begin_count, epochs) + self.assertEqual( + step_observer.epoch_end_count, step_observer.epoch_begin_count + ) + + if steps_per_epoch_test != "not_match_too_low": + model_2 = ExampleModel(units=1) + model_2.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + steps_per_execution=1, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_observer_2 = StepObserver() + + if steps_per_epoch_test in ( + "not_match_but_high_enough", + "match_one_epoch", + ): + model_2_epochs = epochs + else: + model_2_epochs = 1 + + model_2.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=model_2_epochs, + callbacks=[step_observer_2], + verbose=0, + ) + + losses = step_observer.batch_loss_history + losses_2 = step_observer_2.batch_loss_history[ + steps_per_execution - 1 :: steps_per_execution + ] + self.assertAllClose(losses, losses_2) + self.assertAllClose(model.get_weights(), model_2.get_weights()) + self.assertAllClose( + model.predict(x, batch_size=batch_size), + model_2.predict(x, batch_size=batch_size), + ) + self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y)) + + @parameterized.named_parameters( + named_product( + steps_per_epoch_test=[ + "match_one_epoch", + "match_multi_epoch", + "not_match_too_low", + "not_match_but_high_enough", + ], + mode=["eager", "non_jit", "jit"], + ) + ) + @pytest.mark.requires_trainable_backend + def test_steps_per_epoch(self, steps_per_epoch_test, mode): + batch_size = 8 + epochs = 4 + num_batches = 10 + data_size = num_batches * batch_size + + if steps_per_epoch_test == "match_one_epoch": + steps_per_epoch = num_batches + elif steps_per_epoch_test == "match_multi_epoch": + steps_per_epoch = num_batches // (epochs // 2) + elif steps_per_epoch_test == "not_match_too_low": + steps_per_epoch = num_batches - 1 + elif steps_per_epoch_test == "not_match_but_high_enough": + steps_per_epoch = num_batches + 1 + + x = np.ones((data_size, 4)) + y = np.ones((data_size, 1)) + + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_observer = StepObserver() + + model.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + callbacks=[step_observer], + verbose=0, + ) + if steps_per_epoch_test != "not_match_too_low": + training_batch_count = epochs * min(steps_per_epoch, num_batches) + else: + complete_epochs = num_batches // steps_per_epoch + remaining_steps = num_batches % steps_per_epoch + steps_cycles = [ + complete_epochs * steps_per_epoch, + remaining_steps, + ] * epochs + steps_per_epochs = steps_cycles[:epochs] + training_batch_count = sum(steps_per_epochs) + + self.assertEqual(step_observer.begin_count, training_batch_count) + self.assertEqual(step_observer.end_count, step_observer.begin_count) + self.assertEqual(step_observer.epoch_begin_count, epochs) + self.assertEqual( + step_observer.epoch_end_count, step_observer.epoch_begin_count + ) + + if steps_per_epoch_test != "not_match_too_low": + model_2 = ExampleModel(units=1) + model_2.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + steps_per_execution=1, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_observer_2 = StepObserver() + + if steps_per_epoch_test in ( + "not_match_but_high_enough", + "match_one_epoch", + ): + model_2_epochs = epochs + elif steps_per_epoch_test == "match_multi_epoch": + model_2_epochs = epochs // (num_batches // steps_per_epoch) + else: + model_2_epochs = 1 + + model_2.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=model_2_epochs, + callbacks=[step_observer_2], + verbose=0, + ) + + losses = step_observer.batch_loss_history + losses_2 = step_observer_2.batch_loss_history + + self.assertAllClose(losses, losses_2) + self.assertAllClose(model.get_weights(), model_2.get_weights()) + self.assertAllClose( + model.predict(x, batch_size=batch_size), + model_2.predict(x, batch_size=batch_size), + ) + self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y)) + + @pytest.mark.requires_trainable_backend + def test_max_epochs_and_steps(self): + batch_size = 8 + epochs = 4 + num_batches = 10 + data_size = num_batches * batch_size + x, y = np.ones((data_size, 4)), np.ones((data_size, 1)) + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + ) + step_observer = StepObserver() + model.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + callbacks=[step_observer], + verbose=0, + ) + self.assertEqual(step_observer.epoch_begin_count, epochs) + self.assertEqual(step_observer.begin_count, num_batches * epochs) + try: + config.set_max_epochs(2) + config.set_max_steps_per_epoch(3) + step_observer = StepObserver() + model.fit( + x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + callbacks=[step_observer], + verbose=0, + ) + self.assertEqual(step_observer.epoch_begin_count, 2) + self.assertEqual(step_observer.begin_count, 6) + finally: + config.set_max_epochs(None) + config.set_max_steps_per_epoch(None) + + @parameterized.named_parameters( + named_product( + steps_per_epoch_test=[ + "match", + "not_match_too_low", + "not_match_but_high_enough", + ], + mode=["eager", "non_jit", "jit"], + ) + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_steps_per_execution_steps_per_epoch_unknown_data_size( + self, steps_per_epoch_test, mode + ): + batch_size = 8 + epochs = 2 + steps_per_execution = 2 + num_batches = 5 * epochs * steps_per_execution + data_size = num_batches * batch_size + + if steps_per_epoch_test == "match": + steps_per_epoch = num_batches // epochs + elif steps_per_epoch_test == "not_match_too_low": + steps_per_epoch = num_batches - steps_per_execution + elif steps_per_epoch_test == "not_match_but_high_enough": + steps_per_epoch = num_batches + steps_per_execution + + def data_generator(): + x = np.ones((data_size, 4), dtype=np.float32) + y = np.ones((data_size, 1), dtype=np.float32) + for _x, _y in zip(x, y): + yield _x, _y + + import tensorflow as tf + + dataset = tf.data.Dataset.from_generator( + data_generator, + output_signature=( + tf.TensorSpec(shape=(4,), dtype=tf.float32), + tf.TensorSpec(shape=(1,), dtype=tf.float32), + ), + ) + dataset = dataset.batch(batch_size) + + model = ExampleModel(units=1) + model.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + steps_per_execution=steps_per_execution, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_observer = StepObserver() + + model.fit( + dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + callbacks=[step_observer], + verbose=0, + ) + if steps_per_epoch_test != "not_match_too_low": + training_batch_count = ( + epochs + * min(steps_per_epoch, num_batches) + // steps_per_execution + ) + else: + complete_epochs = (num_batches // steps_per_execution) // ( + steps_per_epoch // steps_per_execution + ) + remaining_steps = (num_batches // steps_per_execution) % ( + steps_per_epoch // steps_per_execution + ) + steps_cycles = [ + complete_epochs * steps_per_epoch // steps_per_execution, + remaining_steps, + ] * epochs + steps_per_epochs = steps_cycles[:epochs] + training_batch_count = sum(steps_per_epochs) + + self.assertGreaterEqual(step_observer.begin_count, training_batch_count) + self.assertEqual(step_observer.end_count, training_batch_count) + self.assertEqual(step_observer.epoch_begin_count, epochs) + self.assertEqual( + step_observer.epoch_end_count, step_observer.epoch_begin_count + ) + + if steps_per_epoch_test != "not_match_too_low": + model_2 = ExampleModel(units=1) + model_2.compile( + loss="mse", + optimizer="sgd", + metrics=[EpochAgnosticMeanSquaredError()], + steps_per_execution=1, + run_eagerly=(mode == "eager"), + jit_compile=(mode == "jit"), + ) + step_observer_2 = StepObserver() + + if steps_per_epoch_test == "not_match_but_high_enough": + model_2_epochs = epochs + else: + model_2_epochs = 1 + + model_2.fit( + dataset, + epochs=model_2_epochs, + callbacks=[step_observer_2], + verbose=0, + ) + + losses = step_observer.batch_loss_history + losses_2 = step_observer_2.batch_loss_history[ + steps_per_execution - 1 :: steps_per_execution + ] + self.assertAllClose(losses, losses_2) + self.assertAllClose(model.get_weights(), model_2.get_weights()) + self.assertAllClose( + model.predict(dataset), model_2.predict(dataset) + ) + self.assertAllClose( + model.evaluate(dataset), model_2.evaluate(dataset) + ) + + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_steps_per_execution_steps_count_without_training(self): + class StepCount(Callback): + def __init__(self): + super().__init__() + self.test_count = 0 + self.predict_count = 0 + self.batches = [0, 3, 6] + + def on_test_batch_begin(self, batch, logs=None): + assert batch == self.batches[self.test_count] + self.test_count += 1 + + def on_predict_batch_begin(self, batch, logs=None): + assert batch == self.batches[self.predict_count] + self.predict_count += 1 + + x = np.ones((100, 4)) + y = np.ones((100, 1)) + batch_size = 16 + model = ExampleModel(units=1) + model.compile(loss="mse", steps_per_execution=3) + step_count = StepCount() + model.predict(x, batch_size=batch_size, callbacks=[step_count]) + self.assertEqual(step_count.predict_count, 3) + model.evaluate(x, y, batch_size=batch_size, callbacks=[step_count]) + self.assertEqual(step_count.test_count, 3) + + @pytest.mark.requires_trainable_backend + def test_fit_with_different_batch_size_same_loss(self): + x = np.random.rand(100, 4) + y = np.ones((100, 1)) + model = ExampleModel(units=1) + model.trainable = False + model.compile(loss="mse") + loss1 = model.fit(x, y, batch_size=80).history["loss"] + loss2 = model.fit(x, y, batch_size=100).history["loss"] + self.assertAllClose(loss1, loss2) + + def test_evaluate_with_different_batch_size_same_loss(self): + x = np.random.rand(100, 4) + y = np.ones((100, 1)) + model = ExampleModel(units=1) + model.compile(loss="mse") + loss1 = model.evaluate(x, y, batch_size=80) + loss2 = model.evaluate(x, y, batch_size=100) + self.assertAllClose(loss1, loss2) + + @pytest.mark.requires_trainable_backend + def test_adds_loss_scaling_optimizer(self): + model = TrainingTestingLayer(dtype="mixed_float16") + model.compile(optimizer="rmsprop", loss="mse") + x = np.ones((128, 1)) + y = np.zeros((128, 1)) + model.fit(x, y, batch_size=32) + self.assertIsInstance(model.optimizer, optimizers.LossScaleOptimizer) + + model = TrainingTestingLayer(dtype="mixed_float16") + model.compile(optimizer="rmsprop", loss="mse", auto_scale_loss=False) + x = np.ones((128, 1)) + y = np.zeros((128, 1)) + model.fit(x, y, batch_size=32) + self.assertIsInstance(model.optimizer, RMSprop) + + model = TrainingTestingLayer(dtype="mixed_bfloat16") + model.compile(optimizer="rmsprop", loss="mse") + x = np.ones((128, 1)) + y = np.zeros((128, 1)) + model.fit(x, y, batch_size=32) + self.assertIsInstance(model.optimizer, RMSprop) + + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="half precision unsupported on torch CPU.", + ) + def test_loss_scaling_prevents_underflow(self): + class DeepModel(Trainer, layers.Layer): + def __init__(self): + layers.Layer.__init__(self, dtype="mixed_float16") + Trainer.__init__(self) + self.layers = [] + for _ in range(15): + # Sigmoid has a small gradient, will eventually underflow. + self.layers.append( + layers.Dense( + 1, + use_bias=False, + kernel_initializer="ones", + activation="sigmoid", + dtype="mixed_float16", + ) + ) + + def call(self, x): + for layer in self.layers: + x = layer(x) + return x + + loss = losses.MeanSquaredError() + # Blow up any gradient updates, so underflow is obvious. + optimizer = optimizers.SGD(learning_rate=1e9) + model = DeepModel() + model.compile(optimizer, loss=loss, auto_scale_loss=False) + model.fit(np.ones((1, 1)), np.ones((1, 1)), batch_size=1) + first_kernel = model.layers[0].kernel + # Without autoscaling, the first dense will not update. + self.assertEqual(first_kernel, np.ones_like(first_kernel)) + + # Blow up any gradient updates, so underflow is obvious. + optimizer = optimizers.SGD(learning_rate=1e9) + model = DeepModel() + model.compile(optimizer, loss=loss, auto_scale_loss=True) + model.fit(np.ones((1, 1)), np.ones((1, 1)), batch_size=1) + first_kernel = model.layers[0].kernel + # With autoscaling, the first dense will update. + self.assertNotEqual(first_kernel, np.ones_like(first_kernel)) + + @pytest.mark.requires_trainable_backend + def test_training_arg(self): + model = TrainingTestingLayer() + model.compile(optimizer="rmsprop", loss="mse") + x = np.ones((128, 1)) + y = np.zeros((128, 1)) + history = model.fit(x, y, batch_size=32) + self.assertAllClose(history.history["loss"], [1.0]) + val_loss = model.evaluate(x, y, batch_size=32) + self.assertAllClose(val_loss, 0.0) + preds = model.predict(x) + self.assertAllClose(preds, np.zeros((128, 1))) + + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + @pytest.mark.requires_trainable_backend + def test_on_batch_methods(self, run_eagerly, jit_compile): + if backend.backend() == "torch" and jit_compile: + self.skipTest( + "test_on_batch with jit_compile=True not supported in torch " + "backend yet." + ) + model = ExampleModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + sw = np.arange(100).reshape((100,)).astype("float32") / 50.0 + + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + logs = model.train_on_batch(x, y) + self.assertIsInstance(logs, list) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs[0], 16.0) + + logs = model.train_on_batch(x, y, return_dict=True) + self.assertIsInstance(logs, dict) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs["loss"], 15.579) + + logs = model.test_on_batch(x, y) + self.assertIsInstance(logs, list) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs[0], 15.173) + + logs = model.test_on_batch(x, y, return_dict=True) + self.assertIsInstance(logs, dict) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs["loss"], 14.97) + + output = model.predict_on_batch(x) + self.assertIsInstance(output, np.ndarray) + self.assertAllClose(output[0], np.array([3.789511, 3.789511, 3.789511])) + + # With sample weights + logs = model.train_on_batch(x, y, sw) + self.assertAlmostEqual(logs[0], 14.819) + logs = model.test_on_batch(x, y, sw) + self.assertAlmostEqual(logs[0], 14.595) + output = model.predict_on_batch(x) + self.assertAllClose(output[0], np.array([3.689468, 3.689468, 3.689468])) + + # With class weights + logs = model.train_on_batch(x, y, class_weight={1: 0.3, 0: 0.2}) + self.assertAlmostEqual(logs[0], 12.899) + + @parameterized.named_parameters( + [ + ("eager", True, False), + ("graph_fn", False, False), + ("jit", False, True), + ] + ) + def test_on_batch_methods_without_training(self, run_eagerly, jit_compile): + if backend.backend() == "torch" and jit_compile: + self.skipTest( + "test_on_batch with jit_compile=True not supported in torch " + "backend yet." + ) + model = ExampleModel(units=3) + x = np.ones((100, 4)) + y = np.zeros((100, 3)) + + model.compile( + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + run_eagerly=run_eagerly, + jit_compile=jit_compile, + ) + output = model.predict_on_batch(x) + self.assertIsInstance(output, np.ndarray) + self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0])) + + logs = model.test_on_batch(x, y) + self.assertIsInstance(logs, list) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs[0], 16.0) + + logs = model.test_on_batch(x, y, return_dict=True) + self.assertIsInstance(logs, dict) + self.assertEqual(len(logs), 2) + self.assertAlmostEqual(logs["loss"], 16.0) + + def test_nested_input_predict(self): + # https://github.com/keras-team/keras/issues/325 + + class TupleInputModel(keras.Model): + def call(self, inputs): + a, b = inputs + return a + b + + model = TupleInputModel() + x1, x2 = np.random.rand(2, 3, 4) + out = model.predict((x1, x2)) + self.assertEqual(out.shape, (3, 4)) + + class DictInputModel(keras.Model): + def call(self, inputs): + return inputs["a"] + inputs["b"] + + model = DictInputModel() + x1, x2 = np.random.rand(2, 3, 4) + out = model.predict({"a": x1, "b": x2}) + self.assertEqual(out.shape, (3, 4)) + + @pytest.mark.requires_trainable_backend + def test_for_eval_epoch_iterator(self): + model = ExampleModel(units=3) + model.compile( + optimizer="adam", loss="mse", metrics=["mean_absolute_error"] + ) + x = np.ones((16, 4)) + y = np.zeros((16, 3)) + x_test = np.ones((16, 4)) + y_test = np.zeros((16, 3)) + model.fit( + x, + y, + batch_size=4, + validation_data=(x_test, y_test), + ) + assert getattr(model, "_eval_epoch_iterator", None) is None + + # Try model.fit with reshaped validation_data + # This will throw an exception which is intended + try: + model.fit( + x, + y, + batch_size=4, + validation_data=( + x_test.reshape((-1, 16, 4)), + y_test.reshape((-1, 16, 3)), + ), + ) + except: + pass + + # Try model.fit with correct validation_data this should work. + # After successful training `_eval_epoch_iterator` should be None + model.fit( + x, + y, + batch_size=4, + validation_data=(x_test, y_test), + ) + assert getattr(model, "_eval_epoch_iterator", None) is None + + @pytest.mark.requires_trainable_backend + def test_callback_methods_keys(self): + class CustomCallback(Callback): + def on_train_begin(self, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_train_end(self, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [ + "loss", + "mean_absolute_error", + "val_loss", + "val_mean_absolute_error", + ] + + def on_epoch_begin(self, epoch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_epoch_end(self, epoch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [ + "loss", + "mean_absolute_error", + "val_loss", + "val_mean_absolute_error", + ] + + def on_test_begin(self, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_test_end(self, logs=None): + keys = sorted(list(logs.keys())) + assert keys == ["loss", "mean_absolute_error"] + + def on_predict_begin(self, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_predict_end(self, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_train_batch_begin(self, batch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_train_batch_end(self, batch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == ["loss", "mean_absolute_error"] + + def on_test_batch_begin(self, batch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_test_batch_end(self, batch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == ["loss", "mean_absolute_error"] + + def on_predict_batch_begin(self, batch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == [] + + def on_predict_batch_end(self, batch, logs=None): + keys = sorted(list(logs.keys())) + assert keys == ["outputs"] + + model = ExampleModel(units=3) + model.compile( + optimizer="adam", loss="mse", metrics=["mean_absolute_error"] + ) + x = np.ones((16, 4)) + y = np.zeros((16, 3)) + x_test = np.ones((16, 4)) + y_test = np.zeros((16, 3)) + model.fit( + x, + y, + callbacks=[CustomCallback()], + batch_size=4, + validation_data=(x_test, y_test), + ) + model.evaluate(x_test, y_test, batch_size=4) + model.predict(x_test, batch_size=4) + + @pytest.mark.requires_trainable_backend + def test_internal_only_loss(self): + class LossLayer(layers.Layer): + def call(self, x): + self.add_loss(ops.sum(x)) + return x + + model = keras.Sequential( + [ + layers.Dense(2), + LossLayer(), + layers.Dense(1), + ] + ) + model.compile(optimizer="adam") + x = np.ones((16, 2)) + y = np.zeros((16, 1)) + model.fit(x, y, batch_size=4) + + def get_layer(self): + class ExampleLayer(keras.Layer): + def call(self, x): + return x * 2 + + return ExampleLayer + + def get_model(self): + class ExampleModel(keras.Model): + def call(self, x): + return x * 2 + + return ExampleModel + + def get_functional(self): + ExampleLayer = self.get_layer() + + class ExampleFunctional(keras.src.Functional): + def __init__(self, input_shape=(None,)): + inputs = keras.Input(input_shape) + outputs = ExampleLayer()(inputs) + super().__init__(inputs=inputs, outputs=outputs) + + return ExampleFunctional + + @parameterized.named_parameters( + [ + { + "testcase_name": "model", + "model_class": "get_model", + }, + { + "testcase_name": "layer", + "model_class": "get_layer", + }, + { + "testcase_name": "functional", + "model_class": "get_functional", + }, + ] + ) + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + keras.backend.backend() != "tensorflow", + reason="Only tensorflow supports raggeds", + ) + def test_trainer_with_raggeds(self, model_class): + from keras.src.utils.module_utils import tensorflow as tf + + def loss_fn(y, y_pred, sample_weight=None): + return 0 + + model = getattr(self, model_class)()() + x = tf.ragged.constant([[1], [2, 3]]) + + # test forward pass + y = model(x) + self.assertEqual(type(y), tf.RaggedTensor) + + # test training + if model_class in ["get_model", "get_functional"]: + model.compile(optimizer="adam", loss=loss_fn) + model.fit(x, x) + y = model.predict(x) + self.assertEqual(type(y), tf.RaggedTensor) + + # test if everything works with the sequential model + model = keras.Sequential([model]) + model.compile(optimizer="adam", loss=loss_fn) + model.fit(x, x) + y = model.predict(x) + self.assertEqual(type(y), tf.RaggedTensor) + + def test_predict_dropout(self): + # Test that `predict` with a dropout op + # has nondeterministic behavior across batches. + + inputs = layers.Input((20,)) + outputs = layers.Dropout(0.5, seed=1337)(inputs, training=True) + model = keras.Model(inputs, outputs) + out1 = model.predict(np.ones((4, 20)), batch_size=2) + self.assertGreater(5, np.sum(np.abs(out1[:2, :] - out1[2:4, :]))) + + out2 = model.predict_on_batch(np.ones((2, 20))) + out3 = model.predict_on_batch(np.ones((2, 20))) + self.assertGreater(5, np.sum(np.abs(out2 - out3))) + + @pytest.mark.requires_trainable_backend + def test_recompile(self): + model = ExampleModel(units=3) + model.compile( + optimizer="sgd", loss="mse", metrics=["mean_squared_error"] + ) + history_1 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history + eval_out_1 = model.evaluate( + np.ones((3, 2)), np.ones((3, 3)), return_dict=True + ) + model.compile( + optimizer="sgd", loss="mse", metrics=["mean_absolute_error"] + ) + history_2 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history + eval_out_2 = model.evaluate( + np.ones((3, 2)), np.ones((3, 3)), return_dict=True + ) + self.assertEqual( + sorted(list(history_1.keys())), ["loss", "mean_squared_error"] + ) + self.assertEqual( + sorted(list(eval_out_1.keys())), ["loss", "mean_squared_error"] + ) + self.assertEqual( + sorted(list(history_2.keys())), ["loss", "mean_absolute_error"] + ) + self.assertEqual( + sorted(list(eval_out_2.keys())), ["loss", "mean_absolute_error"] + ) + + def test_evaluate_return_list_respect_metrics_order(self): + def metrics_zero(y_true, y_pred): + return 0.0 + + def metrics_one(y_true, y_pred): + return 1.0 + + model = ExampleModel(units=3) + model.compile( + optimizer="sgd", loss="mse", metrics=[metrics_zero, metrics_one] + ) + eval_out = model.evaluate(np.ones((3, 2)), np.ones((3, 3))) + self.assertLen(eval_out, 3) + self.assertEqual(eval_out[1], 0.0) + self.assertEqual(eval_out[2], 1.0) + + model.compile( + optimizer="sgd", loss="mse", metrics=[metrics_one, metrics_zero] + ) + eval_out = model.evaluate(np.ones((3, 2)), np.ones((3, 3))) + self.assertLen(eval_out, 3) + self.assertEqual(eval_out[1], 1.0) + self.assertEqual(eval_out[2], 0.0) + + @pytest.mark.requires_trainable_backend + def test_nested_inputs(self): + model = ListInputModel(units=2) + out = model([np.ones((3, 2)), np.ones((3, 3))]) + self.assertEqual(tuple(out.shape), (3, 2)) + model.compile(optimizer="sgd", loss="mse", metrics=["mse"]) + history = model.fit( + [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) + ).history + self.assertAllClose(history["loss"], 16.0) + train_out = model.train_on_batch( + [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) + ) + self.assertAllClose(train_out[0], 15.2200) + eval_out = model.evaluate( + [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) + ) + self.assertAllClose(eval_out[0], 13.0321) + eval_out = model.test_on_batch( + [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2)) + ) + self.assertAllClose(eval_out[0], 13.0321) + predict_out = model.predict([np.ones((3, 2)), np.ones((3, 3))]) + self.assertEqual(predict_out.shape, (3, 2)) + predict_out = model.predict_on_batch([np.ones((3, 2)), np.ones((3, 3))]) + self.assertEqual(predict_out.shape, (3, 2)) + + @pytest.mark.requires_trainable_backend + def test_validation_data_infinite_generator(self): + # Test that you can pass an infinite generator to `validation_data` + # arg of fit() as well as a `validation_steps` argument and that + # validation only runs for the correct number of steps. + model = ExampleModel(units=3) + model.compile(optimizer="sgd", loss="mse", metrics=["mse"]) + + class Recorder(keras.callbacks.Callback): + def __init__(self): + self.train_counter = 0 + self.val_counter = 0 + + def on_train_batch_end(self, *args, **kwargs): + self.train_counter += 1 + + def on_test_batch_end(self, *args, **kwargs): + self.val_counter += 1 + + def infinite_gen(): + while True: + yield np.ones((2, 2)), np.ones((2, 3)) + + recorder = Recorder() + + model.fit( + infinite_gen(), + validation_data=infinite_gen(), + steps_per_epoch=3, + validation_steps=4, + epochs=1, + shuffle=False, + callbacks=[recorder], + ) + self.assertEqual(recorder.train_counter, 3) + self.assertEqual(recorder.val_counter, 4) + + @parameterized.named_parameters( + [ + ("fit", "fit", "training", "train"), + ("evaluate", "evaluate", "evaluating", "test"), + ("predict", "predict", "predicting", "predict"), + ] + ) + @pytest.mark.requires_trainable_backend + def test_stop_loop(self, method, method_gerund, on_end_name): + model = ExampleModel(units=3) + model.compile(optimizer="sgd", loss="mse", metrics=["mse"]) + + class Stopper(keras.callbacks.Callback): + def __init__(self, stop_count): + self.stop_count = stop_count + self.counter = 0 + setattr(self, f"on_{on_end_name}_batch_end", self.batch_end) + + def batch_end(self, *args, **kwargs): + self.counter += 1 + if self.counter == self.stop_count: + setattr(self.model, f"stop_{method_gerund}", True) + + def infinite_gen(): + while True: + x = np.ones((2, 2)) + y = np.ones((2, 3)) + yield (x,) if method == "predict" else (x, y) + + stop_count = 5 + stopper = Stopper(stop_count) + + getattr(model, method)( + infinite_gen(), + callbacks=[stopper], + ) + self.assertEqual(stopper.counter, stop_count) + + @pytest.mark.requires_trainable_backend + def test_constraints_are_applied(self): + model = models.Sequential( + [layers.Dense(2, kernel_constraint="non_neg")] + ) + x = np.ones((2, 3)) + y = np.ones((2, 2)) + model.compile(optimizer="rmsprop", loss="mse") + model.fit(x, y) + self.assertGreaterEqual( + np.min(backend.convert_to_numpy(model.layers[0].kernel)), 0.0 + ) + + @pytest.mark.requires_trainable_backend + def test_rng_updated_during_predict(self): + class TestTimeDropout(layers.Layer): + def __init__(self): + super().__init__() + self.random_generator = keras.random.SeedGenerator() + + def call(self, x): + return keras.random.dropout( + x, rate=0.5, seed=self.random_generator + ) + + inputs = layers.Input((20,)) + outputs = TestTimeDropout()(inputs) + model = keras.Model(inputs, outputs) + model.compile(optimizer="rmsprop", loss="mse") + + x = np.ones((32, 20)) + out_1 = model.predict(x) + out_2 = model.predict(x) + self.assertGreater(np.mean(np.abs(out_1 - out_2)), 0.01) + + @pytest.mark.requires_trainable_backend + def test_callbacks_can_update_state_at_batch_boundary(self): + class CounterModel(keras.Model): + def __init__(self): + super().__init__() + self.train_counter = self.add_weight( + shape=(), + initializer="zeros", + ) + self.test_counter = self.add_weight( + shape=(), + initializer="zeros", + ) + self.predict_counter = self.add_weight( + shape=(), + initializer="zeros", + ) + self.dense = layers.Dense(3) + + def call(self, x): + return self.dense(x) + + class CounterCallback(keras.callbacks.Callback): + def __init__(self): + self.eager_call_counter_train = 0 + self.eager_call_counter_test = 0 + self.eager_call_counter_predict = 0 + + def on_train_batch_end(self, *args, **kwargs): + self.model.train_counter.assign_add(1) + self.eager_call_counter_train += 1 + + def on_test_batch_end(self, *args, **kwargs): + self.model.test_counter.assign_add(1) + self.eager_call_counter_test += 1 + + def on_predict_batch_end(self, *args, **kwargs): + self.model.predict_counter.assign_add(1) + self.eager_call_counter_predict += 1 + + model = CounterModel() + model.compile( + optimizer="sgd", loss="mse", metrics=["mse"], run_eagerly=True + ) + cbk = CounterCallback() + model.fit( + np.ones((4, 3)), + np.ones((4, 3)), + callbacks=[cbk], + epochs=3, + batch_size=1, + verbose=0, + validation_data=(np.ones((2, 3)), np.ones((2, 3))), + ) + self.assertAlmostEqual(cbk.eager_call_counter_train, 12) + self.assertAlmostEqual(model.train_counter.numpy(), 12) + self.assertAlmostEqual(cbk.eager_call_counter_test, 6) + self.assertAlmostEqual(model.test_counter.numpy(), 6) + model.predict( + np.ones((4, 3)), + callbacks=[cbk], + batch_size=1, + ) + self.assertAlmostEqual(cbk.eager_call_counter_predict, 4) + self.assertAlmostEqual(model.predict_counter.numpy(), 4) + + @pytest.mark.requires_trainable_backend + def test_metric_update_in_compute_loss(self): + test_self = self + + class MyModel(keras.Model): + def __init__(self): + super().__init__() + self.custom_metric = keras.metrics.Mean(name="custom") + self.dense = keras.layers.Dense(2) + + def call(self, x): + return self.dense(x) + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + if not in_symbolic_scope(): + test_self.assertTrue(training) + loss = super().compute_loss( + x, y, y_pred, sample_weight, training + ) + self.custom_metric.update_state(loss * 4) + return loss + + model = MyModel() + model.compile(optimizer="sgd", loss="mse") + x = np.ones((32, 4)) + y = np.ones((32, 2)) * 2 + history = model.fit(x, y) + self.assertAlmostEqual( + history.history["custom"][0], history.history["loss"][0] * 4 + ) + + @pytest.mark.requires_trainable_backend + def test_fwd_pass_loss_presence_in_compute_loss(self): + test_self = self + + class MyModel(keras.Model): + def __init__(self): + super().__init__() + self.custom_metric = keras.metrics.Mean(name="custom") + self.dense = keras.layers.Dense(2, activity_regularizer="l2") + + def call(self, x): + return self.dense(x) + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + if not in_symbolic_scope(): + test_self.assertTrue(training) + loss = super().compute_loss( + x, y, y_pred, sample_weight, training + ) + self.custom_metric.update_state(sum(self.losses)) + return loss + + model = MyModel() + model.compile(optimizer="sgd", loss="mse") + x = np.ones((32, 4)) + y = np.ones((32, 2)) * 2 + history = model.fit(x, y) + self.assertGreater(history.history["custom"][0], 0.0) + + @pytest.mark.requires_trainable_backend + def test_evaluate_with_custom_compute_loss(self): + test_self = self + + class MyModel(keras.Model): + def __init__(self): + super().__init__() + self.custom_metric = keras.metrics.Mean(name="custom") + self.dense = keras.layers.Dense(2, activity_regularizer="l2") + + def call(self, x): + return self.dense(x) + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + if not in_symbolic_scope(): + test_self.assertFalse(training) + loss = super().compute_loss( + x, y, y_pred, sample_weight, training + ) + self.custom_metric.update_state(loss * 4) + return loss + + model = MyModel() + model.compile(optimizer="sgd", loss="mse") + x = np.ones((32, 4)) + y = np.ones((32, 2)) * 2 + logs = model.evaluate(x, y, return_dict=True) + self.assertAlmostEqual(logs["custom"], logs["loss"] * 4) + + @pytest.mark.requires_trainable_backend + def test_compute_loss_no_training_backwards_compatibility(self): + class MyModel(keras.Model): + def __init__(self): + super().__init__() + self.custom_metric = keras.metrics.Mean(name="custom") + self.dense = keras.layers.Dense(2, activity_regularizer="l2") + + def call(self, x): + return self.dense(x) + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + ): + loss = super().compute_loss(x, y, y_pred, sample_weight) + self.custom_metric.update_state(loss * 4) + return loss + + model = MyModel() + model.compile(optimizer="sgd", loss="mse") + x = np.ones((32, 4)) + y = np.ones((32, 2)) * 2 + logs = model.evaluate(x, y, return_dict=True) + self.assertAlmostEqual(logs["custom"], logs["loss"] * 4) + history = model.fit(x, y) + self.assertAlmostEqual( + history.history["custom"][0], history.history["loss"][0] * 4 + ) + + @pytest.mark.requires_trainable_backend + def test_loss_weights(self): + epochs = 3 + batch_size = 20 + dataset_size = batch_size * 2 + + # Single output case. + model = ExampleModel(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + loss_weights=0.2, + ) + x = np.ones((dataset_size, 4)) + y = np.zeros((dataset_size, 3)) + history = model.fit( + x, + y, + batch_size=batch_size, + epochs=epochs, + ) + history = history.history + self.assertIn("loss", history) + self.assertAllClose( + history["loss"], + [3.182979, 3.115617, 3.049681], + atol=1e-3, + ) + + # Dict output case. + model = StructModel(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss={ + "y_one": losses.MeanSquaredError(), + "y_two": losses.MeanSquaredError(), + }, + metrics={ + "y_one": metrics.MeanSquaredError(), + "y_two": metrics.MeanSquaredError(), + }, + loss_weights={"y_one": 0.1, "y_two": 0.2}, + ) + x1 = np.ones((dataset_size, 4)) + x2 = np.ones((dataset_size, 4)) + y1 = np.zeros((dataset_size, 3)) + y2 = np.zeros((dataset_size, 3)) + history = model.fit( + {"x_one": x1, "x_two": x2}, + {"y_one": y1, "y_two": y2}, + batch_size=batch_size, + epochs=epochs, + ) + history = history.history + self.assertIn("loss", history) + self.assertAllClose( + history["loss"], + [4.778718, 4.694403, 4.611693], + atol=1e-3, + ) + + # List output case. + model = ListOutputModel(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=[losses.MeanSquaredError(), losses.MeanSquaredError()], + metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()], + loss_weights=[0.1, 0.2], + ) + x = np.ones((dataset_size, 4)) + y1 = np.zeros((dataset_size, 3)) + y2 = np.zeros((dataset_size, 3)) + history = model.fit( + x, + [y1, y2], + batch_size=batch_size, + epochs=epochs, + ) + history = history.history + self.assertIn("loss", history) + self.assertAllClose( + history["loss"], + [4.778718, 4.694403, 4.611693], + atol=1e-3, + ) + + @pytest.mark.requires_trainable_backend + def test_partial_loss_partial_label(self): + inputs = keras.Input((2,)) + x = keras.layers.Dense(3, kernel_initializer="ones")(inputs) + partial_model = keras.Model(inputs, [x, x, x]) + partial_model.compile(loss=["mse", None, None]) + full_model = keras.Model(inputs, [x, x, x]) + full_model.compile(loss=["mse", "mse", "mse"]) + + eval_x = np.ones((32, 2)) + eval_y = np.ones((32, 3)) + + partial_logs = partial_model.evaluate(eval_x, eval_y, return_dict=True) + logs = full_model.evaluate(eval_x, [eval_y] * 3, return_dict=True) + + self.assertAlmostEqual(partial_logs["loss"] * 3, logs["loss"]) + + def test_symbolic_build(self): + class ExampleModelWithTrainingArgs(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense = layers.Dense(units) + self.bn = layers.BatchNormalization(axis=-1) + + def build(self, input_shape): + self.dense.build(input_shape) + input_shape = self.dense.compute_output_shape(input_shape) + self.bn.build(input_shape) + + def call(self, x, training=None): + outputs = self.bn(self.dense(x), training=training) + return [outputs, outputs] + + model = ExampleModelWithTrainingArgs(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=[losses.MeanSquaredError(), losses.MeanSquaredError()], + metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()], + ) + x = np.ones((4, 4)) + y = np.zeros((4, 3)) + model(x) # Eager call to build model weights + ref_weights = model.get_weights() + + # Before `_symbolic_build` + self.assertTrue(model.built) + self.assertFalse(model._compile_metrics.built) + self.assertFalse(model._compile_loss.built) + self.assertLen(model._compile_loss.metrics, 0) + self.assertLen(model.metrics, 2) + + model._symbolic_build(data_batch=(x, (y, y))) + weights = model.get_weights() + + # Ensure weights are intact + self.assertEqual(len(weights), len(ref_weights)) + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Ensure `built` + self.assertTrue(model.built) + self.assertTrue(model._compile_metrics.built) + self.assertTrue(model._compile_loss.built) + + # Ensure the len of metrics (original metrics + loss trackers) + self.assertLen(model._compile_metrics.metrics, 2) + self.assertLen(model._compile_loss.metrics, 2) + self.assertLen(model.metrics, 4) + + # Ensure no values in metrics + for v in model._compile_metrics.variables: + self.assertAllClose(v, 0.0) + for v in model._compile_loss.variables: + self.assertAllClose(v, 0.0) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="This test is only applicable to TensorFlow.", + ) + @pytest.mark.requires_trainable_backend + def test_jit_compile_with_tf_determinism(self): + from tensorflow.python.framework.config import disable_op_determinism + from tensorflow.python.framework.config import enable_op_determinism + + enable_op_determinism() + + model = ExampleModel(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=losses.MeanSquaredError(), + metrics=[metrics.MeanSquaredError()], + ) + + self.assertFalse(model.jit_compile) + disable_op_determinism() + + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + def test_retracing(self): + x = np.ones((100, 4)) + y = np.ones((100, 1)) + + input = keras.Input(shape=[4]) + output = keras.layers.Dense(1, activation="relu")(input) + + tracing_count = [0] + + class TracingCounterModel(keras.Model): + def train_step(self, *args): + tracing_count[0] = tracing_count[0] + 1 + return super().train_step(*args) + + model = TracingCounterModel(inputs=input, outputs=output) + model.compile( + loss="mse", + optimizer="adam", + steps_per_execution=20, + ) + + epochs = 1 + model.fit( + x=x, + y=y, + batch_size=1, + epochs=epochs, + verbose=0, + ) + self.assertLessEqual(tracing_count[0], 2) + + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`steps_per_execution` not implemented for torch yet", + ) + @pytest.mark.skipif( + backend.backend() == "tensorflow", + reason="`predict_function` with `steps_per_execution` is not " + "optimized for tensorflow yet", + ) + def test_retracing_predict(self): + x = np.ones((100, 4)) + + input = keras.Input(shape=[4]) + output = keras.layers.Dense(1, activation="relu")(input) + + tracing_count = [0] + + class TracingCounterModel(keras.Model): + def predict_step(self, *args): + tracing_count[0] = tracing_count[0] + 1 + return super().predict_step(*args) + + model = TracingCounterModel(inputs=input, outputs=output) + model.compile( + loss="mse", + optimizer="adam", + steps_per_execution=20, + ) + + model.predict( + x=x, + batch_size=1, + verbose=0, + ) + self.assertLessEqual(tracing_count[0], 2) + + +class JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("single_device", False), + ("distributed", True), + ) + def test_jit_fit_with_out_shardings_logic(self, distributed): + if keras.backend.backend() != "jax": + self.skipTest("This test requires the JAX backend.") + x = np.random.rand(64, 8).astype("float32") + y = np.random.rand(64, 1).astype("float32") + + distribution = None + if distributed: + if len(jax.local_devices()) < 2: + self.skipTest( + "Distributed test requires at least 2 JAX devices." + ) + + devices = jax.local_devices() + mesh = DeviceMesh( + shape=(len(devices),), axis_names=("batch",), devices=devices + ) + distribution = DataParallel(mesh) + + scope = distribution.scope() if distribution else mock.MagicMock() + + with scope: + model = models.Sequential( + [ + layers.Dense(4, activation="relu", input_shape=(8,)), + layers.Dense(1), + ] + ) + model.compile(optimizer="adam", loss="mse", jit_compile=True) + + if distribution: + expected_shardings = [ + v.value.sharding for v in model.trainable_variables + ] + self.assertNotEqual(len(set(expected_shardings)), 1) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + + if distribution: + actual_shardings = [ + v.value.sharding for v in model.trainable_variables + ] + self.assertListEqual(actual_shardings, expected_shardings) diff --git a/keras/src/tree/__init__.py b/keras/src/tree/__init__.py new file mode 100644 index 000000000000..a719378ef350 --- /dev/null +++ b/keras/src/tree/__init__.py @@ -0,0 +1,12 @@ +from keras.src.tree.tree_api import assert_same_paths +from keras.src.tree.tree_api import assert_same_structure +from keras.src.tree.tree_api import flatten +from keras.src.tree.tree_api import flatten_with_path +from keras.src.tree.tree_api import is_nested +from keras.src.tree.tree_api import lists_to_tuples +from keras.src.tree.tree_api import map_shape_structure +from keras.src.tree.tree_api import map_structure +from keras.src.tree.tree_api import map_structure_up_to +from keras.src.tree.tree_api import pack_sequence_as +from keras.src.tree.tree_api import register_tree_node_class +from keras.src.tree.tree_api import traverse diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py new file mode 100644 index 000000000000..5e4132d419a9 --- /dev/null +++ b/keras/src/tree/dmtree_impl.py @@ -0,0 +1,410 @@ +import collections +import collections.abc +import itertools + +from keras.src.backend.config import backend +from keras.src.utils.module_utils import dmtree + +# NOTE: There are two known discrepancies between this `dmtree` implementation +# of the tree API and the `optree` implementation: +# +# 1. `map_structure` with *multiple* structures and `map_structure_up_to` do not +# use the object registration (they use the raw `dmtree.map_structure` and +# `dmtree.map_structure_up_to`). This only has consequences with two types of +# structures: +# - `TrackedSet` will not explored (considered as a leaf). +# - `OrderedDict` will be traversed in the order of sorted keys, not the +# order of the items. This is typically inconsequential because functions +# used with `map_structure` and `map_structure_up_to` are typically not +# order dependent and are, in fact, stateless. +# +# 2. The handling of non-sortable keys in dictionaries in inconsistent. `optree` +# uses the iteration order while `dmtree` raises an error. This is not an +# issue as keys are always strings. But this is the reason why we document +# non-sortable keys as unsupported (meaning behavior is undefined). + +REGISTERED_CLASSES = {} + +ClassRegistration = collections.namedtuple( + "ClassRegistration", ["flatten", "unflatten"] +) + + +class TypeErrorRemapping: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is TypeError: + raise ValueError(exc_value).with_traceback(traceback) + return False + + +def register_tree_node( + cls, + flatten_func=None, + unflatten_func=None, +): + if flatten_func is None: + flatten_func = lambda x: x.tree_flatten() + if unflatten_func is None: + unflatten_func = cls.tree_unflatten + REGISTERED_CLASSES[cls] = ClassRegistration(flatten_func, unflatten_func) + + +def register_tree_node_class(cls): + register_tree_node(cls) + return cls + + +register_tree_node( + collections.OrderedDict, + lambda d: (d.values(), list(d.keys()), d.keys()), + lambda metadata, children: collections.OrderedDict(zip(metadata, children)), +) + +if backend() == "tensorflow": + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + register_tree_node( + ListWrapper, + lambda x: (x, None), + lambda metadata, children: ListWrapper(list(children)), + ) + + def sorted_keys_and_values(d): + keys = sorted(list(d.keys())) + values = [d[k] for k in keys] + return values, keys, keys + + register_tree_node( + _DictWrapper, + sorted_keys_and_values, + lambda metadata, children: _DictWrapper( + {key: child for key, child in zip(metadata, children)} + ), + ) + + +def is_nested(structure): + return type(structure) in REGISTERED_CLASSES or dmtree.is_nested(structure) + + +def traverse(func, structure, top_down=True): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + def remap_map_to_none(value, new_value): + if isinstance(value, type) and value.__name__ == "MAP_TO_NONE": + return new_value + return value + + def traverse_top_down(s): + ret = func(s) + if ret is not None: + return remap_map_to_none(ret, dmtree.MAP_TO_NONE) + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is None: + return None + flat_meta_s = registration.flatten(s) + flat_s = [ + dmtree.traverse(traverse_top_down, x, top_down=True) + for x in list(flat_meta_s[0]) + ] + return registration.unflatten(flat_meta_s[1], flat_s) + + def traverse_bottom_up(s): + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is not None: + flat_meta_s = registration.flatten(s) + ret = [traverse_bottom_up(x) for x in list(flat_meta_s[0])] + ret = registration.unflatten(flat_meta_s[1], ret) + elif not dmtree.is_nested(s): + ret = s + elif isinstance(s, collections.abc.Mapping): + ret = [traverse_bottom_up(s[key]) for key in sorted(s)] + ret = dmtree._sequence_like(s, ret) + else: + ret = [traverse_bottom_up(x) for x in s] + ret = dmtree._sequence_like(s, ret) + func_ret = func(ret) + return ret if func_ret is None else remap_map_to_none(func_ret, None) + + if top_down: + return dmtree.traverse(traverse_top_down, structure, top_down=True) + else: + return traverse_bottom_up(structure) + + +def flatten(structure): + if not is_nested(structure): + return [structure] + + flattened = [] + + def flatten_func(s): + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is not None: + flat_s = list(registration.flatten(s)[0]) + return dmtree.traverse(flatten_func, flat_s, top_down=True) + if not is_nested(s): + flattened.append(s) + return dmtree.MAP_TO_NONE if s is None else s + return None + + dmtree.traverse(flatten_func, structure, top_down=True) + return flattened + + +def _recursive_flatten_with_path(path, structure, flattened): + registration = REGISTERED_CLASSES.get(type(structure), None) + if registration is not None: + flat_meta_paths = registration.flatten(structure) + flat = flat_meta_paths[0] + paths = ( + flat_meta_paths[2] + if len(flat_meta_paths) >= 3 + else itertools.count() + ) + for key, value in zip(paths, flat): + _recursive_flatten_with_path(path + (key,), value, flattened) + elif not dmtree.is_nested(structure): + flattened.append((path, structure)) + elif isinstance(structure, collections.abc.Mapping): + for key in sorted(structure): + _recursive_flatten_with_path( + path + (key,), structure[key], flattened + ) + else: + for key, value in enumerate(structure): + _recursive_flatten_with_path(path + (key,), value, flattened) + + +def flatten_with_path(structure): + if not is_nested(structure): + return [((), structure)] + + # Fully reimplemented in Python to handle registered classes, OrderedDict + # and namedtuples the same way as optree. + flattened = [] + _recursive_flatten_with_path((), structure, flattened) + return flattened + + +def map_structure(func, *structures, none_is_leaf=True): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + map_func = func + if not none_is_leaf: + + def func_skipping_none(*args): + # Check if the reference entry (first one) is None + if args[0] is None: + if not all(s is None for s in args): + raise ValueError( + "Structure mismatch: some arguments are None, others " + f"are not. Received arguments: {args}." + ) + return None + return func(*args) + + map_func = func_skipping_none + + def func_traverse_wrapper(s): + if is_nested(s): + return None + ret = map_func(s) + if ret is None: + return dmtree.MAP_TO_NONE + return ret + + if len(structures) == 1: + return traverse(func_traverse_wrapper, structures[0]) + + with TypeErrorRemapping(): + return dmtree.map_structure(map_func, *structures) + + +def map_structure_up_to(shallow_structure, func, *structures): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + with TypeErrorRemapping(): + return dmtree.map_structure_up_to(shallow_structure, func, *structures) + + +def assert_same_structure(a, b): + # Fully reimplemented in Python to handle registered classes. + + # Don't handle OrderedDict as a registered class, use the normal dict path + # so that OrderedDict is equivalent to dict per optree behavior. + a_registration = REGISTERED_CLASSES.get(type(a), None) + if isinstance(a, collections.OrderedDict): + a_registration = None + + b_registration = REGISTERED_CLASSES.get(type(b), None) + if isinstance(b, collections.OrderedDict): + b_registration = None + + if a_registration != b_registration: + raise ValueError( + f"Custom node type mismatch; " + f"expected type: {type(a)}, got type: {type(b)} " + f"while comparing {a} and {b}." + ) + if a_registration is not None: + a_flat_meta = a_registration.flatten(a) + b_flat_meta = b_registration.flatten(b) + a_flat = list(a_flat_meta[0]) + b_flat = list(b_flat_meta[0]) + if not a_flat_meta[1] == b_flat_meta[1]: + raise ValueError( + f"Mismatch custom node data; " + f"expected: {a_flat_meta[1]}, got: {b_flat_meta[1]} " + f"while comparing {a} and {b}." + ) + if len(a_flat) != len(b_flat): + raise ValueError( + f"Arity mismatch; expected: {len(a)}, got: {len(b)} " + f"while comparing {a} and {b}." + ) + for sub_a, sub_b in zip(a_flat, b_flat): + assert_same_structure(sub_a, sub_b) + elif not dmtree.is_nested(a): + if dmtree.is_nested(b): + raise ValueError( + f"Structures don't have the same nested structure: {a}, {b}." + ) + elif isinstance( + a, (dict, collections.OrderedDict, collections.defaultdict) + ): + if not isinstance( + b, (dict, collections.OrderedDict, collections.defaultdict) + ): + raise ValueError( + f"Expected an instance of dict, collections.OrderedDict, or " + f"collections.defaultdict, got {type(b)} " + f"while comparing {a} and {b}." + ) + a_keys = sorted(a) + b_keys = sorted(b) + if not a_keys == b_keys: + raise ValueError( + f"Dictionary key mismatch; " + f"expected key(s): {a_keys}, got key(s): {b_keys} " + f"while comparing {a} and {b}." + ) + for key in a_keys: + assert_same_structure(a[key], b[key]) + elif isinstance(a, collections.abc.Mapping): + raise ValueError( + f"Encountered unregistered collections.abc.Mapping type: {type(a)} " + f"while comparing {a} and {b}." + ) + else: + if type(a) is not type(b): + raise ValueError( + f"Expected an instance of {type(a)}, got {type(b)} " + f"while comparing {a} and {b}." + ) + if not len(a) == len(b): + raise ValueError( + f"Arity mismatch; expected: {len(a)}, got: {len(b)} " + f"while comparing {a} and {b}." + ) + for sub_a, sub_b in zip(a, b): + assert_same_structure(sub_a, sub_b) + + +def assert_same_paths(a, b): + a_paths = set([path for path, _ in flatten_with_path(a)]) + b_paths = set([path for path, _ in flatten_with_path(b)]) + + if a_paths != b_paths: + msg = "`a` and `b` don't have the same paths." + a_diff = a_paths.difference(b_paths) + if a_diff: + msg += f"\nPaths in `a` missing in `b`:\n{a_diff}" + b_diff = b_paths.difference(a_paths) + if b_diff: + msg += f"\nPaths in `b` missing in `a`:\n{b_diff}" + raise ValueError(msg) + + +def pack_sequence_as(structure, flat_sequence): + # This is not just an optimization for the case when structure is a leaf. + # This is required to avoid Torch Dynamo failures. + if not is_nested(structure): + if len(flat_sequence) == 1: + return flat_sequence[0] + else: + raise ValueError( + "Incorrect number of leaves provided by `flat_sequence` for " + f"`structure`; expected: 1, got {len(flat_sequence)}." + ) + + flat_sequence_it = enumerate(flat_sequence) + + def unflatten_func(s): + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is not None: + flat_meta_s = registration.flatten(s) + flat_s = dmtree.traverse( + unflatten_func, list(flat_meta_s[0]), top_down=True + ) + return registration.unflatten(flat_meta_s[1], flat_s) + elif not dmtree.is_nested(s): + try: + _, value = next(flat_sequence_it) + return dmtree.MAP_TO_NONE if value is None else value + except StopIteration: + raise ValueError( + "Too few leaves provided by `flat_sequence` for " + f"`structure`. Got {len(flat_sequence)}." + ) + return None + + ret = dmtree.traverse(unflatten_func, structure, top_down=True) + try: + index, _ = next(flat_sequence_it) + raise ValueError( + "Too many leaves provided by `flat_sequence` for `structure`; " + f"expected: {index}, got {len(flat_sequence)}." + ) + except StopIteration: + return ret + + +def lists_to_tuples(structure): + def list_to_tuple(instance): + return tuple(instance) if isinstance(instance, list) else None + + return traverse(list_to_tuple, structure, top_down=False) + + +def map_shape_structure(func, structure): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + def map_shape_func(x): + if isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ): + ret = func(x) + elif is_nested(x): + return None + else: + ret = func(x) + return ret if ret is not None else dmtree.MAP_TO_NONE + + return traverse(map_shape_func, structure, top_down=True) diff --git a/keras/src/tree/optree_impl.py b/keras/src/tree/optree_impl.py new file mode 100644 index 000000000000..1134d8338048 --- /dev/null +++ b/keras/src/tree/optree_impl.py @@ -0,0 +1,190 @@ +import optree +import optree.utils + +from keras.src.backend.config import backend + + +def register_tree_node_class(cls): + return optree.register_pytree_node_class(cls, namespace="keras") + + +# Register backend-specific node classes +if backend() == "tensorflow": + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + try: + optree.register_pytree_node( + ListWrapper, + lambda x: (x, None), + lambda metadata, children: ListWrapper(list(children)), + namespace="keras", + ) + + def sorted_keys_and_values(d): + keys = sorted(list(d.keys())) + values = [d[k] for k in keys] + return values, keys, keys + + optree.register_pytree_node( + _DictWrapper, + sorted_keys_and_values, + lambda metadata, children: _DictWrapper( + {key: child for key, child in zip(metadata, children)} + ), + namespace="keras", + ) + except ValueError: + pass # We may have already registered if we are reimporting keras. + + +def is_nested(structure): + return not optree.tree_is_leaf( + structure, none_is_leaf=True, namespace="keras" + ) + + +def traverse(func, structure, top_down=True): + # From https://github.com/google/jax/pull/19695 + def traverse_children(): + children, treedef = optree.tree_flatten( + structure, + is_leaf=lambda x: x is not structure, + none_is_leaf=True, + namespace="keras", + ) + if treedef.num_nodes == 1 and treedef.num_leaves == 1: + return structure + else: + return optree.tree_unflatten( + treedef, + [traverse(func, c, top_down=top_down) for c in children], + ) + + if top_down: + ret = func(structure) + if ret is None: + return traverse_children() + else: + traversed_structure = traverse_children() + ret = func(traversed_structure) + if ret is None: + return traversed_structure + # Detect MAP_TO_NONE without tree_api import to avoid circular import. + if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE": + return None + return ret + + +def flatten(structure): + # optree.tree_flatten returns a pair (leaves, treespec) where the first + # element is a list of leaf values and the second element is a treespec + # representing the structure of the pytree. + leaves, _ = optree.tree_flatten( + structure, none_is_leaf=True, namespace="keras" + ) + return leaves + + +def flatten_with_path(structure): + paths, leaves, _ = optree.tree_flatten_with_path( + structure, none_is_leaf=True, namespace="keras" + ) + return list(zip(paths, leaves)) + + +def map_structure(func, *structures, none_is_leaf=True): + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check for same structures, otherwise optree just maps to shallowest. + def func_with_check(*args): + if not all( + optree.tree_is_leaf(s, none_is_leaf=none_is_leaf, namespace="keras") + for s in args + ): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + map_func = func_with_check if len(structures) > 1 else func + + return optree.tree_map( + map_func, *structures, none_is_leaf=none_is_leaf, namespace="keras" + ) + + +def map_structure_up_to(shallow_structure, func, *structures): + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check that `shallow_structure` really is the shallowest. + # Also only call `func` on `structures` and not `shallow_structure`. + def func_with_check_without_shallow_structure(shallow, *args): + if not optree.tree_is_leaf(shallow): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + return optree.tree_map( + func_with_check_without_shallow_structure, + shallow_structure, + *structures, + none_is_leaf=True, + namespace="keras", + ) + + +def assert_same_structure(a, b): + def check(a_leaf, b_leaf): + if not optree.tree_is_leaf( + a_leaf, none_is_leaf=True, namespace="keras" + ) or not optree.tree_is_leaf( + b_leaf, none_is_leaf=True, namespace="keras" + ): + raise ValueError("Structures don't have the same nested structure.") + return None + + optree.tree_map(check, a, b, none_is_leaf=True, namespace="keras") + + +def assert_same_paths(a, b): + a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace="keras")) + b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace="keras")) + + if a_paths != b_paths: + msg = "`a` and `b` don't have the same paths." + a_diff = a_paths.difference(b_paths) + if a_diff: + msg += f"\nPaths in `a` missing in `b`:\n{a_diff}" + b_diff = b_paths.difference(a_paths) + if b_diff: + msg += f"\nPaths in `b` missing in `a`:\n{b_diff}" + raise ValueError(msg) + + +def pack_sequence_as(structure, flat_sequence): + _, treespec = optree.tree_flatten( + structure, none_is_leaf=True, namespace="keras" + ) + return optree.tree_unflatten(treespec, flat_sequence) + + +def lists_to_tuples(structure): + def list_to_tuple(instance): + return tuple(instance) if isinstance(instance, list) else None + + return traverse(list_to_tuple, structure, top_down=False) + + +def map_shape_structure(func, structure): + def is_shape_tuple(x): + return isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ) + + return optree.tree_map( + func, + structure, + is_leaf=is_shape_tuple, + none_is_leaf=True, + namespace="keras", + ) diff --git a/keras/src/tree/torchtree_impl.py b/keras/src/tree/torchtree_impl.py new file mode 100644 index 000000000000..f7c5c9817cae --- /dev/null +++ b/keras/src/tree/torchtree_impl.py @@ -0,0 +1,215 @@ +from collections import defaultdict + +from torch.utils import _pytree as torch_tree + + +def register_tree_node_class(cls): + torch_tree.register_pytree_node( + cls, + flatten_fn=lambda x: x.torchtree_flatten(), + unflatten_fn=cls.torchtree_unflatten, + serialized_type_name=f"{cls.__name__}", + flatten_with_keys_fn=lambda x: x.torchtree_flatten_with_keys(), + ) + return cls + + +def _tree_is_leaf(tree, is_leaf=None): + if is_leaf is not None and is_leaf(tree): + return True + return torch_tree._get_node_type(tree) not in torch_tree.SUPPORTED_NODES + + +def _dict_to_ordered_dict(structure): + # We need to sort dict and defaultdict to ensure a deterministic order that + # that is consistent with other tree implementations. + def func(x): + if type(x) is dict: + return {k: x[k] for k in sorted(x.keys())} + elif type(x) is defaultdict: + return defaultdict( + x.default_factory, + {k: x[k] for k in sorted(x.keys())}, + ) + return None + + def traverse_children(): + children, treedef = torch_tree.tree_flatten( + structure, + is_leaf=lambda x: x is not structure, + ) + if treedef.num_nodes == 1 and treedef.num_leaves == 1: + return structure + else: + return torch_tree.tree_unflatten( + [_dict_to_ordered_dict(c) for c in children], + treedef, + ) + + ret = func(structure) + if ret is None: + return traverse_children() + if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE": + return None + return ret + + +def is_nested(structure): + return not _tree_is_leaf(structure) + + +def traverse(func, structure, top_down=True): + def traverse_children(): + children, treedef = torch_tree.tree_flatten( + structure, + is_leaf=lambda x: x is not structure, + ) + if treedef.num_nodes == 1 and treedef.num_leaves == 1: + return structure + else: + return torch_tree.tree_unflatten( + [traverse(func, c, top_down=top_down) for c in children], + treedef, + ) + + structure = _dict_to_ordered_dict(structure) + if top_down: + ret = func(structure) + if ret is None: + return traverse_children() + else: + traversed_structure = traverse_children() + ret = func(traversed_structure) + if ret is None: + return traversed_structure + # Detect MAP_TO_NONE without tree_api import to avoid circular import. + if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE": + return None + return ret + + +def flatten(structure): + # We need to first sort dicts to ensure a deterministic order that is + # consistent with other tree implementations. + structure = _dict_to_ordered_dict(structure) + leaves, _ = torch_tree.tree_flatten(structure) + return leaves + + +def flatten_with_path(structure): + # We need to first sort dicts to ensure a deterministic order that is + # consistent with other tree implementations. + structure = _dict_to_ordered_dict(structure) + leaves_with_path, _ = torch_tree.tree_flatten_with_path(structure) + results = [] + fields = [] + for key, leaf in leaves_with_path: + for k in key: + if isinstance(k, torch_tree.GetAttrKey) and k.name not in fields: + fields.append(k.name) + fields = sorted(fields) + field_to_idx = {f: i for i, f in enumerate(fields)} + for key, leaf in leaves_with_path: + # Convert to a tuple of keys. + path = [] + for k in key: + if isinstance(k, torch_tree.SequenceKey): + path.append(k.idx) + elif isinstance(k, torch_tree.MappingKey): + path.append(k.key) + elif isinstance(k, torch_tree.GetAttrKey): + path.append(field_to_idx[k.name]) + results.append((tuple(path), leaf)) + return results + + +def map_structure(func, *structures, none_is_leaf=True): + if not structures: + raise ValueError("Must provide at least one structure") + + map_func = func + if not none_is_leaf: + + def func_skipping_none(*args): + # Check if the reference entry (first one) is None + if args[0] is None: + if not all(s is None for s in args): + raise ValueError( + "Structure mismatch: some arguments are None, others " + f"are not. Received arguments: {args}." + ) + return None + return func(*args) + + map_func = func_skipping_none + + return torch_tree.tree_map(map_func, *structures) + + +def map_structure_up_to(shallow_structure, func, *structures): + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check that `shallow_structure` really is the shallowest. + # Also only call `func` on `structures` and not `shallow_structure`. + def func_with_check_without_shallow_structure(shallow, *args): + if not _tree_is_leaf(shallow): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + return torch_tree.tree_map( + func_with_check_without_shallow_structure, + shallow_structure, + *structures, + ) + + +def assert_same_structure(a, b): + def check(a_leaf, b_leaf): + if not _tree_is_leaf(a_leaf) or not _tree_is_leaf(b_leaf): + raise ValueError("Structures don't have the same nested structure.") + return None + + torch_tree.tree_map(check, a, b) + + +def assert_same_paths(a, b): + a_paths = set([path for path, _ in flatten_with_path(a)]) + b_paths = set([path for path, _ in flatten_with_path(b)]) + + if a_paths != b_paths: + msg = "`a` and `b` don't have the same paths." + a_diff = a_paths.difference(b_paths) + if a_diff: + msg += f"\nPaths in `a` missing in `b`:\n{a_diff}" + b_diff = b_paths.difference(a_paths) + if b_diff: + msg += f"\nPaths in `b` missing in `a`:\n{b_diff}" + raise ValueError(msg) + + +def pack_sequence_as(structure, flat_sequence): + # We need to first sort dicts to ensure a deterministic order that is + # consistent with other tree implementations. + structure = _dict_to_ordered_dict(structure) + _, treespec = torch_tree.tree_flatten(structure) + return torch_tree.tree_unflatten(flat_sequence, treespec) + + +def lists_to_tuples(structure): + def list_to_tuple(instance): + return tuple(instance) if isinstance(instance, list) else None + + return traverse(list_to_tuple, structure, top_down=False) + + +def map_shape_structure(func, structure): + def is_shape_tuple(x): + return isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ) + + # We need to first sort dicts to ensure a deterministic order that is + # consistent with other tree implementations. + structure = _dict_to_ordered_dict(structure) + return torch_tree.tree_map(func, structure, is_leaf=is_shape_tuple) diff --git a/keras/src/tree/tree_api.py b/keras/src/tree/tree_api.py new file mode 100644 index 000000000000..d4e476de5e45 --- /dev/null +++ b/keras/src/tree/tree_api.py @@ -0,0 +1,412 @@ +import warnings + +from keras.src.api_export import keras_export +from keras.src.backend.config import backend +from keras.src.utils.module_utils import dmtree +from keras.src.utils.module_utils import optree + +if backend() == "torch": + # torchtree_impl is especially used for Torch backend, as it works better + # with torch.compile. + from keras.src.tree import torchtree_impl as tree_impl +elif optree.available: + from keras.src.tree import optree_impl as tree_impl +elif dmtree.available: + from keras.src.tree import dmtree_impl as tree_impl +else: + raise ImportError( + "To use Keras, you need to have `optree` installed. " + "Install it via `pip install optree`" + ) + + +def register_tree_node_class(cls): + return tree_impl.register_tree_node_class(cls) + + +@keras_export("keras.tree.MAP_TO_NONE") +class MAP_TO_NONE: + """Special value for use with `traverse()`.""" + + pass + + +@keras_export("keras.tree.is_nested") +def is_nested(structure): + """Checks if a given structure is nested. + + Examples: + + >>> keras.tree.is_nested(42) + False + >>> keras.tree.is_nested({"foo": 42}) + True + + Args: + structure: A structure to check. + + Returns: + `True` if a given structure is nested, i.e. is a sequence, a mapping, + or a namedtuple, and `False` otherwise. + """ + return tree_impl.is_nested(structure) + + +@keras_export("keras.tree.traverse") +def traverse(func, structure, top_down=True): + """Traverses the given nested structure, applying the given function. + + The traversal is depth-first. If `top_down` is True (default), parents + are returned before their children (giving the option to avoid traversing + into a sub-tree). + + Examples: + + >>> v = [] + >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True) + [(1, 2), [3], {'a': 4}] + >>> v + [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] + + >>> v = [] + >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False) + [(1, 2), [3], {'a': 4}] + >>> v + [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] + + Args: + func: The function to be applied to each sub-nest of the structure. + + When traversing top-down: + If `func(subtree) is None` the traversal continues into the + sub-tree. + If `func(subtree) is not None` the traversal does not continue + into the sub-tree. The sub-tree will be replaced by `func(subtree)` + in the returned structure (to replace the sub-tree with `None`, use + the special value `MAP_TO_NONE`). + + When traversing bottom-up: + If `func(subtree) is None` the traversed sub-tree is returned + unaltered. + If `func(subtree) is not None` the sub-tree will be replaced by + `func(subtree)` in the returned structure (to replace the sub-tree + with None, use the special value `MAP_TO_NONE`). + + structure: The structure to traverse. + top_down: If True, parent structures will be visited before their + children. + + Returns: + The structured output from the traversal. + + Raises: + TypeError: If `func` is not callable. + """ + return tree_impl.traverse(func, structure, top_down=top_down) + + +@keras_export("keras.tree.flatten") +def flatten(structure): + """Flattens a possibly nested structure into a list. + + In the case of dict instances, the sequence consists of the values, + sorted by key to ensure deterministic behavior. However, instances of + `collections.OrderedDict` are handled differently: their sequence order is + used instead of the sorted keys. The same convention is followed in + `pack_sequence_as`. This correctly unflattens dicts and `OrderedDict` after + they have been flattened, or vice-versa. + + Dictionaries with non-sortable keys are not supported. + + Examples: + + >>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]]) + [1, 2, 3, 4, 5, 6] + >>> keras.tree.flatten(None) + [None] + >>> keras.tree.flatten(1) + [1] + >>> keras.tree.flatten({100: 'world!', 6: 'Hello'}) + ['Hello', 'world!'] + + Args: + structure: An arbitrarily nested structure. + + Returns: + A list, the flattened version of the input `structure`. + """ + return tree_impl.flatten(structure) + + +@keras_export("keras.tree.flatten_with_path") +def flatten_with_path(structure): + """Flattens a possibly nested structure into a list. + + This is a variant of flattens() which produces a + list of pairs: `(path, item)`. A path is a tuple of indices and/or keys + which uniquely identifies the position of the corresponding item. + + Dictionaries with non-sortable keys are not supported. + + Examples: + + >>> keras.flatten_with_path([{"foo": 42}]) + [((0, 'foo'), 42)] + + + Args: + structure: An arbitrarily nested structure. + + Returns: + A list of `(path, item)` pairs corresponding to the flattened + version of the input `structure`. + """ + return tree_impl.flatten_with_path(structure) + + +@keras_export("keras.tree.map_structure") +def map_structure(func, *structures, none_is_leaf=True): + """Maps `func` through given structures. + + Examples: + + >>> structure = [[1], [2], [3]] + >>> keras.tree.map_structure(lambda v: v**2, structure) + [[1], [4], [9]] + >>> keras.tree.map_structure(lambda x, y: x * y, structure, structure) + [[1], [4], [9]] + + >>> Foo = collections.namedtuple('Foo', ['a', 'b']) + >>> structure = Foo(a=1, b=2) + >>> keras.tree.map_structure(lambda v: v * 2, structure) + Foo(a=2, b=4) + + Args: + func: A callable that accepts as many arguments as there are structures. + *structures: Arbitrarily nested structures of the same layout. + none_is_leaf: If True, `func` will be called on `None` leaves. If False, + `None` values are not passed to `func` and are returned in the + output directly. + + Returns: + A new structure with the same layout as the given ones. + + Raises: + TypeError: If `structures` is empty or `func` is not callable. + ValueError: If there is more than one items in `structures` and some of + the nested structures don't match according to the rules of + `assert_same_structure`. + """ + return tree_impl.map_structure(func, *structures, none_is_leaf=none_is_leaf) + + +@keras_export("keras.tree.map_structure_up_to") +def map_structure_up_to(shallow_structure, func, *structures): + """Maps `func` through given structures up to `shallow_structure`. + + This is a variant of `map_structure` which only maps the given structures + up to `shallow_structure`. All further nested components are retained as-is. + + Examples: + + >>> shallow_structure = [None, None] + >>> structure = [[1, 1], [2, 2]] + >>> keras.tree.map_structure_up_to(shallow_structure, len, structure) + [2, 2] + + >>> shallow_structure = [None, [None, None]] + >>> keras.tree.map_structure_up_to(shallow_structure, str, structure) + ['[1, 1]', ['2', '2']] + + Args: + shallow_structure: A structure with layout common to all `structures`. + func: A callable that accepts as many arguments as there are structures. + *structures: Arbitrarily nested structures of the same layout. + + Returns: + A new structure with the same layout as `shallow_structure`. + + Raises: + TypeError: If `structures` is empty or `func` is not callable. + ValueError: If one of the items in `structures` doesn't match the + nested structure of `shallow_structure` according to the rules of + `assert_same_structure`. Items in `structures` are allowed to be + nested deeper than `shallow_structure`, but they cannot be + shallower. + """ + return tree_impl.map_structure_up_to(shallow_structure, func, *structures) + + +@keras_export("keras.tree.assert_same_structure") +def assert_same_structure(a, b, check_types=None): + """Asserts that two structures are nested in the same way. + + This function verifies that the nested structures match. The leafs can be of + any type. At each level, the structures must be of the same type and have + the same number of elements. Instances of `dict`, `OrderedDict` and + `defaultdict` are all considered the same as long as they have the same set + of keys. However, `list`, `tuple`, `namedtuple` and `deque` are not the same + structures. Two namedtuples with identical fields and even identical names + are not the same structures. + + Examples: + + >>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)]) + + >>> Foo = collections.namedtuple('Foo', ['a', 'b']) + >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b']) + >>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3)) + >>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3)) + Traceback (most recent call last): + ... + ValueError: The two structures don't have the same nested structure. + ... + + Args: + a: an arbitrarily nested structure. + b: an arbitrarily nested structure. + check_types: Deprecated. The behavior of this flag was inconsistent, it + no longer has any effect. For a looser check, use + `assert_same_paths` instead, which considers `list`, `tuple`, + `namedtuple` and `deque` as matching structures. + + Raises: + ValueError: If the two structures `a` and `b` don't match. + """ + if check_types is not None: + if check_types: + warnings.warn( + "The `check_types` argument is deprecated and no longer has " + "any effect, please remove.", + DeprecationWarning, + stacklevel=2, + ) + else: + warnings.warn( + "The `check_types` argument is deprecated and no longer has " + "any effect. For a looser check, use " + "`keras.tree.assert_same_paths()`, which considers `list`, " + "`tuple`, `namedtuple` and `deque` as matching", + DeprecationWarning, + stacklevel=2, + ) + return tree_impl.assert_same_structure(a, b) + + +@keras_export("keras.tree.assert_same_paths") +def assert_same_paths(a, b): + """Asserts that two structures have identical paths in their tree structure. + + This function verifies that two nested structures have the same paths. + Unlike `assert_same_structure`, this function only checks the paths + and ignores the collection types. + For Sequences, to path is the index: 0, 1, 2, etc. For Mappings, the path is + the key, for instance "a", "b", "c". Note that namedtuples also use indices + and not field names for the path. + + Examples: + >>> keras.tree.assert_same_paths([0, 1], (2, 3)) + >>> Point1 = collections.namedtuple('Point1', ['x', 'y']) + >>> Point2 = collections.namedtuple('Point2', ['x', 'y']) + >>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3)) + + Args: + a: an arbitrarily nested structure. + b: an arbitrarily nested structure. + + Raises: + ValueError: If the paths in structure `a` don't match the paths in + structure `b`. The error message will include the specific paths + that differ. + """ + return tree_impl.assert_same_paths(a, b) + + +@keras_export("keras.tree.pack_sequence_as") +def pack_sequence_as(structure, flat_sequence): + """Returns a given flattened sequence packed into a given structure. + + If `structure` is an atom, `flat_sequence` must be a single-item list; in + this case the return value is `flat_sequence[0]`. + + If `structure` is or contains a dict instance, the keys will be sorted to + pack the flat sequence in deterministic order. However, instances of + `collections.OrderedDict` are handled differently: their sequence order is + used instead of the sorted keys. The same convention is followed in + `flatten`. This correctly repacks dicts and `OrderedDicts` after they have + been flattened, or vice-versa. + + Dictionaries with non-sortable keys are not supported. + + Examples: + + >>> structure = {"key3": "", "key1": "", "key2": ""} + >>> flat_sequence = ["value1", "value2", "value3"] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + {"key3": "value3", "key1": "value1", "key2": "value2"} + + >>> structure = (("a", "b"), ("c", "d", "e"), "f") + >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) + + >>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")}, + ... "key1": {"e": "val1", "d": "val2"}} + >>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} + + >>> structure = ["a"] + >>> flat_sequence = [np.array([[1, 2], [3, 4]])] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + [array([[1, 2], + [3, 4]])] + + >>> structure = ["a"] + >>> flat_sequence = [keras.ops.ones([2, 2])] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + [array([[1., 1.], + [1., 1.]]] + + Args: + structure: Arbitrarily nested structure. + flat_sequence: Flat sequence to pack. + + Returns: + `flat_sequence` converted to have the same recursive structure as + `structure`. + + Raises: + TypeError: If `flat_sequence` is not iterable. + ValueError: If `flat_sequence` cannot be repacked as `structure`; for + instance, if `flat_sequence` has too few or too many elements. + """ + return tree_impl.pack_sequence_as(structure, flat_sequence) + + +@keras_export("keras.tree.lists_to_tuples") +def lists_to_tuples(structure): + """Returns the structure with list instances changed to tuples. + + Args: + structure: Arbitrarily nested structure. + + Returns: + The same structure but with tuples instead of lists. + """ + return tree_impl.lists_to_tuples(structure) + + +@keras_export("keras.tree.map_shape_structure") +def map_shape_structure(func, structure): + """Variant of keras.tree.map_structure that operates on shape tuples. + + Tuples containing ints and Nones are considered shapes and passed to `func`. + + Args: + structure: Arbitrarily nested structure. + + Returns: + The same structure with `func` applied. + """ + return tree_impl.map_shape_structure(func, structure) diff --git a/keras/src/tree/tree_test.py b/keras/src/tree/tree_test.py new file mode 100644 index 000000000000..fa026dc0c764 --- /dev/null +++ b/keras/src/tree/tree_test.py @@ -0,0 +1,2332 @@ +import functools +from collections import OrderedDict +from collections import defaultdict +from collections import deque +from collections import namedtuple + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.tree.tree_api import MAP_TO_NONE +from keras.src.utils.module_utils import dmtree +from keras.src.utils.module_utils import optree +from keras.src.utils.tracking import TrackedDict +from keras.src.utils.tracking import TrackedList +from keras.src.utils.tracking import TrackedSet + +TEST_CASES = [] +if dmtree.available: + from keras.src.tree import dmtree_impl + + TEST_CASES += [ + { + "testcase_name": "dmtree", + "t": dmtree_impl, + } + ] +if backend.backend() != "torch" and optree.available: + from keras.src.tree import optree_impl + + TEST_CASES += [ + { + "testcase_name": "optree", + "t": optree_impl, + }, + ] +if backend.backend() == "torch": + from keras.src.tree import torchtree_impl + + TEST_CASES += [ + { + "testcase_name": "torchtree", + "t": torchtree_impl, + }, + ] + + +Empty = namedtuple("Empty", []) +Point = namedtuple("Point", ["x", "y"]) +OtherPoint = namedtuple("OtherPoint", ["x", "y"]) + + +def default_value(): + return None + + +class Visitor: + def __init__(self, func): + self.func = func + self.visited_list = [] + + def __call__(self, x): + self.visited_list.append(x) + return self.func(x) + + def visited(self): + ret = self.visited_list + self.visited_list = [] + return ret + + +@parameterized.named_parameters(TEST_CASES) +class TreeTest(testing.TestCase): + def setUp(self): + if dmtree.available and optree.available: + # If both are available, the annotation on the Keras tracking + # wrappers will have used optree. For testing purposes, we need to + # also register them with dm-tree. + from keras.src.tree import dmtree_impl + + dmtree_impl.register_tree_node_class(TrackedList) + dmtree_impl.register_tree_node_class(TrackedSet) + dmtree_impl.register_tree_node_class(TrackedDict) + super().setUp() + + def assertEqualStrict(self, a, b): + self.assertEqual(a, b) + self.assertEqual(type(a), type(b)) + if isinstance(a, OrderedDict): + # Verify order. + self.assertEqual(a.items(), b.items()) + elif isinstance(a, defaultdict): + self.assertEqual(a.default_factory, b.default_factory) + # Recurse + if isinstance(a, (tuple, list, deque)): + for sub_a, sub_b in zip(a, b): + self.assertEqualStrict(sub_a, sub_b) + elif isinstance(a, dict): + for k in a: + self.assertEqualStrict(a[k], b[k]) + + def is_dmtree(self, tree_impl): + if dmtree.available: + from keras.src.tree import dmtree_impl + + return tree_impl is dmtree_impl + return False + + def test_is_nested(self, t): + # Non-nested. + self.assertFalse(t.is_nested(1)) + self.assertFalse(t.is_nested("1234")) + self.assertFalse(t.is_nested(b"1234")) + self.assertFalse(t.is_nested(bytearray("1234", "ascii"))) + self.assertFalse(t.is_nested(np.ones((4, 5)))) + self.assertFalse(t.is_nested(ops.ones((4, 5)))) + self.assertFalse(t.is_nested(set([1, 2]))) + + # Standard structures. + self.assertTrue(t.is_nested(())) + self.assertTrue(t.is_nested((1,))) + self.assertTrue(t.is_nested((1, 2))) + self.assertTrue(t.is_nested([])) + self.assertTrue(t.is_nested([1])) + self.assertTrue(t.is_nested([1, 2])) + self.assertTrue(t.is_nested(deque([]))) + self.assertTrue(t.is_nested(deque([1]))) + self.assertTrue(t.is_nested(deque([1, 2]))) + self.assertTrue(t.is_nested(Empty())) + self.assertTrue(t.is_nested(Point(x=1, y=2))) + self.assertTrue(t.is_nested({})) + self.assertTrue(t.is_nested({"a": 1})) + self.assertTrue(t.is_nested({"b": 2, "a": 1})) + self.assertTrue(t.is_nested(OrderedDict())) + self.assertTrue(t.is_nested(OrderedDict([("a", 1)]))) + self.assertTrue(t.is_nested(OrderedDict([("b", 2), ("a", 1)]))) + self.assertTrue(t.is_nested(defaultdict(default_value))) + self.assertTrue(t.is_nested(defaultdict(default_value, [("a", 1)]))) + self.assertTrue( + t.is_nested(defaultdict(default_value, [("b", 2), ("a", 1)])) + ) + + # Keras tracking wrappers. + self.assertTrue(t.is_nested(TrackedList([]))) + self.assertTrue(t.is_nested(TrackedList([1]))) + self.assertTrue(t.is_nested(TrackedList([1, 2]))) + self.assertTrue(t.is_nested(TrackedSet([]))) + self.assertTrue(t.is_nested(TrackedSet([1]))) + self.assertTrue(t.is_nested(TrackedSet([1, 2]))) + self.assertTrue(t.is_nested(TrackedDict({}))) + self.assertTrue(t.is_nested(TrackedDict({"a": 1}))) + self.assertTrue(t.is_nested(TrackedDict({"b": 2, "a": 1}))) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_is_nested_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + self.assertTrue(t.is_nested(ListWrapper([]))) + self.assertTrue(t.is_nested(ListWrapper([1]))) + self.assertTrue(t.is_nested(ListWrapper([1, 2]))) + self.assertTrue(t.is_nested(_DictWrapper({}))) + self.assertTrue(t.is_nested(_DictWrapper({"a": 1}))) + self.assertTrue(t.is_nested(_DictWrapper({"b": 2, "a": 1}))) + + def test_flatten(self, t): + # Non-nested. + self.assertEqualStrict(t.flatten(1), [1]) + + # Standard structures. + self.assertEqualStrict(t.flatten(()), []) + self.assertEqualStrict(t.flatten((1,)), [1]) + self.assertEqualStrict(t.flatten((1, 2)), [1, 2]) + self.assertEqualStrict(t.flatten([]), []) + self.assertEqualStrict(t.flatten([1]), [1]) + self.assertEqualStrict(t.flatten([1, 2]), [1, 2]) + self.assertEqualStrict(t.flatten(deque([])), []) + self.assertEqualStrict(t.flatten(deque([1])), [1]) + self.assertEqualStrict(t.flatten(deque([1, 2])), [1, 2]) + self.assertEqualStrict(t.flatten(Empty()), []) + self.assertEqualStrict(t.flatten(Point(y=2, x=1)), [1, 2]) + self.assertEqualStrict(t.flatten({}), []) + self.assertEqualStrict(t.flatten({"a": 1}), [1]) + self.assertEqualStrict(t.flatten({"b": 2, "a": 1}), [1, 2]) + self.assertEqualStrict( + t.flatten(OrderedDict()), + [], + ) + self.assertEqualStrict( + t.flatten(OrderedDict([("a", 1)])), + [1], + ) + self.assertEqualStrict( + t.flatten(OrderedDict([("b", 2), ("a", 1)])), + [2, 1], + ) + self.assertEqualStrict( + t.flatten(defaultdict(default_value)), + [], + ) + self.assertEqualStrict( + t.flatten(defaultdict(default_value, [("a", 1)])), + [1], + ) + self.assertEqualStrict( + t.flatten(defaultdict(default_value, [("b", 2), ("a", 1)])), + [1, 2], + ) + + # Keras tracking wrappers. + self.assertEqualStrict(t.flatten(TrackedList([])), []) + self.assertEqualStrict(t.flatten(TrackedList([1])), [1]) + self.assertEqualStrict(t.flatten(TrackedList([1, 2])), [1, 2]) + self.assertEqualStrict(t.flatten(TrackedSet([])), []) + self.assertEqualStrict(t.flatten(TrackedSet([1])), [1]) + self.assertEqualStrict(sorted(t.flatten(TrackedSet([1, 2]))), [1, 2]) + self.assertEqualStrict(t.flatten(TrackedDict({})), []) + self.assertEqualStrict(t.flatten(TrackedDict({"a": 1})), [1]) + self.assertEqualStrict(t.flatten(TrackedDict({"b": 2, "a": 1})), [1, 2]) + + # Deeper nested structures. + self.assertEqualStrict( + t.flatten( + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ) + ), + [1, 2, 3, 4, 5, 6, 7, 8, 9, np.array([10])], + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_flatten_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + self.assertEqualStrict(t.flatten(ListWrapper([])), []) + self.assertEqualStrict(t.flatten(ListWrapper([1])), [1]) + self.assertEqualStrict(t.flatten(ListWrapper([1, 2])), [1, 2]) + self.assertEqualStrict(t.flatten(_DictWrapper({})), []) + self.assertEqualStrict(t.flatten(_DictWrapper({"a": 1})), [1]) + self.assertEqualStrict( + t.flatten(_DictWrapper({"b": 2, "a": 1})), [1, 2] + ) + + def test_flatten_with_path(self, t): + # Non-nested. + self.assertEqualStrict( + t.flatten_with_path(1), + [((), 1)], + ) + + # Standard structures. + self.assertEqualStrict( + t.flatten_with_path(()), + [], + ) + self.assertEqualStrict( + t.flatten_with_path((1,)), + [((0,), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path((1, 2)), + [((0,), 1), ((1,), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path([]), + [], + ) + self.assertEqualStrict( + t.flatten_with_path([1]), + [((0,), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path([1, 2]), + [((0,), 1), ((1,), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path(deque([])), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(deque([1])), + [((0,), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(deque([1, 2])), + [((0,), 1), ((1,), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path(Empty()), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(Point(y=2, x=1)), + [((0,), 1), ((1,), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path({}), + [], + ) + self.assertEqualStrict( + t.flatten_with_path({"a": 1}), + [(("a",), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path({"b": 2, "a": 1}), + [(("a",), 1), (("b",), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path(OrderedDict()), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(OrderedDict([("a", 1)])), + [(("a",), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(OrderedDict([("b", 2), ("a", 1)])), + [(("b",), 2), (("a",), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(defaultdict(default_value)), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(defaultdict(default_value, [("a", 1)])), + [(("a",), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path( + defaultdict(default_value, [("b", 2), ("a", 1)]) + ), + [(("a",), 1), (("b",), 2)], + ) + + # Keras tracking wrappers. + self.assertEqualStrict( + t.flatten_with_path(TrackedList([])), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(TrackedList([1])), + [((0,), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(TrackedList([1, 2])), + [((0,), 1), ((1,), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path(TrackedSet([])), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(TrackedSet([1])), + [((0,), 1)], + ) + flat = t.flatten_with_path(TrackedSet([1, 2])) + if flat[0][1] == 1: + self.assertEqualStrict(flat, [((0,), 1), ((1,), 2)]) + else: + self.assertEqualStrict(flat, [((0,), 2), ((1,), 1)]) + self.assertEqualStrict( + t.flatten_with_path(TrackedDict({})), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(TrackedDict({"a": 1})), + [(("a",), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(TrackedDict({"b": 2, "a": 1})), + [(("a",), 1), (("b",), 2)], + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.flatten_with_path( + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ) + ), + [ + ((0, "a", 0), 1), + ((0, "b", 0), 2), + ((0, "b", 1), 3), + ((1, "x"), 4), + ((1, "y", 0), 5), + ((1, "y", 1), 6), + ((2, 0), 7), + ((3, 0), 8), + ((3, 1), 9), + ((4,), np.array([10])), + ], + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_flatten_with_path_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + self.assertEqualStrict( + t.flatten_with_path(ListWrapper([])), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(ListWrapper([1])), + [((0,), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(ListWrapper([1, 2])), + [((0,), 1), ((1,), 2)], + ) + self.assertEqualStrict( + t.flatten_with_path(_DictWrapper({})), + [], + ) + self.assertEqualStrict( + t.flatten_with_path(_DictWrapper({"a": 1})), + [(("a",), 1)], + ) + self.assertEqualStrict( + t.flatten_with_path(_DictWrapper({"b": 2, "a": 1})), + [(("a",), 1), (("b",), 2)], + ) + + def test_pack_sequence_as(self, t): + # Non-nested. + self.assertEqualStrict(t.pack_sequence_as(10, [1]), 1) + + # Standard structures. + self.assertEqualStrict(t.pack_sequence_as((), []), ()) + self.assertEqualStrict(t.pack_sequence_as((10,), [1]), (1,)) + self.assertEqualStrict(t.pack_sequence_as((10, 20), [1, 2]), (1, 2)) + self.assertEqualStrict(t.pack_sequence_as([], []), []) + self.assertEqualStrict(t.pack_sequence_as([10], [1]), [1]) + self.assertEqualStrict(t.pack_sequence_as([10, 20], [1, 2]), [1, 2]) + self.assertEqualStrict(t.pack_sequence_as(deque([]), []), deque([])) + self.assertEqualStrict(t.pack_sequence_as(deque([10]), [1]), deque([1])) + self.assertEqualStrict( + t.pack_sequence_as(deque([10, 20]), [1, 2]), deque([1, 2]) + ) + self.assertEqualStrict(t.pack_sequence_as(Empty(), []), Empty()) + self.assertEqualStrict( + t.pack_sequence_as(Point(y=20, x=10), [1, 2]), Point(x=1, y=2) + ) + self.assertEqualStrict(t.pack_sequence_as({}, []), {}) + self.assertEqualStrict(t.pack_sequence_as({"a": 10}, [1]), {"a": 1}) + self.assertEqualStrict( + t.pack_sequence_as({"b": 20, "a": 10}, [1, 2]), {"a": 1, "b": 2} + ) + self.assertEqualStrict( + t.pack_sequence_as(OrderedDict(), []), OrderedDict() + ) + self.assertEqualStrict( + t.pack_sequence_as(OrderedDict([("a", 10)]), [1]), + OrderedDict([("a", 1)]), + ) + self.assertEqualStrict( + t.pack_sequence_as(OrderedDict([("b", 20), ("a", 10)]), [2, 1]), + OrderedDict([("b", 2), ("a", 1)]), + ) + self.assertEqualStrict( + t.pack_sequence_as(defaultdict(default_value), []), + defaultdict(default_value), + ) + self.assertEqualStrict( + t.pack_sequence_as(defaultdict(default_value, [("a", 10)]), [1]), + defaultdict(default_value, [("a", 1)]), + ) + self.assertEqualStrict( + t.pack_sequence_as( + defaultdict(default_value, [("b", 20), ("a", 10)]), [1, 2] + ), + defaultdict(default_value, [("a", 1), ("b", 2)]), + ) + + # Keras tracking wrappers. + self.assertEqualStrict( + t.pack_sequence_as(TrackedList([]), []), TrackedList([]) + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedList([10]), [1]), TrackedList([1]) + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedList([10, 20]), [1, 2]), + TrackedList([1, 2]), + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedSet([]), []), TrackedSet([]) + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedSet([10]), [1]), TrackedSet([1]) + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedSet([10, 20]), [1, 2]), TrackedSet([1, 2]) + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedDict({}), []), TrackedDict({}) + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedDict({"a": 10}), [1]), + TrackedDict({"a": 1}), + ) + self.assertEqualStrict( + t.pack_sequence_as(TrackedDict({"b": 20, "a": 10}), [1, 2]), + TrackedDict({"a": 1, "b": 2}), + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.pack_sequence_as( + ( + {"b": [20, 30], "a": (10,)}, + TrackedDict({"x": 40, "y": TrackedList([50, 60])}), + TrackedSet([70]), + Point(y=90, x=80), + 100, + ), + [1, 2, 3, 4, 5, 6, 7, 8, 9, np.array([10])], + ), + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(x=8, y=9), + np.array([10]), + ), + ) + + # Error cases. + with self.assertRaisesRegex(TypeError, "[Ii]terable"): + t.pack_sequence_as([10, 20], 1) + with self.assertRaisesRegex(ValueError, "leaves.*[expected:|holds] 1"): + t.pack_sequence_as(10, []) + with self.assertRaisesRegex(ValueError, "leaves.*[expected:|holds] 1"): + t.pack_sequence_as(10, [1, 2]) + with self.assertRaisesRegex(ValueError, "[Too few leaves|holds 2]"): + t.pack_sequence_as([10, 20], [1]) + with self.assertRaisesRegex(ValueError, "[Too many leaves|holds 3]"): + t.pack_sequence_as([10, 20], [1, 2, 3]) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_pack_sequence_as_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + self.assertEqualStrict( + t.pack_sequence_as(ListWrapper([]), []), ListWrapper([]) + ) + self.assertEqualStrict( + t.pack_sequence_as(ListWrapper([10]), [1]), ListWrapper([1]) + ) + self.assertEqualStrict( + t.pack_sequence_as(ListWrapper([10, 20]), [1, 2]), + ListWrapper([1, 2]), + ) + self.assertEqualStrict( + t.pack_sequence_as(_DictWrapper({}), []), _DictWrapper({}) + ) + self.assertEqualStrict( + t.pack_sequence_as(_DictWrapper({"a": 10}), [1]), + _DictWrapper({"a": 1}), + ) + self.assertEqualStrict( + t.pack_sequence_as(_DictWrapper({"b": 20, "a": 10}), [1, 2]), + _DictWrapper({"b": 2, "a": 1}), + ) + + def test_map_structure_with_one_structure(self, t): + def f1(x): + return x + 10 if isinstance(x, int) else None + + # Non-nested. + self.assertEqualStrict(t.map_structure(f1, 1), 11) + + # Standard structures. + self.assertEqualStrict(t.map_structure(f1, ()), ()) + self.assertEqualStrict(t.map_structure(f1, (1,)), (11,)) + self.assertEqualStrict(t.map_structure(f1, (1, 2)), (11, 12)) + self.assertEqualStrict(t.map_structure(f1, []), []) + self.assertEqualStrict(t.map_structure(f1, [1]), [11]) + self.assertEqualStrict(t.map_structure(f1, [1, 2]), [11, 12]) + self.assertEqualStrict(t.map_structure(f1, deque([])), deque([])) + self.assertEqualStrict(t.map_structure(f1, deque([1])), deque([11])) + self.assertEqualStrict( + t.map_structure(f1, deque([1, 2])), deque([11, 12]) + ) + self.assertEqualStrict(t.map_structure(f1, Empty()), Empty()) + self.assertEqualStrict( + t.map_structure(f1, Point(y=2, x=1)), Point(x=11, y=12) + ) + self.assertEqualStrict( + t.map_structure(f1, {}), + {}, + ) + self.assertEqualStrict( + t.map_structure(f1, {"a": 1}), + {"a": 11}, + ) + self.assertEqualStrict( + t.map_structure(f1, {"b": 2, "a": 1}), + {"a": 11, "b": 12}, + ) + self.assertEqualStrict( + t.map_structure(f1, OrderedDict()), + OrderedDict(), + ) + self.assertEqualStrict( + t.map_structure(f1, OrderedDict([("a", 1)])), + OrderedDict([("a", 11)]), + ) + self.assertEqualStrict( + t.map_structure(f1, OrderedDict([("b", 2), ("a", 1)])), + OrderedDict([("b", 12), ("a", 11)]), + ) + self.assertEqualStrict( + t.map_structure(f1, defaultdict(default_value)), + defaultdict(default_value), + ) + self.assertEqualStrict( + t.map_structure(f1, defaultdict(default_value, [("a", 1)])), + defaultdict(default_value, [("a", 11)]), + ) + self.assertEqualStrict( + t.map_structure( + f1, defaultdict(default_value, [("b", 2), ("a", 1)]) + ), + defaultdict(default_value, [("a", 11), ("b", 12)]), + ) + + # Keras tracking wrappers. + self.assertEqualStrict( + t.map_structure(f1, TrackedList([])), TrackedList([]) + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedList([1])), TrackedList([11]) + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedList([1, 2])), TrackedList([11, 12]) + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedSet([])), TrackedSet([]) + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedSet([1])), TrackedSet([11]) + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedSet([1, 2])), TrackedSet([11, 12]) + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedDict()), + TrackedDict(), + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedDict({"a": 1})), + TrackedDict({"a": 11}), + ) + self.assertEqualStrict( + t.map_structure(f1, TrackedDict({"b": 2, "a": 1})), + TrackedDict({"a": 11, "b": 12}), + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.map_structure( + f1, + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + ), + ( + {"b": [12, 13], "a": (11,)}, + TrackedDict({"x": 14, "y": TrackedList([15, 16])}), + TrackedSet([17]), + Point(y=19, x=18), + None, + ), + ) + + # Error cases. + with self.assertRaisesRegex(TypeError, "callable"): + t.map_structure("bad", [1, 2]) + with self.assertRaisesRegex(ValueError, "at least one structure"): + t.map_structure(f1) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_map_structure_with_one_structure_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + def f1(x): + return x + 10 + + self.assertEqualStrict( + t.map_structure(f1, ListWrapper([])), ListWrapper([]) + ) + self.assertEqualStrict( + t.map_structure(f1, ListWrapper([1])), ListWrapper([11]) + ) + self.assertEqualStrict( + t.map_structure(f1, ListWrapper([1, 2])), ListWrapper([11, 12]) + ) + self.assertEqualStrict( + t.map_structure(f1, _DictWrapper()), + _DictWrapper(), + ) + self.assertEqualStrict( + t.map_structure(f1, _DictWrapper({"a": 1})), + _DictWrapper({"a": 11}), + ) + self.assertEqualStrict( + t.map_structure(f1, _DictWrapper({"b": 2, "a": 1})), + _DictWrapper({"a": 11, "b": 12}), + ) + + def test_map_structure_with_multiple_structures(self, t): + def f2(x, y): + return x + y if isinstance(x, int) and isinstance(y, int) else None + + # Non-nested. + self.assertEqualStrict(t.map_structure(f2, 1, 10), 11) + + # Standard structures. + self.assertEqualStrict(t.map_structure(f2, ()), ()) + self.assertEqualStrict(t.map_structure(f2, (1,), (10,)), (11,)) + self.assertEqualStrict(t.map_structure(f2, (1, 2), (10, 20)), (11, 22)) + self.assertEqualStrict(t.map_structure(f2, []), []) + self.assertEqualStrict(t.map_structure(f2, [1], [10]), [11]) + self.assertEqualStrict(t.map_structure(f2, [1, 2], [10, 20]), [11, 22]) + self.assertEqualStrict(t.map_structure(f2, deque([])), deque([])) + self.assertEqualStrict( + t.map_structure(f2, deque([1]), deque([10])), deque([11]) + ) + self.assertEqualStrict( + t.map_structure(f2, deque([1, 2]), deque([10, 20])), deque([11, 22]) + ) + self.assertEqualStrict(t.map_structure(f2, Empty()), Empty()) + self.assertEqualStrict( + t.map_structure(f2, Point(y=2, x=1), Point(x=10, y=20)), + Point(x=11, y=22), + ) + self.assertEqualStrict(t.map_structure(f2, {}), {}) + self.assertEqualStrict( + t.map_structure(f2, {"a": 1}, {"a": 10}), {"a": 11} + ) + self.assertEqualStrict( + t.map_structure(f2, {"b": 2, "a": 1}, {"a": 10, "b": 20}), + {"a": 11, "b": 22}, + ) + self.assertEqualStrict( + t.map_structure(f2, OrderedDict()), + OrderedDict(), + ) + self.assertEqualStrict( + t.map_structure( + f2, OrderedDict([("a", 1)]), OrderedDict([("a", 10)]) + ), + OrderedDict([("a", 11)]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + OrderedDict([("b", 2), ("a", 1)]), + OrderedDict([("b", 20), ("a", 10)]), + ), + OrderedDict([("b", 22), ("a", 11)]), + ) + self.assertEqualStrict( + t.map_structure( + f2, defaultdict(default_value), defaultdict(default_value) + ), + defaultdict(default_value), + ) + self.assertEqualStrict( + t.map_structure( + f2, + defaultdict(default_value, [("a", 1)]), + defaultdict(default_value, [("a", 10)]), + ), + defaultdict(default_value, [("a", 11)]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + defaultdict(default_value, [("b", 2), ("a", 1)]), + defaultdict(default_value, [("a", 10), ("b", 20)]), + ), + defaultdict(default_value, [("a", 11), ("b", 22)]), + ) + + # Keras tracking wrappers. + self.assertEqualStrict( + t.map_structure( + f2, + TrackedList([]), + TrackedList([]), + ), + TrackedList([]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + TrackedList([1]), + TrackedList([10]), + ), + TrackedList([11]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + TrackedList([1, 2]), + TrackedList([10, 20]), + ), + TrackedList([11, 22]), + ) + + # Known limitation of the dm-tree implementation: + # Registered classes are not handled when mapping multiple + # structures at once. TrackedSet is the only problematic one. + if not self.is_dmtree(t): + self.assertEqualStrict( + t.map_structure( + f2, + TrackedSet([]), + TrackedSet([]), + ), + TrackedSet([]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + TrackedSet([1]), + TrackedSet([10]), + ), + TrackedSet([11]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + TrackedSet([1, 2]), + TrackedSet([10, 20]), + ), + TrackedSet([11, 22]), + ) + + self.assertEqualStrict( + t.map_structure( + f2, + TrackedDict({}), + TrackedDict({}), + ), + TrackedDict({}), + ) + self.assertEqualStrict( + t.map_structure( + f2, + TrackedDict({"a": 1}), + TrackedDict({"a": 10}), + ), + TrackedDict({"a": 11}), + ) + self.assertEqualStrict( + t.map_structure( + f2, + TrackedDict({"b": 2, "a": 1}), + TrackedDict({"a": 10, "b": 20}), + ), + TrackedDict({"a": 11, "b": 22}), + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.map_structure( + f2, + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + ( + {"b": [20, 30], "a": (10,)}, + TrackedDict({"x": 40, "y": TrackedList([50, 60])}), + TrackedSet([70]), + Point(y=90, x=80), + np.array([100]), + ), + ), + ( + {"b": [22, 33], "a": (11,)}, + TrackedDict({"x": 44, "y": TrackedList([55, 66])}), + # Known limitation of the dm-tree implementation: + # Registered classes are not handled when mapping multiple + # structures at once. TrackedSet is the only problematic one. + None if self.is_dmtree(t) else TrackedSet([77]), + Point(y=99, x=88), + None, + ), + ) + + # Error cases. + + # list, tuple, deque and namedtuple are not considered equivalent. + # Test all 6 combinations: + # tuple, list. + with self.assertRaisesRegex(ValueError, "tuple"): + t.map_structure(f2, (), []) + # tuple, deque. + with self.assertRaisesRegex(ValueError, "tuple"): + t.map_structure(f2, (), deque()) + # tuple, namedtuple. + with self.assertRaisesRegex(ValueError, "tuple"): + t.map_structure(f2, (), Empty()) + # list, deque. + with self.assertRaisesRegex(ValueError, "list"): + t.map_structure(f2, [], deque()) + # list, namedtuple. + with self.assertRaisesRegex(ValueError, "list"): + t.map_structure(f2, [], Empty()) + # deque, namedtuple. + with self.assertRaisesRegex(ValueError, "deque"): + t.map_structure(f2, deque(), Empty()) + + # Equivalent namedtuples don't match. + with self.assertRaisesRegex(ValueError, "namedtuple"): + t.map_structure(f2, Point(x=1, y=2), OtherPoint(x=10, y=20)) + + # Mismatched counts. + with self.assertRaisesRegex(ValueError, "(number|[Aa]rity)"): + t.map_structure(f2, (1, 2), (1,)) + with self.assertRaisesRegex(ValueError, "(number|[Aa]rity)"): + t.map_structure(f2, [1, 2], [1]) + with self.assertRaisesRegex(ValueError, "(number|[Aa]rity)"): + t.map_structure(f2, deque([1, 2]), deque([1])) + + # dict, OrderedDict, defaultdict are considered equivalent, but the + # returned type is the first one. Test all 6 combinations (3 type + # combinations plus the order). + # dict, OrderedDict yields dict. + self.assertEqualStrict( + t.map_structure( + f2, {"a": 1, "b": 2}, OrderedDict([("b", 20), ("a", 10)]) + ), + {"a": 11, "b": 22}, + ) + # OrderedDict, dict yields OrderedDict with same order. + self.assertEqualStrict( + t.map_structure( + f2, + OrderedDict([("b", 2), ("a", 1)]), + {"a": 10, "b": 20}, + ), + OrderedDict([("b", 22), ("a", 11)]), + ) + # dict, defaultdict yields dict. + self.assertEqualStrict( + t.map_structure( + f2, + {"a": 1, "b": 2}, + defaultdict(default_value, [("b", 20), ("a", 10)]), + ), + {"a": 11, "b": 22}, + ) + # defaultdict, dict yields defaultdict. + self.assertEqualStrict( + t.map_structure( + f2, + defaultdict(default_value, [("b", 2), ("a", 1)]), + {"a": 10, "b": 20}, + ), + defaultdict(default_value, [("a", 11), ("b", 22)]), + ) + # defaultdict, OrderedDict yields defaultdict. + self.assertEqualStrict( + t.map_structure( + f2, + defaultdict(default_value, [("a", 1), ("b", 2)]), + OrderedDict([("b", 20), ("a", 10)]), + ), + defaultdict(default_value, [("a", 11), ("b", 22)]), + ) + # OrderedDict, defaultdict yields OrderedDict with same order. + self.assertEqualStrict( + t.map_structure( + f2, + OrderedDict([("b", 2), ("a", 1)]), + defaultdict(default_value, [("a", 10), ("b", 20)]), + ), + OrderedDict([("b", 22), ("a", 11)]), + ) + + # Multiple OrderedDicts with same keys but different orders, the order + # of the first one prevails. + self.assertEqualStrict( + t.map_structure( + f2, + OrderedDict([("b", 2), ("a", 1)]), + OrderedDict([("a", 10), ("b", 20)]), + ), + OrderedDict([("b", 22), ("a", 11)]), + ) + + # Mismatched keys + with self.assertRaisesRegex(ValueError, "[key|Node arity mismatch]"): + t.map_structure(f2, {"a": 1, "b": 2}, {"a": 1}) + with self.assertRaisesRegex(ValueError, "[key|Node arity mismatch]"): + t.map_structure( + f2, + defaultdict(default_value, [("a", 1), ("b", 2)]), + defaultdict(default_value, [("a", 10)]), + ) + with self.assertRaisesRegex(ValueError, "[key|Node arity mismatch]"): + t.map_structure( + f2, OrderedDict([("a", 1), ("b", 2)]), OrderedDict([("a", 10)]) + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_map_structure_with_multiple_structures_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + def f2(x, y): + return x + y + + self.assertEqualStrict( + t.map_structure( + f2, + ListWrapper([]), + ListWrapper([]), + ), + ListWrapper([]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + ListWrapper([1]), + ListWrapper([10]), + ), + ListWrapper([11]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + ListWrapper([1, 2]), + ListWrapper([10, 20]), + ), + ListWrapper([11, 22]), + ) + self.assertEqualStrict( + t.map_structure( + f2, + _DictWrapper({}), + _DictWrapper({}), + ), + _DictWrapper({}), + ) + self.assertEqualStrict( + t.map_structure( + f2, + _DictWrapper({"a": 1}), + _DictWrapper({"a": 10}), + ), + _DictWrapper({"a": 11}), + ) + self.assertEqualStrict( + t.map_structure( + f2, + _DictWrapper({"b": 2, "a": 1}), + _DictWrapper({"a": 10, "b": 20}), + ), + _DictWrapper({"a": 11, "b": 22}), + ) + + def test_map_structure_up_to(self, t): + # Named tuples. + shallow = OtherPoint(x=2, y=3) + deep = OtherPoint(x=Point(x=1, y=2), y=Point(x=2, y=3)) + out = t.map_structure_up_to( + shallow, + lambda a, b: (a + b.x) * b.y, + shallow, + deep, + ) + self.assertEqual(out.x, 6) + self.assertEqual(out.y, 15) + + # Lists. + data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + name_list = ["evens", ["odds", "primes"]] + out = t.map_structure_up_to( + name_list, + lambda name, sec: "first_{}_{}".format(len(sec), name), + name_list, + data_list, + ) + self.assertEqual( + out, ["first_4_evens", ["first_5_odds", "first_3_primes"]] + ) + + def test_assert_same_structure(self, t): + # Non-nested. + t.assert_same_structure(1, 10) + + # Standard structures. + t.assert_same_structure((), ()) + t.assert_same_structure((1,), (10,)) + t.assert_same_structure((1, 2), (10, 20)) + t.assert_same_structure([], []) + t.assert_same_structure([1], [10]) + t.assert_same_structure([1, 2], [10, 20]) + t.assert_same_structure(deque([]), deque([])) + t.assert_same_structure(deque([1]), deque([1])) + t.assert_same_structure(deque([1, 2]), deque([10, 20])) + t.assert_same_structure(Empty(), Empty()) + t.assert_same_structure(Point(y=1, x=2), Point(x=10, y=20)) + t.assert_same_structure({}, {}) + t.assert_same_structure({"a": 1}, {"a": 10}) + t.assert_same_structure({"b": 2, "a": 1}, {"a": 10, "b": 20}) + t.assert_same_structure(OrderedDict(), OrderedDict()) + t.assert_same_structure( + OrderedDict([("a", 1)]), OrderedDict([("a", 10)]) + ) + t.assert_same_structure( + OrderedDict([("b", 1), ("a", 2)]), + OrderedDict([("b", 10), ("a", 20)]), + ) + t.assert_same_structure( + defaultdict(default_value), defaultdict(default_value) + ) + t.assert_same_paths( + defaultdict(default_value, [("a", 1)]), + defaultdict(default_value, [("a", 10)]), + ) + t.assert_same_paths( + defaultdict(default_value, [("b", 1), ("a", 2)]), + defaultdict(default_value, [("a", 10), ("b", 20)]), + ) + + # Keras tracking wrappers. + t.assert_same_structure( + TrackedList([]), + TrackedList([]), + ) + t.assert_same_structure( + TrackedList([1]), + TrackedList([10]), + ) + t.assert_same_structure( + TrackedList([1, 2]), + TrackedList([10, 20]), + ) + t.assert_same_structure( + TrackedSet([]), + TrackedSet([]), + ) + t.assert_same_structure( + TrackedSet([1]), + TrackedSet([10]), + ) + t.assert_same_structure( + TrackedSet([1, 2]), + TrackedSet([10, 20]), + ) + t.assert_same_structure( + TrackedDict({}), + TrackedDict({}), + ) + t.assert_same_structure( + TrackedDict({"a": 1}), + TrackedDict({"a": 10}), + ) + t.assert_same_structure( + TrackedDict({"b": 2, "a": 1}), + TrackedDict({"a": 10, "b": 20}), + ) + + # Deeper nested structures. + t.assert_same_structure( + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + ( + {"b": [20, 30], "a": (10,)}, + TrackedDict({"x": 40, "y": TrackedList([50, 60])}), + TrackedSet([70]), + Point(y=90, x=80), + np.array([100]), + ), + ) + + # Error cases. + + # Non-nested vs. nested. + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, ()) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*tuple"): + t.assert_same_structure((), 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, []) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*list"): + t.assert_same_structure([], 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, deque([])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*deque"): + t.assert_same_structure(deque([]), 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, Empty()) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*(Empty|tuple)"): + t.assert_same_structure(Empty(), 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, Point(x=1, y=2)) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*(Point|tuple)"): + t.assert_same_structure(Point(x=1, y=2), 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, {}) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*dict"): + t.assert_same_structure({}, 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, OrderedDict()) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*OrderedDict"): + t.assert_same_structure(OrderedDict(), 1) + with self.assertRaisesRegex(ValueError, "don't have the same nested"): + t.assert_same_structure(1, defaultdict(default_value)) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*defaultdict"): + t.assert_same_structure(defaultdict(default_value), 1) + + # Non-nested vs. Keras tracking wrappers. + with self.assertRaisesRegex(ValueError, "(nested|TrackedList)"): + t.assert_same_structure(1, TrackedList([])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*TrackedList"): + t.assert_same_structure(TrackedList([]), 1) + with self.assertRaisesRegex(ValueError, "(nested|TrackedSet)"): + t.assert_same_structure(1, TrackedSet([])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*TrackedSet"): + t.assert_same_structure(TrackedSet([]), 1) + with self.assertRaisesRegex(ValueError, "(nested|TrackedDict)"): + t.assert_same_structure(1, TrackedDict([])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*TrackedDict"): + t.assert_same_structure(TrackedDict([]), 1) + + # list, tuple, deque and namedtuple are not considered equivalent. + # Test all 6 combinations: + # tuple, list. + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*tuple"): + t.assert_same_structure((), []) + # tuple, deque. + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*tuple"): + t.assert_same_structure((), deque()) + # tuple, namedtuple. + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*tuple"): + t.assert_same_structure((), Empty()) + # list, deque. + with self.assertRaisesRegex(ValueError, "list"): + t.assert_same_structure([], deque()) + # list, namedtuple. + with self.assertRaisesRegex(ValueError, "list"): + t.assert_same_structure([], Empty()) + # deque, namedtuple. + with self.assertRaisesRegex(ValueError, "deque"): + t.assert_same_structure(deque(), Empty()) + + # Equivalent namedtuples don't match. + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*[. ]Point"): + t.assert_same_structure(Point(x=1, y=2), OtherPoint(x=10, y=20)) + + # Mismatched counts. + with self.assertRaisesRegex(ValueError, "[Aa]rity mismatch"): + t.assert_same_structure((1, 2), (1,)) + with self.assertRaisesRegex(ValueError, "[Aa]rity mismatch"): + t.assert_same_structure([1, 2], [1]) + with self.assertRaisesRegex(ValueError, "[Aa]rity mismatch"): + t.assert_same_structure(deque([1, 2]), deque([1])) + + # Mismatched counts with Keras tracking wrappers. + with self.assertRaisesRegex(ValueError, "[Aa]rity mismatch"): + t.assert_same_structure(TrackedList([1, 2]), TrackedList([1])) + with self.assertRaisesRegex(ValueError, "[Aa]rity mismatch"): + t.assert_same_structure(TrackedSet([1, 2]), TrackedSet([1])) + + # dict, OrderedDict, defaultdict are considered equivalent. + # Test all 6 combinations (3 type combinations plus the order). + # dict, OrderedDict. + t.assert_same_structure( + {"a": 1, "b": 2}, OrderedDict([("b", 20), ("a", 10)]) + ) + # OrderedDict, dict. + t.assert_same_structure( + OrderedDict([("b", 20), ("a", 10)]), {"a": 1, "b": 2} + ) + # dict, defaultdict. + t.assert_same_structure( + {"a": 1, "b": 2}, + defaultdict(default_value, [("b", 20), ("a", 10)]), + ) + # defaultdict, dict. + t.assert_same_structure( + defaultdict(default_value, [("b", 20), ("a", 10)]), + {"a": 1, "b": 2}, + ) + # defaultdict, OrderedDict. + t.assert_same_structure( + defaultdict(default_value, [("a", 1), ("b", 2)]), + OrderedDict([("b", 20), ("a", 10)]), + ) + # OrderedDict, defaultdict. + t.assert_same_structure( + OrderedDict([("b", 2), ("a", 1)]), + defaultdict(default_value, [("a", 10), ("b", 20)]), + ) + + # Two OrderedDicts with same keys but different orders. + t.assert_same_structure( + OrderedDict([("b", 2), ("a", 1)]), + OrderedDict([("a", 10), ("b", 20)]), + ) + + # Keras tracking wrappers are not equivalent to the raw structures. + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*TrackedList"): + t.assert_same_structure(TrackedList([1, 2]), list([10, 20])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*list"): + t.assert_same_structure(list([1, 2]), TrackedList([10, 20])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*TrackedSet"): + t.assert_same_structure(TrackedSet([1, 2]), list([10, 20])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*list"): + t.assert_same_structure(list([1, 2]), TrackedSet([10, 20])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*TrackedDict"): + t.assert_same_structure( + TrackedDict({"b": 2, "a": 1}), {"a": 10, "b": 20} + ) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*dict"): + t.assert_same_structure( + {"b": 2, "a": 1}, TrackedDict({"a": 10, "b": 20}) + ) + + # Mismatched key count. + with self.assertRaisesRegex( + ValueError, "[Dd]ictionary key mismatch|Node arity mismatch" + ): + t.assert_same_structure( + {"a": 1, "b": 2}, + {"a": 1}, + ) + with self.assertRaisesRegex( + ValueError, "[Dd]ictionary key mismatch|Node arity mismatch" + ): + t.assert_same_structure( + defaultdict(default_value, [("a", 1), ("b", 2)]), + defaultdict(default_value, [("a", 10)]), + ) + with self.assertRaisesRegex( + ValueError, "[Dd]ictionary key mismatch|Node arity mismatch" + ): + t.assert_same_structure( + OrderedDict([("a", 1), ("b", 2)]), + OrderedDict([("a", 10)]), + ) + + # Mismatched keys. + with self.assertRaisesRegex( + ValueError, "[Dd]ictionary key mismatch|Node keys mismatch" + ): + t.assert_same_structure( + {"a": 1}, + {"b": 2}, + ) + with self.assertRaisesRegex( + ValueError, "[Dd]ictionary key mismatch|Node keys mismatch" + ): + t.assert_same_structure( + defaultdict(default_value, [("a", 1)]), + defaultdict(default_value, [("b", 2)]), + ) + with self.assertRaisesRegex( + ValueError, "[Dd]ictionary key mismatch|Node keys mismatch" + ): + t.assert_same_structure( + OrderedDict([("a", 1)]), + OrderedDict([("b", 2)]), + ) + + # Mismatched key count and keys with TrackedDict. + with self.assertRaisesRegex( + ValueError, "Mismatch custom node data|Node arity mismatch" + ): + t.assert_same_structure( + TrackedDict({"a": 1, "b": 2}), + TrackedDict({"a": 1}), + ) + with self.assertRaisesRegex( + ValueError, "Mismatch custom node data|Node context mismatch" + ): + t.assert_same_structure( + TrackedDict({"a": 1}), + TrackedDict({"b": 2}), + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_assert_same_structure_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + t.assert_same_structure(ListWrapper([]), ListWrapper([])) + t.assert_same_structure(ListWrapper([1]), ListWrapper([10])) + t.assert_same_structure(ListWrapper([1, 2]), ListWrapper([10, 20])) + t.assert_same_structure(_DictWrapper(), _DictWrapper()) + t.assert_same_structure(_DictWrapper({"a": 1}), _DictWrapper({"a": 11})) + t.assert_same_structure( + _DictWrapper({"b": 2, "a": 1}), _DictWrapper({"a": 11, "b": 12}) + ) + + # Count and key mismatch + with self.assertRaisesRegex(ValueError, "[Aa]rity mismatch"): + t.assert_same_structure(ListWrapper([1, 2]), ListWrapper([1])) + with self.assertRaisesRegex(ValueError, "Mismatch custom node data"): + t.assert_same_structure( + _DictWrapper({"a": 1, "b": 2}), + _DictWrapper({"a": 1}), + ) + with self.assertRaisesRegex(ValueError, "Mismatch custom node data"): + t.assert_same_structure( + _DictWrapper({"a": 1}), + _DictWrapper({"b": 2}), + ) + + # Tensorflow wrappers are not equivalent to the raw structures. + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*ListWrapper"): + t.assert_same_structure(ListWrapper([1, 2]), list([10, 20])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*list"): + t.assert_same_structure(list([1, 2]), ListWrapper([10, 20])) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*_DictWrapper"): + t.assert_same_structure( + _DictWrapper({"b": 2, "a": 1}), {"a": 10, "b": 20} + ) + with self.assertRaisesRegex(ValueError, "[Ee]xpected.*dict"): + t.assert_same_structure( + {"b": 2, "a": 1}, _DictWrapper({"a": 10, "b": 20}) + ) + + def test_assert_same_paths(self, t): + # Non-nested. + t.assert_same_paths(1, 10) + + # Standard structures. + t.assert_same_paths((), ()) + t.assert_same_paths((1,), (10,)) + t.assert_same_paths((1, 2), (10, 20)) + t.assert_same_paths([], []) + t.assert_same_paths([1], [10]) + t.assert_same_paths([1, 2], [10, 20]) + t.assert_same_paths(deque([]), deque([])) + t.assert_same_paths(deque([1]), deque([10])) + t.assert_same_paths(deque([1, 2]), deque([10, 20])) + t.assert_same_paths(Empty(), Empty()) + t.assert_same_paths(Point(y=2, x=1), Point(x=10, y=20)) + t.assert_same_paths({}, {}) + t.assert_same_paths({"a": 1}, {"a": 10}) + t.assert_same_paths({"b": None, "a": None}, {"a": 10, "b": 20}) + t.assert_same_paths(OrderedDict(), OrderedDict()) + t.assert_same_paths(OrderedDict([("a", 1)]), OrderedDict([("a", 10)])) + t.assert_same_paths( + OrderedDict([("b", 2), ("a", 1)]), + OrderedDict([("a", 10), ("b", 20)]), + ) + t.assert_same_paths( + defaultdict(default_value), defaultdict(default_value) + ) + t.assert_same_paths( + defaultdict(default_value, [("a", 1)]), + defaultdict(default_value, [("a", 10)]), + ) + t.assert_same_paths( + defaultdict(default_value, [("b", 2), ("a", 1)]), + defaultdict(default_value, [("a", 1), ("b", 2)]), + ) + + # Keras tracking wrappers. + t.assert_same_paths( + TrackedList([]), + TrackedList([]), + ) + t.assert_same_paths( + TrackedList([1]), + TrackedList([10]), + ) + t.assert_same_paths( + TrackedList([1, 2]), + TrackedList([10, 20]), + ) + t.assert_same_paths( + TrackedSet([]), + TrackedSet([]), + ) + t.assert_same_paths( + TrackedSet([1]), + TrackedSet([10]), + ) + t.assert_same_paths( + TrackedSet([1, 2]), + TrackedSet([10, 20]), + ) + t.assert_same_paths( + TrackedDict({}), + TrackedDict({}), + ) + t.assert_same_paths( + TrackedDict({"a": 1}), + TrackedDict({"a": 10}), + ) + t.assert_same_paths( + TrackedDict({"b": 2, "a": 1}), + TrackedDict({"a": 10, "b": 20}), + ) + + # Deeper nested structures. + t.assert_same_paths( + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + ( + {"b": [20, 30], "a": (10,)}, + TrackedDict({"x": 40, "y": TrackedList([50, 60])}), + TrackedSet([70]), + Point(y=90, x=80), + np.array([100]), + ), + ) + + # list, tuple, deque and namedtuple have the same paths. + # Test all 6 combinations: + # tuple, list. + t.assert_same_paths((), []) + t.assert_same_paths([1, 2], (10, 20)) + # tuple, deque. + t.assert_same_paths((), deque()) + t.assert_same_paths(deque([1, 2]), (10, 20)) + # tuple, namedtuple. + t.assert_same_paths((), Empty()) + t.assert_same_paths(Point(x=1, y=2), (10, 20)) + # list, deque. + t.assert_same_paths([], deque()) + t.assert_same_paths(deque([1, 2]), [10, 20]) + # list, namedtuple. + t.assert_same_paths([], Empty()) + t.assert_same_paths(Point(x=None, y=20), [1, 2]) + # deque, namedtuple. + t.assert_same_paths(deque(), Empty()) + t.assert_same_paths(Point(x=None, y=20), deque([1, 2])) + + # Equivalent namedtuples. + t.assert_same_paths(Point(x=1, y=2), OtherPoint(x=None, y=20)) + + # Mismatched counts. + with self.assertRaisesRegex(ValueError, "don't have the same paths"): + t.assert_same_paths((1, 2), (1,)) + with self.assertRaisesRegex(ValueError, "don't have the same paths"): + t.assert_same_paths([1, 2], [1]) + with self.assertRaisesRegex(ValueError, "don't have the same paths"): + t.assert_same_paths(deque([1, 2]), deque([1])) + + # dict, OrderedDict, defaultdict are considered equivalent. Test all 6 + # combinations (3 type combinations plus the order). + # dict, OrderedDict. + t.assert_same_paths( + {"a": 1, "b": 2}, OrderedDict([("b", 20), ("a", 10)]) + ) + # OrderedDict, dict. + t.assert_same_paths( + OrderedDict([("b", 20), ("a", 10)]), {"a": 1, "b": 2} + ) + # dict, defaultdict. + t.assert_same_paths( + {"a": 1, "b": 2}, + defaultdict(default_value, [("b", 20), ("a", 10)]), + ) + # defaultdict, dict. + t.assert_same_paths( + defaultdict(default_value, [("b", 20), ("a", 10)]), + {"a": 1, "b": 2}, + ) + # defaultdict, OrderedDict. + t.assert_same_paths( + defaultdict(default_value, [("a", 1), ("b", 2)]), + OrderedDict([("b", 20), ("a", 10)]), + ) + # OrderedDict, defaultdict. + t.assert_same_paths( + OrderedDict([("b", 2), ("a", 1)]), + defaultdict(default_value, [("a", 10), ("b", 20)]), + ) + + # Two OrderedDicts with same keys but different orders. + t.assert_same_paths( + OrderedDict([("b", 2), ("a", 1)]), + OrderedDict([("a", 10), ("b", 20)]), + ) + + # Keras tracking wrappers are equivalent to the raw structures. + t.assert_same_paths(TrackedList([1, 2]), list([10, 20])) + t.assert_same_paths(list([1, 2]), TrackedList([10, 20])) + t.assert_same_paths(TrackedSet([1, 2]), list([10, 20])) + t.assert_same_paths(list([1, 2]), TrackedSet([10, 20])) + t.assert_same_paths(TrackedDict({"b": 2, "a": 1}), {"a": 10, "b": 20}) + t.assert_same_paths({"b": 2, "a": 1}, TrackedDict({"a": 10, "b": 20})) + + # Mismatched keys + with self.assertRaisesRegex(ValueError, "don't have the same paths"): + t.assert_same_paths({"a": 1, "b": 2}, {"a": 1}) + with self.assertRaisesRegex(ValueError, "don't have the same paths"): + t.assert_same_paths( + defaultdict(default_value, [("a", 1), ("b", 2)]), + defaultdict(default_value, [("a", 10)]), + ) + with self.assertRaisesRegex(ValueError, "don't have the same paths"): + t.assert_same_paths( + OrderedDict([("a", 1), ("b", 2)]), OrderedDict([("a", 10)]) + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_assert_same_paths_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + t.assert_same_paths(ListWrapper([]), ListWrapper([])) + t.assert_same_paths(ListWrapper([1]), ListWrapper([10])) + t.assert_same_paths(ListWrapper([1, 2]), ListWrapper([10, 20])) + t.assert_same_paths(_DictWrapper(), _DictWrapper()) + t.assert_same_paths(_DictWrapper({"a": 1}), _DictWrapper({"a": 11})) + t.assert_same_paths( + _DictWrapper({"b": 2, "a": 1}), _DictWrapper({"a": 11, "b": 12}) + ) + + # Tensorflow wrappers are equivalent to the raw structures. + t.assert_same_paths(ListWrapper([1, 2]), list([10, 20])) + t.assert_same_paths(list([1, 2]), ListWrapper([10, 20])) + t.assert_same_paths(_DictWrapper({"b": 2, "a": 1}), {"a": 10, "b": 20}) + t.assert_same_paths({"b": 2, "a": 1}, _DictWrapper({"a": 10, "b": 20})) + + def test_traverse_top_down(self, t): + v = Visitor(lambda x: None if t.is_nested(x) else x + 10) + + # Non-nested. + self.assertEqualStrict(t.traverse(v, 1), 11) + self.assertEqualStrict(v.visited(), [1]) + + # Standard structures. + self.assertEqualStrict(t.traverse(v, ()), ()) + self.assertEqualStrict(v.visited(), [()]) + + self.assertEqualStrict(t.traverse(v, (1,)), (11,)) + self.assertEqualStrict(v.visited(), [(1,), 1]) + + self.assertEqualStrict(t.traverse(v, (1, 2)), (11, 12)) + self.assertEqualStrict(v.visited(), [(1, 2), 1, 2]) + + self.assertEqualStrict(t.traverse(v, []), []) + self.assertEqualStrict(v.visited(), [[]]) + + self.assertEqualStrict(t.traverse(v, [1]), [11]) + self.assertEqualStrict(v.visited(), [[1], 1]) + + self.assertEqualStrict(t.traverse(v, [1, 2]), [11, 12]) + self.assertEqualStrict(v.visited(), [[1, 2], 1, 2]) + + self.assertEqualStrict(t.traverse(v, deque([])), deque([])) + self.assertEqualStrict(v.visited(), [deque([])]) + + self.assertEqualStrict(t.traverse(v, deque([1])), deque([11])) + self.assertEqualStrict(v.visited(), [deque([1]), 1]) + + self.assertEqualStrict(t.traverse(v, deque([1, 2])), deque([11, 12])) + self.assertEqualStrict(v.visited(), [deque([1, 2]), 1, 2]) + + self.assertEqualStrict(t.traverse(v, Empty()), Empty()) + self.assertEqualStrict(v.visited(), [Empty()]) + + self.assertEqualStrict( + t.traverse(v, Point(y=2, x=1)), Point(x=11, y=12) + ) + self.assertEqualStrict(v.visited(), [Point(x=1, y=2), 1, 2]) + + self.assertEqualStrict(t.traverse(v, {}), {}) + self.assertEqualStrict(v.visited(), [{}]) + + self.assertEqualStrict(t.traverse(v, {"a": 1}), {"a": 11}) + self.assertEqualStrict(v.visited(), [{"a": 1}, 1]) + + self.assertEqualStrict( + t.traverse(v, {"b": 2, "a": 1}), {"a": 11, "b": 12} + ) + self.assertEqualStrict(v.visited(), [{"a": 1, "b": 2}, 1, 2]) + + self.assertEqualStrict(t.traverse(v, OrderedDict()), OrderedDict()) + self.assertEqualStrict(v.visited(), [OrderedDict()]) + + self.assertEqualStrict( + t.traverse(v, OrderedDict([("a", 1)])), OrderedDict([("a", 11)]) + ) + self.assertEqualStrict(v.visited(), [OrderedDict([("a", 1)]), 1]) + + self.assertEqualStrict( + t.traverse(v, OrderedDict([("b", 2), ("a", 1)])), + OrderedDict([("b", 12), ("a", 11)]), + ) + self.assertEqualStrict( + v.visited(), [OrderedDict([("b", 2), ("a", 1)]), 2, 1] + ) + + self.assertEqualStrict( + t.traverse(v, defaultdict(default_value)), + defaultdict(default_value), + ) + self.assertEqualStrict(v.visited(), [defaultdict(default_value)]) + + self.assertEqualStrict( + t.traverse(v, defaultdict(default_value, [("a", 1)])), + defaultdict(default_value, [("a", 11)]), + ) + self.assertEqualStrict( + v.visited(), [defaultdict(default_value, [("a", 1)]), 1] + ) + + self.assertEqualStrict( + t.traverse(v, defaultdict(default_value, [("b", 2), ("a", 1)])), + defaultdict(default_value, [("a", 11), ("b", 12)]), + ) + self.assertEqualStrict( + v.visited(), + [defaultdict(default_value, [("a", 1), ("b", 2)]), 1, 2], + ) + + # Keras tracking wrappers. + self.assertEqualStrict(t.traverse(v, TrackedList([])), TrackedList([])) + self.assertEqualStrict(v.visited(), [TrackedList([])]) + + self.assertEqualStrict( + t.traverse(v, TrackedList([1])), TrackedList([11]) + ) + self.assertEqualStrict(v.visited(), [TrackedList([1]), 1]) + + self.assertEqualStrict( + t.traverse(v, TrackedList([1, 2])), TrackedList([11, 12]) + ) + self.assertEqualStrict(v.visited(), [TrackedList([1, 2]), 1, 2]) + + self.assertEqualStrict(t.traverse(v, TrackedSet([])), TrackedSet([])) + self.assertEqualStrict(v.visited(), [TrackedSet([])]) + + self.assertEqualStrict(t.traverse(v, TrackedSet([1])), TrackedSet([11])) + self.assertEqualStrict(v.visited(), [TrackedSet([1]), 1]) + + self.assertEqualStrict( + t.traverse(v, TrackedSet([1, 2])), TrackedSet([11, 12]) + ) + visited = v.visited() + self.assertEqualStrict(visited[0], TrackedSet([1, 2])) + self.assertEqualStrict(sorted(visited[1:]), [1, 2]) + + self.assertEqualStrict( + t.traverse(v, TrackedDict()), + TrackedDict(), + ) + self.assertEqualStrict(v.visited(), [TrackedDict()]) + + self.assertEqualStrict( + t.traverse(v, TrackedDict({"a": 1})), + TrackedDict({"a": 11}), + ) + self.assertEqualStrict(v.visited(), [TrackedDict({"a": 1}), 1]) + + self.assertEqualStrict( + t.traverse(v, TrackedDict({"b": 2, "a": 1})), + TrackedDict({"a": 11, "b": 12}), + ) + self.assertEqualStrict( + v.visited(), [TrackedDict({"a": 1, "b": 2}), 1, 2] + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.traverse( + v, + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + ), + ( + {"b": [12, 13], "a": (11,)}, + TrackedDict({"x": 14, "y": TrackedList([15, 16])}), + TrackedSet([17]), + Point(y=19, x=18), + np.array([20]), + ), + ) + self.assertEqualStrict( + v.visited(), + [ + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + {"b": [2, 3], "a": (1,)}, + (1,), + 1, + [2, 3], + 2, + 3, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + 4, + TrackedList([5, 6]), + 5, + 6, + TrackedSet([7]), + 7, + Point(x=8, y=9), + 8, + 9, + np.array([10]), + ], + ) + + # Error cases. + with self.assertRaisesRegex(TypeError, "callable"): + t.traverse("bad", [1, 2]) + + # Children are not explored if structure is replaced with a leaf. + v = Visitor(lambda x: "X" if isinstance(x, tuple) else None) + self.assertEqualStrict( + t.traverse(v, [(1, [2]), [3, (4, 5, 6)]]), + ["X", [3, "X"]], + ) + self.assertEqualStrict( + v.visited(), + [ + [(1, [2]), [3, (4, 5, 6)]], + (1, [2]), + [3, (4, 5, 6)], + 3, + (4, 5, 6), + ], + ) + + # Children are not explored if structure is replaced with structure. + v = Visitor(lambda x: ("a", "b") if isinstance(x, tuple) else None) + self.assertEqualStrict( + t.traverse(v, [(1, [2]), [3, (4, 5, 6)]]), + [("a", "b"), [3, ("a", "b")]], + ) + self.assertEqualStrict( + v.visited(), + [ + [(1, [2]), [3, (4, 5, 6)]], + (1, [2]), + [3, (4, 5, 6)], + 3, + (4, 5, 6), + ], + ) + + # MAP_TO_NONE. + v = Visitor(lambda x: MAP_TO_NONE if isinstance(x, tuple) else None) + self.assertEqualStrict( + t.traverse(v, [(1, [2]), [3, (4, 5, 6)]]), + [None, [3, None]], + ) + self.assertEqualStrict( + v.visited(), + [ + [(1, [2]), [3, (4, 5, 6)]], + (1, [2]), + [3, (4, 5, 6)], + 3, + (4, 5, 6), + ], + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_traverse_top_down_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + v = Visitor(lambda x: None if t.is_nested(x) else x + 10) + + self.assertEqualStrict(t.traverse(v, ListWrapper([])), ListWrapper([])) + self.assertEqualStrict(v.visited(), [ListWrapper([])]) + + self.assertEqualStrict( + t.traverse(v, ListWrapper([1])), ListWrapper([11]) + ) + self.assertEqualStrict(v.visited(), [ListWrapper([1]), 1]) + + self.assertEqualStrict( + t.traverse(v, ListWrapper([1, 2])), ListWrapper([11, 12]) + ) + self.assertEqualStrict(v.visited(), [ListWrapper([1, 2]), 1, 2]) + + self.assertEqualStrict( + t.traverse(v, _DictWrapper()), + _DictWrapper(), + ) + self.assertEqualStrict(v.visited(), [_DictWrapper()]) + + self.assertEqualStrict( + t.traverse(v, _DictWrapper({"a": 1})), + _DictWrapper({"a": 11}), + ) + self.assertEqualStrict(v.visited(), [_DictWrapper({"a": 1}), 1]) + + self.assertEqualStrict( + t.traverse(v, _DictWrapper({"b": 2, "a": 1})), + _DictWrapper({"a": 11, "b": 12}), + ) + self.assertEqualStrict( + v.visited(), [_DictWrapper({"a": 1, "b": 2}), 1, 2] + ) + + def test_traverse_bottom_up(self, t): + v = Visitor(lambda x: None if t.is_nested(x) else x + 10) + traverse_u = functools.partial(t.traverse, top_down=False) + + # Non-nested. + self.assertEqualStrict(traverse_u(v, 1), 11) + self.assertEqualStrict(v.visited(), [1]) + + # Standard structures. + self.assertEqualStrict(traverse_u(v, ()), ()) + self.assertEqualStrict(v.visited(), [()]) + + self.assertEqualStrict(traverse_u(v, (1,)), (11,)) + self.assertEqualStrict(v.visited(), [1, (11,)]) + + self.assertEqualStrict(traverse_u(v, (1, 2)), (11, 12)) + self.assertEqualStrict(v.visited(), [1, 2, (11, 12)]) + + self.assertEqualStrict(traverse_u(v, []), []) + self.assertEqualStrict(v.visited(), [[]]) + + self.assertEqualStrict(traverse_u(v, [1]), [11]) + self.assertEqualStrict(v.visited(), [1, [11]]) + + self.assertEqualStrict(traverse_u(v, [1, 2]), [11, 12]) + self.assertEqualStrict(v.visited(), [1, 2, [11, 12]]) + + self.assertEqualStrict(traverse_u(v, deque([])), deque([])) + self.assertEqualStrict(v.visited(), [deque([])]) + + self.assertEqualStrict(traverse_u(v, deque([1])), deque([11])) + self.assertEqualStrict(v.visited(), [1, deque([11])]) + + self.assertEqualStrict(traverse_u(v, deque([1, 2])), deque([11, 12])) + self.assertEqualStrict(v.visited(), [1, 2, deque([11, 12])]) + + self.assertEqualStrict(traverse_u(v, Empty()), Empty()) + self.assertEqualStrict(v.visited(), [Empty()]) + + self.assertEqualStrict( + traverse_u(v, Point(y=2, x=1)), Point(x=11, y=12) + ) + self.assertEqualStrict(v.visited(), [1, 2, Point(x=11, y=12)]) + + self.assertEqualStrict(traverse_u(v, {}), {}) + self.assertEqualStrict(v.visited(), [{}]) + + self.assertEqualStrict(traverse_u(v, {"a": 1}), {"a": 11}) + self.assertEqualStrict(v.visited(), [1, {"a": 11}]) + + self.assertEqualStrict( + traverse_u(v, {"b": 2, "a": 1}), {"a": 11, "b": 12} + ) + self.assertEqualStrict(v.visited(), [1, 2, {"a": 11, "b": 12}]) + + self.assertEqualStrict(traverse_u(v, OrderedDict()), OrderedDict()) + self.assertEqualStrict(v.visited(), [OrderedDict()]) + + self.assertEqualStrict( + traverse_u(v, OrderedDict([("a", 1)])), OrderedDict([("a", 11)]) + ) + self.assertEqualStrict(v.visited(), [1, OrderedDict([("a", 11)])]) + + self.assertEqualStrict( + traverse_u(v, OrderedDict([("b", 2), ("a", 1)])), + OrderedDict([("b", 12), ("a", 11)]), + ) + self.assertEqualStrict( + v.visited(), [2, 1, OrderedDict([("b", 12), ("a", 11)])] + ) + + self.assertEqualStrict( + traverse_u(v, defaultdict(default_value)), + defaultdict(default_value), + ) + self.assertEqualStrict(v.visited(), [defaultdict(default_value)]) + + self.assertEqualStrict( + traverse_u(v, defaultdict(default_value, [("a", 1)])), + defaultdict(default_value, [("a", 11)]), + ) + self.assertEqualStrict( + v.visited(), [1, defaultdict(default_value, [("a", 11)])] + ) + + self.assertEqualStrict( + traverse_u(v, defaultdict(default_value, [("b", 2), ("a", 1)])), + defaultdict(default_value, [("a", 11), ("b", 12)]), + ) + self.assertEqualStrict( + v.visited(), + [1, 2, defaultdict(default_value, [("a", 11), ("b", 12)])], + ) + + # Keras tracking wrappers. + self.assertEqualStrict(traverse_u(v, TrackedList([])), TrackedList([])) + self.assertEqualStrict(v.visited(), [TrackedList([])]) + + self.assertEqualStrict( + traverse_u(v, TrackedList([1])), TrackedList([11]) + ) + self.assertEqualStrict(v.visited(), [1, TrackedList([11])]) + + self.assertEqualStrict( + traverse_u(v, TrackedList([1, 2])), TrackedList([11, 12]) + ) + self.assertEqualStrict(v.visited(), [1, 2, TrackedList([11, 12])]) + + self.assertEqualStrict(traverse_u(v, TrackedSet([])), TrackedSet([])) + self.assertEqualStrict(v.visited(), [TrackedSet([])]) + + self.assertEqualStrict(traverse_u(v, TrackedSet([1])), TrackedSet([11])) + self.assertEqualStrict(v.visited(), [1, TrackedSet([11])]) + + self.assertEqualStrict( + traverse_u(v, TrackedSet([1, 2])), TrackedSet([11, 12]) + ) + visited = v.visited() + self.assertEqualStrict(visited[-1], TrackedSet([11, 12])) + self.assertEqualStrict(sorted(visited[:-1]), [1, 2]) + + self.assertEqualStrict( + traverse_u(v, TrackedDict()), + TrackedDict(), + ) + self.assertEqualStrict(v.visited(), [TrackedDict()]) + + self.assertEqualStrict( + traverse_u(v, TrackedDict({"a": 1})), + TrackedDict({"a": 11}), + ) + self.assertEqualStrict(v.visited(), [1, TrackedDict({"a": 11})]) + + self.assertEqualStrict( + traverse_u(v, TrackedDict({"b": 2, "a": 1})), + TrackedDict({"a": 11, "b": 12}), + ) + self.assertEqualStrict( + v.visited(), [1, 2, TrackedDict({"a": 11, "b": 12})] + ) + + # Deeper nested structures. + self.assertEqualStrict( + traverse_u( + v, + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([7]), + Point(y=9, x=8), + np.array([10]), + ), + ), + ( + {"b": [12, 13], "a": (11,)}, + TrackedDict({"x": 14, "y": TrackedList([15, 16])}), + TrackedSet([17]), + Point(y=19, x=18), + np.array([20]), + ), + ) + self.assertEqualStrict( + v.visited(), + [ + 1, + (11,), + 2, + 3, + [12, 13], + {"b": [12, 13], "a": (11,)}, + 4, + 5, + 6, + TrackedList([15, 16]), + TrackedDict({"x": 14, "y": TrackedList([15, 16])}), + 7, + TrackedSet([17]), + 8, + 9, + Point(x=18, y=19), + np.array([10]), + ( + {"b": [12, 13], "a": (11,)}, + TrackedDict({"x": 14, "y": TrackedList([15, 16])}), + TrackedSet([17]), + Point(y=19, x=18), + np.array([20]), + ), + ], + ) + + # Error cases. + with self.assertRaisesRegex(TypeError, "callable"): + traverse_u("bad", [1, 2]) + + # Children are not explored if structure is replaced with a leaf. + v = Visitor(lambda x: "X" if isinstance(x, tuple) else None) + self.assertEqualStrict( + traverse_u(v, [(1, [2]), [3, (4, 5, 6)]]), + ["X", [3, "X"]], + ) + self.assertEqualStrict( + v.visited(), + [ + 1, + 2, + [2], + (1, [2]), + 3, + 4, + 5, + 6, + (4, 5, 6), + [3, "X"], + ["X", [3, "X"]], + ], + ) + + # Children are not explored if structure is replaced with structure. + v = Visitor(lambda x: ("a", "b") if isinstance(x, tuple) else None) + self.assertEqualStrict( + traverse_u(v, [(1, [2]), [3, (4, 5, 6)]]), + [("a", "b"), [3, ("a", "b")]], + ) + self.assertEqualStrict( + v.visited(), + [ + 1, + 2, + [2], + (1, [2]), + 3, + 4, + 5, + 6, + (4, 5, 6), + [3, ("a", "b")], + [("a", "b"), [3, ("a", "b")]], + ], + ) + + # MAP_TO_NONE. + v = Visitor(lambda x: MAP_TO_NONE if isinstance(x, tuple) else None) + self.assertEqualStrict( + traverse_u(v, [(1, [2]), [3, (4, 5, 6)]]), + [None, [3, None]], + ) + self.assertEqualStrict( + v.visited(), + [ + 1, + 2, + [2], + (1, [2]), + 3, + 4, + 5, + 6, + (4, 5, 6), + [3, None], + [None, [3, None]], + ], + ) + + @pytest.mark.skipif(backend.backend() != "tensorflow", reason="tf only") + def test_traverse_bottom_up_tf_wrappers(self, t): + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + v = Visitor(lambda x: None if t.is_nested(x) else x + 10) + traverse_u = functools.partial(t.traverse, top_down=False) + + self.assertEqualStrict(traverse_u(v, ListWrapper([])), ListWrapper([])) + self.assertEqualStrict(v.visited(), [ListWrapper([])]) + + self.assertEqualStrict( + traverse_u(v, ListWrapper([1])), ListWrapper([11]) + ) + self.assertEqualStrict(v.visited(), [1, ListWrapper([11])]) + + self.assertEqualStrict( + traverse_u(v, ListWrapper([1, 2])), ListWrapper([11, 12]) + ) + self.assertEqualStrict(v.visited(), [1, 2, ListWrapper([11, 12])]) + + self.assertEqualStrict( + traverse_u(v, _DictWrapper()), + _DictWrapper(), + ) + self.assertEqualStrict(v.visited(), [_DictWrapper()]) + + self.assertEqualStrict( + traverse_u(v, _DictWrapper({"a": 1})), + _DictWrapper({"a": 11}), + ) + self.assertEqualStrict(v.visited(), [1, _DictWrapper({"a": 11})]) + + self.assertEqualStrict( + traverse_u(v, _DictWrapper({"b": 2, "a": 1})), + _DictWrapper({"a": 11, "b": 12}), + ) + self.assertEqualStrict( + v.visited(), [1, 2, _DictWrapper({"a": 11, "b": 12})] + ) + + def test_lists_to_tuples(self, t): + self.assertEqualStrict( + t.lists_to_tuples([1, 2, 3]), + (1, 2, 3), + ) + self.assertEqualStrict( + t.lists_to_tuples([[1], [2, 3]]), + ((1,), (2, 3)), + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.lists_to_tuples( + ( + {"b": [2, 3], "a": (1,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([(7, 8, 9)]), + ), + ), + ( + {"b": (2, 3), "a": (1,)}, + TrackedDict({"x": 4, "y": (5, 6)}), + TrackedSet([(7, 8, 9)]), + ), + ) + + def test_map_shape_structure(self, t): + v = Visitor( + lambda x: tuple(x) + (10,) if isinstance(x, (tuple, list)) else None + ) + + self.assertEqualStrict( + t.map_shape_structure(v, (1, 2, 3)), + (1, 2, 3, 10), + ) + self.assertEqualStrict( + v.visited(), + [ + (1, 2, 3), + ], + ) + + self.assertEqualStrict( + t.map_shape_structure(v, {"a": [1, 2, None], "b": (5,), "c": "hi"}), + {"a": (1, 2, None, 10), "b": (5, 10), "c": None}, + ) + self.assertEqualStrict( + v.visited(), + [ + [1, 2, None], + (5,), + "hi", + ], + ) + + # Deeper nested structures. + self.assertEqualStrict( + t.map_shape_structure( + v, + ( + {"b": [2, 3], "a": (None,)}, + TrackedDict({"x": 4, "y": TrackedList([5, 6])}), + TrackedSet([(7, None, 9)]), + ), + ), + ( + {"b": (2, 3, 10), "a": (None, 10)}, + TrackedDict({"x": None, "y": (5, 6, 10)}), + TrackedSet([(7, None, 9, 10)]), + ), + ) + self.assertEqualStrict( + v.visited(), + [ + (None,), + [2, 3], + 4, + TrackedList([5, 6]), + (7, None, 9), + ], + ) diff --git a/keras/src/utils/__init__.py b/keras/src/utils/__init__.py new file mode 100644 index 000000000000..c503a2043776 --- /dev/null +++ b/keras/src/utils/__init__.py @@ -0,0 +1,26 @@ +from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory +from keras.src.utils.dataset_utils import split_dataset +from keras.src.utils.file_utils import get_file +from keras.src.utils.image_dataset_utils import image_dataset_from_directory +from keras.src.utils.image_utils import array_to_img +from keras.src.utils.image_utils import img_to_array +from keras.src.utils.image_utils import load_img +from keras.src.utils.image_utils import save_img +from keras.src.utils.io_utils import disable_interactive_logging +from keras.src.utils.io_utils import enable_interactive_logging +from keras.src.utils.io_utils import is_interactive_logging_enabled +from keras.src.utils.model_visualization import model_to_dot +from keras.src.utils.model_visualization import plot_model +from keras.src.utils.numerical_utils import normalize +from keras.src.utils.numerical_utils import to_categorical +from keras.src.utils.progbar import Progbar +from keras.src.utils.python_utils import default +from keras.src.utils.python_utils import is_default +from keras.src.utils.python_utils import removeprefix +from keras.src.utils.python_utils import removesuffix +from keras.src.utils.rng_utils import set_random_seed +from keras.src.utils.sequence_utils import pad_sequences +from keras.src.utils.text_dataset_utils import text_dataset_from_directory +from keras.src.utils.timeseries_dataset_utils import ( + timeseries_dataset_from_array, +) diff --git a/keras/src/utils/argument_validation.py b/keras/src/utils/argument_validation.py new file mode 100644 index 000000000000..8f772b11e5c0 --- /dev/null +++ b/keras/src/utils/argument_validation.py @@ -0,0 +1,92 @@ +def standardize_tuple(value, n, name, allow_zero=False): + """Transforms non-negative/positive integer/integers into an integer tuple. + + Args: + value: int or iterable of ints. The value to validate and convert. + n: int. The size of the tuple to be returned. + name: string. The name of the argument being validated, e.g. "strides" + or "kernel_size". This is only used to format error messages. + allow_zero: bool, defaults to `False`. A `ValueError` will raised + if zero is received and this argument is `False`. + + Returns: + A tuple of n integers. + """ + error_msg = ( + f"The `{name}` argument must be a tuple of {n} integers. " + f"Received {name}={value}" + ) + + if isinstance(value, int): + value_tuple = (value,) * n + else: + try: + value_tuple = tuple(value) + except TypeError: + raise ValueError(error_msg) + if len(value_tuple) != n: + raise ValueError(error_msg) + for single_value in value_tuple: + try: + int(single_value) + except (ValueError, TypeError): + error_msg += ( + f"including element {single_value} of " + f"type {type(single_value)}" + ) + raise ValueError(error_msg) + + if allow_zero: + unqualified_values = {v for v in value_tuple if v < 0} + req_msg = ">= 0" + else: + unqualified_values = {v for v in value_tuple if v <= 0} + req_msg = "> 0" + + if unqualified_values: + error_msg += ( + f", including values {unqualified_values}" + f" that do not satisfy `value {req_msg}`" + ) + raise ValueError(error_msg) + + return value_tuple + + +def standardize_padding(value, allow_causal=False): + if isinstance(value, (list, tuple)): + return value + padding = value.lower() + if allow_causal: + allowed_values = {"valid", "same", "causal"} + else: + allowed_values = {"valid", "same"} + if padding not in allowed_values: + raise ValueError( + "The `padding` argument must be a list/tuple or one of " + f"{allowed_values}. " + f"Received: {padding}" + ) + return padding + + +def validate_string_arg( + value, + allowable_strings, + caller_name, + arg_name, + allow_none=False, + allow_callables=False, +): + """Validates the correctness of a string-based arg.""" + if allow_none and value is None: + return + elif allow_callables and callable(value): + return + elif isinstance(value, str) and value in allowable_strings: + return + raise ValueError( + f"Unknown value for `{arg_name}` argument of {caller_name}. " + f"Allowed values are: {allowable_strings}. Received: " + f"{arg_name}={value}" + ) diff --git a/keras/src/utils/audio_dataset_utils.py b/keras/src/utils/audio_dataset_utils.py new file mode 100644 index 000000000000..ad2fb4e7f565 --- /dev/null +++ b/keras/src/utils/audio_dataset_utils.py @@ -0,0 +1,453 @@ +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils import dataset_utils +from keras.src.utils.module_utils import tensorflow as tf +from keras.src.utils.module_utils import tensorflow_io as tfio + +ALLOWED_FORMATS = (".wav",) + + +@keras_export("keras.utils.audio_dataset_from_directory") +def audio_dataset_from_directory( + directory, + labels="inferred", + label_mode="int", + class_names=None, + batch_size=32, + sampling_rate=None, + output_sequence_length=None, + ragged=False, + shuffle=True, + seed=None, + validation_split=None, + subset=None, + follow_links=False, + verbose=True, +): + """Generates a `tf.data.Dataset` from audio files in a directory. + + If your directory structure is: + + ``` + main_directory/ + ...class_a/ + ......a_audio_1.wav + ......a_audio_2.wav + ...class_b/ + ......b_audio_1.wav + ......b_audio_2.wav + ``` + + Then calling `audio_dataset_from_directory(main_directory, + labels='inferred')` + will return a `tf.data.Dataset` that yields batches of audio files from + the subdirectories `class_a` and `class_b`, together with labels + 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). + + Only `.wav` files are supported at this time. + + Args: + directory: Directory where the data is located. + If `labels` is `"inferred"`, it should contain subdirectories, + each containing audio files for a class. Otherwise, the directory + structure is ignored. + labels: Either "inferred" (labels are generated from the directory + structure), `None` (no labels), or a list/tuple of integer labels + of the same size as the number of audio files found in + the directory. Labels should be sorted according to the + alphanumeric order of the audio file paths + (obtained via `os.walk(directory)` in Python). + label_mode: String describing the encoding of `labels`. Options are: + - `"int"`: means that the labels are encoded as integers (e.g. for + `sparse_categorical_crossentropy` loss). + - `"categorical"` means that the labels are encoded as a categorical + vector (e.g. for `categorical_crossentropy` loss) + - `"binary"` means that the labels (there can be only 2) + are encoded as `float32` scalars with values 0 + or 1 (e.g. for `binary_crossentropy`). + - `None` (no labels). + class_names: Only valid if "labels" is `"inferred"`. + This is the explicit list of class names + (must match names of subdirectories). Used to control the order + of the classes (otherwise alphanumerical order is used). + batch_size: Size of the batches of data. Default: 32. If `None`, + the data will not be batched + (the dataset will yield individual samples). + sampling_rate: Audio sampling rate (in samples per second). + output_sequence_length: Maximum length of an audio sequence. Audio files + longer than this will be truncated to `output_sequence_length`. + If set to `None`, then all sequences in the same batch will + be padded to the + length of the longest sequence in the batch. + ragged: Whether to return a Ragged dataset (where each sequence has its + own length). Defaults to `False`. + shuffle: Whether to shuffle the data. + If set to `False`, sorts the data in alphanumeric order. + Defaults to `True`. + seed: Optional random seed for shuffling and transformations. + validation_split: Optional float between 0 and 1, fraction of data to + reserve for validation. + subset: Subset of the data to return. One of `"training"`, + `"validation"` or `"both"`. Only used if `validation_split` is set. + follow_links: Whether to visits subdirectories pointed to by symlinks. + Defaults to `False`. + verbose: Whether to display number information on classes and + number of files found. Defaults to `True`. + + Returns: + + A `tf.data.Dataset` object. + + - If `label_mode` is `None`, it yields `string` tensors of shape + `(batch_size,)`, containing the contents of a batch of audio files. + - Otherwise, it yields a tuple `(audio, labels)`, where `audio` + has shape `(batch_size, sequence_length, num_channels)` and `labels` + follows the format described + below. + + Rules regarding labels format: + + - if `label_mode` is `int`, the labels are an `int32` tensor of shape + `(batch_size,)`. + - if `label_mode` is `binary`, the labels are a `float32` tensor of + 1s and 0s of shape `(batch_size, 1)`. + - if `label_mode` is `categorical`, the labels are a `float32` tensor + of shape `(batch_size, num_classes)`, representing a one-hot + encoding of the class index. + """ + if labels not in ("inferred", None): + if not isinstance(labels, (list, tuple)): + raise ValueError( + "The `labels` argument should be a list/tuple of integer " + "labels, of the same size as the number of audio files in " + "the target directory. If you wish to infer the labels from " + "the subdirectory names in the target directory," + ' pass `labels="inferred"`. ' + "If you wish to get a dataset that only contains audio samples " + f"(no labels), pass `labels=None`. Received: labels={labels}" + ) + if class_names: + raise ValueError( + "You can only pass `class_names` if " + f'`labels="inferred"`. Received: labels={labels}, and ' + f"class_names={class_names}" + ) + if label_mode not in {"int", "categorical", "binary", None}: + raise ValueError( + '`label_mode` argument must be one of "int", "categorical", ' + '"binary", ' + f"or None. Received: label_mode={label_mode}" + ) + + if ragged and output_sequence_length is not None: + raise ValueError( + "Cannot set both `ragged` and `output_sequence_length`" + ) + + if sampling_rate is not None: + if not isinstance(sampling_rate, int): + raise ValueError( + "`sampling_rate` should have an integer value. " + f"Received: sampling_rate={sampling_rate}" + ) + + if sampling_rate <= 0: + raise ValueError( + "`sampling_rate` should be higher than 0. " + f"Received: sampling_rate={sampling_rate}" + ) + + if not tfio.available: + raise ImportError( + "To use the argument `sampling_rate`, you should install " + "tensorflow_io. You can install it via `pip install " + "tensorflow-io`." + ) + + if labels is None or label_mode is None: + labels = None + label_mode = None + + dataset_utils.check_validation_split_arg( + validation_split, subset, shuffle, seed + ) + + if seed is None: + seed = np.random.randint(1e6) + if batch_size is not None: + shuffle_buffer_size = batch_size * 8 + else: + shuffle_buffer_size = 1024 + + file_paths, labels, class_names = dataset_utils.index_directory( + directory, + labels, + formats=ALLOWED_FORMATS, + class_names=class_names, + shuffle=shuffle, + seed=seed, + follow_links=follow_links, + verbose=verbose, + ) + + if label_mode == "binary" and len(class_names) != 2: + raise ValueError( + 'When passing `label_mode="binary"`, there must be exactly 2 ' + f"class_names. Received: class_names={class_names}" + ) + + if subset == "both": + train_dataset, val_dataset = get_training_and_validation_dataset( + file_paths=file_paths, + labels=labels, + validation_split=validation_split, + directory=directory, + label_mode=label_mode, + class_names=class_names, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + train_dataset = prepare_dataset( + dataset=train_dataset, + batch_size=batch_size, + class_names=class_names, + output_sequence_length=output_sequence_length, + ragged=ragged, + ) + val_dataset = prepare_dataset( + dataset=val_dataset, + batch_size=batch_size, + class_names=class_names, + output_sequence_length=output_sequence_length, + ragged=ragged, + ) + return train_dataset, val_dataset + + else: + dataset = get_dataset( + file_paths=file_paths, + labels=labels, + directory=directory, + validation_split=validation_split, + subset=subset, + label_mode=label_mode, + class_names=class_names, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + dataset = prepare_dataset( + dataset=dataset, + batch_size=batch_size, + class_names=class_names, + output_sequence_length=output_sequence_length, + ragged=ragged, + ) + return dataset + + +def prepare_dataset( + dataset, + batch_size, + class_names, + output_sequence_length, + ragged, +): + dataset = dataset.prefetch(tf.data.AUTOTUNE) + if batch_size is not None: + if output_sequence_length is None and not ragged: + dataset = dataset.padded_batch( + batch_size, padded_shapes=([None, None], []) + ) + else: + dataset = dataset.batch(batch_size) + + # Users may need to reference `class_names`. + dataset.class_names = class_names + return dataset + + +def get_training_and_validation_dataset( + file_paths, + labels, + validation_split, + directory, + label_mode, + class_names, + sampling_rate, + output_sequence_length, + ragged, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + ( + file_paths_train, + labels_train, + ) = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, "training" + ) + if not file_paths_train: + raise ValueError( + f"No training audio files found in directory {directory}. " + f"Allowed format(s): {ALLOWED_FORMATS}" + ) + + file_paths_val, labels_val = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, "validation" + ) + if not file_paths_val: + raise ValueError( + f"No validation audio files found in directory {directory}. " + f"Allowed format(s): {ALLOWED_FORMATS}" + ) + + train_dataset = paths_and_labels_to_dataset( + file_paths=file_paths_train, + labels=labels_train, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + + val_dataset = paths_and_labels_to_dataset( + file_paths=file_paths_val, + labels=labels_val, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=False, + ) + + return train_dataset, val_dataset + + +def get_dataset( + file_paths, + labels, + directory, + validation_split, + subset, + label_mode, + class_names, + sampling_rate, + output_sequence_length, + ragged, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + file_paths, labels = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, subset + ) + if not file_paths: + raise ValueError( + f"No audio files found in directory {directory}. " + f"Allowed format(s): {ALLOWED_FORMATS}" + ) + + return paths_and_labels_to_dataset( + file_paths=file_paths, + labels=labels, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + + +def read_and_decode_audio( + path, sampling_rate=None, output_sequence_length=None +): + """Reads and decodes audio file.""" + audio = tf.io.read_file(path) + + if output_sequence_length is None: + output_sequence_length = -1 + + audio, default_audio_rate = tf.audio.decode_wav( + contents=audio, desired_samples=output_sequence_length + ) + if sampling_rate is not None: + # default_audio_rate should have dtype=int64 + default_audio_rate = tf.cast(default_audio_rate, tf.int64) + audio = tfio.audio.resample( + input=audio, rate_in=default_audio_rate, rate_out=sampling_rate + ) + return audio + + +def paths_and_labels_to_dataset( + file_paths, + labels, + label_mode, + num_classes, + sampling_rate, + output_sequence_length, + ragged, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + """Constructs a fixed-size dataset of audio and labels.""" + path_ds = tf.data.Dataset.from_tensor_slices(file_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_tf( + labels, label_mode, num_classes + ) + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) + + if label_mode: + ds = ds.map( + lambda x, y: ( + read_and_decode_audio(x, sampling_rate, output_sequence_length), + y, + ), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + if ragged: + ds = ds.map( + lambda x, y: (tf.RaggedTensor.from_tensor(x), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + else: + ds = ds.map( + lambda x: read_and_decode_audio( + x, sampling_rate, output_sequence_length + ), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + if ragged: + ds = ds.map( + lambda x: tf.RaggedTensor.from_tensor(x), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + return ds diff --git a/keras/src/utils/audio_dataset_utils_test.py b/keras/src/utils/audio_dataset_utils_test.py new file mode 100644 index 000000000000..a6ad7b00b815 --- /dev/null +++ b/keras/src/utils/audio_dataset_utils_test.py @@ -0,0 +1,433 @@ +import os + +import numpy as np + +from keras.src import testing +from keras.src.utils import audio_dataset_utils +from keras.src.utils.module_utils import tensorflow as tf + + +class AudioDatasetFromDirectoryTest(testing.TestCase): + def _get_audio_samples(self, count=16, different_sequence_lengths=False): + sequence_length = 30 + num_channels = 1 + audio_samples = [] + for _ in range(count): + if different_sequence_lengths: + random_sequence_length = np.random.randint( + 10, sequence_length + 1 + ) + audio = np.random.random((random_sequence_length, num_channels)) + else: + audio = np.random.random((sequence_length, num_channels)) + audio_samples.append(tf.audio.encode_wav(audio, 1000)) + return audio_samples + + def _prepare_directory( + self, + num_classes=2, + nested_dirs=False, + count=16, + different_sequence_lengths=False, + ): + # Get a unique temp directory + temp_dir = self.get_temp_dir() + + # Generate paths to class subdirectories + paths = [] + for class_index in range(num_classes): + class_directory = f"class_{class_index}" + if nested_dirs: + class_paths = [ + class_directory, + os.path.join(class_directory, "subfolder_1"), + os.path.join(class_directory, "subfolder_2"), + os.path.join( + class_directory, "subfolder_1", "sub-subfolder" + ), + ] + else: + class_paths = [class_directory] + for path in class_paths: + os.mkdir(os.path.join(temp_dir, path)) + paths += class_paths + + # Save audio samples to the paths + i = 0 + for audio in self._get_audio_samples( + count=count, different_sequence_lengths=different_sequence_lengths + ): + path = paths[i % len(paths)] + ext = "wav" + filename = os.path.join(path, f"audio_{i}.{ext}") + with open(os.path.join(temp_dir, filename), "wb") as f: + f.write(audio.numpy()) + i += 1 + return temp_dir + + def test_audio_dataset_from_directory_standalone(self): + # Test retrieving audio samples without labels from a directory and its + # subdirs. + # Save a few extra audio in the parent directory. + directory = self._prepare_directory(count=7, num_classes=2) + for i, audio in enumerate(self._get_audio_samples(3)): + filename = f"audio_{i}.wav" + with open(os.path.join(directory, filename), "wb") as f: + f.write(audio.numpy()) + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=5, output_sequence_length=30, labels=None + ) + batch = next(iter(dataset)) + # We return plain audio + self.assertEqual(batch.shape, (5, 30, 1)) + self.assertEqual(batch.dtype.name, "float32") + # Count samples + batch_count = 0 + sample_count = 0 + for batch in dataset: + batch_count += 1 + sample_count += batch.shape[0] + self.assertEqual(batch_count, 2) + self.assertEqual(sample_count, 10) + + def test_audio_dataset_from_directory_binary(self): + directory = self._prepare_directory(num_classes=2) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=8, output_sequence_length=30, label_mode="int" + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8, 30, 1)) + self.assertEqual(batch[0].dtype.name, "float32") + self.assertEqual(batch[1].shape, (8,)) + self.assertEqual(batch[1].dtype.name, "int32") + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=8, + output_sequence_length=30, + label_mode="binary", + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8, 30, 1)) + self.assertEqual(batch[0].dtype.name, "float32") + self.assertEqual(batch[1].shape, (8, 1)) + self.assertEqual(batch[1].dtype.name, "float32") + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=8, + output_sequence_length=30, + label_mode="categorical", + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8, 30, 1)) + self.assertEqual(batch[0].dtype.name, "float32") + self.assertEqual(batch[1].shape, (8, 2)) + self.assertEqual(batch[1].dtype.name, "float32") + + def test_static_shape_in_graph(self): + directory = self._prepare_directory(num_classes=2) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=8, output_sequence_length=30, label_mode="int" + ) + test_case = self + + @tf.function + def symbolic_fn(ds): + for x, _ in ds.take(1): + test_case.assertListEqual(x.shape.as_list(), [None, 30, None]) + + symbolic_fn(dataset) + + def test_sample_count(self): + directory = self._prepare_directory(num_classes=4, count=15) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=8, output_sequence_length=30, label_mode=None + ) + sample_count = 0 + for batch in dataset: + sample_count += batch.shape[0] + self.assertEqual(sample_count, 15) + + def test_audio_dataset_from_directory_multiclass(self): + directory = self._prepare_directory(num_classes=4, count=15) + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=8, output_sequence_length=30, label_mode=None + ) + batch = next(iter(dataset)) + self.assertEqual(batch.shape, (8, 30, 1)) + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=8, output_sequence_length=30, label_mode=None + ) + sample_count = 0 + iterator = iter(dataset) + for batch in dataset: + sample_count += next(iterator).shape[0] + self.assertEqual(sample_count, 15) + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=8, output_sequence_length=30, label_mode="int" + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8, 30, 1)) + self.assertEqual(batch[0].dtype.name, "float32") + self.assertEqual(batch[1].shape, (8,)) + self.assertEqual(batch[1].dtype.name, "int32") + + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=8, + output_sequence_length=30, + label_mode="categorical", + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8, 30, 1)) + self.assertEqual(batch[0].dtype.name, "float32") + self.assertEqual(batch[1].shape, (8, 4)) + self.assertEqual(batch[1].dtype.name, "float32") + + def test_audio_dataset_from_directory_validation_split(self): + directory = self._prepare_directory(num_classes=2, count=10) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=10, + output_sequence_length=30, + validation_split=0.2, + subset="training", + seed=1337, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (8, 30, 1)) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=10, + output_sequence_length=30, + validation_split=0.2, + subset="validation", + seed=1337, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (2, 30, 1)) + + def test_audio_dataset_from_directory_manual_labels(self): + directory = self._prepare_directory(num_classes=2, count=2) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=8, + output_sequence_length=30, + labels=[0, 1], + shuffle=False, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertAllClose(batch[1], [0, 1]) + + def test_audio_dataset_from_directory_follow_links(self): + directory = self._prepare_directory( + num_classes=2, count=25, nested_dirs=True + ) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=8, + output_sequence_length=30, + label_mode=None, + follow_links=True, + ) + sample_count = 0 + for batch in dataset: + sample_count += batch.shape[0] + self.assertEqual(sample_count, 25) + + def test_audio_dataset_from_directory_no_audio(self): + directory = self._prepare_directory(num_classes=2, count=0) + with self.assertRaisesRegex( + ValueError, "No audio files found in directory" + ): + _ = audio_dataset_utils.audio_dataset_from_directory(directory) + + def test_audio_dataset_from_directory_ragged(self): + directory = self._prepare_directory( + num_classes=2, count=16, different_sequence_lengths=True + ) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, ragged=True, batch_size=8 + ) + batch = next(iter(dataset)) + + self.assertEqual(batch[0].shape.as_list(), [8, None, None]) + + def test_audio_dataset_from_directory_no_output_sequence_length_no_ragged( + self, + ): + # This test case tests `audio_dataset_from_directory` when `ragged` and + # `output_sequence_length` are not passed while the input sequence + # lengths are different. + directory = self._prepare_directory( + num_classes=2, count=16, different_sequence_lengths=True + ) + # The tensor shapes are different and output_sequence_length is None + # should work fine and pad each sequence to the length of the longest + # sequence in it's batch + min_sequence_length, max_sequence_length = 10, 30 + possible_sequence_lengths = [ + i for i in range(min_sequence_length, max_sequence_length + 1) + ] + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=2 + ) + sequence_lengths = list(set([b.shape[1] for b, _ in dataset])) + for seq_len in sequence_lengths: + self.assertIn(seq_len, possible_sequence_lengths) + + def test_audio_dataset_from_directory_no_output_sequence_length_same_lengths( # noqa: E501 + self, + ): + # This test case tests `audio_dataset_from_directory` when `ragged` and + # `output_sequence_length` are not passed while the input sequence + # lengths are the same + directory = self._prepare_directory( + num_classes=2, count=16, different_sequence_lengths=False + ) + # The tensor shapes are different and output_sequence_length is None + # should work fine and pad each sequence to the length of the longest + # sequence in it's batch + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, batch_size=2 + ) + sequence_lengths = list(set([batch[0].shape[1] for batch in dataset])) + self.assertEqual(len(sequence_lengths), 1) + + def test_audio_dataset_from_directory_errors(self): + directory = self._prepare_directory(num_classes=3, count=5) + + with self.assertRaisesRegex( + ValueError, "`sampling_rate` should be higher than 0. Received:" + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, + ragged=False, + output_sequence_length=10, + sampling_rate=-1, + ) + + with self.assertRaisesRegex( + ValueError, + "`sampling_rate` should have an integer value. Received:", + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, + ragged=False, + output_sequence_length=10, + sampling_rate=1.2, + ) + + # Only run this test case when we don't have tensorflow_io. + try: + import tensorflow_io # noqa: F401 + except ImportError: + with self.assertRaisesRegex( + ImportError, + "To use the argument `sampling_rate`.*tensorflow_io.*", + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, + ragged=False, + output_sequence_length=10, + sampling_rate=44100, + ) + + with self.assertRaisesRegex( + ValueError, "Cannot set both `ragged` and `output_sequence_length`" + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, ragged=True, output_sequence_length=30 + ) + + with self.assertRaisesRegex(ValueError, "`labels` argument should be"): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, labels="other" + ) + + with self.assertRaisesRegex( + ValueError, "`label_mode` argument must be" + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, label_mode="other" + ) + + with self.assertRaisesRegex( + ValueError, 'only pass `class_names` if `labels="inferred"`' + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, + labels=[0, 0, 1, 1, 1], + class_names=["class_0", "class_1", "class_2"], + ) + + with self.assertRaisesRegex( + ValueError, + "Expected the lengths of `labels` to match the number of files", + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, labels=[0, 0, 1, 1] + ) + + with self.assertRaisesRegex( + ValueError, "`class_names` passed did not match" + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, class_names=["class_0", "wrong_class"] + ) + + with self.assertRaisesRegex(ValueError, "there must be exactly 2"): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, label_mode="binary" + ) + + with self.assertRaisesRegex( + ValueError, "`validation_split` must be between 0 and 1" + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, validation_split=2 + ) + + with self.assertRaisesRegex( + ValueError, '`subset` must be either "training",' + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, validation_split=0.2, subset="other" + ) + + with self.assertRaisesRegex( + ValueError, "`validation_split` must be set" + ): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, validation_split=0.0, subset="training" + ) + + with self.assertRaisesRegex(ValueError, "must provide a `seed`"): + _ = audio_dataset_utils.audio_dataset_from_directory( + directory, validation_split=0.2, subset="training" + ) + + def test_audio_dataset_from_directory_not_batched(self): + directory = self._prepare_directory(num_classes=2, count=2) + dataset = audio_dataset_utils.audio_dataset_from_directory( + directory, + batch_size=None, + output_sequence_length=30, + label_mode=None, + shuffle=False, + ) + sample = next(iter(dataset)) + self.assertEqual(len(sample.shape), 2) diff --git a/keras/src/utils/backend_utils.py b/keras/src/utils/backend_utils.py new file mode 100644 index 000000000000..ea8c4b4d097a --- /dev/null +++ b/keras/src/utils/backend_utils.py @@ -0,0 +1,161 @@ +import copy +import importlib +import inspect +import os +import sys + +from keras.src import backend as backend_module +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +def in_tf_graph(): + if global_state.get_global_attribute("in_tf_graph_scope", False): + return True + + if "tensorflow" in sys.modules: + from keras.src.utils.module_utils import tensorflow as tf + + return not tf.executing_eagerly() + return False + + +def convert_tf_tensor(outputs, dtype=None): + if backend_module.backend() != "tensorflow" and not in_tf_graph(): + outputs = backend_module.convert_to_tensor(outputs, dtype=dtype) + return outputs + + +class TFGraphScope: + def __init__(self): + self._original_value = global_state.get_global_attribute( + "in_tf_graph_scope", False + ) + + def __enter__(self): + global_state.set_global_attribute("in_tf_graph_scope", True) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute( + "in_tf_graph_scope", self._original_value + ) + + +def in_grain_data_pipeline(): + if "grain" not in sys.modules: + # Fast path to check if grain is not imported. + return False + + # We use a lightweight version of `inspect.stack` to detect execution within + # grain. + current_frame = inspect.currentframe() + while current_frame: + if ( + os.path.join("grain", "_src", "python", "dataset") + in current_frame.f_code.co_filename + or os.path.join("grain", "_src", "python", "data_loader") + in current_frame.f_code.co_filename + ): + return True + current_frame = current_frame.f_back + return False + + +class DynamicBackend: + """A class that can be used to switch from one backend to another. + + Example: + + ```python + backend = DynamicBackend("tensorflow") + y = backend.square(tf.constant(...)) + backend.set_backend("jax") + y = backend.square(jax.numpy.array(...)) + ``` + + Args: + backend: Initial backend to use (string). + """ + + def __init__(self, backend=None): + self._backend = backend or backend_module.backend() + + def set_backend(self, backend): + if backend not in ("tensorflow", "jax", "torch", "numpy", "openvino"): + raise ValueError( + "Available backends are ('tensorflow', 'jax', 'torch', " + f"'numpy' and 'openvino'). Received: backend={backend}" + ) + self._backend = backend + + def reset(self): + self._backend = backend_module.backend() + + @property + def name(self): + return self._backend + + def __getattr__(self, name): + if self._backend == "tensorflow": + module = importlib.import_module("keras.src.backend.tensorflow") + return getattr(module, name) + if self._backend == "jax": + module = importlib.import_module("keras.src.backend.jax") + return getattr(module, name) + if self._backend == "torch": + module = importlib.import_module("keras.src.backend.torch") + return getattr(module, name) + if self._backend == "numpy": + if backend_module.backend() == "numpy": + return getattr(backend_module, name) + else: + raise NotImplementedError( + "Currently, we cannot dynamically import the numpy backend " + "because it would disrupt the namespace of the import." + ) + if self._backend == "openvino": + module = importlib.import_module("keras.src.backend.openvino") + return getattr(module, name) + + +@keras_export("keras.config.set_backend") +def set_backend(backend): + """Reload the backend (and the Keras package). + + Example: + + ```python + keras.config.set_backend("jax") + ``` + + ⚠️ WARNING ⚠️: Using this function is dangerous and should be done + carefully. Changing the backend will **NOT** convert + the type of any already-instantiated objects. + Thus, any layers / tensors / etc. already created will no + longer be usable without errors. It is strongly recommended **not** + to keep around **any** Keras-originated objects instances created + before calling `set_backend()`. + + This includes any function or class instance that uses any Keras + functionality. All such code needs to be re-executed after calling + `set_backend()`. + """ + os.environ["KERAS_BACKEND"] = backend + # Clear module cache. + loaded_modules = [ + key for key in sys.modules.keys() if key.startswith("keras") + ] + for key in loaded_modules: + del sys.modules[key] + # Reimport Keras with the new backend (set via KERAS_BACKEND). + import keras + + # Finally: refresh all imported Keras submodules. + globs = copy.copy(globals()) + for key, value in globs.items(): + if value.__class__ == keras.__class__: + if str(value).startswith("" + + def __iter__(self): + keys = sorted(self._config.keys()) + for k in keys: + yield k + + def __len__(self): + return len(self._config) + + def __delitem__(self, key): + self._raise_if_frozen() + del self._config[key] + + def __contains__(self, item): + return item in self._config diff --git a/keras/src/utils/dataset_utils.py b/keras/src/utils/dataset_utils.py new file mode 100644 index 000000000000..85fe677fb3d9 --- /dev/null +++ b/keras/src/utils/dataset_utils.py @@ -0,0 +1,838 @@ +import os +import random +import time +import warnings +from multiprocessing.pool import ThreadPool + +import numpy as np + +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.utils import file_utils +from keras.src.utils import io_utils +from keras.src.utils.module_utils import grain +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.utils.split_dataset") +def split_dataset( + dataset, left_size=None, right_size=None, shuffle=False, seed=None +): + """Splits a dataset into a left half and a right half (e.g. train / test). + + Args: + dataset: + A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, + or a list/tuple of arrays with the same length. + left_size: If float (in the range `[0, 1]`), it signifies + the fraction of the data to pack in the left dataset. If integer, it + signifies the number of samples to pack in the left dataset. If + `None`, defaults to the complement to `right_size`. + Defaults to `None`. + right_size: If float (in the range `[0, 1]`), it signifies + the fraction of the data to pack in the right dataset. + If integer, it signifies the number of samples to pack + in the right dataset. + If `None`, defaults to the complement to `left_size`. + Defaults to `None`. + shuffle: Boolean, whether to shuffle the data before splitting it. + seed: A random seed for shuffling. + + Returns: + A tuple of two `tf.data.Dataset` objects: + the left and right splits. + + Example: + + >>> data = np.random.random(size=(1000, 4)) + >>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8) + >>> int(left_ds.cardinality()) + 800 + >>> int(right_ds.cardinality()) + 200 + """ + dataset_type_spec = _get_type_spec(dataset) + + if dataset_type_spec is None: + raise TypeError( + "The `dataset` argument must be either" + "a `tf.data.Dataset`, a `torch.utils.data.Dataset`" + "object, or a list/tuple of arrays. " + f"Received: dataset={dataset} of type {type(dataset)}" + ) + + if right_size is None and left_size is None: + raise ValueError( + "At least one of the `left_size` or `right_size` " + "must be specified. Received: left_size=None and " + "right_size=None" + ) + + dataset_as_list = _convert_dataset_to_list(dataset, dataset_type_spec) + + if shuffle: + if seed is None: + seed = random.randint(0, int(1e6)) + random.seed(seed) + random.shuffle(dataset_as_list) + + total_length = len(dataset_as_list) + + left_size, right_size = _rescale_dataset_split_sizes( + left_size, right_size, total_length + ) + left_split = list(dataset_as_list[:left_size]) + right_split = list(dataset_as_list[-right_size:]) + + left_split = _restore_dataset_from_list( + left_split, dataset_type_spec, dataset + ) + right_split = _restore_dataset_from_list( + right_split, dataset_type_spec, dataset + ) + + left_split = tf.data.Dataset.from_tensor_slices(left_split) + right_split = tf.data.Dataset.from_tensor_slices(right_split) + + # apply batching to the splits if the dataset is batched + if dataset_type_spec is tf.data.Dataset and is_batched(dataset): + batch_size = get_batch_size(dataset) + if batch_size is not None: + left_split = left_split.batch(batch_size) + right_split = right_split.batch(batch_size) + + left_split = left_split.prefetch(tf.data.AUTOTUNE) + right_split = right_split.prefetch(tf.data.AUTOTUNE) + return left_split, right_split + + +def _convert_dataset_to_list( + dataset, + dataset_type_spec, + data_size_warning_flag=True, + ensure_shape_similarity=True, +): + """Convert `dataset` object to a list of samples. + + Args: + dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, + or a list/tuple of arrays. + dataset_type_spec: the type of the dataset. + data_size_warning_flag: If set to `True`, a warning will + be issued if the dataset takes longer than 10 seconds to iterate. + Defaults to `True`. + ensure_shape_similarity: If set to `True`, the shape of + the first sample will be used to validate the shape of rest of the + samples. Defaults to `True`. + + Returns: + List: A list of samples. + """ + dataset_iterator = _get_data_iterator_from_dataset( + dataset, dataset_type_spec + ) + dataset_as_list = [] + + start_time = time.time() + for sample in _get_next_sample( + dataset_iterator, + ensure_shape_similarity, + data_size_warning_flag, + start_time, + ): + dataset_as_list.append(sample) + + return dataset_as_list + + +def _get_data_iterator_from_dataset(dataset, dataset_type_spec): + """Get the iterator from a dataset. + + Args: + dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, + or a list/tuple of arrays. + dataset_type_spec: The type of the dataset. + + Returns: + iterator: An `iterator` object. + """ + if dataset_type_spec is list: + if len(dataset) == 0: + raise ValueError( + "Received an empty list dataset. " + "Please provide a non-empty list of arrays." + ) + + expected_shape = None + for i, element in enumerate(dataset): + if not isinstance(element, np.ndarray): + raise ValueError( + "Expected a list of `numpy.ndarray` objects," + f"Received: {type(element)} at index {i}." + ) + if expected_shape is None: + expected_shape = element.shape + elif element.shape[0] != expected_shape[0]: + raise ValueError( + "Received a list of NumPy arrays with different lengths." + f"Mismatch found at index {i}, " + f"Expected shape={expected_shape} " + f"Received shape={np.array(element).shape}." + "Please provide a list of NumPy arrays of the same length." + ) + + return iter(zip(*dataset)) + elif dataset_type_spec is tuple: + if len(dataset) == 0: + raise ValueError( + "Received an empty list dataset." + "Please provide a non-empty tuple of arrays." + ) + + expected_shape = None + for i, element in enumerate(dataset): + if not isinstance(element, np.ndarray): + raise ValueError( + "Expected a tuple of `numpy.ndarray` objects," + f"Received: {type(element)} at index {i}." + ) + if expected_shape is None: + expected_shape = element.shape + elif element.shape[0] != expected_shape[0]: + raise ValueError( + "Received a tuple of NumPy arrays with different lengths." + f"Mismatch found at index {i}, " + f"Expected shape={expected_shape} " + f"Received shape={np.array(element).shape}." + "Please provide a tuple of NumPy arrays of the same length." + ) + + return iter(zip(*dataset)) + elif dataset_type_spec is tf.data.Dataset: + if is_batched(dataset): + dataset = dataset.unbatch() + return iter(dataset) + + elif is_torch_dataset(dataset): + return iter(dataset) + elif dataset_type_spec is np.ndarray: + return iter(dataset) + raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}") + + +def _get_next_sample( + dataset_iterator, + ensure_shape_similarity, + data_size_warning_flag, + start_time, +): + """Yield data samples from the `dataset_iterator`. + + Args: + dataset_iterator: An `iterator` object. + ensure_shape_similarity: If set to `True`, the shape of + the first sample will be used to validate the shape of rest of the + samples. Defaults to `True`. + data_size_warning_flag: If set to `True`, a warning will + be issued if the dataset takes longer than 10 seconds to iterate. + Defaults to `True`. + start_time (float): the start time of the dataset iteration. this is + used only if `data_size_warning_flag` is set to true. + + Yields: + data_sample: The next sample. + """ + from keras.src.trainers.data_adapters.data_adapter_utils import ( + is_torch_tensor, + ) + + try: + dataset_iterator = iter(dataset_iterator) + first_sample = next(dataset_iterator) + if isinstance(first_sample, (tf.Tensor, np.ndarray)) or is_torch_tensor( + first_sample + ): + first_sample_shape = np.array(first_sample).shape + else: + first_sample_shape = None + ensure_shape_similarity = False + yield first_sample + except StopIteration: + raise ValueError( + "Received an empty dataset. Argument `dataset` must " + "be a non-empty list/tuple of `numpy.ndarray` objects " + "or `tf.data.Dataset` objects." + ) + + for i, sample in enumerate(dataset_iterator): + if ensure_shape_similarity: + if first_sample_shape != np.array(sample).shape: + raise ValueError( + "All `dataset` samples must have same shape, " + f"Expected shape: {np.array(first_sample).shape} " + f"Received shape: {np.array(sample).shape} at index " + f"{i}." + ) + if data_size_warning_flag: + if i % 10 == 0: + cur_time = time.time() + # warns user if the dataset is too large to iterate within 10s + if int(cur_time - start_time) > 10 and data_size_warning_flag: + warnings.warn( + "The dataset is taking longer than 10 seconds to " + "iterate over. This may be due to the size of the " + "dataset. Keep in mind that the `split_dataset` " + "utility is only for small in-memory dataset " + "(e.g. < 10,000 samples).", + category=ResourceWarning, + source="split_dataset", + ) + data_size_warning_flag = False + yield sample + + +def is_torch_dataset(dataset): + if hasattr(dataset, "__class__"): + for parent in dataset.__class__.__mro__: + if parent.__name__ == "Dataset" and str( + parent.__module__ + ).startswith("torch.utils.data"): + return True + return False + + +def is_grain_dataset(dataset): + if hasattr(dataset, "__class__"): + for parent in dataset.__class__.__mro__: + if parent.__name__ in ( + "MapDataset", + "IterDataset", + ) and str(parent.__module__).startswith("grain._src.python"): + return True + return False + + +def _rescale_dataset_split_sizes(left_size, right_size, total_length): + """Rescale the dataset split sizes. + + We want to ensure that the sum of + the split sizes is equal to the total length of the dataset. + + Args: + left_size: The size of the left dataset split. + right_size: The size of the right dataset split. + total_length: The total length of the dataset. + + Returns: + tuple: A tuple of rescaled `left_size` and `right_size` integers. + """ + left_size_type = type(left_size) + right_size_type = type(right_size) + + # check both left_size and right_size are integers or floats + if (left_size is not None and left_size_type not in [int, float]) and ( + right_size is not None and right_size_type not in [int, float] + ): + raise TypeError( + "Invalid `left_size` and `right_size` Types. Expected: " + "integer or float or None, Received: type(left_size)=" + f"{left_size_type} and type(right_size)={right_size_type}" + ) + + # check left_size is a integer or float + if left_size is not None and left_size_type not in [int, float]: + raise TypeError( + "Invalid `left_size` Type. Expected: int or float or None, " + f"Received: type(left_size)={left_size_type}. " + ) + + # check right_size is a integer or float + if right_size is not None and right_size_type not in [int, float]: + raise TypeError( + "Invalid `right_size` Type. " + "Expected: int or float or None," + f"Received: type(right_size)={right_size_type}." + ) + + # check left_size and right_size are non-zero + if left_size == 0 and right_size == 0: + raise ValueError( + "Both `left_size` and `right_size` are zero. " + "At least one of the split sizes must be non-zero." + ) + + # check left_size is non-negative and less than 1 and less than total_length + if ( + left_size_type is int + and (left_size <= 0 or left_size >= total_length) + or left_size_type is float + and (left_size <= 0 or left_size >= 1) + ): + raise ValueError( + "`left_size` should be either a positive integer " + f"smaller than {total_length}, or a float " + "within the range `[0, 1]`. Received: left_size=" + f"{left_size}" + ) + + # check right_size is non-negative and less than 1 and less than + # total_length + if ( + right_size_type is int + and (right_size <= 0 or right_size >= total_length) + or right_size_type is float + and (right_size <= 0 or right_size >= 1) + ): + raise ValueError( + "`right_size` should be either a positive integer " + f"and smaller than {total_length} or a float " + "within the range `[0, 1]`. Received: right_size=" + f"{right_size}" + ) + + # check sum of left_size and right_size is less than or equal to + # total_length + if ( + right_size_type is left_size_type is float + and right_size + left_size > 1 + ): + raise ValueError( + "The sum of `left_size` and `right_size` is greater " + "than 1. It must be less than or equal to 1." + ) + + if left_size_type is float: + left_size = round(left_size * total_length) + elif left_size_type is int: + left_size = float(left_size) + + if right_size_type is float: + right_size = round(right_size * total_length) + elif right_size_type is int: + right_size = float(right_size) + + if left_size is None: + left_size = total_length - right_size + elif right_size is None: + right_size = total_length - left_size + + if left_size + right_size > total_length: + raise ValueError( + "The sum of `left_size` and `right_size` should " + f"be smaller than the {total_length}. " + f"Received: left_size + right_size = {left_size + right_size}" + f"and total_length = {total_length}" + ) + + for split, side in [(left_size, "left"), (right_size, "right")]: + if split == 0: + raise ValueError( + f"With `dataset` of length={total_length}, `left_size`=" + f"{left_size} and `right_size`={right_size}." + f"Resulting {side} side dataset split will be empty. " + "Adjust any of the aforementioned parameters" + ) + + left_size, right_size = int(left_size), int(right_size) + return left_size, right_size + + +def _restore_dataset_from_list( + dataset_as_list, dataset_type_spec, original_dataset +): + """Restore the dataset from the list of arrays.""" + if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset( + original_dataset + ): + # Save structure by taking the first element. + element_spec = dataset_as_list[0] + # Flatten each element. + dataset_as_list = [tree.flatten(sample) for sample in dataset_as_list] + # Combine respective elements at all indices. + dataset_as_list = [np.array(sample) for sample in zip(*dataset_as_list)] + # Recreate the original structure of elements. + dataset_as_list = tree.pack_sequence_as(element_spec, dataset_as_list) + # Turn lists to tuples as tf.data will fail on lists. + return tree.traverse( + lambda x: tuple(x) if isinstance(x, list) else x, + dataset_as_list, + top_down=False, + ) + + return dataset_as_list + + +def is_batched(dataset): + """Check if the `tf.data.Dataset` is batched.""" + return hasattr(dataset, "_batch_size") + + +def get_batch_size(dataset): + """Get the batch size of the dataset.""" + if is_batched(dataset): + return dataset._batch_size + else: + return None + + +def _get_type_spec(dataset): + """Get the type spec of the dataset.""" + if isinstance(dataset, tuple): + return tuple + elif isinstance(dataset, list): + return list + elif isinstance(dataset, np.ndarray): + return np.ndarray + elif isinstance(dataset, tf.data.Dataset): + return tf.data.Dataset + elif is_torch_dataset(dataset): + from torch.utils.data import Dataset as TorchDataset + + return TorchDataset + elif is_grain_dataset(dataset): + from grain import MapDataset + + return MapDataset + else: + return None + + +def index_directory( + directory, + labels, + formats, + class_names=None, + shuffle=True, + seed=None, + follow_links=False, + verbose=True, +): + """List all files in `directory`, with their labels. + + Args: + directory: Directory where the data is located. + If `labels` is `"inferred"`, it should contain + subdirectories, each containing files for a class. + Otherwise, the directory structure is ignored. + labels: Either `"inferred"` + (labels are generated from the directory structure), + `None` (no labels), + or a list/tuple of integer labels of the same size as the number + of valid files found in the directory. + Labels should be sorted according + to the alphanumeric order of the image file paths + (obtained via `os.walk(directory)` in Python). + formats: Allowlist of file extensions to index + (e.g. `".jpg"`, `".txt"`). + class_names: Only valid if `labels="inferred"`. This is the explicit + list of class names (must match names of subdirectories). Used + to control the order of the classes + (otherwise alphanumerical order is used). + shuffle: Whether to shuffle the data. Defaults to `True`. + If set to `False`, sorts the data in alphanumeric order. + seed: Optional random seed for shuffling. + follow_links: Whether to visits subdirectories pointed to by symlinks. + verbose: Whether the function prints number of files found and classes. + Defaults to `True`. + + Returns: + tuple (file_paths, labels, class_names). + - file_paths: list of file paths (strings). + - labels: list of matching integer labels (same length as file_paths) + - class_names: names of the classes corresponding to these labels, in + order. + """ + if file_utils.is_remote_path(directory): + os_module = tf.io.gfile + path_module = tf.io.gfile + else: + os_module = os + path_module = os.path + + if labels == "inferred": + subdirs = [] + for subdir in sorted(os_module.listdir(directory)): + if path_module.isdir(path_module.join(directory, subdir)): + if not subdir.startswith("."): + if subdir.endswith("/"): + subdir = subdir[:-1] + subdirs.append(subdir) + if class_names is not None: + if not set(class_names).issubset(set(subdirs)): + raise ValueError( + "The `class_names` passed did not match the " + "names of the subdirectories of the target directory. " + f"Expected: {subdirs} (or a subset of it), " + f"but received: class_names={class_names}" + ) + subdirs = class_names # Keep provided order. + else: + # In the explicit/no-label cases, index from the parent directory down. + subdirs = [""] + if class_names is not None: + if labels is None: + raise ValueError( + "When `labels=None` (no labels), argument `class_names` " + "cannot be specified." + ) + else: + raise ValueError( + "When argument `labels` is specified, argument " + "`class_names` cannot be specified (the `class_names` " + "will be the sorted list of labels)." + ) + class_names = subdirs + class_indices = dict(zip(class_names, range(len(class_names)))) + + # Build an index of the files + # in the different class subfolders. + pool = ThreadPool() + results = [] + filenames = [] + + for dirpath in (path_module.join(directory, subdir) for subdir in subdirs): + results.append( + pool.apply_async( + index_subdirectory, + (dirpath, class_indices, follow_links, formats), + ) + ) + labels_list = [] + for res in results: + partial_filenames, partial_labels = res.get() + labels_list.append(partial_labels) + filenames += partial_filenames + + if labels == "inferred": + # Inferred labels. + i = 0 + labels = np.zeros((len(filenames),), dtype="int32") + for partial_labels in labels_list: + labels[i : i + len(partial_labels)] = partial_labels + i += len(partial_labels) + elif labels is None: + class_names = None + else: + # Manual labels. + if len(labels) != len(filenames): + raise ValueError( + "Expected the lengths of `labels` to match the number " + "of files in the target directory. len(labels) is " + f"{len(labels)} while we found {len(filenames)} files " + f"in directory {directory}." + ) + class_names = [str(label) for label in sorted(set(labels))] + if verbose: + if labels is None: + io_utils.print_msg(f"Found {len(filenames)} files.") + else: + io_utils.print_msg( + f"Found {len(filenames)} files belonging " + f"to {len(class_names)} classes." + ) + pool.close() + pool.join() + file_paths = [path_module.join(directory, fname) for fname in filenames] + + if shuffle: + # Shuffle globally to erase macro-structure + if seed is None: + seed = np.random.randint(1e6) + rng = np.random.RandomState(seed) + rng.shuffle(file_paths) + if labels is not None: + rng = np.random.RandomState(seed) + rng.shuffle(labels) + return file_paths, labels, class_names + + +def iter_valid_files(directory, follow_links, formats): + io_module = tf.io.gfile if file_utils.is_remote_path(directory) else os + + if not follow_links: + walk = io_module.walk(directory) + else: + walk = os.walk(directory, followlinks=follow_links) + for root, _, files in sorted(walk, key=lambda x: x[0]): + for fname in sorted(files): + if fname.lower().endswith(formats): + yield root, fname + + +def index_subdirectory(directory, class_indices, follow_links, formats): + """Recursively walks directory and list image paths and their class index. + + Args: + directory: string, target directory. + class_indices: dict mapping class names to their index. + follow_links: boolean, whether to recursively follow subdirectories + (if False, we only list top-level images in `directory`). + formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). + + Returns: + tuple `(filenames, labels)`. `filenames` is a list of relative file + paths, and `labels` is a list of integer labels corresponding + to these files. + """ + path_module = ( + tf.io.gfile if file_utils.is_remote_path(directory) else os.path + ) + + dirname = os.path.basename(directory) + valid_files = iter_valid_files(directory, follow_links, formats) + labels = [] + filenames = [] + for root, fname in valid_files: + labels.append(class_indices[dirname]) + absolute_path = path_module.join(root, fname) + relative_path = path_module.join( + dirname, os.path.relpath(absolute_path, directory) + ) + filenames.append(relative_path) + return filenames, labels + + +def get_training_or_validation_split(samples, labels, validation_split, subset): + """Potentially restrict samples & labels to a training or validation split. + + Args: + samples: List of elements. + labels: List of corresponding labels. + validation_split: Float, fraction of data to reserve for validation. + subset: Subset of the data to return. + Either `"training"`, `"validation"`, or `None`. + If `None`, we return all of the data. + + Returns: + tuple (samples, labels), potentially restricted to the specified subset. + """ + if not validation_split: + return samples, labels + + num_val_samples = int(validation_split * len(samples)) + if subset == "training": + io_utils.print_msg( + f"Using {len(samples) - num_val_samples} files for training." + ) + samples = samples[:-num_val_samples] + if labels is not None: + labels = labels[:-num_val_samples] + elif subset == "validation": + io_utils.print_msg(f"Using {num_val_samples} files for validation.") + samples = samples[-num_val_samples:] + if labels is not None: + labels = labels[-num_val_samples:] + else: + raise ValueError( + '`subset` must be either "training" ' + f'or "validation", received: {subset}' + ) + return samples, labels + + +def labels_to_dataset_tf(labels, label_mode, num_classes): + """Create a `tf.data.Dataset` from the list/tuple of labels. + + Args: + labels: list/tuple of labels to be converted into a `tf.data.Dataset`. + label_mode: String describing the encoding of `labels`. Options are: + - `"binary"` indicates that the labels (there can be only 2) are encoded + as `float32` scalars with values 0 or 1 + (e.g. for `binary_crossentropy`). + - `"categorical"` means that the labels are mapped into a categorical + vector. (e.g. for `categorical_crossentropy` loss). + num_classes: number of classes of labels. + + Returns: + A `tf.data.Dataset` instance. + """ + label_ds = tf.data.Dataset.from_tensor_slices(labels) + if label_mode == "binary": + label_ds = label_ds.map( + lambda x: tf.expand_dims(tf.cast(x, "float32"), axis=-1), + num_parallel_calls=tf.data.AUTOTUNE, + ) + elif label_mode == "categorical": + label_ds = label_ds.map( + lambda x: tf.one_hot(x, num_classes), + num_parallel_calls=tf.data.AUTOTUNE, + ) + return label_ds + + +def labels_to_dataset_grain(labels, label_mode, num_classes): + """Create a `grain.MapDataset` from the list/tuple of labels. + + Args: + labels: list/tuple of labels to be converted into a `grain.MapDataset`. + label_mode: String describing the encoding of `labels`. Options are: + - `"binary"` indicates that the labels (there can be only 2) are encoded + as `float32` scalars with values 0 or 1 + (e.g. for `binary_crossentropy`). + - `"categorical"` means that the labels are mapped into a categorical + vector. (e.g. for `categorical_crossentropy` loss). + num_classes: number of classes of labels. + + Returns: + A `grain.MapDataset` instance. + """ + from keras.src import backend + from keras.src import ops + + if label_mode not in ("binary", "categorical", "int"): + raise ValueError( + f"Invalid `label_mode`: {label_mode}. " + "Expected one of: 'binary', 'categorical', 'int'." + ) + + def preprocess_labels_in_cpu(label_mode, x, num_classes): + with backend.device_scope("cpu"): + if label_mode == "binary": + return ops.expand_dims( + ops.convert_to_tensor(x, dtype="float32"), axis=-1 + ) + elif label_mode == "categorical": + return ops.one_hot( + ops.convert_to_tensor(x, dtype="int32"), num_classes + ) + else: + return ops.convert_to_tensor(x, dtype="int32") + + label_ds = grain.MapDataset.source(labels) + label_ds = label_ds.map( + lambda x: preprocess_labels_in_cpu(label_mode, x, num_classes), + ) + return label_ds + + +def check_validation_split_arg(validation_split, subset, shuffle, seed): + """Raise errors in case of invalid argument values. + + Args: + validation_split: float between 0 and 1, fraction of data to reserve for + validation. + subset: One of `"training"`, `"validation"`, or `"both"`. Only used if + `validation_split` is set. + shuffle: Whether to shuffle the data. Either `True` or `False`. + seed: random seed for shuffling and transformations. + """ + if validation_split and not 0 < validation_split < 1: + raise ValueError( + "`validation_split` must be between 0 and 1, " + f"received: {validation_split}" + ) + if (validation_split or subset) and not (validation_split and subset): + raise ValueError( + "If `subset` is set, `validation_split` must be set, and inversely." + ) + if subset not in ("training", "validation", "both", None): + raise ValueError( + '`subset` must be either "training", ' + f'"validation" or "both", received: {subset}' + ) + if validation_split and shuffle and seed is None: + raise ValueError( + "If using `validation_split` and shuffling the data, you must " + "provide a `seed` argument, to make sure that there is no " + "overlap between the training and validation subset." + ) diff --git a/keras/src/utils/dataset_utils_test.py b/keras/src/utils/dataset_utils_test.py new file mode 100644 index 000000000000..c907736cc0ba --- /dev/null +++ b/keras/src/utils/dataset_utils_test.py @@ -0,0 +1,100 @@ +import collections +import itertools + +import numpy as np +from absl.testing import parameterized +from torch.utils.data import Dataset as TorchDataset + +from keras.src.testing import test_case +from keras.src.testing.test_utils import named_product +from keras.src.utils.dataset_utils import split_dataset +from keras.src.utils.module_utils import tensorflow as tf + + +class MyTorchDataset(TorchDataset): + def __init__(self, x, y): + self.x = x + self.y = y + + def __len__(self): + return len(self.x) + + def __getitem__(self, index): + return self.x[index], self.y[index] + + +class DatasetUtilsTest(test_case.TestCase): + @parameterized.named_parameters( + named_product( + dataset_type=["list", "tuple", "tensorflow", "torch"], + features_shape=[(2,), (100, 2), (10, 10, 2)], + ) + ) + def test_split_dataset(self, dataset_type, features_shape): + n_sample, left_size, right_size = 100, 0.2, 0.8 + features = np.random.sample((n_sample,) + features_shape) + labels = np.random.sample((n_sample, 1)) + + if dataset_type == "list": + dataset = [features, labels] + elif dataset_type == "tuple": + dataset = (features, labels) + elif dataset_type == "tensorflow": + dataset = tf.data.Dataset.from_tensor_slices((features, labels)) + elif dataset_type == "torch": + dataset = MyTorchDataset(features, labels) + + dataset_left, dataset_right = split_dataset( + dataset, left_size=left_size, right_size=right_size + ) + self.assertEqual( + int(dataset_left.cardinality()), int(n_sample * left_size) + ) + self.assertEqual( + int(dataset_right.cardinality()), int(n_sample * right_size) + ) + for sample in itertools.chain(dataset_left, dataset_right): + self.assertEqual(sample[0].shape, features_shape) + self.assertEqual(sample[1].shape, (1,)) + + @parameterized.named_parameters( + named_product(structure_type=["tuple", "dict", "OrderedDict"]) + ) + def test_split_dataset_nested_structures(self, structure_type): + n_sample, left_size, right_size = 100, 0.2, 0.8 + features1 = np.random.sample((n_sample, 2)) + features2 = np.random.sample((n_sample, 10, 2)) + labels = np.random.sample((n_sample, 1)) + + if structure_type == "tuple": + dataset = tf.data.Dataset.from_tensor_slices( + ((features1, features2), labels) + ) + if structure_type == "dict": + dataset = tf.data.Dataset.from_tensor_slices( + {"y": features2, "x": features1, "labels": labels} + ) + if structure_type == "OrderedDict": + dataset = tf.data.Dataset.from_tensor_slices( + collections.OrderedDict( + [("y", features2), ("x", features1), ("labels", labels)] + ) + ) + + dataset_left, dataset_right = split_dataset( + dataset, left_size=left_size, right_size=right_size + ) + self.assertEqual( + int(dataset_left.cardinality()), int(n_sample * left_size) + ) + self.assertEqual( + int(dataset_right.cardinality()), int(n_sample * right_size) + ) + for sample in itertools.chain(dataset_left, dataset_right): + if structure_type in ("dict", "OrderedDict"): + x, y, labels = sample["x"], sample["y"], sample["labels"] + elif structure_type == "tuple": + (x, y), labels = sample + self.assertEqual(x.shape, (2,)) + self.assertEqual(y.shape, (10, 2)) + self.assertEqual(labels.shape, (1,)) diff --git a/keras/src/utils/dtype_utils.py b/keras/src/utils/dtype_utils.py new file mode 100644 index 000000000000..44ac7d4f65a3 --- /dev/null +++ b/keras/src/utils/dtype_utils.py @@ -0,0 +1,51 @@ +from keras.src import backend +from keras.src import ops + +DTYPE_TO_SIZE = { + **{f"float{i}": i for i in (16, 32, 64)}, + **{f"int{i}": i for i in (8, 16, 32, 64)}, + **{f"uint{i}": i for i in (8, 16, 32, 64)}, + "bfloat16": 16, + "bool": 1, +} + + +def dtype_size(dtype): + size = DTYPE_TO_SIZE.get(dtype, None) + if size is None: + raise ValueError(f"Invalid dtype: {dtype}") + return size + + +def is_float(dtype): + return "float" in dtype + + +def cast_to_common_dtype(tensors): + """Cast a list of tensors to a common dtype. + + If any tensor is floating-point, they will all be casted to the most-precise + floating-point dtype. Otherwise the tensors are not casted. + + Args: + tensors: A list of tensors. + + Returns: + Same list, casted to a common dtype. + """ + highest_float = None + highest_float_size = ( + -1 + ) # Initially set to an impossible value for comparison + for x in tensors: + dtype = backend.standardize_dtype(x.dtype) + if is_float(dtype): + if highest_float is None or dtype_size(dtype) > highest_float_size: + highest_float = dtype + highest_float_size = dtype_size(dtype) + elif dtype == "float16" and highest_float == "bfloat16": + highest_float = "float32" + highest_float_size = dtype_size(highest_float) + if highest_float: + tensors = [ops.cast(x, highest_float) for x in tensors] + return tensors diff --git a/keras/src/utils/dtype_utils_test.py b/keras/src/utils/dtype_utils_test.py new file mode 100644 index 000000000000..390db6fd72d7 --- /dev/null +++ b/keras/src/utils/dtype_utils_test.py @@ -0,0 +1,146 @@ +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.testing import test_case +from keras.src.utils import dtype_utils + + +class DtypeSizeTests(test_case.TestCase): + def test_bfloat16_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("bfloat16"), 16) + + def test_float16_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("float16"), 16) + + def test_float32_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("float32"), 32) + + def test_int32_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("int32"), 32) + + def test_float64_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("float64"), 64) + + def test_int64_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("int64"), 64) + + def test_uint8_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("uint8"), 8) + + def test_bool_dtype_size(self): + self.assertEqual(dtype_utils.dtype_size("bool"), 1) + + def test_invalid_dtype_size(self): + with self.assertRaises(ValueError): + dtype_utils.dtype_size("unknown_dtype") + + +class IsFloatTests(test_case.TestCase): + def test_is_float_float16(self): + self.assertTrue(dtype_utils.is_float("float16")) + + def test_is_float_float32(self): + self.assertTrue(dtype_utils.is_float("float32")) + + def test_is_float_float64(self): + self.assertTrue(dtype_utils.is_float("float64")) + + def test_is_float_int32(self): + self.assertFalse(dtype_utils.is_float("int32")) + + def test_is_float_bool(self): + self.assertFalse(dtype_utils.is_float("bool")) + + def test_is_float_uint8(self): + self.assertFalse(dtype_utils.is_float("uint8")) + + def test_is_float_containing_float(self): + self.assertTrue(dtype_utils.is_float("floating")) + + def test_is_float_empty_string(self): + self.assertFalse(dtype_utils.is_float("")) + + +class CastToCommonDtype(test_case.TestCase): + def test_cast_to_common_dtype_float32_float64(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float32") + tensor2 = KerasTensor([4, 5, 6], dtype="float64") + casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2]) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float64") + + def test_cast_to_common_dtype_float16_float32_float64(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float16") + tensor2 = KerasTensor([4, 5, 6], dtype="float32") + tensor3 = KerasTensor([7, 8, 9], dtype="float64") + casted_tensors = dtype_utils.cast_to_common_dtype( + [tensor1, tensor2, tensor3] + ) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float64") + + def test_cast_to_common_dtype_float16_int16_float32(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float16") + tensor2 = KerasTensor([4, 5, 6], dtype="int16") + tensor3 = KerasTensor([7, 8, 9], dtype="float32") + casted_tensors = dtype_utils.cast_to_common_dtype( + [tensor1, tensor2, tensor3] + ) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float32") + + def test_cast_to_common_dtype_all_float32(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float32") + tensor2 = KerasTensor([4, 5, 6], dtype="float32") + tensor3 = KerasTensor([7, 8, 9], dtype="float32") + casted_tensors = dtype_utils.cast_to_common_dtype( + [tensor1, tensor2, tensor3] + ) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float32") + + def test_cast_to_common_dtype_float16_bfloat16(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float16") + tensor2 = KerasTensor([4, 5, 6], dtype="bfloat16") + casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2]) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float16") + + def test_cast_to_common_dtype_float16_uint8(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float16") + tensor2 = KerasTensor([4, 5, 6], dtype="uint8") + casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2]) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float16") + + def test_cast_to_common_dtype_mixed_types(self): + tensor1 = KerasTensor([1, 2, 3], dtype="float32") + tensor2 = KerasTensor([4, 5, 6], dtype="int32") + tensor3 = KerasTensor([7, 8, 9], dtype="bool") + casted_tensors = dtype_utils.cast_to_common_dtype( + [tensor1, tensor2, tensor3] + ) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float32") + + def test_cast_to_common_dtype_no_float(self): + tensor1 = KerasTensor([1, 2, 3], dtype="int32") + tensor2 = KerasTensor([4, 5, 6], dtype="uint8") + casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2]) + self.assertEqual(casted_tensors[0].dtype, "int32") + self.assertEqual(casted_tensors[1].dtype, "uint8") + + def test_cast_to_common_dtype_float16_bfloat16_promotion(self): + tensor1 = KerasTensor([4, 5, 6], dtype="bfloat16") + tensor2 = KerasTensor([1, 2, 3], dtype="float16") + casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2]) + for tensor in casted_tensors: + self.assertEqual(tensor.dtype, "float32") + + # TODO failed AssertionError: 'float16' != 'float32' + # The order of the tensors matters in the current logic + # of the cast_to_common_dtype function + # def test_cast_to_common_dtype_bfloat16_float16_promotion(self): + # tensor1 = KerasTensor([1, 2, 3], dtype="float16") + # tensor2 = KerasTensor([4, 5, 6], dtype="bfloat16") + # casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2]) + # for tensor in casted_tensors: + # self.assertEqual(tensor.dtype, "float32") diff --git a/keras/src/utils/file_utils.py b/keras/src/utils/file_utils.py new file mode 100644 index 000000000000..55d0a8d7b76e --- /dev/null +++ b/keras/src/utils/file_utils.py @@ -0,0 +1,530 @@ +import hashlib +import os +import re +import shutil +import tarfile +import tempfile +import urllib +import urllib.error +import urllib.parse +import warnings +import zipfile +from urllib.request import urlretrieve + +from keras.src.api_export import keras_export +from keras.src.backend import config +from keras.src.utils import io_utils +from keras.src.utils.module_utils import gfile +from keras.src.utils.progbar import Progbar + + +def path_to_string(path): + """Convert `PathLike` objects to their string representation. + + If given a non-string typed path object, converts it to its string + representation. + + If the object passed to `path` is not among the above, then it is + returned unchanged. This allows e.g. passthrough of file objects + through this function. + + Args: + path: `PathLike` object that represents a path + + Returns: + A string representation of the path argument, if Python support exists. + """ + if isinstance(path, os.PathLike): + return os.fspath(path) + return path + + +def resolve_path(path): + return os.path.realpath(os.path.abspath(path)) + + +def is_path_in_dir(path, base_dir): + return resolve_path(os.path.join(base_dir, path)).startswith(base_dir) + + +def is_link_in_dir(info, base): + tip = resolve_path(os.path.join(base, os.path.dirname(info.name))) + return is_path_in_dir(info.linkname, base_dir=tip) + + +def filter_safe_paths(members): + base_dir = resolve_path(".") + for finfo in members: + valid_path = False + if is_path_in_dir(finfo.name, base_dir): + valid_path = True + yield finfo + elif finfo.issym() or finfo.islnk(): + if is_link_in_dir(finfo, base_dir): + valid_path = True + yield finfo + if not valid_path: + warnings.warn( + "Skipping invalid path during archive extraction: " + f"'{finfo.name}'.", + stacklevel=2, + ) + + +def extract_archive(file_path, path=".", archive_format="auto"): + """Extracts an archive if it matches a support format. + + Supports `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats. + + Args: + file_path: Path to the archive file. + path: Where to extract the archive file. + archive_format: Archive format to try for extracting the file. + Options are `"auto"`, `"tar"`, `"zip"`, and `None`. + `"tar"` includes `.tar`, `.tar.gz`, and `.tar.bz` files. + The default `"auto"` uses `["tar", "zip"]`. + `None` or an empty list will return no matches found. + + Returns: + `True` if a match was found and an archive extraction was completed, + `False` otherwise. + """ + if archive_format is None: + return False + if archive_format == "auto": + archive_format = ["tar", "zip"] + if isinstance(archive_format, str): + archive_format = [archive_format] + + file_path = path_to_string(file_path) + path = path_to_string(path) + + for archive_type in archive_format: + if archive_type == "tar": + open_fn = tarfile.open + is_match_fn = tarfile.is_tarfile + elif archive_type == "zip": + open_fn = zipfile.ZipFile + is_match_fn = zipfile.is_zipfile + else: + raise NotImplementedError(archive_type) + + if is_match_fn(file_path): + with open_fn(file_path) as archive: + try: + if zipfile.is_zipfile(file_path): + # Zip archive. + archive.extractall(path) + else: + # Tar archive, perhaps unsafe. Filter paths. + archive.extractall( + path, members=filter_safe_paths(archive) + ) + except (tarfile.TarError, RuntimeError, KeyboardInterrupt): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + raise + return True + return False + + +@keras_export("keras.utils.get_file") +def get_file( + fname=None, + origin=None, + untar=False, + md5_hash=None, + file_hash=None, + cache_subdir="datasets", + hash_algorithm="auto", + extract=False, + archive_format="auto", + cache_dir=None, + force_download=False, +): + """Downloads a file from a URL if it not already in the cache. + + By default the file at the url `origin` is downloaded to the + cache_dir `~/.keras`, placed in the cache_subdir `datasets`, + and given the filename `fname`. The final location of a file + `example.txt` would therefore be `~/.keras/datasets/example.txt`. + Files in `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats can + also be extracted. + + Passing a hash will verify the file after download. The command line + programs `shasum` and `sha256sum` can compute the hash. + + Example: + + ```python + path_to_downloaded_file = get_file( + origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", + extract=True + ) + ``` + + Args: + fname: If the target is a single file, this is your desired + local name for the file. + If `None`, the name of the file at `origin` will be used. + If downloading and extracting a directory archive, + the provided `fname` will be used as extraction directory + name (only if it doesn't have an extension). + origin: Original URL of the file. + untar: Deprecated in favor of `extract` argument. + Boolean, whether the file is a tar archive that should + be extracted. + md5_hash: Deprecated in favor of `file_hash` argument. + md5 hash of the file for file integrity verification. + file_hash: The expected hash string of the file after download. + The sha256 and md5 hash algorithms are both supported. + cache_subdir: Subdirectory under the Keras cache dir where the file is + saved. If an absolute path, e.g. `"/path/to/folder"` is + specified, the file will be saved at that location. + hash_algorithm: Select the hash algorithm to verify the file. + options are `"md5'`, `"sha256'`, and `"auto'`. + The default 'auto' detects the hash algorithm in use. + extract: If `True`, extracts the archive. Only applicable to compressed + archive files like tar or zip. + archive_format: Archive format to try for extracting the file. + Options are `"auto'`, `"tar'`, `"zip'`, and `None`. + `"tar"` includes tar, tar.gz, and tar.bz files. + The default `"auto"` corresponds to `["tar", "zip"]`. + None or an empty list will return no matches found. + cache_dir: Location to store cached files, when None it + defaults ether `$KERAS_HOME` if the `KERAS_HOME` environment + variable is set or `~/.keras/`. + force_download: If `True`, the file will always be re-downloaded + regardless of the cache state. + + Returns: + Path to the downloaded file. + + **⚠️ Warning on malicious downloads ⚠️** + + Downloading something from the Internet carries a risk. + NEVER download a file/archive if you do not trust the source. + We recommend that you specify the `file_hash` argument + (if the hash of the source file is known) to make sure that the file you + are getting is the one you expect. + """ + if origin is None: + raise ValueError( + 'Please specify the "origin" argument (URL of the file ' + "to download)." + ) + + if cache_dir is None: + cache_dir = config.keras_home() + if md5_hash is not None and file_hash is None: + file_hash = md5_hash + hash_algorithm = "md5" + datadir_base = os.path.expanduser(cache_dir) + if not os.access(datadir_base, os.W_OK): + datadir_base = os.path.join( + "/tmp" if os.path.isdir("/tmp") else tempfile.gettempdir(), ".keras" + ) + datadir = os.path.join(datadir_base, cache_subdir) + os.makedirs(datadir, exist_ok=True) + + provided_fname = fname + fname = path_to_string(fname) + + if not fname: + fname = os.path.basename(urllib.parse.urlsplit(origin).path) + if not fname: + raise ValueError( + "Can't parse the file name from the origin provided: " + f"'{origin}'." + "Please specify the `fname` argument." + ) + else: + if os.sep in fname: + raise ValueError( + "Paths are no longer accepted as the `fname` argument. " + "To specify the file's parent directory, use " + f"the `cache_dir` argument. Received: fname={fname}" + ) + + if extract or untar: + if provided_fname: + if "." in fname: + download_target = os.path.join(datadir, fname) + fname = fname[: fname.find(".")] + extraction_dir = os.path.join(datadir, f"{fname}_extracted") + else: + extraction_dir = os.path.join(datadir, fname) + download_target = os.path.join(datadir, f"{fname}_archive") + else: + extraction_dir = os.path.join(datadir, fname) + download_target = os.path.join(datadir, f"{fname}_archive") + else: + download_target = os.path.join(datadir, fname) + + if force_download: + download = True + elif os.path.exists(download_target): + # File found in cache. + download = False + # Verify integrity if a hash was provided. + if file_hash is not None: + if not validate_file( + download_target, file_hash, algorithm=hash_algorithm + ): + io_utils.print_msg( + "A local file was found, but it seems to be " + f"incomplete or outdated because the {hash_algorithm} " + "file hash does not match the original value of " + f"{file_hash} so we will re-download the data." + ) + download = True + else: + download = True + + if download: + io_utils.print_msg(f"Downloading data from {origin}") + + class DLProgbar: + """Manage progress bar state for use in urlretrieve.""" + + def __init__(self): + self.progbar = None + self.finished = False + + def __call__(self, block_num, block_size, total_size): + if total_size == -1: + total_size = None + if not self.progbar: + self.progbar = Progbar(total_size) + current = block_num * block_size + + if total_size is None: + self.progbar.update(current) + else: + if current < total_size: + self.progbar.update(current) + elif not self.finished: + self.progbar.update(self.progbar.target) + self.finished = True + + error_msg = "URL fetch failure on {}: {} -- {}" + try: + try: + urlretrieve(origin, download_target, DLProgbar()) + except urllib.error.HTTPError as e: + raise Exception(error_msg.format(origin, e.code, e.msg)) + except urllib.error.URLError as e: + raise Exception(error_msg.format(origin, e.errno, e.reason)) + except (Exception, KeyboardInterrupt): + if os.path.exists(download_target): + os.remove(download_target) + raise + + # Validate download if succeeded and user provided an expected hash + # Security conscious users would get the hash of the file from a + # separate channel and pass it to this API to prevent MITM / corruption: + if os.path.exists(download_target) and file_hash is not None: + if not validate_file( + download_target, file_hash, algorithm=hash_algorithm + ): + raise ValueError( + "Incomplete or corrupted file detected. " + f"The {hash_algorithm} " + "file hash does not match the provided value " + f"of {file_hash}." + ) + + if extract or untar: + if untar: + archive_format = "tar" + + status = extract_archive( + download_target, extraction_dir, archive_format + ) + if not status: + warnings.warn("Could not extract archive.", stacklevel=2) + return extraction_dir + + return download_target + + +def resolve_hasher(algorithm, file_hash=None): + """Returns hash algorithm as hashlib function.""" + if algorithm == "sha256": + return hashlib.sha256() + + if algorithm == "auto" and file_hash is not None and len(file_hash) == 64: + return hashlib.sha256() + + # This is used only for legacy purposes. + return hashlib.md5() + + +def hash_file(fpath, algorithm="sha256", chunk_size=65535): + """Calculates a file sha256 or md5 hash. + + Example: + + >>> hash_file('/path/to/file.zip') + 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + + Args: + fpath: Path to the file being validated. + algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. + The default `"auto"` detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. + + Returns: + The file hash. + """ + if isinstance(algorithm, str): + hasher = resolve_hasher(algorithm) + else: + hasher = algorithm + + with open(fpath, "rb") as fpath_file: + for chunk in iter(lambda: fpath_file.read(chunk_size), b""): + hasher.update(chunk) + + return hasher.hexdigest() + + +def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): + """Validates a file against a sha256 or md5 hash. + + Args: + fpath: path to the file being validated + file_hash: The expected hash string of the file. + The sha256 and md5 hash algorithms are both supported. + algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. + The default `"auto"` detects the hash algorithm in use. + chunk_size: Bytes to read at a time, important for large files. + + Returns: + Boolean, whether the file is valid. + """ + hasher = resolve_hasher(algorithm, file_hash) + + if str(hash_file(fpath, hasher, chunk_size)) == str(file_hash): + return True + else: + return False + + +def is_remote_path(filepath): + """ + Determines if a given filepath indicates a remote location. + + This function checks if the filepath represents a known remote pattern + such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`), Placer + (`/placer`), TFHub (`/tfhub`), or a URL (`.*://`). + + Args: + filepath (str): The path to be checked. + + Returns: + bool: True if the filepath is a recognized remote path, otherwise False + """ + if re.match( + r"^(/cns|/cfs|/gcs|/hdfs|/readahead|/placer|/tfhub|.*://).*$", + str(filepath), + ): + return True + return False + + +# Below are gfile-replacement utils. + + +def _raise_if_no_gfile(path): + raise ValueError( + "Handling remote paths requires installing TensorFlow " + f"(in order to use gfile). Received path: {path}" + ) + + +def exists(path): + if is_remote_path(path): + if gfile.available: + return gfile.exists(path) + else: + _raise_if_no_gfile(path) + return os.path.exists(path) + + +def File(path, mode="r"): + if is_remote_path(path): + if gfile.available: + return gfile.GFile(path, mode=mode) + else: + _raise_if_no_gfile(path) + return open(path, mode=mode) + + +def join(path, *paths): + if is_remote_path(path): + if gfile.available: + return gfile.join(path, *paths) + else: + _raise_if_no_gfile(path) + return os.path.join(path, *paths) + + +def isdir(path): + if is_remote_path(path): + if gfile.available: + return gfile.isdir(path) + else: + _raise_if_no_gfile(path) + return os.path.isdir(path) + + +def remove(path): + if is_remote_path(path): + if gfile.available: + return gfile.remove(path) + else: + _raise_if_no_gfile(path) + return os.remove(path) + + +def rmtree(path): + if is_remote_path(path): + if gfile.available: + return gfile.rmtree(path) + else: + _raise_if_no_gfile(path) + return shutil.rmtree(path) + + +def listdir(path): + if is_remote_path(path): + if gfile.available: + return gfile.listdir(path) + else: + _raise_if_no_gfile(path) + return os.listdir(path) + + +def copy(src, dst): + if is_remote_path(src) or is_remote_path(dst): + if gfile.available: + return gfile.copy(src, dst, overwrite=True) + else: + _raise_if_no_gfile(f"src={src} dst={dst}") + return shutil.copy(src, dst) + + +def makedirs(path): + if is_remote_path(path): + if gfile.available: + return gfile.makedirs(path) + else: + _raise_if_no_gfile(path) + return os.makedirs(path) + + +"/fo" diff --git a/keras/src/utils/file_utils_test.py b/keras/src/utils/file_utils_test.py new file mode 100644 index 000000000000..da99ed627576 --- /dev/null +++ b/keras/src/utils/file_utils_test.py @@ -0,0 +1,758 @@ +import hashlib +import os +import shutil +import tarfile +import tempfile +import urllib +import urllib.parse +import urllib.request +import zipfile +from unittest.mock import patch + +from keras.src.testing import test_case +from keras.src.utils import file_utils + + +class PathToStringTest(test_case.TestCase): + def test_path_to_string_with_string_path(self): + path = os.path.join(os.path.sep, "path", "to", "file.txt") + string_path = file_utils.path_to_string(path) + self.assertEqual(string_path, path) + + def test_path_to_string_with_PathLike_object(self): + path = os.path.join(os.path.sep, "path", "to", "file.txt") + string_path = file_utils.path_to_string(path) + self.assertEqual(string_path, str(path)) + + def test_path_to_string_with_non_string_typed_path_object(self): + class NonStringTypedPathObject: + def __fspath__(self): + return os.path.join(os.path.sep, "path", "to", "file.txt") + + path = NonStringTypedPathObject() + string_path = file_utils.path_to_string(path) + self.assertEqual( + string_path, os.path.join(os.path.sep, "path", "to", "file.txt") + ) + + def test_path_to_string_with_none_path(self): + string_path = file_utils.path_to_string(None) + self.assertEqual(string_path, None) + + +class ResolvePathTest(test_case.TestCase): + def test_resolve_path_with_absolute_path(self): + path = os.path.join(os.path.sep, "path", "to", "file.txt") + resolved_path = file_utils.resolve_path(path) + self.assertEqual(resolved_path, os.path.realpath(os.path.abspath(path))) + + def test_resolve_path_with_relative_path(self): + path = os.path.join(".", "file.txt") + resolved_path = file_utils.resolve_path(path) + self.assertEqual(resolved_path, os.path.realpath(os.path.abspath(path))) + + +class IsPathInDirTest(test_case.TestCase): + def test_is_path_in_dir_with_absolute_paths(self): + base_dir = os.path.join(os.path.sep, "path", "to", "base_dir") + path = os.path.join(base_dir, "file.txt") + self.assertTrue(file_utils.is_path_in_dir(path, base_dir)) + + +class IsLinkInDirTest(test_case.TestCase): + def setUp(self): + self._cleanup(os.path.join("test_path", "to", "base_dir")) + self._cleanup(os.path.join(".", "base_dir")) + + def _cleanup(self, base_dir): + if os.path.exists(base_dir): + shutil.rmtree(base_dir) + + def test_is_link_in_dir_with_absolute_paths(self): + base_dir = os.path.join("test_path", "to", "base_dir") + link_path = os.path.join(base_dir, "symlink") + target_path = os.path.join(base_dir, "file.txt") + + # Create the base_dir directory if it does not exist. + os.makedirs(base_dir, exist_ok=True) + + # Create the file.txt file. + with open(target_path, "w") as f: + f.write("Hello, world!") + + os.symlink(target_path, link_path) + + # Creating a stat_result-like object with a name attribute + info = os.lstat(link_path) + info = type( + "stat_with_name", + (object,), + { + "name": os.path.basename(link_path), + "linkname": os.readlink(link_path), + }, + ) + + self.assertTrue(file_utils.is_link_in_dir(info, base_dir)) + + def test_is_link_in_dir_with_relative_paths(self): + base_dir = os.path.join(".", "base_dir") + link_path = os.path.join(base_dir, "symlink") + target_path = os.path.join(base_dir, "file.txt") + + # Create the base_dir directory if it does not exist. + os.makedirs(base_dir, exist_ok=True) + + # Create the file.txt file. + with open(target_path, "w") as f: + f.write("Hello, world!") + + os.symlink(target_path, link_path) + + # Creating a stat_result-like object with a name attribute + info = os.lstat(link_path) + info = type( + "stat_with_name", + (object,), + { + "name": os.path.basename(link_path), + "linkname": os.readlink(link_path), + }, + ) + + self.assertTrue(file_utils.is_link_in_dir(info, base_dir)) + + def tearDown(self): + self._cleanup(os.path.join("test_path", "to", "base_dir")) + self._cleanup(os.path.join(".", "base_dir")) + + +class FilterSafePathsTest(test_case.TestCase): + def setUp(self): + self.base_dir = os.path.join(os.getcwd(), "temp_dir") + os.makedirs(self.base_dir, exist_ok=True) + self.tar_path = os.path.join(self.base_dir, "test.tar") + + def tearDown(self): + os.remove(self.tar_path) + shutil.rmtree(self.base_dir) + + def test_member_within_base_dir(self): + """Test a member within the base directory.""" + with tarfile.open(self.tar_path, "w") as tar: + tar.add(__file__, arcname="safe_path.txt") + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "safe_path.txt") + + def test_symlink_within_base_dir(self): + """Test a symlink pointing within the base directory.""" + symlink_path = os.path.join(self.base_dir, "symlink.txt") + target_path = os.path.join(self.base_dir, "target.txt") + with open(target_path, "w") as f: + f.write("target") + os.symlink(target_path, symlink_path) + with tarfile.open(self.tar_path, "w") as tar: + tar.add(symlink_path, arcname="symlink.txt") + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "symlink.txt") + os.remove(symlink_path) + os.remove(target_path) + + def test_invalid_path_warning(self): + """Test warning for an invalid path during archive extraction.""" + invalid_path = os.path.join(os.getcwd(), "invalid.txt") + with open(invalid_path, "w") as f: + f.write("invalid") + with tarfile.open(self.tar_path, "w") as tar: + tar.add( + invalid_path, arcname="../../invalid.txt" + ) # Path intended to be outside of base dir + with tarfile.open(self.tar_path, "r") as tar: + with patch("warnings.warn") as mock_warn: + _ = list(file_utils.filter_safe_paths(tar.getmembers())) + warning_msg = ( + "Skipping invalid path during archive extraction: " + "'../../invalid.txt'." + ) + mock_warn.assert_called_with(warning_msg, stacklevel=2) + os.remove(invalid_path) + + def test_symbolic_link_in_base_dir(self): + """symbolic link within the base directory is correctly processed.""" + symlink_path = os.path.join(self.base_dir, "symlink.txt") + target_path = os.path.join(self.base_dir, "target.txt") + + # Create a target file and then a symbolic link pointing to it. + with open(target_path, "w") as f: + f.write("target") + os.symlink(target_path, symlink_path) + + # Add the symbolic link to the tar archive. + with tarfile.open(self.tar_path, "w") as tar: + tar.add(symlink_path, arcname="symlink.txt") + + with tarfile.open(self.tar_path, "r") as tar: + members = list(file_utils.filter_safe_paths(tar.getmembers())) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "symlink.txt") + self.assertTrue( + members[0].issym() + ) # Explicitly assert it's a symbolic link. + + os.remove(symlink_path) + os.remove(target_path) + + +class ExtractArchiveTest(test_case.TestCase): + def setUp(self): + """Create temporary directories and files for testing.""" + self.temp_dir = tempfile.mkdtemp() + self.file_content = "Hello, world!" + + # Create sample files to be archived + with open(os.path.join(self.temp_dir, "sample.txt"), "w") as f: + f.write(self.file_content) + + def tearDown(self): + """Clean up temporary directories.""" + shutil.rmtree(self.temp_dir) + + def create_tar(self): + archive_path = os.path.join(self.temp_dir, "sample.tar") + with tarfile.open(archive_path, "w") as archive: + archive.add( + os.path.join(self.temp_dir, "sample.txt"), arcname="sample.txt" + ) + return archive_path + + def create_zip(self): + archive_path = os.path.join(self.temp_dir, "sample.zip") + with zipfile.ZipFile(archive_path, "w") as archive: + archive.write( + os.path.join(self.temp_dir, "sample.txt"), arcname="sample.txt" + ) + return archive_path + + def test_extract_tar(self): + archive_path = self.create_tar() + extract_path = os.path.join(self.temp_dir, "extract_tar") + result = file_utils.extract_archive(archive_path, extract_path, "tar") + self.assertTrue(result) + with open(os.path.join(extract_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + def test_extract_zip(self): + archive_path = self.create_zip() + extract_path = os.path.join(self.temp_dir, "extract_zip") + result = file_utils.extract_archive(archive_path, extract_path, "zip") + self.assertTrue(result) + with open(os.path.join(extract_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + def test_extract_auto(self): + # This will test the 'auto' functionality + tar_archive_path = self.create_tar() + zip_archive_path = self.create_zip() + + extract_tar_path = os.path.join(self.temp_dir, "extract_auto_tar") + extract_zip_path = os.path.join(self.temp_dir, "extract_auto_zip") + + self.assertTrue( + file_utils.extract_archive(tar_archive_path, extract_tar_path) + ) + self.assertTrue( + file_utils.extract_archive(zip_archive_path, extract_zip_path) + ) + + with open(os.path.join(extract_tar_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + with open(os.path.join(extract_zip_path, "sample.txt"), "r") as f: + self.assertEqual(f.read(), self.file_content) + + def test_non_existent_file(self): + extract_path = os.path.join(self.temp_dir, "non_existent") + with self.assertRaises(FileNotFoundError): + file_utils.extract_archive("non_existent.tar", extract_path) + + def test_archive_format_none(self): + archive_path = self.create_tar() + extract_path = os.path.join(self.temp_dir, "none_format") + result = file_utils.extract_archive(archive_path, extract_path, None) + self.assertFalse(result) + + def test_runtime_error_during_extraction(self): + tar_path = self.create_tar() + extract_path = os.path.join(self.temp_dir, "runtime_error_extraction") + + with patch.object( + tarfile.TarFile, "extractall", side_effect=RuntimeError + ): + with self.assertRaises(RuntimeError): + file_utils.extract_archive(tar_path, extract_path, "tar") + self.assertFalse(os.path.exists(extract_path)) + + def test_keyboard_interrupt_during_extraction(self): + tar_path = self.create_tar() + extract_path = os.path.join( + self.temp_dir, "keyboard_interrupt_extraction" + ) + + with patch.object( + tarfile.TarFile, "extractall", side_effect=KeyboardInterrupt + ): + with self.assertRaises(KeyboardInterrupt): + file_utils.extract_archive(tar_path, extract_path, "tar") + self.assertFalse(os.path.exists(extract_path)) + + +class GetFileTest(test_case.TestCase): + def setUp(self): + """Set up temporary directories and sample files.""" + self.temp_dir = self.get_temp_dir() + self.file_path = os.path.join(self.temp_dir, "sample_file.txt") + with open(self.file_path, "w") as f: + f.write("Sample content") + + def test_valid_tar_extraction(self): + """Test valid tar.gz extraction and hash validation.""" + dest_dir = self.get_temp_dir() + orig_dir = self.get_temp_dir() + _, tar_file_path = self._create_tar_file(orig_dir) + self._test_file_extraction_and_validation( + dest_dir, tar_file_path, "tar.gz" + ) + + def test_valid_zip_extraction(self): + """Test valid zip extraction and hash validation.""" + dest_dir = self.get_temp_dir() + orig_dir = self.get_temp_dir() + _, zip_file_path = self._create_zip_file(orig_dir) + self._test_file_extraction_and_validation( + dest_dir, zip_file_path, "zip" + ) + + def test_valid_text_file_download(self): + """Test valid text file download and hash validation.""" + dest_dir = self.get_temp_dir() + orig_dir = self.get_temp_dir() + text_file_path = os.path.join(orig_dir, "test.txt") + with open(text_file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + self._test_file_extraction_and_validation( + dest_dir, text_file_path, None + ) + + def test_get_file_with_tgz_extension(self): + """Test extraction of file with .tar.gz extension.""" + dest_dir = self.get_temp_dir() + orig_dir = dest_dir + _, tar_file_path = self._create_tar_file(orig_dir) + + origin = urllib.parse.urljoin( + "file://", + urllib.request.pathname2url(os.path.abspath(tar_file_path)), + ) + + path = file_utils.get_file( + "test.txt.tar.gz", origin, untar=True, cache_subdir=dest_dir + ) + self.assertTrue(os.path.exists(path)) + self.assertTrue(os.path.exists(os.path.join(path, "test.txt"))) + + def test_get_file_with_integrity_check(self): + """Test file download with integrity check.""" + orig_dir = self.get_temp_dir() + file_path = os.path.join(orig_dir, "test.txt") + + with open(file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + + hashval = file_utils.hash_file(file_path) + + origin = urllib.parse.urljoin( + "file://", urllib.request.pathname2url(os.path.abspath(file_path)) + ) + + path = file_utils.get_file("test.txt", origin, file_hash=hashval) + self.assertTrue(os.path.exists(path)) + + def test_cache_invalidation(self): + """Test using a hash to force cache invalidation.""" + cache_dir = self.get_temp_dir() + src_path = os.path.join(self.get_temp_dir(), "test.txt") + with open(src_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + orig_hash = file_utils.hash_file(src_path) + origin = urllib.parse.urljoin( + "file://", urllib.request.pathname2url(os.path.abspath(src_path)) + ) + # Download into the cache. + dest_path = file_utils.get_file( + "test.txt", origin, file_hash=orig_hash, cache_dir=cache_dir + ) + self.assertEqual(orig_hash, file_utils.hash_file(dest_path)) + + with open(src_path, "w") as text_file: + text_file.write("Float like a zeppelin, sting like a jellyfish.") + new_hash = file_utils.hash_file(src_path) + # Without a hash, we should get the cached version. + dest_path = file_utils.get_file("test.txt", origin, cache_dir=cache_dir) + self.assertEqual(orig_hash, file_utils.hash_file(dest_path)) + # Without the new hash, we should re-download. + dest_path = file_utils.get_file( + "test.txt", origin, file_hash=new_hash, cache_dir=cache_dir + ) + self.assertEqual(new_hash, file_utils.hash_file(dest_path)) + + def test_force_download(self): + """Test using a hash to force cache invalidation.""" + cache_dir = self.get_temp_dir() + src_path = os.path.join(self.get_temp_dir(), "test.txt") + with open(src_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + orig_hash = file_utils.hash_file(src_path) + origin = urllib.parse.urljoin( + "file://", urllib.request.pathname2url(os.path.abspath(src_path)) + ) + # Download into the cache. + dest_path = file_utils.get_file("test.txt", origin, cache_dir=cache_dir) + self.assertEqual(orig_hash, file_utils.hash_file(dest_path)) + + with open(src_path, "w") as text_file: + text_file.write("Float like a zeppelin, sting like a jellyfish.") + new_hash = file_utils.hash_file(src_path) + # Get cached version. + dest_path = file_utils.get_file("test.txt", origin, cache_dir=cache_dir) + self.assertEqual(orig_hash, file_utils.hash_file(dest_path)) + # Force download. + dest_path = file_utils.get_file( + "test.txt", origin, force_download=True, cache_dir=cache_dir + ) + self.assertEqual(new_hash, file_utils.hash_file(dest_path)) + + def test_get_file_with_failed_integrity_check(self): + """Test file download with failed integrity check.""" + orig_dir = self.get_temp_dir() + file_path = os.path.join(orig_dir, "test.txt") + + with open(file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + + hashval = "0" * 64 + + origin = urllib.parse.urljoin( + "file://", urllib.request.pathname2url(os.path.abspath(file_path)) + ) + + with self.assertRaisesRegex( + ValueError, "Incomplete or corrupted file.*" + ): + _ = file_utils.get_file("test.txt", origin, file_hash=hashval) + + def _create_tar_file(self, directory): + """Helper function to create a tar file.""" + text_file_path = os.path.join(directory, "test.txt") + tar_file_path = os.path.join(directory, "test.tar.gz") + with open(text_file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + + with tarfile.open(tar_file_path, "w:gz") as tar_file: + tar_file.add(text_file_path, arcname="test.txt") + + return text_file_path, tar_file_path + + def _create_zip_file(self, directory): + """Helper function to create a zip file.""" + text_file_path = os.path.join(directory, "test.txt") + zip_file_path = os.path.join(directory, "test.zip") + with open(text_file_path, "w") as text_file: + text_file.write("Float like a butterfly, sting like a bee.") + + with zipfile.ZipFile(zip_file_path, "w") as zip_file: + zip_file.write(text_file_path, arcname="test.txt") + + return text_file_path, zip_file_path + + def _test_file_extraction_and_validation( + self, dest_dir, file_path, archive_type + ): + """Helper function for file extraction and validation.""" + origin = urllib.parse.urljoin( + "file://", + urllib.request.pathname2url(os.path.abspath(file_path)), + ) + + hashval_md5 = file_utils.hash_file(file_path, algorithm="md5") + + extract = bool(archive_type) + + path = file_utils.get_file( + "test", + origin, + md5_hash=hashval_md5, + extract=extract, + cache_subdir=dest_dir, + ) + if extract: + fpath = f"{path}_archive" + else: + fpath = path + + self.assertTrue(os.path.exists(path)) + self.assertTrue(file_utils.validate_file(fpath, hashval_md5)) + if extract: + self.assertTrue(os.path.exists(os.path.join(path, "test.txt"))) + + def test_exists(self): + temp_dir = self.get_temp_dir() + file_path = os.path.join(temp_dir, "test_exists.txt") + + with open(file_path, "w") as f: + f.write("test") + + self.assertTrue(file_utils.exists(file_path)) + self.assertFalse( + file_utils.exists(os.path.join(temp_dir, "non_existent.txt")) + ) + + def test_file_open_read(self): + temp_dir = self.get_temp_dir() + file_path = os.path.join(temp_dir, "test_file.txt") + content = "test content" + + with open(file_path, "w") as f: + f.write(content) + + with file_utils.File(file_path, "r") as f: + self.assertEqual(f.read(), content) + + def test_file_open_write(self): + temp_dir = self.get_temp_dir() + file_path = os.path.join(temp_dir, "test_file_write.txt") + content = "test write content" + + with file_utils.File(file_path, "w") as f: + f.write(content) + + with open(file_path, "r") as f: + self.assertEqual(f.read(), content) + + def test_isdir(self): + temp_dir = self.get_temp_dir() + self.assertTrue(file_utils.isdir(temp_dir)) + + file_path = os.path.join(temp_dir, "test_isdir.txt") + with open(file_path, "w") as f: + f.write("test") + self.assertFalse(file_utils.isdir(file_path)) + + def test_join_simple(self): + self.assertEqual(file_utils.join("/path", "to", "dir"), "/path/to/dir") + + def test_join_single_directory(self): + self.assertEqual(file_utils.join("/path"), "/path") + + def test_listdir(self): + content = file_utils.listdir(self.temp_dir) + self.assertIn("sample_file.txt", content) + + def test_makedirs_and_rmtree(self): + new_dir = os.path.join(self.temp_dir, "new_directory") + file_utils.makedirs(new_dir) + self.assertTrue(os.path.isdir(new_dir)) + file_utils.rmtree(new_dir) + self.assertFalse(os.path.exists(new_dir)) + + def test_copy(self): + dest_path = os.path.join(self.temp_dir, "copy_sample_file.txt") + file_utils.copy(self.file_path, dest_path) + self.assertTrue(os.path.exists(dest_path)) + with open(dest_path, "r") as f: + content = f.read() + self.assertEqual(content, "Sample content") + + def test_remove_sub_directory(self): + parent_dir = os.path.join(self.get_temp_dir(), "parent_directory") + child_dir = os.path.join(parent_dir, "child_directory") + file_utils.makedirs(child_dir) + file_utils.rmtree(parent_dir) + self.assertFalse(os.path.exists(parent_dir)) + self.assertFalse(os.path.exists(child_dir)) + + def test_remove_files_inside_directory(self): + dir_path = os.path.join(self.get_temp_dir(), "test_directory") + file_path = os.path.join(dir_path, "test.txt") + file_utils.makedirs(dir_path) + with open(file_path, "w") as f: + f.write("Test content") + file_utils.rmtree(dir_path) + self.assertFalse(os.path.exists(dir_path)) + self.assertFalse(os.path.exists(file_path)) + + def test_handle_complex_paths(self): + complex_dir = os.path.join(self.get_temp_dir(), "complex dir@#%&!") + file_utils.makedirs(complex_dir) + file_utils.rmtree(complex_dir) + self.assertFalse(os.path.exists(complex_dir)) + + +class HashFileTest(test_case.TestCase): + def setUp(self): + self.test_content = b"Hello, World!" + self.temp_file = tempfile.NamedTemporaryFile(delete=False) + self.temp_file.write(self.test_content) + self.temp_file.close() + + def tearDown(self): + os.remove(self.temp_file.name) + + def test_hash_file_sha256(self): + """Test SHA256 hashing of a file.""" + expected_sha256 = ( + "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + ) + calculated_sha256 = file_utils.hash_file( + self.temp_file.name, algorithm="sha256" + ) + self.assertEqual(expected_sha256, calculated_sha256) + + def test_hash_file_md5(self): + """Test MD5 hashing of a file.""" + expected_md5 = "65a8e27d8879283831b664bd8b7f0ad4" + calculated_md5 = file_utils.hash_file( + self.temp_file.name, algorithm="md5" + ) + self.assertEqual(expected_md5, calculated_md5) + + +class TestValidateFile(test_case.TestCase): + def setUp(self): + self.tmp_file = tempfile.NamedTemporaryFile(delete=False) + self.tmp_file.write(b"Hello, World!") + self.tmp_file.close() + + self.sha256_hash = ( + "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f" + ) + self.md5_hash = "65a8e27d8879283831b664bd8b7f0ad4" + + def test_validate_file_sha256(self): + """Validate SHA256 hash of a file.""" + self.assertTrue( + file_utils.validate_file( + self.tmp_file.name, self.sha256_hash, "sha256" + ) + ) + + def test_validate_file_md5(self): + """Validate MD5 hash of a file.""" + self.assertTrue( + file_utils.validate_file(self.tmp_file.name, self.md5_hash, "md5") + ) + + def test_validate_file_auto_sha256(self): + """Auto-detect and validate SHA256 hash.""" + self.assertTrue( + file_utils.validate_file( + self.tmp_file.name, self.sha256_hash, "auto" + ) + ) + + def test_validate_file_auto_md5(self): + """Auto-detect and validate MD5 hash.""" + self.assertTrue( + file_utils.validate_file(self.tmp_file.name, self.md5_hash, "auto") + ) + + def test_validate_file_wrong_hash(self): + """Test validation with incorrect hash.""" + wrong_hash = "deadbeef" * 8 + self.assertFalse( + file_utils.validate_file(self.tmp_file.name, wrong_hash, "sha256") + ) + + def tearDown(self): + os.remove(self.tmp_file.name) + + +class ResolveHasherTest(test_case.TestCase): + def test_resolve_hasher_sha256(self): + """Test resolving hasher for sha256 algorithm.""" + hasher = file_utils.resolve_hasher("sha256") + self.assertIsInstance(hasher, type(hashlib.sha256())) + + def test_resolve_hasher_auto_sha256(self): + """Auto-detect and resolve hasher for sha256.""" + hasher = file_utils.resolve_hasher("auto", file_hash="a" * 64) + self.assertIsInstance(hasher, type(hashlib.sha256())) + + def test_resolve_hasher_auto_md5(self): + """Auto-detect and resolve hasher for md5.""" + hasher = file_utils.resolve_hasher("auto", file_hash="a" * 32) + self.assertIsInstance(hasher, type(hashlib.md5())) + + def test_resolve_hasher_default(self): + """Resolve hasher with a random algorithm value.""" + hasher = file_utils.resolve_hasher("random_value") + self.assertIsInstance(hasher, type(hashlib.md5())) + + +class IsRemotePathTest(test_case.TestCase): + def test_gcs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/gcs/some/path/to/file.txt")) + self.assertTrue(file_utils.is_remote_path("/gcs/another/directory/")) + self.assertTrue(file_utils.is_remote_path("gcs://bucket/some/file.txt")) + + def test_hdfs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("hdfs://some/path/on/hdfs")) + self.assertTrue(file_utils.is_remote_path("/hdfs/some/local/path")) + + def test_cns_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cns/some/path")) + + def test_placer_remote_path(self): + self.assertTrue( + file_utils.is_remote_path("/placer/prod/home/some/path") + ) + self.assertTrue( + file_utils.is_remote_path("/placer/test/home/some/path") + ) + self.assertTrue( + file_utils.is_remote_path("/placer/prod/scratch/home/some/path") + ) + + def test_tfhub_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/tfhub/some/path")) + + def test_cfs_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/cfs/some/path")) + + def test_readahead_remote_path(self): + self.assertTrue(file_utils.is_remote_path("/readahead/some/path")) + + def test_non_remote_paths(self): + self.assertFalse(file_utils.is_remote_path("/local/path/to/file.txt")) + self.assertFalse( + file_utils.is_remote_path("C:\\local\\path\\on\\windows\\file.txt") + ) + self.assertFalse(file_utils.is_remote_path("~/relative/path/")) + self.assertFalse(file_utils.is_remote_path("./another/relative/path")) + self.assertFalse(file_utils.is_remote_path("/local/path")) + self.assertFalse(file_utils.is_remote_path("./relative/path")) + self.assertFalse(file_utils.is_remote_path("~/relative/path")) + + +class TestRaiseIfNoGFile(test_case.TestCase): + def test_raise_if_no_gfile_raises_correct_message(self): + path = "gs://bucket/some/file.txt" + expected_error_msg = ( + "Handling remote paths requires installing TensorFlow " + f".*Received path: {path}" + ) + with self.assertRaisesRegex(ValueError, expected_error_msg): + file_utils._raise_if_no_gfile(path) diff --git a/keras/src/utils/grain_utils.py b/keras/src/utils/grain_utils.py new file mode 100644 index 000000000000..f0a562505dd6 --- /dev/null +++ b/keras/src/utils/grain_utils.py @@ -0,0 +1,33 @@ +from keras.src import backend +from keras.src import tree + + +def make_batch(values): + from keras.src import ops + + if not values: + raise ValueError("Cannot batch 0 values. Please file a bug.") + + with backend.device_scope("cpu"): + return tree.map_structure(lambda *xs: ops.stack(xs), *values) + + +def make_string_batch(values): + from keras.src import ops + + if not values: + raise ValueError("Cannot batch 0 values. Please file a bug.") + + def batch_fn(*xs): + if isinstance(xs[0], str): + if backend.backend() == "tensorflow": + import tensorflow as tf + + xs = [tf.convert_to_tensor(x, dtype=tf.string) for x in xs] + xs = tf.stack(xs) + return xs + else: + return ops.stack(xs) + + with backend.device_scope("cpu"): + return tree.map_structure(batch_fn, *values) diff --git a/keras/src/utils/image_dataset_utils.py b/keras/src/utils/image_dataset_utils.py new file mode 100755 index 000000000000..a9fe50050187 --- /dev/null +++ b/keras/src/utils/image_dataset_utils.py @@ -0,0 +1,681 @@ +import io +import pathlib + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.backend.config import standardize_data_format +from keras.src.utils import dataset_utils +from keras.src.utils import image_utils +from keras.src.utils.grain_utils import make_batch +from keras.src.utils.module_utils import grain +from keras.src.utils.module_utils import tensorflow as tf + +try: + from PIL import Image as pil_image + + try: + pil_image_resampling = pil_image.Resampling + except AttributeError: + pil_image_resampling = pil_image +except ImportError: + pil_image = None + pil_image_resampling = None + +ALLOWLIST_FORMATS = (".bmp", ".gif", ".jpeg", ".jpg", ".png") + + +@keras_export( + [ + "keras.utils.image_dataset_from_directory", + "keras.preprocessing.image_dataset_from_directory", + ] +) +def image_dataset_from_directory( + directory, + labels="inferred", + label_mode="int", + class_names=None, + color_mode="rgb", + batch_size=32, + image_size=(256, 256), + shuffle=True, + seed=None, + validation_split=None, + subset=None, + interpolation="bilinear", + follow_links=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + data_format=None, + format="tf", + verbose=True, +): + """Generates a dataset from image files in a directory. + + If your directory structure is: + + ``` + main_directory/ + ...class_a/ + ......a_image_1.jpg + ......a_image_2.jpg + ...class_b/ + ......b_image_1.jpg + ......b_image_2.jpg + ``` + + Then calling `image_dataset_from_directory(main_directory, + labels='inferred')` will return a dataset that yields batches of + images from the subdirectories `class_a` and `class_b`, together with labels + 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). + + Supported image formats: `.jpeg`, `.jpg`, `.png`, `.bmp`, `.gif`. + Animated gifs are truncated to the first frame. + + By default, this function will return a `tf.data.Dataset` object. You can + set `format="grain"` to return a `grain.IterDataset` object instead, which + removes the TensorFlow dependency. + + Args: + directory: Directory where the data is located. + If `labels` is `"inferred"`, it should contain + subdirectories, each containing images for a class. + Otherwise, the directory structure is ignored. + labels: Either `"inferred"` + (labels are generated from the directory structure), + `None` (no labels), + or a list/tuple of integer labels of the same size as the number of + image files found in the directory. Labels should be sorted + according to the alphanumeric order of the image file paths + (obtained via `os.walk(directory)` in Python). + label_mode: String describing the encoding of `labels`. Options are: + - `"int"`: means that the labels are encoded as integers + (e.g. for `sparse_categorical_crossentropy` loss). + - `"categorical"` means that the labels are + encoded as a categorical vector + (e.g. for `categorical_crossentropy` loss). + - `"binary"` means that the labels (there can be only 2) + are encoded as `float32` scalars with values 0 or 1 + (e.g. for `binary_crossentropy`). + - `None` (no labels). + class_names: Only valid if `labels` is `"inferred"`. + This is the explicit list of class names + (must match names of subdirectories). Used to control the order + of the classes (otherwise alphanumerical order is used). + color_mode: One of `"grayscale"`, `"rgb"`, `"rgba"`. + Whether the images will be converted to + have 1, 3, or 4 channels. Defaults to `"rgb"`. + batch_size: Size of the batches of data. Defaults to 32. + If `None`, the data will not be batched + (the dataset will yield individual samples). + image_size: Size to resize images to after they are read from disk, + specified as `(height, width)`. + Since the pipeline processes batches of images that must all have + the same size, this must be provided. Defaults to `(256, 256)`. + shuffle: Whether to shuffle the data. Defaults to `True`. + If set to `False`, sorts the data in alphanumeric order. + seed: Optional random seed for shuffling and transformations. + validation_split: Optional float between 0 and 1, + fraction of data to reserve for validation. + subset: Subset of the data to return. + One of `"training"`, `"validation"`, or `"both"`. + Only used if `validation_split` is set. + When `subset="both"`, the utility returns a tuple of two datasets + (the training and validation datasets respectively). + interpolation: String, the interpolation method used when + resizing images. + Supports `"bilinear"`, `"nearest"`, `"bicubic"`, `"area"`, + `"lanczos3"`, `"lanczos5"`, `"gaussian"`, `"mitchellcubic"`. + Defaults to `"bilinear"`. + follow_links: Whether to visit subdirectories pointed to by symlinks. + Defaults to `False`. + crop_to_aspect_ratio: If `True`, resize the images without aspect + ratio distortion. When the original aspect ratio differs from the + target aspect ratio, the output image will be cropped so as to + return the largest possible window in the image + (of size `image_size`) that matches the target aspect ratio. By + default (`crop_to_aspect_ratio=False`), aspect ratio may not be + preserved. + pad_to_aspect_ratio: If `True`, resize the images without aspect + ratio distortion. When the original aspect ratio differs from the + target aspect ratio, the output image will be padded so as to + return the largest possible window in the image + (of size `image_size`) that matches the target aspect ratio. By + default (`pad_to_aspect_ratio=False`), aspect ratio may not be + preserved. + data_format: If None uses keras.config.image_data_format() + otherwise either 'channel_last' or 'channel_first'. + format: The format of the return object. Defaults to `"tf"`. Available + options are: + - `"tf"`: returns a `tf.data.Dataset` object. Requires + TensorFlow to be installed. + - `"grain"`: returns a `grain.IterDataset` object. Requires + Grain to be installed. + verbose: Whether to display number information on classes and + number of files found. Defaults to `True`. + + Returns: + + A `tf.data.Dataset` (`format="tf"`) or `grain.IterDataset` + (`format="grain"`) object. + + - If `label_mode` is `None`, it yields `float32` tensors of shape + `(batch_size, image_size[0], image_size[1], num_channels)`, + encoding images (see below for rules regarding `num_channels`). + - Otherwise, it yields a tuple `(images, labels)`, where `images` has + shape `(batch_size, image_size[0], image_size[1], num_channels)`, + and `labels` follows the format described below. + + Rules regarding labels format: + + - if `label_mode` is `"int"`, the labels are an `int32` tensor of shape + `(batch_size,)`. + - if `label_mode` is `"binary"`, the labels are a `float32` tensor of + 1s and 0s of shape `(batch_size, 1)`. + - if `label_mode` is `"categorical"`, the labels are a `float32` tensor + of shape `(batch_size, num_classes)`, representing a one-hot + encoding of the class index. + + Rules regarding number of channels in the yielded images: + + - if `color_mode` is `"grayscale"`, + there's 1 channel in the image tensors. + - if `color_mode` is `"rgb"`, + there are 3 channels in the image tensors. + - if `color_mode` is `"rgba"`, + there are 4 channels in the image tensors. + """ + + if labels not in ("inferred", None): + if not isinstance(labels, (list, tuple)): + raise ValueError( + "`labels` argument should be a list/tuple of integer labels, " + "of the same size as the number of image files in the target " + "directory. If you wish to infer the labels from the " + "subdirectory " + 'names in the target directory, pass `labels="inferred"`. ' + "If you wish to get a dataset that only contains images " + f"(no labels), pass `labels=None`. Received: labels={labels}" + ) + if class_names: + raise ValueError( + "You can only pass `class_names` if " + f'`labels="inferred"`. Received: labels={labels}, and ' + f"class_names={class_names}" + ) + if label_mode not in {"int", "categorical", "binary", None}: + raise ValueError( + '`label_mode` argument must be one of "int", ' + '"categorical", "binary", ' + f"or None. Received: label_mode={label_mode}" + ) + if labels is None or label_mode is None: + labels = None + label_mode = None + if color_mode == "rgb": + num_channels = 3 + elif color_mode == "rgba": + num_channels = 4 + elif color_mode == "grayscale": + num_channels = 1 + else: + raise ValueError( + '`color_mode` must be one of {"rgb", "rgba", "grayscale"}. ' + f"Received: color_mode={color_mode}" + ) + + if isinstance(image_size, int): + image_size = (image_size, image_size) + elif not isinstance(image_size, (list, tuple)) or not len(image_size) == 2: + raise ValueError( + "Invalid `image_size` value. Expected a tuple of 2 integers. " + f"Received: image_size={image_size}" + ) + + interpolation = interpolation.lower() + supported_interpolations = ( + "bilinear", + "nearest", + "bicubic", + "area", + "lanczos3", + "lanczos5", + "gaussian", + "mitchellcubic", + ) + if interpolation not in supported_interpolations: + raise ValueError( + "Argument `interpolation` should be one of " + f"{supported_interpolations}. " + f"Received: interpolation={interpolation}" + ) + if format not in ("tf", "grain"): + raise ValueError( + '`format` should be either "tf" or "grain". ' + f"Received: format={format}" + ) + + dataset_utils.check_validation_split_arg( + validation_split, subset, shuffle, seed + ) + + if seed is None: + seed = np.random.randint(1e6) + image_paths, labels, class_names = dataset_utils.index_directory( + directory, + labels, + formats=ALLOWLIST_FORMATS, + class_names=class_names, + shuffle=shuffle, + seed=seed, + follow_links=follow_links, + verbose=verbose, + ) + + if label_mode == "binary" and len(class_names) != 2: + raise ValueError( + 'When passing `label_mode="binary"`, there must be exactly 2 ' + f"class_names. Received: class_names={class_names}" + ) + + data_format = standardize_data_format(data_format=data_format) + if batch_size is not None: + shuffle_buffer_size = batch_size * 8 + else: + shuffle_buffer_size = 1024 + + if subset == "both": + ( + image_paths_train, + labels_train, + ) = dataset_utils.get_training_or_validation_split( + image_paths, labels, validation_split, "training" + ) + ( + image_paths_val, + labels_val, + ) = dataset_utils.get_training_or_validation_split( + image_paths, labels, validation_split, "validation" + ) + if not image_paths_train: + raise ValueError( + f"No training images found in directory {directory}. " + f"Allowed formats: {ALLOWLIST_FORMATS}" + ) + if not image_paths_val: + raise ValueError( + f"No validation images found in directory {directory}. " + f"Allowed formats: {ALLOWLIST_FORMATS}" + ) + train_dataset = paths_and_labels_to_dataset( + image_paths=image_paths_train, + image_size=image_size, + num_channels=num_channels, + labels=labels_train, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + interpolation=interpolation, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + data_format=data_format, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + format=format, + ) + + val_dataset = paths_and_labels_to_dataset( + image_paths=image_paths_val, + image_size=image_size, + num_channels=num_channels, + labels=labels_val, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + interpolation=interpolation, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + data_format=data_format, + shuffle=False, + format=format, + ) + + if format == "tf": + if batch_size is not None: + train_dataset = train_dataset.batch(batch_size) + val_dataset = val_dataset.batch(batch_size) + train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) + val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) + else: + train_dataset = train_dataset.to_iter_dataset() + val_dataset = val_dataset.to_iter_dataset() + if batch_size is not None: + train_dataset = train_dataset.batch( + batch_size, batch_fn=make_batch + ) + val_dataset = val_dataset.batch(batch_size, batch_fn=make_batch) + + # Users may need to reference `class_names`. + train_dataset.class_names = class_names + val_dataset.class_names = class_names + + # Include file paths for images as attribute. + train_dataset.file_paths = image_paths_train + val_dataset.file_paths = image_paths_val + + dataset = [train_dataset, val_dataset] + else: + image_paths, labels = dataset_utils.get_training_or_validation_split( + image_paths, labels, validation_split, subset + ) + if not image_paths: + raise ValueError( + f"No images found in directory {directory}. " + f"Allowed formats: {ALLOWLIST_FORMATS}" + ) + + dataset = paths_and_labels_to_dataset( + image_paths=image_paths, + image_size=image_size, + num_channels=num_channels, + labels=labels, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + interpolation=interpolation, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + data_format=data_format, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + format=format, + ) + + if format == "tf": + if batch_size is not None: + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + else: + dataset = dataset.to_iter_dataset() + if batch_size is not None: + dataset = dataset.batch(batch_size, batch_fn=make_batch) + + # Users may need to reference `class_names`. + dataset.class_names = class_names + + # Include file paths for images as attribute. + dataset.file_paths = image_paths + + return dataset + + +def paths_and_labels_to_dataset( + image_paths, + image_size, + num_channels, + labels, + label_mode, + num_classes, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + shuffle=False, + shuffle_buffer_size=None, + seed=None, + format="tf", +): + """Constructs a dataset of images and labels.""" + if format == "tf": + return _paths_and_labels_to_dataset_tf( + image_paths=image_paths, + image_size=image_size, + num_channels=num_channels, + labels=labels, + label_mode=label_mode, + num_classes=num_classes, + interpolation=interpolation, + data_format=data_format, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + elif format == "grain": + return _paths_and_labels_to_dataset_grain( + image_paths=image_paths, + image_size=image_size, + num_channels=num_channels, + labels=labels, + label_mode=label_mode, + num_classes=num_classes, + interpolation=interpolation, + data_format=data_format, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + shuffle=shuffle, + seed=seed, + ) + else: + raise ValueError( + '`format` should be either "tf" or "grain". ' + f"Received: format={format}" + ) + + +def _paths_and_labels_to_dataset_tf( + image_paths, + image_size, + num_channels, + labels, + label_mode, + num_classes, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + """Constructs a dataset of images and labels.""" + path_ds = tf.data.Dataset.from_tensor_slices(image_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_tf( + labels, label_mode, num_classes + ) + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) + + args = ( + image_size, + num_channels, + interpolation, + data_format, + crop_to_aspect_ratio, + pad_to_aspect_ratio, + ) + if label_mode: + ds = ds.map( + lambda x, y: (_load_image_tf(x, *args), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + else: + ds = ds.map( + lambda x: _load_image_tf(x, *args), + num_parallel_calls=tf.data.AUTOTUNE, + ) + return ds + + +def _load_image_tf( + path, + image_size, + num_channels, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, +): + """Load an image from a path and resize it.""" + img = tf.io.read_file(path) + img = tf.image.decode_image( + img, channels=num_channels, expand_animations=False + ) + + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio`, `crop_to_aspect_ratio`" + " can be set to `True`." + ) + + if crop_to_aspect_ratio: + from keras.src.backend import tensorflow as tf_backend + + if data_format == "channels_first": + img = tf.transpose(img, (2, 0, 1)) + img = image_utils.smart_resize( + img, + image_size, + interpolation=interpolation, + data_format=data_format, + backend_module=tf_backend, + ) + elif pad_to_aspect_ratio: + img = tf.image.resize_with_pad( + img, image_size[0], image_size[1], method=interpolation + ) + if data_format == "channels_first": + img = tf.transpose(img, (2, 0, 1)) + else: + img = tf.image.resize(img, image_size, method=interpolation) + if data_format == "channels_first": + img = tf.transpose(img, (2, 0, 1)) + + if data_format == "channels_last": + img.set_shape((image_size[0], image_size[1], num_channels)) + else: + img.set_shape((num_channels, image_size[0], image_size[1])) + return img + + +def _paths_and_labels_to_dataset_grain( + image_paths, + image_size, + num_channels, + labels, + label_mode, + num_classes, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + shuffle=False, + seed=None, +): + """Constructs a dataset of images and labels.""" + path_ds = grain.MapDataset.source(image_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_grain( + labels, label_mode, num_classes + ) + ds = grain.experimental.ZipMapDataset([path_ds, label_ds]) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(seed=seed) + + args = ( + image_size, + num_channels, + interpolation, + data_format, + crop_to_aspect_ratio, + pad_to_aspect_ratio, + ) + if label_mode: + ds = ds.map(lambda data: (_load_image_grain(data[0], *args), data[1])) + else: + ds = ds.map(lambda x: _load_image_grain(x, *args)) + + return ds + + +def _load_image_grain( + path, + image_size, + num_channels, + interpolation, + data_format, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, +): + """Load an image from a path and resize it.""" + from keras.src import backend + from keras.src import ops + + if pil_image is None: + raise ImportError( + "Could not import PIL.Image. The use of `load_img` requires PIL." + ) + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio`, `crop_to_aspect_ratio`" + " can be set to `True`." + ) + + if isinstance(path, io.BytesIO): + img = pil_image.open(path) + elif isinstance(path, (pathlib.Path, bytes, str)): + if isinstance(path, pathlib.Path): + path = str(path.resolve()) + img = pil_image.open(path) + else: + raise TypeError( + f"path should be path-like or io.BytesIO, not {type(path)}" + ) + if num_channels == 1: + # if image is not already an 8-bit, 16-bit or 32-bit grayscale image + # convert it to an 8-bit grayscale image. + if img.mode not in ("L", "I;16", "I"): + img = img.convert("L") + elif num_channels == 4: + if img.mode != "RGBA": + img = img.convert("RGBA") + elif num_channels == 3: + if img.mode != "RGB": + img = img.convert("RGB") + else: + raise ValueError( + "num_channels must be 1, 3 or 4. " + f"Received: num_channels={num_channels}" + ) + + with backend.device_scope("cpu"): + img = ops.convert_to_tensor(np.array(img), dtype="float32") + if len(img.shape) == 2: + # If the image is grayscale, expand dims to add channel axis. + # The reason is that `ops.image.resize` expects 3D or 4D tensors. + img = ops.expand_dims(img, axis=-1) + if data_format == "channels_first": + img = ops.transpose(img, (2, 0, 1)) + img = ops.image.resize( + img, + size=image_size, + interpolation=interpolation, + crop_to_aspect_ratio=crop_to_aspect_ratio, + pad_to_aspect_ratio=pad_to_aspect_ratio, + data_format=data_format, + ) + if backend.backend() == "tensorflow": + if data_format == "channels_last": + img.set_shape((image_size[0], image_size[1], num_channels)) + else: + img.set_shape((num_channels, image_size[0], image_size[1])) + return img diff --git a/keras/src/utils/image_dataset_utils_test.py b/keras/src/utils/image_dataset_utils_test.py new file mode 100644 index 000000000000..31251228b86f --- /dev/null +++ b/keras/src/utils/image_dataset_utils_test.py @@ -0,0 +1,655 @@ +import os + +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.utils import image_dataset_utils +from keras.src.utils import image_utils +from keras.src.utils.module_utils import tensorflow as tf + + +class ImageDatasetFromDirectoryTest(testing.TestCase): + def _get_images(self, count=16, color_mode="rgb"): + width = height = 24 + imgs = [] + for _ in range(count): + if color_mode == "grayscale": + img = np.random.randint(0, 256, size=(height, width, 1)) + elif color_mode == "rgba": + img = np.random.randint(0, 256, size=(height, width, 4)) + else: + img = np.random.randint(0, 256, size=(height, width, 3)) + if backend.config.image_data_format() == "channels_first": + img = np.transpose(img, (2, 0, 1)) + img = image_utils.array_to_img(img) + imgs.append(img) + return imgs + + def _prepare_directory( + self, + num_classes=2, + nested_dirs=False, + color_mode="rgb", + count=16, + ): + # Generate paths to class subdirectories + temp_dir = self.get_temp_dir() + paths = [] + for class_index in range(num_classes): + class_directory = f"class_{class_index}" + if nested_dirs: + class_paths = [ + class_directory, + os.path.join(class_directory, "subfolder_1"), + os.path.join(class_directory, "subfolder_2"), + os.path.join( + class_directory, "subfolder_1", "sub-subfolder" + ), + ] + else: + class_paths = [class_directory] + for path in class_paths: + os.mkdir(os.path.join(temp_dir, path)) + paths += class_paths + + # Save images to the paths + i = 0 + for img in self._get_images(color_mode=color_mode, count=count): + path = paths[i % len(paths)] + if color_mode == "rgb": + ext = "jpg" + else: + ext = "png" + filename = os.path.join(path, f"image_{i}.{ext}") + img.save(os.path.join(temp_dir, filename)) + i += 1 + return temp_dir + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_no_labels(self, format): + # Test retrieving images without labels from a directory and its + # subdirs. + + # Save a few extra images in the parent directory. + directory = self._prepare_directory(count=7, num_classes=2) + for i, img in enumerate(self._get_images(3)): + filename = f"image_{i}.jpg" + img.save(os.path.join(directory, filename)) + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=5, + image_size=(18, 18), + labels=None, + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [5, 18, 18, 3] + else: + output_shape = [5, 3, 18, 18] + self.assertEqual(dataset.class_names, None) + batch = next(iter(dataset)) + # We return plain images + self.assertEqual(list(batch.shape), output_shape) + self.assertDType(batch, "float32") + # Count samples + batch_count = 0 + sample_count = 0 + for batch in dataset: + batch_count += 1 + sample_count += batch.shape[0] + self.assertEqual(batch_count, 2) + self.assertEqual(sample_count, 10) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_binary(self, format): + directory = self._prepare_directory(num_classes=2) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode="int", + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [8, 18, 18, 3] + else: + output_shape = [8, 3, 18, 18] + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode="binary", + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8, 1]) + self.assertDType(batch[1], "float32") + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode="categorical", + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8, 2]) + self.assertDType(batch[1], "float32") + + def test_static_shape_in_graph(self): + directory = self._prepare_directory(num_classes=2) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, batch_size=8, image_size=(18, 18), label_mode="int" + ) + test_case = self + if backend.config.image_data_format() == "channels_last": + output_shape = [None, 18, 18, 3] + else: + output_shape = [None, 3, 18, 18] + + @tf.function + def symbolic_fn(ds): + for x, _ in ds.take(1): + test_case.assertListEqual(x.shape.as_list(), output_shape) + + symbolic_fn(dataset) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_sample_count(self, format): + directory = self._prepare_directory(num_classes=4, count=15) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + format=format, + ) + sample_count = 0 + for batch in dataset: + sample_count += batch.shape[0] + self.assertEqual(sample_count, 15) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_multiclass(self, format): + directory = self._prepare_directory(num_classes=4, count=15) + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [8, 18, 18, 3] + else: + output_shape = [8, 3, 18, 18] + batch = next(iter(dataset)) + self.assertEqual(list(batch.shape), output_shape) + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + format=format, + ) + sample_count = 0 + iterator = iter(dataset) + for batch in dataset: + sample_count += next(iterator).shape[0] + self.assertEqual(sample_count, 15) + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode="int", + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode="categorical", + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + self.assertEqual(list(batch[1].shape), [8, 4]) + self.assertDType(batch[1], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_color_modes(self, format): + directory = self._prepare_directory(num_classes=4, color_mode="rgba") + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + color_mode="rgba", + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [8, 18, 18, 4] + else: + output_shape = [8, 4, 18, 18] + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + + directory = self._prepare_directory( + num_classes=4, color_mode="grayscale" + ) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + color_mode="grayscale", + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [8, 18, 18, 1] + else: + output_shape = [8, 1, 18, 18] + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + self.assertDType(batch[0], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_validation_split(self, format): + directory = self._prepare_directory(num_classes=2, count=10) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=10, + image_size=(18, 18), + validation_split=0.2, + subset="training", + seed=1337, + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + if backend.config.image_data_format() == "channels_last": + train_output_shape = [8, 18, 18, 3] + val_output_shape = [2, 18, 18, 3] + else: + train_output_shape = [8, 3, 18, 18] + val_output_shape = [2, 3, 18, 18] + self.assertEqual(list(batch[0].shape), train_output_shape) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=10, + image_size=(18, 18), + validation_split=0.2, + subset="validation", + seed=1337, + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), val_output_shape) + + ( + train_dataset, + val_dataset, + ) = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=10, + image_size=(18, 18), + validation_split=0.2, + subset="both", + seed=1337, + format=format, + ) + batch = next(iter(train_dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), train_output_shape) + batch = next(iter(val_dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), val_output_shape) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_manual_labels(self, format): + # Case: wrong number of labels + directory = self._prepare_directory(num_classes=1, count=4) + with self.assertRaisesRegex(ValueError, "match the number of files"): + image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + labels=[0, 1, 0], + shuffle=False, + format=format, + ) + + # Case: single directory + directory = self._prepare_directory(num_classes=1, count=4) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + labels=[0, 1, 0, 1], + shuffle=False, + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [18, 18, 3] + else: + output_shape = [3, 18, 18] + self.assertEqual(dataset.class_names, ["0", "1"]) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), [4] + output_shape) + self.assertAllClose(batch[1], [0, 1, 0, 1]) + + # Case: multiple directories + directory = self._prepare_directory(num_classes=3, count=6) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + labels=[0, 1, 0, 1, 1, 1], + shuffle=False, + format=format, + ) + self.assertEqual(dataset.class_names, ["0", "1"]) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), [6] + output_shape) + self.assertAllClose(batch[1], [0, 1, 0, 1, 1, 1]) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_follow_links(self, format): + directory = self._prepare_directory( + num_classes=2, count=25, nested_dirs=True + ) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + follow_links=True, + format=format, + ) + sample_count = 0 + for batch in dataset: + sample_count += batch.shape[0] + self.assertEqual(sample_count, 25) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_no_images(self, format): + directory = self._prepare_directory(num_classes=2, count=0) + with self.assertRaisesRegex(ValueError, "No images found."): + _ = image_dataset_utils.image_dataset_from_directory( + directory, format=format + ) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_crop_to_aspect_ratio(self, format): + directory = self._prepare_directory(num_classes=2, count=5) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=5, + image_size=(18, 18), + crop_to_aspect_ratio=True, + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [5, 18, 18, 3] + else: + output_shape = [5, 3, 18, 18] + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_pad_to_aspect_ratio(self, format): + directory = self._prepare_directory(num_classes=2, count=5) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=5, + image_size=(18, 18), + pad_to_aspect_ratio=True, + format=format, + ) + if backend.config.image_data_format() == "channels_last": + output_shape = [5, 18, 18, 3] + else: + output_shape = [5, 3, 18, 18] + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertEqual(list(batch[0].shape), output_shape) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_errors(self, format): + directory = self._prepare_directory(num_classes=3, count=5) + + with self.assertRaisesRegex(ValueError, "`labels` argument should be"): + _ = image_dataset_utils.image_dataset_from_directory( + directory, labels="other", format=format + ) + + with self.assertRaisesRegex( + ValueError, "`label_mode` argument must be" + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, label_mode="other", format=format + ) + + with self.assertRaisesRegex(ValueError, "`color_mode` must be one of"): + _ = image_dataset_utils.image_dataset_from_directory( + directory, color_mode="other", format=format + ) + + with self.assertRaisesRegex( + ValueError, 'only pass `class_names` if `labels="inferred"`' + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, + labels=[0, 0, 1, 1, 1], + class_names=["class_0", "class_1", "class_2"], + format=format, + ) + + with self.assertRaisesRegex( + ValueError, + "Expected the lengths of `labels` to match the number of files", + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, labels=[0, 0, 1, 1], format=format + ) + + with self.assertRaisesRegex( + ValueError, "`class_names` passed did not match" + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, class_names=["class_0", "wrong_class"], format=format + ) + + with self.assertRaisesRegex(ValueError, "there must be exactly 2"): + _ = image_dataset_utils.image_dataset_from_directory( + directory, label_mode="binary", format=format + ) + + with self.assertRaisesRegex( + ValueError, "`validation_split` must be between 0 and 1" + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, validation_split=2, format=format + ) + + with self.assertRaisesRegex( + ValueError, + '`subset` must be either "training", "validation" or "both"', + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, validation_split=0.2, subset="other", format=format + ) + + with self.assertRaisesRegex( + ValueError, "`validation_split` must be set" + ): + _ = image_dataset_utils.image_dataset_from_directory( + directory, + validation_split=0.0, + subset="training", + format=format, + ) + + with self.assertRaisesRegex(ValueError, "must provide a `seed`"): + _ = image_dataset_utils.image_dataset_from_directory( + directory, + validation_split=0.2, + subset="training", + format=format, + ) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_not_batched(self, format): + directory = self._prepare_directory(num_classes=2, count=2) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=None, + image_size=(18, 18), + label_mode=None, + shuffle=False, + format=format, + ) + sample = next(iter(dataset)) + self.assertEqual(len(sample.shape), 3) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_image_dataset_from_directory_shuffle(self, format): + # TODO: add same test for train/val + directory = self._prepare_directory( + num_classes=2, count=25, nested_dirs=True + ) + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + follow_links=True, + shuffle=False, + format=format, + ) + batches_1 = [] + batches_2 = [] + for b in dataset: + batches_1.append(ops.convert_to_numpy(b)) + batches_1 = np.concatenate(batches_1, axis=0) + for b in dataset: + batches_2.append(ops.convert_to_numpy(b)) + batches_2 = np.concatenate(batches_2, axis=0) + self.assertAllClose(batches_1, batches_2, atol=1e-6) + + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + follow_links=True, + shuffle=True, + seed=1337, + format=format, + ) + batches_1 = [] + batches_2 = [] + for b in dataset: + batches_1.append(ops.convert_to_numpy(b)) + batches_1 = np.concatenate(batches_1, axis=0) + for b in dataset: + batches_2.append(ops.convert_to_numpy(b)) + batches_2 = np.concatenate(batches_2, axis=0) + if format == "tf": + self.assertNotAllClose(batches_1, batches_2, atol=1e-6) + else: + # Grain shuffles deterministically, so we expect the same batches. + self.assertAllClose(batches_1, batches_2, atol=1e-6) + + # Test random seed determinism + dataset = image_dataset_utils.image_dataset_from_directory( + directory, + batch_size=8, + image_size=(18, 18), + label_mode=None, + follow_links=True, + shuffle=True, + seed=1337, + format=format, + ) + batches_1_alt = [] + for b in dataset: + batches_1_alt.append(ops.convert_to_numpy(b)) + batches_1_alt = np.concatenate(batches_1_alt, axis=0) + self.assertAllClose(batches_1, batches_1_alt, atol=1e-6) diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py new file mode 100644 index 000000000000..ca8289c9f9b7 --- /dev/null +++ b/keras/src/utils/image_utils.py @@ -0,0 +1,457 @@ +"""Utilities related to image handling.""" + +import io +import pathlib +import warnings + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export + +try: + from PIL import Image as pil_image + + try: + pil_image_resampling = pil_image.Resampling + except AttributeError: + pil_image_resampling = pil_image +except ImportError: + pil_image = None + pil_image_resampling = None + + +if pil_image_resampling is not None: + PIL_INTERPOLATION_METHODS = { + "nearest": pil_image_resampling.NEAREST, + "bilinear": pil_image_resampling.BILINEAR, + "bicubic": pil_image_resampling.BICUBIC, + "hamming": pil_image_resampling.HAMMING, + "box": pil_image_resampling.BOX, + "lanczos": pil_image_resampling.LANCZOS, + } + + +@keras_export( + [ + "keras.utils.array_to_img", + "keras.preprocessing.image.array_to_img", + ] +) +def array_to_img(x, data_format=None, scale=True, dtype=None): + """Converts a 3D NumPy array to a PIL Image instance. + + Example: + + ```python + from PIL import Image + img = np.random.random(size=(100, 100, 3)) + pil_img = keras.utils.array_to_img(img) + ``` + + Args: + x: Input data, in any form that can be converted to a NumPy array. + data_format: Image data format, can be either `"channels_first"` or + `"channels_last"`. Defaults to `None`, in which case the global + setting `keras.backend.image_data_format()` is used (unless you + changed it, it defaults to `"channels_last"`). + scale: Whether to rescale the image such that minimum and maximum values + are 0 and 255 respectively. Defaults to `True`. + dtype: Dtype to use. `None` means the global setting + `keras.backend.floatx()` is used (unless you changed it, it + defaults to `"float32"`). Defaults to `None`. + + Returns: + A PIL Image instance. + """ + + data_format = backend.standardize_data_format(data_format) + if dtype is None: + dtype = backend.floatx() + if pil_image is None: + raise ImportError( + "Could not import PIL.Image. " + "The use of `array_to_img` requires PIL." + ) + x = np.asarray(x, dtype=dtype) + if x.ndim != 3: + raise ValueError( + "Expected image array to have rank 3 (single image). " + f"Got array with shape: {x.shape}" + ) + + # Original NumPy array x has format (height, width, channel) + # or (channel, height, width) + # but target PIL image has format (width, height, channel) + if data_format == "channels_first": + x = x.transpose(1, 2, 0) + if scale: + x = x - np.min(x) + x_max = np.max(x) + if x_max != 0: + x /= x_max + x *= 255 + if x.shape[2] == 4: + # RGBA + return pil_image.fromarray(x.astype("uint8"), "RGBA") + elif x.shape[2] == 3: + # RGB + return pil_image.fromarray(x.astype("uint8"), "RGB") + elif x.shape[2] == 1: + # grayscale + if np.max(x) > 255: + # 32-bit signed integer grayscale image. PIL mode "I" + return pil_image.fromarray(x[:, :, 0].astype("int32"), "I") + return pil_image.fromarray(x[:, :, 0].astype("uint8"), "L") + else: + raise ValueError(f"Unsupported channel number: {x.shape[2]}") + + +@keras_export( + [ + "keras.utils.img_to_array", + "keras.preprocessing.image.img_to_array", + ] +) +def img_to_array(img, data_format=None, dtype=None): + """Converts a PIL Image instance to a NumPy array. + + Example: + + ```python + from PIL import Image + img_data = np.random.random(size=(100, 100, 3)) + img = keras.utils.array_to_img(img_data) + array = keras.utils.image.img_to_array(img) + ``` + + Args: + img: Input PIL Image instance. + data_format: Image data format, can be either `"channels_first"` or + `"channels_last"`. Defaults to `None`, in which case the global + setting `keras.backend.image_data_format()` is used (unless you + changed it, it defaults to `"channels_last"`). + dtype: Dtype to use. `None` means the global setting + `keras.backend.floatx()` is used (unless you changed it, it + defaults to `"float32"`). + + Returns: + A 3D NumPy array. + """ + + data_format = backend.standardize_data_format(data_format) + if dtype is None: + dtype = backend.floatx() + # NumPy array x has format (height, width, channel) + # or (channel, height, width) + # but original PIL image has format (width, height, channel) + x = np.asarray(img, dtype=dtype) + if len(x.shape) == 3: + if data_format == "channels_first": + x = x.transpose(2, 0, 1) + elif len(x.shape) == 2: + if data_format == "channels_first": + x = x.reshape((1, x.shape[0], x.shape[1])) + else: + x = x.reshape((x.shape[0], x.shape[1], 1)) + else: + raise ValueError(f"Unsupported image shape: {x.shape}") + return x + + +@keras_export(["keras.utils.save_img", "keras.preprocessing.image.save_img"]) +def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): + """Saves an image stored as a NumPy array to a path or file object. + + Args: + path: Path or file object. + x: NumPy array. + data_format: Image data format, either `"channels_first"` or + `"channels_last"`. + file_format: Optional file format override. If omitted, the format to + use is determined from the filename extension. If a file object was + used instead of a filename, this parameter should always be used. + scale: Whether to rescale image values to be within `[0, 255]`. + **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. + """ + data_format = backend.standardize_data_format(data_format) + img = array_to_img(x, data_format=data_format, scale=scale) + if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): + warnings.warn( + "The JPG format does not support RGBA images, converting to RGB." + ) + img = img.convert("RGB") + img.save(path, format=file_format, **kwargs) + + +@keras_export(["keras.utils.load_img", "keras.preprocessing.image.load_img"]) +def load_img( + path, + color_mode="rgb", + target_size=None, + interpolation="nearest", + keep_aspect_ratio=False, +): + """Loads an image into PIL format. + + Example: + + ```python + image = keras.utils.load_img(image_path) + input_arr = keras.utils.img_to_array(image) + input_arr = np.array([input_arr]) # Convert single image to a batch. + predictions = model.predict(input_arr) + ``` + + Args: + path: Path to image file. + color_mode: One of `"grayscale"`, `"rgb"`, `"rgba"`. Default: `"rgb"`. + The desired image format. + target_size: Either `None` (default to original size) or tuple of ints + `(img_height, img_width)`. + interpolation: Interpolation method used to resample the image if the + target size is different from that of the loaded image. Supported + methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. + If PIL version 1.1.3 or newer is installed, `"lanczos"` + is also supported. If PIL version 3.4.0 or newer is installed, + `"box"` and `"hamming"` are also + supported. By default, `"nearest"` is used. + keep_aspect_ratio: Boolean, whether to resize images to a target + size without aspect ratio distortion. The image is cropped in + the center with target aspect ratio before resizing. + + Returns: + A PIL Image instance. + """ + if pil_image is None: + raise ImportError( + "Could not import PIL.Image. The use of `load_img` requires PIL." + ) + if isinstance(path, io.BytesIO): + img = pil_image.open(path) + elif isinstance(path, (pathlib.Path, bytes, str)): + if isinstance(path, pathlib.Path): + path = str(path.resolve()) + with open(path, "rb") as f: + img = pil_image.open(io.BytesIO(f.read())) + else: + raise TypeError( + f"path should be path-like or io.BytesIO, not {type(path)}" + ) + + if color_mode == "grayscale": + # if image is not already an 8-bit, 16-bit or 32-bit grayscale image + # convert it to an 8-bit grayscale image. + if img.mode not in ("L", "I;16", "I"): + img = img.convert("L") + elif color_mode == "rgba": + if img.mode != "RGBA": + img = img.convert("RGBA") + elif color_mode == "rgb": + if img.mode != "RGB": + img = img.convert("RGB") + else: + raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') + if target_size is not None: + width_height_tuple = (target_size[1], target_size[0]) + if img.size != width_height_tuple: + if interpolation not in PIL_INTERPOLATION_METHODS: + raise ValueError( + "Invalid interpolation method {} specified. Supported " + "methods are {}".format( + interpolation, + ", ".join(PIL_INTERPOLATION_METHODS.keys()), + ) + ) + resample = PIL_INTERPOLATION_METHODS[interpolation] + + if keep_aspect_ratio: + width, height = img.size + target_width, target_height = width_height_tuple + + crop_height = (width * target_height) // target_width + crop_width = (height * target_width) // target_height + + # Set back to input height / width + # if crop_height / crop_width is not smaller. + crop_height = min(height, crop_height) + crop_width = min(width, crop_width) + + crop_box_hstart = (height - crop_height) // 2 + crop_box_wstart = (width - crop_width) // 2 + crop_box_wend = crop_box_wstart + crop_width + crop_box_hend = crop_box_hstart + crop_height + crop_box = [ + crop_box_wstart, + crop_box_hstart, + crop_box_wend, + crop_box_hend, + ] + img = img.resize(width_height_tuple, resample, box=crop_box) + else: + img = img.resize(width_height_tuple, resample) + return img + + +@keras_export("keras.preprocessing.image.smart_resize") +def smart_resize( + x, + size, + interpolation="bilinear", + data_format="channels_last", + backend_module=None, +): + """Resize images to a target size without aspect ratio distortion. + + Image datasets typically yield images that have each a different + size. However, these images need to be batched before they can be + processed by Keras layers. To be batched, images need to share the same + height and width. + + You could simply do, in TF (or JAX equivalent): + + ```python + size = (200, 200) + ds = ds.map(lambda img: resize(img, size)) + ``` + + However, if you do this, you distort the aspect ratio of your images, since + in general they do not all have the same aspect ratio as `size`. This is + fine in many cases, but not always (e.g. for image generation models + this can be a problem). + + Note that passing the argument `preserve_aspect_ratio=True` to `resize` + will preserve the aspect ratio, but at the cost of no longer respecting the + provided target size. + + This calls for: + + ```python + size = (200, 200) + ds = ds.map(lambda img: smart_resize(img, size)) + ``` + + Your output images will actually be `(200, 200)`, and will not be distorted. + Instead, the parts of the image that do not fit within the target size + get cropped out. + + The resizing process is: + + 1. Take the largest centered crop of the image that has the same aspect + ratio as the target size. For instance, if `size=(200, 200)` and the input + image has size `(340, 500)`, we take a crop of `(340, 340)` centered along + the width. + 2. Resize the cropped image to the target size. In the example above, + we resize the `(340, 340)` crop to `(200, 200)`. + + Args: + x: Input image or batch of images (as a tensor or NumPy array). + Must be in format `(height, width, channels)` + or `(batch_size, height, width, channels)`. + size: Tuple of `(height, width)` integer. Target size. + interpolation: String, interpolation to use for resizing. + Supports `"bilinear"`, `"nearest"`, `"bicubic"`, + `"lanczos3"`, `"lanczos5"`. + Defaults to `"bilinear"`. + data_format: `"channels_last"` or `"channels_first"`. + backend_module: Backend module to use (if different from the default + backend). + + Returns: + Array with shape `(size[0], size[1], channels)`. + If the input image was a NumPy array, the output is a NumPy array, + and if it was a backend-native tensor, + the output is a backend-native tensor. + """ + backend_module = backend_module or backend + if len(size) != 2: + raise ValueError( + f"Expected `size` to be a tuple of 2 integers, but got: {size}." + ) + img = backend_module.convert_to_tensor(x) + if len(img.shape) is not None: + if len(img.shape) < 3 or len(img.shape) > 4: + raise ValueError( + "Expected an image array with shape `(height, width, " + "channels)`, or `(batch_size, height, width, channels)`, but " + f"got input with incorrect rank, of shape {img.shape}." + ) + shape = backend_module.shape(img) + if data_format == "channels_last": + height, width = shape[-3], shape[-2] + else: + height, width = shape[-2], shape[-1] + target_height, target_width = size + + # Set back to input height / width if crop_height / crop_width is not + # smaller. + if isinstance(height, int) and isinstance(width, int): + # For JAX, we need to keep the slice indices as static integers + crop_height = int(float(width * target_height) / target_width) + crop_height = max(min(height, crop_height), 1) + crop_width = int(float(height * target_width) / target_height) + crop_width = max(min(width, crop_width), 1) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + else: + crop_height = backend_module.cast( + backend_module.cast(width * target_height, "float32") + / target_width, + "int32", + ) + crop_height = backend_module.numpy.minimum(height, crop_height) + crop_height = backend_module.numpy.maximum(crop_height, 1) + crop_height = backend_module.cast(crop_height, "int32") + + crop_width = backend_module.cast( + backend_module.cast(height * target_width, "float32") + / target_height, + "int32", + ) + crop_width = backend_module.numpy.minimum(width, crop_width) + crop_width = backend_module.numpy.maximum(crop_width, 1) + crop_width = backend_module.cast(crop_width, "int32") + + crop_box_hstart = backend_module.cast( + backend_module.cast(height - crop_height, "float32") / 2, "int32" + ) + crop_box_wstart = backend_module.cast( + backend_module.cast(width - crop_width, "float32") / 2, "int32" + ) + + if data_format == "channels_last": + if len(img.shape) == 4: + img = img[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + img = img[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(img.shape) == 4: + img = img[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + img = img[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + + img = backend_module.image.resize( + img, size=size, interpolation=interpolation, data_format=data_format + ) + + if isinstance(x, np.ndarray): + return np.array(img) + return img diff --git a/keras/src/utils/io_utils.py b/keras/src/utils/io_utils.py new file mode 100644 index 000000000000..f593099c3626 --- /dev/null +++ b/keras/src/utils/io_utils.py @@ -0,0 +1,142 @@ +import sys + +from absl import logging + +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +@keras_export( + [ + "keras.config.enable_interactive_logging", + "keras.utils.enable_interactive_logging", + ] +) +def enable_interactive_logging(): + """Turn on interactive logging. + + When interactive logging is enabled, Keras displays logs via stdout. + This provides the best experience when using Keras in an interactive + environment such as a shell or a notebook. + """ + global_state.set_global_attribute("interactive_logging", True) + + +@keras_export( + [ + "keras.config.disable_interactive_logging", + "keras.utils.disable_interactive_logging", + ] +) +def disable_interactive_logging(): + """Turn off interactive logging. + + When interactive logging is disabled, Keras sends logs to `absl.logging`. + This is the best option when using Keras in a non-interactive + way, such as running a training or inference job on a server. + """ + global_state.set_global_attribute("interactive_logging", False) + + +@keras_export( + [ + "keras.config.is_interactive_logging_enabled", + "keras.utils.is_interactive_logging_enabled", + ] +) +def is_interactive_logging_enabled(): + """Check if interactive logging is enabled. + + To switch between writing logs to stdout and `absl.logging`, you may use + `keras.config.enable_interactive_logging()` and + `keras.config.disable_interactive_logging()`. + + Returns: + Boolean, `True` if interactive logging is enabled, + and `False` otherwise. + """ + return global_state.get_global_attribute("interactive_logging", True) + + +def set_logging_verbosity(level): + """Sets the verbosity level for logging. + + Supported log levels are as follows: + + - `"FATAL"` (least verbose) + - `"ERROR"` + - `"WARNING"` + - `"INFO"` + - `"DEBUG"` (most verbose) + + Args: + level: A string corresponding to the level of verbosity for logging. + """ + valid_levels = { + "FATAL": logging.FATAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + } + verbosity = valid_levels.get(level) + if verbosity is None: + raise ValueError( + "Please pass a valid level for logging verbosity. " + f"Expected one of: {set(valid_levels.keys())}. " + f"Received: {level}" + ) + logging.set_verbosity(verbosity) + + +def print_msg(message, line_break=True): + """Print the message to absl logging or stdout.""" + message = str(message) + if is_interactive_logging_enabled(): + message = f"{message}\n" if line_break else message + try: + sys.stdout.write(message) + except UnicodeEncodeError: + # If the encoding differs from UTF-8, `sys.stdout.write` may fail. + # To address this, replace special unicode characters in the + # message, and then encode and decode using the target encoding. + message = _replace_special_unicode_character(message) + # Fallback to UTF-8 when `sys.stdout.encoding` is `None` (e.g. when + # stdout is redirected). This prevents a `TypeError` that would be + # raised by `bytes.encode(None)` / `bytes.decode(None)`. + encoding = sys.stdout.encoding or "utf-8" + message_bytes = message.encode(encoding, errors="ignore") + message = message_bytes.decode(encoding) + sys.stdout.write(message) + sys.stdout.flush() + else: + logging.info(message) + + +def ask_to_proceed_with_overwrite(filepath): + """Produces a prompt asking about overwriting a file. + + Args: + filepath: the path to the file to be overwritten. + + Returns: + True if we can proceed with overwrite, False otherwise. + """ + overwrite = ( + input(f"[WARNING] {filepath} already exists - overwrite? [y/n]") + .strip() + .lower() + ) + while overwrite not in ("y", "n"): + overwrite = ( + input('Enter "y" (overwrite) or "n" (cancel).').strip().lower() + ) + if overwrite == "n": + return False + print_msg("[TIP] Next time specify overwrite=True!") + return True + + +def _replace_special_unicode_character(message): + message = str(message).replace("━", "=") # Fall back to Keras2 behavior. + return message diff --git a/keras/src/utils/io_utils_test.py b/keras/src/utils/io_utils_test.py new file mode 100644 index 000000000000..2fe1fbbea219 --- /dev/null +++ b/keras/src/utils/io_utils_test.py @@ -0,0 +1,69 @@ +import sys +import tempfile +from unittest.mock import patch + +from keras.src.testing import test_case +from keras.src.utils import io_utils + + +class TestIoUtils(test_case.TestCase): + def test_enable_interactive_logging(self): + io_utils.enable_interactive_logging() + self.assertTrue(io_utils.is_interactive_logging_enabled()) + + def test_disable_interactive_logging(self): + io_utils.disable_interactive_logging() + self.assertFalse(io_utils.is_interactive_logging_enabled()) + + def test_set_logging_verbosity_valid(self): + valid_levels = ["FATAL", "ERROR", "WARNING", "INFO", "DEBUG"] + for level in valid_levels: + io_utils.set_logging_verbosity(level) + + def test_set_logging_verbosity_invalid(self): + with self.assertRaises(ValueError): + io_utils.set_logging_verbosity("INVALID") + + @patch("builtins.input", side_effect=["y"]) + def test_ask_to_proceed_with_overwrite_yes(self, _): + self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + + @patch("builtins.input", side_effect=["n"]) + def test_ask_to_proceed_with_overwrite_no(self, _): + self.assertFalse(io_utils.ask_to_proceed_with_overwrite("test_path")) + + @patch("sys.stdout.write") + def test_print_msg_interactive_with_line_break(self, mock_write): + io_utils.enable_interactive_logging() + io_utils.print_msg("Hello", line_break=True) + mock_write.assert_called_once_with("Hello\n") + + @patch("sys.stdout.write") + def test_print_msg_interactive_without_line_break(self, mock_write): + io_utils.enable_interactive_logging() + io_utils.print_msg("Hello", line_break=False) + mock_write.assert_called_once_with("Hello") + + @patch("absl.logging.info") + def test_print_msg_non_interactive(self, mock_logging): + io_utils.disable_interactive_logging() + io_utils.print_msg("Hello") + mock_logging.assert_called_once_with("Hello") + + @patch("builtins.input", side_effect=["invalid", "invalid", "y"]) + def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _): + self.assertTrue(io_utils.ask_to_proceed_with_overwrite("test_path")) + + @patch("builtins.input", side_effect=["invalid", "n"]) + def test_ask_to_proceed_with_overwrite_invalid_then_no(self, _): + self.assertFalse(io_utils.ask_to_proceed_with_overwrite("test_path")) + + def test_print_msg_with_different_encoding(self): + # https://github.com/keras-team/keras/issues/19386 + io_utils.enable_interactive_logging() + self.assertTrue(io_utils.is_interactive_logging_enabled()) + ori_stdout = sys.stdout + with tempfile.TemporaryFile(mode="w", encoding="cp1251") as tmp: + sys.stdout = tmp + io_utils.print_msg("━") + sys.stdout = ori_stdout diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py new file mode 100644 index 000000000000..a02af992778f --- /dev/null +++ b/keras/src/utils/jax_layer.py @@ -0,0 +1,690 @@ +import inspect + +import numpy as np + +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend.common.variables import is_float_dtype +from keras.src.backend.common.variables import standardize_dtype +from keras.src.layers.layer import Layer +from keras.src.saving import serialization_lib +from keras.src.utils import jax_utils +from keras.src.utils import tracking +from keras.src.utils.module_utils import jax + + +@keras_export("keras.layers.JaxLayer") +class JaxLayer(Layer): + """Keras Layer that wraps a JAX model. + + This layer enables the use of JAX components within Keras when using JAX as + the backend for Keras. + + ## Model function + + This layer accepts JAX models in the form of a function, `call_fn`, which + must take the following arguments with these exact names: + + - `params`: trainable parameters of the model. + - `state` (*optional*): non-trainable state of the model. Can be omitted if + the model has no non-trainable state. + - `rng` (*optional*): a `jax.random.PRNGKey` instance. Can be omitted if the + model does not need RNGs, neither during training nor during inference. + - `inputs`: inputs to the model, a JAX array or a `PyTree` of arrays. + - `training` (*optional*): an argument specifying if we're in training mode + or inference mode, `True` is passed in training mode. Can be omitted if + the model behaves the same in training mode and inference mode. + + The `inputs` argument is mandatory. Inputs to the model must be provided via + a single argument. If the JAX model takes multiple inputs as separate + arguments, they must be combined into a single structure, for instance in a + `tuple` or a `dict`. + + ## Model weights initialization + + The initialization of the `params` and `state` of the model can be handled + by this layer, in which case the `init_fn` argument must be provided. This + allows the model to be initialized dynamically with the right shape. + Alternatively, and if the shape is known, the `params` argument and + optionally the `state` argument can be used to create an already initialized + model. + + The `init_fn` function, if provided, must take the following arguments with + these exact names: + + - `rng`: a `jax.random.PRNGKey` instance. + - `inputs`: a JAX array or a `PyTree` of arrays with placeholder values to + provide the shape of the inputs. + - `training` (*optional*): an argument specifying if we're in training mode + or inference mode. `True` is always passed to `init_fn`. Can be omitted + regardless of whether `call_fn` has a `training` argument. + + ## Models with non-trainable state + + For JAX models that have non-trainable state: + + - `call_fn` must have a `state` argument + - `call_fn` must return a `tuple` containing the outputs of the model and + the new non-trainable state of the model + - `init_fn` must return a `tuple` containing the initial trainable params of + the model and the initial non-trainable state of the model. + + This code shows a possible combination of `call_fn` and `init_fn` signatures + for a model with non-trainable state. In this example, the model has a + `training` argument and an `rng` argument in `call_fn`. + + ```python + def stateful_call(params, state, rng, inputs, training): + outputs = ... + new_state = ... + return outputs, new_state + + def stateful_init(rng, inputs): + initial_params = ... + initial_state = ... + return initial_params, initial_state + ``` + + ## Models without non-trainable state + + For JAX models with no non-trainable state: + + - `call_fn` must not have a `state` argument + - `call_fn` must return only the outputs of the model + - `init_fn` must return only the initial trainable params of the model. + + This code shows a possible combination of `call_fn` and `init_fn` signatures + for a model without non-trainable state. In this example, the model does not + have a `training` argument and does not have an `rng` argument in `call_fn`. + + ```python + def stateless_call(params, inputs): + outputs = ... + return outputs + + def stateless_init(rng, inputs): + initial_params = ... + return initial_params + ``` + + ## Conforming to the required signature + + If a model has a different signature than the one required by `JaxLayer`, + one can easily write a wrapper method to adapt the arguments. This example + shows a model that has multiple inputs as separate arguments, expects + multiple RNGs in a `dict`, and has a `deterministic` argument with the + opposite meaning of `training`. To conform, the inputs are combined in a + single structure using a `tuple`, the RNG is split and used the populate the + expected `dict`, and the Boolean flag is negated: + + ```python + def my_model_fn(params, rngs, input1, input2, deterministic): + ... + if not deterministic: + dropout_rng = rngs["dropout"] + keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape) + x = jax.numpy.where(keep, x / dropout_rate, 0) + ... + ... + return outputs + + def my_model_wrapper_fn(params, rng, inputs, training): + input1, input2 = inputs + rng1, rng2 = jax.random.split(rng) + rngs = {"dropout": rng1, "preprocessing": rng2} + deterministic = not training + return my_model_fn(params, rngs, input1, input2, deterministic) + + keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params) + ``` + + ## Usage with Haiku modules + + `JaxLayer` enables the use of [Haiku](https://dm-haiku.readthedocs.io) + components in the form of + [`haiku.Module`](https://dm-haiku.readthedocs.io/en/latest/api.html#module). + This is achieved by transforming the module per the Haiku pattern and then + passing `module.apply` in the `call_fn` parameter and `module.init` in the + `init_fn` parameter if needed. + + If the model has non-trainable state, it should be transformed with + [`haiku.transform_with_state`]( + https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform_with_state). + If the model has no non-trainable state, it should be transformed with + [`haiku.transform`]( + https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform). + Additionally, and optionally, if the module does not use RNGs in "apply", it + can be transformed with + [`haiku.without_apply_rng`]( + https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng). + + The following example shows how to create a `JaxLayer` from a Haiku module + that uses random number generators via `hk.next_rng_key()` and takes a + training positional argument: + + ```python + class MyHaikuModule(hk.Module): + def __call__(self, x, training): + x = hk.Conv2D(32, (3, 3))(x) + x = jax.nn.relu(x) + x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x) + x = hk.Flatten()(x) + x = hk.Linear(200)(x) + if training: + x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x) + x = jax.nn.relu(x) + x = hk.Linear(10)(x) + x = jax.nn.softmax(x) + return x + + def my_haiku_module_fn(inputs, training): + module = MyHaikuModule() + return module(inputs, training) + + transformed_module = hk.transform(my_haiku_module_fn) + + keras_layer = JaxLayer( + call_fn=transformed_module.apply, + init_fn=transformed_module.init, + ) + ``` + + Args: + call_fn: The function to call the model. See description above for the + list of arguments it takes and the outputs it returns. + init_fn: the function to call to initialize the model. See description + above for the list of arguments it takes and the outputs it returns. + If `None`, then `params` and/or `state` must be provided. + params: A `PyTree` containing all the model trainable parameters. This + allows passing trained parameters or controlling the initialization. + If both `params` and `state` are `None`, `init_fn` is called at + build time to initialize the trainable parameters of the model. + state: A `PyTree` containing all the model non-trainable state. This + allows passing learned state or controlling the initialization. If + both `params` and `state` are `None`, and `call_fn` takes a `state` + argument, then `init_fn` is called at build time to initialize the + non-trainable state of the model. + seed: Seed for random number generator. Optional. + dtype: The dtype of the layer's computations and weights. Can also be a + `keras.DTypePolicy`. Optional. Defaults to the default policy. + """ + + def __init__( + self, + call_fn, + init_fn=None, + params=None, + state=None, + seed=None, + **kwargs, + ): + if backend.backend() != "jax": + raise ValueError( + "JaxLayer is only supported with the JAX backend. Current " + f"backend: {backend.backend()}" + ) + + if init_fn is None and params is None and state is None: + raise ValueError( + "`init_fn`, `params` and `state` cannot all be `None`." + ) + + super().__init__(**kwargs) + self.call_fn = call_fn + self.init_fn = init_fn + self.seed_generator = backend.random.SeedGenerator(seed) + self.tracked_params = self._create_variables(params, trainable=True) + self.tracked_state = self._create_variables(state, trainable=False) + if self.params is not None or self.state is not None: + self._build_at_init() + + self.call_fn_arguments = self._validate_signature( + call_fn, + "call_fn", + {"params", "state", "rng", "inputs", "training"}, + {"inputs"}, + ) + self.has_state = "state" in self.call_fn_arguments + + if init_fn: + self.init_fn_arguments = self._validate_signature( + init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"} + ) + + def _validate_signature(self, fn, fn_name, allowed, required): + fn_parameters = inspect.signature(fn).parameters + for parameter_name in required: + if parameter_name not in fn_parameters: + raise ValueError( + f"Missing required argument in `{fn_name}`: " + f"`{parameter_name}`" + ) + + parameter_names = [] + for parameter in fn_parameters.values(): + if parameter.name not in allowed: + raise ValueError( + f"Unsupported argument in `{fn_name}`: `{parameter.name}`, " + f"supported arguments are `{'`, `'.join(allowed)}`" + ) + parameter_names.append(parameter.name) + + return parameter_names + + @tracking.no_automatic_dependency_tracking + def _create_variables(self, values, trainable): + """Create a structure of variables from a structure of JAX arrays. + + `values` is traversed via JAX's `tree_map`. When a leaf is a JAX array + or a tensor-like object, a corresponding variable is created with it as + the initial value. The resulting structure of variables is assigned to + `self.params` or `self.state` depending on `trainable`. Then, a + flattened version of the variables is returned for tracking. + `self.params` or `self.state` are intentionally not tracked because + structures like `TrackedList` interfere with `jax.tree_utils`. + Note that leaf objects that are not JAX arrays and not tensor-like are + left intact as they are assumed to be configuration used by the model. + + Args: + values: the structure of values to traverse. + trainable: whether to create trainable variables. + + Returns: + flat list of variables initialized with `values` for tracking. + """ + + def create_variable(value): + if backend.is_tensor(value) or isinstance( + value, (np.ndarray, np.generic) + ): + dtype = value.dtype + if is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + value.shape, + initializer=value, + dtype=dtype, + trainable=trainable, + ) + elif isinstance(value, (bool, int, float)): + dtype = standardize_dtype(type(value)) + if is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + (), + initializer=backend.convert_to_tensor(value), + dtype=dtype, + trainable=trainable, + ) + else: + return value + + # Use JAX's tree_map as it understands registered classes. + variables = jax.tree_util.tree_map(create_variable, values) + + if trainable: + self.params = variables + else: + self.state = variables + + flat_variables, _ = jax.tree_util.tree_flatten(variables) + return flat_variables + + def _get_init_rng(self): + """ + Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`. + + By default, this returns a single `PRNGKey` retrieved by calling + `self.seed_generator.next()`. Override this to return a different + structure. + + Returns: + a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as + the `rng` argument of `init_fn`. + """ + return self.seed_generator.next() + + def _get_call_rng(self, training): + """ + Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`. + + By default, this returns a single `PRNGKey` retrieved by calling + `self.seed_generator.next()` when `training` is `True`, and `None` when + `training` is `False`. Override this to return a different structure or + to pass RNGs in inference mode too. + + Returns: + a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as + the `rng` argument of `call_fn`. + """ + if training: + return self.seed_generator.next() + else: + return None + + def build(self, input_shape): + if self.params is not None or self.state is not None: + return + + if jax_utils.is_in_jax_tracing_scope(): + # This exception is not actually shown, it is caught and a detailed + # warning about calling 'build' is printed. + raise ValueError("'JaxLayer' cannot be built in tracing scope") + + # Initialize `params` and `state` if needed by calling `init_fn`. + def create_input(shape): + shape = [d if d is not None else 1 for d in shape] + return jax.numpy.ones(shape) + + init_inputs = tree.map_shape_structure(create_input, input_shape) + init_args = [] + for argument_name in self.init_fn_arguments: + if argument_name == "rng": + init_args.append(self._get_init_rng()) + elif argument_name == "inputs": + init_args.append(init_inputs) + elif argument_name == "training": + init_args.append(True) + + init_result = self.init_fn(*init_args) + if self.has_state: + init_params, init_state = init_result + else: + init_params, init_state = init_result, None + + self.tracked_params = self._create_variables( + init_params, trainable=True + ) + self.tracked_state = self._create_variables(init_state, trainable=False) + + def call(self, inputs, training=False): + def unwrap_variable(variable): + return None if variable is None else variable.value + + call_args = [] + for argument_name in self.call_fn_arguments: + if argument_name == "params": + call_args.append( + jax.tree_util.tree_map(unwrap_variable, self.params) + ) + elif argument_name == "state": + call_args.append( + jax.tree_util.tree_map(unwrap_variable, self.state) + ) + elif argument_name == "rng": + call_args.append(self._get_call_rng(training)) + elif argument_name == "inputs": + call_args.append(inputs) + elif argument_name == "training": + call_args.append(training) + + def assign_state_to_variable(value, variable): + # This exists only to make debugging this error case easier. + if not hasattr(variable, "assign"): + raise ValueError( + "Structure mismatch: the structure of the state returned " + "by `call` does not match the structure of the state at " + "initialization time." + ) + variable.assign(value) + + if self.has_state: + predictions, new_state = self.call_fn(*call_args) + jax.tree_util.tree_map( + assign_state_to_variable, new_state, self.state + ) + return predictions + else: + return self.call_fn(*call_args) + + def get_config(self): + config = { + "call_fn": serialization_lib.serialize_keras_object(self.call_fn), + "init_fn": serialization_lib.serialize_keras_object(self.init_fn), + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + call_fn = serialization_lib.deserialize_keras_object(config["call_fn"]) + init_fn = serialization_lib.deserialize_keras_object(config["init_fn"]) + config["call_fn"] = call_fn + config["init_fn"] = init_fn + return super().from_config(config) + + +@keras_export("keras.layers.FlaxLayer") +class FlaxLayer(JaxLayer): + """Keras Layer that wraps a [Flax](https://flax.readthedocs.io) module. + + This layer enables the use of Flax components in the form of + [`flax.linen.Module`]( + https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) + instances within Keras when using JAX as the backend for Keras. + + The module method to use for the forward pass can be specified via the + `method` argument and is `__call__` by default. This method must take the + following arguments with these exact names: + + - `self` if the method is bound to the module, which is the case for the + default of `__call__`, and `module` otherwise to pass the module. + - `inputs`: the inputs to the model, a JAX array or a `PyTree` of arrays. + - `training` *(optional)*: an argument specifying if we're in training mode + or inference mode, `True` is passed in training mode. + + `FlaxLayer` handles the non-trainable state of your model and required RNGs + automatically. Note that the `mutable` parameter of + [`flax.linen.Module.apply()`]( + https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply) + is set to `DenyList(["params"])`, therefore making the assumption that all + the variables outside of the "params" collection are non-trainable weights. + + This example shows how to create a `FlaxLayer` from a Flax `Module` with + the default `__call__` method and no training argument: + + ```python + class MyFlaxModule(flax.linen.Module): + @flax.linen.compact + def __call__(self, inputs): + x = inputs + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + flax_module = MyFlaxModule() + keras_layer = FlaxLayer(flax_module) + ``` + + This example shows how to wrap the module method to conform to the required + signature. This allows having multiple input arguments and a training + argument that has a different name and values. This additionally shows how + to use a function that is not bound to the module. + + ```python + class MyFlaxModule(flax.linen.Module): + @flax.linen.compact + def forward(self, input1, input2, deterministic): + ... + return outputs + + def my_flax_module_wrapper(module, inputs, training): + input1, input2 = inputs + return module.forward(input1, input2, not training) + + flax_module = MyFlaxModule() + keras_layer = FlaxLayer( + module=flax_module, + method=my_flax_module_wrapper, + ) + ``` + + Args: + module: An instance of `flax.linen.Module` or subclass. + method: The method to call the model. This is generally a method in the + `Module`. If not provided, the `__call__` method is used. `method` + can also be a function not defined in the `Module`, in which case it + must take the `Module` as the first argument. It is used for both + `Module.init` and `Module.apply`. Details are documented in the + `method` argument of [`flax.linen.Module.apply()`]( + https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply). + variables: A `dict` containing all the variables of the module in the + same format as what is returned by [`flax.linen.Module.init()`]( + https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.init). + It should contain a "params" key and, if applicable, other keys for + collections of variables for non-trainable state. This allows + passing trained parameters and learned non-trainable state or + controlling the initialization. If `None` is passed, the module's + `init` function is called at build time to initialize the variables + of the model. + """ + + def __init__( + self, + module, + method=None, + variables=None, + **kwargs, + ): + # Late import to only require Flax when this is used. + from flax.core import scope as flax_scope + + if backend.backend() != "jax": + raise ValueError( + "FlaxLayer is only supported with the JAX backend. Current " + f"backend: {backend.backend()}" + ) + + self.module = module + self.method = method + + apply_mutable = flax_scope.DenyList(["params"]) + + def apply_with_training(params, state, rng, inputs, training): + return self.module.apply( + self._params_and_state_to_variables(params, state), + inputs, + rngs=rng, + method=self.method, + mutable=apply_mutable, + training=training, + ) + + def apply_without_training(params, state, rng, inputs): + return self.module.apply( + self._params_and_state_to_variables(params, state), + inputs, + rngs=rng, + method=self.method, + mutable=apply_mutable, + ) + + def init_with_training(rng, inputs, training): + return self._variables_to_params_and_state( + self.module.init( + rng, + inputs, + method=self.method, + training=training, + ) + ) + + def init_without_training(rng, inputs): + return self._variables_to_params_and_state( + self.module.init( + rng, + inputs, + method=self.method, + ) + ) + + if ( + "training" + in inspect.signature(method or module.__call__).parameters + ): + call_fn, init_fn = apply_with_training, init_with_training + else: + call_fn, init_fn = apply_without_training, init_without_training + + params, state = self._variables_to_params_and_state(variables) + + super().__init__( + call_fn=call_fn, + init_fn=init_fn, + params=params, + state=state, + **kwargs, + ) + + def _params_and_state_to_variables(self, params, state): + if params: + if state: + return {**params, **state} + else: + return params + elif state: + return state + return {} + + def _variables_to_params_and_state(self, variables): + # neither params nor state + if variables is None: + return None, None + # state only + if "params" not in variables: + return {}, variables + # params only + if len(variables) == 1: + return variables, {} + # both, we need to split + params = {"params": variables["params"]} + state = {k: v for k, v in variables.items() if k != "params"} + return params, state + + def _get_init_rng(self): + return { + "params": self.seed_generator.next(), + "dropout": self.seed_generator.next(), + } + + def _get_call_rng(self, training): + if training: + return {"dropout": self.seed_generator.next()} + else: + return {} + + def get_config(self): + config_method = self.method + if ( + hasattr(self.method, "__self__") + and self.method.__self__ == self.module + ): + # A method bound to the module is serialized by name. + config_method = self.method.__name__ + config = { + "module": serialization_lib.serialize_keras_object(self.module), + "method": serialization_lib.serialize_keras_object(config_method), + } + base_config = super().get_config() + # call_fn and init_fn come from module, do not save them. + base_config.pop("call_fn") + base_config.pop("init_fn") + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + module = serialization_lib.deserialize_keras_object(config["module"]) + method = serialization_lib.deserialize_keras_object(config["method"]) + if isinstance(config["method"], str): + # Deserialize bound method from the module. + method = getattr(module, method) + config["module"] = module + config["method"] = method + return cls(**config) diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py new file mode 100644 index 000000000000..009ecd402e5f --- /dev/null +++ b/keras/src/utils/jax_layer_test.py @@ -0,0 +1,716 @@ +import os + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import metrics +from keras.src import models +from keras.src import saving +from keras.src import testing +from keras.src import tree +from keras.src import utils +from keras.src.dtype_policies.dtype_policy import DTypePolicy +from keras.src.saving import object_registration +from keras.src.utils.jax_layer import FlaxLayer +from keras.src.utils.jax_layer import JaxLayer + +try: + import flax +except ImportError: + flax = None + +num_classes = 10 +input_shape = (28, 28, 1) # Excluding batch_size + + +@object_registration.register_keras_serializable() +def jax_stateless_init(rng, inputs): + layer_sizes = [784, 300, 100, 10] + params = [] + w_init = jax.nn.initializers.glorot_normal() + b_init = jax.nn.initializers.normal(0.1) + for m, n in zip(layer_sizes[:-1], layer_sizes[1:]): + rng, w_rng = jax.random.split(rng) + rng, b_rng = jax.random.split(rng) + params.append([w_init(w_rng, (m, n)), b_init(b_rng, (n,))]) + return params + + +@object_registration.register_keras_serializable() +def jax_stateless_apply(params, inputs): + activations = inputs.reshape((inputs.shape[0], -1)) # flatten + for w, b in params[:-1]: + outputs = jnp.dot(activations, w) + b + activations = jnp.tanh(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(activations, final_w) + final_b + return jax.nn.softmax(logits, axis=-1) + + +@object_registration.register_keras_serializable() +def jax_stateful_init(rng, inputs, training): + params = jax_stateless_init(rng, inputs) + state = jnp.zeros([], jnp.int32) + return params, state + + +@object_registration.register_keras_serializable() +def jax_stateful_apply(params, state, inputs, training): + outputs = jax_stateless_apply(params, inputs) + if training: + state = state + 1 + return outputs, state + + +if flax is not None: + + @object_registration.register_keras_serializable() + class FlaxTrainingIndependentModel(flax.linen.Module): + @flax.linen.compact + def forward(self, inputs): + x = inputs + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + @object_registration.register_keras_serializable() + class FlaxDropoutModel(flax.linen.Module): + @flax.linen.compact + def my_apply(self, inputs, training): + x = inputs + x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) + x = flax.linen.relu(x) + x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200)(x) + x = flax.linen.Dropout(rate=0.3, deterministic=not training)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + @object_registration.register_keras_serializable() + def flax_dropout_wrapper(module, x, training): + return module.my_apply(x, training) + + @object_registration.register_keras_serializable() + class FlaxBatchNormModel(flax.linen.Module): + @flax.linen.compact + def __call__(self, inputs, training=False): + ura = not training + x = inputs + x = flax.linen.Conv( + features=12, kernel_size=(3, 3), use_bias=False + )(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)( + x + ) + x = flax.linen.relu(x) + x = flax.linen.Conv( + features=24, kernel_size=(6, 6), strides=(2, 2) + )(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)( + x + ) + x = flax.linen.relu(x) + x = flax.linen.Conv( + features=32, kernel_size=(6, 6), strides=(2, 2) + )(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)( + x + ) + x = x.reshape((x.shape[0], -1)) # flatten + x = flax.linen.Dense(features=200, use_bias=True)(x) + x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)( + x + ) + x = flax.linen.Dropout(rate=0.3, deterministic=not training)(x) + x = flax.linen.relu(x) + x = flax.linen.Dense(features=10)(x) + x = flax.linen.softmax(x) + return x + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + FLAX_OBJECTS = { + "FlaxTrainingIndependentModel": FlaxTrainingIndependentModel, + "FlaxBatchNormModel": FlaxBatchNormModel, + "FlaxDropoutModel": FlaxDropoutModel, + "flax_dropout_wrapper": flax_dropout_wrapper, + } + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JaxLayer and FlaxLayer are only supported with JAX backend", +) +class TestJaxLayer(testing.TestCase): + def _test_layer( + self, + model_name, + layer_class, + layer_init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + # Fake MNIST data + x_train = np.random.uniform(size=(320, 28, 28, 1)) + y_train = np.eye(num_classes, dtype="int32")[ + (np.random.uniform(size=(320,)) * num_classes).astype("int32") + ] + x_test = np.random.uniform(size=(32, 28, 28, 1)) + + def _count_params(weights): + count = 0 + for weight in weights: + count = count + np.prod(weight.shape) + return count + + def verify_weights_and_params(layer): + self.assertEqual(trainable_weights, len(layer.trainable_weights)) + self.assertEqual( + trainable_params, + _count_params(layer.trainable_weights), + ) + self.assertEqual( + non_trainable_weights, len(layer.non_trainable_weights) + ) + self.assertEqual( + non_trainable_params, + _count_params(layer.non_trainable_weights), + ) + + # functional model + layer1 = layer_class(**layer_init_kwargs) + inputs1 = layers.Input(shape=input_shape) + outputs1 = layer1(inputs1) + model1 = models.Model( + inputs=inputs1, outputs=outputs1, name=f"{model_name}1" + ) + model1.summary() + + verify_weights_and_params(layer1) + + model1.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=[metrics.CategoricalAccuracy()], + ) + + tw1_before_fit = tree.map_structure( + backend.convert_to_numpy, layer1.trainable_weights + ) + ntw1_before_fit = tree.map_structure( + backend.convert_to_numpy, layer1.non_trainable_weights + ) + model1.fit(x_train, y_train, epochs=1, steps_per_epoch=10) + tw1_after_fit = tree.map_structure( + backend.convert_to_numpy, layer1.trainable_weights + ) + ntw1_after_fit = tree.map_structure( + backend.convert_to_numpy, layer1.non_trainable_weights + ) + + # verify both trainable and non-trainable weights did change after fit + for before, after in zip(tw1_before_fit, tw1_after_fit): + self.assertNotAllClose(before, after) + for before, after in zip(ntw1_before_fit, ntw1_after_fit): + self.assertNotAllClose(before, after) + + expected_ouput_shape = (x_test.shape[0], num_classes) + output1 = model1(x_test) + self.assertEqual(output1.shape, expected_ouput_shape) + predict1 = model1.predict(x_test, steps=1) + self.assertEqual(predict1.shape, expected_ouput_shape) + + # verify both trainable and non-trainable weights did not change + tw1_after_call = tree.map_structure( + backend.convert_to_numpy, layer1.trainable_weights + ) + ntw1_after_call = tree.map_structure( + backend.convert_to_numpy, layer1.non_trainable_weights + ) + for after_fit, after_call in zip(tw1_after_fit, tw1_after_call): + self.assertAllClose(after_fit, after_call) + for after_fit, after_call in zip(ntw1_after_fit, ntw1_after_call): + self.assertAllClose(after_fit, after_call) + + exported_params = jax.tree_util.tree_map( + backend.convert_to_numpy, layer1.params + ) + if layer1.state is not None: + exported_state = jax.tree_util.tree_map( + backend.convert_to_numpy, layer1.state + ) + else: + exported_state = None + + def verify_identical_model(model): + output = model(x_test) + self.assertAllClose(output1, output) + + predict = model.predict(x_test, steps=1) + self.assertAllClose(predict1, predict) + + # sequential model to compare results + layer2 = layer_class( + params=exported_params, + state=exported_state, + input_shape=input_shape, + **layer_init_kwargs, + ) + model2 = models.Sequential([layer2], name=f"{model_name}2") + model2.summary() + verify_weights_and_params(layer2) + model2.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=[metrics.CategoricalAccuracy()], + ) + verify_identical_model(model2) + + # save, load back and compare results + path = os.path.join(self.get_temp_dir(), "jax_layer_model.keras") + model2.save(path) + + model3 = saving.load_model(path) + layer3 = model3.layers[0] + model3.summary() + verify_weights_and_params(layer3) + verify_identical_model(model3) + + # export, load back and compare results + path = os.path.join(self.get_temp_dir(), "jax_layer_export") + model2.export(path, format="tf_saved_model") + model4 = tf.saved_model.load(path) + output4 = model4.serve(x_test) + # The output difference is greater when using the GPU or bfloat16 + lower_precision = testing.jax_uses_gpu() or "dtype" in layer_init_kwargs + self.assertAllClose( + output1, + output4, + atol=1e-2 if lower_precision else 1e-6, + rtol=1e-3 if lower_precision else 1e-6, + ) + + # test subclass model building without a build method + class TestModel(models.Model): + def __init__(self, layer): + super().__init__() + self._layer = layer + + def call(self, inputs): + return self._layer(inputs) + + layer5 = layer_class(**layer_init_kwargs) + model5 = TestModel(layer5) + output5 = model5(x_test) + self.assertNotAllClose(output5, 0.0) + + @parameterized.named_parameters( + { + "testcase_name": "training_independent", + "init_kwargs": { + "call_fn": jax_stateless_apply, + "init_fn": jax_stateless_init, + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_state", + "init_kwargs": { + "call_fn": jax_stateful_apply, + "init_fn": jax_stateful_init, + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 1, + "non_trainable_params": 1, + }, + { + "testcase_name": "training_state_dtype_policy", + "init_kwargs": { + "call_fn": jax_stateful_apply, + "init_fn": jax_stateful_init, + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 1, + "non_trainable_params": 1, + }, + ) + def test_jax_layer( + self, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + self._test_layer( + init_kwargs["call_fn"].__name__, + JaxLayer, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ) + + @parameterized.named_parameters( + { + "testcase_name": "training_independent_bound_method", + "flax_model_class": "FlaxTrainingIndependentModel", + "flax_model_method": "forward", + "init_kwargs": {}, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_unbound_method", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_state_no_method", + "flax_model_class": "FlaxBatchNormModel", + "flax_model_method": None, + "init_kwargs": {}, + "trainable_weights": 13, + "trainable_params": 354258, + "non_trainable_weights": 8, + "non_trainable_params": 536, + }, + { + "testcase_name": "training_rng_unbound_method_dtype_policy", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + ) + @pytest.mark.skipif(flax is None, reason="Flax library is not available.") + def test_flax_layer( + self, + flax_model_class, + flax_model_method, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + flax_model_class = FLAX_OBJECTS.get(flax_model_class) + if "method" in init_kwargs: + init_kwargs["method"] = FLAX_OBJECTS.get(init_kwargs["method"]) + + def create_wrapper(**kwargs): + params = kwargs.pop("params") if "params" in kwargs else None + state = kwargs.pop("state") if "state" in kwargs else None + if params and state: + variables = {**params, **state} + elif params: + variables = params + elif state: + variables = state + else: + variables = None + kwargs["variables"] = variables + flax_model = flax_model_class() + if flax_model_method: + kwargs["method"] = getattr(flax_model, flax_model_method) + return FlaxLayer(flax_model_class(), **kwargs) + + self._test_layer( + flax_model_class.__name__, + create_wrapper, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ) + + def test_with_no_init_fn_and_no_params(self): + def jax_fn(params, inputs): + return inputs + + with self.assertRaises(ValueError): + JaxLayer(jax_fn) + + def test_with_training_in_call_fn_but_not_init_fn(self): + def jax_call_fn(params, state, rng, inputs, training): + return inputs, {} + + def jax_init_fn(rng, inputs): + return {}, {} + + layer = JaxLayer(jax_call_fn, jax_init_fn) + layer(np.ones((1,))) + + def test_with_different_argument_order(self): + def jax_call_fn(training, inputs, rng, state, params): + return inputs, {} + + def jax_init_fn(training, inputs, rng): + return {}, {} + + layer = JaxLayer(jax_call_fn, jax_init_fn) + layer(np.ones((1,))) + + def test_with_minimal_arguments(self): + def jax_call_fn(inputs): + return inputs + + def jax_init_fn(inputs): + return {} + + layer = JaxLayer(jax_call_fn, jax_init_fn) + layer(np.ones((1,))) + + def test_with_missing_inputs_in_call_fn(self): + def jax_call_fn(params, rng, training): + return jnp.ones((1,)) + + def jax_init_fn(rng, inputs): + return {} + + with self.assertRaisesRegex(ValueError, "`call_fn`.*`inputs`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_missing_inputs_in_init_fn(self): + def jax_call_fn(params, rng, inputs, training): + return jnp.ones((1,)) + + def jax_init_fn(rng, training): + return {} + + with self.assertRaisesRegex(ValueError, "`init_fn`.*`inputs`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_unsupported_argument_in_call_fn(self): + def jax_call_fn(params, rng, inputs, mode): + return jnp.ones((1,)) + + def jax_init_fn(rng, inputs): + return {} + + with self.assertRaisesRegex(ValueError, "`call_fn`.*`mode`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_unsupported_argument_in_init_fn(self): + def jax_call_fn(params, rng, inputs, training): + return inputs + + def jax_init_fn(rng, inputs, mode): + return {} + + with self.assertRaisesRegex(ValueError, "`init_fn`.*`mode`"): + JaxLayer(jax_call_fn, jax_init_fn) + + def test_with_structures_as_inputs_and_outputs(self): + def jax_fn(params, inputs): + a = inputs["a"] + b = inputs["b"] + output1 = jnp.concatenate([a, b], axis=1) + output2 = jnp.concatenate([b, a], axis=1) + return output1, output2 + + layer = JaxLayer(jax_fn, params={}) + inputs = { + "a": layers.Input((None, 3)), + "b": layers.Input((None, 3)), + } + outputs = layer(inputs) + model = models.Model(inputs, outputs) + + test_inputs = { + "a": np.ones((2, 6, 3)), + "b": np.ones((2, 7, 3)), + } + test_outputs = model(test_inputs) + self.assertAllClose(test_outputs[0], np.ones((2, 13, 3))) + self.assertAllClose(test_outputs[1], np.ones((2, 13, 3))) + + def test_with_polymorphic_shape_more_than_26_dimension_names(self): + def jax_fn(params, inputs): + return jnp.concatenate(inputs, axis=1) + + layer = JaxLayer(jax_fn, params=()) + inputs = [layers.Input((None, 3)) for _ in range(60)] + output = layer(inputs) + model = models.Model(inputs, output) + + test_inputs = [np.ones((2, 1, 3))] * 60 + test_output = model(test_inputs) + self.assertAllClose(test_output, np.ones((2, 60, 3))) + + @pytest.mark.skipif(flax is None, reason="Flax library is not available.") + def test_with_flax_state_no_params(self): + class MyFlaxLayer(flax.linen.Module): + @flax.linen.compact + def __call__(self, x): + def zeros_init(shape): + return jnp.zeros(shape, jnp.int32) + + count = self.variable("a", "b", zeros_init, []) + count.value = count.value + 1 + return x + + layer = FlaxLayer(MyFlaxLayer(), variables={"a": {"b": 0}}) + layer(np.ones((1,))) + self.assertLen(layer.params, 0) + self.assertEqual(layer.state["a"]["b"].value, 1) + + def test_with_state_none_leaves(self): + def jax_fn(params, state, inputs): + return inputs, state + + layer = JaxLayer(jax_fn, state={"foo": None}) + self.assertIsNone(layer.state["foo"]) + layer(np.ones((1,))) + + def test_with_state_non_tensor_leaves(self): + def jax_fn(params, state, inputs): + return inputs, state + + layer = JaxLayer(jax_fn, state={"foo": "bar"}) + self.assertEqual(layer.state["foo"], "bar") + # layer cannot be invoked as jax2tf will fail on strings + + def test_with_state_jax_registered_node_class(self): + @jax.tree_util.register_pytree_node_class + class NamedPoint: + def __init__(self, x, y, name): + self.x = x + self.y = y + self.name = name + + def tree_flatten(self): + return ((self.x, self.y), self.name) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children, aux_data) + + def jax_fn(params, state, inputs): + return inputs, state + + layer = JaxLayer(jax_fn, state=[NamedPoint(1.0, 2.0, "foo")]) + layer(np.ones((1,))) + + @parameterized.named_parameters( + { + "testcase_name": "sequence_instead_of_mapping", + "init_state": [0.0], + "error_regex": "Expected dict, got ", + }, + { + "testcase_name": "mapping_instead_of_sequence", + "init_state": {"state": {"foo": 0.0}}, + "error_regex": "Expected list, got ", + }, + { + "testcase_name": "sequence_instead_of_variable", + "init_state": {"state": [[0.0]]}, + "error_regex": "Structure mismatch", + }, + { + "testcase_name": "no_initial_state", + "init_state": None, + "error_regex": "Expected dict, got None", + }, + { + "testcase_name": "missing_dict_key", + "init_state": {"state": {}}, + "error_regex": "Expected list, got ", + }, + { + "testcase_name": "missing_variable_in_list", + "init_state": {"state": {"foo": [2.0]}}, + "error_regex": "Expected list, got ", + }, + ) + def test_state_mismatch_during_update(self, init_state, error_regex): + def jax_fn(params, state, inputs): + return inputs, {"state": [jnp.ones([])]} + + layer = JaxLayer(jax_fn, params={}, state=init_state) + with self.assertRaisesRegex(ValueError, error_regex): + layer(np.ones((1,))) + + def test_rng_seeding(self): + def jax_init(rng, inputs): + return [jax.nn.initializers.normal(1.0)(rng, inputs.shape)] + + def jax_apply(params, inputs): + return jnp.dot(inputs, params[0]) + + shape = (2, 2) + + utils.set_random_seed(0) + layer1 = JaxLayer(jax_apply, jax_init) + layer1.build(shape) + utils.set_random_seed(0) + layer2 = JaxLayer(jax_apply, jax_init) + layer2.build(shape) + self.assertAllClose(layer1.params[0], layer2.params[0]) diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py new file mode 100644 index 000000000000..d5375785f762 --- /dev/null +++ b/keras/src/utils/jax_utils.py @@ -0,0 +1,11 @@ +from keras.src import backend + + +def is_in_jax_tracing_scope(x=None): + if backend.backend() == "jax": + if x is None: + x = backend.numpy.ones(()) + for c in x.__class__.__mro__: + if c.__name__ == "Tracer" and c.__module__.startswith("jax"): + return True + return False diff --git a/keras/src/utils/model_visualization.py b/keras/src/utils/model_visualization.py new file mode 100644 index 000000000000..fb5ec22ceaa4 --- /dev/null +++ b/keras/src/utils/model_visualization.py @@ -0,0 +1,520 @@ +"""Utilities related to model visualization.""" + +import os +import sys + +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.utils import io_utils + +try: + import pydot +except ImportError: + # pydot_ng and pydotplus are older forks of pydot + # which may still be used by some users + try: + import pydot_ng as pydot + except ImportError: + try: + import pydotplus as pydot + except ImportError: + pydot = None + + +def check_pydot(): + """Returns True if PyDot is available.""" + return pydot is not None + + +def check_graphviz(): + """Returns True if both PyDot and Graphviz are available.""" + if not check_pydot(): + return False + try: + # Attempt to create an image of a blank graph + # to check the pydot/graphviz installation. + pydot.Dot.create(pydot.Dot()) + return True + except (OSError, pydot.PydotException): + return False + + +def add_edge(dot, src, dst): + src_id = str(id(src)) + dst_id = str(id(dst)) + if not dot.get_edge(src_id, dst_id): + edge = pydot.Edge(src_id, dst_id) + edge.set("penwidth", "2") + dot.add_edge(edge) + + +def get_layer_activation_name(layer): + if hasattr(layer.activation, "name"): + activation_name = layer.activation.name + elif hasattr(layer.activation, "__name__"): + activation_name = layer.activation.__name__ + else: + activation_name = str(layer.activation) + return activation_name + + +def make_layer_label(layer, **kwargs): + class_name = layer.__class__.__name__ + + show_layer_names = kwargs.pop("show_layer_names") + show_layer_activations = kwargs.pop("show_layer_activations") + show_dtype = kwargs.pop("show_dtype") + show_shapes = kwargs.pop("show_shapes") + show_trainable = kwargs.pop("show_trainable") + if kwargs: + raise ValueError(f"Invalid kwargs: {kwargs}") + + table = ( + '<' + ) + + colspan_max = sum(int(x) for x in (show_dtype, show_trainable)) + if show_shapes: + colspan_max += 2 + colspan = max(1, colspan_max) + + if show_layer_names: + table += ( + f'" + ) + else: + table += ( + f'" + ) + if ( + show_layer_activations + and hasattr(layer, "activation") + and layer.activation is not None + ): + table += ( + f'" + ) + + cols = [] + if show_shapes: + input_shape = None + output_shape = None + try: + input_shape = tree.map_structure(lambda x: x.shape, layer.input) + output_shape = tree.map_structure(lambda x: x.shape, layer.output) + except (ValueError, AttributeError): + pass + + def format_shape(shape): + if shape is not None: + if isinstance(shape, dict): + shape_str = ", ".join( + [f"{k}: {v}" for k, v in shape.items()] + ) + else: + shape_str = f"{shape}" + shape_str = shape_str.replace("}", "").replace("{", "") + else: + shape_str = "?" + return shape_str + + if class_name != "InputLayer": + cols.append( + ( + '" + ) + ) + cols.append( + ( + '" + ) + ) + if show_dtype: + dtype = None + try: + dtype = tree.map_structure(lambda x: x.dtype, layer.output) + except (ValueError, AttributeError): + pass + cols.append( + ( + '" + ) + ) + if show_trainable and hasattr(layer, "trainable") and layer.weights: + if layer.trainable: + cols.append( + ( + '" + ) + ) + else: + cols.append( + ( + '" + ) + ) + if cols: + colspan = len(cols) + else: + colspan = 1 + + if cols: + table += f"{''.join(cols)}" + table += "
' + '' + f"{layer.name} ({class_name})" + "
' + '' + f"{class_name}" + "
' + '' + f"Activation: {get_layer_activation_name(layer)}" + "
' + f"Input shape: {format_shape(input_shape)}" + "' + f"Output shape: {format_shape(output_shape)}" + "' + f"Output dtype: {dtype or '?'}" + "' + '' + "Trainable' + '' + "Non-trainable
>" + return table + + +def make_node(layer, **kwargs): + node = pydot.Node(str(id(layer)), label=make_layer_label(layer, **kwargs)) + node.set("fontname", "Helvetica") + node.set("border", "0") + node.set("margin", "0") + return node + + +@keras_export("keras.utils.model_to_dot") +def model_to_dot( + model, + show_shapes=False, + show_dtype=False, + show_layer_names=True, + rankdir="TB", + expand_nested=False, + dpi=200, + subgraph=False, + show_layer_activations=False, + show_trainable=False, + **kwargs, +): + """Convert a Keras model to dot format. + + Args: + model: A Keras model instance. + show_shapes: whether to display shape information. + show_dtype: whether to display layer dtypes. + show_layer_names: whether to display layer names. + rankdir: `rankdir` argument passed to PyDot, + a string specifying the format of the plot: `"TB"` + creates a vertical plot; `"LR"` creates a horizontal plot. + expand_nested: whether to expand nested Functional models + into clusters. + dpi: Image resolution in dots per inch. + subgraph: whether to return a `pydot.Cluster` instance. + show_layer_activations: Display layer activations (only for layers that + have an `activation` property). + show_trainable: whether to display if a layer is trainable. + + Returns: + A `pydot.Dot` instance representing the Keras model or + a `pydot.Cluster` instance representing nested model if + `subgraph=True`. + """ + from keras.src.ops.function import make_node_key + + if not model.built: + raise ValueError( + "This model has not yet been built. " + "Build the model first by calling `build()` or by calling " + "the model on a batch of data." + ) + + from keras.src.models import functional + from keras.src.models import sequential + + # from keras.src.layers import Wrapper + + if not check_pydot(): + raise ImportError( + "You must install pydot (`pip install pydot`) for " + "model_to_dot to work." + ) + + if subgraph: + dot = pydot.Cluster(style="dashed", graph_name=model.name) + dot.set("label", model.name) + dot.set("labeljust", "l") + else: + dot = pydot.Dot() + dot.set("rankdir", rankdir) + dot.set("concentrate", True) + dot.set("dpi", dpi) + dot.set("splines", "ortho") + dot.set_node_defaults(shape="record") + + if kwargs.pop("layer_range", None) is not None: + raise ValueError("Argument `layer_range` is no longer supported.") + if kwargs: + raise ValueError(f"Unrecognized keyword arguments: {kwargs}") + + kwargs = { + "show_layer_names": show_layer_names, + "show_layer_activations": show_layer_activations, + "show_dtype": show_dtype, + "show_shapes": show_shapes, + "show_trainable": show_trainable, + } + + if isinstance(model, sequential.Sequential): + layers = model.layers + elif not isinstance(model, functional.Functional): + # We treat subclassed models as a single node. + node = make_node(model, **kwargs) + dot.add_node(node) + return dot + else: + layers = model._operations + + # Create graph nodes. + for i, layer in enumerate(layers): + # Process nested functional and sequential models. + if expand_nested and isinstance( + layer, (functional.Functional, sequential.Sequential) + ): + submodel = model_to_dot( + layer, + show_shapes, + show_dtype, + show_layer_names, + rankdir, + expand_nested, + subgraph=True, + show_layer_activations=show_layer_activations, + show_trainable=show_trainable, + ) + dot.add_subgraph(submodel) + + else: + node = make_node(layer, **kwargs) + dot.add_node(node) + + # Connect nodes with edges. + if isinstance(model, sequential.Sequential): + if not expand_nested: + # Single Sequential case. + for i in range(len(layers) - 1): + add_edge(dot, layers[i], layers[i + 1]) + return dot + else: + # The first layer is connected to the `InputLayer`, which is not + # represented for Sequential models, so we skip it. What will draw + # the incoming edge from outside of the sequential model is the + # edge connecting the Sequential model itself. + layers = model.layers[1:] + + # Functional and nested Sequential case. + for layer in layers: + # Go from current layer to input `Node`s. + for inbound_index, inbound_node in enumerate(layer._inbound_nodes): + # `inbound_node` is a `Node`. + if ( + isinstance(model, functional.Functional) + and make_node_key(layer, inbound_index) not in model._nodes + ): + continue + + # Go from input `Node` to `KerasTensor` representing that input. + for input_index, input_tensor in enumerate( + inbound_node.input_tensors + ): + # `input_tensor` is a `KerasTensor`. + # `input_history` is a `KerasHistory`. + input_history = input_tensor._keras_history + if input_history.operation is None: + # Operation is `None` for `Input` tensors. + continue + + # Go from input `KerasTensor` to the `Operation` that produced + # it as an output. + input_node = input_history.operation._inbound_nodes[ + input_history.node_index + ] + output_index = input_history.tensor_index + + # Tentative source and destination of the edge. + source = input_node.operation + destination = layer + + if not expand_nested: + # No nesting, connect directly. + add_edge(dot, source, layer) + continue + + # ==== Potentially nested models case ==== + + # ---- Resolve the source of the edge ---- + while isinstance( + source, + (functional.Functional, sequential.Sequential), + ): + # When `source` is a `Functional` or `Sequential` model, we + # need to connect to the correct box within that model. + # Functional and sequential models do not have explicit + # "output" boxes, so we need to find the correct layer that + # produces the output we're connecting to, which can be + # nested several levels deep in sub-models. Hence the while + # loop to continue going into nested models until we + # encounter a real layer that's not a `Functional` or + # `Sequential`. + source, _, output_index = source.outputs[ + output_index + ]._keras_history + + # ---- Resolve the destination of the edge ---- + while isinstance( + destination, + (functional.Functional, sequential.Sequential), + ): + if isinstance(destination, functional.Functional): + # When `destination` is a `Functional`, we point to the + # specific `InputLayer` in the model. + destination = destination.inputs[ + input_index + ]._keras_history.operation + else: + # When `destination` is a `Sequential`, there is no + # explicit "input" box, so we want to point to the first + # box in the model, but it may itself be another model. + # Hence the while loop to continue going into nested + # models until we encounter a real layer that's not a + # `Functional` or `Sequential`. + destination = destination.layers[0] + + add_edge(dot, source, destination) + return dot + + +@keras_export("keras.utils.plot_model") +def plot_model( + model, + to_file="model.png", + show_shapes=False, + show_dtype=False, + show_layer_names=False, + rankdir="TB", + expand_nested=False, + dpi=200, + show_layer_activations=False, + show_trainable=False, + **kwargs, +): + """Converts a Keras model to dot format and save to a file. + + Example: + + ```python + inputs = ... + outputs = ... + model = keras.Model(inputs=inputs, outputs=outputs) + + dot_img_file = '/tmp/model_1.png' + keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) + ``` + + Args: + model: A Keras model instance + to_file: File name of the plot image. + show_shapes: whether to display shape information. + show_dtype: whether to display layer dtypes. + show_layer_names: whether to display layer names. + rankdir: `rankdir` argument passed to PyDot, + a string specifying the format of the plot: `"TB"` + creates a vertical plot; `"LR"` creates a horizontal plot. + expand_nested: whether to expand nested Functional models + into clusters. + dpi: Image resolution in dots per inch. + show_layer_activations: Display layer activations (only for layers that + have an `activation` property). + show_trainable: whether to display if a layer is trainable. + + Returns: + A Jupyter notebook Image object if Jupyter is installed. + This enables in-line display of the model plots in notebooks. + """ + + if not model.built: + raise ValueError( + "This model has not yet been built. " + "Build the model first by calling `build()` or by calling " + "the model on a batch of data." + ) + if not check_pydot(): + message = ( + "You must install pydot (`pip install pydot`) " + "for `plot_model` to work." + ) + if "IPython.core.magics.namespace" in sys.modules: + # We don't raise an exception here in order to avoid crashing + # notebook tests where graphviz is not available. + io_utils.print_msg(message) + return + else: + raise ImportError(message) + if not check_graphviz(): + message = ( + "You must install graphviz " + "(see instructions at https://graphviz.gitlab.io/download/) " + "for `plot_model` to work." + ) + if "IPython.core.magics.namespace" in sys.modules: + # We don't raise an exception here in order to avoid crashing + # notebook tests where graphviz is not available. + io_utils.print_msg(message) + return + else: + raise ImportError(message) + + if kwargs.pop("layer_range", None) is not None: + raise ValueError("Argument `layer_range` is no longer supported.") + if kwargs: + raise ValueError(f"Unrecognized keyword arguments: {kwargs}") + + dot = model_to_dot( + model, + show_shapes=show_shapes, + show_dtype=show_dtype, + show_layer_names=show_layer_names, + rankdir=rankdir, + expand_nested=expand_nested, + dpi=dpi, + show_layer_activations=show_layer_activations, + show_trainable=show_trainable, + ) + to_file = str(to_file) + if dot is None: + return + _, extension = os.path.splitext(to_file) + if not extension: + extension = "png" + else: + extension = extension[1:] + # Save image to disk. + dot.write(to_file, format=extension) + # Return the image as a Jupyter Image object, to be displayed in-line. + # Note that we cannot easily detect whether the code is running in a + # notebook, and thus we always return the Image if Jupyter is available. + if extension != "pdf": + try: + from IPython import display + + return display.Image(filename=to_file) + except ImportError: + pass diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py new file mode 100644 index 000000000000..286394a99358 --- /dev/null +++ b/keras/src/utils/module_utils.py @@ -0,0 +1,61 @@ +import importlib + + +class LazyModule: + def __init__(self, name, pip_name=None, import_error_msg=None): + self.name = name + self.pip_name = pip_name or name + self.import_error_msg = import_error_msg or ( + f"This requires the {self.name} module. " + f"You can install it via `pip install {self.pip_name}`" + ) + self.module = None + self._available = None + + @property + def available(self): + if self._available is None: + try: + self.initialize() + self._available = True + except ImportError: + self._available = False + return self._available + + def initialize(self): + try: + self.module = importlib.import_module(self.name) + except ImportError: + raise ImportError(self.import_error_msg) + + def __getattr__(self, name): + if name == "_api_export_path": + raise AttributeError + if self.module is None: + self.initialize() + return getattr(self.module, name) + + def __repr__(self): + return f"LazyModule({self.name})" + + +tensorflow = LazyModule("tensorflow") +gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow") +tensorflow_io = LazyModule("tensorflow_io") +scipy = LazyModule("scipy") +jax = LazyModule("jax") +torch_xla = LazyModule( + "torch_xla", + import_error_msg=( + "This requires the torch_xla module. You can install it via " + "`pip install torch-xla`. Additionally, you may need to update " + "LD_LIBRARY_PATH if necessary. Torch XLA builds a shared library, " + "_XLAC.so, which needs to link to the version of Python it was built " + "with. Use the following command to update LD_LIBRARY_PATH: " + "`export LD_LIBRARY_PATH=/lib:$LD_LIBRARY_PATH`" + ), +) +optree = LazyModule("optree") +dmtree = LazyModule("tree") +tf2onnx = LazyModule("tf2onnx") +grain = LazyModule("grain") diff --git a/keras/src/utils/naming.py b/keras/src/utils/naming.py new file mode 100644 index 000000000000..28107f0f30f4 --- /dev/null +++ b/keras/src/utils/naming.py @@ -0,0 +1,73 @@ +import collections +import re + +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +def auto_name(prefix): + prefix = to_snake_case(prefix) + return uniquify(prefix) + + +def uniquify(name): + object_name_uids = global_state.get_global_attribute( + "object_name_uids", + default=collections.defaultdict(int), + set_to_default=True, + ) + if name in object_name_uids: + unique_name = f"{name}_{object_name_uids[name]}" + else: + unique_name = name + object_name_uids[name] += 1 + return unique_name + + +def to_snake_case(name): + name = re.sub(r"\W+", "", name) + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower() + return name + + +@keras_export("keras.backend.get_uid") +def get_uid(prefix=""): + """Associates a string prefix with an integer counter. + + Args: + prefix: String prefix to index. + + Returns: + Unique integer ID. + + Example: + + >>> get_uid('dense') + 1 + >>> get_uid('dense') + 2 + """ + object_name_uids = global_state.get_global_attribute( + "object_name_uids", + default=collections.defaultdict(int), + set_to_default=True, + ) + object_name_uids[prefix] += 1 + return object_name_uids[prefix] + + +def reset_uids(): + global_state.set_global_attribute( + "object_name_uids", collections.defaultdict(int) + ) + + +def get_object_name(obj): + if hasattr(obj, "name"): # Most Keras objects. + return obj.name + elif hasattr(obj, "__name__"): # Function. + return to_snake_case(obj.__name__) + elif hasattr(obj, "__class__"): # Class instance. + return to_snake_case(obj.__class__.__name__) + return to_snake_case(str(obj)) diff --git a/keras/src/utils/naming_test.py b/keras/src/utils/naming_test.py new file mode 100644 index 000000000000..25adc45885d5 --- /dev/null +++ b/keras/src/utils/naming_test.py @@ -0,0 +1,119 @@ +from keras.src.testing import test_case +from keras.src.utils import naming + + +class NamingUtilsTest(test_case.TestCase): + def test_uniquify_unique_name(self): + name = "the_unique_name" + unique_name = naming.uniquify(name) + self.assertEqual(unique_name, name) + + def test_auto_name(self): + self.assertEqual(naming.auto_name("unique_name"), "unique_name") + self.assertEqual(naming.auto_name("unique_name"), "unique_name_1") + self.assertEqual(naming.auto_name("unique_name"), "unique_name_2") + + def test_get_uid(self): + self.assertEqual(naming.get_uid("very_unique_name"), 1) + self.assertEqual(naming.get_uid("very_unique_name"), 2) + self.assertEqual(naming.get_uid("very_unique_name"), 3) + + def test_uniquify_non_unique_name(self): + name = "non_unique_name" + naming.uniquify(name) + unique_name = naming.uniquify(name) + self.assertEqual(unique_name, f"{name}_1") + + def test_to_snake_case_snake_case_name(self): + name = "snake_case_name" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, name) + + def test_get_uid_existing_prefix(self): + prefix = "existing_prefix" + naming.get_uid(prefix) + uid = naming.get_uid(prefix) + self.assertEqual(uid, 2) + + def test_reset_uids(self): + naming.get_uid("unique_name") + naming.reset_uids() + uid = naming.get_uid("unique_name") + self.assertEqual(uid, 1) + + def test_get_object_name_no_name_attribute(self): + class ObjectWithoutName: + __name__ = "ObjectWithoutName" + + obj = ObjectWithoutName() + object_name = naming.get_object_name(obj) + self.assertEqual(object_name, "object_without_name") + + def test_get_object_name_no_name_or_class_attribute(self): + class ObjectWithoutNameOrClass: + pass + + obj = ObjectWithoutNameOrClass() + object_name = naming.get_object_name(obj) + self.assertEqual(object_name, "object_without_name_or_class") + + def test_uniquify_already_uniquified_name(self): + name = "unique_name" + unique_name = naming.uniquify(name) + new_unique_name = naming.uniquify(unique_name) + + # first time `name` is uniquified so returns same name + self.assertEqual(name, unique_name) + + # second time `name` is uniquified should be different + # from the first output + self.assertNotEqual(new_unique_name, unique_name) + + def test_to_snake_case_capital_after_any_character(self): + name = "myVariableNameHere" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, "my_variable_name_here") + + def test_to_snake_case_lower_before_upper(self): + name = "convertTHIS" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, "convert_this") + + def test_to_snake_case_already_snake_cased(self): + name = "already_snake_cased" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, name) + + def test_to_snake_case_no_changes(self): + name = "lowercase" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, name) + + def test_to_snake_case_single_uppercase_word(self): + name = "UPPERCASE" + snake_case_name = naming.to_snake_case(name) + self.assertEqual(snake_case_name, "uppercase") + + def test_get_object_name_for_keras_objects(self): + class MockKerasObject: + name = "mock_object" + + obj = MockKerasObject() + result = naming.get_object_name(obj) + self.assertEqual( + result, "mock_object", f"Expected 'mock_object' but got {result}" + ) + + # Test for function objects that have a `__name__` attribute. + def test_get_object_name_for_functions(self): + def mock_function(): + pass + + result = naming.get_object_name(mock_function) + # Assumes to_snake_case works correctly. + expected_name = naming.to_snake_case(mock_function.__name__) + self.assertEqual( + result, + expected_name, + f"Expected '{expected_name}' but got {result}", + ) diff --git a/keras/src/utils/numerical_utils.py b/keras/src/utils/numerical_utils.py new file mode 100644 index 000000000000..7a04299f13c3 --- /dev/null +++ b/keras/src/utils/numerical_utils.py @@ -0,0 +1,224 @@ +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.utils import tf_utils + + +@keras_export("keras.utils.normalize") +def normalize(x, axis=-1, order=2): + """Normalizes an array. + + If the input is a NumPy array, a NumPy array will be returned. + If it's a backend tensor, a backend tensor will be returned. + + Args: + x: Array to normalize. + axis: axis along which to normalize. + order: Normalization order (e.g. `order=2` for L2 norm). + + Returns: + A normalized copy of the array. + """ + from keras.src import ops + + if isinstance(x, np.ndarray): + # NumPy input + norm = np.atleast_1d(np.linalg.norm(x, order, axis)) + norm[norm == 0] = 1 + + # axis cannot be `None` + axis = axis or -1 + return x / np.expand_dims(norm, axis) + + # Backend tensor input + return ops.nn.normalize(x, axis=axis, order=order) + + +@keras_export("keras.utils.to_categorical") +def to_categorical(x, num_classes=None): + """Converts a class vector (integers) to binary class matrix. + + E.g. for use with `categorical_crossentropy`. + + Args: + x: Array-like with class values to be converted into a matrix + (integers from 0 to `num_classes - 1`). + num_classes: Total number of classes. If `None`, this would be inferred + as `max(x) + 1`. Defaults to `None`. + + Returns: + A binary matrix representation of the input as a NumPy array. The class + axis is placed last. + + Example: + + >>> a = keras.utils.to_categorical([0, 1, 2, 3], num_classes=4) + >>> print(a) + [[1. 0. 0. 0.] + [0. 1. 0. 0.] + [0. 0. 1. 0.] + [0. 0. 0. 1.]] + + >>> b = np.array([.9, .04, .03, .03, + ... .3, .45, .15, .13, + ... .04, .01, .94, .05, + ... .12, .21, .5, .17]).reshape(4,4) + >>> loss = keras.ops.categorical_crossentropy(a, b) + >>> print(np.around(loss, 5)) + [0.10536 0.82807 0.1011 1.77196] + + >>> loss = keras.ops.categorical_crossentropy(a, a) + >>> print(np.around(loss, 5)) + [0. 0. 0. 0.] + """ + if backend.is_tensor(x): + input_shape = backend.core.shape(x) + # Shrink the last dimension if the shape is (..., 1). + if ( + input_shape is not None + and len(input_shape) > 1 + and input_shape[-1] == 1 + ): + newshape = tuple(input_shape[:-1]) + x = backend.numpy.reshape(x, newshape) + return backend.nn.one_hot(x, num_classes) + x = np.array(x, dtype="int64") + input_shape = x.shape + + # Shrink the last dimension if the shape is (..., 1). + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + + x = x.reshape(-1) + if not num_classes: + num_classes = np.max(x) + 1 + batch_size = x.shape[0] + categorical = np.zeros((batch_size, num_classes)) + categorical[np.arange(batch_size), x] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +def encode_categorical_inputs( + inputs, + output_mode, + depth, + dtype, + sparse=False, + count_weights=None, + backend_module=None, +): + """Encodes categorical inputs according to output_mode. + + Args: + inputs: the inputs to encode. + output_mode: one of `"int"`, `"one_hot"`, `"multi_hot"`, or `"count"`. + depth: number of classes, this will be the last dimension of the output. + dtype: the dtype of the output, unless `count_weights` is not `None`. + sparse: whether the output should be sparse for backends supporting it. + count_weights: weights to apply if `output_mode` is `"count"`. + backend_module: the backend to use instead of the current one. + + Returns: the encoded inputs. + """ + backend_module = backend_module or backend + + if output_mode == "int": + return backend_module.cast(inputs, dtype=dtype) + + rank_of_inputs = len(backend_module.shape(inputs)) + + # In all cases, we should uprank scalar input to a single sample. + if rank_of_inputs == 0: + inputs = backend_module.numpy.expand_dims(inputs, -1) + rank_of_inputs = 1 + + if ( + backend_module.__name__.endswith("tensorflow") + and rank_of_inputs <= 2 + and output_mode in ("multi_hot", "count") + ): + # TF only fastpath. Uses bincount; faster. Doesn't work for rank 3+. + try: + return tf_utils.tf_encode_categorical_inputs( + inputs, + output_mode, + depth, + dtype=dtype, + sparse=sparse, + count_weights=count_weights, + ) + except ValueError: + pass + + if output_mode == "multi_hot": + return backend_module.nn.multi_hot( + inputs, depth, dtype=dtype, sparse=sparse + ) + elif output_mode == "one_hot": + input_shape = backend_module.core.shape(inputs) + # Shrink the last dimension if the shape is (..., 1). + if ( + input_shape is not None + and len(input_shape) > 1 + and input_shape[-1] == 1 + ): + newshape = tuple(input_shape[:-1]) + inputs = backend_module.numpy.reshape(inputs, newshape) + return backend_module.nn.one_hot( + inputs, depth, dtype=dtype, sparse=sparse + ) + elif output_mode == "count": + # We don't use `ops.bincount` because its output has a dynamic shape + # (last dimension is the highest value of `inputs`). We implement a + # narrower use case where `minlength` and `maxlength` (not supported by + # `ops.bincount`) are the same and static value: `depth`. We also don't + # need to support indices that are negative or greater than `depth`. + reduction_axis = 1 if len(inputs.shape) > 1 else 0 + + if count_weights is not None: + dtype = count_weights.dtype + one_hot_encoding = backend_module.nn.one_hot( + inputs, depth, dtype=dtype, sparse=sparse + ) + if count_weights is not None: + count_weights = backend_module.numpy.expand_dims(count_weights, -1) + one_hot_encoding = one_hot_encoding * count_weights + + outputs = backend_module.numpy.sum( + one_hot_encoding, + axis=reduction_axis, + ) + return outputs + + +def build_pos_neg_masks( + query_labels, + key_labels, + remove_diagonal=True, +): + from keras.src import ops + + if ops.ndim(query_labels) == 1: + query_labels = ops.reshape(query_labels, (-1, 1)) + + if ops.ndim(key_labels) == 1: + key_labels = ops.reshape(key_labels, (-1, 1)) + + positive_mask = ops.equal(query_labels, ops.transpose(key_labels)) + negative_mask = ops.logical_not(positive_mask) + + if remove_diagonal: + positive_mask = ops.logical_and( + positive_mask, + ~ops.eye( + ops.size(query_labels), + ops.size(key_labels), + k=0, + dtype="bool", + ), + ) + + return positive_mask, negative_mask diff --git a/keras/src/utils/numerical_utils_test.py b/keras/src/utils/numerical_utils_test.py new file mode 100644 index 000000000000..9b9520abc90e --- /dev/null +++ b/keras/src/utils/numerical_utils_test.py @@ -0,0 +1,151 @@ +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.utils import numerical_utils + +NUM_CLASSES = 5 + + +class TestNumericalUtils(testing.TestCase): + @parameterized.parameters( + [ + ((1,), (1, NUM_CLASSES)), + ((3,), (3, NUM_CLASSES)), + ((4, 3), (4, 3, NUM_CLASSES)), + ((5, 4, 3), (5, 4, 3, NUM_CLASSES)), + ((3, 1), (3, NUM_CLASSES)), + ((3, 2, 1), (3, 2, NUM_CLASSES)), + ] + ) + def test_to_categorical(self, shape, expected_shape): + label = np.random.randint(0, NUM_CLASSES, shape) + one_hot = numerical_utils.to_categorical(label, NUM_CLASSES) + # Check shape + self.assertEqual(one_hot.shape, expected_shape) + # Make sure there is only one 1 in a row + self.assertTrue(np.all(one_hot.sum(axis=-1) == 1)) + # Get original labels back from one hots + self.assertTrue( + np.all(np.argmax(one_hot, -1).reshape(label.shape) == label) + ) + + def test_to_categorical_without_num_classes(self): + label = [0, 2, 5] + one_hot = numerical_utils.to_categorical(label) + self.assertEqual(one_hot.shape, (3, 5 + 1)) + + def test_to_categorical_with_backend_tensor(self): + label = backend.convert_to_tensor(np.array([0, 2, 1, 3, 4])) + expected = backend.convert_to_tensor( + np.array( + [ + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ) + ) + one_hot = numerical_utils.to_categorical(label, NUM_CLASSES) + assert backend.is_tensor(one_hot) + self.assertAllClose(one_hot, expected) + + @parameterized.parameters([1, 2, 3]) + def test_normalize(self, order): + xb = backend.random.uniform((3, 3), seed=1337) + xnp = backend.convert_to_numpy(xb) + + # Expected result + l2 = np.atleast_1d(np.linalg.norm(xnp, order, axis=-1)) + l2[l2 == 0] = 1 + expected = xnp / np.expand_dims(l2, axis=-1) + + # Test NumPy + out = numerical_utils.normalize(xnp, axis=-1, order=order) + self.assertIsInstance(out, np.ndarray) + self.assertAllClose(out, expected) + + # Test backend + out = numerical_utils.normalize(xb, axis=-1, order=order) + self.assertTrue(backend.is_tensor(out)) + self.assertAllClose(backend.convert_to_numpy(out), expected) + + def test_build_pos_neg_masks(self): + query_labels = np.array([0, 1, 2, 2, 0]) + key_labels = np.array([0, 1, 2, 0, 2]) + expected_shape = (len(query_labels), len(key_labels)) + + positive_mask, negative_mask = numerical_utils.build_pos_neg_masks( + query_labels, key_labels, remove_diagonal=False + ) + + positive_mask = backend.convert_to_numpy(positive_mask) + negative_mask = backend.convert_to_numpy(negative_mask) + self.assertEqual(positive_mask.shape, expected_shape) + self.assertEqual(negative_mask.shape, expected_shape) + self.assertTrue( + np.all(np.logical_not(np.logical_and(positive_mask, negative_mask))) + ) + + expected_positive_mask_keep_diag = np.array( + [ + [1, 0, 0, 1, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + ], + dtype="bool", + ) + self.assertTrue( + np.all(positive_mask == expected_positive_mask_keep_diag) + ) + self.assertTrue( + np.all( + negative_mask + == np.logical_not(expected_positive_mask_keep_diag) + ) + ) + + positive_mask, negative_mask = numerical_utils.build_pos_neg_masks( + query_labels, key_labels, remove_diagonal=True + ) + positive_mask = backend.convert_to_numpy(positive_mask) + negative_mask = backend.convert_to_numpy(negative_mask) + self.assertEqual(positive_mask.shape, expected_shape) + self.assertEqual(negative_mask.shape, expected_shape) + self.assertTrue( + np.all(np.logical_not(np.logical_and(positive_mask, negative_mask))) + ) + + expected_positive_mask_with_remove_diag = np.array( + [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + ], + dtype="bool", + ) + self.assertTrue( + np.all(positive_mask == expected_positive_mask_with_remove_diag) + ) + + query_labels = np.array([1, 2, 3]) + key_labels = np.array([1, 2, 3, 1]) + + positive_mask, negative_mask = numerical_utils.build_pos_neg_masks( + query_labels, key_labels, remove_diagonal=True + ) + positive_mask = backend.convert_to_numpy(positive_mask) + negative_mask = backend.convert_to_numpy(negative_mask) + expected_shape_diff_sizes = (len(query_labels), len(key_labels)) + self.assertEqual(positive_mask.shape, expected_shape_diff_sizes) + self.assertEqual(negative_mask.shape, expected_shape_diff_sizes) + self.assertTrue( + np.all(np.logical_not(np.logical_and(positive_mask, negative_mask))) + ) diff --git a/keras/src/utils/progbar.py b/keras/src/utils/progbar.py new file mode 100644 index 000000000000..c340f4037b4b --- /dev/null +++ b/keras/src/utils/progbar.py @@ -0,0 +1,262 @@ +import math +import os +import sys +import time + +from keras.src.api_export import keras_export +from keras.src.utils import io_utils + + +@keras_export("keras.utils.Progbar") +class Progbar: + """Displays a progress bar. + + Args: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). + """ + + def __init__( + self, + target, + width=20, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name="step", + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ( + (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) + or "ipykernel" in sys.modules + or "posix" in sys.modules + or "PYCHARM_HOSTED" in os.environ + ) + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + self._time_at_epoch_start = self._start + self._time_after_first_step = None + self._prev_total_width = 0 + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + + Args: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, defaults to `current >= self.target`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in + # the first epoch, both on_batch_end and on_epoch_end will be + # called, which will cause 'current' and 'self._seen_so_far' to + # have the same value. Force the minimal value to 1 here, + # otherwise stateful_metric will be 0s. + if finalize: + self._values[k] = [v, 1] + else: + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + message = "" + special_char_len = 0 + now = time.time() + time_per_unit = self._estimate_step_duration(current, now) + + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + if self._dynamic_display: + message += "\b" * self._prev_total_width + message += "\r" + else: + message += "\n" + + if self.target is not None: + numdigits = int(math.log10(self.target)) + 1 + bar = (f"%{numdigits}d/%d") % (current, self.target) + bar = f"\x1b[1m{bar}\x1b[0m " + special_char_len += 8 + prog = float(current) / self.target + prog_width = int(self.width * prog) + + if prog_width > 0: + bar += f"\33[32m{'━' * prog_width}\x1b[0m" + special_char_len += 9 + bar += f"\33[37m{'━' * (self.width - prog_width)}\x1b[0m" + special_char_len += 9 + + else: + bar = "%7d/Unknown" % current + message += bar + + # Add ETA if applicable + if self.target is not None and not finalize: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + info = f" \x1b[1m{eta_format}\x1b[0m" + else: + # Time elapsed since start, in seconds + info = f" \x1b[1m{now - self._start:.0f}s\x1b[0m" + special_char_len += 8 + + # Add time/step + info += self._format_time(time_per_unit, self.unit_name) + + # Add metrics + for k in self._values_order: + info += f" - {k}:" + if isinstance(self._values[k], list): + avg = self._values[k][0] / max(1, self._values[k][1]) + if abs(avg) > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + else: + info += f" {self._values[k]}" + message += info + + total_width = len(bar) + len(info) - special_char_len + if self._prev_total_width > total_width: + message += " " * (self._prev_total_width - total_width) + if finalize: + message += "\n" + + io_utils.print_msg(message, line_break=False) + self._prev_total_width = total_width + message = "" + + elif self.verbose == 2: + if finalize: + numdigits = int(math.log10(self.target)) + 1 + count = f"%{numdigits}d/%d" % (current, self.target) + info = f"{count} - {now - self._start:.0f}s" + info += f" -{self._format_time(time_per_unit, self.unit_name)}" + for k in self._values_order: + info += f" - {k}:" + avg = self._values[k][0] / max(1, self._values[k][1]) + if avg > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + info += "\n" + message += info + io_utils.print_msg(message, line_break=False) + message = "" + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _format_time(self, time_per_unit, unit_name): + """format a given duration to display to the user. + + Given the duration, this function formats it in either milliseconds + or seconds and displays the unit (i.e. ms/step or s/epoch). + + Args: + time_per_unit: the duration to display + unit_name: the name of the unit to display + + Returns: + A string with the correctly formatted duration and units + """ + formatted = "" + if time_per_unit >= 1 or time_per_unit == 0: + formatted += f" {time_per_unit:.0f}s/{unit_name}" + elif time_per_unit >= 1e-3: + formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" + else: + formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" + return formatted + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + + Given the step number `current` and the corresponding time `now` this + function returns an estimate for how long a single step takes. If this + is called before one step has been completed (i.e. `current == 0`) then + zero is given as an estimate. The duration estimate ignores the duration + of the (assumed to be non-representative) first step for estimates when + more steps are available (i.e. `current>1`). + + Args: + current: Index of current step. + now: The current time. + + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying + # step 1 + # 2) somebody is calling the progress bar and supplies step one + # multiple times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / ( + current - 1 + ) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit + else: + return 0 diff --git a/keras/src/utils/python_utils.py b/keras/src/utils/python_utils.py new file mode 100644 index 000000000000..28ebe95754cd --- /dev/null +++ b/keras/src/utils/python_utils.py @@ -0,0 +1,200 @@ +import binascii +import codecs +import marshal +import os +import types as python_types + + +def is_continuous_axis(axis): + # Used to determine whether the dimensions in an axis are continuous + if isinstance(axis, int) or len(axis) == 1: + return True + positive_order_flag = True + for i in range(len(axis) - 1): + if axis[i + 1] - axis[i] != 1: + positive_order_flag = False + break + + negative_order_flag = True + for i in range(len(axis) - 1): + if axis[i + 1] - axis[i] != 1: + negative_order_flag = False + break + return positive_order_flag or negative_order_flag + + +def default(method): + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method + + +def is_default(method): + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) + + +def func_dump(func): + """Serializes a user-defined function. + + Args: + func: the function to serialize. + + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == "nt": + raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") + code = codecs.encode(raw_code, "base64").decode("ascii") + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, "base64").decode("ascii") + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. + + Args: + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. + + Returns: + A function object. + """ + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Args: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + + def dummy_fn(): + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode("ascii"), "base64") + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode("raw_unicode_escape") + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure + ) + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Args: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] + + +def remove_long_seq(maxlen, seq, label): + """Removes sequences that exceed the maximum length. + + Args: + maxlen: Int, maximum length of the output sequences. + seq: List of lists, where each sublist is a sequence. + label: List where each element is an integer. + + Returns: + new_seq, new_label: shortened lists for `seq` and `label`. + """ + new_seq, new_label = [], [] + for x, y in zip(seq, label): + if len(x) < maxlen: + new_seq.append(x) + new_label.append(y) + return new_seq, new_label + + +def removeprefix(x, prefix): + """Backport of `removeprefix` from PEP-616 (Python 3.9+)""" + + if len(prefix) > 0 and x.startswith(prefix): + return x[len(prefix) :] + else: + return x + + +def removesuffix(x, suffix): + """Backport of `removesuffix` from PEP-616 (Python 3.9+)""" + + if len(suffix) > 0 and x.endswith(suffix): + return x[: -len(suffix)] + else: + return x + + +def remove_by_id(lst, value): + """Remove a value from a list by id.""" + for i, v in enumerate(lst): + if id(v) == id(value): + del lst[i] + return + + +def pythonify_logs(logs): + """Flatten and convert log values to Python-native types. + + This function attempts to convert dict value by `float(value)` and skips + the conversion if it fails. + + Args: + logs: A dict containing log values. + + Returns: + A flattened dict with values converted to Python-native types if + possible. + """ + from keras.src import backend + + logs = logs or {} + result = {} + for key, value in sorted(logs.items()): + if isinstance(value, dict): + result.update(pythonify_logs(value)) + else: + try: + # Prevent torch compiler from breaking the graph. + if backend.is_tensor(value): + value = backend.convert_to_numpy(value) + value = float(value) + except: + pass + result[key] = value + return result diff --git a/keras/src/utils/python_utils_test.py b/keras/src/utils/python_utils_test.py new file mode 100644 index 000000000000..2ca2a72d341c --- /dev/null +++ b/keras/src/utils/python_utils_test.py @@ -0,0 +1,101 @@ +import base64 +import marshal + +from keras.src import testing +from keras.src.utils import python_utils + + +class PythonUtilsTest(testing.TestCase): + def test_func_dump_and_load(self): + def my_function(x, y=1, **kwargs): + return x + y + + serialized = python_utils.func_dump(my_function) + deserialized = python_utils.func_load(serialized) + self.assertEqual(deserialized(2, y=3), 5) + + def test_removesuffix(self): + x = "model.keras" + self.assertEqual(python_utils.removesuffix(x, ".keras"), "model") + self.assertEqual(python_utils.removesuffix(x, "model"), x) + + def test_removeprefix(self): + x = "model.keras" + self.assertEqual(python_utils.removeprefix(x, "model"), ".keras") + self.assertEqual(python_utils.removeprefix(x, ".keras"), x) + + def test_func_load_defaults_as_tuple(self): + # Using tuple as a default argument + def dummy_function(x=(1, 2, 3)): + pass + + serialized = python_utils.func_dump(dummy_function) + deserialized = python_utils.func_load(serialized) + # Ensure that the defaults are still a tuple + self.assertIsInstance(deserialized.__defaults__[0], tuple) + # Ensure that the tuple default remains unchanged + self.assertEqual(deserialized.__defaults__[0], (1, 2, 3)) + + def test_remove_long_seq_standard_case(self): + sequences = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]] + labels = [1, 2, 3, 4] + new_sequences, new_labels = python_utils.remove_long_seq( + 3, sequences, labels + ) + self.assertEqual(new_sequences, [[1], [2, 2]]) + self.assertEqual(new_labels, [1, 2]) + + def test_func_load_with_closure(self): + def outer_fn(x): + def inner_fn(y): + return x + y + + return inner_fn + + func_with_closure = outer_fn(10) + serialized = python_utils.func_dump(func_with_closure) + deserialized = python_utils.func_load(serialized) + self.assertEqual(deserialized(5), 15) + + def test_func_load_closure_conversion(self): + def my_function_with_closure(x): + return x + y + + y = 5 + serialized = python_utils.func_dump(my_function_with_closure) + deserialized = python_utils.func_load(serialized) + self.assertEqual(deserialized(5), 10) + + def test_ensure_value_to_cell(self): + value_to_test = "test_value" + + def dummy_fn(): + value_to_test + + cell_value = dummy_fn.__closure__[0].cell_contents + self.assertEqual(value_to_test, cell_value) + + def test_closure_processing(self): + def simple_function(x): + return x + 10 + + serialized = python_utils.func_dump(simple_function) + deserialized = python_utils.func_load(serialized) + self.assertEqual(deserialized(5), 15) + + def test_func_load_valid_encoded_code(self): + def another_simple_function(x): + return x * 2 + + raw_data = marshal.dumps(another_simple_function.__code__) + valid_encoded_code = base64.b64encode(raw_data).decode("utf-8") + + try: + python_utils.func_load(valid_encoded_code) + except (UnicodeEncodeError, ValueError): + self.fail("Expected no error for valid code, but got an error.") + + def test_func_load_bad_encoded_code(self): + bad_encoded_code = "This isn't valid base64!" + with self.assertRaises(AttributeError): + python_utils.func_load(bad_encoded_code) diff --git a/keras/src/utils/rng_utils.py b/keras/src/utils/rng_utils.py new file mode 100644 index 000000000000..dd45021d1c25 --- /dev/null +++ b/keras/src/utils/rng_utils.py @@ -0,0 +1,71 @@ +import random + +import numpy as np + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state +from keras.src.utils.module_utils import tensorflow as tf + +GLOBAL_RANDOM_SEED = "global_random_seed" + + +@keras_export("keras.utils.set_random_seed") +def set_random_seed(seed): + """Sets all random seeds (Python, NumPy, and backend framework, e.g. TF). + + You can use this utility to make almost any Keras program fully + deterministic. Some limitations apply in cases where network communications + are involved (e.g. parameter server distribution), which creates additional + sources of randomness, or when certain non-deterministic cuDNN ops are + involved. + + Calling this utility is equivalent to the following: + + ```python + import random + random.seed(seed) + + import numpy as np + np.random.seed(seed) + + import tensorflow as tf # Only if TF is installed + tf.random.set_seed(seed) + + import torch # Only if the backend is 'torch' + torch.manual_seed(seed) + ``` + + Note that the TensorFlow seed is set even if you're not using TensorFlow + as your backend framework, since many workflows leverage `tf.data` + pipelines (which feature random shuffling). Likewise many workflows + might leverage NumPy APIs. + + Arguments: + seed: Integer, the random seed to use. + """ + if not isinstance(seed, int): + raise ValueError( + "Expected `seed` argument to be an integer. " + f"Received: seed={seed} (of type {type(seed)})" + ) + + # Store seed in global state so we can query it if set. + global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed) + random.seed(seed) + np.random.seed(seed) + if tf.available: + tf.random.set_seed(seed) + if backend.backend() == "torch": + import torch + + torch.manual_seed(seed) + + +def get_random_seed(): + """Returns the explicit integer random seed if set. + + If the seed has been explicitly set via `set_random_seed`, then + returns the seed. Otherwise, returns `None`. + """ + return global_state.get_global_attribute(GLOBAL_RANDOM_SEED) diff --git a/keras/src/utils/rng_utils_test.py b/keras/src/utils/rng_utils_test.py new file mode 100644 index 000000000000..aef96ddacc43 --- /dev/null +++ b/keras/src/utils/rng_utils_test.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest +import tensorflow as tf + +import keras +from keras.src import backend +from keras.src.testing import test_case +from keras.src.utils import rng_utils + + +class TestRandomSeedSetting(test_case.TestCase): + @pytest.mark.skipif( + backend.backend() == "numpy", + reason="Numpy backend does not support random seed setting.", + ) + def test_set_random_seed(self): + def get_model_output(): + model = keras.Sequential( + [ + keras.layers.Dense(10), + keras.layers.Dropout(0.5), + keras.layers.Dense(10), + ] + ) + x = np.random.random((32, 10)).astype("float32") + ds = tf.data.Dataset.from_tensor_slices(x).shuffle(32).batch(16) + return model.predict(ds) + + rng_utils.set_random_seed(42) + y1 = get_model_output() + rng_utils.set_random_seed(42) + y2 = get_model_output() + self.assertAllClose(y1, y2) diff --git a/keras/src/utils/sequence_utils.py b/keras/src/utils/sequence_utils.py new file mode 100644 index 000000000000..cfb27ef25de6 --- /dev/null +++ b/keras/src/utils/sequence_utils.py @@ -0,0 +1,139 @@ +import numpy as np + +from keras.src.api_export import keras_export + + +@keras_export( + [ + "keras.utils.pad_sequences", + "keras.preprocessing.sequence.pad_sequences", + ] +) +def pad_sequences( + sequences, + maxlen=None, + dtype="int32", + padding="pre", + truncating="pre", + value=0.0, +): + """Pads sequences to the same length. + + This function transforms a list (of length `num_samples`) + of sequences (lists of integers) + into a 2D NumPy array of shape `(num_samples, num_timesteps)`. + `num_timesteps` is either the `maxlen` argument if provided, + or the length of the longest sequence in the list. + + Sequences that are shorter than `num_timesteps` + are padded with `value` until they are `num_timesteps` long. + + Sequences longer than `num_timesteps` are truncated + so that they fit the desired length. + + The position where padding or truncation happens is determined by + the arguments `padding` and `truncating`, respectively. + Pre-padding or removing values from the beginning of the sequence is the + default. + + >>> sequence = [[1], [2, 3], [4, 5, 6]] + >>> keras.utils.pad_sequences(sequence) + array([[0, 0, 1], + [0, 2, 3], + [4, 5, 6]], dtype=int32) + + >>> keras.utils.pad_sequences(sequence, value=-1) + array([[-1, -1, 1], + [-1, 2, 3], + [ 4, 5, 6]], dtype=int32) + + >>> keras.utils.pad_sequences(sequence, padding='post') + array([[1, 0, 0], + [2, 3, 0], + [4, 5, 6]], dtype=int32) + + >>> keras.utils.pad_sequences(sequence, maxlen=2) + array([[0, 1], + [2, 3], + [5, 6]], dtype=int32) + + Args: + sequences: List of sequences (each sequence is a list of integers). + maxlen: Optional Int, maximum length of all sequences. If not provided, + sequences will be padded to the length of the longest individual + sequence. + dtype: (Optional, defaults to `"int32"`). Type of the output sequences. + To pad sequences with variable length strings, you can use `object`. + padding: String, "pre" or "post" (optional, defaults to `"pre"`): + pad either before or after each sequence. + truncating: String, "pre" or "post" (optional, defaults to `"pre"`): + remove values from sequences larger than + `maxlen`, either at the beginning or at the end of the sequences. + value: Float or String, padding value. (Optional, defaults to `0.`) + + Returns: + NumPy array with shape `(len(sequences), maxlen)` + """ + if not hasattr(sequences, "__len__"): + raise ValueError("`sequences` must be iterable.") + num_samples = len(sequences) + + lengths = [] + sample_shape = () + flag = True + + # take the sample shape from the first non empty sequence + # checking for consistency in the main loop below. + + for x in sequences: + try: + lengths.append(len(x)) + if flag and len(x): + sample_shape = np.asarray(x).shape[1:] + flag = False + except TypeError as e: + raise ValueError( + "`sequences` must be a list of iterables. " + f"Found non-iterable: {str(x)}" + ) from e + + if maxlen is None: + maxlen = np.max(lengths) + + is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype( + dtype, np.str_ + ) + if isinstance(value, str) and dtype is not object and not is_dtype_str: + raise ValueError( + f"`dtype` {dtype} is not compatible with `value`'s type: " + f"{type(value)}\nYou should set `dtype=object` for variable length " + "strings." + ) + + x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) + for idx, s in enumerate(sequences): + if not len(s): + continue # empty list/array was found + if truncating == "pre": + trunc = s[-maxlen:] + elif truncating == "post": + trunc = s[:maxlen] + else: + raise ValueError(f'Truncating type "{truncating}" not understood') + + # check `trunc` has expected shape + trunc = np.asarray(trunc, dtype=dtype) + if trunc.shape[1:] != sample_shape: + raise ValueError( + f"Shape of sample {trunc.shape[1:]} of sequence at " + f"position {idx} is different from expected shape " + f"{sample_shape}" + ) + + if padding == "post": + x[idx, : len(trunc)] = trunc + elif padding == "pre": + x[idx, -len(trunc) :] = trunc + else: + raise ValueError(f'Padding type "{padding}" not understood') + return x diff --git a/keras/src/utils/sequence_utils_test.py b/keras/src/utils/sequence_utils_test.py new file mode 100644 index 000000000000..0714bd469a92 --- /dev/null +++ b/keras/src/utils/sequence_utils_test.py @@ -0,0 +1,129 @@ +from keras.src import testing +from keras.src.utils import sequence_utils + + +class PadSequencesTest(testing.TestCase): + def test_pad_sequences(self): + a = [[1], [1, 2], [1, 2, 3]] + + # test padding + b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre") + self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]]) + b = sequence_utils.pad_sequences(a, maxlen=3, padding="post") + self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]]) + + # test truncating + b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre") + self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]]) + b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post") + self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]]) + + # test value + b = sequence_utils.pad_sequences(a, maxlen=3, value=1) + self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]]) + + def test_pad_sequences_float(self): + a = [[1.2], [1.2, 2.3], [1.2, 2.3, 3.4]] + + # test padding + b = sequence_utils.pad_sequences( + a, maxlen=3, padding="pre", dtype="float32" + ) + self.assertAllClose(b, [[0, 0, 1.2], [0, 1.2, 2.3], [1.2, 2.3, 3.4]]) + b = sequence_utils.pad_sequences( + a, maxlen=3, padding="post", dtype="float32" + ) + self.assertAllClose(b, [[1.2, 0, 0], [1.2, 2.3, 0], [1.2, 2.3, 3.4]]) + + # test truncating + b = sequence_utils.pad_sequences( + a, maxlen=2, truncating="pre", dtype="float32" + ) + self.assertAllClose(b, [[0, 1.2], [1.2, 2.3], [2.3, 3.4]]) + b = sequence_utils.pad_sequences( + a, maxlen=2, truncating="post", dtype="float32" + ) + self.assertAllClose(b, [[0, 1.2], [1.2, 2.3], [1.2, 2.3]]) + + # test value + b = sequence_utils.pad_sequences(a, maxlen=3, value=1, dtype="float32") + self.assertAllClose(b, [[1, 1, 1.2], [1, 1.2, 2.3], [1.2, 2.3, 3.4]]) + + def test_pad_sequences_str(self): + a = [["1"], ["1", "2"], ["1", "2", "3"]] + + # test padding + b = sequence_utils.pad_sequences( + a, maxlen=3, padding="pre", value="pad", dtype=object + ) + self.assertAllEqual( + b, [["pad", "pad", "1"], ["pad", "1", "2"], ["1", "2", "3"]] + ) + b = sequence_utils.pad_sequences( + a, maxlen=3, padding="post", value="pad", dtype=" 0: + for i in range(len(layer._inbound_nodes)): + outputs = layer._inbound_nodes[i].output_tensors + output_shapes = tree.map_structure( + lambda x: format_shape(x.shape), outputs + ) + else: + try: + if hasattr(layer, "output_shape"): + output_shapes = format_shape(layer.output_shape) + else: + outputs = layer.compute_output_shape(**layer._build_shapes_dict) + output_shapes = tree.map_shape_structure( + lambda x: format_shape(x), outputs + ) + except NotImplementedError: + return "?" + if len(output_shapes) == 1: + return output_shapes[0] + out = str(output_shapes) + out = out.replace("'", "") + return out + + +def print_summary( + model, + line_length=None, + positions=None, + print_fn=None, + expand_nested=False, + show_trainable=False, + layer_range=None, +): + """Prints a summary of a model. + + Args: + model: Keras model instance. + line_length: Total length of printed lines + (e.g. set this to adapt the display to different + terminal window sizes). + positions: Relative or absolute positions of log elements in each line. + If not provided, defaults to `[0.3, 0.6, 0.70, 1.]`. + print_fn: Print function to use. + It will be called on each line of the summary. + You can set it to a custom function + in order to capture the string summary. + It defaults to `print` (prints to stdout). + expand_nested: Whether to expand the nested models. + If not provided, defaults to `False`. + show_trainable: Whether to show if a layer is trainable. + If not provided, defaults to `False`. + layer_range: List or tuple containing two strings, + the starting layer name and ending layer name (both inclusive), + indicating the range of layers to be printed in the summary. The + strings could also be regexes instead of an exact name. In this + case, the starting layer will be the first layer that matches + `layer_range[0]` and the ending layer will be the last element that + matches `layer_range[1]`. By default (`None`) all + layers in the model are included in the summary. + """ + from keras.src.models import Functional + from keras.src.models import Sequential + + if not print_fn and not io_utils.is_interactive_logging_enabled(): + print_fn = io_utils.print_msg + + if isinstance(model, Sequential): + sequential_like = True + layers = model.layers + elif not isinstance(model, Functional): + # We treat subclassed models as a simple sequence of layers, for logging + # purposes. + sequential_like = True + layers = model.layers + else: + layers = model._operations + sequential_like = True + nodes_by_depth = model._nodes_by_depth.values() + nodes = [] + for v in nodes_by_depth: + if (len(v) > 1) or ( + len(v) == 1 and len(tree.flatten(v[0].input_tensors)) > 1 + ): + # if the model has multiple nodes + # or if the nodes have multiple inbound_layers + # the model is no longer sequential + sequential_like = False + break + nodes += v + if sequential_like: + # search for shared layers + for layer in model.layers: + flag = False + for node in layer._inbound_nodes: + if node in nodes: + if flag: + sequential_like = False + break + else: + flag = True + if not sequential_like: + break + + if sequential_like: + default_line_length = 88 + positions = positions or [0.45, 0.80, 1.0] + # header names for the different log elements + header = ["Layer (type)", "Output Shape", "Param #"] + alignment = ["left", "left", "right"] + else: + default_line_length = 108 + positions = positions or [0.3, 0.56, 0.74, 1.0] + # header names for the different log elements + header = ["Layer (type)", "Output Shape", "Param #", "Connected to"] + alignment = ["left", "left", "right", "left"] + relevant_nodes = [] + for v in model._nodes_by_depth.values(): + relevant_nodes += v + + if show_trainable: + default_line_length += 12 + positions = [p * 0.90 for p in positions] + [1.0] + header.append("Trainable") + alignment.append("center") + + # Compute columns widths + default_line_length = min( + default_line_length, shutil.get_terminal_size().columns - 4 + ) + line_length = line_length or default_line_length + column_widths = [] + current = 0 + for pos in positions: + width = int(pos * line_length) - current + if width < 4: + raise ValueError("Insufficient console width to print summary.") + column_widths.append(width) + current += width + + # Render summary as a rich table. + columns = [] + # Right align parameter counts. + for i, name in enumerate(header): + column = rich.table.Column( + name, + justify=alignment[i], + width=column_widths[i], + ) + columns.append(column) + + table = rich.table.Table(*columns, width=line_length, show_lines=True) + + def get_connections(layer): + connections = "" + for node in layer._inbound_nodes: + if relevant_nodes and node not in relevant_nodes: + # node is not part of the current network + continue + for kt in node.input_tensors: + keras_history = kt._keras_history + inbound_layer = keras_history.operation + node_index = highlight_number(keras_history.node_index) + tensor_index = highlight_number(keras_history.tensor_index) + if connections: + connections += ", " + connections += ( + f"{inbound_layer.name}[{node_index}][{tensor_index}]" + ) + if not connections: + connections = "-" + return connections + + def get_layer_fields(layer, prefix=""): + output_shape = format_layer_shape(layer) + name = f"{prefix}{layer.name}" + cls_name = layer.__class__.__name__ + name = rich.markup.escape(name) + name += f" ({highlight_symbol(rich.markup.escape(cls_name))})" + + if not hasattr(layer, "built"): + params = highlight_number(0) + elif not layer.built: + params = f"{highlight_number(0)} (unbuilt)" + else: + params = highlight_number(f"{layer.count_params():,}") + + fields = [name, output_shape, params] + if not sequential_like: + fields.append(get_connections(layer)) + if show_trainable: + if hasattr(layer, "weights") and len(layer.weights) > 0: + fields.append( + bold_text("Y", color=34) + if layer.trainable + else bold_text("N", color=9) + ) + else: + fields.append(bold_text("-")) + return fields + + def print_layer(layer, nested_level=0): + if nested_level: + prefix = " " * nested_level + "└ " + else: + prefix = "" + + fields = get_layer_fields(layer, prefix=prefix) + + rows = [fields] + if expand_nested and hasattr(layer, "layers") and layer.layers: + nested_layers = layer.layers + nested_level += 1 + for i in range(len(nested_layers)): + rows.extend( + print_layer(nested_layers[i], nested_level=nested_level) + ) + return rows + + # Render all layers to the rich table. + layer_range = get_layer_index_bound_by_layer_name(layers, layer_range) + for layer in layers[layer_range[0] : layer_range[1]]: + for row in print_layer(layer): + table.add_row(*row) + + # After the table, append information about parameter count and size. + if hasattr(model, "_collected_trainable_weights"): + trainable_count = count_params(model._collected_trainable_weights) + trainable_memory_size = weight_memory_size( + model._collected_trainable_weights + ) + else: + trainable_count = count_params(model.trainable_weights) + trainable_memory_size = weight_memory_size(model.trainable_weights) + + non_trainable_count = count_params(model.non_trainable_weights) + non_trainable_memory_size = weight_memory_size(model.non_trainable_weights) + + if model.compiled and model.optimizer and model.optimizer.built: + optimizer_weight_count = count_params(model.optimizer.variables) + optimizer_memory_size = weight_memory_size(model.optimizer.variables) + optimizer_built = True + else: + optimizer_weight_count = 0 + optimizer_memory_size = 0 + optimizer_built = False + + total_count = trainable_count + non_trainable_count + optimizer_weight_count + total_memory_size = ( + trainable_memory_size + + non_trainable_memory_size + + optimizer_memory_size + ) + + # Create a rich console for printing. Capture for non-interactive logging. + if print_fn: + console = rich.console.Console( + highlight=False, force_terminal=False, color_system=None + ) + console.begin_capture() + else: + console = rich.console.Console(highlight=False) + + # Print the to the console. + console.print(bold_text(f'Model: "{rich.markup.escape(model.name)}"')) + console.print(table) + console.print( + bold_text(" Total params: ") + + highlight_number(f"{total_count:,}") + + f" ({readable_memory_size(total_memory_size)})" + ) + console.print( + bold_text(" Trainable params: ") + + highlight_number(f"{trainable_count:,}") + + f" ({readable_memory_size(trainable_memory_size)})" + ) + console.print( + bold_text(" Non-trainable params: ") + + highlight_number(f"{non_trainable_count:,}") + + f" ({readable_memory_size(non_trainable_memory_size)})" + ) + if optimizer_built: + console.print( + bold_text(" Optimizer params: ") + + highlight_number(f"{optimizer_weight_count:,}") + + f" ({readable_memory_size(optimizer_memory_size)})" + ) + + # Output captured summary for non-interactive logging. + if print_fn: + if print_fn is io_utils.print_msg: + print_fn(console.end_capture(), line_break=False) + else: + print_fn(console.end_capture()) + + +def get_layer_index_bound_by_layer_name(layers, layer_range=None): + """Get the layer indexes from the model based on layer names. + + The layer indexes can be used to slice the model into sub models for + display. + + Args: + model: `Model` instance. + layer_names: a list or tuple of 2 strings, the starting layer name and + ending layer name (both inclusive) for the result. All layers will + be included when `None` is provided. + + Returns: + The index value of layer based on its unique name (layer_names). + Output will be [first_layer_index, last_layer_index + 1]. + """ + if layer_range is not None: + if len(layer_range) != 2: + raise ValueError( + "layer_range must be a list or tuple of length 2. Received: " + f"layer_range = {layer_range} of length {len(layer_range)}" + ) + if not isinstance(layer_range[0], str) or not isinstance( + layer_range[1], str + ): + raise ValueError( + "layer_range should contain string type only. " + f"Received: {layer_range}" + ) + else: + return [0, len(layers)] + + lower_index = [ + idx + for idx, layer in enumerate(layers) + if re.match(layer_range[0], layer.name) + ] + upper_index = [ + idx + for idx, layer in enumerate(layers) + if re.match(layer_range[1], layer.name) + ] + + if not lower_index or not upper_index: + raise ValueError( + "Passed layer_names do not match the layer names in the model. " + f"Received: {layer_range}" + ) + + if min(lower_index) > max(upper_index): + return [min(upper_index), max(lower_index) + 1] + return [min(lower_index), max(upper_index) + 1] diff --git a/keras/src/utils/summary_utils_test.py b/keras/src/utils/summary_utils_test.py new file mode 100644 index 000000000000..bda3ed571260 --- /dev/null +++ b/keras/src/utils/summary_utils_test.py @@ -0,0 +1,126 @@ +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src.utils import summary_utils + + +class SummaryUtilsTest(testing.TestCase): + @parameterized.parameters([("adam",), (None,)]) + @pytest.mark.requires_trainable_backend + def test_print_model_summary(self, optimizer): + inputs = layers.Input((2,)) + outputs = layers.Dense(3)(inputs) + model = models.Model(inputs, outputs) + model.compile(optimizer=optimizer, loss="mse", metrics=["mse"]) + if optimizer: + # Trigger the optimizer weights creation + model.fit(x=np.zeros([4, 2]), y=np.zeros([4, 3])) + + summary_content = [] + + def print_to_variable(text, line_break=False): + summary_content.append(text) + + try: + summary_utils.print_summary(model, print_fn=print_to_variable) + summary_content = "\n".join(summary_content) + if optimizer: + self.assertIn("Total params: 29", summary_content) + self.assertIn("Trainable params: 9", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) + self.assertIn("Optimizer params: 20", summary_content) + else: + self.assertIn("Total params: 9", summary_content) + self.assertIn("Trainable params: 9", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) + self.assertNotIn("Optimizer params", summary_content) + except ImportError: + pass + + def test_print_model_summary_custom_build(self): + class MyModel(models.Model): + def __init__(self): + super().__init__() + self.dense1 = layers.Dense(4, activation="relu") + self.dense2 = layers.Dense(2, activation="softmax") + self.unbuilt_dense = layers.Dense(1) + + def build(self, input_shape): + self.dense1.build(input_shape) + input_shape = self.dense1.compute_output_shape(input_shape) + self.dense2.build(input_shape) + + def call(self, inputs): + x = self.dense1(inputs) + return self.dense2(x) + + model = MyModel() + model.build((None, 2)) + + summary_content = [] + + def print_to_variable(text, line_break=False): + summary_content.append(text) + + summary_utils.print_summary(model, print_fn=print_to_variable) + summary_content = "\n".join(summary_content) + self.assertIn("(None, 4)", summary_content) # dense1 + self.assertIn("(None, 2)", summary_content) # dense2 + self.assertIn("?", summary_content) # unbuilt_dense + self.assertIn("Total params: 22", summary_content) + self.assertIn("Trainable params: 22", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) + + def test_print_model_summary_op_as_layer(self): + inputs = layers.Input((2,)) + x = layers.Dense(4)(inputs) + outputs = ops.mean(x) + model = models.Model(inputs, outputs) + + summary_content = [] + + def print_to_variable(text, line_break=False): + summary_content.append(text) + + summary_utils.print_summary( + model, print_fn=print_to_variable, show_trainable=True + ) + summary_content = "\n".join(summary_content) + self.assertIn("(None, 4)", summary_content) # dense + self.assertIn("Y", summary_content) # dense + self.assertIn("()", summary_content) # mean + self.assertIn("-", summary_content) # mean + self.assertIn("Total params: 12", summary_content) + self.assertIn("Trainable params: 12", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) + + def test_print_model_summary_with_mha(self): + # In Keras <= 3.6, MHA exposes `output_shape` property which breaks this + # test. + class MyModel(models.Model): + def __init__(self): + super().__init__() + self.mha = layers.MultiHeadAttention(2, 2, output_shape=(4,)) + + def call(self, inputs): + return self.mha(inputs, inputs, inputs) + + model = MyModel() + model(np.ones((1, 2, 2))) + + summary_content = [] + + def print_to_variable(text, line_break=False): + summary_content.append(text) + + summary_utils.print_summary(model, print_fn=print_to_variable) + summary_content = "\n".join(summary_content) + self.assertIn("(1, 2, 4)", summary_content) # mha + self.assertIn("Total params: 56", summary_content) + self.assertIn("Trainable params: 56", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) diff --git a/keras/src/utils/text_dataset_utils.py b/keras/src/utils/text_dataset_utils.py new file mode 100644 index 000000000000..d329d6944540 --- /dev/null +++ b/keras/src/utils/text_dataset_utils.py @@ -0,0 +1,416 @@ +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils import dataset_utils +from keras.src.utils.grain_utils import make_string_batch +from keras.src.utils.module_utils import grain +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export( + [ + "keras.utils.text_dataset_from_directory", + "keras.preprocessing.text_dataset_from_directory", + ] +) +def text_dataset_from_directory( + directory, + labels="inferred", + label_mode="int", + class_names=None, + batch_size=32, + max_length=None, + shuffle=True, + seed=None, + validation_split=None, + subset=None, + follow_links=False, + format="tf", + verbose=True, +): + """Generates a dataset from text files in a directory. + + If your directory structure is: + + ``` + main_directory/ + ...class_a/ + ......a_text_1.txt + ......a_text_2.txt + ...class_b/ + ......b_text_1.txt + ......b_text_2.txt + ``` + + Then calling `text_dataset_from_directory(main_directory, + labels='inferred')` will return a dataset that yields batches of + texts from the subdirectories `class_a` and `class_b`, together with labels + 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). + + Only `.txt` files are supported at this time. + + By default, this function will return a `tf.data.Dataset` object. You can + set `format="grain"` to return a `grain.IterDataset` object instead, which + removes the TensorFlow dependency. + + Args: + directory: Directory where the data is located. + If `labels` is `"inferred"`, it should contain + subdirectories, each containing text files for a class. + Otherwise, the directory structure is ignored. + labels: Either `"inferred"` + (labels are generated from the directory structure), + `None` (no labels), + or a list/tuple of integer labels of the same size as the number of + text files found in the directory. Labels should be sorted according + to the alphanumeric order of the text file paths + (obtained via `os.walk(directory)` in Python). + label_mode: String describing the encoding of `labels`. Options are: + - `"int"`: means that the labels are encoded as integers + (e.g. for `sparse_categorical_crossentropy` loss). + - `"categorical"` means that the labels are + encoded as a categorical vector + (e.g. for `categorical_crossentropy` loss). + - `"binary"` means that the labels (there can be only 2) + are encoded as `float32` scalars with values 0 or 1 + (e.g. for `binary_crossentropy`). + - `None` (no labels). + class_names: Only valid if `"labels"` is `"inferred"`. + This is the explicit list of class names + (must match names of subdirectories). Used to control the order + of the classes (otherwise alphanumerical order is used). + batch_size: Size of the batches of data. + If `None`, the data will not be batched + (the dataset will yield individual samples). + Defaults to `32`. + max_length: Maximum size of a text string. Texts longer than this will + be truncated to `max_length`. + shuffle: Whether to shuffle the data. + If set to `False`, sorts the data in alphanumeric order. + Defaults to `True`. + seed: Optional random seed for shuffling and transformations. + validation_split: Optional float between 0 and 1, + fraction of data to reserve for validation. + subset: Subset of the data to return. + One of `"training"`, `"validation"` or `"both"`. + Only used if `validation_split` is set. + When `subset="both"`, the utility returns a tuple of two datasets + (the training and validation datasets respectively). + follow_links: Whether to visits subdirectories pointed to by symlinks. + Defaults to `False`. + format: The format of the return object. Defaults to `"tf"`. Available + options are: + - `"tf"`: returns a `tf.data.Dataset` object. Requires + TensorFlow to be installed. + - `"grain"`: returns a `grain.IterDataset` object. Requires + Grain to be installed. + verbose: Whether to display number information on classes and + number of files found. Defaults to `True`. + + Returns: + + A `tf.data.Dataset` (`format="tf"`) or `grain.IterDataset` + (`format="grain"`) object. + + When `format="tf"`: + - If `label_mode` is `None`, it yields `string` tensors of shape + `(batch_size,)`, containing the contents of a batch of text files. + - Otherwise, it yields a tuple `(texts, labels)`, where `texts` + has shape `(batch_size,)` and `labels` follows the format described + below. + + When `format="grain"`: + - If `label_mode` is `None`, it yields a list of Python strings containing + the contents of a batch of text files. + - Otherwise, it yields a tuple `(texts, labels)`, where `texts` + is a list of Python strings and `labels` follows the format described + below. + + Rules regarding labels format: + + - if `label_mode` is `int`, the labels are an `int32` tensor of shape + `(batch_size,)`. + - if `label_mode` is `binary`, the labels are a `float32` tensor of + 1s and 0s of shape `(batch_size, 1)`. + - if `label_mode` is `categorical`, the labels are a `float32` tensor + of shape `(batch_size, num_classes)`, representing a one-hot + encoding of the class index. + """ + if labels not in ("inferred", None): + if not isinstance(labels, (list, tuple)): + raise ValueError( + "`labels` argument should be a list/tuple of integer labels, " + "of the same size as the number of text files in the target " + "directory. If you wish to infer the labels from the " + "subdirectory names in the target directory, " + 'pass `labels="inferred"`. ' + "If you wish to get a dataset that only contains text samples " + f"(no labels), pass `labels=None`. Received: labels={labels}" + ) + if class_names: + raise ValueError( + "You can only pass `class_names` if " + f'`labels="inferred"`. Received: labels={labels}, and ' + f"class_names={class_names}" + ) + if label_mode not in {"int", "categorical", "binary", None}: + raise ValueError( + '`label_mode` argument must be one of "int", ' + '"categorical", "binary", ' + f"or None. Received: label_mode={label_mode}" + ) + if format not in ("tf", "grain"): + raise ValueError( + '`format` should be either "tf" or "grain". ' + f"Received: format={format}" + ) + if labels is None or label_mode is None: + labels = None + label_mode = None + dataset_utils.check_validation_split_arg( + validation_split, subset, shuffle, seed + ) + + if seed is None: + seed = np.random.randint(1e6) + file_paths, labels, class_names = dataset_utils.index_directory( + directory, + labels, + formats=(".txt",), + class_names=class_names, + shuffle=shuffle, + seed=seed, + follow_links=follow_links, + verbose=verbose, + ) + + if label_mode == "binary" and len(class_names) != 2: + raise ValueError( + 'When passing `label_mode="binary"`, there must be exactly 2 ' + f"class_names. Received: class_names={class_names}" + ) + if batch_size is not None: + shuffle_buffer_size = batch_size * 8 + else: + shuffle_buffer_size = 1024 + + if subset == "both": + ( + file_paths_train, + labels_train, + ) = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, "training" + ) + ( + file_paths_val, + labels_val, + ) = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, "validation" + ) + if not file_paths_train: + raise ValueError( + f"No training text files found in directory {directory}. " + "Allowed format: .txt" + ) + if not file_paths_val: + raise ValueError( + f"No validation text files found in directory {directory}. " + "Allowed format: .txt" + ) + train_dataset = paths_and_labels_to_dataset( + file_paths=file_paths_train, + labels=labels_train, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + max_length=max_length, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + format=format, + ) + val_dataset = paths_and_labels_to_dataset( + file_paths=file_paths_val, + labels=labels_val, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + max_length=max_length, + shuffle=False, + format=format, + ) + + if format == "tf": + if batch_size is not None: + train_dataset = train_dataset.batch(batch_size) + val_dataset = val_dataset.batch(batch_size) + train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) + val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) + else: + train_dataset = train_dataset.to_iter_dataset() + val_dataset = val_dataset.to_iter_dataset() + if batch_size is not None: + train_dataset = train_dataset.batch( + batch_size, batch_fn=make_string_batch + ) + val_dataset = val_dataset.batch( + batch_size, batch_fn=make_string_batch + ) + + # Users may need to reference `class_names`. + train_dataset.class_names = class_names + val_dataset.class_names = class_names + dataset = [train_dataset, val_dataset] + else: + file_paths, labels = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, subset + ) + if not file_paths: + raise ValueError( + f"No text files found in directory {directory}. " + "Allowed format: .txt" + ) + dataset = paths_and_labels_to_dataset( + file_paths=file_paths, + labels=labels, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + max_length=max_length, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + format=format, + ) + + if format == "tf": + if batch_size is not None: + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + else: + dataset = dataset.to_iter_dataset() + if batch_size is not None: + dataset = dataset.batch(batch_size, batch_fn=make_string_batch) + + # Users may need to reference `class_names`. + dataset.class_names = class_names + return dataset + + +def paths_and_labels_to_dataset( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle=False, + shuffle_buffer_size=None, + seed=None, + format="tf", +): + """Constructs a dataset of text strings and labels.""" + if format == "tf": + return _paths_and_labels_to_dataset_tf( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle, + shuffle_buffer_size, + seed, + ) + elif format == "grain": + return _paths_and_labels_to_dataset_grain( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle, + shuffle_buffer_size, + seed, + ) + + +def _paths_and_labels_to_dataset_tf( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + """Constructs a dataset of text strings and labels.""" + path_ds = tf.data.Dataset.from_tensor_slices(file_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_tf( + labels, label_mode, num_classes + ) + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) + + if label_mode: + ds = ds.map( + lambda x, y: (_path_to_string_content_tf(x, max_length), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + else: + ds = ds.map( + lambda x: _path_to_string_content_tf(x, max_length), + num_parallel_calls=tf.data.AUTOTUNE, + ) + return ds + + +def _path_to_string_content_tf(path, max_length): + txt = tf.io.read_file(path) + if max_length is not None: + txt = tf.strings.substr(txt, 0, max_length) + return txt + + +def _paths_and_labels_to_dataset_grain( + file_paths, + labels, + label_mode, + num_classes, + max_length, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + """Constructs a dataset of text strings and labels.""" + path_ds = grain.MapDataset.source(file_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset_grain( + labels, label_mode, num_classes + ) + ds = grain.experimental.ZipMapDataset([path_ds, label_ds]) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(seed=seed) + + if label_mode: + ds = ds.map( + lambda data: ( + _path_to_string_content_grain(data[0], max_length), + data[1], + ), + ) + else: + ds = ds.map(lambda x: _path_to_string_content_grain(x, max_length)) + return ds + + +def _path_to_string_content_grain(path, max_length): + with open(path, "r") as f: + txt = f.read() + if max_length is not None: + txt = txt[:max_length] + return txt diff --git a/keras/src/utils/text_dataset_utils_test.py b/keras/src/utils/text_dataset_utils_test.py new file mode 100644 index 000000000000..cfa5d30b1878 --- /dev/null +++ b/keras/src/utils/text_dataset_utils_test.py @@ -0,0 +1,403 @@ +import os +import random +import string + +from absl.testing import parameterized + +from keras.src import backend +from keras.src import testing +from keras.src.utils import text_dataset_utils + + +class TextDatasetFromDirectoryTest(testing.TestCase): + def _prepare_directory( + self, num_classes=2, nested_dirs=False, count=16, length=20 + ): + # Get a unique temp directory + temp_dir = self.get_temp_dir() + + # Generate paths to class subdirectories + paths = [] + for class_index in range(num_classes): + class_directory = f"class_{class_index}" + if nested_dirs: + class_paths = [ + class_directory, + os.path.join(class_directory, "subfolder_1"), + os.path.join(class_directory, "subfolder_2"), + os.path.join( + class_directory, "subfolder_1", "sub-subfolder" + ), + ] + else: + class_paths = [class_directory] + for path in class_paths: + os.mkdir(os.path.join(temp_dir, path)) + paths += class_paths + + for i in range(count): + path = paths[i % len(paths)] + filename = os.path.join(path, f"text_{i}.txt") + with open(os.path.join(temp_dir, filename), "w") as f: + text = "".join( + [random.choice(string.printable) for _ in range(length)] + ) + f.write(text) + return temp_dir + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_standalone(self, format): + # Test retrieving txt files without labels from a directory and its + # subdirs. Save a few extra files in the parent directory. + directory = self._prepare_directory(count=7, num_classes=2) + for i in range(3): + filename = f"text_{i}.txt" + with open(os.path.join(directory, filename), "w") as f: + text = "".join( + [random.choice(string.printable) for _ in range(20)] + ) + f.write(text) + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=5, + label_mode=None, + max_length=10, + format=format, + ) + batch = next(iter(dataset)) + # We just return the texts, no labels + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch.shape), [5]) + self.assertDType(batch, "string") + else: + self.assertLen(batch, 5) + self.assertIsInstance(batch[0], str) + # Count samples + batch_count = 0 + sample_count = 0 + for batch in dataset: + batch_count += 1 + sample_count += len(batch) + self.assertEqual(batch_count, 2) + self.assertEqual(sample_count, 10) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_binary(self, format=format): + directory = self._prepare_directory(num_classes=2) + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=8, + label_mode="int", + max_length=10, + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(batch[0].shape, (8,)) + self.assertDType(batch[0], "string") + self.assertEqual(len(batch[0].numpy()[0]), 10) # Test max_length + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertLen(batch[0][0], 10) # Test max_length + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=8, + label_mode="binary", + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8, 1]) + self.assertDType(batch[1], "float32") + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=8, + label_mode="categorical", + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8, 2]) + self.assertDType(batch[1], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_sample_count(self, format): + directory = self._prepare_directory(num_classes=4, count=15) + dataset = text_dataset_utils.text_dataset_from_directory( + directory, batch_size=8, label_mode=None, format=format + ) + sample_count = 0 + for batch in dataset: + sample_count += len(batch) + self.assertEqual(sample_count, 15) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_multiclass(self, format): + directory = self._prepare_directory(num_classes=4, count=15) + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, batch_size=8, label_mode=None, format=format + ) + batch = next(iter(dataset)) + self.assertLen(batch, 8) + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, batch_size=8, label_mode=None, format=format + ) + sample_count = 0 + iterator = iter(dataset) + for batch in dataset: + sample_count += len(next(iterator)) + self.assertEqual(sample_count, 15) + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, batch_size=8, label_mode="int", format=format + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8]) + self.assertDType(batch[1], "int32") + + dataset = text_dataset_utils.text_dataset_from_directory( + directory, batch_size=8, label_mode="categorical", format=format + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + if format == "tf" or backend.backend() == "tensorflow": + self.assertEqual(list(batch[0].shape), [8]) + self.assertEqual(batch[0].dtype.name, "string") + else: + self.assertLen(batch[0], 8) + self.assertIsInstance(batch[0][0], str) + self.assertEqual(list(batch[1].shape), [8, 4]) + self.assertDType(batch[1], "float32") + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_validation_split(self, format): + directory = self._prepare_directory(num_classes=2, count=10) + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=10, + validation_split=0.2, + subset="training", + seed=1337, + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertLen(batch[0], 8) + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=10, + validation_split=0.2, + subset="validation", + seed=1337, + format=format, + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertLen(batch[0], 2) + + ( + train_dataset, + val_dataset, + ) = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=10, + validation_split=0.2, + subset="both", + seed=1337, + format=format, + ) + batch = next(iter(train_dataset)) + self.assertLen(batch, 2) + self.assertLen(batch[0], 8) + batch = next(iter(val_dataset)) + self.assertLen(batch, 2) + self.assertLen(batch[0], 2) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_manual_labels(self, format): + directory = self._prepare_directory(num_classes=2, count=2) + dataset = text_dataset_utils.text_dataset_from_directory( + directory, batch_size=8, labels=[0, 1], shuffle=False, format=format + ) + batch = next(iter(dataset)) + self.assertLen(batch, 2) + self.assertAllClose(batch[1], [0, 1]) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_follow_links(self, format): + directory = self._prepare_directory( + num_classes=2, count=25, nested_dirs=True + ) + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=8, + label_mode=None, + follow_links=True, + format=format, + ) + sample_count = 0 + for batch in dataset: + sample_count += len(batch) + self.assertEqual(sample_count, 25) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_no_files(self, format): + directory = self._prepare_directory(num_classes=2, count=0) + with self.assertRaisesRegex(ValueError, "No text files found"): + _ = text_dataset_utils.text_dataset_from_directory( + directory, format=format + ) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_errors(self, format): + directory = self._prepare_directory(num_classes=3, count=5) + + with self.assertRaisesRegex(ValueError, "`labels` argument should be"): + _ = text_dataset_utils.text_dataset_from_directory( + directory, labels="other", format=format + ) + + with self.assertRaisesRegex( + ValueError, "`label_mode` argument must be" + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, label_mode="other", format=format + ) + + with self.assertRaisesRegex( + ValueError, 'only pass `class_names` if `labels="inferred"`' + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, + labels=[0, 0, 1, 1, 1], + class_names=["class_0", "class_1", "class_2"], + format=format, + ) + + with self.assertRaisesRegex( + ValueError, + "Expected the lengths of `labels` to match the number of files", + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, labels=[0, 0, 1, 1], format=format + ) + + with self.assertRaisesRegex( + ValueError, "`class_names` passed did not match" + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, class_names=["class_0", "wrong_class"], format=format + ) + + with self.assertRaisesRegex(ValueError, "there must be exactly 2"): + _ = text_dataset_utils.text_dataset_from_directory( + directory, label_mode="binary", format=format + ) + + with self.assertRaisesRegex( + ValueError, "`validation_split` must be between 0 and 1" + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, validation_split=2, format=format + ) + + with self.assertRaisesRegex( + ValueError, + '`subset` must be either "training", "validation" or "both"', + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, validation_split=0.2, subset="other", format=format + ) + + with self.assertRaisesRegex( + ValueError, "`validation_split` must be set" + ): + _ = text_dataset_utils.text_dataset_from_directory( + directory, + validation_split=0.0, + subset="training", + format=format, + ) + + with self.assertRaisesRegex(ValueError, "must provide a `seed`"): + _ = text_dataset_utils.text_dataset_from_directory( + directory, + validation_split=0.2, + subset="training", + format=format, + ) + + @parameterized.named_parameters( + ("tf", "tf"), + ("grain", "grain"), + ) + def test_text_dataset_from_directory_not_batched(self, format): + directory = self._prepare_directory() + dataset = text_dataset_utils.text_dataset_from_directory( + directory, + batch_size=None, + label_mode=None, + follow_links=True, + format=format, + ) + + sample = next(iter(dataset)) + if format == "tf": + self.assertEqual(len(sample.shape), 0) + else: + self.assertIsInstance(sample, str) diff --git a/keras/src/utils/tf_utils.py b/keras/src/utils/tf_utils.py new file mode 100644 index 000000000000..9589fe230f02 --- /dev/null +++ b/keras/src/utils/tf_utils.py @@ -0,0 +1,157 @@ +from keras.src import backend +from keras.src.utils.module_utils import tensorflow as tf + + +def get_tensor_spec(t, dynamic_batch=False, name=None): + """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" + if isinstance(t, tf.TypeSpec): + spec = t + elif isinstance(t, tf.__internal__.CompositeTensor): + # Check for ExtensionTypes + spec = t._type_spec + elif hasattr(t, "shape") and hasattr(t, "dtype"): + spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) + else: + return None # Allow non-Tensors to pass through. + + if not dynamic_batch: + return spec + + shape = spec.shape + if shape.rank is None or shape.rank == 0: + return spec + + shape_list = shape.as_list() + shape_list[0] = None + shape = tf.TensorShape(shape_list) + spec._shape = shape + return spec + + +def ensure_tensor(inputs, dtype=None): + """Ensures the input is a Tensor, SparseTensor or RaggedTensor.""" + if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)): + if backend.backend() == "torch" and backend.is_tensor(inputs): + # Plain `np.asarray()` conversion fails with PyTorch. + inputs = backend.convert_to_numpy(inputs) + inputs = tf.convert_to_tensor(inputs, dtype) + if dtype is not None and inputs.dtype != dtype: + inputs = tf.cast(inputs, dtype) + return inputs + + +def is_ragged_tensor(x): + return "ragged_tensor.RaggedTensor" in str(type(x)) + + +def sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None): + """Apply binary or count encoding to an input and return a sparse tensor.""" + result = tf.sparse.bincount( + inputs, + weights=count_weights, + minlength=depth, + maxlength=depth, + axis=-1, + binary_output=binary_output, + ) + result = tf.cast(result, dtype) + if inputs.shape.rank == 1: + output_shape = (depth,) + else: + batch_size = tf.shape(result)[0] + output_shape = (batch_size, depth) + result = tf.SparseTensor( + indices=result.indices, values=result.values, dense_shape=output_shape + ) + return result + + +def dense_bincount(inputs, depth, binary_output, dtype, count_weights=None): + """Apply binary or count encoding to an input.""" + result = tf.math.bincount( + inputs, + weights=count_weights, + minlength=depth, + maxlength=depth, + dtype=dtype, + axis=-1, + binary_output=binary_output, + ) + if inputs.shape.rank == 1: + result.set_shape(tf.TensorShape((depth,))) + else: + batch_size = inputs.shape.as_list()[0] + result.set_shape(tf.TensorShape((batch_size, depth))) + return result + + +def expand_dims(inputs, axis): + """Expand dims on sparse, ragged, or dense tensors.""" + if isinstance(inputs, tf.SparseTensor): + return tf.sparse.expand_dims(inputs, axis) + return tf.expand_dims(inputs, axis) + + +def tf_encode_categorical_inputs( + inputs, + output_mode, + depth, + dtype="float32", + sparse=False, + count_weights=None, + idf_weights=None, +): + """Encodes categorical inputs according to output_mode. + + Faster method that relies on bincount. + """ + + if output_mode == "int": + return tf.identity(tf.cast(inputs, dtype)) + + original_shape = inputs.shape + # In all cases, we should uprank scalar input to a single sample. + if inputs.shape.rank == 0: + inputs = expand_dims(inputs, -1) + # One hot will uprank only if the final output dimension is not already 1. + if output_mode == "one_hot": + if inputs.shape[-1] != 1: + inputs = expand_dims(inputs, -1) + + if inputs.shape.rank > 2: + raise ValueError( + "When output_mode is not `'int'`, maximum supported output rank " + f"is 2. Received output_mode {output_mode} and input shape " + f"{original_shape}, " + f"which would result in output rank {inputs.shape.rank}." + ) + + binary_output = output_mode in ("multi_hot", "one_hot") + if sparse: + bincounts = sparse_bincount( + inputs, depth, binary_output, dtype, count_weights + ) + else: + bincounts = dense_bincount( + inputs, depth, binary_output, dtype, count_weights + ) + + bincounts = tf.cast(bincounts, dtype) + if output_mode != "tf_idf": + return bincounts + + if idf_weights is None: + raise ValueError( + "When output mode is `'tf_idf'`, idf_weights must be provided. " + f"Received: output_mode={output_mode} and idf_weights={idf_weights}" + ) + + if sparse: + value_weights = tf.gather(idf_weights, bincounts.indices[:, -1]) + return tf.SparseTensor( + bincounts.indices, + value_weights * bincounts.values, + bincounts.dense_shape, + ) + else: + return tf.multiply(bincounts, idf_weights) diff --git a/keras/src/utils/timeseries_dataset_utils.py b/keras/src/utils/timeseries_dataset_utils.py new file mode 100644 index 000000000000..bf0997b98bbe --- /dev/null +++ b/keras/src/utils/timeseries_dataset_utils.py @@ -0,0 +1,261 @@ +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export( + [ + "keras.utils.timeseries_dataset_from_array", + "keras.preprocessing.timeseries_dataset_from_array", + ] +) +def timeseries_dataset_from_array( + data, + targets, + sequence_length, + sequence_stride=1, + sampling_rate=1, + batch_size=128, + shuffle=False, + seed=None, + start_index=None, + end_index=None, +): + """Creates a dataset of sliding windows over a timeseries provided as array. + + This function takes in a sequence of data-points gathered at + equal intervals, along with time series parameters such as + length of the sequences/windows, spacing between two sequence/windows, etc., + to produce batches of timeseries inputs and targets. + + Args: + data: Numpy array or eager tensor + containing consecutive data points (timesteps). + Axis 0 is expected to be the time dimension. + targets: Targets corresponding to timesteps in `data`. + `targets[i]` should be the target + corresponding to the window that starts at index `i` + (see example 2 below). + Pass `None` if you don't have target data (in this case the dataset + will only yield the input data). + sequence_length: Length of the output sequences + (in number of timesteps). + sequence_stride: Period between successive output sequences. + For stride `s`, output samples would + start at index `data[i]`, `data[i + s]`, `data[i + 2 * s]`, etc. + sampling_rate: Period between successive individual timesteps + within sequences. For rate `r`, timesteps + `data[i], data[i + r], ... data[i + sequence_length]` + are used for creating a sample sequence. + batch_size: Number of timeseries samples in each batch + (except maybe the last one). If `None`, the data will not be batched + (the dataset will yield individual samples). + shuffle: Whether to shuffle output samples, + or instead draw them in chronological order. + seed: Optional int; random seed for shuffling. + start_index: Optional int; data points earlier (exclusive) + than `start_index` will not be used + in the output sequences. This is useful to reserve part of the + data for test or validation. + end_index: Optional int; data points later (exclusive) than `end_index` + will not be used in the output sequences. + This is useful to reserve part of the data for test or validation. + + Returns: + + A `tf.data.Dataset` instance. If `targets` was passed, the dataset yields + tuple `(batch_of_sequences, batch_of_targets)`. If not, the dataset yields + only `batch_of_sequences`. + + Example 1: + + Consider indices `[0, 1, ... 98]`. + With `sequence_length=10, sampling_rate=2, sequence_stride=3`, + `shuffle=False`, the dataset will yield batches of sequences + composed of the following indices: + + ``` + First sequence: [0 2 4 6 8 10 12 14 16 18] + Second sequence: [3 5 7 9 11 13 15 17 19 21] + Third sequence: [6 8 10 12 14 16 18 20 22 24] + ... + Last sequence: [78 80 82 84 86 88 90 92 94 96] + ``` + + In this case the last 2 data points are discarded since no full sequence + can be generated to include them (the next sequence would have started + at index 81, and thus its last step would have gone over 98). + + Example 2: Temporal regression. + + Consider an array `data` of scalar values, of shape `(steps,)`. + To generate a dataset that uses the past 10 + timesteps to predict the next timestep, you would use: + + ```python + input_data = data[:-10] + targets = data[10:] + dataset = timeseries_dataset_from_array( + input_data, targets, sequence_length=10) + for batch in dataset: + inputs, targets = batch + assert np.array_equal(inputs[0], data[:10]) # First sequence: steps [0-9] + # Corresponding target: step 10 + assert np.array_equal(targets[0], data[10]) + break + ``` + + Example 3: Temporal regression for many-to-many architectures. + + Consider two arrays of scalar values `X` and `Y`, + both of shape `(100,)`. The resulting dataset should consist samples with + 20 timestamps each. The samples should not overlap. + To generate a dataset that uses the current timestamp + to predict the corresponding target timestep, you would use: + + ```python + X = np.arange(100) + Y = X*2 + + sample_length = 20 + input_dataset = timeseries_dataset_from_array( + X, None, sequence_length=sample_length, sequence_stride=sample_length) + target_dataset = timeseries_dataset_from_array( + Y, None, sequence_length=sample_length, sequence_stride=sample_length) + + for batch in zip(input_dataset, target_dataset): + inputs, targets = batch + assert np.array_equal(inputs[0], X[:sample_length]) + + # second sample equals output timestamps 20-40 + assert np.array_equal(targets[1], Y[sample_length:2*sample_length]) + break + ``` + """ + if start_index: + if start_index < 0: + raise ValueError( + "`start_index` must be 0 or greater. Received: " + f"start_index={start_index}" + ) + if start_index >= len(data): + raise ValueError( + "`start_index` must be lower than the length of the " + f"data. Received: start_index={start_index}, for data " + f"of length {len(data)}" + ) + if end_index: + if start_index and end_index <= start_index: + raise ValueError( + "`end_index` must be higher than `start_index`. " + f"Received: start_index={start_index}, and " + f"end_index={end_index} " + ) + if end_index >= len(data): + raise ValueError( + "`end_index` must be lower than the length of the " + f"data. Received: end_index={end_index}, for data of " + f"length {len(data)}" + ) + if end_index <= 0: + raise ValueError( + "`end_index` must be higher than 0. " + f"Received: end_index={end_index}" + ) + + # Validate strides + if sampling_rate <= 0: + raise ValueError( + "`sampling_rate` must be higher than 0. Received: " + f"sampling_rate={sampling_rate}" + ) + if sampling_rate >= len(data): + raise ValueError( + "`sampling_rate` must be lower than the length of the " + f"data. Received: sampling_rate={sampling_rate}, for data " + f"of length {len(data)}" + ) + if sequence_stride <= 0: + raise ValueError( + "`sequence_stride` must be higher than 0. Received: " + f"sequence_stride={sequence_stride}" + ) + if sequence_stride >= len(data): + raise ValueError( + "`sequence_stride` must be lower than the length of the " + f"data. Received: sequence_stride={sequence_stride}, for " + f"data of length {len(data)}" + ) + + if start_index is None: + start_index = 0 + if end_index is None: + end_index = len(data) + + # Determine the lowest dtype to store start positions (to lower memory + # usage). + num_seqs = end_index - start_index - (sequence_length - 1) * sampling_rate + if targets is not None: + num_seqs = min(num_seqs, len(targets)) + if num_seqs < 2147483647: + index_dtype = "int32" + else: + index_dtype = "int64" + + # Generate start positions + start_positions = np.arange(0, num_seqs, sequence_stride, dtype=index_dtype) + if shuffle: + if seed is None: + seed = np.random.randint(1e6) + rng = np.random.RandomState(seed) + rng.shuffle(start_positions) + + sequence_length = tf.cast(sequence_length, dtype=index_dtype) + sampling_rate = tf.cast(sampling_rate, dtype=index_dtype) + + positions_ds = tf.data.Dataset.from_tensors(start_positions).repeat() + + # For each initial window position, generates indices of the window elements + indices = tf.data.Dataset.zip( + (tf.data.Dataset.range(len(start_positions)), positions_ds) + ).map( + lambda i, positions: tf.range( + positions[i], + positions[i] + sequence_length * sampling_rate, + sampling_rate, + ), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + dataset = sequences_from_indices(data, indices, start_index, end_index) + if targets is not None: + indices = tf.data.Dataset.zip( + (tf.data.Dataset.range(len(start_positions)), positions_ds) + ).map( + lambda i, positions: positions[i], + num_parallel_calls=tf.data.AUTOTUNE, + ) + target_ds = sequences_from_indices( + targets, indices, start_index, end_index + ) + dataset = tf.data.Dataset.zip((dataset, target_ds)) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + if batch_size is not None: + if shuffle: + # Shuffle locally at each iteration + dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) + dataset = dataset.batch(batch_size) + else: + if shuffle: + dataset = dataset.shuffle(buffer_size=1024, seed=seed) + return dataset + + +def sequences_from_indices(array, indices_ds, start_index, end_index): + dataset = tf.data.Dataset.from_tensors(array[start_index:end_index]) + dataset = tf.data.Dataset.zip((dataset.repeat(), indices_ds)).map( + lambda steps, inds: tf.gather(steps, inds), + num_parallel_calls=tf.data.AUTOTUNE, + ) + return dataset diff --git a/keras/src/utils/timeseries_dataset_utils_test.py b/keras/src/utils/timeseries_dataset_utils_test.py new file mode 100644 index 000000000000..251b81cd3589 --- /dev/null +++ b/keras/src/utils/timeseries_dataset_utils_test.py @@ -0,0 +1,204 @@ +import numpy as np + +from keras.src import testing +from keras.src.utils import timeseries_dataset_utils + + +class TimeseriesDatasetTest(testing.TestCase): + def test_basics(self): + # Test ordering, targets, sequence length, batch size + data = np.arange(100) + targets = data * 2 + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, targets, sequence_length=9, batch_size=5 + ) + # Expect 19 batches + for i, batch in enumerate(dataset): + self.assertLen(batch, 2) + inputs, targets = batch + if i < 18: + self.assertEqual(inputs.shape, (5, 9)) + if i == 18: + # Last batch: size 2 + self.assertEqual(inputs.shape, (2, 9)) + # Check target values + self.assertAllClose(targets, inputs[:, 0] * 2) + for j in range(min(5, len(inputs))): + # Check each sample in the batch + self.assertAllClose( + inputs[j], np.arange(i * 5 + j, i * 5 + j + 9) + ) + + def test_timeseries_regression(self): + # Test simple timeseries regression use case + data = np.arange(10) + offset = 3 + targets = data[offset:] + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, targets, sequence_length=offset, batch_size=1 + ) + i = 0 + for batch in dataset: + self.assertLen(batch, 2) + inputs, targets = batch + self.assertEqual(inputs.shape, (1, 3)) + # Check values + self.assertAllClose(targets[0], data[offset + i]) + self.assertAllClose(inputs[0], data[i : i + offset]) + i += 1 + self.assertEqual(i, 7) # Expect 7 batches + + def test_no_targets(self): + data = np.arange(50) + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, None, sequence_length=10, batch_size=5 + ) + # Expect 9 batches + i = None + for i, batch in enumerate(dataset): + if i < 8: + self.assertEqual(batch.shape, (5, 10)) + elif i == 8: + self.assertEqual(batch.shape, (1, 10)) + for j in range(min(5, len(batch))): + # Check each sample in the batch + self.assertAllClose( + batch[j], np.arange(i * 5 + j, i * 5 + j + 10) + ) + self.assertEqual(i, 8) + + def test_shuffle(self): + # Test cross-epoch random order and seed determinism + data = np.arange(10) + targets = data * 2 + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, + targets, + sequence_length=5, + batch_size=1, + shuffle=True, + seed=123, + ) + first_seq = None + for x, y in dataset.take(1): + self.assertNotAllClose(x, np.arange(0, 5)) + self.assertAllClose(x[:, 0] * 2, y) + first_seq = x + # Check that a new iteration with the same dataset yields different + # results + for x, _ in dataset.take(1): + self.assertNotAllClose(x, first_seq) + # Check determinism with same seed + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, + targets, + sequence_length=5, + batch_size=1, + shuffle=True, + seed=123, + ) + for x, _ in dataset.take(1): + self.assertAllClose(x, first_seq) + + def test_sampling_rate(self): + data = np.arange(100) + targets = data * 2 + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, targets, sequence_length=9, batch_size=5, sampling_rate=2 + ) + for i, batch in enumerate(dataset): + self.assertLen(batch, 2) + inputs, targets = batch + if i < 16: + self.assertEqual(inputs.shape, (5, 9)) + if i == 16: + # Last batch: size 4 + self.assertEqual(inputs.shape, (4, 9)) + # Check target values + self.assertAllClose(inputs[:, 0] * 2, targets) + for j in range(min(5, len(inputs))): + # Check each sample in the batch + start_index = i * 5 + j + end_index = start_index + 9 * 2 + self.assertAllClose( + inputs[j], np.arange(start_index, end_index, 2) + ) + + def test_sequence_stride(self): + data = np.arange(100) + targets = data * 2 + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, targets, sequence_length=9, batch_size=5, sequence_stride=3 + ) + for i, batch in enumerate(dataset): + self.assertLen(batch, 2) + inputs, targets = batch + if i < 6: + self.assertEqual(inputs.shape, (5, 9)) + if i == 6: + # Last batch: size 1 + self.assertEqual(inputs.shape, (1, 9)) + # Check target values + self.assertAllClose(inputs[:, 0] * 2, targets) + for j in range(min(5, len(inputs))): + # Check each sample in the batch + start_index = i * 5 * 3 + j * 3 + end_index = start_index + 9 + self.assertAllClose( + inputs[j], np.arange(start_index, end_index) + ) + + def test_start_and_end_index(self): + data = np.arange(100) + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, + None, + sequence_length=9, + batch_size=5, + sequence_stride=3, + sampling_rate=2, + start_index=10, + end_index=90, + ) + for batch in dataset: + self.assertLess(np.max(batch[0]), 90) + self.assertGreater(np.min(batch[0]), 9) + + def test_errors(self): + # bad start index + with self.assertRaisesRegex(ValueError, "`start_index` must be "): + _ = timeseries_dataset_utils.timeseries_dataset_from_array( + np.arange(10), None, 3, start_index=-1 + ) + with self.assertRaisesRegex(ValueError, "`start_index` must be "): + _ = timeseries_dataset_utils.timeseries_dataset_from_array( + np.arange(10), None, 3, start_index=11 + ) + # bad end index + with self.assertRaisesRegex(ValueError, "`end_index` must be "): + _ = timeseries_dataset_utils.timeseries_dataset_from_array( + np.arange(10), None, 3, end_index=-1 + ) + with self.assertRaisesRegex(ValueError, "`end_index` must be "): + _ = timeseries_dataset_utils.timeseries_dataset_from_array( + np.arange(10), None, 3, end_index=11 + ) + # bad sampling_rate + with self.assertRaisesRegex(ValueError, "`sampling_rate` must be "): + _ = timeseries_dataset_utils.timeseries_dataset_from_array( + np.arange(10), None, 3, sampling_rate=0 + ) + # bad sequence stride + with self.assertRaisesRegex(ValueError, "`sequence_stride` must be "): + _ = timeseries_dataset_utils.timeseries_dataset_from_array( + np.arange(10), None, 3, sequence_stride=0 + ) + + def test_not_batched(self): + data = np.arange(100) + + dataset = timeseries_dataset_utils.timeseries_dataset_from_array( + data, None, sequence_length=9, batch_size=None, shuffle=True + ) + sample = next(iter(dataset)) + self.assertEqual(len(sample.shape), 1) diff --git a/keras/src/utils/torch_utils.py b/keras/src/utils/torch_utils.py new file mode 100644 index 000000000000..f6ac7f034c5c --- /dev/null +++ b/keras/src/utils/torch_utils.py @@ -0,0 +1,194 @@ +import base64 +import io + +from packaging.version import parse + +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.layers import Layer +from keras.src.ops import convert_to_numpy +from keras.src.ops import convert_to_tensor +from keras.src.saving.serialization_lib import in_safe_mode + + +@keras_export("keras.layers.TorchModuleWrapper") +class TorchModuleWrapper(Layer): + """Torch module wrapper layer. + + `TorchModuleWrapper` is a wrapper class that can turn any + `torch.nn.Module` into a Keras layer, in particular by making its + parameters trackable by Keras. + + `TorchModuleWrapper` is only compatible with the PyTorch backend and + cannot be used with the TensorFlow or JAX backends. + + Args: + module: `torch.nn.Module` instance. If it's a `LazyModule` + instance, then its parameters must be initialized before + passing the instance to `TorchModuleWrapper` (e.g. by calling + it once). + output_shape :The shape of the output of this layer. It helps Keras + perform automatic shape inference. + name: The name of the layer (string). + + Example: + + Here's an example of how the `TorchModuleWrapper` can be used with vanilla + PyTorch modules. + + ```python + import torch + import torch.nn as nn + import torch.nn.functional as F + + import keras + from keras.layers import TorchModuleWrapper + + class Classifier(keras.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Wrap `torch.nn.Module`s with `TorchModuleWrapper` + # if they contain parameters + self.conv1 = TorchModuleWrapper( + nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3)) + ) + self.conv2 = TorchModuleWrapper( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)) + ) + self.pool = nn.MaxPool2d(kernel_size=(2, 2)) + self.flatten = nn.Flatten() + self.dropout = nn.Dropout(p=0.5) + self.fc = TorchModuleWrapper(nn.Linear(1600, 10)) + + def call(self, inputs): + x = F.relu(self.conv1(inputs)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = self.flatten(x) + x = self.dropout(x) + x = self.fc(x) + return F.softmax(x, dim=1) + + + model = Classifier() + model.build((1, 28, 28)) + print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape) + + model.compile( + loss="sparse_categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"] + ) + model.fit(train_loader, epochs=5) + ``` + """ + + def __init__(self, module, name=None, output_shape=None, **kwargs): + super().__init__(name=name, **kwargs) + import torch.nn as nn + + from keras.src.backend.torch.core import get_device + + if ( + isinstance(module, nn.modules.lazy.LazyModuleMixin) + and module.has_uninitialized_params() + ): + raise ValueError( + "LazyModules are not supported unless they " + "are already initialized. " + f"Received uninitialized LazyModule: module={module}" + ) + + self.module = module.to(get_device()) + self._track_module_parameters() + self.output_shape = output_shape + + def parameters(self, recurse=True): + return self.module.parameters(recurse=recurse) + + def _track_module_parameters(self): + for param in self.module.parameters(): + # The Variable will reuse the raw `param` + # and simply wrap it. + variable = backend.Variable( + initializer=param, trainable=param.requires_grad + ) + self._track_variable(variable) + self.built = True + + def call(self, *args, training=None, **kwargs): + if training is False: + self.eval() + else: + self.train() + return self.module(*args, **kwargs) + + def save_own_variables(self, store): + """Saves model's state from `state_dict`. + `model.parameters` excludes some of model's state like + `BatchNorm` mean and variance. So, use `state_dict` to obtain + all of model's state. + """ + state_dict = self.module.state_dict() + for key in state_dict.keys(): + store[key] = convert_to_numpy(state_dict[key]) + + def load_own_variables(self, store): + """Loads model's state via `state_dict`.""" + state_dict = {} + for key in store.keys(): + if isinstance(key, bytes): + key = key.decode() + state_dict[key] = convert_to_tensor(store[key]) + self.module.load_state_dict(state_dict) + + def compute_output_shape(self, input_shape): + if self.output_shape is None: + return super().compute_output_shape(input_shape) + return self.output_shape + + def get_config(self): + base_config = super().get_config() + import torch + + buffer = io.BytesIO() + torch.save(self.module, buffer) + # Encode the buffer using base64 to ensure safe serialization + buffer_b64 = base64.b64encode(buffer.getvalue()).decode("ascii") + config = { + "module": buffer_b64, + "output_shape": self.output_shape, + } + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + import torch + + if "module" in config: + if in_safe_mode(): + raise ValueError( + "Requested the deserialization of a `torch.nn.Module` " + "object via `torch.load()`. This carries a potential risk " + "of arbitrary code execution and thus it is disallowed by " + "default. If you trust the source of the artifact, you can " + "override this error by passing `safe_mode=False` to the " + "loading function, or calling " + "`keras.config.enable_unsafe_deserialization()." + ) + + # Decode the base64 string back to bytes + buffer_bytes = base64.b64decode(config["module"].encode("ascii")) + buffer = io.BytesIO(buffer_bytes) + config["module"] = torch.load(buffer, weights_only=False) + return cls(**config) + + +def no_grad(orig_func): + import torch + + if parse(torch.__version__) >= parse("2.1.0"): + return torch.no_grad(orig_func) + else: + return orig_func diff --git a/keras/src/utils/torch_utils_test.py b/keras/src/utils/torch_utils_test.py new file mode 100644 index 000000000000..c1f0cb78c534 --- /dev/null +++ b/keras/src/utils/torch_utils_test.py @@ -0,0 +1,291 @@ +import os + +import numpy as np +import pytest +import torch +from absl.testing import parameterized + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import saving +from keras.src import testing +from keras.src.backend.torch.core import get_device +from keras.src.utils.torch_utils import TorchModuleWrapper + + +class Classifier(models.Model): + def __init__( + self, use_batch_norm=False, num_torch_layers=1, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.use_batch_norm = use_batch_norm + self.num_torch_layers = num_torch_layers + self.torch_wrappers = [] + for _ in range(num_torch_layers): + modules = [torch.nn.Linear(2, 2)] + if use_batch_norm: + modules.append(torch.nn.BatchNorm1d(2)) + torch_model = torch.nn.Sequential(*modules) + self.torch_wrappers.append(TorchModuleWrapper(torch_model)) + self.fc = layers.Dense(1) + + def call(self, x, training=None): + for wrapper in self.torch_wrappers: + x = wrapper(x, training=training) + return self.fc(x) + + def get_config(self): + config = super().get_config() + config["use_batch_norm"] = self.use_batch_norm + config["num_torch_layers"] = self.num_torch_layers + return config + + +class ClassifierWithNoSpecialCasing(models.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fc1 = torch.nn.Linear(2, 4) + self.bn1 = torch.nn.BatchNorm1d(4) + self.fc2 = torch.nn.Linear(4, 4) + self.fc3 = layers.Dense(2) + + def call(self, x, training=None): + return self.fc3(self.fc2(self.bn1(self.fc1(x)))) + + +@pytest.mark.skipif( + backend.backend() != "torch", reason="Requires torch backend" +) +class TorchUtilsTest(testing.TestCase): + @parameterized.parameters( + {"use_batch_norm": False, "num_torch_layers": 1}, + {"use_batch_norm": True, "num_torch_layers": 1}, + ) + def test_basic_usage(self, use_batch_norm, num_torch_layers): + model = Classifier(use_batch_norm, num_torch_layers) + self.assertEqual(len(model.layers), 2) + # Linear - Weights, bias, BN - beta, gamma + torch_trainable_count = 0 + for i, layer in zip(range(num_torch_layers), model.torch_wrappers): + layer_trainable_count = 2 + if use_batch_norm: + layer_trainable_count += 2 + self.assertEqual( + len(layer.trainable_weights), layer_trainable_count + ) + torch_trainable_count += layer_trainable_count + model(np.random.random((3, 2))) + self.assertEqual(len(model.layers), 2 * num_torch_layers) + self.assertEqual( + len(model.trainable_weights), torch_trainable_count + 2 + ) + model.compile(optimizer="sgd", loss="mse") + model.fit(np.random.random((3, 2)), np.random.random((3, 1))) + + @parameterized.named_parameters( + ( + "explicit_torch_wrapper", + Classifier, + {"use_batch_norm": True, "num_torch_layers": 1}, + ), + ("implicit_torch_wrapper", ClassifierWithNoSpecialCasing, {}), + ) + def test_training_args(self, cls, kwargs): + model = cls(**kwargs) + model(np.random.random((3, 2)), training=False) # Eager call to build + ref_weights = model.get_weights() + ref_running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + + # Test training=False doesn't affect model weights + model(np.random.random((3, 2)), training=False) + weights = model.get_weights() + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Test training=None affects BN's stats + model.set_weights(ref_weights) # Restore previous weights + model(np.random.random((3, 2))) + running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + self.assertNotAllClose(running_mean, ref_running_mean) + + # Test training=True affects BN's stats + model.set_weights(ref_weights) # Restore previous weights + model(np.random.random((3, 2)), training=True) + running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + self.assertNotAllClose(running_mean, ref_running_mean) + + def test_module_autowrapping(self): + model = ClassifierWithNoSpecialCasing() + self.assertIsInstance(model.fc1, TorchModuleWrapper) + self.assertIsInstance(model.bn1, TorchModuleWrapper) + self.assertIsInstance(model.fc2, TorchModuleWrapper) + self.assertFalse(isinstance(model.fc3, TorchModuleWrapper)) + self.assertEqual(len(model.fc1.trainable_weights), 2) + self.assertEqual(len(model.bn1.trainable_weights), 2) + self.assertEqual(len(model.fc2.trainable_weights), 2) + model(np.random.random((3, 2))) + self.assertEqual(len(model.layers), 4) + self.assertEqual(len(model.fc3.trainable_weights), 2) + self.assertEqual(len(model.trainable_weights), 8) + model.compile(optimizer="sgd", loss="mse") + model.fit(np.random.random((3, 2)), np.random.random((3, 2))) + + def test_load_weights_autowrapping(self): + # Test loading weights + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + model = ClassifierWithNoSpecialCasing() + model.compile(optimizer="sgd", loss="mse") + x, y = np.random.random((3, 2)), np.random.random((3, 1)) + x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1)) + model.fit(x, y) + ref_loss = model.evaluate(x_test, y_test) + model.save_weights(temp_filepath) + + new_model = ClassifierWithNoSpecialCasing() + new_model(np.random.random((3, 2))) + new_model.compile(optimizer="sgd", loss="mse") + new_model.load_weights(temp_filepath) + for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()): + self.assertAllClose(ref_w, new_w, atol=1e-5) + loss = new_model.evaluate(x_test, y_test) + self.assertAllClose(ref_loss, loss, atol=1e-5) + + def test_serialize_model_autowrapping(self): + # Test loading saved model + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + model = ClassifierWithNoSpecialCasing() + model.compile(optimizer="sgd", loss="mse") + x, y = np.random.random((3, 2)), np.random.random((3, 1)) + x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1)) + model.fit(x, y) + ref_loss = model.evaluate(x_test, y_test) + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()): + self.assertAllClose(ref_w, new_w, atol=1e-5) + loss = new_model.evaluate(x_test, y_test) + self.assertAllClose(ref_loss, loss, atol=1e-5) + + @parameterized.parameters( + {"use_batch_norm": False, "num_torch_layers": 1}, + {"use_batch_norm": True, "num_torch_layers": 1}, + {"use_batch_norm": False, "num_torch_layers": 2}, + {"use_batch_norm": True, "num_torch_layers": 2}, + ) + def test_load_weights(self, use_batch_norm, num_torch_layers): + # Test loading weights + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + model = Classifier(use_batch_norm, num_torch_layers) + model.compile(optimizer="sgd", loss="mse") + x, y = np.random.random((3, 2)), np.random.random((3, 1)) + x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1)) + model.fit(x, y) + ref_loss = model.evaluate(x_test, y_test) + model.save_weights(temp_filepath) + + new_model = Classifier(use_batch_norm, num_torch_layers) + new_model(np.random.random((3, 2))) + new_model.compile(optimizer="sgd", loss="mse") + new_model.load_weights(temp_filepath) + for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()): + self.assertAllClose(ref_w, new_w, atol=1e-5) + loss = new_model.evaluate(x_test, y_test) + self.assertAllClose(ref_loss, loss, atol=1e-5) + + @parameterized.parameters( + {"use_batch_norm": False, "num_torch_layers": 1}, + {"use_batch_norm": True, "num_torch_layers": 1}, + {"use_batch_norm": False, "num_torch_layers": 2}, + {"use_batch_norm": True, "num_torch_layers": 2}, + ) + def test_serialize_model(self, use_batch_norm, num_torch_layers): + # Test loading saved model + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + model = Classifier(use_batch_norm, num_torch_layers) + model.compile(optimizer="sgd", loss="mse") + x, y = np.random.random((3, 2)), np.random.random((3, 1)) + x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1)) + model.fit(x, y) + ref_loss = model.evaluate(x_test, y_test) + model.save(temp_filepath) + + new_model = saving.load_model(temp_filepath) + for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()): + self.assertAllClose(ref_w, new_w, atol=1e-5) + loss = new_model.evaluate(x_test, y_test) + self.assertAllClose(ref_loss, loss, atol=1e-5) + + def test_from_config(self): + module = torch.nn.Sequential(torch.nn.Linear(2, 4)) + mw = TorchModuleWrapper(module) + config = mw.get_config() + new_mw = TorchModuleWrapper.from_config(config) + for ref_w, new_w in zip(mw.get_weights(), new_mw.get_weights()): + self.assertAllClose(ref_w, new_w, atol=1e-5) + + def test_build_model(self): + x = keras.Input([4]) + z = TorchModuleWrapper(torch.nn.Linear(4, 8), output_shape=[None, 8])(x) + y = TorchModuleWrapper(torch.nn.Linear(8, 16), output_shape=[None, 16])( + z + ) + model = keras.Model(x, y) + self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16)) + self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16)) + + @parameterized.named_parameters( + ("safe_mode", True), + ("unsafe_mode", False), + ) + def test_save_load(self, safe_mode): + @keras.saving.register_keras_serializable() + class M(keras.Model): + def __init__(self, module, **kwargs): + super().__init__(**kwargs) + self.module = module + + def call(self, x): + return self.module(x) + + def get_config(self): + base_config = super().get_config() + config = {"module": self.module} + return {**base_config, **config} + + @classmethod + def from_config(cls, config): + config["module"] = saving.deserialize_keras_object( + config["module"] + ) + return cls(**config) + + m = M(torch.nn.Conv2d(1, 10, kernel_size=(3, 3))) + device = get_device() # Get the current device (e.g., "cuda" or "cpu") + x = torch.ones( + (10, 1, 28, 28), device=device + ) # Place input on the correct device + ref_output = m(x) + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") + m.save(temp_filepath) + + if safe_mode: + with self.assertRaisesRegex(ValueError, "arbitrary code execution"): + saving.load_model(temp_filepath, safe_mode=safe_mode) + else: + new_model = saving.load_model(temp_filepath, safe_mode=safe_mode) + self.assertAllClose(new_model(x), ref_output) diff --git a/keras/src/utils/traceback_utils.py b/keras/src/utils/traceback_utils.py new file mode 100644 index 000000000000..88c3e9ac0ba2 --- /dev/null +++ b/keras/src/utils/traceback_utils.py @@ -0,0 +1,241 @@ +import inspect +import os +import traceback +import types +from functools import wraps + +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + +_EXCLUDED_PATHS = ( + os.path.abspath(os.path.join(__file__, "..", "..")), + os.path.join("tensorflow", "python"), +) + + +@keras_export("keras.config.enable_traceback_filtering") +def enable_traceback_filtering(): + """Turn on traceback filtering. + + Raw Keras tracebacks (also known as stack traces) + involve many internal frames, which can be + challenging to read through, while not being actionable for end users. + By default, Keras filters internal frames in most exceptions that it + raises, to keep traceback short, readable, and focused on what's + actionable for you (your own code). + + See also `keras.config.disable_traceback_filtering()` and + `keras.config.is_traceback_filtering_enabled()`. + + If you have previously disabled traceback filtering via + `keras.config.disable_traceback_filtering()`, you can re-enable it via + `keras.config.enable_traceback_filtering()`. + """ + global_state.set_global_attribute("traceback_filtering", True) + + +@keras_export("keras.config.disable_traceback_filtering") +def disable_traceback_filtering(): + """Turn off traceback filtering. + + Raw Keras tracebacks (also known as stack traces) + involve many internal frames, which can be + challenging to read through, while not being actionable for end users. + By default, Keras filters internal frames in most exceptions that it + raises, to keep traceback short, readable, and focused on what's + actionable for you (your own code). + + See also `keras.config.enable_traceback_filtering()` and + `keras.config.is_traceback_filtering_enabled()`. + + If you have previously disabled traceback filtering via + `keras.config.disable_traceback_filtering()`, you can re-enable it via + `keras.config.enable_traceback_filtering()`. + """ + global_state.set_global_attribute("traceback_filtering", False) + + +@keras_export("keras.config.is_traceback_filtering_enabled") +def is_traceback_filtering_enabled(): + """Check if traceback filtering is enabled. + + Raw Keras tracebacks (also known as stack traces) + involve many internal frames, which can be + challenging to read through, while not being actionable for end users. + By default, Keras filters internal frames in most exceptions that it + raises, to keep traceback short, readable, and focused on what's + actionable for you (your own code). + + See also `keras.config.enable_traceback_filtering()` and + `keras.config.disable_traceback_filtering()`. + + If you have previously disabled traceback filtering via + `keras.config.disable_traceback_filtering()`, you can re-enable it via + `keras.config.enable_traceback_filtering()`. + + Returns: + Boolean, `True` if traceback filtering is enabled, + and `False` otherwise. + """ + return global_state.get_global_attribute("traceback_filtering", True) + + +def include_frame(fname): + for exclusion in _EXCLUDED_PATHS: + if exclusion in fname: + return False + return True + + +def _process_traceback_frames(tb): + """Iterate through traceback frames and return a new, filtered traceback.""" + last_tb = None + tb_list = list(traceback.walk_tb(tb)) + for f, line_no in reversed(tb_list): + if include_frame(f.f_code.co_filename): + last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no) + if last_tb is None and tb_list: + # If no frames were kept during filtering, create a new traceback + # from the outermost function. + f, line_no = tb_list[-1] + last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no) + return last_tb + + +def filter_traceback(fn): + """Filter out Keras-internal traceback frames in exceptions raised by fn.""" + + @wraps(fn) + def error_handler(*args, **kwargs): + if not is_traceback_filtering_enabled(): + return fn(*args, **kwargs) + + filtered_tb = None + try: + return fn(*args, **kwargs) + except Exception as e: + filtered_tb = _process_traceback_frames(e.__traceback__) + # To get the full stack trace, call: + # `keras.config.disable_traceback_filtering()` + raise e.with_traceback(filtered_tb) from None + finally: + del filtered_tb + + return error_handler + + +def inject_argument_info_in_traceback(fn, object_name=None): + """Add information about call argument values to an error message. + + Arguments: + fn: Function to wrap. Exceptions raised by the this function will be + re-raised with additional information added to the error message, + displaying the values of the different arguments that the function + was called with. + object_name: String, display name of the class/function being called, + e.g. `'layer "layer_name" (LayerClass)'`. + + Returns: + A wrapped version of `fn`. + """ + if backend.backend() == "tensorflow": + from tensorflow import errors as tf_errors + else: + tf_errors = None + + @wraps(fn) + def error_handler(*args, **kwargs): + if not is_traceback_filtering_enabled(): + return fn(*args, **kwargs) + + signature = None + bound_signature = None + try: + return fn(*args, **kwargs) + except Exception as e: + if hasattr(e, "_keras_call_info_injected"): + # Only inject info for the innermost failing call + raise e + signature = inspect.signature(fn) + try: + # The first argument is `self`, so filter it out + bound_signature = signature.bind(*args, **kwargs) + except TypeError: + # Likely unbindable arguments + raise e + + # Add argument context + arguments_context = [] + for arg in list(signature.parameters.values()): + if arg.name in bound_signature.arguments: + value = tree.map_structure( + format_argument_value, + bound_signature.arguments[arg.name], + ) + else: + value = arg.default + arguments_context.append(f" • {arg.name}={value}") + if arguments_context: + arguments_context = "\n".join(arguments_context) + # Get original error message and append information to it. + if tf_errors is not None and isinstance(e, tf_errors.OpError): + message = e.message + elif e.args: + # Canonically, the 1st argument in an exception is the error + # message. This works for all built-in Python exceptions. + message = e.args[0] + else: + message = "" + display_name = f"{object_name if object_name else fn.__name__}" + message = ( + f"Exception encountered when calling {display_name}.\n\n" + f"\x1b[1m{message}\x1b[0m\n\n" + f"Arguments received by {display_name}:\n" + f"{arguments_context}" + ) + + # Reraise exception, with added context + if tf_errors is not None and isinstance(e, tf_errors.OpError): + new_e = e.__class__(e.node_def, e.op, message, e.error_code) + else: + try: + # For standard exceptions such as ValueError, TypeError, + # etc. + new_e = e.__class__(message) + except TypeError: + # For any custom error that doesn't have a standard + # signature. + new_e = RuntimeError(message) + new_e._keras_call_info_injected = True + else: + new_e = e + raise new_e.with_traceback(e.__traceback__) from None + finally: + del signature + del bound_signature + + return error_handler + + +def format_argument_value(value): + if backend.is_tensor(value): + # Simplified representation for eager / graph tensors + # to keep messages readable + if backend.backend() == "tensorflow": + tensor_cls = "tf.Tensor" + elif backend.backend() == "jax": + tensor_cls = "jnp.ndarray" + elif backend.backend() == "torch": + tensor_cls = "torch.Tensor" + elif backend.backend() == "numpy": + tensor_cls = "np.ndarray" + else: + tensor_cls = "array" + + return ( + f"{tensor_cls}(shape={value.shape}, " + f"dtype={backend.standardize_dtype(value.dtype)})" + ) + return repr(value) diff --git a/keras/src/utils/tracking.py b/keras/src/utils/tracking.py new file mode 100644 index 000000000000..0c8e1e8447ea --- /dev/null +++ b/keras/src/utils/tracking.py @@ -0,0 +1,355 @@ +from functools import wraps + +from keras.src import tree +from keras.src.backend.common.global_state import get_global_attribute +from keras.src.backend.common.global_state import set_global_attribute +from keras.src.utils import python_utils + + +class DotNotTrackScope: + def __enter__(self): + self.original_value = is_tracking_enabled() + set_global_attribute("tracking_on", False) + + def __exit__(self, *args, **kwargs): + set_global_attribute("tracking_on", self.original_value) + + +def is_tracking_enabled(): + return get_global_attribute("tracking_on", True) + + +def no_automatic_dependency_tracking(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with DotNotTrackScope(): + return fn(*args, **kwargs) + + return wrapper + + +class Tracker: + """Attribute tracker, used for e.g. Variable tracking. + + Monitors certain attribute types + and put them in appropriate lists in case of a match. + + Also passively tracks certain mutable collections + (dict, list) so that items added to them later + still get tracked. This is done by wrapping these + collections into an equivalent, tracking-aware object. + + Example: + + ```python + def __init__(self): + self.tracker = Tracker( + # Format: `name: (test_fn, store)` + { + "variables": + (lambda x: isinstance(x, Variable), self._variables), + "metrics": (lambda x: isinstance(x, Metric), self._metrics), + "layers": (lambda x: isinstance(x, Layer), self._layers), + } + ) + + def __setattr__(self, name, value): + if hasattr(self, "_tracker"): + value = self._tracker.track(value) + return super().__setattr__(name, value) + ``` + """ + + def __init__(self, config, exclusions=None): + self.config = config + self.stored_ids = {name: set() for name in self.config.keys()} + self.locked = False + self._lock_violation_msg = None + self.exclusions = exclusions or {} + + def track(self, attr): + if not is_tracking_enabled(): + return attr + + for store_name, (is_attr_type, _) in self.config.items(): + if is_attr_type(attr): + if store_name in self.exclusions: + for excl in self.exclusions[store_name]: + if self.is_in_store(excl, attr): + return attr + if not self.is_in_store(store_name, attr): + self.add_to_store(store_name, attr) + return attr + if isinstance(attr, tuple) and hasattr(attr, "_fields"): + # Named tuple case. + wrapped_attr = {} + for name, e in attr._asdict().items(): + wrapped_attr[name] = self.track(e) + return attr.__class__(**wrapped_attr) + if isinstance(attr, tuple): + wrapped_attr = [] + for e in attr: + wrapped_attr.append(self.track(e)) + return attr.__class__(wrapped_attr) + elif isinstance(attr, list): + return TrackedList(attr, self) + elif isinstance(attr, dict): + # TODO: OrderedDict? + return TrackedDict(attr, self) + elif isinstance(attr, set): + return TrackedSet(attr, self) + return attr + + def untrack(self, value): + for store_name in self.stored_ids.keys(): + if id(value) in self.stored_ids[store_name]: + self.stored_ids[store_name].remove(id(value)) + python_utils.remove_by_id(self.config[store_name][1], value) + + def lock(self, msg=None): + self.locked = True + if msg is not None: + self._lock_violation_msg = msg + + def unlock(self): + self.locked = False + + def add_to_store(self, store_name, value): + if self.locked: + raise ValueError(self._lock_violation_msg) + self.config[store_name][1].append(value) + self.stored_ids[store_name].add(id(value)) + + def is_in_store(self, store_name, value): + return id(value) in self.stored_ids[store_name] + + def replace_tracked_value(self, store_name, old_value, new_value): + if not self.is_in_store(store_name, old_value): + raise ValueError(f"Unknown value: {old_value}") + store_list = self.config[store_name][1] + index = store_list.index(old_value) + store_list[index] = new_value + self.stored_ids[store_name].remove(id(old_value)) + self.stored_ids[store_name].add(id(new_value)) + + +@tree.register_tree_node_class +class TrackedList(list): + def __init__(self, values=None, tracker=None): + self.tracker = tracker + if tracker and values: + values = [tracker.track(v) for v in values] + super().__init__(values or []) + + def append(self, value): + if self.tracker: + self.tracker.track(value) + super().append(value) + + def insert(self, index, value): + if self.tracker: + self.tracker.track(value) + super().insert(index, value) + + def extend(self, values): + if self.tracker: + values = [self.tracker.track(v) for v in values] + super().extend(values) + + def remove(self, value): + if self.tracker: + self.tracker.untrack(value) + try: + super().remove(value) + except ValueError: + python_utils.remove_by_id(self, value) + + def pop(self, index=-1): + if self.tracker: + value = self[index] + self.tracker.untrack(value) + return super().pop(index) + else: + return super().pop(index) + + def clear(self): + if self.tracker: + for value in self: + self.tracker.untrack(value) + super().clear() + + def __delitem__(self, index): + value = self[index] # Get value before removing + super().__delitem__(index) + if self.tracker: + self.tracker.untrack(value) + + def tree_flatten(self): + # For optree / dmtree + return (self, None) + + @classmethod + def tree_unflatten(cls, metadata, children): + # For optree / dmtree + return cls(children) + + def torchtree_flatten(self): + # For torchtree + # Returns (values, metadata) + return (self, None) + + @classmethod + def torchtree_unflatten(cls, children, metadata): + # For torchtree + # Requires (children, metadata) + return cls(children) + + def torchtree_flatten_with_keys(self): + # For torchtree + # Returns (children, metadata) + from torch.utils import _pytree as torch_tree + + values, context = self.torchtree_flatten() + return [ + (torch_tree.SequenceKey(i), v) for i, v in enumerate(values) + ], context + + +@tree.register_tree_node_class +class TrackedDict(dict): + def __init__(self, values=None, tracker=None): + self.tracker = tracker + if tracker and values: + values = {k: tracker.track(v) for k, v in values.items()} + super().__init__(values or []) + + def __setitem__(self, key, value): + if self.tracker: + self.tracker.track(value) + super().__setitem__(key, value) + + def update(self, mapping): + if self.tracker: + mapping = {k: self.tracker.track(v) for k, v in mapping.items()} + super().update(mapping) + + def pop(self, key, default=None): + if self.tracker: + value = super().pop(key, default) + if value is not default: + self.tracker.untrack(value) + return value + else: + return super().pop(key, default) + + def popitem(self): + key, value = super().popitem() + if self.tracker: + self.tracker.untrack(value) + return key, value + + def clear(self): + if self.tracker: + for value in self.values(): + self.tracker.untrack(value) + super().clear() + + def tree_flatten(self): + # For optree / dmtree + keys = sorted(list(self.keys())) + values = [self[k] for k in keys] + return values, keys, keys + + @classmethod + def tree_unflatten(cls, keys, values): + # For optree / dmtree + return cls(zip(keys, values)) + + def torchtree_flatten(self): + # For torch_tree + # Returns (values, metadata) + keys = sorted(list(self.keys())) + values = [self[k] for k in keys] + return values, keys + + @classmethod + def torchtree_unflatten(cls, values, keys): + # For torch_tree + # Requires (children, metadata) + return cls(zip(keys, values)) + + def torchtree_flatten_with_keys(self): + # For torchtree + # Returns (children, metadata) + from torch.utils import _pytree as torch_tree + + values, context = self.torchtree_flatten() + return [ + (torch_tree.MappingKey(k), v) for k, v in zip(context, values) + ], context + + +@tree.register_tree_node_class +class TrackedSet(set): + def __init__(self, values=None, tracker=None): + self.tracker = tracker + if tracker and values: + values = {tracker.track(v) for v in values} + super().__init__(values or []) + + def add(self, value): + if self.tracker: + self.tracker.track(value) + super().add(value) + + def update(self, values): + if self.tracker: + values = [self.tracker.track(v) for v in values] + super().update(values) + + def remove(self, value): + if self.tracker: + self.tracker.untrack(value) + super().remove(value) + + def pop(self): + value = super().pop() + if self.tracker: + self.tracker.untrack(value) + return value + + def clear(self): + if self.tracker: + for value in self: + self.tracker.untrack(value) + super().clear() + + def tree_flatten(self): + # For optree / dmtree + return (self, None) + + @classmethod + def tree_unflatten(cls, metadata, children): + # For optree / dmtree + return cls(children) + + def torchtree_flatten(self): + # For torchtree + # Returns (values, metadata) + return (self, None) + + @classmethod + def torchtree_unflatten(cls, children, metadata): + # For torchtree + # Requires (values, metadata) + return cls(children) + + def torchtree_flatten_with_keys(self): + # For torchtree + # Returns (children, metadata) + from torch.utils import _pytree as torch_tree + + values, context = self.torchtree_flatten() + return [ + (torch_tree.SequenceKey(i), v) for i, v in enumerate(values) + ], context diff --git a/keras/src/utils/tracking_test.py b/keras/src/utils/tracking_test.py new file mode 100644 index 000000000000..961e7da89526 --- /dev/null +++ b/keras/src/utils/tracking_test.py @@ -0,0 +1,99 @@ +import collections + +from keras.src import backend +from keras.src import testing +from keras.src.utils import tracking + + +class TrackingTest(testing.TestCase): + def test_untracking_in_tracked_list(self): + tracked_variables = [] + tracker = tracking.Tracker( + { + "variables": ( + lambda x: isinstance(x, backend.Variable), + tracked_variables, + ), + } + ) + v1 = backend.Variable(1.0) + v2 = backend.Variable(2.0) + lst = tracking.TrackedList([], tracker) + lst.append(v1) + lst.append(float("nan")) + lst.append(v2) + lst.append(0) + + self.assertLen(tracked_variables, 2) + self.assertEqual(tracked_variables[0], v1) + self.assertEqual(tracked_variables[1], v2) + + lst.remove(v1) + self.assertLen(lst, 3) + self.assertLen(tracked_variables, 1) + + lst.remove(v2) + self.assertLen(lst, 2) + self.assertLen(tracked_variables, 0) + + lst2 = tracking.TrackedList([], tracker) + lst2.append(v1) + lst2.append(float("nan")) + lst2.append(v2) + lst2.append(0) + + popped_value = lst2.pop() + self.assertEqual(popped_value, 0) + self.assertLen(lst2, 3) + self.assertLen(tracked_variables, 2) + + lst2.clear() + self.assertLen(lst2, 0) + self.assertLen(tracked_variables, 0) + + lst2.append(v1) + lst2.append(v2) + del lst2[0] + self.assertLen(lst2, 1) + self.assertLen(tracked_variables, 1) + + def test_tuple_tracking(self): + tracked_variables = [] + tracker = tracking.Tracker( + { + "variables": ( + lambda x: isinstance(x, backend.Variable), + tracked_variables, + ), + } + ) + v1 = backend.Variable(1.0) + v2 = backend.Variable(2.0) + tup = (v1, v2) + tup = tracker.track(tup) + self.assertIsInstance(tup, tuple) + self.assertLen(tracked_variables, 2) + self.assertEqual(tracked_variables[0], v1) + self.assertEqual(tracked_variables[1], v2) + + def test_namedtuple_tracking(self): + tracked_variables = [] + tracker = tracking.Tracker( + { + "variables": ( + lambda x: isinstance(x, backend.Variable), + tracked_variables, + ), + } + ) + v1 = backend.Variable(1.0) + v2 = backend.Variable(2.0) + nt = collections.namedtuple("NT", ["x", "y"]) + tup = nt(x=v1, y=v2) + tup = tracker.track(tup) + self.assertIsInstance(tup, tuple) + self.assertEqual(tup.x, v1) + self.assertEqual(tup.y, v2) + self.assertLen(tracked_variables, 2) + self.assertEqual(tracked_variables[0], v1) + self.assertEqual(tracked_variables[1], v2) diff --git a/keras/src/version.py b/keras/src/version.py new file mode 100644 index 000000000000..380071698b67 --- /dev/null +++ b/keras/src/version.py @@ -0,0 +1,9 @@ +from keras.src.api_export import keras_export + +# Unique source of truth for the version number. +__version__ = "3.12.0" + + +@keras_export("keras.version") +def version(): + return __version__ diff --git a/keras/src/visualization/__init__.py b/keras/src/visualization/__init__.py new file mode 100644 index 000000000000..04524f857be5 --- /dev/null +++ b/keras/src/visualization/__init__.py @@ -0,0 +1,2 @@ +from keras.src.visualization import draw_bounding_boxes +from keras.src.visualization import plot_image_gallery diff --git a/keras/src/visualization/draw_bounding_boxes.py b/keras/src/visualization/draw_bounding_boxes.py new file mode 100644 index 000000000000..e5e93920d2e4 --- /dev/null +++ b/keras/src/visualization/draw_bounding_boxes.py @@ -0,0 +1,177 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 + convert_format, +) + +try: + import cv2 +except ImportError: + cv2 = None + + +@keras_export("keras.visualization.draw_bounding_boxes") +def draw_bounding_boxes( + images, + bounding_boxes, + bounding_box_format, + class_mapping=None, + color=(128, 128, 128), + line_thickness=2, + text_thickness=1, + font_scale=1.0, + data_format=None, +): + """Draws bounding boxes on images. + + This function draws bounding boxes on a batch of images. It supports + different bounding box formats and can optionally display class labels + and confidences. + + Args: + images: A batch of images as a 4D tensor or NumPy array. Shape should be + `(batch_size, height, width, channels)`. + bounding_boxes: A dictionary containing bounding box data. Should have + the following keys: + - `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)` + containing the bounding box coordinates in the specified format. + - `labels`: A tensor or array of shape `(batch_size, num_boxes)` + containing the class labels for each bounding box. + - `confidences` (Optional): A tensor or array of shape + `(batch_size, num_boxes)` containing the confidence scores for + each bounding box. + bounding_box_format: A string specifying the format of the bounding + boxes. Refer [keras-io](TODO) + class_mapping: A dictionary mapping class IDs (integers) to class labels + (strings). Used to display class labels next to the bounding boxes. + Defaults to None (no labels displayed). + color: A tuple or list representing the RGB color of the bounding boxes. + For example, `(255, 0, 0)` for red. Defaults to `(128, 128, 128)`. + line_thickness: An integer specifying the thickness of the bounding box + lines. Defaults to `2`. + text_thickness: An integer specifying the thickness of the text labels. + Defaults to `1`. + font_scale: A float specifying the scale of the font used for text + labels. Defaults to `1.0`. + data_format: A string, either `"channels_last"` or `"channels_first"`, + specifying the order of dimensions in the input images. Defaults to + the `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + "channels_last". + + Returns: + A NumPy array of the annotated images with the bounding boxes drawn. + The array will have the same shape as the input `images`. + + Raises: + ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is + not a dictionary, or if `bounding_boxes` does not contain `"boxes"` + and `"labels"` keys. + TypeError: If `bounding_boxes` is not a dictionary. + ImportError: If `cv2` (OpenCV) is not installed. + """ + + if cv2 is None: + raise ImportError( + "The `draw_bounding_boxes` function requires the `cv2` package " + " (OpenCV). Please install it with `pip install opencv-python`." + ) + + class_mapping = class_mapping or {} + text_thickness = ( + text_thickness or line_thickness + ) # Default text_thickness if not provided. + data_format = data_format or backend.image_data_format() + images_shape = ops.shape(images) + if len(images_shape) != 4: + raise ValueError( + "`images` must be batched 4D tensor. " + f"Received: images.shape={images_shape}" + ) + if not isinstance(bounding_boxes, dict): + raise TypeError( + "`bounding_boxes` should be a dict. " + f"Received: bounding_boxes={bounding_boxes} of type " + f"{type(bounding_boxes)}" + ) + if "boxes" not in bounding_boxes or "labels" not in bounding_boxes: + raise ValueError( + "`bounding_boxes` should be a dict containing 'boxes' and " + f"'labels' keys. Received: bounding_boxes={bounding_boxes}" + ) + if data_format == "channels_last": + h_axis = -3 + w_axis = -2 + else: + h_axis = -2 + w_axis = -1 + height = images_shape[h_axis] + width = images_shape[w_axis] + bounding_boxes = bounding_boxes.copy() + bounding_boxes = convert_format( + bounding_boxes, bounding_box_format, "xyxy", height, width + ) + + # To numpy array + images = ops.convert_to_numpy(images).astype("uint8") + boxes = ops.convert_to_numpy(bounding_boxes["boxes"]) + labels = ops.convert_to_numpy(bounding_boxes["labels"]) + if "confidences" in bounding_boxes: + confidences = ops.convert_to_numpy(bounding_boxes["confidences"]) + else: + confidences = None + + result = [] + batch_size = images.shape[0] + for i in range(batch_size): + _image = images[i] + _box = boxes[i] + _class = labels[i] + for box_i in range(_box.shape[0]): + x1, y1, x2, y2 = _box[box_i].astype("int32") + c = _class[box_i].astype("int32") + if c == -1: + continue + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + c = int(c) + # Draw bounding box + cv2.rectangle(_image, (x1, y1), (x2, y2), color, line_thickness) + + if c in class_mapping: + label = class_mapping[c] + if confidences is not None: + conf = confidences[i][box_i] + label = f"{label} | {conf:.2f}" + + font_x1, font_y1 = _find_text_location( + x1, y1, font_scale, text_thickness + ) + cv2.putText( + img=_image, + text=label, + org=(font_x1, font_y1), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=font_scale, + color=color, + thickness=text_thickness, + ) + result.append(_image) + return np.stack(result, axis=0) + + +def _find_text_location(x, y, font_scale, thickness): + font_height = int(font_scale * 12) + target_y = y - 8 + if target_y - (2 * font_height) > 0: + return x, y - 8 + + line_offset = thickness + static_offset = 3 + + return ( + x + static_offset, + y + (2 * font_height) + line_offset + static_offset, + ) diff --git a/keras/src/visualization/draw_segmentation_masks.py b/keras/src/visualization/draw_segmentation_masks.py new file mode 100644 index 000000000000..0fa8c6fbb7a1 --- /dev/null +++ b/keras/src/visualization/draw_segmentation_masks.py @@ -0,0 +1,109 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export + + +@keras_export("keras.visualization.draw_segmentation_masks") +def draw_segmentation_masks( + images, + segmentation_masks, + num_classes=None, + color_mapping=None, + alpha=0.8, + blend=True, + ignore_index=-1, + data_format=None, +): + """Draws segmentation masks on images. + + The function overlays segmentation masks on the input images. + The masks are blended with the images using the specified alpha value. + + Args: + images: A batch of images as a 4D tensor or NumPy array. Shape + should be (batch_size, height, width, channels). + segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor + or NumPy array. Shape should be (batch_size, height, width) or + (batch_size, height, width, 1). The values represent class indices + starting from 1 up to `num_classes`. Class 0 is reserved for + the background and will be ignored if `ignore_index` is not 0. + num_classes: The number of segmentation classes. If `None`, it is + inferred from the maximum value in `segmentation_masks`. + color_mapping: A dictionary mapping class indices to RGB colors. + If `None`, a default color palette is generated. The keys should be + integers starting from 1 up to `num_classes`. + alpha: The opacity of the segmentation masks. Must be in the range + `[0, 1]`. + blend: Whether to blend the masks with the input image using the + `alpha` value. If `False`, the masks are drawn directly on the + images without blending. Defaults to `True`. + ignore_index: The class index to ignore. Mask pixels with this value + will not be drawn. Defaults to -1. + data_format: Image data format, either `"channels_last"` or + `"channels_first"`. Defaults to the `image_data_format` value found + in your Keras config file at `~/.keras/keras.json`. If you never + set it, then it will be `"channels_last"`. + + Returns: + A NumPy array of the images with the segmentation masks overlaid. + + Raises: + ValueError: If the input `images` is not a 4D tensor or NumPy array. + TypeError: If the input `segmentation_masks` is not an integer type. + """ + data_format = data_format or backend.image_data_format() + images_shape = ops.shape(images) + if len(images_shape) != 4: + raise ValueError( + "`images` must be batched 4D tensor. " + f"Received: images.shape={images_shape}" + ) + if data_format == "channels_first": + images = ops.transpose(images, (0, 2, 3, 1)) + segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1)) + images = ops.convert_to_tensor(images, dtype="float32") + segmentation_masks = ops.convert_to_tensor(segmentation_masks) + + if not backend.is_int_dtype(segmentation_masks.dtype): + dtype = backend.standardize_dtype(segmentation_masks.dtype) + raise TypeError( + "`segmentation_masks` must be in integer dtype. " + f"Received: segmentation_masks.dtype={dtype}" + ) + + # Infer num_classes + if num_classes is None: + num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks))) + if color_mapping is None: + colors = _generate_color_palette(num_classes) + else: + colors = [color_mapping[i] for i in range(num_classes)] + valid_masks = ops.not_equal(segmentation_masks, ignore_index) + valid_masks = ops.squeeze(valid_masks, axis=-1) + segmentation_masks = ops.one_hot(segmentation_masks, num_classes) + segmentation_masks = segmentation_masks[..., 0, :] + segmentation_masks = ops.convert_to_numpy(segmentation_masks) + + # Replace class with color + masks = segmentation_masks + masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool") + images_to_draw = ops.convert_to_numpy(images).copy() + for mask, color in zip(masks, colors): + color = np.array(color, dtype=images_to_draw.dtype) + images_to_draw[mask, ...] = color[None, :] + images_to_draw = ops.convert_to_tensor(images_to_draw) + outputs = ops.cast(images_to_draw, dtype="float32") + + if blend: + outputs = images * (1 - alpha) + outputs * alpha + outputs = ops.where(valid_masks[..., None], outputs, images) + outputs = ops.cast(outputs, dtype="uint8") + outputs = ops.convert_to_numpy(outputs) + return outputs + + +def _generate_color_palette(num_classes): + palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1]) + return [((i * palette) % 255).tolist() for i in range(num_classes)] diff --git a/keras/src/visualization/plot_bounding_box_gallery.py b/keras/src/visualization/plot_bounding_box_gallery.py new file mode 100644 index 000000000000..3fe3242f718c --- /dev/null +++ b/keras/src/visualization/plot_bounding_box_gallery.py @@ -0,0 +1,165 @@ +import functools + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes +from keras.src.visualization.plot_image_gallery import plot_image_gallery + +try: + from matplotlib import patches # For legend patches +except ImportError: + patches = None + + +@keras_export("keras.visualization.plot_bounding_box_gallery") +def plot_bounding_box_gallery( + images, + bounding_box_format, + y_true=None, + y_pred=None, + value_range=(0, 255), + true_color=(0, 188, 212), + pred_color=(255, 235, 59), + line_thickness=2, + font_scale=1.0, + text_thickness=None, + class_mapping=None, + ground_truth_mapping=None, + prediction_mapping=None, + legend=False, + legend_handles=None, + rows=None, + cols=None, + data_format=None, + **kwargs, +): + """Plots a gallery of images with bounding boxes. + + This function can display both ground truth and predicted bounding boxes on + a set of images. It supports various bounding box formats and can include + class labels and a legend. + + Args: + images: A 4D tensor or NumPy array of images. Shape should be + `(batch_size, height, width, channels)`. + bounding_box_format: The format of the bounding boxes. + Refer [keras-io](TODO) + y_true: A dictionary containing the ground truth bounding boxes and + labels. Should have the same structure as the `bounding_boxes` + argument in `keras.visualization.draw_bounding_boxes`. + Defaults to `None`. + y_pred: A dictionary containing the predicted bounding boxes and labels. + Should have the same structure as `y_true`. Defaults to `None`. + value_range: A tuple specifying the value range of the images + (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`. + true_color: A tuple of three integers representing the RGB color for the + ground truth bounding boxes. Defaults to `(0, 188, 212)`. + pred_color: A tuple of three integers representing the RGB color for the + predicted bounding boxes. Defaults to `(255, 235, 59)`. + line_thickness: The thickness of the bounding box lines. Defaults to 2. + font_scale: The scale of the font used for labels. Defaults to 1.0. + text_thickness: The thickness of the bounding box text. Defaults to + `line_thickness`. + class_mapping: A dictionary mapping class IDs to class names. Used f + or both ground truth and predicted boxes if `ground_truth_mapping` + and `prediction_mapping` are not provided. Defaults to `None`. + ground_truth_mapping: A dictionary mapping class IDs to class names + specifically for ground truth boxes. Overrides `class_mapping` + for ground truth. Defaults to `None`. + prediction_mapping: A dictionary mapping class IDs to class names + specifically for predicted boxes. Overrides `class_mapping` for + predictions. Defaults to `None`. + legend: A boolean indicating whether to show a legend. + Defaults to `False`. + legend_handles: A list of matplotlib `Patch` objects to use for the + legend. If this is provided, the `legend` argument will be ignored. + Defaults to `None`. + rows: The number of rows in the image gallery. Required if the images + are not batched. Defaults to `None`. + cols: The number of columns in the image gallery. Required if the images + are not batched. Defaults to `None`. + data_format: The image data format `"channels_last"` or + `"channels_first"`. Defaults to the Keras backend data format. + kwargs: Additional keyword arguments to be passed to + `keras.visualization.plot_image_gallery`. + + Returns: + The output of `keras.visualization.plot_image_gallery`. + + Raises: + ValueError: If `images` is not a 4D tensor/array or if both `legend` a + nd `legend_handles` are specified. + ImportError: if matplotlib is not installed + """ + if patches is None: + raise ImportError( + "The `plot_bounding_box_gallery` function requires the " + " `matplotlib` package. Please install it with " + " `pip install matplotlib`." + ) + + prediction_mapping = prediction_mapping or class_mapping + ground_truth_mapping = ground_truth_mapping or class_mapping + data_format = data_format or backend.image_data_format() + images_shape = ops.shape(images) + if len(images_shape) != 4: + raise ValueError( + "`images` must be batched 4D tensor. " + f"Received: images.shape={images_shape}" + ) + if data_format == "channels_first": # Ensure correct data format + images = ops.transpose(images, (0, 2, 3, 1)) + plotted_images = ops.convert_to_numpy(images) + + draw_fn = functools.partial( + draw_bounding_boxes, + bounding_box_format=bounding_box_format, + line_thickness=line_thickness, + text_thickness=text_thickness, + font_scale=font_scale, + ) + + if y_true is not None: + plotted_images = draw_fn( + plotted_images, + y_true, + color=true_color, + class_mapping=ground_truth_mapping, + ) + + if y_pred is not None: + plotted_images = draw_fn( + plotted_images, + y_pred, + color=pred_color, + class_mapping=prediction_mapping, + ) + + if legend: + if legend_handles: + raise ValueError( + "Only pass `legend` OR `legend_handles` to " + "`keras.visualization.plot_bounding_box_gallery()`." + ) + legend_handles = [ + patches.Patch( + color=np.array(true_color) / 255.0, # Normalize color + label="Ground Truth", + ), + patches.Patch( + color=np.array(pred_color) / 255.0, # Normalize color + label="Prediction", + ), + ] + + return plot_image_gallery( + plotted_images, + value_range=value_range, + legend_handles=legend_handles, + rows=rows, + cols=cols, + **kwargs, + ) diff --git a/keras/src/visualization/plot_image_gallery.py b/keras/src/visualization/plot_image_gallery.py new file mode 100644 index 000000000000..c0c57802d692 --- /dev/null +++ b/keras/src/visualization/plot_image_gallery.py @@ -0,0 +1,200 @@ +import math + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None + + +def _extract_image_batch(images, num_images, batch_size): + """Extracts a batch of images for plotting. + + Args: + images: The 4D tensor or NumPy array of images. + num_images: The number of images to extract. + batch_size: The original batch size of the images. + + Returns: + A 4D tensor or NumPy array containing the extracted images. + + Raises: + ValueError: If `images` is not a 4D tensor/array. + """ + + if len(ops.shape(images)) != 4: + raise ValueError( + "`plot_images_gallery()` requires you to " + "batch your `np.array` samples together." + ) + num_samples = min(num_images, batch_size) + sample = images[:num_samples, ...] + + return sample + + +@keras_export("keras.visualization.plot_image_gallery") +def plot_image_gallery( + images, + y_true=None, + y_pred=None, + label_map=None, + rows=None, + cols=None, + value_range=(0, 255), + scale=2, + path=None, + show=None, + transparent=True, + dpi=60, + legend_handles=None, + data_format=None, +): + """Displays a gallery of images with optional labels and predictions. + + Args: + images: A 4D tensor or NumPy array of images. Shape should be + `(batch_size, height, width, channels)`. + y_true: A 1D tensor or NumPy array of true labels (class indices). + Defaults to `None`. + y_pred: A 1D tensor or NumPy array of predicted labels (class indices). + Defaults to `None`. + label_map: A dictionary mapping class indices to class names. + Required if `y_true` or `y_pred` are provided. + Defaults to `None`. + value_range: A tuple specifying the value range of the images + (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`. + rows: The number of rows in the gallery. If `None`, it's calculated + based on the number of images and `cols`. Defaults to `None`. + cols: The number of columns in the gallery. If `None`, it's calculated + based on the number of images and `rows`. Defaults to `None`. + scale: A float controlling the size of the displayed images. The images + are scaled by this factor. Defaults to `2`. + path: The path to save the generated gallery image. If `None`, the + image is displayed using `plt.show()`. Defaults to `None`. + show: Whether to display the image using `plt.show()`. If `True`, the + image is displayed. If `False`, the image is not displayed. + Ignored if `path` is not `None`. Defaults to `True` if `path` + is `None`, `False` otherwise. + transparent: A boolean, whether to save the figure with a transparent + background. Defaults to `True`. + dpi: The DPI (dots per inch) for saving the figure. Defaults to 60. + legend_handles: A list of matplotlib `Patch` objects to use as legend + handles. Defaults to `None`. + data_format: The image data format `"channels_last"` or + `"channels_first"`. Defaults to the Keras backend data format. + + Raises: + ValueError: If both `path` and `show` are set to non-`None` values, + if `images` is not a 4D tensor or array, or if `y_true` or `y_pred` + are provided without a `label_map`. + ImportError: if matplotlib is not installed. + """ + if plt is None: + raise ImportError( + "The `plot_image_gallery` function requires the `matplotlib` " + "package. Please install it with `pip install matplotlib`." + ) + + if path is not None and show: + raise ValueError( + "plot_gallery() expects either `path` to be set, or `show` " + "to be true." + ) + + if (y_true is not None or y_pred is not None) and label_map is None: + raise ValueError( + "If `y_true` or `y_pred` are provided, a `label_map` must also be" + " provided." + ) + + show = show if show is not None else (path is None) + data_format = data_format or backend.image_data_format() + + batch_size = ops.shape(images)[0] if len(ops.shape(images)) == 4 else 1 + + rows = rows or int(math.ceil(math.sqrt(batch_size))) + cols = cols or int(math.ceil(batch_size // rows)) + num_images = rows * cols + + images = _extract_image_batch(images, num_images, batch_size) + if ( + data_format == "channels_first" + ): # Ensure correct data format for plotting + images = ops.transpose(images, (0, 2, 3, 1)) + + # Generate subplots + fig, axes = plt.subplots( + nrows=rows, + ncols=cols, + figsize=(cols * scale, rows * scale), + frameon=False, + layout="tight", + squeeze=True, + sharex="row", + sharey="col", + ) + fig.subplots_adjust(wspace=0, hspace=0) + + if isinstance(axes, np.ndarray) and len(axes.shape) == 1: + expand_axis = 0 if rows == 1 else -1 + axes = np.expand_dims(axes, expand_axis) + + if legend_handles is not None: + fig.legend(handles=legend_handles, loc="lower center") + + images = BaseImagePreprocessingLayer()._transform_value_range( + images=images, original_range=value_range, target_range=(0, 255) + ) + + images = ops.convert_to_numpy(images) + if data_format == "channels_first": + images = images.transpose(0, 2, 3, 1) + + if y_true is not None: + y_true = ops.convert_to_numpy(y_true) + if y_pred is not None: + y_pred = ops.convert_to_numpy(y_pred) + + for row in range(rows): + for col in range(cols): + index = row * cols + col + current_axis = ( + axes[row, col] if isinstance(axes, np.ndarray) else axes + ) + current_axis.imshow(images[index].astype("uint8")) + current_axis.margins(x=0, y=0) + current_axis.axis("off") + title_parts = [] + if y_true is not None and index < len(y_true): + title_parts.append( + f"Label: {label_map.get(y_true[index], 'Unknown')}" + ) + if y_pred is not None and index < len(y_pred): + title_parts.append( + f"Pred: {label_map.get(y_pred[index], 'Unknown')}" + ) + + if title_parts: + current_axis.set_title(" ".join(title_parts), fontsize=8) + + if path is not None: + plt.savefig( + fname=path, + pad_inches=0, + bbox_inches="tight", + transparent=transparent, + dpi=dpi, + ) + plt.close() + elif show: + plt.show() + plt.close() diff --git a/keras/src/visualization/plot_segmentation_mask_gallery.py b/keras/src/visualization/plot_segmentation_mask_gallery.py new file mode 100644 index 000000000000..1edf603ddf72 --- /dev/null +++ b/keras/src/visualization/plot_segmentation_mask_gallery.py @@ -0,0 +1,121 @@ +import functools + +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.visualization.draw_segmentation_masks import ( + draw_segmentation_masks, +) +from keras.src.visualization.plot_image_gallery import plot_image_gallery + + +@keras_export("keras.visualization.plot_segmentation_mask_gallery") +def plot_segmentation_mask_gallery( + images, + num_classes, + value_range=(0, 255), + y_true=None, + y_pred=None, + color_mapping=None, + blend=True, + alpha=0.8, + ignore_index=-1, + data_format=None, + **kwargs, +): + """Plots a gallery of images with corresponding segmentation masks. + + Args: + images: A 4D tensor or NumPy array of images. Shape should be + `(batch_size, height, width, channels)`. + num_classes: The number of segmentation classes. Class indices should + start from `1`. Class `0` will be treated as background and + ignored if `ignore_index` is not 0. + value_range: A tuple specifying the value range of the images + (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`. + y_true: A 3D/4D tensor or NumPy array representing the ground truth + segmentation masks. Shape should be `(batch_size, height, width)` or + `(batch_size, height, width, 1)`. Defaults to `None`. + y_pred: A 3D/4D tensor or NumPy array representing the predicted + segmentation masks. Shape should be the same as `y_true`. + Defaults to `None`. + color_mapping: A dictionary mapping class indices to RGB colors. + If `None`, a default color palette is used. Class indices start + from `1`. Defaults to `None`. + blend: Whether to blend the masks with the input image using the + `alpha` value. If `False`, the masks are drawn directly on the + images without blending. Defaults to `True`. + alpha: The opacity of the segmentation masks (a float between 0 and 1). + Defaults to `0.8`. + ignore_index: The class index to ignore when drawing masks. + Defaults to `-1`. + data_format: The image data format `"channels_last"` or + `"channels_first"`. Defaults to the Keras backend data format. + kwargs: Additional keyword arguments to be passed to + `keras.visualization.plot_image_gallery`. + + Returns: + The output of `keras.visualization.plot_image_gallery`. + + Raises: + ValueError: If `images` is not a 4D tensor/array. + """ + data_format = data_format or backend.image_data_format() + image_shape = ops.shape(images) + if len(image_shape) != 4: + raise ValueError( + "`images` must be batched 4D tensor. " + f"Received: images.shape={image_shape}" + ) + if data_format == "channels_first": + images = ops.transpose(images, (0, 2, 3, 1)) + + batch_size = image_shape[0] if len(image_shape) == 4 else 1 + + rows = batch_size + cols = 1 + + if y_true is not None: + cols += 1 + + if y_pred is not None: + cols += 1 + + images_np = ops.convert_to_numpy(images) + + draw_masks_fn = functools.partial( + draw_segmentation_masks, + num_classes=num_classes, + color_mapping=color_mapping, + alpha=alpha, + ignore_index=ignore_index, + blend=blend, + ) + + if y_true is not None: + if data_format == "channels_first": + y_true = ops.transpose(y_true, (0, 2, 3, 1)) + y_true = ops.cast(y_true, "int32") + true_masks_drawn = draw_masks_fn(images_np, y_true) + + if y_pred is not None: + if data_format == "channels_first": + y_pred = ops.transpose(y_pred, (0, 2, 3, 1)) + y_pred = ops.cast(y_pred, "int32") + predicted_masks_drawn = draw_masks_fn(images_np, y_pred) + + images_with_masks = [] + for i in range(batch_size): + images_with_masks.append(images_np[i]) + if y_true is not None: + images_with_masks.append(true_masks_drawn[i]) + if y_pred is not None: + images_with_masks.append(predicted_masks_drawn[i]) + + gallery_images = np.stack(images_with_masks, axis=0) + + return plot_image_gallery( + gallery_images, value_range=value_range, rows=rows, cols=cols, **kwargs + ) diff --git a/keras/src/wrappers/__init__.py b/keras/src/wrappers/__init__.py new file mode 100644 index 000000000000..8c55aa752f5c --- /dev/null +++ b/keras/src/wrappers/__init__.py @@ -0,0 +1,5 @@ +from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier +from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor +from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer + +__all__ = ["SKLearnClassifier", "SKLearnRegressor", "SKLearnTransformer"] diff --git a/keras/src/wrappers/fixes.py b/keras/src/wrappers/fixes.py new file mode 100644 index 000000000000..b503e4e88e82 --- /dev/null +++ b/keras/src/wrappers/fixes.py @@ -0,0 +1,83 @@ +try: + import sklearn +except ImportError: + sklearn = None + + +def _validate_data(estimator, *args, **kwargs): + """Validate the input data. + + wrapper for sklearn.utils.validation.validate_data or + BaseEstimator._validate_data depending on the scikit-learn version. + + TODO: remove when minimum scikit-learn version is 1.6 + """ + try: + # scikit-learn >= 1.6 + from sklearn.utils.validation import validate_data + + return validate_data(estimator, *args, **kwargs) + except ImportError: + return estimator._validate_data(*args, **kwargs) + except: + raise + + +def type_of_target(y, input_name="", *, raise_unknown=False): + def _raise_or_return(target_type): + """Depending on the value of raise_unknown, either raise an error or + return 'unknown'. + """ + if raise_unknown and target_type == "unknown": + input = input_name if input_name else "data" + raise ValueError(f"Unknown label type for {input}: {y!r}") + else: + return target_type + + from sklearn.utils.multiclass import type_of_target as sk_type_of_target + + target_type = sk_type_of_target(y, input_name=input_name) + return _raise_or_return(target_type) + + +def _routing_enabled(): + """Return whether metadata routing is enabled. + + Returns: + enabled : bool + Whether metadata routing is enabled. If the config is not set, it + defaults to False. + + TODO: remove when the config key is no longer available in scikit-learn + """ + return sklearn.get_config().get("enable_metadata_routing", False) + + +def _raise_for_params(params, owner, method): + """Raise an error if metadata routing is not enabled and params are passed. + + Parameters: + params : dict + The metadata passed to a method. + owner : object + The object to which the method belongs. + method : str + The name of the method, e.g. "fit". + + Raises: + ValueError + If metadata routing is not enabled and params are passed. + """ + caller = ( + f"{owner.__class__.__name__}.{method}" + if method + else owner.__class__.__name__ + ) + if not _routing_enabled() and params: + raise ValueError( + f"Passing extra keyword arguments to {caller} is only supported if" + " enable_metadata_routing=True, which you can set using" + " `sklearn.set_config`. See the User Guide" + " for more" + f" details. Extra parameters passed are: {set(params)}" + ) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py new file mode 100644 index 000000000000..250b12c51274 --- /dev/null +++ b/keras/src/wrappers/sklearn_test.py @@ -0,0 +1,160 @@ +"""Tests using Scikit-Learn's bundled estimator_checks.""" + +from contextlib import contextmanager + +import pytest +import sklearn +from packaging.version import parse as parse_version +from sklearn.utils.estimator_checks import parametrize_with_checks + +import keras +from keras.src.backend import floatx +from keras.src.backend import set_floatx +from keras.src.layers import Dense +from keras.src.layers import Input +from keras.src.models import Model +from keras.src.wrappers import SKLearnClassifier +from keras.src.wrappers import SKLearnRegressor +from keras.src.wrappers import SKLearnTransformer + + +def wrapped_parametrize_with_checks( + estimators, + *, + legacy=True, + expected_failed_checks=None, +): + """Wrapped `parametrize_with_checks` handling backwards compat.""" + sklearn_version = parse_version( + parse_version(sklearn.__version__).base_version + ) + + if sklearn_version >= parse_version("1.6"): + return parametrize_with_checks( + estimators, + legacy=legacy, + expected_failed_checks=expected_failed_checks, + ) + + def patched_more_tags(estimator, expected_failed_checks): + import copy + + original_tags = copy.deepcopy(sklearn.utils._tags._safe_tags(estimator)) + + def patched_more_tags(self): + original_tags.update({"_xfail_checks": expected_failed_checks}) + return original_tags + + estimator.__class__._more_tags = patched_more_tags + return estimator + + estimators = [ + patched_more_tags(estimator, expected_failed_checks(estimator)) + for estimator in estimators + ] + + # legacy is not supported and ignored + return parametrize_with_checks(estimators) + + +def dynamic_model(X, y, loss, layers=[10]): + """Creates a basic MLP classifier dynamically choosing binary/multiclass + classification loss and ouput activations. + """ + n_features_in = X.shape[1] + inp = Input(shape=(n_features_in,)) + + hidden = inp + for layer_size in layers: + hidden = Dense(layer_size, activation="relu")(hidden) + + n_outputs = y.shape[1] if len(y.shape) > 1 else 1 + out = [Dense(n_outputs, activation="softmax")(hidden)] + model = Model(inp, out) + model.compile(loss=loss, optimizer="rmsprop") + + return model + + +@contextmanager +def use_floatx(x): + """Context manager to temporarily + set the keras backend precision. + """ + _floatx = floatx() + set_floatx(x) + try: + yield + finally: + set_floatx(_floatx) + + +EXPECTED_FAILED_CHECKS = { + "SKLearnClassifier": { + "check_classifiers_regression_target": "not an issue in sklearn>=1.6", + "check_parameters_default_constructible": ( + "not an issue in sklearn>=1.6" + ), + "check_classifiers_one_label_sample_weights": ( + "0 sample weight is not ignored" + ), + "check_classifiers_classes": ( + "with small test cases the estimator returns not all classes " + "sometimes" + ), + "check_classifier_data_not_an_array": ( + "This test assumes reproducibility in fit." + ), + "check_supervised_y_2d": "This test assumes reproducibility in fit.", + "check_fit_idempotent": "This test assumes reproducibility in fit.", + }, + "SKLearnRegressor": { + "check_parameters_default_constructible": ( + "not an issue in sklearn>=1.6" + ), + }, + "SKLearnTransformer": { + "check_parameters_default_constructible": ( + "not an issue in sklearn>=1.6" + ), + }, +} + + +@wrapped_parametrize_with_checks( + estimators=[ + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "loss": "categorical_crossentropy", + "layers": [20, 20, 20], + }, + fit_kwargs={"epochs": 5}, + ), + SKLearnRegressor( + model=dynamic_model, + model_kwargs={"loss": "mse"}, + ), + SKLearnTransformer( + model=dynamic_model, + model_kwargs={"loss": "mse"}, + ), + ], + expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[ + type(estimator).__name__ + ], +) +def test_sklearn_estimator_checks(estimator, check): + """Checks that can be passed with sklearn's default tolerances + and in a single epoch. + """ + try: + check(estimator) + except Exception as exc: + if keras.config.backend() in ["numpy", "openvino"] and ( + isinstance(exc, NotImplementedError) + or "NotImplementedError" in str(exc) + ): + pytest.xfail("Backend not implemented") + else: + raise diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py new file mode 100644 index 000000000000..90d36c669792 --- /dev/null +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -0,0 +1,494 @@ +import copy + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.models.cloning import clone_model +from keras.src.models.model import Model +from keras.src.wrappers.fixes import _routing_enabled +from keras.src.wrappers.fixes import _validate_data +from keras.src.wrappers.fixes import type_of_target +from keras.src.wrappers.utils import TargetReshaper +from keras.src.wrappers.utils import _check_model +from keras.src.wrappers.utils import assert_sklearn_installed + +try: + import sklearn + from sklearn.base import BaseEstimator + from sklearn.base import ClassifierMixin + from sklearn.base import RegressorMixin + from sklearn.base import TransformerMixin +except ImportError: + sklearn = None + + class BaseEstimator: + pass + + class ClassifierMixin: + pass + + class RegressorMixin: + pass + + class TransformerMixin: + pass + + +class SKLBase(BaseEstimator): + """Base class for scikit-learn wrappers. + + Note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + history_ : dict + The history of the fit, returned by `model.fit`. + """ + + def __init__( + self, + model, + warm_start=False, + model_kwargs=None, + fit_kwargs=None, + ): + assert_sklearn_installed(self.__class__.__name__) + self.model = model + self.warm_start = warm_start + self.model_kwargs = model_kwargs + self.fit_kwargs = fit_kwargs + + def _more_tags(self): + return {"non_deterministic": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.non_deterministic = True + return tags + + def __sklearn_clone__(self): + """Return a deep copy of the model. + + This is used by the `sklearn.base.clone` function. + """ + model = ( + self.model if callable(self.model) else copy.deepcopy(self.model) + ) + return type(self)( + model=model, + warm_start=self.warm_start, + model_kwargs=self.model_kwargs, + ) + + @property + def epoch_(self): + """The current training epoch.""" + return getattr(self, "history_", {}).get("epoch", 0) + + def set_fit_request(self, **kwargs): + """Set requested parameters by the fit method. + + Please see [scikit-learn's metadata routing]( + https://scikit-learn.org/stable/metadata_routing.html) for more + details. + + + Arguments: + kwargs : dict + Arguments should be of the form `param_name=alias`, and `alias` + can be one of `{True, False, None, str}`. + + Returns: + self + """ + if not _routing_enabled(): + raise RuntimeError( + "This method is only available when metadata routing is " + "enabled. You can enable it using " + "sklearn.set_config(enable_metadata_routing=True)." + ) + + self._metadata_request = sklearn.utils.metadata_routing.MetadataRequest( + owner=self.__class__.__name__ + ) + for param, alias in kwargs.items(): + self._metadata_request.score.add_request(param=param, alias=alias) + return self + + def _get_model(self, X, y): + if isinstance(self.model, Model): + return clone_model(self.model) + else: + args = self.model_kwargs or {} + return self.model(X=X, y=y, **args) + + def fit(self, X, y, **kwargs): + """Fit the model. + + Args: + X: array-like, shape=(n_samples, n_features) + The input samples. + y: array-like, shape=(n_samples,) or (n_samples, n_outputs) + The targets. + **kwargs: keyword arguments passed to `model.fit` + """ + X, y = _validate_data(self, X, y) + y = self._process_target(y, reset=True) + model = self._get_model(X, y) + _check_model(model) + + fit_kwargs = self.fit_kwargs or {} + fit_kwargs.update(kwargs) + self.history_ = model.fit(X, y, **fit_kwargs) + + self.model_ = model + return self + + def predict(self, X): + """Predict using the model.""" + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + raw_output = self.model_.predict(X) + return self._reverse_process_target(raw_output) + + def _process_target(self, y, reset=False): + """Regressors are NOOP here, classifiers do OHE.""" + # This is here to raise the right error in case of invalid target + type_of_target(y, raise_unknown=True) + if reset: + self._target_encoder = TargetReshaper().fit(y) + return self._target_encoder.transform(y) + + def _reverse_process_target(self, y): + """Regressors are NOOP here, classifiers reverse OHE.""" + return self._target_encoder.inverse_transform(y) + + +@keras_export("keras.wrappers.SKLearnClassifier") +class SKLearnClassifier(ClassifierMixin, SKLBase): + """scikit-learn compatible classifier wrapper for Keras models. + + Note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + history_ : dict + The history of the fit, returned by `model.fit`. + classes_ : array-like, shape=(n_classes,) + The classes labels. + + Example: + Here we use a function which creates a basic MLP model dynamically + choosing the input and output shapes. We will use this to create our + scikit-learn model. + + ``` python + from keras.layers import Dense, Input + from keras.models import Model + + def dynamic_model(X, y, loss, layers=[10]): + # Creates a basic MLP model dynamically choosing the input and + # output shapes. + n_features_in = X.shape[1] + inp = Input(shape=(n_features_in,)) + + hidden = inp + for layer_size in layers: + hidden = Dense(layer_size, activation="relu")(hidden) + + n_outputs = y.shape[1] if len(y.shape) > 1 else 1 + out = Dense(n_outputs, activation="softmax")(hidden) + model = Model(inp, out) + model.compile(loss=loss, optimizer="rmsprop") + + return model + ``` + + You can then use this function to create a scikit-learn compatible model + and fit it on some data. + + ``` python + from sklearn.datasets import make_classification + from keras.wrappers import SKLearnClassifier + + X, y = make_classification(n_samples=1000, n_features=10) + est = SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "loss": "categorical_crossentropy", + "layers": [20, 20, 20], + }, + ) + + est.fit(X, y, epochs=5) + ``` + """ + + def _process_target(self, y, reset=False): + """Classifiers do OHE.""" + target_type = type_of_target(y, raise_unknown=True) + if target_type not in ["binary", "multiclass"]: + raise ValueError( + "Only binary and multiclass target types are supported." + f" Target type: {target_type}" + ) + if reset: + self._target_encoder = sklearn.pipeline.make_pipeline( + TargetReshaper(), + sklearn.preprocessing.OneHotEncoder(sparse_output=False), + ).fit(y) + self.classes_ = np.unique(y) + if len(self.classes_) == 1: + raise ValueError( + "Classifier can't train when only one class is present." + ) + return self._target_encoder.transform(y) + + def _more_tags(self): + # required to be compatible with scikit-learn<1.6 + return {"poor_score": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.classifier_tags.poor_score = True + return tags + + +@keras_export("keras.wrappers.SKLearnRegressor") +class SKLearnRegressor(RegressorMixin, SKLBase): + """scikit-learn compatible regressor wrapper for Keras models. + + Note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + + Example: + Here we use a function which creates a basic MLP model dynamically + choosing the input and output shapes. We will use this to create our + scikit-learn model. + + ``` python + from keras.layers import Dense, Input + from keras.models import Model + + def dynamic_model(X, y, loss, layers=[10]): + # Creates a basic MLP model dynamically choosing the input and + # output shapes. + n_features_in = X.shape[1] + inp = Input(shape=(n_features_in,)) + + hidden = inp + for layer_size in layers: + hidden = Dense(layer_size, activation="relu")(hidden) + + n_outputs = y.shape[1] if len(y.shape) > 1 else 1 + out = Dense(n_outputs)(hidden) + model = Model(inp, out) + model.compile(loss=loss, optimizer="rmsprop") + + return model + ``` + + You can then use this function to create a scikit-learn compatible model + and fit it on some data. + + ``` python + from sklearn.datasets import make_regression + from keras.wrappers import SKLearnRegressor + + X, y = make_regression(n_samples=1000, n_features=10) + est = SKLearnRegressor( + model=dynamic_model, + model_kwargs={ + "loss": "mse", + "layers": [20, 20, 20], + }, + ) + + est.fit(X, y, epochs=5) + ``` + """ + + def _more_tags(self): + # required to be compatible with scikit-learn<1.6 + return {"poor_score": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.regressor_tags.poor_score = True + return tags + + +@keras_export("keras.wrappers.SKLearnTransformer") +class SKLearnTransformer(TransformerMixin, SKLBase): + """scikit-learn compatible transformer wrapper for Keras models. + + Note that this is a scikit-learn compatible transformer, and not a + transformer in the deep learning sense. + + Also note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + history_ : dict + The history of the fit, returned by `model.fit`. + + Example: + A common use case for a scikit-learn transformer, is to have a step + which gives you the embedding of your data. Here we assume + `my_package.my_model` is a Keras model which takes the input and gives + embeddings of the data, and `my_package.my_data` is your dataset loader. + + ``` python + from my_package import my_model, my_data + from keras.wrappers import SKLearnTransformer + from sklearn.frozen import FrozenEstimator # requires scikit-learn>=1.6 + from sklearn.pipeline import make_pipeline + from sklearn.ensemble import HistGradientBoostingClassifier + + X, y = my_data() + + trs = FrozenEstimator(SKLearnTransformer(model=my_model)) + pipe = make_pipeline(trs, HistGradientBoostingClassifier()) + pipe.fit(X, y) + ``` + + Note that in the above example, `FrozenEstimator` prevents any further + training of the transformer step in the pipeline, which can be the case + if you don't want to change the embedding model at hand. + """ + + def transform(self, X): + """Transform the data. + + Args: + X: array-like, shape=(n_samples, n_features) + The input samples. + + Returns: + X_transformed: array-like, shape=(n_samples, n_features) + The transformed data. + """ + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + return self.model_.predict(X) + + def _more_tags(self): + # required to be compatible with scikit-learn<1.6 + return { + "preserves_dtype": [], + } + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags.preserves_dtype = [] + return tags diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py new file mode 100644 index 000000000000..8c2954b055ad --- /dev/null +++ b/keras/src/wrappers/utils.py @@ -0,0 +1,90 @@ +import numpy as np + +try: + import sklearn + from sklearn.base import BaseEstimator + from sklearn.base import TransformerMixin +except ImportError: + sklearn = None + + class BaseEstimator: + pass + + class TransformerMixin: + pass + + +def assert_sklearn_installed(symbol_name): + if sklearn is None: + raise ImportError( + f"{symbol_name} requires `scikit-learn` to be installed. " + "Run `pip install scikit-learn` to install it." + ) + + +def _check_model(model): + """Check whether the model need sto be compiled.""" + # compile model if user gave us an un-compiled model + if not model.compiled or not model.loss or not model.optimizer: + raise RuntimeError( + "Given model needs to be compiled, and have a loss " + "and an optimizer." + ) + + +class TargetReshaper(TransformerMixin, BaseEstimator): + """Convert 1D targets to 2D and back. + + For use in pipelines with transformers that only accept + 2D inputs, like OneHotEncoder and OrdinalEncoder. + + Attributes: + ndim_ : int + Dimensions of y that the transformer was trained on. + """ + + def fit(self, y): + """Fit the transformer to a target y. + + Returns: + TargetReshaper + A reference to the current instance of TargetReshaper. + """ + self.ndim_ = y.ndim + return self + + def transform(self, y): + """Makes 1D y 2D. + + Args: + y : np.ndarray + Target y to be transformed. + + Returns: + np.ndarray + A numpy array, of dimension at least 2. + """ + if y.ndim == 1: + return y.reshape(-1, 1) + return y + + def inverse_transform(self, y): + """Revert the transformation of transform. + + Args: + y: np.ndarray + Transformed numpy array. + + Returns: + np.ndarray + If the transformer was fit to a 1D numpy array, + and a 2D numpy array with a singleton second dimension + is passed, it will be squeezed back to 1D. Otherwise, it + will eb left untouched. + """ + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + if self.ndim_ == 1 and y.ndim == 2: + return np.squeeze(y, axis=1) + return y diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py deleted file mode 100644 index 96b8b1a0a889..000000000000 --- a/keras/utils/generic_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -import numpy as np -import time -import sys - -def get_from_module(identifier, module_params, module_name, instantiate=False): - if type(identifier) is str: - res = module_params.get(identifier) - if not res: - raise Exception('Invalid', module_name, ': ' + identifier) - if instantiate: - return res() - else: - return res - return identifier - -def make_tuple(*args): - return args - -class Progbar(object): - def __init__(self, target, width=30): - ''' - @param target: total number of steps expected - ''' - self.width = width - self.target = target - self.sum_values = {} - self.unique_values = [] - self.start = time.time() - self.total_width = 0 - self.seen_so_far = 0 - - def update(self, current, values=[]): - ''' - @param current: index of current step - @param values: list of tuples (name, value_for_last_step). - The progress bar will display averages for these values. - ''' - for k, v in values: - if k not in self.sum_values: - self.sum_values[k] = [v, 1] - self.unique_values.append(k) - else: - self.sum_values[k][0] += v * (current-self.seen_so_far) - self.sum_values[k][1] += (current-self.seen_so_far) - - prev_total_width = self.total_width - sys.stdout.write("\b" * (self.total_width+1)) - - bar = '%d/%d [' % (current, self.target) - prog = float(current)/self.target - prog_width = int(self.width*prog) - if prog_width > 0: - bar += ('='*(prog_width-1)) - if current < self.target: - bar += '>' - else: - bar += '=' - bar += ('.'*(self.width-prog_width)) - bar += ']' - sys.stdout.write(bar) - self.total_width = len(bar) - - now = time.time() - if current: - time_per_unit = (now - self.start) / current - else: - time_per_unit = 0 - eta = time_per_unit*(self.target - current) - info = '' - if current < self.target: - info += ' - ETA: %ds' % eta - else: - info += ' - %ds' % (now - self.start) - for k in self.unique_values: - info += ' - %s: %.4f' % (k, self.sum_values[k][0]/self.sum_values[k][1]) - - self.total_width += len(info) - if prev_total_width > self.total_width: - info += ((prev_total_width-self.total_width) * " ") - - sys.stdout.write(info) - sys.stdout.flush() - self.seen_so_far = current - - if current >= self.target: - sys.stdout.write("\n") - - def add(self, n, values=[]): - self.update(self.seen_so_far+n, values) diff --git a/keras/utils/np_utils.py b/keras/utils/np_utils.py deleted file mode 100644 index 0c53f9a98365..000000000000 --- a/keras/utils/np_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -import scipy as sp - -def to_categorical(y, nb_classes=None): - '''Convert class vector (integers from 0 to nb_classes) - to binary class matrix, for use with categorical_crossentropy - ''' - y = np.asarray(y, dtype='int32') - if not nb_classes: - nb_classes = np.max(y)+1 - Y = np.zeros((len(y), nb_classes)) - for i in range(len(y)): - Y[i, y[i]] = 1. - return Y - - -def binary_logloss(p, y): - epsilon = 1e-15 - p = sp.maximum(epsilon, p) - p = sp.minimum(1-epsilon, p) - res = sum(y*sp.log(p) + sp.subtract(1,y)*sp.log(sp.subtract(1,p))) - res *= -1.0/len(y) - return res - -def multiclass_logloss(P, Y): - score = 0. - npreds = [P[i][Y[i]-1] for i in range(len(Y))] - score = -(1./len(Y)) * np.sum(np.log(npreds)) - return score - -def accuracy(p, y): - return np.mean([a==b for a, b in zip(p, y)]) - -def probas_to_classes(y_pred): - if len(y_pred.shape) > 1 and y_pred.shape[1] > 1: - return categorical_probas_to_classes(y_pred) - return np.array([1 if p > 0.5 else 0 for p in y_pred]) - -def categorical_probas_to_classes(p): - return np.argmax(p, axis=1) - - -def save_array(array, name): - import tables - f = tables.open_file(name, 'w') - atom = tables.Atom.from_dtype(array.dtype) - ds = f.createCArray(f.root, 'data', atom, array.shape) - ds[:] = array - f.close() - -def load_array(name): - import tables - f = tables.open_file(name) - array = f.root.data - a=np.empty(shape=array.shape, dtype=array.dtype) - a[:]=array[:] - f.close() - return a \ No newline at end of file diff --git a/keras/utils/theano_utils.py b/keras/utils/theano_utils.py deleted file mode 100644 index f90e7ae1fa1c..000000000000 --- a/keras/utils/theano_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np -import theano -import theano.tensor as T - -def floatX(X): - return np.asarray(X, dtype=theano.config.floatX) - -def sharedX(X, dtype=theano.config.floatX, name=None): - return theano.shared(np.asarray(X, dtype=dtype), name=name) - -def shared_zeros(shape, dtype=theano.config.floatX, name=None): - return sharedX(np.zeros(shape), dtype=dtype, name=name) - -def shared_scalar(val=0., dtype=theano.config.floatX, name=None): - return theano.shared(np.cast[dtype](val)) - -def shared_ones(shape, dtype=theano.config.floatX, name=None): - return sharedX(np.ones(shape), dtype=dtype, name=name) - -def alloc_zeros_matrix(*dims): - return T.alloc(np.cast[theano.config.floatX](0.), *dims) diff --git a/pip_build.py b/pip_build.py new file mode 100644 index 000000000000..799e84d32797 --- /dev/null +++ b/pip_build.py @@ -0,0 +1,148 @@ +"""Script to create (and optionally install) a `.whl` archive for Keras 3. + +Usage: + +1. Create a `.whl` file in `dist/`: + +``` +python3 pip_build.py +``` + +2. Also install the new package immediately after: + +``` +python3 pip_build.py --install +``` +""" + +import argparse +import datetime +import glob +import os +import pathlib +import re +import shutil + +# Needed because importing torch after TF causes the runtime to crash +try: + import torch # noqa: F401 +except ImportError: + pass + +package = "keras" +build_directory = "tmp_build_dir" +dist_directory = "dist" +to_copy = ["pyproject.toml", "README.md"] + + +def export_version_string(version, is_nightly=False, rc_index=None): + """Export Version and Package Name.""" + if is_nightly: + date = datetime.datetime.now() + version += f".dev{date:%Y%m%d%H}" + # Update `name = "keras"` with "keras-nightly" + pyproj_pth = pathlib.Path("pyproject.toml") + pyproj_str = pyproj_pth.read_text().replace( + 'name = "keras"', 'name = "keras-nightly"' + ) + pyproj_pth.write_text(pyproj_str) + elif rc_index is not None: + version += f"rc{str(rc_index)}" + + # Make sure to export the __version__ string + with open(os.path.join(package, "src", "version.py")) as f: + init_contents = f.read() + with open(os.path.join(package, "src", "version.py"), "w") as f: + init_contents = re.sub( + "\n__version__ = .*\n", + f'\n__version__ = "{version}"\n', + init_contents, + ) + f.write(init_contents) + + +def ignore_files(_, filenames): + return [f for f in filenames if f.endswith("_test.py")] + + +def copy_source_to_build_directory(root_path): + # Copy sources (`keras/` directory and setup files) to build + # directory + os.chdir(root_path) + os.mkdir(build_directory) + shutil.copytree( + package, os.path.join(build_directory, package), ignore=ignore_files + ) + for fname in to_copy: + shutil.copy(fname, os.path.join(f"{build_directory}", fname)) + os.chdir(build_directory) + + +def build(root_path, is_nightly=False, rc_index=None): + if os.path.exists(build_directory): + raise ValueError(f"Directory already exists: {build_directory}") + + try: + copy_source_to_build_directory(root_path) + + from keras.src.version import __version__ # noqa: E402 + + export_version_string(__version__, is_nightly, rc_index) + return build_and_save_output(root_path, __version__) + finally: + # Clean up: remove the build directory (no longer needed) + shutil.rmtree(build_directory) + + +def build_and_save_output(root_path, __version__): + # Build the package + os.system("python3 -m build") + + # Save the dist files generated by the build process + os.chdir(root_path) + if not os.path.exists(dist_directory): + os.mkdir(dist_directory) + for fpath in glob.glob( + os.path.join(build_directory, dist_directory, "*.*") + ): + shutil.copy(fpath, dist_directory) + + # Find the .whl file path + whl_path = None + for fname in os.listdir(dist_directory): + if __version__ in fname and fname.endswith(".whl"): + whl_path = os.path.abspath(os.path.join(dist_directory, fname)) + if whl_path: + print(f"Build successful. Wheel file available at {whl_path}") + else: + print("Build failed.") + return whl_path + + +def install_whl(whl_fpath): + print(f"Installing wheel file: {whl_fpath}") + os.system(f"pip3 install {whl_fpath} --force-reinstall --no-dependencies") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--install", + action="store_true", + help="Whether to install the generated wheel file.", + ) + parser.add_argument( + "--nightly", + action="store_true", + help="Whether to generate nightly wheel file.", + ) + parser.add_argument( + "--rc", + type=int, + help="Specify `[0-9] when generating RC wheels.", + ) + args = parser.parse_args() + root_path = pathlib.Path(__file__).parent.resolve() + whl_path = build(root_path, args.nightly, args.rc) + if whl_path and args.install: + install_whl(whl_path) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000000..bd9e7c30f869 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,111 @@ +[build-system] +requires = ["setuptools >=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "keras" +authors = [ + {name = "Keras team", email = "keras-users@googlegroups.com"}, +] +description = "Multi-backend Keras" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "Apache License 2.0"} +dynamic = ["version"] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: Unix", + "Operating System :: MacOS", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Software Development", +] +dependencies = [ + "absl-py", + "numpy", + "rich", + "namex", + "h5py", + "optree", + "ml-dtypes", + "packaging", +] +# Run also: pip install -r requirements.txt + +[project.urls] +Home = "https://keras.io/" +Repository = "https://github.com/keras-team/keras" + +[tool.setuptools.dynamic] +version = {attr = "keras.src.version.__version__"} + +[tool.setuptools.package-dir] +"" = "." +"keras" = "keras/api" # Remap api/ to the root of the package. +"keras.src" = "keras/src" + +[tool.ruff] +line-length = 80 +exclude = ["keras/src/namex"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle error + "F", # Pyflakes + "I", # isort +] +ignore = [ + "E722", # do not use bare 'except' + "E741", # ambiguous variable name + "E731", # do not assign a `lambda` expression, use a `def` +] + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = ["E501", "F401"] # lines too long; imported but unused +"**/random.py" = ["F401"] # imported but unused +"examples/*" = ["I", "E"] +"guides/*" = ["I", "E", "F"] + +[tool.ruff.lint.isort] +force-single-line = true +known-first-party = ["keras"] + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore::DeprecationWarning", + "ignore::ImportWarning", + "ignore::RuntimeWarning", + "ignore::PendingDeprecationWarning", + "ignore::FutureWarning", + "ignore::UserWarning", + # Ignore a spurious warning on tf-nightly related to save model changes. + "ignore:Custom mask layers require a config", +] +addopts = "-vv" + +# Do not run tests in the `build` folders +norecursedirs = ["build"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "@abstract", + "raise NotImplementedError", +] +omit = [ + "*/*_test.py", + "keras/src/legacy/*", +] + +[tool.coverage.run] +branch = true +omit = [ + "*/*_test.py", + "keras/src/legacy/*", +] + diff --git a/requirements-common.txt b/requirements-common.txt new file mode 100644 index 000000000000..2fecef1d5946 --- /dev/null +++ b/requirements-common.txt @@ -0,0 +1,31 @@ +pre-commit +namex>=0.0.8 +ruff +pytest +numpy +scipy +scikit-learn +pillow +pandas +absl-py +requests +h5py +ml-dtypes +protobuf +tensorboard-plugin-profile +rich +build +optree +pytest-cov +packaging +# for tree_test.py +dm_tree +coverage +# for onnx_test.py +onnxruntime +# https://github.com/keras-team/keras/issues/21390 +# onnxscript==0.3.2 breaks LSTM model export. +onnxscript!=0.3.2 +openvino +# for grain_dataset_adapter_test.py +grain diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt new file mode 100644 index 000000000000..f1ffb0f91933 --- /dev/null +++ b/requirements-jax-cuda.txt @@ -0,0 +1,14 @@ +# Tensorflow cpu-only version (needed for testing). +tensorflow-cpu~=2.18.1 +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax with cuda support. +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax[cuda12]==0.6.2 +flax + +-r requirements-common.txt diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt new file mode 100644 index 000000000000..f895f0224154 --- /dev/null +++ b/requirements-tensorflow-cuda.txt @@ -0,0 +1,12 @@ +# Tensorflow with cuda support. +tensorflow[and-cuda]~=2.18.1 +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax cpu-only version (needed for testing). +jax[cpu] + +-r requirements-common.txt diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt new file mode 100644 index 000000000000..9bad75fb290b --- /dev/null +++ b/requirements-torch-cuda.txt @@ -0,0 +1,14 @@ +# Tensorflow cpu-only version (needed for testing). +tensorflow-cpu~=2.18.1 +tf2onnx + +# Torch with cuda support. +# - torch is pinned to a version that is compatible with torch-xla. +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.6.0 +torch-xla==2.6.0;sys_platform != 'darwin' + +# Jax cpu-only version (needed for testing). +jax[cpu] + +-r requirements-common.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000000..e5a44501e6b4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +# Tensorflow. +tensorflow-cpu~=2.18.1;sys_platform != 'darwin' +tensorflow~=2.18.1;sys_platform == 'darwin' +tf_keras +tf2onnx + +# Torch. +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0;sys_platform != 'darwin' +torch==2.6.0;sys_platform == 'darwin' +torch-xla==2.6.0;sys_platform != 'darwin' + +# Jax. +# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. +# Note that we test against the latest JAX on GPU. +jax[cpu]==0.5.0 +flax +# Common deps. +-r requirements-common.txt diff --git a/setup.py b/setup.py deleted file mode 100644 index 42e47968fcd6..000000000000 --- a/setup.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python - -from distutils.core import setup - -setup(name='Keras', - version='0.0.1', - description='Theano-based Deep Learning', - author='Francois Chollet', - author_email='francois.chollet@gmail.com', - url='https://github.com/fchollet/keras', - license='MIT', - packages=[ - 'keras', - 'keras.layers', - 'keras.preprocessing', - 'keras.datasets', - 'keras.utils', - ], - # TODO: dependencies -) \ No newline at end of file diff --git a/shell/api_gen.sh b/shell/api_gen.sh new file mode 100755 index 000000000000..db2f87c43b3b --- /dev/null +++ b/shell/api_gen.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -Eeuo pipefail + +base_dir=$(dirname $(dirname $0)) + +echo "Generating api directory with public APIs..." +# Generate API Files +python3 "${base_dir}"/api_gen.py + +# Format code because `api_gen.py` might order +# imports differently. +echo "Formatting api directory..." +(SKIP=api-gen pre-commit run --files $(find "${base_dir}"/keras/api -type f) --hook-stage pre-commit || true) > /dev/null diff --git a/shell/format.sh b/shell/format.sh new file mode 100755 index 000000000000..c4c36607b1d9 --- /dev/null +++ b/shell/format.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -Eeuo pipefail + +if ! command -v pre-commit 2>&1 >/dev/null +then + echo 'Please `pip install pre-commit` to run format.sh.' + exit 1 +fi + +base_dir=$(dirname $(dirname $0)) + +echo "Formatting all files..." +SKIP=api-gen pre-commit run --all-files